题目描述
树链是指树里的一条路径。美团外卖的形象代言人袋鼠先生最近在研究一个特殊的最长树链问题。现在树中的每个点都有一个正整数值,他想在树中找出最长的树链,使得这条树链上所有对应点的值的最大公约数大于1。请求出这条树链的长度。
输入描述:
第1行:整数n(1 ≤ n ≤ 100000),表示点的个数。
第2~n行:每行两个整数x,y表示xy之间有边,数据保证给出的是一棵树。
第n+1行:n个整数,依次表示点1~n对应的权值(1 ≤ 权值 ≤ 1,000,000,000)。
输出描述:
满足最长路径的长度
输入例子:
4
1 2
1 3
2 4
6 4 5 2
输出例子:
3
题解:也许当时我并没有报名参加比赛是个错误的决定?
看起来10^9很虚,但是我们分解质因数只需要预处理出3*10^4里的质数,实际上只有3000多个,3000*n完全不虚,所以我们可以先将n个数全都分解质因数。
然后我们枚举每个质数,枚举所有包含这些质数的点,看一下这些点在树上能形成的最长的链有多长。具体做法是我们将这些点按照深度从大到小排序,然后更新每个点父亲的子树中到父亲的最长链、次长链分别是多长,乱搞一下就行了。
T4帮同学1A了,但是一到1000分的题才拿了700多分,其他题看都没看,我只能默默在这边打辅助了~
如果你把这个代码交上去你就废了呵呵呵~
#include <cstdio> #include <cstring> #include <iostream> #include <algorithm> #include <cmath> #include <vector> #include <map> using namespace std; int n,m,cnt,ans,tot,mor; const int maxn=100010; int to[maxn<<1],nxt[maxn<<1],head[maxn],v[maxn],d1[maxn],d2[maxn],p[103600],fa[maxn],dep[maxn],pri[3600]; bool np[33000]; vector<int> s[103600]; map<int,int> mp; int rd() { int ret=0; char gc=getchar(); while(gc<'0'||gc>'9') gc=getchar(); while(gc>='0'&&gc<='9') ret=ret*10+gc-'0',gc=getchar(); return ret; } void dfs(int x) { for(int i=head[x];i!=-1;i=nxt[i]) if(to[i]!=fa[x]) fa[to[i]]=x,dep[to[i]]=dep[x]+1,dfs(to[i]); } bool cmp1(int a,int b) { return s[a].size()>s[b].size(); } bool cmp2(int a,int b) { return dep[a]>dep[b]; } void updata(int a,int b) { if(d1[b]>d1[a]) d2[a]=d1[a],d1[a]=d1[b]; else d2[a]=max(d2[a],d1[b]); } void add(int a,int b) { to[cnt]=b,nxt[cnt]=head[a],head[a]=cnt++; } int main() { n=rd(); int i,j,a,b; memset(head,-1,sizeof(head)); for(i=1;i<n;i++) a=rd(),b=rd(),add(a,b),add(b,a); for(i=1;i<=n;i++) v[i]=rd(),m=max(m,v[i]); m=ceil(sqrt(1.0*m)); for(i=2;i<=m;i++) { if(!np[i]) pri[++tot]=i,mp[i]=tot; for(j=1;j<=tot&&i*pri[j]<=m;j++) { np[i*pri[j]]=1; if(i%pri[j]==0) break; } } dep[1]=1,dfs(1),mor=tot; for(i=1;i<=n;i++) { for(j=1;j<=tot&&pri[j]*pri[j]<=v[i];j++) { if(v[i]%pri[j]==0) { s[j].push_back(i); while(v[i]%pri[j]==0) v[i]/=pri[j]; } } if(v[i]>1) { if(mp.find(v[i])==mp.end()) mp[v[i]]=++mor; s[mp[v[i]]].push_back(i); } } ans=1; for(i=1;i<=mor;i++) p[i]=i; sort(p+1,p+mor+1,cmp1); for(i=1;i<=mor;i++) { b=p[i]; if(s[b].size()<=ans) break; sort(s[b].begin(),s[b].end(),cmp2); for(j=0;j<s[b].size();j++) a=s[b][j],d1[a]++,d2[a]++,ans=max(ans,d1[a]+d2[a]-1),updata(fa[a],a); for(j=0;j<s[b].size();j++) a=s[b][j],d1[a]=d2[a]=d1[fa[a]]=d2[fa[a]]=0; } printf("%d",ans); return 0; }