• 2017 国庆湖南 Day5


    期望得分:76+80+30=186

    实际得分:72+10+0=82

    先看第一问:

    本题不是求方案数,所以我们不关心 选的数是什么以及的选的顺序

    只关心选了某个数后,对当前gcd的影响

    预处理

    cnt[i] 表示 i的倍数有多少个

    g[i][j] 表示gcd(i,第j张卡片上的数)

    dp[i][j] 表示已经选了i个数,gcd=j 的 概率

    再选k,要么gcd不变,要么变小

    1、gcd不变 

    即k是j的倍数,因为已经选了i个且都是j的倍数,所以在剩下的n-i 个数中,还有 cnt[j]-i 个数可以选

    所以状态转移方程:dp[i+1][j]+=dp[i][j]*(cnt[j]-i)/(n-i)

    2、gcd变小  

    枚举要选的是第h个数 ,h满足gcd(a[h],j)!=j

    (a[h] 表示第h张卡片上的数)

    那么gcd会变为g[j][h]

    因为 当gcd=1 的时候游戏结束,即 gcd=1 不能用来转移

    所以 当gcd=1时,直接累计进答案,不更新dp

    所以状态转移方程:dp[i+1][g[j][h]+=dp[i][j]/(n-i),g[j][h]!=1

    答案的累计:

    1、dp 过程中 gcd=1

    只有 选了偶数个数之后,gcd=1,先手才赢

    所以 在dp过程中,若i是奇数,ans+=dp[i][j]/(n-i)

    (因为是在由i推出去的时候 累计答案,所以i是奇数)

    2、dp完之后,没有牌选了

    若n是奇数,则先手胜

    所以若n是奇数,ans+=dp[n][i] 

    第二问:

    就是裸地SG函数

    sg[i][j] 表示 已经选了i个数,gcd=j 是必胜态(1)还是必败态(0)

    根据

    必胜态的后继状态至少有一个是必败态

    必败态的后继状态全是必胜态

    用 & 运算符可以方便的记录

    记忆化搜索

    边界:sg[n][i]=0,sg[i][1]=1

    因为 选了n个数且j!=1 之后,对方败

    当gcd=1 之后,对方胜

    为什么要用对方的状态?(以下可能表述不清)

    因为边界是在dfs 最前面判断的,而且是从选了0张牌开始

    己方选了x张牌之后的状态,随dfs到了下一层里,即到了对方选的哪儿

    如果己方选了n张牌且gcd!=1,己方赢,但sg[n][]的状态是到下一层dfs里判断的

    主客交换,对方输,所以sg[n][]=0

    sg[i][1] 同理

    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    
    #define N 301
    #define K 1001
    
    using namespace std;
    
    const double eps=1e-8;
    
    int n,m,a[N];
    
    int cnt[K],g[K][N];
    
    double dp[N][K];
    
    int sg[N][K];
    
    int getgcd(int a,int b) { return !b ? a : getgcd(b,a%b); }
    
    void init()
    {
        scanf("%d",&n);
        for(int i=1;i<=n;i++) scanf("%d",&a[i]),m=max(m,a[i]); 
    }
    
    void pre()
    {
        for(int i=1;i<=n;i++) g[0][i]=a[i];
        for(int i=1;i<=m;i++)
            for(int j=1;j<=n;j++)
                cnt[i]+=(a[j]%i==0),g[i][j]=getgcd(i,a[j]);    
    }
    
    void getprobability()
    {
        double ans=0.0;
        dp[0][0]=1.0;
        for(int i=0;i<n;i++)
            for(int j=0;j<=m;j++)
                if(dp[i][j]>eps) 
                {
                    dp[i+1][j]+=dp[i][j]*(cnt[j]-i)/(n-i);
                    for(int k=1;k<=n;k++)
                        if(g[j][k]!=j)
                        {
                            if(g[j][k]!=1) dp[i+1][g[j][k]]+=dp[i][j]/(n-i);
                            else ans+=(i&1)*dp[i][j]/(n-i);
                        }    
                }
        if(n&1)
            for(int i=0;i<=m;i++) ans+=dp[n][i];
        printf("%.9lf",ans);
    }
    
    int dfs(int x,int gcd)
    {
        if(sg[x][gcd]!=-1) return sg[x][gcd];
        bool win=true;
        if(cnt[gcd]>x) win&=dfs(x+1,gcd);
        for(int i=1;i<=n;i++)
            if(g[gcd][i]!=gcd) win&=dfs(x+1,g[gcd][i]);
        return sg[x][gcd]=!win;
    }
    
    void getsg()
    {
        memset(sg,-1,sizeof(sg));
        for(int i=0;i<=m;i++) sg[n][i]=0;
        for(int i=0;i<=n;i++) sg[i][1]=1;
        if(dfs(0,0)) printf("1.000000000");
        else printf("0.000000000");
    }
    
    int main()
    {
        freopen("cards.in","r",stdin);
        freopen("cards.out","w",stdout);
        init();
        pre();
        getprobability(); 
        printf(" ");
        getsg();
    }
    View Code

    80分暴力:

    删边转化成倒着加边

    每次 加一条边,两个端点重新做树形DP,得到合并之后的树的权值

    用并查集维护连通块

    一个连通块就是一棵树,答案就是所有 连通块的权值的乘积

    维护乘积 乘一下再除一下就好了,考场上智商全掉了 用的线段树

    100分做法:

    上述做法慢就慢在每次加一条边,两个端点重新做树形DP

    这里有一个结论:

    设树S1最大权值路径的两端点为u1,u2

    设树S2最大权值路径的两端点为v1,v2

    那么树S1和树S2合并之后

    最大权值路径的两端点一定是u1,u2,v1,v2中的两个

    结论的简单证明:

    设合并之后的最大权值路径的两端点为k1、k2

    1、k1、k2 = u1、u2  或 k1、k2=v1、v2 ,显然成立

    2、k1 = u1或u2,k2=v1或v2

    如下图所示

    若选的最长权值路径为路径P+路径L1

    根据dfs求树的直径的原理可推得,

    w——v1 和 w——v2 中必有一条是从w出发的最大权值路径

    假设是w——v1

    那么选路径P+路径L2 更优

     

    有了上述结论

    那么我们每次合并只需要计算4条路径 、原来两棵树 的权值取最大

    我么需要维护

    val[i] 表示 当前i号连通块(树) 的最大权值

    endpoint[i][2] 表示 i号连通块对应val[i] 的两端点

    每次用最大的路径来更新这两个数组

    每次的答案=原答案/val[S1]/val[S2]*合并之后的最大权值

    如何计算路径权值?

    dfs 一遍记录树上前缀和len[]

    dis(u,v)=len[u]+len[v]-len[lca]+lca的权值

    #include<cstdio>
    #include<iostream>
    #include<algorithm> 
    
    using namespace std;
    
    #define N 100001
    
    const int mod=1e9+7;
    
    int n,cnt;
    int cut[N],e[N][2];
    
    int front[N],to[N<<1],nxt[N<<1],tot;
    
    int len[N],id[N];
    int fa[N][18];
    
    int a[N],val[N],ans[N];
    int endpoint[N][2];
    
    int F[N];
    
    void read(int &x)
    {
        x=0; char c=getchar();
        while(!isdigit(c)) c=getchar();
        while(isdigit(c)) { x=x*10+c-'0'; c=getchar();  }
    }
    
    void add(int u,int v)
    {
        to[++tot]=v; nxt[tot]=front[u]; front[u]=tot;
        to[++tot]=u; nxt[tot]=front[v]; front[v]=tot;
    }
    
    void init()
    {
        read(n); ans[n]=1;
        for(int i=1;i<=n;i++) 
        {
            read(val[i]);
            ans[n]=1ll*ans[n]*val[i]%mod;
            endpoint[i][0]=endpoint[i][1]=i;
            F[i]=i; a[i]=val[i];
        }
        int u,v;
        for(int i=1;i<n;i++)
        {
            read(u); read(v);
            add(u,v);
            e[i][0]=u; e[i][1]=v;
        }
        for(int i=1;i<n;i++) read(cut[i]);
    }
    
    void dfs(int x,int f)
    {
        fa[x][0]=f;
        len[x]=len[f]+a[x];
        id[x]=++cnt;
        for(int i=front[x];i;i=nxt[i])
            if(to[i]!=f) dfs(to[i],x);
    }
    
    void prelca()
    {
        for(int j=1;j<18;++j)
            for(int i=1;i<=n;i++)
                fa[i][j]=fa[fa[i][j-1]][j-1];
    }
    
    int getlca(int u,int v)
    {
        if(id[u]<id[v]) swap(u,v);
        for(int i=17;i>=0;i--)
            if(id[fa[u][i]]>id[v]) u=fa[u][i];
        return fa[u][0];
    }
    
    int getlength(int u,int v)
    {
        int lca=getlca(u,v);
        return len[u]+len[v]-2*len[lca]+a[lca];
    }
    
    int find(int i) { return F[i]==i ? i : F[i]=find(F[i]); }
    
    int Pow(int a,int b)
    {
        int res=1;
        for(;b;a=1ll*a*a%mod,b>>=1)
            if(b&1) res=1ll*res*a%mod;
        return res;
    }
    
    void solve()
    {
        int u,v; int product=ans[n],mx; 
        int l,e1,e2;
        for(int i=n-1;i;i--)
        {
            u=e[cut[i]][0],v=e[cut[i]][1];
            u=find(u); v=find(v);
            if(val[u]>val[v]) mx=val[u], e1=endpoint[u][0], e2=endpoint[u][1];
            else mx=val[v], e1=endpoint[v][0], e2=endpoint[v][1];
            for(int j=0;j<2;j++)
                for(int k=0;k<2;k++)
                {
                    l=getlength(endpoint[u][j],endpoint[v][k]);
                    if(l>mx)
                    {
                        mx=l;
                        e1=endpoint[u][j]; e2=endpoint[v][k];
                    }
                }
            product=1ll*product*Pow(val[u],mod-2)%mod;
            product=1ll*product*Pow(val[v],mod-2)%mod;
            product=1ll*product*mx%mod;
            ans[i]=product;
            F[u]=F[v];
            endpoint[v][0]=e1,endpoint[v][1]=e2;
            val[v]=mx;
        }
        for(int i=1;i<=n;i++) printf("%d
    ",ans[i]);
    }
    
    int main()
    {
        freopen("forest.in","r",stdin);
        freopen("forest.out","w",stdout);
        init();
        dfs(1,0);
        prelca();
        solve();
    }
    View Code

    80分暴力 

    #include<cstdio>
    #include<iostream>
    #include<algorithm>
    
    using namespace std;
    
    #define N 100001
    
    #define lowbit(x) x&-x
    
    const int mod=1e9+7;
    
    int val[N],e[N][2],cut[N];
    
    int front[N],to[N<<1],nxt[N<<1];
    
    int tmp,tot,n;
    
    int f[N][2],out[N];
    
    int F[N];
    
    int st[4],ans1,ans2;
    
    int g[N<<2];
    
    void read(int &x)
    {
        x=0; char c=getchar();
        while(!isdigit(c)) c=getchar();
        while(isdigit(c)) { x=x*10+c-'0'; c=getchar(); }
    }
    
    void add(int u,int v)
    {
        to[++tot]=v; nxt[tot]=front[u]; front[u]=tot;
        to[++tot]=u; nxt[tot]=front[v]; front[v]=tot;
    }
    
    void build(int k,int l,int r)
    {
        g[k]=val[l];
        if(l==r) return;
        int mid=l+r>>1;
        build(k<<1,l,mid); build(k<<1|1,mid+1,r);
        g[k]=1ll*g[k<<1]*g[k<<1|1]%mod;
    }
    
    void change(int k,int l,int r,int pos,int w)
    {
        if(l==r) { g[k]=w; return; }
        int mid=l+r>>1;
        if(pos<=mid) change(k<<1,l,mid,pos,w);
        else change(k<<1|1,mid+1,r,pos,w);
        g[k]=1; 
        if(g[k<<1]!=-1) g[k]=1ll*g[k]*g[k<<1]%mod;
        if(g[k<<1|1]!=-1) g[k]=1ll*g[k]*g[k<<1|1]%mod;
    }
    
    void init()
    {
        read(n); int m1=0,m2=0; out[n]=1;
        for(int i=1;i<=n;i++) 
        {
            read(val[i]); out[n]=1ll*out[n]*val[i]%mod;
            F[i]=i;
            if(val[i]>=m1) m2=m1,m1=val[i];
            else if(val[i]>m2) m2=val[i];
        }
        int u,v;
        for(int i=1;i<n;i++) read(e[i][0]),read(e[i][1]);
        for(int i=1;i<n;i++) read(cut[i]);
        build(1,1,n);
    }
    
    void dfs(int x,int fa)
    {
        bool leave=true;
        for(int i=front[x];i;i=nxt[i])
            if(to[i]!=fa) 
            {
                leave=false;
                dfs(to[i],x);
                if(f[to[i]][0]>=f[x][0]) f[x][1]=f[x][0],f[x][0]=f[to[i]][0];
                else if(f[to[i]][0]>f[x][1]) f[x][1]=f[to[i]][0];
                f[to[i]][0]=f[to[i]][1]=0;
            }
        f[x][0]+=val[x];
        tmp=max(tmp,f[x][0]+f[x][1]);
        if(!leave) f[x][1]+=val[x];
    }
    
    int find(int i) { return F[i]==i ? i : F[i]=find(F[i]); }
    
    void solve()
    {
        int res1,res2,res;
        int u,v;
        for(int i=n-1;i;i--) 
        {
            u=e[cut[i]][0]; v=e[cut[i]][1];
            res=0; 
            tmp=0; dfs(u,0); res=max(res,tmp); res1=f[u][0]; f[u][0]=f[u][1]=0;
            tmp=0; dfs(v,0); res=max(res,tmp); res2=f[v][0]; f[v][0]=f[v][1]=0;
            change(1,1,n,find(v),-1); F[find(v)]=find(u); 
            change(1,1,n,F[u],max(res,res1+res2));
            out[i]=g[1];
            add(u,v); 
        }
        for(int i=1;i<=n;i++) printf("%d
    ",out[i]);
    }
    
    int main()
    {
        freopen("forest.in","r",stdin);
        freopen("forest.out","w",stdout);
        init();
        solve();
    }
    View Code

    std:

    # include<iostream>
    # include<cstdio>
    # include<cstring>
    # include<cstdlib>
    using namespace std;
    const int pp=1000000007;
    int c[2008][2008],f[2008],p[2008],ni[2008];
    int n,m,k,nn;
    inline int power(int x,int n)
    {
        int ans=1,tmp=x;
        while (n)
        {
              if (n&1) ans=(long long)ans*tmp%pp;
              tmp=(long long)tmp*tmp%pp;n>>=1;
        }    
        return ans;
    }
    void Count_c()
    {
         for (int i=0;i<=nn;i++) c[i][0]=1;
         for (int i=1;i<=nn;i++)
          for (int j=1;j<=i;j++)
          {
              c[i][j]=c[i-1][j-1]+c[i-1][j];
              if (c[i][j]>=pp) c[i][j]-=pp;
          }
    }
    void Count_p()
    {
         int mm=(m-2)*n;
         for (int i=0;i<=nn;i++)
          p[i]=power(i,mm);
    }
    void Count_f()
    {
         f[0]=0;f[1]=1;
         for (int i=2;i<=nn;i++)
         {
             f[i]=power(i,n);
             for (int j=1;j<i;j++)
             {
                 f[i]-=(long long)f[j]*c[i][j]%pp;
                 if (f[i]<=-pp) f[i]+=pp;
             }
             if (f[i]<0) f[i]+=pp;
         }
    }
    void Count_ni()
    {
         ni[1]=1;
         for (int i=2;i<=nn;i++)
         ni[i]=power(i,pp-2);
    }
    int main()
    {
        freopen("photo.in","r",stdin);
        freopen("photo.out","w",stdout);
        scanf("%d%d%d",&n,&m,&k);
        nn=min(n,k);
        if (m==1)
           printf("%d
    ",power(k,n));
        else
        {
            Count_c();
            Count_p();
            Count_f();
            Count_ni();
            long long tmp=1,tmp1=1,sum=0,sum1;
            for (int s=1;s<=nn;s++)
            {
                tmp=tmp*ni[s]%pp;
                tmp=tmp*(k-s+1)%pp;
                tmp1=1;sum1=0;
                for (int j=0;j<=s;j++)
                {
                    sum1+=tmp1*c[s][s-j]%pp*p[s-j]%pp;
                    if (sum1>=pp) sum1-=pp;
                    tmp1=tmp1*ni[j+1]%pp; 
                    if (k-s<j+1) break;
                    tmp1=tmp1*(k-s-j)%pp;
                }
                sum+=tmp*f[s]%pp*f[s]%pp*sum1%pp;
                if (sum>=pp) sum-=pp;
            }
            printf("%d
    ",sum);
        }
        fclose(stdin);
        fclose(stdout);
        return 0;
    }
    View Code
  • 相关阅读:
    匿名函数
    内置函数
    基础函数--3
    基础函数(2)
    基础函数(1)
    文件的相关操作
    知识点补充,set集合,深浅copy
    is 和 ==的区别
    Django-form组件中过滤当前用户信息
    Django的常用模块引入整理
  • 原文地址:https://www.cnblogs.com/TheRoadToTheGold/p/7687578.html
Copyright © 2020-2023  润新知