题目描述
master 对树上的求和非常感兴趣。他生成了一棵有根树,并且希望多次询问这棵树上一段路径上所有节点深度的k 次方和,而且每次的k可能是不同的。此处节点深度的定义是这个节点到根的路径上的边数。他把这个问题交给了pupil,但pupil 并不会这么复杂的操作,你能帮他解决吗?
输入格式
第一行包含一个正整数n,表示树的节点数。
之后n-1行每行两个空格隔开的正整数i, j,表示树上的一条连接点ii 和点jj 的边。
之后一行一个正整数m,表示询问的数量。
之后每行三个空格隔开的正整数i, j, k,表示询问从点i到点j的路径上所有节点深度的k次方和。由于这个结果可能非常大,输出其对998244353取模的结果。
树的节点从1开始标号,其中1号节点为树的根。
输出格式
对于每组数据输出一行一个正整数表示取模后的结果。
数据范围
对于30%的数据,1≤n,n≤100
对于60% 的数据,1≤n,m≤1000。
对于100%的数据,1≤n,m≤300000,1≤k≤50。
考虑30分的做法。先处理出深度,设dep(x)表示x的深度,那么dep(x)=dep(fa[x])+1。对于每次询问,我们可以暴力从询问的两个点LCA走,每次求出当前点的深度的k次方并加入答案中即可。求LCA可以用Tarjan做到O(N+M),每次询问可以做到O(NK),总共就是O(MNK+N+M)≈O(MNK),如果用倍增或者树剖求LCA就是O(M(NK+logN))。这种三次方的级别也就只能过30分了......
考虑60分的做法。根据k的范围为1~50,我们可以预处理出每个点的深度的1~50次方,设idep(x,i)表示x的深度的i次方,那么idep(x,i)=idep(x,i-1) * dep(x),初始化idep(x,0)=1,这个复杂度为O(NK)。然后每次询问还是一个个往上走并加起来即可。若用Tarjan求LCA,总共就是O(NK+MN+N+M)≈O((K+M)N),否则为O((K+M)N+MlogN),降到了平方级别。
然后考虑100分做法。既然每次都要一个个加上,我们为什么不用树上前缀和直接加一加减一减得出答案呢?设val(x,i)为根节点到x的路径上所有点深度的k次方和,那么val(x,i)=val(fa[x],i)+idep(x,i)。时间复杂度还是O(NK)。接下来对于每个询问,首先求出两个点u,v的LCA,设LCA(u,v)=w,那么ans=val(u,k)+val(v,k)-val(w,k)-val(fa[w],k),最后的val(fa[w],k)是因为LCA只计算一次。若用Tarjan求LCA,每次询问的复杂度就是O(1),总共就是O(NK+M+N)≈O(NK+M),否则每次询问就是O(logN),总共就是O(NK+MlogN)。表面上还是平方级别,但已经相比之前60分做法优化掉了一项NM,剩下的NK和MlogN由于K≤50,可以看成是log级别的,所以就是O(NlogN)级别的算法,分肯定拿满。
附上代码,LCA是用树剖求的
#include<iostream>
#include<cstring>
#include<cstdio>
#define maxn 300001
#define maxk 51
#define p 998244353
using namespace std;
struct edge{
int to,next;
edge(){}
edge(const int &_to,const int &_next){ to=_to,next=_next; }
}e[maxn<<1];
int head[maxn],k;
long long dep[maxn],idep[maxn][maxk],val[maxn][maxk];
int size[maxn],fa[maxn],son[maxn],top[maxn];
int n,m;
inline int read(){
register int x(0),f(1); register char c(getchar());
while(c<'0'||'9'<c){ if(c=='-') f=-1; c=getchar(); }
while('0'<=c&&c<='9') x=(x<<1)+(x<<3)+(c^48),c=getchar();
return x*f;
}
inline void add(const int &u,const int &v){
e[k]=edge(v,head[u]);
head[u]=k++;
}
void dfs_getson(int u){
size[u]=1;
for(register int i=head[u];~i;i=e[i].next){
int v=e[i].to;
if(v==fa[u]) continue;
dep[v]=dep[u]+1,fa[v]=u;
idep[v][0]=1;
for(register int j=1;j<maxk;j++) idep[v][j]=idep[v][j-1]*dep[v]%p;
for(register int j=1;j<maxk;j++) val[v][j]=(val[u][j]+idep[v][j])%p;
dfs_getson(v);
size[u]+=size[v];
if(size[v]>size[son[u]]) son[u]=v;
}
}
void dfs_rewrite(int u,int tp){
top[u]=tp;
if(son[u]) dfs_rewrite(son[u],tp);
for(register int i=head[u];~i;i=e[i].next){
int v=e[i].to;
if(v!=son[u]&&v!=fa[u]) dfs_rewrite(v,v);
}
}
inline int lca(int u,int v){
while(top[u]!=top[v]){
if(dep[top[u]]>dep[top[v]]) swap(u,v);
v=fa[top[v]];
}
if(dep[u]>dep[v]) swap(u,v);
return u;
}
int main(){
memset(head,-1,sizeof head);
n=read();
for(register int i=1;i<n;i++){
int u=read(),v=read();
add(u,v),add(v,u);
}
dfs_getson(1);
dfs_rewrite(1,1);
m=read();
while(m--){
int u=read(),v=read(),t=read();
int w=lca(u,v);
printf("%lld
",(val[u][t]+val[v][t]+(p<<1)-val[w][t]-val[fa[w]][t])%p);
}
return 0;
}
*最后的地方相减时要判负数......或者直接加上两个p得了