没想到我也能做出来果然(网络的力量)强大啊。
#include<bits/stdc++.h>
#define maxx 305000
#define mod 51061
#define ll unsigned int
using namespace std;
ll c[maxx][2],val[maxx],fa[maxx],sum[maxx],mul[maxx],add[maxx],siz[maxx];
bool rev[maxx];
struct node{
ll st[maxx];
ll tp;
bool empty(){return tp==0;}
void pop(){st[tp--]=0;}
void push(ll x){st[++tp]=x;}
ll top(){return st[tp];}
}s;
void cal(ll x,ll c,ll j)
{
if(!x)return;
val[x]=(val[x]*c+j)%mod;
sum[x]=(sum[x]*c+j*siz[x])%mod;
add[x]=(add[x]*c+j)%mod;
mul[x]=(mul[x]*c)%mod;
}
bool pdroot(ll x){return (c[fa[x]][0]==x||c[fa[x]][1]==x);}
void put_up(ll x)
{ sum[x]=(sum[c[x][0]]+sum[c[x][1]]+val[x])%mod;
siz[x]=siz[c[x][0]]+siz[c[x][1]]+1;
}
void zychange(ll x)
{
rev[x]^=1;
swap(c[x][0],c[x][1]);
}
void put_down(ll x)
{
if(rev[x])
{
if(c[x][0])zychange(c[x][0]);
if(c[x][1])zychange(c[x][1]);
rev[x]=0;
}
cal(c[x][0],mul[x],add[x]),cal(c[x][1],mul[x],add[x]);
mul[x]=1,add[x]=0;
}
void rot(ll x)
{
ll y=fa[x],z=fa[y];
ll l=(c[y][1]==x),r=l^1;
if(pdroot(y))c[z][c[z][1]==y]=x;
fa[x]=z;
fa[y]=x;
fa[c[x][r]]=y;
c[y][l]=c[x][r];c[x][r]=y;//注意 这个语句顺序不能换!!(卡了我半小时……)
put_up(y);
put_up(x);
}
void splay(ll x)
{
ll y=x;
s.push(y);
while(pdroot(y))
{
y=fa[y];
s.push(y);
}
while(!s.empty())
{
put_down(s.top());
s.pop();
}
while(pdroot(x))
{
ll y=fa[x],z=fa[y];
ll l=(c[y][1]==x),l1=(c[z][1]==y);
if(pdroot(y))
{
if(l^l1)rot(x);
else rot(y);
}
rot(x);
}
put_up(x);
}
void access(ll x)
{
for(ll y=0;x;x=fa[y=x])
{
splay(x);
c[x][1]=y;
put_up(x);
}
}
void makeroot(ll x)
{
access(x);
splay(x);
zychange(x);
}
void split(ll x,ll y)
{
makeroot(x);
access(y);
splay(y);
}
int findroot(ll x)
{
access(x);
splay(x);
while(c[x][0])
{
put_down(x);
x=c[x][0];
}
return x;
}
void link(ll x,ll y)
{
makeroot(x);
if(findroot(y)!=x)fa[x]=y;
}
void cut(ll x,ll y)
{
makeroot(x);
if(findroot(y)==x&&fa[x]==y&&!c[x][1])
{
fa[x]=c[y][0]=0;
put_up(y);
}
}
int main()
{
ll n;
char ch[20];
ll x,y;
scanf("%d",&n);
ll q;
ll u1,u2,v1,c,v2;
scanf("%d",&q);
for(ll i=1;i<=n;i++) val[i]=mul[i]=siz[i]=sum[i]=1;
for(ll i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
link(x,y);
}
for(q;q;q--)
{
scanf("%s",ch);
if(ch[0]=='+')
{
scanf("%d%d%d",&u1,&v1,&c);
split(v1,u1);
cal(u1,1,c);
}
if(ch[0]=='*')
{
scanf("%d%d%d",&u1,&v1,&c);
split(v1,u1);
cal(u1,c,0);
}
if(ch[0]=='-')
{
scanf("%d%d%d%d",&u1,&v1,&u2,&v2);
cut(u1,v1);
link(u2,v2);
}
if(ch[0]=='/')
{
scanf("%d%d",&u1,&v1);
split(v1,u1);
printf("%d
",sum[u1]);
}
}
return 0;
}