以后传数组绝对用指针...
考虑点分治
在点分的时候,把相同的颜色的在一起合并
之后,把不同颜色依次合并
我们可以用单调队列做到单次合并$O(n + m)$
如果我们按照深度大小来合并,那么由于每次都是把大的往小的去合并
因此,合并$n$的序列最多需要$2n$的势能
因此,最终我们就能达到$O(n log n)$的统计复杂度
然而还有排序,所以实际复杂度$O(n log^2 n)$,排序常数很小,自然能过
#include <set> #include <vector> #include <cstdio> #include <cstring> #include <iostream> #include <algorithm> namespace remoon { #define re register #define de double #define le long double #define ri register int #define ll long long #define sh short #define pii pair<int, int> #define mp make_pair #define pb push_back #define fi first #define se second #define tpr template <typename ra> #define rep(iu, st, ed) for(ri iu = st; iu <= ed; iu ++) #define drep(iu, ed, st) for(ri iu = ed; iu >= st; iu --) #define gc getchar inline int read() { int p = 0, w = 1; char c = gc(); while(c > '9' || c < '0') { if(c == '-') w = -1; c = gc(); } while(c >= '0' && c <= '9') p = p * 10 + c - '0', c = gc(); return p * w; } int wr[50], rw; #define pc(iw) putchar(iw) tpr inline void write(ra o, char c = ' ') { if(!o) pc('0'); if(o < 0) o = -o, pc('-'); while(o) wr[++ rw] = o % 10, o /= 10; while(rw) pc(wr[rw --] + '0'); pc(c); } tpr inline void cmin(ra &a, ra b) { if(a > b) a = b; } tpr inline void cmax(ra &a, ra b) { if(a < b) a = b; } tpr inline bool ckmin(ra &a, ra b) { return (a > b) ? a = b, 1 : 0; } tpr inline bool ckmax(ra &a, ra b) { return (a < b) ? a = b, 1 : 0; } } using namespace std; using namespace remoon; #define sid 400050 #define inf 1e9 int n, m, l, r, rt, asz, cnp, ans = -2e9; int over[sid], son[sid], sz[sid], dep[sid], col[sid]; int cap[sid], nxt[sid], node[sid], cv[sid]; inline void addedge(int u, int v, int c) { nxt[++ cnp] = cap[u]; cap[u] = cnp; node[cnp] = v; col[cnp] = c; } struct Ans { int f[sid], end; } nowc, prec, now; inline void init(Ans &a) { rep(i, 0, a.end) a.f[i] = -inf; a.end = 0; } inline void upd(Ans &a, Ans &b) { int len = max(a.end, b.end); a.end = len; rep(i, 0, len) { cmax(a.f[i], b.f[i]); if(l <= i && i <= r) cmax(ans, a.f[i]); } } int q[sid]; inline void qry(Ans &a, Ans &b, int opt = 0) { if(l == 1 && r == n - 1) { int mx1 = -inf, mx2 = -inf; rep(i, 0, a.end) cmax(mx1, a.f[i]); rep(j, 0, b.end) cmax(mx2, b.f[j]); cmax(ans, mx1 + mx2 - opt); } else { int fr = 1, to = 0; drep(i, min(r, b.end), l) { while(fr <= to && b.f[i] >= b.f[q[to]]) to --; q[++ to] = i; } rep(i, 0, a.end) { if(l - i >= 0) { while(fr <= to && b.f[l - i] >= b.f[q[to]]) to --; q[++ to] = l - i; } while(fr <= to && q[fr] > r - i) fr ++; if(fr <= to) cmax(ans, a.f[i] + b.f[q[fr]] - opt); } } } int vis[sid], tim; vector <pii> all, c[sid]; #define cur node[i] inline void grt(int o, int fa) { sz[o] = 1; son[o] = 0; for(int i = cap[o]; i; i = nxt[i]) if(!over[cur] && cur != fa){ grt(cur, o); sz[o] += sz[cur]; cmax(son[o], sz[cur]); } cmax(son[o], asz - sz[o]); if(son[o] < son[rt]) rt = o; } inline int gh(int o, int fa) { int tmp = dep[o]; for(ri i = cap[o]; i; i = nxt[i]) if(cur != fa && !over[cur]) { dep[cur] = dep[o] + 1; cmax(tmp, gh(cur, o)); } return tmp; } inline void gt(int o, int fa, int val, int lst) { sz[o] = 1; for(int i = cap[o]; i; i = nxt[i]) if(!over[cur] && cur != fa) { int C = col[i]; int nxt = (C == lst) ? 0 : cv[C]; gt(cur, o, val + nxt, C); sz[o] += sz[cur]; } cmax(now.end, dep[o]); cmax(now.f[dep[o]], val); } inline void solve(int o) { over[o] = 1; ++ tim; for(int i = cap[o]; i; i = nxt[i]) if(!over[cur]) { int C = col[i]; if(vis[C] != tim) c[C].clear(); vis[C] = tim; dep[cur] = 1; c[C].pb(mp(gh(cur, o), cur)); } ++ tim; all.clear(); for(int i = cap[o]; i; i = nxt[i]) if(!over[cur]) { int C = col[i]; if(vis[C] != tim) { vis[C] = tim; sort(c[C].begin(), c[C].end()); int len = c[C][c[C].size() - 1].fi; all.pb(mp(len, C)); } } sort(all.begin(), all.end()); init(prec); for(auto x : all) { int C = x.se; init(nowc); for(auto y : c[C]) { init(now); gt(y.se, o, cv[C], C); qry(now, nowc, cv[C]); upd(nowc, now); } qry(prec, nowc); upd(prec, nowc); } for(int i = cap[o]; i; i = nxt[i]) if(!over[cur]) { asz = sz[cur]; rt = 0; grt(cur, o); solve(rt); } } int main() { n = read(); m = read(); l = read(); r = read(); rep(i, 1, m) cv[i] = read(); rep(i, 2, n) { int u = read(), v = read(), w = read(); addedge(u, v, w); addedge(v, u, w); } rep(i, 0, n) now.f[i] = -inf; rep(i, 0, n) prec.f[i] = -inf; rep(i, 0, n) nowc.f[i] = -inf; asz = n; son[0] = n; grt(1, 0); solve(rt); write(ans); return 0; }