题意:
题解:
去年D1T2,呵呵,被pei死,菜鸡的我考场上第一次做图论题,愉快爆零;
链剖+转化等式;
链剖就是求个lca和路径长度;
首先考虑链的情况:
先假设s在t的上面,那么对于一个点i,能对它产生贡献的路径一定满足i-s==w[i] && t>=i,于是我们设k[i]=i-w[i],cnt[i]表示以i为起点的路径条数,那么从上往下遍历,到点i先计算以点i为起点的贡献,再计算点i的答案,即访问cnt[k[i]],注意到t>=i,所以到i后cnt[以i为终点的起点]--,这里用vector存一下即可;
对于s在t下面的类似,只是k[i]=deep[i]+w[i];
正解和链很类似:
考虑将一条路径划分为两条链:s->lca,t->lca,然后我们就可以用类似于链的做法;
对于s->lca的一条链,我们只考虑s->t这条路径对s->lca这条链上的点的贡献,则对i有贡献需满足deep[s]-deep[i]w[i],移项之后又有k[i]=deep[i]+w[i]deep[s],于是在s上打+1标记,在lca上打-1标记 ,选1为根dfs,对每个点统计一下答案即可,但是有一点需要注意,与链的情况不同的是,这里的deep[s]是s的深度,可能会把别的子树的贡献计入答案,于是我们进入这个点时记一下这个点cnt[k[i]]的last,然后回溯后cnt[k[i]]-last就是对这个点产生的贡献。
t->lca的情况类似,只是k[i]=deep[i]-w[i]==deep[t]-len(len为路径长度)。
总结:对某个等式进行分析时,对等式进行化归,相关联的变量可划到等式的一边,从而找出各变量之间的线性关系。
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#define ll long long
#define N 300010
#define M 300000
using namespace std;
int n,m,e_num,deep;
int nxt[N<<1],to[N<<1],h[N];
int fa[N],dep[N],top[N],siz[N],son[N],w[N],val[N],ans[N],cnt[N<<1];
struct Node {int s,t,lca,len;}p[N];
vector<int> v1[N],v2[N],v3[N];
int gi() {
int x=0,o=1; char ch=getchar();
while(ch!='-' && (ch<'0' || ch>'9')) ch=getchar();
if(ch=='-') o=-1,ch=getchar();
while(ch>='0' && ch<='9') x=x*10+ch-'0',ch=getchar();
return o*x;
}
void add(int x, int y) {
nxt[++e_num]=h[x],to[e_num]=y,h[x]=e_num;
}
void dfs1(int u) {
siz[u]=1;
for(int i=h[u]; i; i=nxt[i]) {
int v=to[i];
if(v==fa[u]) continue;
fa[v]=u,dep[v]=dep[u]+1;
dfs1(v);
if(siz[v]>siz[son[u]]) son[u]=v;
siz[u]+=siz[v];
}
}
void dfs2(int u) {
if(son[u]) top[son[u]]=top[u],dfs2(son[u]);
for(int i=h[u]; i; i=nxt[i]) {
int v=to[i];
if(v==fa[u] || v==son[u]) continue;
top[v]=v,dfs2(v);
}
}
int lca(int x, int y) {
while(top[x]!=top[y]) {
if(dep[top[x]]>dep[top[y]]) x=fa[top[x]];
else y=fa[top[y]];//hehe,chain divide was wrong
}
if(dep[x]<dep[y]) return x;
else return y;
}
void sol1(int u) {
int last,num=dep[u]+w[u];
if(num<=deep) last=cnt[num];
for(int i=h[u]; i; i=nxt[i]) {
if(to[i]!=fa[u]) sol1(to[i]);
}
cnt[dep[u]]+=val[u];
if(num<=deep) ans[u]+=cnt[num]-last;
for(int i=0; i<v1[u].size(); i++) cnt[v1[u][i]]--;
}
void sol2(int u) {
int last,num=dep[u]-w[u]+M;
last=cnt[num];
for(int i=h[u]; i; i=nxt[i]) {
if(to[i]!=fa[u]) sol2(to[i]);
}
for(int i=0; i<v2[u].size(); i++) cnt[v2[u][i]+M]++;
ans[u]+=cnt[num]-last;
for(int i=0; i<v3[u].size(); i++) cnt[v3[u][i]+M]--;
}
int main() {
n=gi(),m=gi();
for(int i=1; i<n; i++) {
int x=gi(),y=gi();
add(x,y),add(y,x);
}
fa[1]=1,dep[1]=1,top[1]=1;
dfs1(1),dfs2(1);
for(int i=1; i<=n; i++) w[i]=gi(),deep=max(deep,dep[i]);
for(int i=1; i<=m; i++) {
p[i].s=gi(),p[i].t=gi();
p[i].lca=lca(p[i].s,p[i].t);
p[i].len=dep[p[i].s]+dep[p[i].t]-2*dep[p[i].lca];
val[p[i].s]++;
v1[p[i].lca].push_back(dep[p[i].s]);
v2[p[i].t].push_back(dep[p[i].t]-p[i].len);
v3[p[i].lca].push_back(dep[p[i].t]-p[i].len);
}
sol1(1);
memset(cnt,0,sizeof(cnt));
sol2(1);
for(int i=1; i<=m; i++) {
if(dep[p[i].s]==dep[p[i].lca]+w[p[i].lca]) ans[p[i].lca]--;
}
for(int i=1; i<=n; i++) printf("%d ", ans[i]);
return 0;
}
附上部分分的代码,没有暴力跳路径和T=1的点
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<cmath>
#define ll long long
#define N 300010
using namespace std;
int n,m,e_num;
int nxt[N<<1],to[N<<1],h[N],siz[N],dep[N],w[N],cnt[N],inx[N],v1[N],v2[100010][100],res[N];
bool flg1=1,flg2=1,flg3=1,flg4=1,flg5=1;
struct Node {int s,t;}p[N];
int gi() {
int x=0,o=1; char ch=getchar();
while(ch!='-' && (ch<'0' || ch>'9')) ch=getchar();
if(ch=='-') o=-1,ch=getchar();
while(ch>='0' && ch<='9') x=x*10+ch-'0',ch=getchar();
return o*x;
}
void add(int x, int y) {
nxt[++e_num]=h[x],to[e_num]=y,h[x]=e_num;
}
void dfs(int u, int fa) {
for(int i=h[u]; i; i=nxt[i]) {
int v=to[i];
if(v==fa) continue;
dep[v]=dep[u]+1;
dfs(v,u);
siz[u]+=siz[v];
}
}
void work1() {
for(int i=1; i<=m; i++) cnt[p[i].s]++;
for(int i=1; i<=n; i++) {
if(!w[i]) printf("%d ", cnt[i]);
else printf("0 ");
}
}
void work2() {
for(int i=1; i<=m; i++) cnt[p[i].s]++;
for(int i=1; i<=n; i++) {
printf("%d ", cnt[i]);
}
}
void work3() {
for(int i=1; i<=m; i++) {
if(p[i].s<=p[i].t) {
v1[p[i].s]++,v2[p[i].t][++v2[p[i].t][0]]=p[i].s;
}
}
for(int i=1; i<=n; i++) {
cnt[i]+=v1[i];
if(i-w[i]>=1) res[i]+=cnt[i-w[i]];
for(int j=1; j<=v2[i][0]; j++) cnt[v2[i][j]]--;
}
memset(v1,0,sizeof(v1));
memset(v2,0,sizeof(v2));
memset(cnt,0,sizeof(cnt));
for(int i=1; i<=m; i++) {
if(p[i].s>p[i].t)
v1[p[i].s]++,v2[p[i].t][++v2[p[i].t][0]]=p[i].s;
}
for(int i=n; i>=1; i--) {
cnt[i]+=v1[i];
if(i+w[i]<=n) res[i]+=cnt[i+w[i]];
for(int j=1; j<=v2[i][0]; j++) cnt[v2[i][j]]--;
}
for(int i=1; i<=n; i++)
printf("%d ", res[i]);
}
void work4() {
for(int i=1; i<=m; i++) siz[p[i].t]++;
dfs(1,0);
for(int i=1; i<=n; i++) {
if(dep[i]==w[i]) cnt[i]=siz[i];
printf("%d ", cnt[i]);
}
}
int main() {
n=gi(),m=gi();
for(int i=1; i<n; i++) {
int x=gi(),y=gi();
add(x,y),add(y,x);
inx[x]++,inx[y]++;
if(inx[x]>2 || inx[y]>2) flg3=0;
}
for(int i=1; i<=n; i++) {
w[i]=gi();
if(w[i]) flg2=0;
}
for(int i=1; i<=m; i++) {
int s=gi(),t=gi();
if(s!=t) flg1=0;
if(s!=1) flg4=0;
if(t!=1) flg5=0;
p[i]=(Node){s,t};
}
if(flg1) work1();
if(flg2) work2();
if(flg3) work3();
if(flg4) work4();
return 0;
}