点分治+trie树+迷之复杂度分析
调代码时候的鬼畜错误也是不想说撒了
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef pair<int,int>pii; 4 #define fir first 5 #define sec second 6 #define mp make_pair 7 #define maxn 100005 8 #define maxnode 3000005 9 int cnt,v[maxn<<1],next[maxn<<1],first[maxn]; 10 int trie[maxnode][2],mark[maxnode],tot; 11 int Max,n,K,pp,root,sum,f[maxn],size[maxn],vis[maxn],lik[maxn],val[maxn]; 12 pii poi[maxn]; 13 void add(int st,int end){ 14 v[++cnt]=end; 15 next[cnt]=first[st]; 16 first[st]=cnt; 17 } 18 void getroot(int x,int fa){ 19 size[x]=1;f[x]=0; 20 for(int e=first[x];e;e=next[e]){ 21 if(v[e]!=fa&&!vis[v[e]]){ 22 getroot(v[e],x); 23 size[x]+=size[v[e]]; 24 f[x]=max(f[x],size[v[e]]); 25 } 26 } 27 f[x]=max(f[x],sum-size[x]); 28 if(f[x]<f[root])root=x; 29 } 30 void clear(){ 31 for(int i=0;i<=tot;i++){ 32 trie[i][0]=trie[i][1]=0; 33 mark[i]=0; 34 } 35 tot=0; 36 } 37 void insert(int x,int ks){ 38 int p=0; 39 for(int i=30;i>=0;i--){ 40 int s=(x>>i)&1; 41 if(!trie[p][s])trie[p][s]=++tot; 42 mark[p]=max(mark[p],ks); 43 p=trie[p][s]; 44 } 45 mark[p]=max(mark[p],ks); 46 } 47 int query(int x,int ks){ 48 int p=0,ans=0; 49 for(int i=30;i>=0;i--){ 50 int s=(x>>i)&1; 51 if(trie[p][s^1]&&mark[trie[p][s^1]]+ks>=K)p=trie[p][s^1],ans+=(1<<i); 52 else if(trie[p][s]&&mark[trie[p][s]]+ks>=K)p=trie[p][s]; 53 else return -1; 54 } 55 return ans; 56 } 57 void dfs(int x,int fa,int sx,int ks){ 58 Max=max(Max,query(sx,ks)); 59 poi[++pp]=mp(sx^val[root],ks+lik[root]); 60 for(int e=first[x];e;e=next[e]) 61 if(v[e]!=fa&&!vis[v[e]]) 62 dfs(v[e],x,sx ^ val[v[e]],ks + lik[v[e]]); 63 } 64 void work(int x){ 65 vis[x]=1; 66 clear(); 67 insert(val[x],lik[x]);//ziji 68 for(int e=first[x];e;e=next[e]){ 69 if(!vis[v[e]]){ 70 pp=0; 71 dfs(v[e],x,val[v[e]],lik[v[e]]); 72 for(int i=1;i<=pp;i++) 73 insert(poi[i].fir,poi[i].sec); 74 } 75 } 76 Max=max(Max,query(0,0));//以x为端点 77 int tmp=sum; 78 for(int e=first[x];e;e=next[e]) 79 if(!vis[v[e]]){ 80 root=0,sum=size[v[e]]; 81 if(sum>tmp/2)sum=tmp-size[x]; 82 getroot(v[e],0),work(root); 83 } 84 } 85 int main(){ 86 scanf("%d%d",&n,&K); 87 for(int i=1;i<=n;i++)scanf("%d",&lik[i]); 88 for(int i=1;i<=n;i++)scanf("%d",&val[i]); 89 int a,b; 90 for(int i=1;i<n;i++){ 91 scanf("%d%d",&a,&b); 92 add(a,b);add(b,a); 93 } 94 Max=-1,f[0]=n + 1, sum=n; 95 getroot(1,0),work(root); 96 printf("%d ",Max); 97 }