• Codeforces 148D Bag of mice:概率dp 记忆化搜索


    题目链接:http://codeforces.com/problemset/problem/148/D

    题意:

      一个袋子中有w只白老鼠,b只黑老鼠。

      公主和龙轮流从袋子里随机抓一只老鼠出来,不放回,公主先拿。

      公主每次抓一只出来。龙每次在抓一只出来之后,会随机有一只老鼠跳出来(被龙吓的了。。。)。

      先抓到白老鼠的人赢。若两人最后都没有抓到白老鼠,则龙赢。

      问你公主赢的概率。

    题解:

      表示状态:

        dp[i][j] = probability to win(当前公主先手,公主赢的概率)

        i:剩i只白老鼠

        j:剩j只黑老鼠

      找出答案:

        ans = dp[w][b]

      边界条件:

        if i==0 dp[i][j] = 0 (没有白老鼠了,不可能赢)

        else if j==0 dp[i][j] = 1 (有且只有白老鼠,一定赢)

        else if j==1 dp[i][j] = i/(i+1) (如果公主拿了黑老鼠,那么龙一定会拿到白老鼠,公主输。所以公主一下就要拿到白老鼠)

      如何转移:

        对于dp[i][j],有两种赢的方法:

          (1)公主在这个回合一次就抓到了白老鼠。

          (2)公主和龙都各抓了一只黑老鼠,然后公主在下一个回合赢了。

        P(一次就抓到了白老鼠) = i/(i+j)

        P(进入下个回合,即两人都抓到黑老鼠) = P(公主抓到黑老鼠) * P(龙抓到黑老鼠) = j/(i+j) * (j-1)/(i+j-1)

        所以dp[i][j] = P(一次就抓到了白老鼠) + P(进入下个回合) * P(在下个回合赢)

        

        那么考虑下个回合可能的状态。

        因为公主和龙都已经抓走了两只黑老鼠,那么下个回合取决于跳出来的老鼠,有三种可能:

          (1)跳出来白老鼠

          (2)跳出来黑老鼠

          (3)老鼠已经抓完了,没有老鼠跳出来

        对于情况(3),原状态(i,j)只可能为:(1,1) , (0,2) , (2,0),均包含在边界条件中,所以不作考虑。

        剩下两种情况的可能性:

          (1)P(跳出来白老鼠) = i/(i+j-2) (i>=1 and j>=2)

          (2)P(跳出来黑老鼠) = (j-2)/(i+j-2) (j>=3)

        所以P(在下个回合赢) = P(跳出来白老鼠) * dp[i-1][j-2] + P(跳出来黑老鼠) * dp[i][j-3]

        总方程:

          nex = 0

          if i>=1 and j>=2 nex += i/(i+j-2)*dp[i-1][j-2]

          if j>=3 nex += (j-2)/(i+j-2)*dp[i][j-3]

          dp[i][j] = i/(i+j) + j/(i+j) * (j-1)/(i+j-1) * nex

      另外,这道题的题解有两个版本,一种记忆化搜索,一种for循环版,都差不多。

    AC Code(记忆化搜索):

     1 // state expression:
     2 // dp[i][j] = probability to win
     3 // i: i white mice
     4 // j: j black mice
     5 //
     6 // find the answer:
     7 // ans = dp[w][b]
     8 //
     9 // transferring:
    10 // if i>=1 and j>=2 nex += i/(i+j-2)*dp[i-1][j-2]
    11 // if j>=3 nex += (j-2)/(i+j-2)*dp[i][j-3]
    12 // dp[i][j] = i/(i+j) + j/(i+j) * (j-1)/(i+j-1) * nex
    13 //
    14 // boundary:
    15 // if i==0 dp[i][j] = 0
    16 // if j==0 dp[i][j] = 1
    17 // if j==1 dp[i][j] = i/(i+1)
    18 #include <iostream>
    19 #include <stdio.h>
    20 #include <string.h>
    21 #define MAX_N 1005
    22 
    23 using namespace std;
    24 
    25 int w,b;
    26 bool vis[MAX_N][MAX_N];
    27 double ans;
    28 double dp[MAX_N][MAX_N];
    29 
    30 double dfs(int i,int j)
    31 {
    32     if(vis[i][j]) return dp[i][j];
    33     vis[i][j]=true;
    34     if(i==0) return dp[i][j]=0;
    35     if(j==0) return dp[i][j]=1;
    36     if(j==1) return dp[i][j]=(double)i/(i+1);
    37     double nex=0;
    38     nex+=(double)i/(i+j-2)*dfs(i-1,j-2);
    39     if(j>=3) nex+=(double)(j-2)/(i+j-2)*dfs(i,j-3);
    40     return dp[i][j]=(double)i/(i+j)+(double)j/(i+j)*(j-1)/(i+j-1)*nex;
    41 }
    42 
    43 void read()
    44 {
    45     cin>>w>>b;
    46 }
    47 
    48 void solve()
    49 {
    50     memset(vis,false,sizeof(vis));
    51     ans=dfs(w,b);
    52 }
    53 
    54 void print()
    55 {
    56     printf("%.9f
    ",ans);
    57 }
    58 
    59 int main()
    60 {
    61     read();
    62     solve();
    63     print();
    64 }

    AC Code(for循环):

     1 #include <iostream>
     2 #include <stdio.h>
     3 #include <string.h>
     4 #define MAX_N 1005
     5 
     6 using namespace std;
     7 
     8 int w,b;
     9 double ans;
    10 double dp[MAX_N][MAX_N];
    11 
    12 void read()
    13 {
    14     cin>>w>>b;
    15 }
    16 
    17 void solve()
    18 {
    19     memset(dp,0,sizeof(dp));
    20     for(int i=0;i<=w;i++)
    21     {
    22         for(int j=0;j<=b;j++)
    23         {
    24             if(i==0)
    25             {
    26                 dp[i][j]=0;
    27                 continue;
    28             }
    29             if(j==0)
    30             {
    31                 dp[i][j]=1;
    32                 continue;
    33             }
    34             if(j==1)
    35             {
    36                 dp[i][j]=(double)i/(i+1);
    37                 continue;
    38             }
    39             double nex=(double)i/(i+j-2)*dp[i-1][j-2];
    40             if(j>=3) nex+=(double)(j-2)/(i+j-2)*dp[i][j-3];
    41             dp[i][j]=(double)i/(i+j)+(double)j/(i+j)*(j-1)/(i+j-1)*nex;
    42         }
    43     }
    44 }
    45 
    46 void print()
    47 {
    48     printf("%.9f
    ",dp[w][b]);
    49 }
    50 
    51 int main()
    52 {
    53     read();
    54     solve();
    55     print();
    56 }
  • 相关阅读:
    deepcopy list,dict
    朴素贝叶斯
    COMP6714 week2a skipTo()
    batch normalization / layer normalization
    self-attention Transformer
    44. 通配符匹配
    FOJ 10月赛题 FOJ2198~2204
    CF #323 DIV2 D题
    HDU 5467
    CF #321 (Div. 2) E
  • 原文地址:https://www.cnblogs.com/Leohh/p/7468561.html
Copyright © 2020-2023  润新知