• CF1111E Tree 动态规划+LCT


    这个题的思路非常好啊.    

    我们可以把 $k$ 个点拿出来,那么就是求将 $k$ 个点划分成不大于 $m$ 个集合的方案数.   

    令 $f[i][j]$ 表示将前 $i$ 个点划分到 $j$ 个集合中的方案数.  

    那么有 $f[i][j]=f[i-1][j-1]+f[i-1][j]*(j-fail[i])$,其中 $fail[i]$ 代表 $i$ 到根这条路径上祖先数量.             

    而 $fail[i]$ 的求解方式有:虚数统计/树上数据结构维护路径和,这里选择了用 LCT 来维护. 

    code: 

    #include <cstdio> 
    #include <cstring> 
    #include <vector>
    #include <algorithm>  
    #define N 100007   
    #define ll long long 
    #define mod 1000000007 
    #define setIO(s) freopen(s".in","r",stdin) 
    using namespace std;   
    namespace LCT 
    {    
        #define lson t[x].ch[0] 
        #define rson t[x].ch[1] 
        struct node 
        {
            int ch[2],f,rev,sum,val;       
        }t[N];     
        int sta[N];  
        int get(int x) 
        {
            return t[t[x].f].ch[1]==x; 
        }                                    
        int isrt(int x) 
        {
            return !(t[t[x].f].ch[0]==x||t[t[x].f].ch[1]==x); 
        }
        void pushup(int x) 
        {
            t[x].sum=t[lson].sum+t[rson].sum+t[x].val; 
        }  
        void mark(int x) 
        {  
            t[x].rev^=1; 
            swap(lson,rson);    
        }  
        void pushdown(int x) 
        { 
            if(t[x].rev) 
            {
                if(lson) mark(lson); 
                if(rson) mark(rson); 
                t[x].rev=0; 
            }
        }
        void rotate(int x) 
        {
            int old=t[x].f,fold=t[old].f,which=get(x);            
            if(!isrt(old)) t[fold].ch[t[fold].ch[1]==old]=x;  
            t[old].ch[which]=t[x].ch[which^1],t[t[old].ch[which]].f=old;  
            t[x].ch[which^1]=old,t[old].f=x,t[x].f=fold; 
            pushup(old),pushup(x); 
        }    
        void splay(int x) 
        { 
            int v=0,u=x,fa;           
            for(sta[++v]=u;!isrt(u);u=t[u].f) sta[++v]=t[u].f;            
            for(;v;--v) pushdown(sta[v]); 
            for(u=t[u].f;(fa=t[x].f)!=u;rotate(x)) 
            {
                if(t[fa].f!=u)
                    rotate(get(fa)==get(x)?fa:x);  
            }
        }
        void Access(int x) 
        {
            for(int y=0;x;y=x,x=t[x].f) 
            {
                splay(x); 
                rson=y; 
                pushup(x);  
            }
        }
        void makeroot(int x) 
        {
            Access(x),splay(x),mark(x); 
        }   
        void split(int x,int y) 
        {
            makeroot(x),Access(y),splay(y);  
        }
        void add(int x,int v) 
        {
            Access(x),splay(x); 
            t[x].val+=v,pushup(x);        
        }
        int query(int x) 
        {           
            Access(x),splay(x);  
            return t[x].sum;   
        }
        #undef lson 
        #undef rson 
    }; 
    int n,edges; 
    int hd[N],to[N<<1],nex[N<<1],f[N],A[N],dp[N][302];     
    void add(int u,int v) 
    {
        nex[++edges]=hd[u],hd[u]=edges,to[edges]=v;    
    }
    void dfs(int u,int ff) 
    {
        LCT::t[u].f=ff;  
        for(int i=hd[u];i;i=nex[i]) 
        {
            int v=to[i]; 
            if(v==ff) continue;  
            dfs(v,u); 
        }  
    }
    int main() 
    { 
        // setIO("input");      
        int i,j,q; 
        scanf("%d%d",&n,&q);  
        for(i=1;i<n;++i) 
        {   
            int x,y; 
            scanf("%d%d",&x,&y);  
            add(x,y),add(y,x);   
        }  
        dfs(1,0); 
        for(i=1;i<=q;++i) 
        {
            int k,m,r,flag=0; 
            scanf("%d%d%d",&k,&m,&r); 
            LCT::makeroot(r);   
            for(j=1;j<=k;++j) 
            {
                scanf("%d",&A[j]);    
                LCT::add(A[j],1);              
            }   
            for(j=1;j<=k;++j) 
            {      
                f[j]=LCT::query(A[j])-1;        
                if(f[j]>m) flag=1;             
            } 
            for(j=1;j<=k;++j)  LCT::add(A[j],-1); 
            if(flag)   printf("0
    "); 
            else 
            {        
                sort(f+1,f+1+k);    
                dp[1][1]=1;   
                for(j=2;j<=k;++j) 
                {      
                    for(int p=1;p<=min(j,m);++p) 
                    {
                        dp[j][p]=0; 
                        if(p<f[j])   dp[j][p]=dp[j-1][p-1];  
                        else dp[j][p]=(ll)(dp[j-1][p-1]+1ll*(p-f[j])*dp[j-1][p]%mod)%mod;               
                    }
                }
                int ans=0;   
                for(j=1;j<=m;++j)  ans=(ans+dp[k][j])%mod;  
                printf("%d
    ",ans); 
            }
        }
        return 0; 
    }
    

      

  • 相关阅读:
    python 打印对象的所有属性值
    selenium+python测试
    java连接3种数据库 JdbcLinkDB --201801
    又来折腾--正则表达式
    Excel 将A表的基础数据拼接到B表中来-三种方法: ctrl+回车, VLOOKUP()函数,宏
    Excel如何快速统计一列中相同数值出现的个数--数据透视表
    Jmeter、Postman 、 loadrunner SoapUI 接口测试工具
    delphi 获取时间戳 如何得到 和 js 中 new Date().getTime();的 相同?
    IIS部署项目
    C#使用log4net记录日志
  • 原文地址:https://www.cnblogs.com/guangheli/p/12091045.html
Copyright © 2020-2023  润新知