lct裸题
#include <iostream>
#include <cstdio>
using namespace std;
typedef long long ll;
int n, m, uu, vv, val[100005], tagm[100005], taga[100005], sum[100005], siz[100005];
int rev[100005], ww, xx, ch[100005][2], fa[100005];
char ss[15];
const int mod=51061;
bool isRoot(int x){
return ch[fa[x]][0]!=x && ch[fa[x]][1]!=x;
}
bool getW(int x){
return ch[fa[x]][1]==x;
}
void pushMul(int x, int k){
sum[x] = (ll)sum[x] * k % mod;
val[x] = (ll)val[x] * k % mod;
tagm[x] = (ll)tagm[x] * k % mod;
taga[x] = (ll)taga[x] * k % mod;
}
void pushAdd(int x, int k){
sum[x] = (sum[x] + (ll)siz[x]*k) % mod;
val[x] = (val[x] + k) % mod;
taga[x] = (taga[x] + k) % mod;
}
void pushDown(int x){
if(tagm[x]!=1){
pushMul(ch[x][0], tagm[x]);
pushMul(ch[x][1], tagm[x]);
tagm[x] = 1;
}
if(tagm[x]){
pushAdd(ch[x][0], taga[x]);
pushAdd(ch[x][1], taga[x]);
taga[x] = 0;
}
if(rev[x]){
swap(ch[x][0], ch[x][1]);
rev[ch[x][0]] ^= 1;
rev[ch[x][1]] ^= 1;
rev[x] = 0;
}
}
void xf(int x){
if(!isRoot(x)) xf(fa[x]);
pushDown(x);
}
void upd(int x){
siz[x] = siz[ch[x][0]] + siz[ch[x][1]] + 1;
sum[x] = (sum[ch[x][0]] + sum[ch[x][1]] + val[x]) % mod;
}
void rotate(int x){
int old=fa[x], oldf=fa[old], w=getW(x);
if(!isRoot(old)) ch[oldf][ch[oldf][1]==old] = x;
ch[old][w] = ch[x][w^1]; ch[x][w^1] = old;
fa[ch[old][w]] = old; fa[ch[x][w^1]] = x; fa[x] = oldf;
upd(old); upd(x);
}
void splay(int x){
xf(x);
while(!isRoot(x)){
int f=fa[x];
if(!isRoot(f)) rotate(getW(f)==getW(x)?f:x);
rotate(x);
}
}
void access(int x){
int y=0;
while(x){
splay(x);
ch[x][1] = y;
upd(x);
y = x;
x = fa[x];
}
}
void makeRoot(int x){
access(x);
splay(x);
rev[x] ^= 1;
}
void link(int u, int v){
makeRoot(u);
fa[u] = v;
}
void splitLine(int u, int v){
makeRoot(u);
access(v);
splay(v);
}
void cut(int u, int v){
splitLine(u, v);
ch[v][0] = fa[u] = 0;
}
int main(){
cin>>n>>m;
for(int i=1; i<=n; i++)
tagm[i] = val[i] = siz[i] = 1;
for(int i=1; i<n; i++){
scanf("%d %d", &uu, &vv);
link(uu, vv);
}
while(m--){
scanf("%s", ss);
if(ss[0]=='+'){
scanf("%d %d %d", &uu, &vv, &ww);
splitLine(uu, vv);
pushAdd(vv, ww);
}
else if(ss[0]=='-'){
scanf("%d %d %d %d", &uu, &vv, &ww, &xx);
cut(uu, vv);
link(ww, xx);
}
else if(ss[0]=='*'){
scanf("%d %d %d", &uu, &vv, &ww);
splitLine(uu, vv);
pushMul(vv, ww);
}
else{
scanf("%d %d", &uu, &vv);
splitLine(uu, vv);
printf("%d
", sum[vv]);
}
}
return 0;
}