题目
题目链接:https://www.luogu.com.cn/problem/P7324
定义二元操作符 <
:对于两个长度都为 (n) 的数组 (A, B)(下标从 (1) 到 (n)),(A)<
(B) 的结果也是一个长度为 (n) 的数组,记为 (C)。则有 (C[i] = min(A[i], B[i]))((1 le i le n))。
定义二元操作符 >
:对于两个长度都为 (n) 的数组 (A, B)(下标从 (1) 到 (n)),(A)>
(B) 的结果也是一个长度为 (n) 的数组,记为 (C)。则有 (C[i] = max(A[i], B[i]))((1 le i le n))。
现在有 (m)((1 le m le 10))个长度均为 (n) 的整数数组 (A_0, A_1, ldots , A_{m-1})。给定一个待计算的表达式 (E),其满足 (E) 中出现的每个操作数都是 (A_0, A_1, ldots , A_{m-1}) 其中之一,且 (E) 中只包含 <
和 >;
两种操作符(<
和 >
的运算优先级相同),因此该表达式的结果值也将是一个长度为 (n) 的数组。
特殊地,表达式 (E) 中还可能出现操作符 ?
,它表示该运算符可能是 <
也可能是 >
。因此若表达式中有 (t) 个 ?
,则该表达式可生成 (2^t) 个可求确定值的表达式,从而可以得到 (2^t) 个结果值,你的任务就是求出这 (2^t) 个结果值(每个结果都是一个数组)中所有的元素的和。你只需要给出所有元素之和对 ({10}^9 + 7) 取模后的值。
(n,|S|leq 5 imes 10^4;mleq 10)。
思路
首先根据给出的表达式建立表达式树。
数组的每一位之间是互相独立的,所以考虑在表达式树上预处理出来所有情况,然后就可以枚举每一位计算答案。
但是如果直接枚举大小排列的话复杂度是 (O(m!|S|)),显然不可接受。
观察到本题只关心数字之间的大小关系,有一个经典的 trick 是考虑一个阈值 (k),把不超过 (k) 的设为 (0),超过 (k) 的设为 (1),然后进行完操作只需要看得到的是 (0) 还是 (1) 就可以判断答案与 (k) 的大小关系。
所以考虑设 (f[x][s][0/1]) 表示表达式树上点 (x) 为根的子树内,(m) 个表达式大小关系状态为 (s),且进行玩子树的表达式后答案在小的那部分 / 大的那部分的方案数。
其中一个状态 (s) 二进制下第 (i) 位表示第 (i) 个序列前这一位(上文提到每一个序列的每一位是独立的)是在小的部分 / 大的部分。
考虑转移,记左右儿子分别为 (lc,rc),拿当前节点的字符为 <
举例,有
>
的话就把 (min) 改为 (max);?
就两个都要转移。
统计答案的部分就直接枚举每一位,然后把 (m) 个序列这一位的数字从小到大排序,依次枚举每一位,假设把这一位划进小的部分前的状态为 (S),划进去后状态为 (T),那么贡献即为
时间复杂度 (O(2^m|S|+nmlog m))。
代码
#include <bits/stdc++.h>
using namespace std;
const int N=50010,M=10,MOD=1e9+7;
int n,m,rt,len,ans,I,a[M][N],b[M],ch[N][2],f[N][1<<M][2];
char s[N];
stack<int> st;
bool cmp(int x,int y)
{
return a[x][I]<a[y][I];
}
void build()
{
for (int i=len;i>=1;i--)
{
if (s[i]!='(') st.push(i);
if (s[i]=='(')
{
int x=st.top(); st.pop();
while (s[st.top()]!=')')
{
int y=st.top(); st.pop();
int z=st.top(); st.pop();
ch[y][0]=z; ch[y][1]=x; x=y;
}
st.pop(); st.push(x);
}
}
rt=st.top();
}
void dfs(int x)
{
if (isdigit(s[x]))
{
int num=s[x]-48;
for (int i=0;i<(1<<m);i++)
f[x][i][(i>>num)&1]=1;
return;
}
int lc=ch[x][0],rc=ch[x][1];
dfs(lc); dfs(rc);
for (int i=0;i<(1<<m);i++)
for (int j=0;j<=1;j++)
for (int k=0;k<=1;k++)
{
int sum=1LL*f[lc][i][j]*f[rc][i][k]%MOD;
if (s[x]=='<' || s[x]=='?') f[x][i][min(j,k)]=(f[x][i][min(j,k)]+sum)%MOD;
if (s[x]=='>' || s[x]=='?') f[x][i][max(j,k)]=(f[x][i][max(j,k)]+sum)%MOD;
}
}
int main()
{
scanf("%d%d",&n,&m);
for (int i=0;i<m;i++)
for (int j=1;j<=n;j++)
scanf("%d",&a[i][j]);
scanf("%s",s+2);
len=strlen(s+2)+1;
s[1]='('; s[++len]=')';
build();
dfs(rt);
for (int i=0;i<m;i++) b[i]=i;
for (int i=1;i<=n;i++)
{
I=i;
sort(b,b+m,cmp);
for (int j=0,S=(1<<m)-1;j<m;j++)
{
ans=(ans+1LL*(f[rt][S^(1<<b[j])][0]-f[rt][S][0])*a[b[j]][i])%MOD;
S^=(1<<b[j]);
}
}
printf("%d
",(ans+MOD)%MOD);
return 0;
}