题解
首先当(k=1)的时候肥肠简单
就是按照(x)从小到大排序
每处理到一个(x),就把(1 o x)的路径上的点都+1
然后查询(y)的时候就查询(1 o y)的点权和
那么(k>1)的时候也一样
对于深度为(i)的节点,给ta加上(i^k-(i-1)^k)即可
最后查询的时候查询点u的贡献就是$val[u] imes ( dep[u]k-(dep[u]-1)k ) ( 可以用线段树维护 先记录每个节点的)(dep[u]k-(dep[u]-1)k)(值 然后每次就是给一些节点的这个东西前面加一个系数 所以线段树维护一个区间)(dep[u]k-(dep[u]-1)k)(之和来方便维护区间)val[u] imes (dep[u]k-(dep[u]-1)k)$之和
代码
#include<vector>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
const int M = 50005 ;
const int mod = 998244353 ;
using namespace std ;
inline int read() {
char c = getchar() ; int x = 0 , w = 1 ;
while(c>'9'||c<'0') { if(c=='-') w = -1 ; c = getchar() ; }
while(c>='0'&&c<='9') { x = x*10+c-'0' ; c = getchar() ; }
return x*w ;
}
int n , m , k , cnt ;
int tp[M] , fa[M] , p[M] , size[M] ;
int ans[M] , dep[M] , w[M] , son[M] ;
int id[M] , val[M] , top[M] ;
struct Q { int x , y , idx ; } q[M] ;
vector < int > vec[M] ;
inline bool operator < (Q a , Q b) { return a.x < b.x ; }
struct Node { int Bsm , sum , Tag ; } t[M * 4] ;
inline int Fpw(int Base , int k) {
int temp = 1 ;
while(k) {
if(k & 1) temp = 1LL * temp * Base % mod ;
Base = 1LL * Base * Base % mod ; k >>= 1 ;
}
return temp ;
}
void dfs1(int u , int father) {
dep[u] = dep[father] + 1 ; size[u] = 1 ; int mx = -1 ;
for(int i = 0 , v ; i < vec[u].size() ; i ++) {
v = vec[u][i] ; dfs1(v , u) ;
size[u] += size[v] ; if(size[v] > mx) mx = size[v] , son[u] = v ;
}
}
void dfs2(int u , int topf) {
top[u] = topf ; id[u] = ++ cnt ; val[cnt] = w[u] ;
if(!son[u]) return ; dfs2(son[u] , topf) ;
for(int i = 0 , v ; i < vec[u].size() ; i ++) {
v = vec[u][i] ;
if(!id[v]) dfs2(v , v) ;
}
}
# define ls (now << 1)
# define rs (now << 1 | 1)
inline void pushup(int now) {
t[now].Bsm = (t[ls].Bsm + t[rs].Bsm) % mod ;
t[now].sum = (t[ls].sum + t[rs].sum) % mod ;
}
void build(int l , int r , int now) {
if(l == r) { t[now].Bsm = val[l] ; return ; }
int mid = (l + r) >> 1;
build(l , mid , ls) ; build(mid + 1 , r , rs) ;
pushup(now) ;
}
inline void update(int now , int k) {
t[now].sum = (t[now].sum + 1LL * k * t[now].Bsm % mod) % mod ;
t[now].Tag = (t[now].Tag + k) % mod ;
}
inline void pushdown(int now) {
if(t[now].Tag) {
update(ls , t[now].Tag) ;
update(rs , t[now].Tag) ;
t[now].Tag = 0 ;
}
}
void Change(int L , int R , int l , int r , int now) {
if(l > R || r < L) return ;
if(l >= L && r <= R) { update(now , 1) ; return ; }
int mid = (l + r) >> 1 ;
pushdown(now) ;
if(mid >= R) Change(L , R , l , mid , ls) ;
else if(mid < L) Change(L , R , mid + 1 , r , rs) ;
else Change(L , mid , l , mid , ls) , Change(mid + 1 , R , mid + 1 , r , rs) ;
pushup(now) ;
}
int qry(int L , int R , int l , int r , int now) {
if(l > R || r < L) return 0 ;
if(l >= L && r <= R) return t[now].sum ;
int mid = (l + r) >> 1 ;
pushdown(now) ;
if(mid >= R) return qry(L , R , l , mid , ls) ;
else if(mid < L) return qry(L , R , mid + 1 , r , rs) ;
else return ( qry(L , mid , l , mid , ls) + qry(mid + 1 , R , mid + 1 , r , rs) ) % mod ;
}
inline void Change(int x , int y) {
while(top[x] != top[y]) {
if(dep[top[x]] < dep[top[y]]) swap(x , y) ;
Change(id[top[x]] , id[x] , 1 , n , 1) ;
x = fa[top[x]] ;
}
if(dep[x] > dep[y]) swap(x , y) ;
Change(id[x] , id[y] , 1 , n , 1) ;
}
inline int query(int x , int y) {
int ret = 0 ;
while(top[x] != top[y]) {
if(dep[top[x]] < dep[top[y]]) swap(x , y) ;
ret = (ret + qry(id[top[x]] , id[x] , 1 , n , 1)) % mod ;
x = fa[top[x]] ;
}
if(dep[x] > dep[y]) swap(x , y) ;
ret = (ret + qry(id[x] , id[y] , 1 , n , 1)) % mod ;
return ret ;
}
# undef ls
# undef rs
int main() {
n = read() ; m = read() ; k = read() % (mod - 1) ;
for(int i = 1 ; i <= n ; i ++) {
tp[i] = Fpw(i , k) ;
p[i] = (tp[i] - tp[i - 1]) % mod ;
p[i] = (p[i] + mod) % mod ;
}
for(int i = 2 ; i <= n ; i ++) {
fa[i] = read() ;
vec[fa[i]].push_back(i) ;
}
dfs1(1 , 0) ;
for(int i = 1 ; i <= n ; i ++)
w[i] = p[dep[i]] ;
dfs2(1 , 1) ;
build(1 , n , 1) ;
for(int i = 1 ; i <= m ; i ++)
q[i].x = read() , q[i].y = read() , q[i].idx = i ;
sort(q + 1 , q + m + 1) ;
for(int i = 1 ; i <= m ; i ++) {
for(int j = q[i - 1].x + 1 ; j <= q[i].x ; j ++)
Change(1 , j) ;
ans[q[i].idx] = query(1 , q[i].y) ;
}
for(int i = 1 ; i <= m ; i ++) printf("%d
",ans[i]) ;
return 0 ;
}