前言
有线段树合并就应该有线段树分裂。它是线段树合并的逆过程。具体的,你需要以权值线段树中第 k 小的数为分界线,把线段树分成两半。
算法流程
和线段树上二分类似。假设原来的线段树为 u,要分裂出线段树 v
- 记左子树的权值为 val。
- 如果 k>val,那么分界线在右子树,那么左子树归 u,递归右子树,此时 k=k-val。
- 如果 k==val,那么分界线正好就是mid,那么左子树归 u,右子树归 v。
- 如果 k<val,那么分界线在左子树,那么右子树归 v,递归左子树
- 计算 u,v 的权值 tr[v].val=tr[u].val-k; tr[u].val=k;
核心代码
int split(int u,int v,int st,int ed,int k)
{
if(u==0) return 0;
int mid=st+ed>>1;
tr.push_back(seg());
v=tr.size()-1;
int val=tr[tr[u].ls].val;
if(k>val)
tr[v].rs=split(tr[u].rs,tr[v].rs,mid+1,ed,k-val);
else
swap(tr[u].rs,tr[v].rs);
if(k<val)
tr[v].ls=split(tr[u].ls,tr[v].ls,st,mid,k);
tr[v].val=tr[u].val-k;
tr[u].val=k;
return v;
}
【模板】线段树分裂
题解
操作0
先把线段树分裂成 ( 1 , x − 1 ) , ( x , n ) (1,x-1),(x,n) (1,x−1),(x,n),再把 ( x , n ) (x,n) (x,n) 线段树分裂成 ( x , y ) , ( y + 1 , n ) (x,y),(y+1,n) (x,y),(y+1,n),最后合并线段树 ( 1 , x − 1 ) , ( y + 1 , n ) (1,x-1),(y+1,n) (1,x−1),(y+1,n)。
操作1
线段树合并
操作2
单点加
操作3
区间查
操作4
线段树上二分
代码
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e6+7,inf=1e18;
int n,m;
struct seg
{
int ls,rs,val;
seg():ls(0),rs(0),val(0)
{
}
seg(int a,int b,int c):ls(a),rs(b),val(c)
{
}
};
vector<seg> tr(2);
void update(int u)
{
tr[u].val=tr[tr[u].ls].val+tr[tr[u].rs].val;
}
void insert(int u,int st,int ed,int x,int t)
{
if(st==ed)
{
tr[u].val+=t;
return;
}
int mid=st+ed>>1;
if(x<=mid)
{
if(!tr[u].ls)
{
tr.push_back(seg());
tr[u].ls=tr.size()-1;
}
insert(tr[u].ls,st,mid,x,t);
}
else
{
if(!tr[u].rs)
{
tr.push_back(seg());
tr[u].rs=tr.size()-1;
}
insert(tr[u].rs,mid+1,ed,x,t);
}
update(u);
}
int query(int u,int st,int ed,int l,int r)
{
if(l<=st&&ed<=r)
{
return tr[u].val;
}
int mid=st+ed>>1,res=0;
if(l<=mid)
{
if(tr[u].ls)
res+=query(tr[u].ls,st,mid,l,r);
}
if(mid<r)
{
if(tr[u].rs)
res+=query(tr[u].rs,mid+1,ed,l,r);
}
return res;
}
void merge(int u,int v,int st,int ed)
{
if(st==ed)
{
tr[u].val+=tr[v].val;
return;
}
int mid=st+ed>>1;
if(tr[u].ls&&tr[v].ls)
merge(tr[u].ls,tr[v].ls,st,mid);
else if(tr[v].ls)
tr[u].ls=tr[v].ls;
if(tr[u].rs&&tr[v].rs)
merge(tr[u].rs,tr[v].rs,mid+1,ed);
else if(tr[v].rs)
tr[u].rs=tr[v].rs;
update(u);
}
int split(int u,int v,int st,int ed,int k)
{
if(u==0) return 0;
int mid=st+ed>>1;
tr.push_back(seg());
v=tr.size()-1;
int val=tr[tr[u].ls].val;
if(k>val)
tr[v].rs=split(tr[u].rs,tr[v].rs,mid+1,ed,k-val);
else
swap(tr[u].rs,tr[v].rs);
if(k<val)
tr[v].ls=split(tr[u].ls,tr[v].ls,st,mid,k);
tr[v].val=tr[u].val-k;
tr[u].val=k;
return v;
}
int find(int u,int st,int ed,int k)
{
if(k>tr[u].val||st>ed||k==0)
return -1;
if(st==ed)
{
return st;
}
int mid=st+ed>>1;
int val=tr[tr[u].ls].val;
if(k>val)
{
return find(tr[u].rs,mid+1,ed,k-val);
}
else
{
return find(tr[u].ls,st,mid,k);
}
}
vector<int> rt(1);
void O_o()
{
cin>>n>>m;
rt.push_back(1);
for(int i=1; i<=n; i++)
{
int x;
cin>>x;
insert(rt[1],1,n,i,x);
}
while(m--)
{
int op,id;
cin>>op>>id;
if(op==0)
{
int x,y;
cin>>x>>y;
rt.push_back(0);
int now=rt.size()-1;
int v1=query(rt[id],1,n,1,x-1),v2=query(rt[id],1,n,x,y);
rt[now]=split(rt[id],rt[now],1,n,v1);
int t=split(rt[now],0,1,n,v2);
merge(rt[id],t,1,n);
}
else if(op==1)
{
int t;
cin>>t;
merge(rt[id],rt[t],1,n);
}
else if(op==2)
{
int x,q;
cin>>x>>q;
insert(rt[id],1,n,q,x);
}
else if(op==3)
{
int l,r;
cin>>l>>r;
cout<<query(rt[id],1,n,l,r)<<"\n";
}
else if(op==4)
{
int k;
cin>>k;
cout<<find(rt[id],1,n,k)<<"\n";
}
else assert(0);
}
}
signed main()
{
ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
cout<<fixed<<setprecision(12);
int T=1;
// cin>>T;
while(T--)
{
O_o();
}
}