「LibreOJ NOI Round #2」不等关系
传送门
题解
首先考虑,我们可以通过把(<)合并成一段,然后这个(>)不是很好考虑,所以往容斥方面想.
容斥系数为(pm 1),所以考虑设(f_i)表示前(i)段合法的情况,有:
[f_i=-sum_{j=0}^{i-1}f_jinom{n-s_j}{s_i-s_j}
]
这个东西可以分治(NTT),就是把(inom{n-s_j}{s_i-s_j})拆开,然后对应的改成位置即可.
具体实现可以通过把对应的位置如果不是('>')改成(0)即可.
代码
#include<stdio.h>
#include<stdlib.h>
#include<string.h>
#include<math.h>
#include<algorithm>
#include<queue>
#include<set>
#include<map>
#include<iostream>
using namespace std;
#define ll long long
#define REP(a,b,c) for(int a=b;a<=c;a++)
#define re register
#define file(a) freopen(a".in","r",stdin);freopen(a".out","w",stdout)
typedef pair<int,int> pii;
#define mp make_pair
inline int gi()
{
int f=1,sum=0;char ch=getchar();
while(ch>'9' || ch<'0'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0' && ch<='9'){sum=(sum<<3)+(sum<<1)+ch-'0';ch=getchar();}
return f*sum;
}
const int N=400010,Mod=998244353,G=3;
char s[N];
int n,f[N],g[N],A[N],B[N],rev[N],fl[N],ifac[N],fac[N],inv[N],ans;
int qpow(int a,int b){int ret=1;while(b){if(b&1)ret=1ll*ret*a%Mod;b>>=1;a=1ll*a*a%Mod;}return ret;}
void ntt(int *a,int limit,int opt)
{
for(int i=0;i<limit;i++)
if(i<rev[i])swap(a[i],a[rev[i]]);
for(int i=1;i<limit;i<<=1)
{
int Rt=qpow(G,(Mod-1)/(i<<1));
for(int R=i<<1,j=0;j<limit;j+=R)
{
int W=1;
for(int k=0;k<i;k++,W=1ll*W*Rt%Mod)
{
int X=a[j+k],Y=1ll*W*a[i+j+k]%Mod;
a[j+k]=(X+Y)%Mod;
a[i+j+k]=(X-Y+Mod)%Mod;
}
}
}
if(opt==-1)
{
reverse(a+1,a+limit);int Inv=qpow(limit,Mod-2);
for(int i=0;i<limit;i++)a[i]=1ll*a[i]*Inv%Mod;
}
}
void cdq(int l,int r)
{
if(l==r)
{
if(!l)f[l]=1;
else if(!fl[l])f[l]=0;
else f[l]=1ll*f[l]*(Mod-ifac[n-l])%Mod;
return;
}
int mid=(l+r)>>1;
cdq(l,mid);
for(int i=l;i<=mid;i++)A[i-l]=1ll*f[i]*fac[n-i]%Mod;
for(int i=1;i<=r-l+1;i++)B[i]=ifac[i];
int L=0,limit=1;
while(limit<=(r-l+1+mid-l+1))limit<<=1,L++;
for(int i=0;i<limit;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
ntt(A,limit,1);ntt(B,limit,1);
for(int i=0;i<limit;i++)A[i]=1ll*A[i]*B[i]%Mod;
ntt(A,limit,-1);
for(int i=mid+1;i<=r;i++)f[i]=(f[i]+A[i-l])%Mod;
for(int i=0;i<limit;i++)A[i]=B[i]=0;
cdq(mid+1,r);
}
int main()
{
scanf("%s",s+1);n=strlen(s+1)+1;
fac[0]=fac[1]=ifac[0]=ifac[1]=inv[0]=inv[1]=1;
for(int i=2;i<=n;i++)
{
fac[i]=1ll*fac[i-1]*i%Mod;
inv[i]=1ll*(Mod-Mod/i)*inv[Mod%i]%Mod;
ifac[i]=1ll*ifac[i-1]*inv[i]%Mod;
}
int cnt=0;
for(int i=1;i<n;i++)
if(s[i]=='>')fl[i]=1,cnt++;
cdq(0,n);
for(int i=0;i<=n;i++)ans=(ans+f[i])%Mod;
if(cnt&1)ans=Mod-ans;
printf("%d
",ans);
return 0;
}