每个点有一个颜色,现在我们想知道所有链构成的不同颜色序列有多少种,正反是不一样的。
关键问题就是如何去处理“转折”这样的一个问题,但是可以看到题目中给出了一个条件,所有度为1的节点个数小于等于20个,也就是说,我们可以从每个度为1的节点开始搜,来避免遇上转折,因为转折的情况,在另外的时候被计算进去了。
#include <iostream> #include <cstdio> #include <cmath> #include <string> #include <cstring> #include <algorithm> #include <limits> #include <vector> #include <stack> #include <queue> #include <set> #include <map> #include <bitset> #include <unordered_map> #include <unordered_set> #define lowbit(x) ( x&(-x) ) #define pi 3.141592653589793 #define e 2.718281828459045 #define INF 0x3f3f3f3f #define HalF (l + r)>>1 #define lsn rt<<1 #define rsn rt<<1|1 #define Lson lsn, l, mid #define Rson rsn, mid+1, r #define QL Lson, ql, qr #define QR Rson, ql, qr #define myself rt, l, r #define pii pair<int, int> #define MP(a, b) make_pair(a, b) using namespace std; typedef unsigned long long ull; typedef unsigned int uit; typedef long long ll; const int maxS = 2e6 + 7; int N, C; struct Trie { int c, fa; int nex[11]; } t[maxS]; struct SAM { struct state { int len, link, next[11]; } st[maxS << 1]; int siz = 1; void init() { siz = 1; st[1].len = 0; st[1].link = 0; siz++; } int extend(int c, int last) { int cur = siz++; st[cur].len = st[last].len + 1; int p = last; while (p != 0 && !st[p].next[c]) { st[p].next[c] = cur; p = st[p].link; } if (p == 0) { st[cur].link = 1; } else { int q = st[p].next[c]; if (st[p].len + 1 == st[q].len) { st[cur].link = q; } else { int clone = siz++; st[clone].len = st[p].len + 1; memcpy(st[clone].next, st[q].next, sizeof(st[q].next)); st[clone].link = st[q].link; while (p != 0 && st[p].next[c] == q) { st[p].next[c] = clone; p = st[p].link; } st[q].link = st[cur].link = clone; } } return last = cur; } } sam; int col[maxS]; namespace Graph { const int maxN = maxS; int head[maxN], cnt, du[maxN]; struct Eddge { int nex, to; Eddge(int a=-1, int b=0):nex(a), to(b) {} } edge[maxN << 1]; inline void addEddge(int u, int v) { edge[cnt] = Eddge(head[u], v); head[u] = cnt++; du[v]++; } inline void _add(int u, int v) { addEddge(u, v); addEddge(v, u); } inline void init() { cnt = 0; for(int i=1; i<=N; i++) { head[i] = -1; du[i] = 0; } } }; using namespace Graph; int pos[maxS], que[maxN], top, tail, fa[maxN]; void bfs(int u) { fa[u] = 0; top = tail = 0; que[tail++] = u; pos[0] = 1; while(top < tail) { u = que[top++]; pos[u] = sam.extend(col[u], pos[fa[u]]); for(int i=head[u], v; ~i; i=edge[i].nex) { v = edge[i].to; if(v == fa[u]) continue; fa[v] = u; que[tail++] = v; } } } int main() { sam.init(); scanf("%d%d", &N, &C); for(int i=1; i<=N; i++) scanf("%d", &col[i]); init(); for(int i=1, u, v; i<N; i++) { scanf("%d%d", &u, &v); _add(u, v); } for(int i=1; i<=N; i++) if(du[i] == 1) bfs(i); ll ans = 0; for(int i=2; i<=sam.siz; i++) ans += sam.st[i].len - sam.st[sam.st[i].link].len; printf("%lld ", ans); return 0; }