题面
( ext{Description:})
给一张 (n) 割点 (m) 条边的 ({ m DAG}) ,保证点 (1)不存在入边,现在需要在({ m DAG}) 中加入一条不在原图中的边 ({ m (x,y)}) , 求这个有向图以 (1) 为根的树形图个数对 (1e9+7) 取模的结果
(n<=100000,m<=200000)
( ext{Solution:})
考虑没有加边,是 ({ m DAG}) 的情况,那么答案就是每个点的入度之积,相当于每个点选一个指向自己的点作为父亲,有入度种选法,有乘法原理:
[{
m Ans} = prod in[x]
]
现在有一条边加进来,如果每个点还是像原来那么随便,可能会成环,选考虑减去不合法情况。
对于一个环,它所能“贡献”的不合法数量为
[prod in[x] (x在环上)
]
原因是 (x) 的父亲只选环上的边,其它的乱选,这样一定成环,一定不合法。
所以我们把所有环的 "贡献" 减去即可。
建反图(原来怎么乘过来,现在怎么减回去),直接在 ({ m DAG}) 上 ({ m dp}) 求非法"贡献"。
转移:
[f[x] = frac{sum f[y]}{in[x]} (y连向x)
]
#include <vector>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <bitset>
#include <iostream>
#include <assert.h>
#include <algorithm>
using namespace std;
#define LL long long
#define debug(...) fprintf(stderr, __VA_ARGS__)
#define GO debug("GO
")
struct Stream {
template<typename T>
inline T rint() {
register T x = 0, f = 1; register char c;
while (!isdigit(c = getchar())) if (c == '-') f = -1;
while (x = (x << 1) + (x << 3) + (c ^ 48), isdigit(c = getchar()));
return x * f;
}
template <typename _Tp>
Stream& operator>> (_Tp& x)
{ x = rint<_Tp>(); return *this; }
Stream& operator<< (int x)
{ printf("%d", x); return *this;}
Stream& operator<< (LL x)
{ printf("%lld", x); return *this;}
Stream& operator<< (char ch)
{ putchar(ch); return *this; }
} xin, xout;
template<typename T> inline void chkmin(T &a, T b) { a > b ? a = b : 0; }
template<typename T> inline void chkmax(T &a, T b) { a < b ? a = b : 0; }
const int N = 1e5 + 10, P = 1e9 + 7;
LL sum = 1, ans = 1;
bitset<N> vis;
int n, m, S, T, in[N];
LL f[N];
vector<int> G[N];
LL qpow(LL a, LL b) {
LL res = 1;
for (; b; b >>= 1, a = a * a % P)
if (b & 1) res = res * a % P;
return res;
}
void DFS(int u) {
if (vis[u])
return;
if (u == T) {
f[u] = sum * qpow(in[u], P - 2) % P;
return;
}
vis[u] = 1;
for (vector<int>::iterator it = G[u].begin(); it != G[u].end(); ++ it) {
DFS(*it);
f[u] = (f[u] + f[*it]) % P;
}
f[u] = (f[u] * qpow(in[u], P - 2)) % P;
}
int main() {
#ifndef ONLINE_JUDGE
freopen("xhc.in", "r", stdin);
freopen("xhc.out", "w", stdout);
#endif
xin >> n >> m >> S >> T;
for (register int i = 1; i <= m; ++ i) {
int u, v;
xin >> u >> v;
G[v].push_back(u);
in[v]++;
}
in[1]++;
for (int i = 1; i <= n; ++ i) {
sum = sum * in[i] % P;
ans = ans * (in[i] + (i == T)) % P;
}
DFS(S);
xout << (ans - f[S] + P) % P << '
';
}