「清华集训2016」连通子树 (点分治+dfs序dp+虚树)
丧心病狂系列
首先对于会影响答案的点构建虚树,然后点分治+dfs序dp常见套路。。。
点分治+dfs序dp好题:HDU 5909
由于构建虚树之后\(x,y\)之间的点随便选联通块的方案还需要预处理,最好是倍增吧。。
底层是子树随便选,点之间是倍增处理,都需要换根\(\text{dp}\)预处理
算法嵌套大赛
#include<bits/stdc++.h>
using namespace std;
#define reg register
typedef long long ll;
#define rep(i,a,b) for(int i=a,i##end=b;i<=i##end;++i)
#define drep(i,a,b) for(int i=a,i##end=b;i>=i##end;--i)
#define pb push_back
template <class T> inline void cmin(T &a,T b){ if(a>b) a=b; }
template <class T> inline void cmax(T &a,T b){ if(a<b) a=b; }
char IO;
template <class T=int> T rd(){
T s=0;
int f=0;
while(!isdigit(IO=getchar())) if(IO=='-') f=1;
do s=(s<<1)+(s<<3)+(IO^'0');
while(isdigit(IO=getchar()));
return f?-s:s;
}
const int N=1e5+10,P=1e9+7;
int n,m;
vector <int> G[N],V[N],E[N];
int col[N],a,ca,b,cb,c,cc,fa[N][20],dep[N];
int L[N],dfn;
struct Node{
ll s,ls,rs,sum;
Node operator + (const Node __) const {
Node res;
res.s=s*__.s%P;
res.ls=(ls+s*__.ls)%P;
res.rs=(__.s*rs+__.rs)%P;
res.sum=(sum+__.sum+rs*__.ls)%P;
return res;
}
}s[N][20],tmp[N];//倍增,整段选了,区间左边连续选,右边连续选,整段连续随便选
int tmp2[N];
ll dp[N],g[N],up[N],all[N],Idp[N],Iup[N];
// dp子树随便选,up外面随便选
ll qpow(ll x,ll k) {
ll res=1;
for(;k;k>>=1,x=x*x%P) if(k&1) res=res*x%P;
return res;
}
Node Que(int x,int f) {
Node res=(Node){1,0,0,0};
drep(i,18,0) if(dep[fa[x][i]]>dep[f]) res=s[x][i]+res,x=fa[x][i];
return res;
} // x,f路径上的点随便选,不包括x,f
void pre_dfs(int u,int f) {
L[u]=++dfn,dep[u]=dep[fa[u][0]=f]+1;
rep(i,1,18) fa[u][i]=fa[fa[u][i-1]][i-1];
dp[u]=1;
for(int v:G[u]) if(v!=f) {
pre_dfs(v,u);
dp[u]=dp[u]*(dp[v]+1)%P;
g[u]=(g[u]+g[v]+dp[v])%P;
}
}
void redfs(int u,int f) {
if(f) {
ll t=dp[f]*qpow(dp[u]+1,P-2)%P;
s[u][0]=(Node){t,t,t,((t+g[f]-g[u]-dp[u])%P+P)%P};
rep(i,1,18) if(fa[u][i]) s[u][i]=s[fa[u][i-1]][i-1]+s[u][i-1];
}
up[u]=f?(all[f]*qpow(dp[u]+1,P-2)%P):1;
all[u]=up[u]*dp[u]%P;
for(int v:G[u]) if(v!=f) redfs(v,u);
Idp[u]=qpow(dp[u]+1,P-2),Iup[u]=qpow(up[u],P-2);
} // 换根dp预处理
int line[N],cnt,stk[N],top;
int LCA(int x,int y) {
if(dep[x]<dep[y]) swap(x,y);
for(int del=dep[x]-dep[y],i=0;(1<<i)<=del;++i) if(del&(1<<i)) x=fa[x][i];
if(x==y) return x;
drep(i,18,0) if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
void Insert(int u){
int lca=LCA(stk[top],u);
if(lca==stk[top]) { E[u].clear(); stk[++top]=u; return; }
while(top>1 && L[stk[top-1]]>L[lca]) {
E[stk[top-1]].pb(stk[top]);
top--;
}
if(stk[top-1]!=lca) E[lca].clear();
E[lca].pb(stk[top--]);
if(stk[top]!=lca) stk[++top]=lca;
stk[++top]=u,E[u].clear();
}
void Construct(){
sort(line+1,line+cnt+1,[&](const int x,const int y){ return L[x]<L[y]; });
stk[top=1]=1,E[1].clear();
rep(i,1,cnt) if(line[i]!=1) Insert(line[i]);
while(top>1) E[stk[top-1]].pb(stk[top]),top--;
}// 构建虚树
ll ans=0,f[N][6][6][6];
int vis[N],nfa[N]; //虚树上的father
int Up(int u,int f) {
drep(i,18,0) if(dep[fa[u][i]]>dep[f]) u=fa[u][i];
return u;
} // 找到u在f下面最高的点
void dfs(int u,int fa) {
nfa[u]=fa;
vis[u]=0;
if(ca==0 && cb==0 && cc==0) ans=(ans+g[u])%P;
for(int v:E[u]) {
int t=Up(v,u);
if(ca==0 && cb==0 && cc==0) ans=((ans-g[t]-dp[t])%P+P)%P;
}
for(int v:E[u]) {
tmp[v]=Que(v,u),tmp2[v]=Up(v,u);
dfs(v,u);
Node t=Que(v,u);
if(ca==0 && cb==0 && cc==0) ans=(ans+t.sum)%P; // 各种神奇的特判,对拍一年,我死了
E[v].pb(u);
}
}
int mi,rt,sz[N];
void FindRt(int n,int u,int f) {
int ma=n-sz[u];
for(int v:E[u]) if(!vis[v] && v!=f) FindRt(n,v,u),cmax(ma,sz[v]);
if(mi>ma) mi=ma,rt=u;
}
int TL[N],TR[N],nca,ncb,ncc;// 虚树上的dfs序
void dfs2(int u,int f) {
nca+=(col[u]==a),ncb+=(col[u]==b),ncc+=(col[u]==c);
nfa[u]=f;
TL[u]=++cnt,sz[u]=1,line[cnt]=u;
for(int v:E[u]) if(!vis[v] && v!=f) {
dfs2(v,u);
sz[u]+=sz[v];
}
TR[u]=cnt;
}
void Calc(int u){
if(nca<ca || ncb<cb || ncc<cc) return;
rep(i,1,cnt) rep(j,0,ca) rep(k,0,cb) rep(d,0,cc) f[i][j][k][d]=0;
ll x=all[u];
for(int v:E[u]) {
if(dep[v]<dep[u]) x=x*Iup[u]%P;
else x=x*Idp[tmp2[v]]%P;
if(vis[v]) {
if(dep[v]<dep[u]) x=x*(tmp[u].rs+1)%P;
else x=x*(tmp[v].ls+1)%P;
}
}
f[1][col[u]==a][col[u]==b][col[u]==c]=x;
rep(i,2,cnt) {
int u=line[i];
ll x=all[u];
for(int v:E[u]) {
if(dep[v]<dep[u]) x=x*Iup[u]%P;
else x=x*Idp[tmp2[v]]%P;
if(vis[v]) {
if(dep[v]<dep[u]) x=x*(tmp[u].rs+1)%P;
else x=x*(tmp[v].ls+1)%P;
}
}
if(dep[nfa[u]]<dep[u]) {
Node t=tmp[u];
rep(na,0,ca) rep(nb,0,cb) rep(nc,0,cc) if(f[i-1][na][nb][nc]) {
f[TR[u]][na][nb][nc]=(f[TR[u]][na][nb][nc]+f[i-1][na][nb][nc]*(t.ls+1))%P;
f[i][na+(col[u]==a)][nb+(col[u]==b)][nc+(col[u]==c)]=(f[i][na+(col[u]==a)][nb+(col[u]==b)][nc+(col[u]==c)]+f[i-1][na][nb][nc]*t.s%P*x)%P;
}
} else {
Node t=tmp[nfa[u]];
rep(na,0,ca) rep(nb,0,cb) rep(nc,0,cc) if(f[i-1][na][nb][nc]) {
f[TR[u]][na][nb][nc]=(f[TR[u]][na][nb][nc]+f[i-1][na][nb][nc]*(t.rs+1))%P;
f[i][na+(col[u]==a)][nb+(col[u]==b)][nc+(col[u]==c)]=(f[i][na+(col[u]==a)][nb+(col[u]==b)][nc+(col[u]==c)]+f[i-1][na][nb][nc]*t.s%P*x)%P;
}
}
} // dfs序dp
ans=(ans+f[cnt][ca][cb][cc])%P;
}
void Divide(int u) {
nca=ncb=ncc=0;
cnt=0,dfs2(u,0);
Calc(u),vis[u]=1;
for(int v:E[u]) if(!vis[v]) {
mi=1e9,FindRt(sz[v],v,u);
Divide(rt);
}
}
int main(){
n=rd(),m=rd();
rep(i,1,n) {
col[i]=rd();
V[col[i]].pb(i);
}
rep(i,2,n) {
int u=rd(),v=rd();
G[u].pb(v),G[v].pb(u);
}
pre_dfs(1,0),redfs(1,0);
rep(kase,1,m) {
a=rd(),ca=rd(),b=rd(),cb=rd(),c=rd(),cc=rd();
cnt=0;
if(ca>(int)V[a].size() || cb>(int)V[b].size() || cc>(int)V[c].size()) { puts("0"); continue; }
for(int v:V[a]) line[++cnt]=v;
for(int v:V[b]) line[++cnt]=v;
for(int v:V[c]) line[++cnt]=v;
Construct();
ans=0,dfs(1,0);
Divide(1);
printf("%lld\n",ans);
}
}