• [CEOI2004]锯木厂选址 斜率优化DP


    斜率优化DP

    先考虑朴素DP方程,

    f[i][k]代表第k个厂建在i棵树那里的最小代价,最后答案为f[n+1][3];

    f[i][k]=min(f[j][k-1] + 把j+1~i的树都运到i的代价)

    首先注意到“把j+1~i的树都运到i的代价”不太方便表达,每次都暴力计算显然是无法承受的,

    于是考虑前缀和优化,观察到先运到下一棵树那里,等一会再运下去,和直接运下去是等效的。

    设sum[i]代表1 ~ i的树都运到i的代价,

    于是根据前缀和思想,猜想我们可以用1 ~ r 的代价与 1 ~ l-1的代价获取l ~ r的代价,

    所以要做的就是吧1 ~ l-1 对 1 ~ r产生的贡献给算出来,然后减掉,

    考虑先把1 ~ l-1的树都运到l-1,所以这部分的代价是sum[l-1],

    然后再把树一次性运到r,那么代价是sum_weight[l-1] * (sum_len[r] - sum_len[l-1]);

                                        总的重量 * 现在要再次运的路程

    这里为了表示方便,用$sw$代表sum_weight,用$sl$代表sum_len;


    于是用$sum[r]$ 减去这两部分代价就可以得到$l ~ r$ 的代价(把$l ~ r$的树都运到$r$)

    代价(l ~ r)$ =  sum[r] - sum[l-1] - sw[l-1] * (sl[r] - sl[l-1]);$

    那么如何计算sum ?

    也是一样的思想,用前面的推后面的,先得到前面的代价,再加上新增的代价即可

    $sum[i]=sum[i-1] + swt[i-1] * len[i-1];$//len[i-1]代表i-1到i的距离

    于是我们就得到了DP方程:

    当$k==1$时,$f[i][k]=sum[i]$;

    else

            $f[i][k]=min(f[j][k-1] + sum[i] - sum[j] - sw[j] * (sl[i] - sl[j]));$

    但是可以发现,由于k最大就是3,而且3必须是n+1才可以取,

    而且当$k==1$时,$f[i][k]$就等于$sum[i]$,

    所以考虑优化维数:

    当$k==1$时,不用求,因为有$sum$了

    当$k==2$时,调用的$f[j][k-1]$替换为$sum[j]$,并且还可以发现由于后面有一个$-sum[j]$,所以可以直接消掉

    当$k==3$时,由于只有$n+1$可以取,所以直接在外面多写一个循环,相当于最后统计答案即可

    转移方式同朴素方程

    但是这样是$n^2$的DP,而$n$有20000,那怎么办呢?

    考虑斜率优化。

    首先我们用暴力打表可以发现,决策是单调的,

    打表代码(朴素DP):

     1 #include<bits/stdc++.h>
     2 using namespace std;
     3 #define R register int
     4 #define AC 20100
     5 int n, ans;
     6 int sum[AC], sum_weight[AC], sum_len[AC], f[AC];
     7 int weight[AC], len[AC];
     8 inline int read()
     9 {
    10     int x = 0; char c = getchar();
    11     while(c < '0' || c > '9') c = getchar();
    12     while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    13     return x; 
    14 }
    15 
    16 void pre()
    17 {
    18     n = read();
    19     for(R i = 1; i <= n; i ++)
    20         weight[i] = read(), len[i] = read();
    21 }
    22 
    23 void getsum()
    24 {
    25     for(R i = 1; i <= n + 1; i ++)//山脚的也要求
    26     {
    27         sum_len[i] = sum_len[i - 1] + len[i - 1];
    28         sum_weight[i] = sum_weight[i - 1] + weight[i];
    29         sum[i] = sum[i - 1] + sum_weight[i - 1] * len[i - 1];
    30     //    printf("%d : %d
    ",i,sum[i]);
    31     }
    32 }
    33 
    34 void work()
    35 {
    36     for(R i = 1; i <= n; i ++)
    37     {
    38         int tmp = 0;
    39         f[i] = INT_MAX;
    40         for(R j = 1;j < i;j ++)
    41         {
    42             if(sum[i] - sum_weight[j] * (sum_len[i] - sum_len[j]) < f[i])
    43             {
    44                 f[i] = sum[i] - sum_weight[j] * (sum_len[i] - sum_len[j]);
    45                 tmp = j;
    46             } 
    47         }
    48         printf("%d --- > %d
    ", tmp, i);//打表验证决策单调性
    49     }
    50     ans = INT_MAX;
    51     for(R i = 2; i <= n; i ++)//注意应该是n+1,因为山脚是在下面
    52         ans = min(ans, f[i] + sum[n + 1] - sum[i] - sum_weight[i] * (sum_len[n + 1] - sum_len[i]));
    53     for(R i = 2; i <= n; i ++) printf("%d : %d
    ", i, f[i]);
    54     printf("%d
    ", ans);
    55 }
    56 
    57 int main()
    58 {
    59     freopen("in.in", "r", stdin);
    60     freopen("out.out", "w", stdout);
    61     pre();
    62     getsum();
    63     work();
    64     fclose(stdin);
    65     fclose(stdout);
    66     return 0;
    67 }
    View Code

    于是我们推斜率优化方程:

    设有 $k < j < i$,且$j$优于$k$(相当于$j$是后面来的),则有:

    $sum[i] - sw[j] * (sl[i] - sl[j]) < sum[i] - sw[k] * (sl[i] - sl[k])$

    $sw[j] * (sl[i] - sl[j]) > sw[k] * (sl[i] - sl[k])$

    $sw[j] * sl[i] - sw[j] * sl[j] >  sw[k] * sl[i] - sw[k] * sl[k]$

    $sw[k] * sl[k] - sw[j] * sl[j] > sw[k] * sl[i] - sw[j] * sl[i]$

    $sw[k] * sl[k] - sw[j] * sl[j] > sl[i] * (sw[k] - sw[j])$

    $frac{(sw[k] * sl[k] - sw[j] * sl[j])} {(sw[k] - sw[j])} < sl[i]$ //注意sw[k] - sw[j]小于0,要变号

    所以令$K = frac{(sw[k] * sl[k] - sw[j] * sl[j])}{(sw[k] - sw[j])}$;

    则    while(head < tail && k(q[head],q[head+1]) < sum_len[i])  ++head;

    while(head < tail && k(q[tail-1],q[tail]) > k(q[tail],i)) --tail;

    最后上代码:

     1 #include<bits/stdc++.h>
     2 using namespace std;
     3 #define R register int
     4 #define AC 20100
     5 int n, ans;
     6 int sum[AC], sum_weight[AC], sum_len[AC], f[AC];
     7 int weight[AC], len[AC];
     8 int q[AC], head, tail;
     9 inline int read()
    10 {
    11     int x = 0; char c = getchar();
    12     while(c < '0' || c > '9') c = getchar();
    13     while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    14     return x; 
    15 }
    16 
    17 inline double k(int x, int y)
    18 {
    19     double a = sum_weight[x] * sum_len[x] - sum_weight[y] * sum_len[y];
    20     double b = sum_weight[x] - sum_weight[y];
    21     return a / b;
    22 }
    23 
    24 void pre()
    25 {
    26     n = read();
    27     for(R i = 1; i <= n; i ++) weight[i] = read(), len[i] = read();
    28 }
    29 
    30 void getsum()
    31 {
    32     for(R i = 1; i <= n + 1; i ++)//山脚的也要求
    33     {
    34         sum_len[i] = sum_len[i - 1] + len[i - 1];
    35         sum_weight[i] = sum_weight[i - 1] + weight[i];
    36         sum[i] = sum[i - 1] + sum_weight[i - 1] * len[i - 1];
    37     //    printf("%d : %d
    ",i,sum[i]);
    38     }
    39 }
    40 
    41 void work()
    42 {
    43     head=1;
    44     for(R i = 1; i <= n; i ++)
    45     {
    46         f[i] = INT_MAX;
    47         
    48         /*int tmp = 0;
    49         f[i] = INT_MAX;
    50         for(R j = 1; j < i; j ++)
    51         {
    52             if(sum[i] - sum_weight[j] * (sum_len[i] - sum_len[j]) < f[i])
    53             {
    54                 f[i] = sum[i] - sum_weight[j] * (sum_len[i] - sum_len[j]);
    55                 tmp = j;
    56             } 
    57         }
    58         printf("%d --- > %d
    ", tmp, i);//打表验证决策单调性*/
    59         
    60         while(head < tail && k(q[head], q[head + 1]) < sum_len[i]) ++ head;
    61         int now = q[head];
    62     //    printf("%d --- > %d
    ",now,i);
    63         f[i] = sum[i] - sum_weight[now] * (sum_len[i] - sum_len[now]);
    64         while(head < tail && k(q[tail - 1], q[tail]) > k(q[tail], i)) -- tail;
    65         q[++tail] = i;
    66     }
    67     ans = INT_MAX;
    68     for(R i = 2; i <= n; i ++)//注意应该是n+1,因为山脚是在下面,注意要从2开始,因为这是在枚举第2个厂在哪
    69         ans = min(ans, f[i] + sum[n + 1] - sum[i] - sum_weight[i] * (sum_len[n + 1] - sum_len[i]));
    70     printf("%d
    ", ans);
    71 }
    72 
    73 int main()
    74 {
    75 //    freopen("in.in", "r", stdin);
    76     pre();
    77     getsum();
    78     work();
    79 //    fclose(stdin);
    80     return 0;
    81 }
    View Code

     ---------------2018.10.12--------------优化了代码格式

  • 相关阅读:
    Open source cryptocurrency exchange
    Salted Password Hashing
    95. Unique Binary Search Trees II
    714. Best Time to Buy and Sell Stock with Transaction Fee
    680. Valid Palindrome II
    Java compiler level does not match the version of the installed Java project facet.
    eclipse自动编译
    Exception in thread "main" java.lang.StackOverflowError(栈溢出)
    博客背景美化——动态雪花飘落
    java九九乘法表
  • 原文地址:https://www.cnblogs.com/ww3113306/p/8906890.html
Copyright © 2020-2023  润新知