北京市商汤科技开发有限公司建立了新的 AI 人工智能产业园,这个产业园区里有 nn 个路口,由 n - 1n−1 条道路连通。第 ii 条道路连接路口 u_iui 和 v_ivi。
每个路口都布有一台信号发射器,信号频段是 11 到 mm 之间的一个整数。
道路所连接的两个路口的发射信号叠加可能会影响道路的正常运行。具体地,如果第 ii 条道路连接的两个路口发射信号的频段分别为 aa 和 bb,那么 gcd(a, b)gcd(a,b) 不能恰好等于道路的保留频段 w_iwi。每条道路的保留频段是唯一的,即不会与其余任何道路的保留频段相同。
你现在需要确定每个路口发射信号的频段,使其符合要求。
在开始之前,你想先算出共有多少种合法的方案。
由于答案可能很大,输出对 10 ^ 9 + 7109+7 取模的值作为答案。
输入格式
第一行,两个正整数 n, mn,m 分别代表路口数量和信号频段上限。
接下来 n - 1n−1 行,每行描述一条道路。第 ii 行有三个整数 u_i, v_i, w_iui,vi,wi 意义如上所述。
保证 1 le n le m, 1 le u_i, v_i le n, 1 le w_i le m1≤n≤m,1≤ui,vi≤n,1≤wi≤m。
输出格式
输出一个整数,代表合法方案的数量对 10 ^ 9 + 7109+7 取模的值。
数据范围
- m le 10 ^ 3m≤103
样例解释
所有合法的方案为 (2, 2, 1), (2, 2, 3), (3, 3, 1), (3, 3, 2), (3, 3, 3)(2,2,1),(2,2,3),(3,3,1),(3,3,2),(3,3,3)。
样例输入
3 3 1 2 1 1 3 2
样例输出
5
思路:
紧紧抓住题目给的一个条件,每一条边的权值w是不同的。
我们定义状态 dp[i][j] 代表 第i个节点赋值为j可以和它的子树所有节点赋值不冲突时的可能方案数。
对于对一个节点 枚举j从1~m来转移,
对于每一个j,我们枚举节点的所有子节点, 对于节点与其子节点的权值w,枚举w的倍数,求与j的gcd是否为w,
与倍数t的gcd如果为w,该子节点对farther节点的贡献就应该减去 dp[v_son_id][t]
一个节点的所有子节点的总贡献应该是相乘的关系。
再用一个数组sum[i] 来维护i节点的赋值为1~m所有值得可能方案数的sum和。父节点要用到。
细节见代码:
#include <iostream> #include <cstdio> #include <cstring> #include <algorithm> #include <cmath> #include <queue> #include <stack> #include <map> #include <set> #include <vector> #include <iomanip> #define ALL(x) (x).begin(), (x).end() #define rt return #define dll(x) scanf("%I64d",&x) #define xll(x) printf("%I64d ",x) #define sz(a) int(a.size()) #define all(a) a.begin(), a.end() #define rep(i,x,n) for(int i=x;i<n;i++) #define repd(i,x,n) for(int i=x;i<=n;i++) #define pii pair<int,int> #define pll pair<long long ,long long> #define gbtb ios::sync_with_stdio(false),cin.tie(0),cout.tie(0) #define MS0(X) memset((X), 0, sizeof((X))) #define MSC0(X) memset((X), ' ', sizeof((X))) #define pb push_back #define mp make_pair #define fi first #define se second #define eps 1e-6 #define gg(x) getInt(&x) #define db(x) cout<<"== [ "<<x<<" ] =="<<endl; using namespace std; typedef long long ll; ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;} ll lcm(ll a, ll b) {return a / gcd(a, b) * b;} ll powmod(ll a, ll b, ll MOD) {ll ans = 1; while (b) {if (b % 2)ans = ans * a % MOD; a = a * a % MOD; b /= 2;} return ans;} inline void getInt(int* p); const int maxn = 1000010; const int inf = 0x3f3f3f3f; /*** TEMPLATE CODE * * STARTS HERE ***/ int n, m; const ll mod = 1e9 + 7; ll dp[1010][1011]; struct node { int next; int w; node() {} node(int nn, int ww) { next = nn; w = ww; } }; std::vector<node> v[1010]; ll sum[2020]; void dfs(int x, int pre) { for (auto temp : v[x]) { if (temp.next != pre) { dfs(temp.next, x); } } repd(i, 1, m) { ll num = 1ll; for (auto temp : v[x]) { if (temp.next == pre) continue; ll tot = sum[temp.next]; if (temp.next != pre) { for (ll j = temp.w; j <= m; j += temp.w) { if (gcd(j, i) == temp.w) { tot = (tot - dp[temp.next][j] + mod) % mod; } } } num *= tot; num %= mod; } dp[x][i] = (dp[x][i] + num) % mod; } repd(i, 1, m) { sum[x] += dp[x][i]; sum[x] %= mod; } } int main() { //freopen("D:\common_text\code_stream\in.txt","r",stdin); //freopen("D:\common_text\code_stream\out.txt","w",stdout); gbtb; int x, y, w; cin >> n >> m; repd(i, 2, n) { cin >> x >> y >> w; v[x].push_back(node(y, w)); v[y].push_back(node(x, w)); } dfs(1, 0); // repd(i,1,m) // { // db(dp[x][1]); // } cout << sum[1] << endl; return 0; } inline void getInt(int* p) { char ch; do { ch = getchar(); } while (ch == ' ' || ch == ' '); if (ch == '-') { *p = -(getchar() - '0'); while ((ch = getchar()) >= '0' && ch <= '9') { *p = *p * 10 - ch + '0'; } } else { *p = ch - '0'; while ((ch = getchar()) >= '0' && ch <= '9') { *p = *p * 10 + ch - '0'; } } }