题意:(floyd)被写成了这个样子:
for i from 1 to n
for j from 1 to n
for k from 1 to n
dis[i][j] <- min(dis[i][j], dis[i][k] + dis[k][j])
求最后有多少个位置仍然是相等的。
解法1:这个错误的东西其实就是枚举每个点然后去更新其他的点,那么我们模拟这个过程去做,对于一个点对((i,j)),我们需要知道是否存在一个中转点(k)使得((i,k))是最短路边,((k,j))是最短路边,并且((i,k)+(k,j))为最短路,所以做法就是对于每个根建出最短路的(DAG),拓扑对每个点求出每个点会有多少点能够到达它,判断((i,j))是不是最短路时,用三个(bitset)与一下就行了。
解法2:直接去跑错误的做法,先把原图上的边权=最短路的边权加入到边集中,然后再对每个起点跑最短路,对于起点(s),可以发现它只有第一跳是可以向前跳的之外其他的都只能向后跳,在做(spfa)的时候记录一下就可以了,在做的过程中发现有当前最短路已经完成的时候就把边加进去。
#include <bits/stdc++.h>
#define N 2009
#define mm make_pair
using namespace std;
typedef long long ll;
const ll mod=99824353;
int tot,n,m,head[N],dis[N][N],p[N][N];
bool vis[N];
inline ll rd(){
ll x=0;char c=getchar();bool f=0;
while(!isdigit(c)){if(c=='-')f=1;c=getchar();}
while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
return f?-x:x;
}
struct edge{
int n,to,l;
}e[1000009];
inline void add(int u,int v,int l){
e[++tot].n=head[u];
e[tot].to=v;
head[u]=tot;
e[tot].l=l;
}
void spfa1(int id){
queue<int>q;
memset(dis[id],0x3f,sizeof(dis[id]));
dis[id][id]=0;
q.push(id);
while(!q.empty()){
int u=q.front();q.pop();vis[u]=0;
for(int i=head[u];i;i=e[i].n){
int v=e[i].to;
if(dis[id][v]>dis[id][u]+e[i].l){
dis[id][v]=dis[id][u]+e[i].l;
if(!vis[v]){
vis[v]=1;
q.push(v);
}
}
}
}
}
void spfa2(int id){
queue<pair<int,int> >q;
for(int i=head[id];i;i=e[i].n)q.push(mm(0,e[i].to));
while(!q.empty()){
pair<int,int> nw=q.front();q.pop();
int u=nw.second;vis[u]=0;
for(int i=head[u];i;i=e[i].n){
int v=e[i].to;
if(v<nw.first)continue;
if(dis[id][v]==p[id][v]||dis[id][v]>1e9)continue;
if(dis[id][v]==dis[id][u]+e[i].l){
p[id][v]=dis[id][v];
add(id,v,dis[id][v]);
if(!vis[v]){
vis[v]=1;
q.push(mm(v,v));
}
}
}
}
}
int main(){
n=rd();m=rd();
memset(p,0x3f,sizeof(p));
int u,v,w;
for(int i=1;i<=m;++i){
u=rd();v=rd();w=rd();
p[u][v]=w;
add(u,v,w);
}
for(int i=1;i<=n;++i)p[i][i]=0;
for(int i=1;i<=n;++i)spfa1(i);
memset(head,0,sizeof(head));tot=0;
for(int i=1;i<=n;++i)
for(int j=1;j<=n;++j)if(i!=j&&p[i][j]<1e9&&p[i][j]==dis[i][j]){
add(i,j,dis[i][j]);
}
for(int i=1;i<=n;++i)spfa2(i);
int ans=0;
for(int i=1;i<=n;++i)
for(int j=1;j<=n;++j)if(dis[i][j]==p[i][j])ans++;
printf("%d
",ans);
return 0;
}