题目大意:给一个整数序列,统计四元组(a,b,c,d)的个数,满足条件1:a<>b<>c<>d;条件2:<a,b>组成一个顺序对,<c,d>组成一个逆序对。(a、b、c、d均为下标)
代码如下:从所有四元组中减去不满足条件的四元组。用顺序对数乘以逆序对数得到只满足条件2的四元组数目sum,从sum减去不满足条件1的四元组数目便是答案。sum中只有四种不满足条件1的情况,即:a=c、a=d、b=c、b=d。四种情况的含义分别为顺序对的左端点与逆序对的左端点相同、顺序对的左端点与逆序对的右端点相同、顺序对的右端点与逆序对的左端点相同、顺序对的右端点与逆序对的右端点相同。维护四个数组A、B、C、D,数组中的位置 i 分别表示以 i 为右端点的顺序、逆序以及以 i 为左端点的顺序、逆序对数,便可计算出相应情况下多余的四元组数目。
这四个数组是通过维护树状数组获得的,但是通过维护线段树却超时啦。。。
代码如下:
# include<iostream> # include<cstdio> # include<map> # include<vector> # include<cstring> # include<algorithm> using namespace std; # define LL long long const int N=50000; int a[N+5]; map<int,int>mp; vector<int>v; int vis[N+5]; int cnt[N+5]; int A[N+5];//以i为右端点的顺序对个数 int B[N+5];//以i为右端点的逆序对个数 int C[N+5];//以i为左端点的顺序对个数 int D[N+5];//以i为左端点的逆序对个数 int lowbit(int x) { return x&(-x); } int query(int x) { int res=0; while(x>=1){ res+=cnt[x]; x-=lowbit(x); } return res; } void update(int x,int n) { while(x<=n){ ++cnt[x]; x+=lowbit(x); } } int main() { int n; while(~scanf("%d",&n)) { v.clear(); mp.clear(); for(int i=1;i<=n;++i){ scanf("%d",a+i); v.push_back(a[i]); } sort(v.begin(),v.end()); int m=unique(v.begin(),v.end())-v.begin(); for(int i=0;i<m;++i) mp[v[i]]=i+1; memset(cnt,0,sizeof(cnt)); memset(vis,0,sizeof(vis)); LL sum1=0,sum2=0; for(int i=1;i<=n;++i){ A[i]=query(mp[a[i]]-1); B[i]=i-1-A[i]-vis[mp[a[i]]]; update(mp[a[i]],m+1); ++vis[mp[a[i]]]; sum1+=A[i]; sum2+=B[i]; } memset(cnt,0,sizeof(cnt)); memset(vis,0,sizeof(vis)); for(int i=n;i>=1;--i){ D[i]=query(mp[a[i]]-1); C[i]=n-i-D[i]-vis[mp[a[i]]]; update(mp[a[i]],m+1); ++vis[mp[a[i]]]; } LL sum3=0; for(int i=1;i<=n;++i){ sum3+=(LL)C[i]*(LL)D[i]; sum3+=(LL)C[i]*(LL)B[i]; sum3+=(LL)A[i]*(LL)D[i]; sum3+=(LL)A[i]*(LL)B[i]; } printf("%lld ",sum1*sum2-sum3); } return 0; }
线段树的超时代码:
# include<iostream> # include<cstdio> # include<map> # include<vector> # include<cstring> # include<algorithm> using namespace std; # define LL long long # define mid (l+(r-l)/2) const int N=50000; int a[N+5]; vector<int>v; int A[N+5];//以i为右端点的顺序对个数 int B[N+5];//以i为右端点的逆序对个数 int C[N+5];//以i为左端点的顺序对个数 int D[N+5];//以i为左端点的逆序对个数 int tr[N*4+5]; map<int,int>mp; void pushUp(int rt) { tr[rt]=tr[rt<<1]+tr[rt<<1|1]; } void build(int rt,int l,int r) { tr[rt]=0; if(l==r) return ; build(rt<<1,l,mid); build(rt<<1|1,mid+1,r); } void update(int rt,int l,int r,int pos) { if(l==r) ++tr[rt]; else{ if(pos<=mid) update(rt<<1,l,mid,pos); else update(rt<<1|1,mid+1,r,pos); pushUp(rt); } } int query(int rt,int l,int r,int L,int R) { if(L<=l&&r<=R) return tr[rt]; int res=0; if(L<=mid) res+=query(rt<<1,l,mid,L,R); if(R>mid) res+=query(rt<<1|1,mid+1,r,L,R); return res; } int main() { int n; while(~scanf("%d",&n)) { mp.clear(); v.clear(); for(int i=1;i<=n;++i){ scanf("%d",a+i); v.push_back(a[i]); } sort(v.begin(),v.end()); int len=unique(v.begin(),v.end())-v.begin(); for(int i=0;i<len;++i) mp[v[i]]=i+1; build(1,0,len+1); LL sum1=0,sum2=0; for(int i=1;i<=n;++i){ A[i]=query(1,0,len+1,0,mp[a[i]]-1); B[i]=query(1,0,len+1,mp[a[i]]+1,len+1); update(1,0,len+1,mp[a[i]]); sum1+=A[i]; sum2+=B[i]; } build(1,0,len+1); for(int i=n;i>=1;--i){ C[i]=query(1,0,len+1,mp[a[i]]+1,len+1); D[i]=query(1,0,len+1,0,mp[a[i]]-1); update(1,0,len+1,mp[a[i]]); } LL sum3=0; for(int i=1;i<=n;++i){ sum3+=(LL)C[i]*(LL)D[i]; sum3+=(LL)C[i]*(LL)B[i]; sum3+=(LL)A[i]*(LL)D[i]; sum3+=(LL)A[i]*(LL)B[i]; } printf("%lld ",sum1*sum2-sum3); } return 0; }