题目大意
用N个不同的字符(编号1 – N),组成一个字符串,有如下要求:
(1) 对于编号为i的字符,如果2 * i > n,则该字符可以作为结尾字符。如果不作为结尾字符而是中间的字符,则该字符后面可以接任意字符。
(2) 对于编号为i的字符,如果2 * i <= n,则该字符不可以作为结尾字符。作为中间字符,那么后面接的字符编号一定要 >= 2 * i。
问有多少长度为M且符合条件的字符串,由于数据很大,只需要输出该数Mod 10^9 + 7的结果。
例如:N = 2,M = 3。则abb, bab, bbb是符合条件的字符串,剩下的均为不符合条件的字符串。
Input
输入2个数,N, M中间用空格分割,N为不同字符的数量,M为字符串的长度。(2 <= N <= 10^6, 2 <= M <= 10^18)
Output
输出符合条件的字符串的数量。由于数据很大,只需要输出该数Mod 10^9 + 7的结果。
Input示例
6 3
Output示例
73
Solution
先尝试分解每个字符串,可以发现每个合法字符串可以被这样的“链”构造,且划分方式唯一
—–这样的“链”:后一个字符串的编号大于等于2*前一个字符串(最后一个字母编号i*2>n)
则问题被分成两块
- g(x)表示长度为x的合法链的个数,求g(x)
- v(x)表示长度为x的合法字符串数,求v(x)
显然v(x) = g(1) * v(x – 1) + g(2) * v(x – 2) +…+g(p) * v(x – p)
可以用矩阵乘法优化 复杂度O((logn)^3 * logm)
对于g的转移,设g(x,len)为以x为开头,长度为len的链的方案数
则有g(x,len)=g(2*x,len-1)+g(2*x+1,len-1)+…+g(n,len-1) (有优化过V3,不过我没打)
这道题就解了。
Code
1 #include<cstdio> 2 #include<cstdlib> 3 #include<cstring> 4 #include<algorithm> 5 #define fo(i,a,b) for(int i=a;i<=b;i++) 6 #define fd(i,a,b) for(int i=a;i>=b;i--) 7 typedef long long LL; 8 typedef double DB; 9 using namespace std; 10 LL read() { 12 LL x=0,f=1;char ch=getchar(); 13 while(ch<'0'||ch>'9')f=(ch=='-')?-1:f,ch=getchar(); 14 while(ch>='0'&&ch<='9')x=x*10+(ch-'0'),ch=getchar();return f*x; 15 } 16 const int mo=1e9+7,N=1e6+50; 17 int n,s[N][25]; 18 LL m; 19 struct Matrix { 21 int a[25][25]; 22 void clear(int n) {fo(i,1,n)fo(j,1,n)a[i][j]=0;} 23 void init(int n) {clear(n);fo(i,1,n)a[i][i]=1;} 24 }a; 25 void mul(Matrix &c,Matrix a,Matrix b) { 27 c.clear(20); 28 fo(i,1,20)fo(k,1,20)fo(j,1,20)(c.a[i][j]+=(LL)a.a[i][k]*b.a[k][j]%mo)%=mo; 29 } 30 Matrix ksm(Matrix x,LL y) { 32 Matrix ans; 33 ans.init(20); 34 while(y) { 36 if(y&1)mul(ans,ans,x); 37 mul(x,x,x),y>>=1; 38 } 39 return ans; 40 } 41 int main() { 43 n=(int)read(),m=read(); 44 fd(i,n,1) 45 if(i*2>n)s[i][1]=s[i+1][1]+1; 46 else fo(j,1,20)s[i][j]=s[i*2][j-1]+s[i+1][j],s[i][j]-=(s[i][j]>=mo)?mo:0; 47 fo(i,1,20)a.a[20-i+1][20]=s[1][i]; 48 fo(i,1,19)a.a[i+1][i]=1; 49 a=ksm(a,m); 50 printf("%d",a.a[20][20]); 51 return 0; 52 }