题目大意
对于序列$A$,它的逆序对数定义为满足$i< j$,且$A_i> A_j$的数对$(i,j)$的个数
给$1$到$n$的一个排列,按照某种顺序依次删除$m$个元素,询问在每次删除一个元素之前统计整个序列的逆序对数
数据范围
$n\leqslant 100000,m\leqslant 50000$
题解
如果没有删除操作,则只要跑一遍树状数组或归并排序就可以跑出一开始的逆序对总数
多了删除操作之后,发现每个元素对整个序列的逆序对总数贡献为:
这个元素前面有多少个元素比它大,记录为
front[]
这个元素后面有多少个元素比它小,记录为
back[]
则总贡献就是front[]+back[]
,我们只要找到这样的数据结构,在删除操作的同时维护所有元素的front[]
和back[]
即可
这两种记录方式需要我们维护维护元素之间的大小关系,显然用可持久化线段树(权值线段树),删除时需要维护区间的可持久化线段树,在外套一个树状数组就可以优化区间维护过程
所以在删除的时候用树状数组统计被删除元素前面和后面的整个区间内所有权值线段树的信息,就相当于得到front[]
和back[]
,在一开始的序列跑出逆序对总数,慢慢减就好啦
代码
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#define For(i,l,r) for(int i=l;i<=r;i++)
#define Ford(i,r,l) for(int i=r;i>=l;i--)
#define lowbit(x) (x&-x)
#define ll long long
using namespace std;
inline int read(){
int x=0,f=1;char ch=getchar();
while (ch<'0'||ch>'9') {if (ch=='-') f=-1;ch=getchar();}
while (ch>='0'&&ch<='9') {x=x*10+ch-48;ch=getchar();}
return x*f;
}
int n,m,size;ll ans;
int A[30],B[30];
int a[100050],pos[100050],front[100050],back[100050],c[100050];
int root[100050],ls[5000050],rs[5000050];ll sum[5000050];
inline void insert(int x){
for (;x<=n;x+=lowbit(x)) c[x]++;
}
inline int query(int x){
int result=0;
for (;x;x-=lowbit(x)) result+=c[x];
return result;
}
inline void prepare(){
For(i,1,n){
a[i]=read();pos[a[i]]=i;
ans+=(front[i]=i-1-query(a[i]));
insert(a[i]);
}
memset(c,0,sizeof(c));
Ford(i,n,1){
back[i]=query(a[i]-1);
insert(a[i]);
}
}
void update(int l,int r,int &y,int c){
if (!y) y=++size;
sum[y]++;
if (l==r) return;
int mid=(l+r)>>1;
if (c<=mid) update(l,mid,ls[y],c);
else update(mid+1,r,rs[y],c);
}
inline ll ask_more(int l,int r,int c){
if (l>r) return 0;
l--;A[0]=B[0]=0;
for (int i=l;i;i-=lowbit(i)) A[++A[0]]=root[i];
for (int i=r;i;i-=lowbit(i)) B[++B[0]]=root[i];
l=1;r=n;ll ans=0;
while (l!=r){
int mid=(l+r)>>1;
if (c<=mid){
For(i,1,A[0]) ans-=sum[rs[A[i]]];
For(i,1,B[0]) ans+=sum[rs[B[i]]];
For(i,1,A[0]) A[i]=ls[A[i]];
For(i,1,B[0]) B[i]=ls[B[i]];
r=mid;
}
else {
For(i,1,A[0]) A[i]=rs[A[i]];
For(i,1,B[0]) B[i]=rs[B[i]];
l=mid+1;
}
}
return ans;
}
inline ll ask_less(int l,int r,int c){
if (l>r) return 0;
l--;A[0]=B[0]=0;
for (int i=l;i;i-=lowbit(i)) A[++A[0]]=root[i];
for (int i=r;i;i-=lowbit(i)) B[++B[0]]=root[i];
l=1;r=n;ll ans=0;
while (l!=r){
int mid=(l+r)>>1;
if (c<=mid){
For(i,1,A[0]) A[i]=ls[A[i]];
For(i,1,B[0]) B[i]=ls[B[i]];
r=mid;
}
else {
For(i,1,A[0]) ans-=sum[ls[A[i]]];
For(i,1,B[0]) ans+=sum[ls[B[i]]];
For(i,1,A[0]) A[i]=rs[A[i]];
For(i,1,B[0]) B[i]=rs[B[i]];
l=mid+1;
}
}
return ans;
}
int main(){
n=read(),m=read();
prepare();
For(i,1,m-1){
printf("%lld\n",ans);
int del=read(),x=pos[del];
ans-=front[x]-ask_more(1,x-1,del);
ans-=back[x]-ask_less(x+1,n,del);
for (;x<=n;x+=lowbit(x)) update(1,n,root[x],del);
}
printf("%lld\n",ans);
return 0;
}