题目链接: (bzoj)https://www.lydsy.com/JudgeOnline/problem.php?id=4006
(luogu)https://www.luogu.org/problemnew/show/P3264
题解: 终于写出来斯坦纳树了。。
我一直不明白的地方是: spfa那种转移为什么是直接加边权?为什么没有一些特殊情况(如从根转移到儿子)不是加边权?后来觉得大概是因为那种特殊情况如果出现,则一定会在枚举子集的转移中被转移到。
做法就是,先对每个特殊点的子集求出来最小斯坦纳树,然后设(dp[S])表示颜色集合(S)内的最小答案,那么(dp[S])可以直接等于它所对应的关键点集合的斯坦纳树,也可以由好几个子集合并过来,枚举子集转移即可。
时间复杂度(O(ShortestPath(n,m) imes 2^p+n3^p))
这里貌似SPFA比Dijkstra略快一些。(我在洛谷上开O2,spfa 3234ms, Dijkstra 6695ms, 不开O2 spfa T成65, Dijkstra T成40)
代码
SPFA
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<queue>
using namespace std;
const int N = 1e3;
const int M = 3e3;
const int NN = 10;
const int INF = 707406378;
struct Edge
{
int v,w,nxt;
} e[(M<<1)+3];
int fe[N+3];
int ky[NN+3];
int clrset[(1<<NN)+3];
int clr[NN+3];
int dp[N+3][(1<<NN)+3];
int ans[(1<<NN)+3];
bool inq[M+3];
int que[M+3];
int n,m,nn,en;
void addedge(int u,int v,int w)
{
en++; e[en].v = v; e[en].w = w;
e[en].nxt = fe[u]; fe[u] = en;
}
void update(int &x,int y) {x = x<y?x:y;}
void SPFA(int sta)
{
int head = 1,tail = 1;
for(int i=1; i<=n; i++)
{
if(dp[i][sta]<INF)
{
que[tail] = i; tail++; if(tail>n+1) tail = 1;
inq[i] = true;
}
}
while(head!=tail)
{
int u = que[head]; head++; if(head>n+1) head = 1;
for(int i=fe[u]; i; i=e[i].nxt)
{
int v = e[i].v;
if(dp[u][sta]+e[i].w<dp[v][sta])
{
dp[v][sta] = dp[u][sta]+e[i].w;
if(!inq[v])
{
que[tail] = v; tail++; if(tail>n+1) tail = 1;
inq[v] = true;
}
}
}
inq[u] = false;
}
}
int main()
{
scanf("%d%d%d",&n,&m,&nn);
for(int i=1; i<=m; i++)
{
int x,y,z; scanf("%d%d%d",&x,&y,&z);
addedge(x,y,z); addedge(y,x,z);
}
for(int i=0; i<nn; i++)
{
scanf("%d%d",&clr[i],&ky[i]); clr[i]--;
clrset[1<<clr[i]] |= (1<<i);
}
memset(dp,42,sizeof(dp));
for(int i=0; i<nn; i++) dp[ky[i]][(1<<i)] = 0;
for(int i=1; i<(1<<nn); i++)
{
for(int j=(i-1)&i; j; j=(j-1)&i)
{
for(int k=1; k<=n; k++)
{
dp[k][i] = min(dp[k][i],dp[k][i^j]+dp[k][j]);
}
}
SPFA(i);
}
for(int i=1; i<(1<<nn); i<<=1)
{
for(int j=0; j<(1<<nn); j++)
{
if(j&i)
{
clrset[j] |= clrset[i];
}
}
}
for(int i=1; i<(1<<nn); i++)
{
ans[i] = INF;
for(int j=1; j<=n; j++)
{
update(ans[i],dp[j][clrset[i]]);
}
for(int j=(i-1)&i; j; j=(j-1)&i)
{
update(ans[i],ans[j]+ans[i^j]);
}
}
printf("%d
",ans[(1<<nn)-1]);
return 0;
}
Dijkstra
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<queue>
using namespace std;
const int N = 1e3;
const int M = 3e3;
const int NN = 10;
const int INF = 707406378;
struct Edge
{
int v,w,nxt;
} e[(M<<1)+3];
struct DijNode
{
int u,dis;
DijNode() {}
DijNode(int _u,int _dis) {u = _u,dis = _dis;}
bool operator <(const DijNode &arg) const {return dis>arg.dis;}
};
int fe[N+3];
bool vis[N+3];
int ky[NN+3];
int clrset[(1<<NN)+3];
int clr[NN+3];
int dp[N+3][(1<<NN)+3];
int ans[(1<<NN)+3];
priority_queue<DijNode> que;
int n,m,nn,en;
void addedge(int u,int v,int w)
{
en++; e[en].v = v; e[en].w = w;
e[en].nxt = fe[u]; fe[u] = en;
}
void update(int &x,int y) {x = min(x,y);}
void Dijkstra(int sta)
{
while(!que.empty())
{
DijNode tmp = que.top(); que.pop(); int u = tmp.u;
if(tmp.dis!=dp[u][sta]) continue;
vis[u] = true;
for(int i=fe[u]; i; i=e[i].nxt)
{
int v = e[i].v;
if(vis[v]==false && dp[u][sta]+e[i].w<dp[v][sta])
{
dp[v][sta] = dp[u][sta]+e[i].w;
que.push(DijNode(v,dp[v][sta]));
}
}
}
for(int i=1; i<=n; i++) vis[i] = false;
}
int main()
{
scanf("%d%d%d",&n,&m,&nn);
for(int i=1; i<=m; i++)
{
int x,y,z; scanf("%d%d%d",&x,&y,&z);
addedge(x,y,z); addedge(y,x,z);
}
for(int i=0; i<nn; i++)
{
scanf("%d%d",&clr[i],&ky[i]); clr[i]--;
clrset[1<<clr[i]] |= (1<<i);
}
memset(dp,42,sizeof(dp));
for(int i=0; i<nn; i++) dp[ky[i]][(1<<i)] = 0;
for(int i=1; i<(1<<nn); i++)
{
for(int j=(i-1)&i; j; j=(j-1)&i)
{
for(int k=1; k<=n; k++)
{
dp[k][i] = min(dp[k][i],dp[k][i^j]+dp[k][j]);
}
}
for(int j=1; j<=n; j++)
{
if(dp[j][i]!=INF)
{
que.push(DijNode(j,dp[j][i]));
}
}
Dijkstra(i);
}
for(int i=1; i<(1<<nn); i<<=1)
{
for(int j=0; j<(1<<nn); j++)
{
if(j&i)
{
clrset[j] |= clrset[i];
}
}
}
for(int i=1; i<(1<<nn); i++)
{
ans[i] = INF;
for(int j=1; j<=n; j++)
{
update(ans[i],dp[j][clrset[i]]);
}
for(int j=(i-1)&i; j; j=(j-1)&i)
{
update(ans[i],ans[j]+ans[i^j]);
}
}
printf("%d
",ans[(1<<nn)-1]);
return 0;
}