求树上点权积为立方数的路径数。
显然,分解质因数后,若所有的质因子出现的次数都%3==0,则该数是立方数。
于是在模意义下暴力统计即可。
当然,为了不MLE/TLE,我们不能存一个30长度的数组,而要压成一个long long。
存储状态用map即可,貌似哈希表可以随便卡掉……?
手动开栈……当然这样有可能MLE,所以还得改一些BFS……
<法一>map:
#pragma comment(linker, "/STACK:1024000000,1024000000") #include<cstdio> #include<algorithm> #include<cstring> #include<map> #include<queue> using namespace std; #define MAXN 50001 typedef pair<int,int> Point; typedef long long ll; struct Point2 { int x;ll y;int z; Point2(){} Point2(const int &a,const ll &b,const int &c){x=a;y=b;z=c;} }; #define MOD 50021 queue<Point2>q; map<ll,int>ma; int n,m,ans,ev[MAXN][31]; ll pr[MAXN]; int v[MAXN<<1],first[MAXN],next[MAXN<<1],en; void AddEdge(const int &U,const int &V) { v[++en]=V; next[en]=first[U]; first[U]=en; } bool centroid[MAXN]; int size[MAXN]; int calc_sizes(int U,int Fa) { int res=1; for(int i=first[U];i;i=next[i]) if(v[i]!=Fa&&!centroid[v[i]]) res+=calc_sizes(v[i],U); return size[U]=res; } Point calc_centroid(int U,int Fa,int nn) { Point res=make_pair(2147483647,-1); int sum=1,maxv=0; for(int i=first[U];i;i=next[i]) if(v[i]!=Fa&&!centroid[v[i]]) { res=min(res,calc_centroid(v[i],U,nn)); maxv=max(maxv,size[v[i]]); sum+=size[v[i]]; } maxv=max(maxv,nn-sum); res=min(res,make_pair(maxv,U)); return res; } int en2,en3; ll td[MAXN],ds[MAXN],base[31]; void calc_dis(int U,int Fa) { td[en2]=0; for(int j=1;j<=m;++j) td[en2]+=ev[U][j]*base[j-1]; ++ma[td[en2]]; q.push(Point2(U,td[en2++],Fa)); while(!q.empty()) { Point2 Node=q.front(); q.pop(); for(int i=first[Node.x];i;i=next[i]) if(v[i]!=Node.z&&!centroid[v[i]]) { ll t=Node.y; td[en2]=0; for(int j=1;j<=m;++j) { td[en2]+=((ev[v[i]][j]+t%3)%3)*base[j-1]; t/=3; } ++ma[td[en2]]; q.push(Point2(v[i],td[en2++],Node.x)); } } } int calc_pairs(ll dis[],int En,int S) { int res=0; for(int i=0;i<En;++i) { ll che=0,t=dis[i]; for(int j=1;j<=m;++j) { che+=((6-t%3-ev[S][j])%3)*base[j-1]; t/=3; } res+=ma[che]; } return res; } void solve(int U) { calc_sizes(U,-1); int s=calc_centroid(U,-1,size[U]).second; centroid[s]=1; for(int i=first[s];i;i=next[i]) if(!centroid[v[i]]) solve(v[i]); en3=0; ds[en3++]=0; ma.clear(); for(int i=first[s];i;i=next[i]) if(!centroid[v[i]]) { en2=0; calc_dis(v[i],s); ans-=calc_pairs(td,en2,s); memcpy(ds+en3,td,en2*sizeof(ll)); en3+=en2; } ans+=calc_pairs(ds,en3,s); centroid[s]=0; } void init() { memset(first,0,(n+1)*sizeof(int)); memset(ev,0,sizeof(ev)); en=ans=0; } int main() { ll t; int a,b; base[0]=1; for(int i=1;i<=30;++i) base[i]=base[i-1]*3; while(scanf("%d%d",&n,&m)!=EOF) { init(); for(int i=1;i<=m;++i) scanf("%I64d",&pr[i]); for(int i=1;i<=n;++i) { bool flag=1; scanf("%I64d",&t); for(int j=1;j<=m;++j) { while(t%pr[j]==0&&t) { ++ev[i][j]; t/=pr[j]; } ev[i][j]%=3; if(ev[i][j]) flag=0; } ans+=flag; } for(int i=1;i<n;++i) { scanf("%d%d",&a,&b); AddEdge(a,b); AddEdge(b,a); } solve(1); printf("%d ",ans); } return 0; }
Update:<法二>哈希表,新的姿势,超快的说
#pragma comment(linker, "/STACK:1024000000,1024000000") #include<cstdio> #include<algorithm> #include<cstring> #include<queue> using namespace std; #define MAXN 50001 #define MOD 50021 typedef pair<int,int> Point; typedef long long ll; struct Point2 { int x;ll y;int z; Point2(){} Point2(const int &a,const ll &b,const int &c){x=a;y=b;z=c;} }; struct HashTable { ll v[MOD]; int en,first[MOD],next[MOD]; HashTable(){en=0;memset(first,-1,sizeof(first));} void clear(){en=0;memset(first,-1,sizeof(first));} void insert(const ll &V) { int U=(int)(V%MOD); v[en]=V; next[en]=first[U]; first[U]=en++; } int count(const ll &V) { int U=(int)(V%MOD),res=0; for(int i=first[U];i!=-1;i=next[i]) if(v[i]==V) ++res; return res; } }T; queue<Point2>q; int n,m,ans,ev[MAXN][31]; ll pr[MAXN]; int v[MAXN<<1],first[MAXN],next[MAXN<<1],en; void AddEdge(const int &U,const int &V) { v[++en]=V; next[en]=first[U]; first[U]=en; } bool centroid[MAXN]; int size[MAXN]; int calc_sizes(int U,int Fa) { int res=1; for(int i=first[U];i;i=next[i]) if(v[i]!=Fa&&!centroid[v[i]]) res+=calc_sizes(v[i],U); return size[U]=res; } Point calc_centroid(int U,int Fa,int nn) { Point res=make_pair(2147483647,-1); int sum=1,maxv=0; for(int i=first[U];i;i=next[i]) if(v[i]!=Fa&&!centroid[v[i]]) { res=min(res,calc_centroid(v[i],U,nn)); maxv=max(maxv,size[v[i]]); sum+=size[v[i]]; } maxv=max(maxv,nn-sum); res=min(res,make_pair(maxv,U)); return res; } int En,last; ll dis[MAXN],base[31]; void calc_dis(int U,int Fa) { dis[En]=0; for(int j=1;j<=m;++j) dis[En]+=ev[U][j]*base[j-1]; q.push(Point2(U,dis[En++],Fa)); while(!q.empty()) { Point2 Node=q.front(); q.pop(); for(int i=first[Node.x];i;i=next[i]) if(v[i]!=Node.z&&!centroid[v[i]]) { ll t=Node.y; dis[En]=0; for(int j=1;j<=m;++j) { dis[En]+=((ev[v[i]][j]+t%3)%3)*base[j-1]; t/=3; } q.push(Point2(v[i],dis[En++],Node.x)); } } } void calc_pairs(int s) { for(int i=last;i<En;++i) { ll che=0,t=dis[i]; for(int j=1;j<=m;++j) { che+=((6-t%3-ev[s][j])%3)*base[j-1]; t/=3; } ans+=T.count(che); } for(int i=last;i<En;++i) T.insert(dis[i]); } void solve(int U) { calc_sizes(U,-1); int s=calc_centroid(U,-1,size[U]).second; centroid[s]=1; for(int i=first[s];i;i=next[i]) if(!centroid[v[i]]) solve(v[i]); En=0; dis[En++]=0; T.insert(0); for(int i=first[s];i;i=next[i]) if(!centroid[v[i]]) { last=En; calc_dis(v[i],s); calc_pairs(s); } for(int i=0;i<En;++i) T.first[dis[i]%MOD]=-1; T.en=0; centroid[s]=0; } void init() { memset(first,0,(n+1)*sizeof(int)); memset(ev,0,sizeof(ev)); en=ans=0; } int main() { ll t; int a,b; base[0]=1; for(int i=1;i<=30;++i) base[i]=base[i-1]*3; while(scanf("%d%d",&n,&m)!=EOF) { init(); for(int i=1;i<=m;++i) scanf("%I64d",&pr[i]); for(int i=1;i<=n;++i) { bool flag=1; scanf("%I64d",&t); for(int j=1;j<=m;++j) { while(t%pr[j]==0&&t) { ++ev[i][j]; t/=pr[j]; } ev[i][j]%=3; if(ev[i][j]) flag=0; } ans+=flag; } for(int i=1;i<n;++i) { scanf("%d%d",&a,&b); AddEdge(a,b); AddEdge(b,a); } solve(1); printf("%d ",ans); } return 0; }
<法三>哈希表,旧的姿势,比法一快,比法二慢
#pragma comment(linker, "/STACK:1024000000,1024000000") #include<cstdio> #include<algorithm> #include<cstring> #include<queue> using namespace std; #define MAXN 50001 #define MOD 50021 typedef pair<int,int> Point; typedef long long ll; struct Point2 { int x;ll y;int z; Point2(){} Point2(const int &a,const ll &b,const int &c){x=a;y=b;z=c;} }; struct HashTable { ll v[MOD]; int en,first[MOD],next[MOD]; HashTable(){en=0;memset(first,-1,sizeof(first));} void clear(){en=0;memset(first,-1,sizeof(first));} void insert(const ll &V) { int U=(int)(V%MOD); v[en]=V; next[en]=first[U]; first[U]=en++; } int count(const ll &V) { int U=(int)(V%MOD),res=0; for(int i=first[U];i!=-1;i=next[i]) if(v[i]==V) ++res; return res; } }T; queue<Point2>q; int n,m,ans,ev[MAXN][31]; ll pr[MAXN]; int v[MAXN<<1],first[MAXN],next[MAXN<<1],en; void AddEdge(const int &U,const int &V) { v[++en]=V; next[en]=first[U]; first[U]=en; } bool centroid[MAXN]; int size[MAXN]; int calc_sizes(int U,int Fa) { int res=1; for(int i=first[U];i;i=next[i]) if(v[i]!=Fa&&!centroid[v[i]]) res+=calc_sizes(v[i],U); return size[U]=res; } Point calc_centroid(int U,int Fa,int nn) { Point res=make_pair(2147483647,-1); int sum=1,maxv=0; for(int i=first[U];i;i=next[i]) if(v[i]!=Fa&&!centroid[v[i]]) { res=min(res,calc_centroid(v[i],U,nn)); maxv=max(maxv,size[v[i]]); sum+=size[v[i]]; } maxv=max(maxv,nn-sum); res=min(res,make_pair(maxv,U)); return res; } int en2,en3; ll td[MAXN],ds[MAXN],base[31]; void calc_dis(int U,int Fa) { td[en2]=0; for(int j=1;j<=m;++j) td[en2]+=ev[U][j]*base[j-1]; T.insert(td[en2]); q.push(Point2(U,td[en2++],Fa)); while(!q.empty()) { Point2 Node=q.front(); q.pop(); for(int i=first[Node.x];i;i=next[i]) if(v[i]!=Node.z&&!centroid[v[i]]) { ll t=Node.y; td[en2]=0; for(int j=1;j<=m;++j) { td[en2]+=((ev[v[i]][j]+t%3)%3)*base[j-1]; t/=3; } T.insert(td[en2]); q.push(Point2(v[i],td[en2++],Node.x)); } } } int calc_pairs(ll dis[],int En,int S) { int res=0; for(int i=0;i<En;++i) { ll che=0,t=dis[i]; for(int j=1;j<=m;++j) { che+=((6-t%3-ev[S][j])%3)*base[j-1]; t/=3; } res+=T.count(che); } return res; } void solve(int U) { calc_sizes(U,-1); int s=calc_centroid(U,-1,size[U]).second; centroid[s]=1; for(int i=first[s];i;i=next[i]) if(!centroid[v[i]]) solve(v[i]); en3=0; ds[en3++]=0; T.clear(); for(int i=first[s];i;i=next[i]) if(!centroid[v[i]]) { en2=0; calc_dis(v[i],s); ans-=calc_pairs(td,en2,s); memcpy(ds+en3,td,en2*sizeof(ll)); en3+=en2; } ans+=calc_pairs(ds,en3,s); centroid[s]=0; } void init() { memset(first,0,(n+1)*sizeof(int)); memset(ev,0,sizeof(ev)); en=ans=0; } int main() { ll t; int a,b; base[0]=1; for(int i=1;i<=30;++i) base[i]=base[i-1]*3; while(scanf("%d%d",&n,&m)!=EOF) { init(); for(int i=1;i<=m;++i) scanf("%I64d",&pr[i]); for(int i=1;i<=n;++i) { bool flag=1; scanf("%I64d",&t); for(int j=1;j<=m;++j) { while(t%pr[j]==0&&t) { ++ev[i][j]; t/=pr[j]; } ev[i][j]%=3; if(ev[i][j]) flag=0; } ans+=flag; } for(int i=1;i<n;++i) { scanf("%d%d",&a,&b); AddEdge(a,b); AddEdge(b,a); } solve(1); printf("%d ",ans); } return 0; }