题目描述:
有n个函数,分别为F1,F2,...,Fn。定义Fi(x)=Ai*x^2+Bi*x+Ci (x∈N*)。给定这些Ai、Bi和Ci,请求出所有函数的所有函数值中最小的m个(如有重复的要输出多个)。
输入样例:
3 10 4 5 3 3 4 5 1 7 1
输出样例:
9 12 12 19 25 29 31 44 45 54
题目分析:
(MLE,TLE做法)看到题目我们第一眼肯定想到的是将每个函数枚举m次,然后对这n*m个数进行排序,输出前m个。代码如下:
1 #include<cstdio> 2 #include<algorithm> 3 #define MAXSIZE 100000000+20 4 #define MAXN 10000+20 5 int a[MAXN],b[MAXN],c[MAXN]; 6 int num[MAXSIZE]; 7 8 int main(){ 9 int n,m,k=1; 10 scanf("%d%d",&n,&m); 11 for(int i=1;i<=n;i++) 12 scanf("%d%d%d",&a[i],&b[i],&c[i]); 13 for(int i=1;i<=n;i++) 14 for(int j=1;j<=m;j++) 15 num[k++]=a[i]*j*j+b[i]*j+c[i]; 16 std::sort(num+1,num+k); 17 for(int i=1;i<=m;i++) 18 printf("%d ",num[i]); 19 return 0; 20 }
这样做的正确性是显而易见的:题目要求给出最小的M个数,这最小的M个数是由这N个函数生成的,最坏情况就是这M个数都由一个函数生成。那么即使这样,x的最大值也只是等于M,因为我们都知道:当二次函数y=ax2+bx+c(a>0)时,抛物线开口朝上,y随x增大而增大。所以当x1=i,x2=j(i,j∈N+,i<j)时y1<y2,所以我们对于同一个函数来说,x越小生成的数越小,这点在AC的算法中有很重要的作用。
通过上面所说的,我们可以得知x最大值也只是等于M,也就是说我们的枚举区间就是1~m。我们需要对每个函数都枚举一遍,所以这一步的时间复杂度是O(n*m),最高是一亿,暂时还没超时,但是这是无序的,也就是说我们要将这n*m个数排序。
假设我们用的是最快的快排,没被卡快排,那么这一步的时间复杂度就是O(n*m*log2(n*m));总体时间复杂度就是O(n*m*log2(n*m)),对于n=10000,m=10000,这个时间复杂度明显超时,况且我们还要花O(n*m)空间开销,这也是吃不消的。
(AC做法)因为全部都算出来再排序输出必然超时,所以我们必须想办法减少算出来的值。在做出TLE且MLE的做法的时候,我们就证明了对于同一个函数而言,x越小生成的数越小,那我们不难画出下面这个表:
与x+1的关系 | x的值 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |
函数 | < | < | < | < | < | < | < | < | |
f1 | < | < | < | < | < | < | < | < | |
f2 | < | < | < | < | < | < | < | < | |
f3 | < | < | < | < | < | < | < | < | |
f4 | < | < | < | < | < | < | < | < | |
f5 | < | < | < | < | < | < | < | < | |
f6 | < | < | < | < | < | < | < | < | |
f7 | < | < | < | < | < | < | < | < | |
... | < | < | < | < | < | < | < | < |
假设xi=1,xj=1时fi<fj
则xi=1,xj=xj+1时fi<fj
则当只存在fi和fj时xi=1时fi为最小值
稍加思考,我们可以想到下面一种贪心算法:
一开始先取每个函数中x=1时的值,输出这中间最小的,然后将最小的更新为x=2时的情况,再输出最小的......
说到最小,就想到了用最小堆实现,代码如下:
1 #include<cstdio> 2 #include<algorithm> 3 #define min(a,b) (((a)<(b))?(a):(b)) 4 using std::swap; 5 struct h{ 6 int data; 7 int place; 8 }heap[10000+20]; 9 bool operator < (const h &X,const h &Y){ //重载小于号 10 return X.data<Y.data; 11 } 12 int a[10000+20],b[10000+20],c[10000+20]; 13 int x[10000+20]; 14 int n; 15 void put(int i){ 16 while(i*2+1<=n){ 17 if(min(heap[i*2],heap[i*2+1])<heap[i]){ //比较左儿子和右儿子 18 if(heap[i*2]<heap[i*2+1]){ 19 swap(heap[i*2],heap[i]); 20 i*=2; 21 }else{ 22 swap(heap[i*2+1],heap[i]); 23 i*=2; 24 i++; 25 } 26 }else 27 break; 28 }if(i*2<=n&&heap[i*2]<heap[i]){ 29 swap(heap[i*2],heap[i]); 30 i*=2; 31 } 32 } 33 34 int main(){ 35 int m; 36 scanf("%d%d",&n,&m); 37 for(int i=1;i<=n;i++) 38 scanf("%d%d%d",&a[i],&b[i],&c[i]); 39 for(int i=1;i<=n;i++){ 40 x[i]=1; //保存每个函数的x用到了几 41 heap[i].data=a[i]+b[i]+c[i]; //当x=1的情况 42 heap[i].place=i; 43 } 44 for(int i=n;i>=1;i--){ 45 put(i); //建堆 46 } 47 for(int i=1;i<=m;i++){ 48 printf("%d ",heap[1].data); 49 int &k=heap[1].place; 50 heap[1].data=heap[1].data+2*a[k]*x[k]+a[k]+b[k]; //自己去算,不想讲,初二数学--完全平方式 51 x[heap[1].place]++; 52 put(1); 53 } 54 return 0; 55 }
时间复杂度O((n+m)log2n),最坏情况20000*log2(10000)≈280000不会超时
空间复杂度O(n)最坏10000,约0.04MB,离上限128MB很远。