分块DP,真是奇妙的想法。
(dp[i][j][k])表示(i)到(j)走恰好(k)步的最短路。
我们可以用(Floyd)来处理。
然后我们再遍历整个数组,求(min),让(dp[i][j][k])表示(i)到(j)走至少(k)步的最短路。
但是我们不能开(dp[55][55][10005])的数组,时间空间都会炸,怎么办呢?
贪心地一想,如果我们的第(100)步处理好后,(dp[i][j][100])已经是最优的了,那我们(dp[i][j][200])可以直接用(dp[i][j][100])来推导出来
定义:(dp2[i][j][k])表示(i)到(j)走至少(k*100)步的最短路。
(dp/dp2)的数组更新基本相同。
统计答案时,如果(k_i leq 100),直接输出(dp1[s_i][t_i][k_i])
如果(100 leq k_i),(v1 = k_i / 100, v2 = k_i Mod 100),枚举中间点(p) , (Ans = min(dp[s_i][p][v2] +dp2[p][t_i][v1], dp2[s_i][p][v1] + dp[p][t_i][v2]))
注意(k_i = 100k) 时, (v2 = 100 , v1 = 100(k-1)) ,这样比直接(Ans = dp2[s_i][t_i][k])考虑的情况更全面。
现在想一想,为什么偏偏是(100)呢?
因为(sqrt{max(k_i)} = sqrt{10000} = 100),这就是分块了。
最后,贪心认为(dp[i][j][100])已经是最优的了,是为什么呢?
理论上我们要做到(dp[i][j][10000])来证明最优,但是我们实际只要考虑到(dp[i][j][200])即可,
反正卡着时间和空间,考虑到最大就可以了。(O(50 * 50 * k))
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
using namespace std;
const int inf = 1e9+7;
int n,m,q;
int dp1[55][55][205],dp2[55][55][205];
int main(){
int T; scanf("%d",&T);
while(T --){
scanf("%d%d",&n,&m);
for(int i = 1; i <= 50; ++ i)
for(int j = 1; j <= 50; ++ j)
for(int k = 1; k <= 200; ++ k)
dp1[i][j][k] = dp2[i][j][k] = inf;
for(int i = 1; i <= m; ++ i){
int x,y,z; scanf("%d%d%d",&x,&y,&z);
dp1[x][y][1] = min(dp1[x][y][1],z);
}
for(int p = 2; p <= 200; ++ p)
for(int k = 1; k <= n; ++ k)
for(int i = 1; i <= n; ++ i)
for(int j = 1; j <= n; ++ j){
dp1[i][j][p] = min(dp1[i][j][p], dp1[i][k][p - 1] + dp1[k][j][1]);
}
for(int p = 199; p >= 1; -- p)
for(int i = 1; i <= n; ++ i)
for(int j = 1; j <= n; ++ j){
dp1[i][j][p] = min(dp1[i][j][p], dp1[i][j][p + 1]);
}
for(int i = 1; i <= n; ++ i)
for(int j = 1; j <= n; ++ j)
dp2[i][j][1] = dp1[i][j][100];
for(int p = 2; p <= 100; ++ p)
for(int k = 1; k <= n; ++ k)
for(int i = 1; i <= n; ++ i)
for(int j = 1; j <= n; ++ j){
dp2[i][j][p] = min(dp2[i][j][p], dp2[i][k][p - 1] + dp2[k][j][1]);
}
scanf("%d",&q);
while(q --){
int s,t,k; scanf("%d%d%d",&s,&t,&k);
int v1 = (k - 1) / 100, v2 = k - v1 * 100;
int ans = inf;
if(k <= 100) ans = dp1[s][t][k];
else {
for(int i = 1; i <= n; ++ i)
ans = min(ans, dp1[s][i][v2] + dp2[i][t][v1]), ans = min(ans, dp1[s][i][v2] + dp2[i][t][v1]);
//lalala
}
if(ans >= inf) printf("-1
");
else printf("%d
",ans);
}
}
return 0;
}