题意:给定序列,问最多可以分成多少段序列使得每段序列不超过L且异或和不超过X
首先对于区间异或和,很容易想到前缀异或和去优化使其可以在O(1)时间内求出区间异或和,然后我们就可以写出一个n²暴力
#include <map> #include <set> #include <ctime> #include <cmath> #include <queue> #include <stack> #include <vector> #include <string> #include <bitset> #include <cstdio> #include <cstdlib> #include <cstring> #include <sstream> #include <iostream> #include <algorithm> #include <functional> using namespace std; #define For(i, x, y) for(int i=x;i<=y;i++) #define _For(i, x, y) for(int i=x;i>=y;i--) #define Mem(f, x) memset(f,x,sizeof(f)) #define Sca(x) scanf("%d", &x) #define Sca2(x,y) scanf("%d%d",&x,&y) #define Sca3(x,y,z) scanf("%d%d%d",&x,&y,&z) #define Scl(x) scanf("%lld",&x) #define Pri(x) printf("%d ", x) #define Prl(x) printf("%lld ",x) #define CLR(u) for(int i=0;i<=N;i++)u[i].clear(); #define LL long long #define ULL unsigned long long #define mp make_pair #define PII pair<int,int> #define PIL pair<int,long long> #define PLL pair<long long,long long> #define pb push_back #define fi first #define se second typedef vector<int> VI; int read(){int x = 0,f = 1;char c = getchar();while (c<'0' || c>'9'){if (c == '-') f = -1;c = getchar();} while (c >= '0'&&c <= '9'){x = x * 10 + c - '0';c = getchar();}return x*f;} const double PI = acos(-1.0); const double eps = 1e-9; const int maxn = 1e5 + 10; const int INF = 0x3f3f3f3f; const int mod = 268435456; LL N,X,L,P,Q; LL a[maxn],dp[maxn]; LL pre[maxn]; LL sum(int i,int j){ return pre[j] ^ pre[i - 1]; } int main(){ int T; Sca(T); while(T--){ scanf("%lld%lld%lld",&N,&X,&L); scanf("%lld%lld%lld",&a[1],&P,&Q); for(int i = 2; i <= N ; i ++){ a[i] = ((a[i - 1] * P) + Q) % mod; } pre[0] = 0; for(int i = 1; i <= N ; i ++) pre[i] = pre[i - 1] ^ a[i]; for(int i = 0; i <= N ; i ++) dp[i] = 0; for(int i = 1; i <= N; i ++){ for(int j = max(0LL,i - L); j < i ; j ++){ if(sum(j + 1,i) <= X) dp[i] = max(dp[i],dp[j] + 1); } } Prl(dp[N]); } return 0; }
我们可以发现对于pre相同的下标而言,dp的大小呈单调性,即i > j 且pre[i] = pre[j] 则dp[i] > dp[j],由于i,j之间异或和为0,显然dp[i] - dp[j] >= 1
那么对于前面长度L的区间,我们可以考虑用字典树优化,用01字典树维护每个前缀和的dp最大值,由于满足单调性,对于字典树上的删除我们只需要维护每个节点出现的次数,因为只要字典树上还存在当前节点(出现次数不为0),就意味着当前最大值不会变(最大值永远越后面的越大)
对于查询的时候就需要讨论,如果当前位X为0,说明查询的pre当前位上也是0,需要走当前位与他相同的路径,如果X为1,那么可以走与当前位相反的路径使得该位和X一样为1,或者走与其相同的路径使得该位为0,倘若走0的路径,那么直接取子树的最大值不用继续往下走,因为下面无论怎么走都一定比X小
#include <map> #include <set> #include <ctime> #include <cmath> #include <queue> #include <stack> #include <vector> #include <string> #include <bitset> #include <cstdio> #include <cstdlib> #include <cstring> #include <sstream> #include <iostream> #include <algorithm> #include <functional> using namespace std; #define For(i, x, y) for(int i=x;i<=y;i++) #define _For(i, x, y) for(int i=x;i>=y;i--) #define Mem(f, x) memset(f,x,sizeof(f)) #define Sca(x) scanf("%d", &x) #define Sca2(x,y) scanf("%d%d",&x,&y) #define Sca3(x,y,z) scanf("%d%d%d",&x,&y,&z) #define Scl(x) scanf("%lld",&x) #define Pri(x) printf("%d ", x) #define Prl(x) printf("%lld ",x) #define CLR(u) for(int i=0;i<=N;i++)u[i].clear(); #define LL long long #define ULL unsigned long long #define mp make_pair #define PII pair<int,int> #define PIL pair<int,long long> #define PLL pair<long long,long long> #define pb push_back #define fi first #define se second typedef vector<int> VI; int read(){int x = 0,f = 1;char c = getchar();while (c<'0' || c>'9'){if (c == '-') f = -1;c = getchar();} while (c >= '0'&&c <= '9'){x = x * 10 + c - '0';c = getchar();}return x*f;} const int maxn = 1e5 + 10; const int maxm = 5e6 + 10; const LL INF = 1e18; const LL mod = 268435456; LL N,X,L,P,Q; LL a[maxn],dp[maxn],pre[maxn]; int nxt[maxm][2],cnt,num[maxm]; LL val[maxm]; void insert(int j){ LL x = pre[j],v = dp[j]; int p = 1; for(int i = 32; i >= 0; i --){ int id = (x >> i) & 1; if(!nxt[p][id]){ nxt[p][id] = ++cnt; val[cnt] = -INF; num[cnt] = nxt[cnt][0] = nxt[cnt][1] = 0; } p = nxt[p][id]; val[p] = max(val[p],v); num[p]++; } } void del(int p,int i,LL x){ if(i == -1){if(!num[p]) val[p] = -INF;return;} int id = (x >> i) & 1; num[nxt[p][id]]--; del(nxt[p][id],i - 1,x); val[p] = val[nxt[p][id]]; if(nxt[p][id ^ 1] && num[nxt[p][id ^ 1]] > 0) val[p] = max(val[nxt[p][0]],val[nxt[p][1]]); } LL query(LL x){ int p = 1; LL ans = -INF; for(int i = 32; i >= 0 ; i --){ int id = (x >> i) & 1; if((X >> i) & 1){ if(nxt[p][id] && num[nxt[p][id]]){ ans = max(ans,val[nxt[p][id]]); } if(nxt[p][id ^ 1] && num[nxt[p][id ^ 1]]){ p = nxt[p][id ^ 1]; } }else{ if(!nxt[p][id] || !num[nxt[p][id]]) return ans; p = nxt[p][id]; } } ans = max(ans,val[p]); return ans; return val[p]; } int main(){ int T; Sca(T); cnt = 1; while(T--){ for(int i = 0 ; i <= cnt; i ++){val[i] = -INF; nxt[i][0] = nxt[i][1] = num[i] = 0;} scanf("%lld%lld%lld",&N,&X,&L); cnt = 1; scanf("%lld%lld%lld",&a[1],&P,&Q); for(int i = 2; i <= N ; i ++) a[i] = ((a[i - 1] * P) + Q) % mod; pre[0] = dp[0] = 0; insert(0); for(int i = 1; i <= N ; i ++) pre[i] = pre[i - 1] ^ a[i]; // For(i,1,N) cout << pre[i] << " "; // cout << endl; for(int i = 1; i <= N; i ++){ if(i - L - 1 >= 0 && dp[i - L - 1] >= 0) del(1,32,pre[i - L - 1]); dp[i] = query(pre[i]) + 1; if(dp[i] > 0) insert(i); } if(dp[N] < 0) dp[N] = 0; Prl(dp[N]); } return 0; }