http://uoj.ac/problem/279
先判断答案为0的情况,(d(i,i)
eq 0),(d(i,j)
eq d(j,i)),(d(i,j)>d(i,k)+d(k,j)),(d(i,j)>k)。
对于(d(i,j)>0)的情况,如果存在(k
eq i,j)且满足(d(i,j)=d(i,k)+d(k,j)),那么i和j的边就可以取d(i,j)~k的所有权值,答案乘上(k-d(i,j)+1)即可。
如果存在(d(i,j)=0)的情况,用并查集把最短距离为0的点缩起来,形成许多连通块。两个连通块之前会有很多边,且这些边的d值相同。把连通块看成一个点,那么就成了上一行的问题中有d值相同重边的情况;设ij之间的重边数为a,这个情况中如果存在上一行所述的k,答案乘上((k-d(i,j)+1)^a);如果不存在上一行所述的k,答案乘上((k-d(i,j)+1)^a-(k-d(i,j))^a)。
对于每个连通块内的方案数,假设连通块大小为n,设边权为0的边连通整个连通块的方案数为(f(n)),不管连不连通所有边权任意的方案数为(g(n)),因为这个连通块是个完全图,所以(g(n)=(k+1)^{frac{n(n-1)}2}),同时:$$f(n)=g(n)-sum_{i=1}^{n-1}f(i)g(n-i)
egin{pmatrix}n-1i-1 end{pmatrix} k^{i(n-i)}$$
先定住一个连通块内的点x,转移时枚举包括x的连通块的大小(f(i))。
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N = 403;
const int p = 998244353;
int fa[N], d[N][N], n, k, F[N], G[N], jc[N], jcn[N], cnt[N][N], sum[N];
inline int ipow(int a, int b) {
int r = 1, w = a;
while (b) {
if (b & 1) r = 1ll * r * w % p;
b >>= 1;
w = 1ll * w * w % p;
}
return r;
}
inline int C(int a, int b) {return 1ll * jc[a] * jcn[b] % p * jcn[a - b] % p;}
inline void sub(int &a, int b) {a -= b; if (a < 0) a += p;}
int find(int x) {return fa[x] == x ? x : fa[x] = find(fa[x]);}
inline void merge(int a, int b) {
if ((a = find(a)) != (b = find(b)))
fa[a] = b;
}
int ans = 1;
int main() {
scanf("%d%d", &n, &k);
for (int i = 1; i <= n; ++i)
for (int j = 1; j <= n; ++j) {
scanf("%d", &d[i][j]);
if (d[i][j] > k) {puts("0"); return 0;}
}
for (int i = 1; i <= n; ++i) {
if (d[i][i] != 0) {puts("0"); return 0;}
for (int j = i + 1; j <= n; ++j)
if (d[i][j] != d[j][i])
{puts("0"); return 0;}
}
for (int i = 1; i <= n; ++i)
for (int j = 1; j <= n; ++j)
for (int kk = 1; kk <= n; ++kk)
if (d[i][j] > d[i][kk] + d[kk][j])
{puts("0"); return 0;}
jc[0] = jcn[0] = 1;
for (int i = 1; i <= n; ++i) {
jc[i] = 1ll * jc[i - 1] * i % p;
jcn[i] = ipow(jc[i], p - 2);
}
for (int i = 1; i <= n; ++i) {
F[i] = G[i] = ipow(k + 1, i * (i - 1) / 2);
for (int j = 1; j < i; ++j)
sub(F[i], 1ll * F[j] * G[i - j] % p * C(i - 1, j - 1) % p * ipow(k, j * (i - j)) % p);
}
for (int i = 1; i <= n; ++i) fa[i] = i;
for (int i = 1; i <= n; ++i)
for (int j = i + 1; j <= n; ++j)
if (d[i][j] == 0)
merge(i, j);
int u, v;
for (int i = 1; i <= n; ++i)
for (int j = i + 1; j <= n; ++j)
if ((u = find(i)) != (v = find(j)))
++cnt[u][v], ++cnt[v][u];
int a;
for (int i = 1; i <= n; ++i)
for (int j = i + 1; j <= n; ++j)
if (a = cnt[i][j]) {
bool flag = false;
for (int kk = 1; kk <= n; ++kk)
if (cnt[i][kk] && cnt[kk][j] && d[i][j] == d[i][kk] + d[kk][j]) {
flag = true;
break;
}
if (flag) ans = 1ll * ans * ipow(k - d[i][j] + 1, a) % p;
else ans = 1ll * ans * ((ipow(k - d[i][j] + 1, a) - ipow(k - d[i][j], a) + p) % p) % p;
}
for (int i = 1; i <= n; ++i) ++sum[find(i)];
for (int i = 1; i <= n; ++i)
if (sum[i]) ans = 1ll * ans * F[sum[i]] % p;
printf("%d
", ans);
return 0;
}