【SHTSC2014】概率充电器(charger)
Description
著名的电子产品品牌SHOI刚刚发布了引领世界潮流的下一代电子产品——概率充电器:
“采用全新纳米级加工技术,实现元件与导线能否通电完全由真随机数决定!SHOI概率充电器,您生活不可或缺的必需品!能充上电吗?现在就试试看吧!”
SHOI概率充电器由n-1条导线连通了n个充电元件。进行充电时,每条导线是否可以导电以概率决定,每一个充电元件自身是否直接进行充电也由概率决定。随后电能可以从直接充电的元件经过通电的导线使得其他充电元件进行间接充电。
作为SHOI公司的忠实客户,你无法抑制自己购买SHOI产品的冲动。在排了一个星期的长队之后终于入手了最新型号的SHOI概率充电器。你迫不及待地将SHOI概率充电器插入电源——这时你突然想知道,进入充电状态的元件个数的期望是多少呢?
Input
第一行一个整数:n。概率充电器的充电元件个数。充电元件由1-n编号。
之后的n-1行每行三个整数a, b, p,描述了一根导线连接了编号为a和b的充电元件,通电概率为p%。
第n+2行n个整数:qi。表示i号元件直接充电的概率为qi%。
Output
输出一行一个实数,为能进入充电状态的元件个数的期望,四舍五入到小数点后6位小数。
Sample Input
输入1:
3
1 2 50
1 3 50
50 0 0
输入2:
5
1 2 90
1 3 80
1 4 70
1 5 60
100 10 20 30 40
Sample Output
输出1:
1.000000
输出2:
4.300000
Data Constraint
对于30%的数据,n≤5000。
对于100%的数据,n≤500000,0≤p,qi≤100。
题解
看到给出的图是一棵树,很显然这道题是一个树形的期望DP
我们设第(i)个元件进入充电状态的概率,仔细读题之后,发现我们要求的就是一个很简单的式子:$$sum_{i=1}^{n} p[i]$$
对于概率,我们需要知道两个很基础的公式(设有事件(A),发生的概率为(P_{A});有事件(B),发生的概率为(P_{B}))
- 事件(A)和事件(B)至少发生一件事的概率为:事件(A)发生事件(B)不发生的概率+事件(A)不发生事件(B)发生的概率+事件(A)和事件(B)都发生的概率
转化成式子如下(设事件(A)和事件(B)至少发生一件事的概率为(P')):
- 知道发生其中一件事的概率和两件事至少发生一件的概率,求另一件事发生的概率为(假设知道(P_{B}),概率表示方式跟公式1一样):
然后回过头来分析题目,一个元件只有以下三种进入充电状态的方式:
- 父亲节点充电通过导线导过来
- 自己充电
- 儿子节点充电通过导线导过来
我们发现后面两种情况比较好处理,自己充电的概率题目已经给出了,儿子导过来只需要用公式1递归求出即可,从父亲节点导过来的情况有些困难
但是我们想想树的性质,树的根节点是没有父亲的(bushi),那我们让每个节点懂当一次没父亲的根节点不就行了吗?
于是这道题的做法就显现出来了:换根期望DP
首先我们以节点1为根做一遍(dfs)进行统计,然后这时候算出来的(p[1])是没有问题的,然后我们考虑怎么用当前算出的节点概率来更新他的子节点,如图(假设当前在处理节点3):
当进行完第一次(dfs)时,对节点3产生贡献的只有两条红色边以及以两条红色边到达的点为根的子树(即黄色圈圈起部分)
但是我们将节点3变成根时,如图:
对节点3产生贡献的多了蓝色边以及以蓝色边到达的点为根的子树(即绿色圈圈起部分)
于是我们可以利用公式2算出多贡献的部分的概率,然后再用公式1进行统计就行了(这里有点懵的可以看看代码)
但是由于GMOJ的评测机很辣鸡,动不动就会爆栈,所以要通过GMOJ的数据要写人工栈或(bfs)(但是洛谷和LOJ用(dfs)可以正常跑过)
CODE1(dfs版)
#include<cstdio>
#include<string>
#define R register int
#define N 500005
#define ll long long
using namespace std;
struct G{int to,next;double w;}e[N<<1];
int n,cnt,head[N];
double q[N],ans,p[N];
int max(int a,int b) {return a>b?a:b;}
int min(int a,int b) {return a<b?a:b;}
void read(int &x)
{
x=0;int f=1;char ch=getchar();
while (!isdigit(ch)) {if (ch=='-') f=-1;ch=getchar();}
while (isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar();x*=f;
}
void add(int u,int v,double w)
{
e[++cnt].to=v;e[cnt].w=w;
e[cnt].next=head[u];head[u]=cnt;
}
void dfs1(int u,int fa)
{
for (R i=head[u];i;i=e[i].next)
{
int v=e[i].to;if (v==fa) continue;dfs1(v,u);
q[u]=q[u]+q[v]*e[i].w-q[u]*q[v]*e[i].w;
}
}
void dfs2(int u,int fa)
{
ans+=p[u];
for (R i=head[u];i;i=e[i].next)
{
int v=e[i].to;if (v==fa) continue;
double npv=1-q[v]*e[i].w;
if (!npv) p[v]=1;
else
{
double k=(p[u]-q[v]*e[i].w)/npv;p[v]=q[v]+k*e[i].w-q[v]*k*e[i].w;
}
dfs2(v,u);
}
}
int main()
{
freopen("charger.in","r",stdin);
freopen("charger.out","w",stdout);
read(n);
for (R x,y,i=1;i<n;++i)
{
read(x);read(y);double z;scanf("%lf",&z);z/=100;
add(x,y,z);add(y,x,z);
}
for (R i=1;i<=n;++i)
scanf("%lf",&q[i]),q[i]/=100;
dfs1(1,0);p[1]=q[1];dfs2(1,0);printf("%.6lf
",ans);
return 0;
}
CODE2(人工栈+bfs版)
#include<cstdio>
#include<string>
#include<queue>
#define R register int
#define N 500005
#define ll long long
using namespace std;
struct G{int to,next;double w;}e[N<<1];
struct stack{int num;double w;}zhan[N<<1];
int n,cnt,head[N],fa[N],tot;
double q[N],ans,p[N],zhannum[N];
int max(int a,int b) {return a>b?a:b;}
int min(int a,int b) {return a<b?a:b;}
void read(int &x)
{
x=0;int f=1;char ch=getchar();
while (!isdigit(ch)) {if (ch=='-') f=-1;ch=getchar();}
while (isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar();x*=f;
}
void add(int u,int v,double w)
{
e[++cnt].to=v;e[cnt].w=w;
e[cnt].next=head[u];head[u]=cnt;
}
void bfs1()
{
queue<int>d;d.push(1);
while (!d.empty())
{
int u=d.front();d.pop();
for (R i=head[u];i;i=e[i].next)
{
int v=e[i].to;if (v==fa[u]) continue;
fa[v]=u;d.push(v);zhan[++tot].num=v;zhan[tot].w=e[i].w;
}
}
for (R i=tot;i;--i)
q[fa[zhan[i].num]]=q[fa[zhan[i].num]]+q[zhan[i].num]*zhan[i].w-q[fa[zhan[i].num]]*q[zhan[i].num]*zhan[i].w;
}
void bfs2()
{
queue<int>d;d.push(1);
while (!d.empty())
{
int u=d.front();d.pop();ans+=p[u];
for (R i=head[u];i;i=e[i].next)
{
int v=e[i].to;if (v==fa[u]) continue;
double npv=1-q[v]*e[i].w;d.push(v);
if (!npv) p[v]=1;
else
{
double k=(p[u]-q[v]*e[i].w)/npv;p[v]=q[v]+k*e[i].w-q[v]*k*e[i].w;
}
}
}
}
int main()
{
freopen("charger.in","r",stdin);
freopen("charger.out","w",stdout);
read(n);
for (R x,y,i=1;i<n;++i)
{
read(x);read(y);double z;scanf("%lf",&z);z/=100;
add(x,y,z);add(y,x,z);
}
for (R i=1;i<=n;++i)
scanf("%lf",&q[i]),q[i]/=100;
bfs1();p[1]=q[1];bfs2();printf("%.6lf
",ans);
return 0;
}