题意:
给定一颗树,问树上有多少节点对,节点对间包括了所有K种苹果。
思路:
点分治,对于每个节点记录从根节点到这个节点包含的所有情况,类似状压,因为K《=10。然后处理每个重根连着的点的值:直接枚举每个点,然后找出这个点对应的每个子集,累计和子集互补的个数。
枚举一个数的子集,例如1010,它的子集包括1010,1000,0010,0000.这里有个技巧:
for(int s = x; s; s = (s - 1) & x){ res += 1ll*cnt[((1<<k)-1) ^ s]; }
//#pragma GCC optimize(3) //#pragma comment(linker, "/STACK:102400000,102400000") //c++ // #pragma GCC diagnostic error "-std=c++11" // #pragma comment(linker, "/stack:200000000") // #pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native") #include <algorithm> #include <iterator> #include <iostream> #include <cstring> #include <cstdlib> #include <iomanip> #include <bitset> #include <cctype> #include <cstdio> #include <string> #include <vector> #include <stack> #include <cmath> #include <queue> #include <list> #include <map> #include <set> #include <cassert> using namespace std; #define lson (l , mid , rt << 1) #define rson (mid + 1 , r , rt << 1 | 1) #define debug(x) cerr << #x << " = " << x << " "; #define pb push_back #define pq priority_queue typedef long long ll; typedef unsigned long long ull; //typedef __int128 bll; typedef pair<ll ,ll > pll; typedef pair<int ,int > pii; typedef pair<int,pii> p3; //priority_queue<int> q;//这是一个大根堆q //priority_queue<int,vector<int>,greater<int> >q;//这是一个小根堆q #define fi first #define se second //#define endl ' ' #define OKC ios::sync_with_stdio(false);cin.tie(0) #define FT(A,B,C) for(int A=B;A <= C;++A) //用来压行 #define REP(i , j , k) for(int i = j ; i < k ; ++i) #define max3(a,b,c) max(max(a,b), c); #define min3(a,b,c) min(min(a,b), c); //priority_queue<int ,vector<int>, greater<int> >que; const ll mos = 0x7FFFFFFF; //2147483647 const ll nmos = 0x80000000; //-2147483648 const int inf = 0x3f3f3f3f; const ll inff = 0x3f3f3f3f3f3f3f3f; //18 const int mod = 1e9+7; const double esp = 1e-8; const double PI=acos(-1.0); const double PHI=0.61803399; //黄金分割点 const double tPHI=0.38196601; template<typename T> inline T read(T&x){ x=0;int f=0;char ch=getchar(); while (ch<'0'||ch>'9') f|=(ch=='-'),ch=getchar(); while (ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar(); return x=f?-x:x; } /*-----------------------showtime----------------------*/ const int maxn = 50009; int a[maxn],g[maxn],dp[maxn],cnt[maxn]; vector<int>mp[maxn]; int n,k; ll ans = 0; void dfs(int u,int fa){ dp[u] = 1; for(int i=0; i<mp[u].size(); i++){ int v = mp[u][i]; if(g[v] || fa == v)continue; dfs(v, u); dp[u] += dp[v]; } } pii findg(int u,int fa, int sz){ int mx = 0; pii tmp = pii(inf, u); for(int i=0; i<mp[u].size(); i++){ int v = mp[u][i]; if(g[v] || fa == v)continue; tmp = min(tmp, findg(v,u,sz)); mx = max(mx, dp[v]); } mx = max(mx, sz - dp[u]); return min(tmp, pii(mx, u)); } void route(int u, int fa, vector<int>& ve, int sta){ sta = ((1<<a[u]) | sta); ve.pb(sta); for(int i=0; i<mp[u].size(); i++){ int v = mp[u][i]; if(v == fa || g[v])continue; route(v, u, ve, sta); } } ll cal(vector<int> &ve){ // memset(cnt, 0, sizeof(cnt)); for(int i=0; i<2000; i++) cnt[i] = 0; for(int i=0; i<ve.size(); i++){ cnt[ve[i]] ++; } /* Hash[it]-=1; ans+=Hash[(1<<m)-1]; for(int j=it;j;j=(j-1)&it){ ans+=Hash[((1<<m)-1)^j]; } Hash[it]+=1; */ ll res = 0; for(int i=0; i<ve.size(); i++){ int x = ve[i]; cnt[ve[i]]--; res += 1ll*cnt[(1<<k)-1]; for(int s = x; s; s = (s - 1) & x){ res += 1ll*cnt[((1<<k)-1) ^ s]; } cnt[ve[i]]++; } return res; } void divide(int u){ dfs(u,-1); int rt = findg(u, -1, dp[u]).se; g[rt] = 1; for(int i=0; i<mp[rt].size(); i++){ int v = mp[rt][i]; if(g[v])continue; divide(v); } vector<int>all; all.pb((1<<a[rt])); for(int i=0; i<mp[rt].size(); i++){ vector<int>ve; int v = mp[rt][i]; if(g[v])continue; route(v, -1, ve, (1<<a[rt])); ans -= 1ll*cal(ve); all.insert(all.end(),ve.begin(),ve.end()); } ans += 1ll*cal(all); g[rt] = 0; } int main(){ while(~scanf("%d%d", &n, &k)){ for(int i=1; i<=n; i++) scanf("%d", &a[i]), a[i]--; for(int i=1; i<=n; i++) mp[i].clear(); for(int i=1; i< n; i++) { int u,v; scanf("%d%d", &u, &v); mp[u].pb(v); mp[v].pb(u); } if(k == 1) { ans = 1ll*n*n; printf("%lld ", ans); continue; } // memset(g,0,sizeof(g)); ans = 0; divide(1); printf("%lld ", ans); } return 0 ; }