感觉这题可以模板化。
听说spfa死了,所以要练堆优化dijkstra。
首先对$x_{1},y_{1},x_{2},y_{2}$各跑一遍最短路,然后扫一遍所有边看看是不是同时在两个点的最短路里面,如果是的话就把这条边加到一张新图中去,因为最短路一定没有环,所以最后造出来的这张新图一定是一个$DAG$,dp一遍求最长链即为答案。
考虑一下怎么判断一条边是否在最短路里,设这条边连接的两个点是$x$,$y$,边权是$v$,如果它在最短路里面,那么有$dis(x_{1}, x) + v + dis(y_{1}, y) == dis(x_{1}, y_{1})$并且$dis(x_{2}, x) + v + dis(y_{2}, y) == dis(x_{2}, y_{2})$,注意第二个条件中$x$和$y$可以交换。加边的时候注意维持一下$DAG$的形态,可以把$x$和$y$到$x_{1}$的距离小的向距离大的连边。
时间复杂度$O(nlogn)$,堆优化dij是瓶颈。
感觉写得很长。
Code:
#include <cstdio> #include <cstring> #include <queue> #include <iostream> using namespace std; typedef pair <int, int> pin; const int N = 1505; const int M = 3e6 + 5; const int inf = 0x3f3f3f3f; int n, m, inx[M], iny[M], inv[M], deg[N], f[N], ans = 0; int c1, c2, c3, c4, tot = 0, head[N], dis[N], d[4][N]; bool vis[N]; struct Edge { int to, nxt, val; } e[M << 1]; inline void add(int from, int to, int val) { e[++tot].to = to; e[tot].val = val; e[tot].nxt = head[from]; head[from] = tot; } inline void read(int &X) { X = 0; char ch = 0; int op = 1; for(; ch > '9'|| ch < '0'; ch = getchar()) if(ch == '-') op = -1; for(; ch >= '0' && ch <= '9'; ch = getchar()) X = (X << 3) + (X << 1) + ch - 48; X *= op; } inline void swap(int &x, int &y) { int t = x; x = y; y = t; } priority_queue <pin> Q; void dij(int st) { memset(dis, 0x3f, sizeof(dis)); memset(vis, 0, sizeof(vis)); Q.push(pin(dis[st] = 0, st)); for(; !Q.empty(); ) { int x = Q.top().second; Q.pop(); if(vis[x]) continue; vis[x] = 1; for(int i = head[x]; i; i = e[i].nxt) { int y = e[i].to; if(dis[y] > dis[x] + e[i].val) { dis[y] = dis[x] + e[i].val; Q.push(pin(-dis[y], y)); } } } } inline void chkMax(int &x, int y) { if(y > x) x = y; } int dfs(int x) { if(vis[x]) return f[x]; vis[x] = 1; int res = 0; for(int i = head[x]; i; i = e[i].nxt) { int y = e[i].to; chkMax(res, dfs(y) + e[i].val); } return f[x] = res; } int main() { read(n), read(m), read(c1), read(c2), read(c3), read(c4); for(int i = 1; i <= m; i++) { read(inx[i]), read(iny[i]), read(inv[i]); add(inx[i], iny[i], inv[i]), add(iny[i], inx[i], inv[i]); } dij(c1); memcpy(d[0], dis, sizeof(d[0])); dij(c2); memcpy(d[1], dis, sizeof(d[1])); dij(c3); memcpy(d[2], dis, sizeof(d[2])); dij(c4); memcpy(d[3], dis, sizeof(d[3])); /* for(int i = 1; i <= n; i++) printf("%d ", d[0][i]); printf(" "); for(int i = 1; i <= n; i++) printf("%d ", d[1][i]); printf(" "); for(int i = 1; i <= n; i++) printf("%d ", d[2][i]); printf(" "); for(int i = 1; i <= n; i++) printf("%d ", d[3][i]); printf(" "); */ tot = 0; memset(head, 0, sizeof(head)); for(int i = 1; i <= m; i++) { int x = inx[i], y = iny[i], v = inv[i]; if(d[0][x] + v + d[1][y] == d[0][c2]) if(d[2][y] + v + d[3][x] == d[2][c4] || d[2][x] + v + d[3][y] == d[2][c4]) { if(d[0][x] < d[0][y]) { add(x, y, v); deg[y]++; } else { add(y, x, v); deg[x]++; } } swap(x, y); if(d[0][x] + v + d[1][y] == d[0][c2]) if(d[2][y] + v + d[3][x] == d[2][c4] || d[2][x] + v + d[3][y] == d[2][c4]) { if(d[0][x] < d[0][y]) { add(x, y, v); deg[y]++; } else { add(y, x, v); deg[x]++; } } } memset(vis, 0, sizeof(vis)); for(int i = 1; i <= n; i++) if(deg[i] == 0 && !vis[i]) dfs(i); /* for(int i = 1; i <= n; i++) printf("%d ", f[i]); printf(" "); */ for(int i = 1; i <= n; i++) chkMax(ans, f[i]); printf("%d ", ans); return 0; }