[题目链接]
https://www.lydsy.com/JudgeOnline/problem.php?id=1415
[算法]
首先BFS预处理出点与点之间的最短路 , 求出每次聪聪的下一步将会往哪走
然后 , 用f[i][j]表示聪聪在i , 可可在j , 期望走的步数是多少 , 概率DP即可
时间复杂度 : O(N ^ 2)
[代码]
#include<bits/stdc++.h> using namespace std; #define MAXN 1010 const int inf = 2e9; int n , m , s , t , tot; int deg[MAXN] , head[MAXN]; int nxt[MAXN][MAXN] , dist[MAXN][MAXN]; bool visited[MAXN][MAXN]; double f[MAXN][MAXN]; struct edge { int to , nxt; } e[MAXN << 1]; template <typename T> inline void chkmax(T &x,T y) { x = max(x,y); } template <typename T> inline void chkmin(T &x,T y) { x = min(x,y); } template <typename T> inline void read(T &x) { T f = 1; x = 0; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + c - '0'; x *= f; } inline void addedge(int u , int v) { ++tot; e[tot] = (edge){v , head[u]}; head[u] = tot; } inline void bfs(int s) { int l , r; static int q[MAXN]; for (int i = 0; i <= n; i++) dist[s][i] = inf; dist[s][s] = 0; q[l = r = 1] = s; while (l <= r) { int cur = q[l++]; for (int i = head[cur]; i; i = e[i].nxt) { int v = e[i].to; if (dist[s][cur] + 1 < dist[s][v]) { dist[s][v] = dist[s][cur] + 1; q[++r] = v; } } } } inline double dp(int s , int t) { if (visited[s][t]) return f[s][t]; if (s == t) return f[s][t] = 0; if (nxt[s][t] == t) return f[s][t] = 1; if (nxt[nxt[s][t]][t] == t) return f[s][t] = 1; for (int i = head[t]; i; i = e[i].nxt) { int v = e[i].to; f[s][t] += 1.0 / (deg[t] + 1) * (dp(nxt[nxt[s][t]][t] , v) + 1); } f[s][t] += 1.0 / (deg[t] + 1) * (dp(nxt[nxt[s][t]][t] , t) + 1); visited[s][t] = true; return f[s][t]; } int main() { read(n); read(m); read(s); read(t); for (int i = 1; i <= m; i++) { int u , v; read(u); read(v); addedge(u , v); addedge(v , u); ++deg[u]; ++deg[v]; } for (int i = 1; i <= n; i++) bfs(i); for (int i = 1; i <= n; i++) { for (int j = 1; j <= n; j++) { for (int k = head[i]; k; k = e[k].nxt) { int v = e[k].to; if (dist[v][j] < dist[j][nxt[i][j]] || (dist[v][j] == dist[j][nxt[i][j]] && v < nxt[i][j])) nxt[i][j] = v; } } } printf("%.3lf " , dp(s , t)); return 0; }