明显发现这个东西可以递归处理, 然后把式子列出来, 记忆化搜就可以了。
#include<bits/stdc++.h> #define LL long long #define LD long double #define ull unsigned long long #define fi first #define se second #define mk make_pair #define PLL pair<LL, LL> #define PLI pair<LL, int> #define PII pair<int, int> #define SZ(x) ((int)x.size()) #define ALL(x) (x).begin(), (x).end() #define fio ios::sync_with_stdio(false); cin.tie(0); using namespace std; const int N = 60 + 7; const int inf = 0x3f3f3f3f; const LL INF = 0x3f3f3f3f3f3f3f3f; const int mod = 1e9 + 7; const double eps = 1e-8; const double PI = acos(-1); template<class T, class S> inline void add(T &a, S b) {a += b; if(a >= mod) a -= mod;} template<class T, class S> inline void sub(T &a, S b) {a -= b; if(a < 0) a += mod;} template<class T, class S> inline bool chkmax(T &a, S b) {return a < b ? a = b, true : false;} template<class T, class S> inline bool chkmin(T &a, S b) {return a > b ? a = b, true : false;} mt19937 rng(chrono::steady_clock::now().time_since_epoch().count()); LL m; LL a[N], b[N], c[N], d[N], l[N]; LL ans[N]; LL sz[N]; map<LL, LL> memo1[N]; map<PLL, LL> memo2[N]; LL getPointDis(LL id, LL u, LL v) { if(u == v) return 0; if(u > v) swap(u, v); if(memo2[id].count(mk(u, v))) return memo2[id][mk(u, v)]; if(v < sz[a[id]]) { return memo2[id][mk(u, v)] = getPointDis(a[id], u, v); } else if(u >= sz[a[id]]) { return memo2[id][mk(u, v)] = getPointDis(b[id], u - sz[a[id]], v - sz[a[id]]); } else { return memo2[id][mk(u, v)] = (getPointDis(a[id], u, c[id]) + getPointDis(b[id], v - sz[a[id]], d[id]) + (LL)l[id]) % mod; } } LL getAllToPoint(LL id, LL u) { if(id == 0) return 0; if(memo1[id].count(u)) return memo1[id][u]; LL ans = 0; if(u < sz[a[id]]) { add(ans, (getAllToPoint(b[id], d[id]) + sz[b[id]] % mod * (l[id] + getPointDis(a[id], u, c[id])) % mod) % mod); add(ans, getAllToPoint(a[id], u)); } else { add(ans, (getAllToPoint(a[id], c[id]) + sz[a[id]] % mod * (l[id] + getPointDis(b[id], u - sz[a[id]], d[id])) % mod) % mod); add(ans, getAllToPoint(b[id], u - sz[a[id]])); } return memo1[id][u] = ans; } void init() { for(int i = 0; i < N; i++) { memo1[i].clear(); memo2[i].clear(); } memset(sz, 0, sizeof(sz)); memset(ans, 0, sizeof(ans)); } int main() { while(scanf("%lld", &m) != EOF) { init(); sz[0] = 1; for(int i = 1; i <= m; i++) { scanf("%lld%lld%lld%lld%lld", &a[i], &b[i], &c[i], &d[i], &l[i]); sz[i] = sz[a[i]] + sz[b[i]]; } for(int i = 1; i <= m; i++) { ans[i] = (ans[a[i]] + ans[b[i]]) % mod; add(ans[i], sz[b[i]] % mod * getAllToPoint(a[i], c[i]) % mod); add(ans[i], sz[a[i]] % mod * getAllToPoint(b[i], d[i]) % mod); add(ans[i], (sz[a[i]] % mod) * (sz[b[i]] % mod) % mod * l[i] % mod); } for(int i = 1; i <= m; i++) { printf("%lld ", ans[i]); } } return 0; } /* */