• 洛谷 P3676 小清新数据结构题


    https://www.luogu.org/problemnew/show/P3676

    这题被我当成动态dp去做了,码了4k,搞了一个换根的动态dp

      1 #include<cstdio>
      2 #include<algorithm>
      3 #include<cstring>
      4 using namespace std;
      5 typedef long long ll;
      6 struct E
      7 {
      8     int to,nxt;
      9 }e[400011];
     10 int f1[200011],ne;
     11 struct P1
     12 {
     13     int len;ll a,b,c,d,e,f;
     14     //长度,(点权)和,后缀和之和,后缀和的平方之和,(答案)和
     15     //前缀和之和,前缀和的平方之和
     16 };
     17 struct P2
     18 {
     19     ll a,b;
     20     //点权和,答案和
     21 };
     22 ll a[200101];
     23 int sz[200101],hson[200101],ff[200101];
     24 int b[200101],pl[200101];
     25 int n,m;
     26 inline void merge(P1 &c,const P1 &a,const P1 &b)
     27 {
     28     c.len=a.len+b.len;
     29     c.a=a.a+b.a;
     30     c.b=b.b+a.b+b.a*a.len;
     31     c.c=b.c+b.a*b.a*a.len+2*a.b*b.a+a.c;
     32     c.d=a.d+b.d;
     33     c.e=a.e+b.e+a.a*b.len;
     34     c.f=a.f+a.a*a.a*b.len+2*b.e*a.a+b.f;
     35 }
     36 inline void initnode(P1 &c,const P2 &a)
     37 {
     38     c.len=1;c.a=c.b=c.e=a.a;c.c=c.f=a.a*a.a;c.d=a.b;
     39 }
     40 namespace S
     41 {
     42 #define lc (num<<1)
     43 #define rc (num<<1|1)
     44     P1 d[800101];
     45     inline void upd(int num){merge(d[num],d[lc],d[rc]);}
     46     P1 x;int L;
     47     void _setx(int l,int r,int num)
     48     {
     49         if(l==r)
     50         {
     51             d[num]=x;
     52             return;
     53         }
     54         int mid=(l+r)>>1;
     55         if(L<=mid)    _setx(l,mid,lc);
     56         else    _setx(mid+1,r,rc);
     57         upd(num);
     58     }
     59     P1 getx(int L,int R,int l,int r,int num)
     60     {
     61         if(L<=l&&r<=R)    return d[num];
     62         int mid=(l+r)>>1;
     63         if(L<=mid&&mid<R)
     64         {
     65             P1 x;
     66             merge(x,getx(L,R,l,mid,lc),getx(L,R,mid+1,r,rc));
     67             return x;
     68         }
     69         else if(L<=mid)
     70             return getx(L,R,l,mid,lc);
     71         else if(mid<R)
     72             return getx(L,R,mid+1,r,rc);
     73         else
     74             exit(-1);
     75     }
     76 }
     77 void dfs1(int u,int fa)
     78 {
     79     sz[u]=1;
     80     for(int v,k=f1[u];k;k=e[k].nxt)
     81         if(e[k].to!=fa)
     82         {
     83             v=e[k].to;
     84             ff[v]=u;
     85             dfs1(v,u);
     86             sz[u]+=sz[v];
     87             if(sz[v]>sz[hson[u]])    hson[u]=v;
     88         }
     89 }
     90 P2 d1[200101];//d1[i]维护i节点及其轻儿子的贡献
     91 P2 d2[200101];//d2[i]维护i节点(是重链顶)所在重链的值
     92 int tp[200101],dwn[200101];//链顶,链底
     93 inline void updd1(int x)
     94 {
     95     initnode(S::x,d1[x]);S::L=pl[x];S::_setx(1,n,1);
     96 }
     97 void dfs2(int u,int fa)
     98 {
     99     d1[u].a=a[u];
    100     b[++b[0]]=u;pl[u]=b[0];
    101     tp[u]=(u==hson[fa])?tp[fa]:u;
    102     if(hson[u])    dfs2(hson[u],u);
    103     dwn[u]=hson[u]?dwn[hson[u]]:u;
    104     int v,k;
    105     for(k=f1[u];k;k=e[k].nxt)
    106         if(e[k].to!=fa&&e[k].to!=hson[u])
    107         {
    108             v=e[k].to;
    109             dfs2(v,u);
    110             d1[u].b+=d2[v].b;
    111             d1[u].a+=d2[v].a;
    112         }
    113     updd1(u);
    114     if(u==tp[u])
    115     {
    116         P1 t=S::getx(pl[u],pl[dwn[u]],1,n,1);
    117         d2[u].a=t.a;d2[u].b=t.d+t.c;
    118     }
    119 }
    120 inline ll getsize(int x)
    121 {
    122     return S::getx(pl[x],pl[dwn[x]],1,n,1).a;
    123 }
    124 int main()
    125 {
    126     int i,x,y,idx;ll z,ans,szall;P1 t;
    127     scanf("%d%d",&n,&m);
    128     for(i=1;i<n;++i)
    129     {
    130         scanf("%d%d",&x,&y);
    131         e[++ne].to=y;e[ne].nxt=f1[x];f1[x]=ne;
    132         e[++ne].to=x;e[ne].nxt=f1[y];f1[y]=ne;
    133     }
    134     for(i=1;i<=n;++i)    scanf("%lld",a+i);
    135     dfs1(1,0);
    136     dfs2(1,0);
    137     while(m--)
    138     {
    139         scanf("%d",&idx);
    140         if(idx==1)
    141         {
    142             scanf("%d%lld",&x,&z);
    143             d1[x].a-=a[x];a[x]=z;d1[x].a+=z;
    144             while(x)
    145             {
    146                 updd1(x);
    147                 x=tp[x];y=ff[x];
    148                 t=S::getx(pl[x],pl[dwn[x]],1,n,1);
    149                 d1[y].a-=d2[x].a;d1[y].b-=d2[x].b;
    150                 d2[x].a=t.a;d2[x].b=t.d+t.c;
    151                 d1[y].a+=d2[x].a;d1[y].b+=d2[x].b;
    152                 x=y;
    153             }
    154             //printf("3t%d
    ",d2[1].b);
    155         }
    156         else
    157         {
    158             scanf("%d",&x);
    159             ans=d2[1].b;
    160             szall=getsize(1);
    161             if(x!=tp[x])
    162             {
    163                 y=tp[x];
    164                 z=d1[y].a;
    165                 d1[y].a+=szall-getsize(y);
    166                 updd1(y);
    167                 if(y!=dwn[y])
    168                 {
    169                     t=S::getx(pl[y]+1,pl[dwn[y]],1,n,1);
    170                     ans-=t.c;
    171                 }
    172                 if(x!=dwn[y])
    173                 {
    174                     t=S::getx(pl[x]+1,pl[dwn[y]],1,n,1);
    175                     ans+=t.c;
    176                 }
    177                 t=S::getx(pl[y],pl[x]-1,1,n,1);
    178                 ans+=t.f;
    179                 d1[y].a=z;
    180                 updd1(y);
    181                 x=y;
    182             }
    183             while(x!=1)
    184             {
    185                 y=ff[x];
    186                 z=getsize(x);
    187                 ans-=z*z;
    188                 z=szall-z;
    189                 ans+=z*z;
    190                 x=y;
    191                 if(x!=tp[x])
    192                 {
    193                     y=tp[x];
    194                     z=d1[y].a;
    195                     d1[y].a+=szall-getsize(y);
    196                     updd1(y);
    197                     if(y!=dwn[y])
    198                     {
    199                         t=S::getx(pl[y]+1,pl[dwn[y]],1,n,1);
    200                         ans-=t.c;
    201                     }
    202                     if(x!=dwn[y])
    203                     {
    204                         t=S::getx(pl[x]+1,pl[dwn[y]],1,n,1);
    205                         ans+=t.c;
    206                     }
    207                     t=S::getx(pl[y],pl[x]-1,1,n,1);
    208                     ans+=t.f;
    209                     d1[y].a=z;
    210                     updd1(y);
    211                     x=y;
    212                 }
    213             }
    214             printf("%lld
    ",ans);
    215         }
    216     }
    217     return 0;
    218 }
    View Code

    码完一看题解,???好像画风不太对??

    所以还是无视上面那个代码吧...

    正常得多的做法:

      1 #include<cstdio>
      2 #include<algorithm>
      3 using namespace std;
      4 typedef long long ll;
      5 struct E
      6 {
      7     int to,nxt;
      8 }e[400011];
      9 int f1[200011],ne;
     10 int n,m;
     11 struct S
     12 {
     13 #define lowbit(x) ((x)&(-x))
     14     ll d1[200011],d2[200011];
     15     void _add(int p,ll x,ll *d)
     16     {
     17         for(;p<=n;p+=lowbit(p))
     18             d[p]+=x;
     19     }
     20     ll _sum(int p,ll *d)
     21     {
     22         ll ans=0;
     23         for(;p>0;p-=lowbit(p))
     24             ans+=d[p];
     25         return ans;
     26     }
     27     void add(int l,int r,ll x)
     28     {
     29         _add(l,x,d1);
     30         _add(r+1,-x,d1);
     31         _add(l,x*l,d2);
     32         _add(r+1,-x*(r+1),d2);
     33     }
     34     ll sum(int l,int r)
     35     {
     36         return (r+1)*_sum(r,d1)-_sum(r,d2)
     37             -l*_sum(l-1,d1)+_sum(l-1,d2);
     38     }
     39 }s1;
     40 int b[200011],pl[200011];
     41 ll a[200011],a2[200011];
     42 int sz[200011],hson[200011],tp[200011];
     43 ll dep[200011];
     44 int ff[200011];
     45 void dfs1(int u,int fa)
     46 {
     47     sz[u]=1;
     48     for(int k=f1[u];k;k=e[k].nxt)
     49         if(e[k].to!=fa)
     50         {
     51             ff[e[k].to]=u;
     52             dep[e[k].to]=dep[u]+1;
     53             dfs1(e[k].to,u);
     54             sz[u]+=sz[e[k].to];
     55             if(sz[e[k].to]>sz[hson[u]])    hson[u]=e[k].to;
     56         }
     57 }
     58 void dfs2(int u,int fa)
     59 {
     60     b[++b[0]]=u;pl[u]=b[0];
     61     tp[u]=u==hson[fa]?tp[fa]:u;
     62     a2[u]=a[u];
     63     if(hson[u])
     64     {
     65         dfs2(hson[u],u);
     66         a2[u]+=a2[hson[u]];
     67     }
     68     for(int k=f1[u];k;k=e[k].nxt)
     69         if(e[k].to!=fa&&e[k].to!=hson[u])
     70         {
     71             dfs2(e[k].to,u);
     72             a2[u]+=a2[e[k].to];
     73         }
     74 }
     75 inline ll gsum1(int x)//x到1的路径和
     76 {
     77     int y;ll an=0;
     78     for(;x;x=ff[y])
     79     {
     80         y=tp[x];
     81         an+=s1.sum(pl[y],pl[x]);
     82     }
     83     return an;
     84 }
     85 inline void add1(int x,ll z)//x到1加上z
     86 {
     87     int y;
     88     for(;x;x=ff[y])
     89     {
     90         y=tp[x];
     91         s1.add(pl[y],pl[x],z);
     92     }
     93 }
     94 ll anss;
     95 int main()
     96 {
     97     ll ans,z,t;
     98     int i,x,y,idx;
     99     scanf("%d%d",&n,&m);
    100     for(i=1;i<n;++i)
    101     {
    102         scanf("%d%d",&x,&y);
    103         e[++ne].to=y;e[ne].nxt=f1[x];f1[x]=ne;
    104         e[++ne].to=x;e[ne].nxt=f1[y];f1[y]=ne;
    105     }
    106     for(i=1;i<=n;++i)
    107         scanf("%lld",a+i);
    108     dfs1(1,0);
    109     dfs2(1,0);
    110     for(i=1;i<=n;++i)
    111     {
    112         s1.add(pl[i],pl[i],a2[i]);
    113         anss+=a2[i]*a2[i];
    114     }
    115     while(m--)
    116     {
    117         scanf("%d",&idx);
    118         if(idx==1)
    119         {
    120             scanf("%d%lld",&x,&z);
    121             z=z-a[x];a[x]+=z;
    122             anss+=z*z*(dep[x]+1);
    123             anss+=2*gsum1(x)*z;
    124             add1(x,z);
    125         }
    126         else
    127         {
    128             scanf("%d",&x);
    129             ans=anss;
    130             t=gsum1(1);
    131             ans+=dep[x]*t*t;
    132             ans-=2*t*(gsum1(x)-t);
    133             printf("%lld
    ",ans);
    134         }
    135     }
    136     return 0;
    137 }
    View Code
  • 相关阅读:
    匿名对象
    JAVA中的方法重载 (参数个数不同,顺序不同,类型不同)
    构造方法的返回值和void 的区别
    一些小算法技巧
    Java基础总结(一)
    Struts2 Intercepter 笔记
    js Dom 编程
    The Bug and Exception of Hibernate
    包--R In Action
    --三种方法查询人所在部门平均工资
  • 原文地址:https://www.cnblogs.com/hehe54321/p/10198047.html
Copyright © 2020-2023  润新知