(又是一道树套树……自己真是玩疯了……)
(题意略)
从网上也看过题解,好像解法很多……比如CDQ+树状数组,树状数组套主席树,树状数组套平衡树……我用的是树状数组套splay。
(我会说是因为我不会写CDQ和树状数组套主席树么= =)
(不得不吐槽,为啥splay这么快= =)
也没啥可说的,我写的是在线算法,只要在删除一个元素之前统计它前面比它大的数和后面比它小的数的个数(区间求和用树状数组,统计比它小/大的数的个数用平衡树写),把答案减掉对应数值即可。
鉴于这题卡常,我就加了快读和各种inline卡常大法,然后卡了不下5次评测机才过……(COGS垃圾评测机)
顺便一提,求初始逆序对可以用归并排序or树状数组,我用的是后者。
贴个代码(貌似这是全程非递归):
#include<cstdio> #include<cstring> #include<algorithm> #define siz(x) ((x)?(x)->size:0) #define lowbit(x) ((x)&(-(x))) using namespace std; namespace mine{ template<class T>inline void readint(T &__x){ static int __c; static bool __neg; __x=0; __neg=false; do __c=getchar();while(__c==' '||__c==' '||__c==' '||__c==' '); if(__c=='-'){ __neg=true; __c=getchar(); } for(;__c>='0'&&__c<='9';__c=getchar())__x=__x*10+(__c^48); if(__neg)__x=-__x; } template<class T>inline void putint(T __x){ static int __a[40],__i,__j; static bool __neg; __neg=__x<0; if(__neg)__x=-__x; __i=0; do{ __a[__i++]=__x%(T)10^(T)48; __x/=10; }while(__x); if(__neg)putchar('-'); for(__j=__i-1;__j^-1;__j--)putchar(__a[__j]); } } using namespace mine; const int maxn=100010; struct node{//Splay Tree int data,size; node *lc,*rc,*prt; node(int d=0):data(d),size(1),lc(NULL),rc(NULL),prt(NULL){} inline void refresh(){size=siz(lc)+siz(rc)+1;} }*root[maxn]={NULL}; void add(int); void query(int); void build(int,int); int qlss(int,int); int qgrt(int,int); void mdel(int,int); void insert(node*,int); node *find(int,int); void erase(node*,int); int rank(int,int); int rerank(int,int); void splay(node*,node*,int); void lrot(node*,int); void rrot(node*,int); node *findmax(node*); int n,m,a[maxn],b[maxn],c[maxn]={0},x; long long ans=0ll; int main(){ #define MINE #ifdef MINE freopen("inverse.in","r",stdin); freopen("inverse.out","w",stdout); #endif readint(n); readint(m); for(int i=1;i<=n;i++){ readint(a[i]); b[a[i]]=i; query(a[i]); add(a[i]); build(i,a[i]); } while(m--){ putint(ans); putchar(' '); readint(x); x=b[x]; ans-=(long long)qgrt(x,a[x])+(long long)qlss(n,a[x])-(long long)qlss(x-1,a[x]); mdel(x,a[x]); } #ifndef MINE printf(" --------------------DONE-------------------- "); for(;;); #endif return 0; } inline void add(int x){ while(x){ c[x]++; x-=lowbit(x); } } inline void query(int x){ while(x<=n){ ans+=c[x]; x+=lowbit(x); } } inline void build(int x,int d){ while(x<=n){ insert(new node(d),x); x+=lowbit(x); } } inline int qlss(int x,int d){ int ans=0; while(x){ ans+=rank(d,x); x-=lowbit(x); } return ans; } inline int qgrt(int x,int d){ int ans=0; while(x){ ans+=rerank(d,x); x-=lowbit(x); } return ans; } inline void mdel(int x,int d){ while(x<=n){ erase(find(d,x),x); x+=lowbit(x); } } inline void insert(node *x,int i){ if(!root[i]){ root[i]=x; return; } node *rt=root[i]; for(;;){ if(x->data<rt->data){ if(rt->lc)rt=rt->lc; else{ rt->lc=x; break; } } else{ if(rt->rc)rt=rt->rc; else{ rt->rc=x; break; } } } x->prt=rt; for(;rt;rt=rt->prt)rt->refresh(); splay(x,NULL,i); } inline node *find(int x,int i){ node *rt=root[i]; while(rt){ if(x==rt->data)return rt; else if(x<rt->data)rt=rt->lc; else rt=rt->rc; } return NULL; } inline void erase(node *x,int i){ splay(x,NULL,i); if(x->lc){ splay(findmax(x->lc),x,i); x->lc->rc=x->rc; if(x->rc)x->rc->prt=x->lc; x->lc->prt=NULL; root[i]=x->lc; x->lc->refresh(); } else{ if(x->rc)x->rc->prt=NULL; root[i]=x->rc; } delete x; } inline int rank(int x,int i){ node *rt=root[i],*y=NULL; int ans=0; while(rt){ y=rt; if(x<=rt->data)rt=rt->lc; else{ ans+=siz(rt->lc)+1; rt=rt->rc; } } if(y)splay(y,NULL,i); return ans; } inline int rerank(int x,int i){ return siz(root[i])-rank(x+1,i); } inline void splay(node *x,node *tar,int i){ for(node *rt=x->prt;rt!=tar;rt=x->prt){ if(rt->prt==tar){ if(x==rt->lc)rrot(rt,i); else lrot(rt,i); break; } if(rt==rt->prt->lc){ if(x==rt->lc)rrot(rt,i); else lrot(rt,i); rrot(x->prt,i); } else{ if(x==rt->rc)lrot(rt,i); else rrot(rt,i); lrot(x->prt,i); } } } inline void lrot(node *x,int i){ node *y=x->rc; if(x->prt){ if(x==x->prt->lc)x->prt->lc=y; else x->prt->rc=y; } else root[i]=y; y->prt=x->prt; x->rc=y->lc; if(y->lc)y->lc->prt=x; y->lc=x; x->prt=y; x->refresh(); y->refresh(); } inline void rrot(node *x,int i){ node *y=x->lc; if(x->prt){ if(x==x->prt->lc)x->prt->lc=y; else x->prt->rc=y; } else root[i]=y; y->prt=x->prt; x->lc=y->rc; if(y->rc)y->rc->prt=x; y->rc=x; x->prt=y; x->refresh(); y->refresh(); } inline node *findmax(node *x){ while(x->rc)x=x->rc; return x; }