一看好像会做的样子,就去做了一下,结果
猝不及防地T掉了
赶紧查了一下,没有死循环,复杂度也是对的,无果,于是翻了题解
题解没看懂,但是找到了标程,然后发现我被卡常了。。。
而且好像当时还过了前10个点啊。。这要真的是比赛稳稳的FST啊
小技巧:
逆元只需要求inv[i]和inv[i!],可以预处理出来
令md=1e9+7
则inv[1]=1
除此外inv[i]=(md-md/i)*inv[md%i]%md
令inv2[i]=inv[i!]
则inv2[n]=pow(n!,md-2)
除此外inv2[i]=inv2[i+1]*(i+1)%mod;
1 #include<cstdio> 2 #include<algorithm> 3 #include<cstring> 4 #include<map> 5 #define md 1000000007 6 using namespace std; 7 typedef long long LL; 8 LL poww(LL a,LL b) 9 { 10 LL base=a,ans=1; 11 while(b) 12 { 13 if(b&1) ans=(ans*base)%md; 14 b>>=1; 15 base=(base*base)%md; 16 } 17 return ans; 18 } 19 LL inv[1000100],inv2[1000100],jc[1000100],sum,sumx,num[30],ans,n; 20 char s1[1000100],s2[1000100]; 21 void addx(LL x) 22 //num[x]++,同时维护sum,sumx 23 { 24 sum=sum*jc[num[x]]%md; 25 num[x]++;sumx++; 26 sum=sum*sumx%md; 27 sum=sum*inv2[num[x]]%md; 28 } 29 void delx(LL x) 30 { 31 sum=sum*jc[num[x]]%md; 32 sum=sum*inv[sumx]%md;#include #include #include #include33 sumx--;num[x]--; 34 sum=sum*inv2[num[x]]%md; 35 } 36 //s1的所有排列中小于s2的个数-s1的所有排列中小于s1的个数+1 37 int main() 38 { 39 LL i,j; 40 scanf("%s",s1+1); 41 scanf("%s",s2+1);n=strlen(s2+1); 42 jc[0]=1; 43 for(i=1;i<=1000000;i++) jc[i]=jc[i-1]*i%md; 44 inv[1]=1; 45 for(i=2;i<=1000000;i++) inv[i]=(md-md/i)*inv[md%i]%md; 46 for(i=0;i<=1000000;i++) inv2[i]=poww(jc[i],md-2); 47 for(i=1;i<=n;i++) num[s1[i]-'a']++; 48 sum=jc[n];sumx=n; 49 for(i=0;i<26;i++) sum=sum*inv2[num[i]]%md; 50 for(i=1;i<=n;i++) 51 { 52 for(j=0;j<s2[i]-'a';j++) 53 if(num[j]) 54 { 55 delx(j); 56 ans=(ans+sum)%md; 57 addx(j); 58 } 59 if(num[s2[i]-'a']) delx(s2[i]-'a'); 60 else break; 61 } 62 for(i=0;i<26;i++) num[i]=0; 63 for(i=1;i<=n;i++) num[s1[i]-'a']++; 64 sum=jc[n];sumx=n; 65 for(i=0;i<26;i++) sum=sum*inv2[num[i]]%md; 66 for(i=1;i<=n;i++) 67 { 68 for(j=0;j<s1[i]-'a';j++) 69 if(num[j]) 70 { 71 delx(j); 72 ans=(ans-sum+md)%md; 73 addx(j); 74 } 75 if(num[s1[i]-'a']) delx(s1[i]-'a'); 76 else break; 77 } 78 ans=(ans-1+md)%md; 79 printf("%lld",ans); 80 return 0; 81 }
原来的代码(假的)
1 #include<cstdio> 2 #include<algorithm> 3 #include<cstring> 4 #include<map> 5 #define md 1000000007 6 using namespace std; 7 typedef long long LL; 8 LL poww(LL a,LL b) 9 { 10 LL base=a,ans=1; 11 while(b) 12 { 13 if(b&1) ans=(ans*base)%md; 14 b>>=1; 15 base=(base*base)%md; 16 } 17 return ans; 18 } 19 LL inv[1000100],inv2[1000100],jc[1000100],sum,sumx,num[30],ans,n; 20 char s1[1000100],s2[1000100]; 21 void addx(LL x) 22 //num[x]++,同时维护sum,sumx 23 { 24 sum=sum*jc[num[x]]%md; 25 num[x]++;sumx++; 26 sum=sum*sumx%md; 27 sum=sum*inv2[num[x]]%md; 28 } 29 void delx(LL x) 30 { 31 sum=sum*jc[num[x]]%md; 32 sum=sum*inv[sumx]%md; 33 sumx--;num[x]--; 34 sum=sum*inv2[num[x]]%md; 35 } 36 //s1的所有排列中小于s2的个数-s1的所有排列中小于s1的个数+1 37 int main() 38 { 39 LL i,j; 40 scanf("%s",s1+1); 41 scanf("%s",s2+1);n=strlen(s2+1); 42 jc[0]=1; 43 for(i=1;i<=1000000;i++) jc[i]=jc[i-1]*i%md; 44 inv[1]=1; 45 for(i=2;i<=1000000;i++) inv[i]=(md-md/i)*inv[md%i]%md; 46 for(i=0;i<=1000000;i++) inv2[i]=poww(jc[i],md-2); 47 for(i=1;i<=n;i++) num[s1[i]-'a']++; 48 sum=jc[n];sumx=n; 49 for(i=0;i<26;i++) sum=sum*inv2[num[i]]%md; 50 for(i=1;i<=n;i++) 51 { 52 for(j=0;j<s2[i]-'a';j++) 53 if(num[j]) 54 { 55 delx(j); 56 ans=(ans+sum)%md; 57 addx(j); 58 } 59 if(num[s2[i]-'a']) delx(s2[i]-'a'); 60 else break; 61 } 62 for(i=0;i<26;i++) num[i]=0; 63 for(i=1;i<=n;i++) num[s1[i]-'a']++; 64 sum=jc[n];sumx=n; 65 for(i=0;i<26;i++) sum=sum*inv2[num[i]]%md; 66 for(i=1;i<=n;i++) 67 { 68 for(j=0;j<s1[i]-'a';j++) 69 if(num[j]) 70 { 71 delx(j); 72 ans=(ans-sum+md)%md; 73 addx(j); 74 } 75 if(num[s1[i]-'a']) delx(s1[i]-'a'); 76 else break; 77 } 78 ans=(ans-1+md)%md; 79 printf("%lld",ans); 80 return 0; 81 }