题目描述
Manao came up with a solution that produces the correct answers but is too slow. You are given the pseudocode of his solution, where the function getAnswer calculates the answer to the problem:
getAnswer(a[1..n], b[1..len], h)
answer = 0
for i = 1 to n-len+1
answer = answer + f(a[i..i+len-1], b, h, 1)
return answer
f(s[1..len], b[1..len], h, index)
if index = len+1 then
return 1
for i = 1 to len
if s[index] + b[i] >= h
mem = b[i]
b[i] = 0
res = f(s, b, h, index + 1)
b[i] = mem
if res > 0
return 1
return 0
Your task is to help Manao optimize his algorithm.
题目大意
当两个长度相同的数组存在一种两两匹配方式,使得每一对的和不小于 (h),则称之为匹配。
有一个长为 (m) 的数组 (B),求 (A) 有多少长度为 (m) 的子序列(连续)中的数和 (B) 中的数匹配。
思路
两个数组按照以下方式匹配一定是最优的:
令 (A) 的子序列为 (S)
(B) 最大( o S) 最小,(B) 次大( o S) 次小,...,(B) 最小( o S) 最大
因为假设 (S) 有两数 (s_1 le s_2),(B) 有两数 (b_1 le b_2)
不按大对小的匹配方式则有
且此时大对小的匹配方式不满足
若 (s_1+b_2 < h ext{③})
则 ( ext{①-③} implies b_1 > b_2) 矛盾
同理若 (s_2+b_1 < h ext{④})
则 ( ext{①-④} implies s_1 > s_2) 矛盾
所以 (B) 中的最大至少有 (m) 个数在 (S) 中能匹配,(B) 中的次大至少有 (m-1) 个数在 (S) 中能匹配,...,(B) 中的最小至少有 (1) 个数在 (S) 中能匹配
将 (B) 从小到大排序,用线段树维护 (B) 中每个数能被多少个数匹配,将初始可匹配数设为 -1,-2,...,-m
每次将 (a_i) 加入时,找到第一个 (a_i+b_j ge h) 的位置 (j),并将 (B) 中 ([j,m]) 的可匹配数加 1
若线段树维护的可匹配数最小值大于 0,这个子序列就是合法的
#include <functional>
#include <algorithm>
#include <cstdio>
using namespace std;
const int maxn = 1.5e5 + 10;
int n,m,h,a[maxn],b[maxn],minv[maxn<<2],laz[maxn<<2];
inline void pushdown(int root) {
if (laz[root]) {
laz[root<<1] += laz[root];
laz[root<<1|1] += laz[root];
minv[root<<1] += laz[root];
minv[root<<1|1] += laz[root];
laz[root] = 0;
}
}
inline void pushup(int root) { minv[root] = min(minv[root<<1],minv[root<<1|1]); }
inline void update(int ul,int ur,int x,int l = 1,int r = m,int root = 1) {
if (l > ur || r < ul) return;
if (ul <= l && r <= ur) return laz[root] += x,minv[root] += x,void();
int mid = l+r>>1;
pushdown(root);
update(ul,ur,x,l,mid,root<<1);
update(ul,ur,x,mid+1,r,root<<1|1);
pushup(root);
}
int main() {
scanf("%d%d%d",&n,&m,&h);
for (int i = 1;i <= m;i++) { scanf("%d",&b[i]); update(i,m,-1); }
sort(b+1,b+m+1);
for (int i = 1;i <= n;i++) {
scanf("%d",&a[i]);
a[i] = lower_bound(b+1,b+m+1,h-a[i])-b;
if (i <= m) update(a[i],m,1);
}
int ans = minv[1] >= 0;
for (int i = m+1;i <= n;i++) {
update(a[i],m,1);
update(a[i-m],m,-1);
ans += minv[1] >= 0;
}
printf("%d",ans);
return 0;
}