题目
点这里看题目。
分析
首先发现,对于((a,b,c))的合法三元组,(c)一定在(a)的子树内,并且(b)也是(c)的祖先。那么我们只需要考虑(b)的位置。如果(b)是(a)的子孙,那么(c)一定就是(b)的子孙,此时的贡献是(siz(b)-1)(我们以下都用(siz(u))表示(u)的子树大小)。如果(b)是(a)的祖先,那么贡献就是(siz(a)-1)。
发现,当(b)是(a)的祖先的时候,贡献只与(k)、(siz(a))和(a)的深度有关,因此可以直接计算。
因此我们只需要考虑,对每个询问求出:
[sum_{dis(p,v)le k} (siz(v)-1)
]
这是一个长链剖分可做的问题。但是由于我很蒻,因此我不会。
考虑到 " (a)和(b)距离不超过(k) " 的限制,我们可以发现,(b)到(a)的深度差不会超过(k)。因此,知道了(a)的深度和(k),我们就可以知道合法的(b)的最大的深度,我们称这个最大深度为一个询问的 " 可用深度 " 。
因此,我们将所有点按照深度从小到大,插入贡献。对于一个询问,当深度小于等于它的 " 可用深度 " 的点都被插入后,我们就可以计算它此时的答案——也就是它子树内的贡献和。
可以发现,此时查询既满足贡献不遗漏,也满足不会有多余的贡献。
子树内的贡献和,可以通过将树展开为 DFS 序的方法,转化为区间求和,用 BIT 可以维护。点和询问都可以在排序后使用指针扫过去。
时间复杂度(O((n+q)log_2n))。
代码
#include <cstdio>
#include <algorithm>
using namespace std;
typedef long long LL;
const int MAXN = 3e5 + 5;
template<typename _T>
void read( _T &x )
{
x = 0;char s = getchar();int f = 1;
while( s > '9' || s < '0' ){if( s == '-' ) f = -1; s = getchar();}
while( s >= '0' && s <= '9' ){x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar();}
x *= f;
}
template<typename _T>
void write( _T x )
{
if( x < 0 ){ putchar( '-' ); x = ( ~ x ) + 1; }
if( 9 < x ){ write( x / 10 ); }
putchar( x % 10 + '0' );
}
template<typename _T>
_T MIN( const _T a, const _T b )
{
return a < b ? a : b;
}
template<typename _T>
_T MAX( const _T a, const _T b )
{
return a > b ? a : b;
}
struct edge
{
int to, nxt;
}Graph[MAXN << 1];
struct Query
{
int dd, id, A, K;
bool operator < ( const Query &b ) const { return dd < b.dd; }
};
Query q[MAXN];
LL BIT[MAXN], ans[MAXN];
int pts[MAXN];
int head[MAXN], dep[MAXN], pos[MAXN], siz[MAXN];
int N, Q, cnt, ID;
int lowbit( const int x ) { return x & ( -x ); }
void update( int x, int v ) { for( ; x <= N ; x += lowbit( x ) ) BIT[x] += v; }
LL getSum( int x ) { LL ret = 0; while( x ) ret += BIT[x], x -= lowbit( x ); return ret; }
LL query( int l, int r ) { return getSum( r ) - getSum( l - 1 ); }
bool cmp( const int &x, const int &y ) { return dep[x] < dep[y]; }
void addEdge( const int from, const int to )
{
Graph[++ cnt].to = to, Graph[cnt].nxt = head[from];
head[from] = cnt;
}
void DFS( const int u, const int fa )
{
dep[u] = dep[fa] + 1, siz[u] = 1, pos[u] = ++ ID;
for( int i = head[u], v ; i ; i = Graph[i].nxt )
if( ( v = Graph[i].to ) ^ fa )
DFS( v, u ), siz[u] += siz[v];
}
signed main()
{
int height = 0;
read( N ), read( Q );
for( int i = 1, a, b ; i < N ; i ++ ) read( a ), read( b ), addEdge( a, b ), addEdge( b, a );
DFS( 1, 0 );
for( int i = 1 ; i <= N ; i ++ ) height = MAX( height, dep[i] ), pts[i] = i;
for( int i = 1 ; i <= Q ; i ++ )
read( q[i].A ), read( q[i].K ), q[i].dd = MIN( dep[q[i].A] + q[i].K, height ), q[i].id = i;
std :: sort( q + 1, q + 1 + Q );
std :: sort( pts + 1, pts + 1 + N, cmp );
int rig = 1, ptr = 1;
for( int i = 1 ; i <= N ; )
{
for( ; rig <= N && dep[pts[rig]] == dep[pts[i]] ; rig ++ );
for( ; i < rig ; i ++ ) update( pos[pts[i]], siz[pts[i]] - 1 );
if( q[ptr].dd == dep[pts[i - 1]] )
{
for( ; ptr <= Q && q[ptr].dd == dep[pts[i - 1]] ; ptr ++ )
{
ans[q[ptr].id] = query( pos[q[ptr].A] + 1, pos[q[ptr].A] + siz[q[ptr].A] - 1 );
ans[q[ptr].id] += 1ll * MIN( q[ptr].K, dep[q[ptr].A] - 1 ) * ( siz[q[ptr].A] - 1 );
}
}
}
for( int i = 1 ; i <= Q ; i ++ ) write( ans[i] ), putchar( '
' );
return 0;
}