正难则反的思想还是不能灵活应用啊
题意:给你n个点,每个点有一个权值,接着是n-1有向条边形成一颗有根树,问你有多少对点的权值乘积小于等于给定的值k,其中这对点必须是孩子节点与祖先的关系
我们反向思考,可以知道任一点都只对其每个祖先有贡献。所以我们可以转化为求每个点与其每个祖先的乘积小于等于给定的值k的对数。
我们dfs遍历这颗树使用树状数组维护,dfs遍历孩子就添点回溯就删点,接着对每个点计算树状数组里不大于(k/此点)的个数。注意值太大我们需要离散化,而且我们可以把每个点m与k/m都离散化出来,然后就是每次树状数组更新这个点的个数(加减1),查询不大于(k/此点)的个数
#include<set> #include<map> #include<queue> #include<stack> #include<cmath> #include<vector> #include<string> #include<cstdio> #include<cstring> #include<stdlib.h> #include<iostream> #include<algorithm> using namespace std; #define eps 1E-8 /*注意可能会有输出-0.000*/ #define Sgn(x) (x<-eps? -1 :x<eps? 0:1)//x为两个浮点数差的比较,注意返回整型 #define Cvs(x) (x > 0.0 ? x+eps : x-eps)//浮点数转化 #define zero(x) (((x)>0?(x):-(x))<eps)//判断是否等于0 #define mul(a,b) (a<<b) #define dir(a,b) (a>>b) typedef long long ll; typedef unsigned long long ull; const int Inf=1<<28; const double Pi=acos(-1.0); const int Mod=1e9+7; const int Max=200010; struct node { int pos; ll val; } nod[Max]; int bit[Max],vis[Max],n; ll k,ans,rop[Max]; int head[Max],nnext[Max],to[Max],e; bool cmp(struct node p1,struct node p2) { if(p1.val==p2.val) return p1.pos<p2.pos; return p1.val<p2.val; } void Add(int u,int v) { to[e]=v; nnext[e]=head[u]; head[u]=e++; return; } int lowbit(int x) { return x&(-x); } void AddBit(int x,int y) { while(x<=(n<<1)) { bit[x]+=y; x+=lowbit(x); } return; } int Sum(int x) { int sum=0; while(x) { sum+=bit[x]; x-=lowbit(x); } return sum; } void dfs(int son,int fat) { AddBit(rop[son],1);//遍历孩子添边 for(int i=head[son]; i!=-1; i=nnext[i]) { if(to[i]!=fat) { ans+=Sum(rop[to[i]+n]); dfs(to[i],son); } } AddBit(rop[son],-1);//回溯删边 return; } int main() { int t; scanf("%d",&t); while(t--) { memset(vis,0,sizeof(vis)); scanf("%d %I64d",&n,&k); for(int i=1; i<=n; ++i) { scanf("%I64d",&nod[i].val); nod[i].pos=i; } for(int i=1; i<=n; ++i)//和k/num[u]一起离散化 { if(nod[i].val)//关键注意除数为0 nod[i+n].val=k/nod[i].val; else nod[i+n].val=0ll; nod[i+n].pos=i+n; } sort(nod+1,nod+1+(n<<1),cmp); for(int i=1; i<=(n<<1); ++i)//离散化 rop[nod[i].pos]=i; e=0; memset(head,-1,sizeof(head)); int u,v; for(int i=1; i<n; ++i) { scanf("%d %d",&u,&v); Add(u,v); vis[v]++; } ans=0ll; memset(bit,0,sizeof(bit)); for(int i=1; i<=n; ++i) { if(!vis[i])//找根 { dfs(i,i); } } printf("%I64d ",ans); } return 0; }