题目大意
有一个(1001 imes n)的的网格,每个格子有(q)的概率是安全的,(1-q)的概率是危险的。
定义一个矩形是合法的当且仅当:
- 这个矩形中每个格子都是安全的
- 必须紧贴网格的下边界
问你最大的合法子矩形大小为(k)的概率是多少。
(nleq {10}^9,kleq 1000)
吉老师:这题本来是(kleq 20000)
题解
一道好题。
我们计算最大子矩形不超过(i)的答案(s_i),那么答案就是(s_k-s_{k-1})。
显然最后一行连续的安全格子不会超过(k)个。
设(g_{i,j})表示长度为(j),高度为(i)的海域全部是安全的,剩下的部分未知,最大子矩形(leq k)的概率。
设(h_{i,j})表示长度为(j),高度为(i+1)的海域中,前(i)行全部是安全的,剩下的未知且((i+1,j))是危险的,最大子矩形(leq k)的概率。
边界:
那么我们从(k-1)到(1)DP,对于(i)行(j)列,枚举第(i+1)行的下一个危险的格子在哪个地方,然后转移:
因为第(i)行的宽度不会超过(lfloorfrac{k}{i} floor),所以的暴力的时间复杂度是(sum_{i=1}^k{lfloorfrac{k}{i} floor}^2=O(k^2))。
这已经足够了,但我们可以做的更好。
设
那么
时间复杂度是(sum_{i=1}^klfloorfrac{k}{i} floorloglfloorfrac{k}{i} floor=O(klog^2k))
设(f_i)为前(i)列最大子矩形(leq k)的概率,那么
这就是一个常系数线性递推。
时间复杂度:
- 暴力:(O(nk)),(70)pts
- 矩阵快速幂:(O(k^3log n)),(90)pts
- 特征多项式+暴力:(O(k^2log n)),(100)pts
- 特征多项式+NTT取模:(O(klog klog n)),(100)pts
这里简单讲一下最后一个做法
矩阵快速幂是给你一个矩阵(A),求((A^n)_{1,1})
设矩阵的大小为(k)。
根据Cayley-Hamilton定理,(|lambda I-A|)是一个关于(lambda)的(k)次多项式,记为(g(lambda))。对于任意矩阵(A),有(g(A)=0)
对于常系数线性递推的矩阵,设(f_i=sum_{j=1}^kf_{i-j}a_j),(g(lambda)=lambda^k-sum_{i=1}^{k}a_{i}lambda^{k-i})。
所以我们只需要求(A^nmod g(A))。可以用快速幂(倍增取模)求解。
然后还要求出(f_1ldots f_k),可以通过其他方法计算(多项式求逆或者题目给你了)。
最后一次卷积可以得到答案。
如果要求(f_{n-k+1}ldots f_n),那就把(f_1ldots f_{2k})带进去卷积。
总时间复杂度:(O(klog^2k+klog klog n))
代码
暴力取模
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
#include<cmath>
#include<functional>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
void sort(int &a,int &b)
{
if(a>b)
swap(a,b);
}
void open(const char *s)
{
#ifndef ONLINE_JUDGE
char str[100];
sprintf(str,"%s.in",s);
freopen(str,"r",stdin);
sprintf(str,"%s.out",s);
freopen(str,"w",stdout);
#endif
}
int rd()
{
int s=0,c;
while((c=getchar())<'0'||c>'9');
do
{
s=s*10+c-'0';
}
while((c=getchar())>='0'&&c<='9');
return s;
}
int upmin(int &a,int b)
{
if(b<a)
{
a=b;
return 1;
}
return 0;
}
int upmax(int &a,int b)
{
if(b>a)
{
a=b;
return 1;
}
return 0;
}
ll p=998244353;
void add(ll &a,ll b)
{
a=(a+b)%p;
}
ll fp(ll a,ll b)
{
ll s=1;
for(;b;b>>=1,a=a*a%p)
if(b&1)
s=s*a%p;
return s;
}
ll inv(ll a)
{
return fp(a,p-2);
}
ll pw1[1010];
ll pw2[1010];
ll q;
ll q2;
ll g[1010][1010];
ll h[1010][1010];
ll f[2010];
ll a[2010];
ll c[2010];
ll d[2010];
ll final[2010];
void mul(ll *a,ll *b,ll *e,int len)
{
static ll c[2010];
int i,j;
for(i=0;i<=2*len;i++)
c[i]=0;
for(i=0;i<=len;i++)
for(j=0;j<=len;j++)
add(c[i+j],a[i]*b[j]);
for(i=2*len;i>=len;i--)
{
ll v=c[i]*inv(e[len]);
if(v)
for(j=0;j<=len;j++)
c[i-len+j]=(c[i-len+j]-e[j]*v)%p;
}
for(i=0;i<=len;i++)
a[i]=c[i];
}
ll solve(int n,int k)
{
if(!k)
return fp(q2,n);
memset(g,0,sizeof g);
memset(h,0,sizeof h);
g[k][1]=q2*pw1[k]%p;
g[k][0]=1;
int i,j,l;
for(i=k-1;i>=1;i--)
{
int m=k/i;
g[i][0]=1;
h[i][0]=1;
for(j=0;j<=m;j++)
{
for(l=j+1;l<=m;l++)
add(h[i][l],h[i][j]*g[i+1][l-j-1]%p*q2%p*pw1[i]%p);
for(l=j;l<=m;l++)
if(l)
add(g[i][l],h[i][j]*g[i+1][l-j]%p);
}
}
memset(f,0,sizeof f);
f[0]=1;
for(i=1;i<=2*(k+1);i++)
for(j=0;j<i&&j<=k;j++)
add(f[i],f[i-j-1]*q2%p*g[1][j]);
if(n<=2*(k+1))
{
ll s=0;
for(i=0;i<=n&&i<=k;i++)
add(s,f[n-i]*g[1][i]);
return s;
}
int len=k+1;
for(i=0;i<len;i++)
a[i]=-q2*g[1][len-i-1]%p;
a[len]=1;
memset(c,0,sizeof c);
c[1]=1;
memset(d,0,sizeof d);
d[0]=1;
int m=n-k-1;
while(m)
{
if(m&1)
mul(d,c,a,len);
mul(c,c,a,len);
m>>=1;
}
memset(final,0,sizeof final);
for(i=1;i<=k+1;i++)
for(j=0;j<=k;j++)
add(final[i],d[j]*f[i+j]);
ll s=0;
for(i=1;i<=k+1;i++)
add(s,final[i]*g[1][k+1-i]);
return s;
}
int main()
{
open("bzoj4944");
int n,k,x,y;
scanf("%d%d%d%d",&n,&k,&x,&y);
q=x*inv(y)%p;
q2=(y-x)*inv(y)%p;
pw1[0]=pw2[0]=1;
int i;
for(i=1;i<=k;i++)
{
pw1[i]=pw1[i-1]*q%p;
pw2[i]=pw2[i-1]*q2%p;
}
ll ans1=solve(n,k);
ll ans2=solve(n,k-1);
ll ans=((ans1-ans2)%p+p)%p;
printf("%lld
",ans);
return 0;
}
NTT取模
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
#include<cmath>
#include<functional>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
void sort(int &a,int &b)
{
if(a>b)
swap(a,b);
}
void open(const char *s)
{
#ifndef ONLINE_JUDGE
char str[100];
sprintf(str,"%s.in",s);
freopen(str,"r",stdin);
sprintf(str,"%s.out",s);
freopen(str,"w",stdout);
#endif
}
int rd()
{
int s=0,c;
while((c=getchar())<'0'||c>'9');
do
{
s=s*10+c-'0';
}
while((c=getchar())>='0'&&c<='9');
return s;
}
int upmin(int &a,int b)
{
if(b<a)
{
a=b;
return 1;
}
return 0;
}
int upmax(int &a,int b)
{
if(b>a)
{
a=b;
return 1;
}
return 0;
}
const ll p=998244353;
const int maxn=300000;
ll fp(ll a,ll b)
{
ll s=1;
for(;b;b>>=1,a=a*a%p)
if(b&1)
s=s*a%p;
return s;
}
namespace ntt
{
const ll g=3;
ll w1[maxn];
ll w2[maxn];
int rev[maxn];
int n;
void init(int m)
{
n=1;
while(n<m)
n<<=1;
int i;
for(i=2;i<=n;i<<=1)
{
w1[i]=fp(g,(p-1)/i);
w2[i]=fp(w1[i],p-2);
}
rev[0]=0;
for(i=1;i<n;i++)
rev[i]=(rev[i>>1]>>1)|((i&1)*(n>>1));
}
void ntt(ll *a,int t)
{
int i,j,k;
ll u,v,w,wn;
for(i=0;i<n;i++)
if(rev[i]<i)
swap(a[i],a[rev[i]]);
for(i=2;i<=n;i<<=1)
{
wn=(t==1?w1[i]:w2[i]);
for(j=0;j<n;j+=i)
{
w=1;
for(k=j;k<j+i/2;k++)
{
u=a[k];
v=a[k+i/2]*w%p;
a[k]=(u+v)%p;
a[k+i/2]=(u-v)%p;
w=w*wn%p;
}
}
}
if(t==-1)
{
u=fp(n,p-2);
for(i=0;i<n;i++)
a[i]=a[i]*u%p;
}
}
ll x[maxn];
ll y[maxn];
ll z[maxn];
void copy_clear(ll *a,ll *b,int m)
{
int i;
for(i=0;i<m;i++)
a[i]=b[i];
for(i=m;i<n;i++)
a[i]=0;
}
void copy(ll *a,ll *b,int m)
{
int i;
for(i=0;i<m;i++)
a[i]=b[i];
}
void mul(ll *a,ll *b,ll *c,int m)
{
init(m<<1);
copy_clear(x,a,m);
copy_clear(y,b,m);
ntt(x,1);
ntt(y,1);
int i;
for(i=0;i<n;i++)
x[i]=x[i]*y[i]%p;
ntt(x,-1);
copy(c,x,m);
}
void inverse(ll *a,ll *b,int m)
{
if(m==1)
{
b[0]=fp(a[0],p-2);
return;
}
inverse(a,b,m>>1);
init(m<<1);
copy_clear(x,a,m);
copy_clear(y,b,m>>1);
ntt(x,1);
ntt(y,1);
int i;
for(i=0;i<n;i++)
x[i]=y[i]*(2-x[i]*y[i]%p)%p;
ntt(x,-1);
copy(b,x,m);
}
ll c[maxn],d[maxn],e[maxn],f[maxn];
void sqrt(ll *a,ll *b,int m)
{
if(m==1)
{
if(a[0]==1)
b[0]=1;
else if(a[0]==0)
b[0]=0;
else
//我也不会
;
return;
}
sqrt(a,b,m>>1);
// copy_clear(c,b,m>>1);
int i;
for(i=m;i<m<<1;i++)
b[i]=0;
inverse(b,d,m);
init(m<<1);
for(i=m;i<m<<1;i++)
b[i]=d[i]=0;
ll inv2=fp(2,p-2);
copy_clear(x,a,m);
ntt(x,1);
ntt(d,1);
for(i=0;i<n;i++)
x[i]=x[i]*d[i]%p;
ntt(x,-1);
for(i=0;i<m;i++)
b[i]=((b[i]+x[i])%p*inv2)%p;
}
void derivative(ll *a,ll *b,int m)
{
int i;
for(i=0;i<m-1;i++)
b[i]=(i+1)*a[i+1]%p;
b[m-1]=0;
}
void differential(ll *a,ll *b,int m)
{
// int i;
// for(i=m-1;i>=1;i--)
// b[i]=a[i-1]*inv[i]%p;
b[0]=0;
}
void ln(ll *a,ll *b,int m)
{
static ll c[maxn],d[maxn];
derivative(a,c,m);
inverse(a,d,m);
init(m<<1);
int i;
for(i=m;i<n;i++)
c[i]=d[i]=0;
ntt(c,1);
ntt(d,1);
for(i=0;i<n;i++)
c[i]=c[i]*d[i]%p;
ntt(c,-1);
differential(c,b,m);
}
void exp(ll *a,ll *b,int m)
{
if(m==1)
{
b[0]=1;
return;
}
exp(a,b,m>>1);
int i;
for(i=m>>1;i<m;i++)
b[i]=0;
ln(b,y,m);
init(m<<1);
copy_clear(x,a,m);
x[0]++;
for(i=0;i<m;i++)
x[i]=(x[i]-y[i])%p;
copy_clear(y,b,m);
ntt(x,1);
ntt(y,1);
for(i=0;i<n;i++)
x[i]=x[i]*y[i]%p;
ntt(x,-1);
copy(b,x,m);
}
void module(ll *a,ll *b,ll *c,int n1,int n2)
{
int k=1;
while(k<=n1-n2+1)
k<<=1;
int i;
for(i=0;i<=n1;i++)
d[i]=a[i];
for(i=0;i<=n2;i++)
e[i]=b[i];
reverse(d,d+n1+1);
reverse(e,e+n2+1);
for(i=n1-n2+1;i<k<<1;i++)
d[i]=e[i]=0;
inverse(e,f,k);
for(i=n1-n2+1;i<k<<1;i++)
f[i]=0;
init(k<<1);
ntt::ntt(d,1);
ntt::ntt(f,1);
for(i=0;i<n;i++)
e[i]=d[i]*f[i]%p;
ntt::ntt(e,-1);
for(i=0;i<=n1-n2;i++)
c[i]=e[i];
reverse(c,c+n1-n2+1);
}
};
void add(ll &a,ll b)
{
a=(a+b)%p;
}
ll inv(ll a)
{
return fp(a,p-2);
}
ll pw1[maxn];
ll pw2[maxn];
ll q;
ll q2;
ll f[maxn];
ll a[maxn];
ll c[maxn];
ll d[maxn];
ll final[maxn];
ll g[2][maxn];
ll h[maxn];
ll e[maxn];
void mul(ll *a,ll *b,ll *c,int n)
{
static ll d[maxn],e[maxn];
int k=1;
while(k<=n)
k<<=1;
ntt::init(k<<1);
int i;
for(i=0;i<k<<1;i++)
d[i]=e[i]=0;
for(i=0;i<=n;i++)
{
d[i]=a[i];
e[i]=b[i];
}
ntt::ntt(d,1);
ntt::ntt(e,1);
for(i=0;i<k<<1;i++)
d[i]=d[i]*e[i]%p;
ntt::ntt(d,-1);
//d=a*b
for(i=0;i<k<<1;i++)
e[i]=0;
int n2=(k<<1)-1;
while(!d[n2])
n2--;
ntt::module(d,c,e,n2,n);
for(i=0;i<n;i++)
a[i]=d[i];
for(i=0;i<k;i++)
d[i]=c[i];
for(i=k;i<k<<1;i++)
d[i]=0;
ntt::init(k<<1);
ntt::ntt(d,1);
ntt::ntt(e,1);
for(i=0;i<k<<1;i++)
d[i]=d[i]*e[i]%p;
ntt::ntt(d,-1);
for(i=0;i<n;i++)
a[i]=(a[i]-d[i])%p;
}
void powmod(ll *a,ll *b,ll *c,int m,int n)
{
if(!n)
return;
powmod(a,b,c,m,n>>1);
mul(a,a,c,m);
if(n&1)
mul(a,b,c,m);
}
ll solve(int n,int k)
{
memset(g,0,sizeof g);
memset(h,0,sizeof h);
int now=0;
g[now][1]=q2*pw1[k]%p;
g[now][0]=1;
h[0]=1;
int i,j;
for(i=k-1;i>=1;i--)
{
now^=1;
int m=k/i;
ll c=q2*pw1[i]%p;
int len=1;
while(len<=m)
len<<=1;
for(j=1;j<len;j++)
e[j]=-c*g[now^1][j-1];
e[0]=1;
ntt::inverse(e,h,len);
for(j=m+1;j<len<<1;j++)
h[j]=0;
ntt::init(len<<1);
ntt::ntt(g[now^1],1);
ntt::ntt(h,1);
for(j=0;j<len<<1;j++)
g[now][j]=g[now^1][j]*h[j]%p;
ntt::ntt(g[now],-1);
for(j=m+1;j<len<<1;j++)
g[now][j]=0;
}
memset(a,0,sizeof a);
for(i=0;i<=k;i++)
a[i+1]=-g[now][i]*q2%p;
a[0]=1;
int len=1;
while(len<=k+1)
len<<=1;
ntt::inverse(a,f,len<<1);
if(n<=2*(k+1))
{
ll s=0;
for(i=0;i<=n&&i<=k;i++)
add(s,f[n-i]*g[now][i]);
return s;
}
memset(a,0,sizeof a);
memset(c,0,sizeof c);
memset(d,0,sizeof d);
for(i=0;i<=k;i++)
a[i]=-g[now][k-i]*q2%p;
a[k+1]=1;
if(k)
c[1]=1;
else
c[0]=-a[0];
d[0]=1;
int m=n-k;
powmod(d,c,a,k+1,m);
// while(m)
// {
// if(m&1)
// mul(d,c,a,k+1);
// mul(c,c,a,k+1);
// m>>=1;
//// for(i=0;i<=k;i++)
//// printf("%lld ",(d[i]+p)%p);
//// printf("
");
// }
reverse(d,d+k+1);
ntt::init(len<<2);
ntt::ntt(d,1);
ntt::ntt(f,1);
for(i=0;i<len<<2;i++)
final[i]=d[i]*f[i]%p;
ntt::ntt(final,-1);
ll s=0;
for(i=0;i<=k;i++)
add(s,g[now][i]*final[2*k-i]);
return s;
// for(i=0;i<=k;i++)
// g[now][i]=(g[now][i]+p)%p;
// memset(f,0,sizeof f);
// f[0]=1;
// for(i=1;i<=2*(k+1);i++)
// for(j=0;j<i&&j<=k;j++)
// add(f[i],f[i-j-1]*q2%p*g[now][j]);
// if(n<=2*(k+1))
// {
// ll s=0;
// for(i=0;i<=n&&i<=k;i++)
// add(s,f[n-i]*g[now][i]);
// return s;
// }
// int len=k+1;
// for(i=0;i<len;i++)
// a[i]=-q2*g[now][len-i-1]%p;
// a[len]=1;
// memset(c,0,sizeof c);
// c[1]=1;
// memset(d,0,sizeof d);
// d[0]=1;
// int m=n-k-1;
// while(m)
// {
// if(m&1)
// mul(d,c,a,len);
// mul(c,c,a,len);
// m>>=1;
// }
// memset(final,0,sizeof final);
// for(i=1;i<=k+1;i++)
// for(j=0;j<=k;j++)
// add(final[i],d[j]*f[i+j]);
// ll s=0;
// for(i=1;i<=k+1;i++)
// add(s,final[i]*g[now][k+1-i]);
// return s;
}
int main()
{
open("bzoj4944");
int n,k,x,y;
scanf("%d%d%d%d",&n,&k,&x,&y);
q=x*inv(y)%p;
q2=(y-x)*inv(y)%p;
pw1[0]=pw2[0]=1;
int i;
for(i=1;i<=k;i++)
{
pw1[i]=pw1[i-1]*q%p;
pw2[i]=pw2[i-1]*q2%p;
}
ll ans1=solve(n,k);
ll ans2=solve(n,k-1);
ll ans=((ans1-ans2)%p+p)%p;
printf("%lld
",ans);
return 0;
}