1.树状数组
设计二分,二叉树,位运算,前缀和等思想
lowbit = x & -x
功能:找到x的二进制数的最后一个1
1.1 树状数组模板
def lowbit(x):
return x &-x
def add (x,d):
while(x < n) :
tree[x] +=d
x+=lowbit(x)
def sum(x):
ans = 0
while(x >0):
ans += tree[x]
x-=lowbit(x)
return ans
1.2前缀和应用
1.2.1单点修改,区间查询
def lowbit(x):
return x &-x
def add (x,d): # 修改元素a[x],a[x]=a[x]+d
while(x <= N) :
tree[x] +=d
x+=lowbit(x)
def sum(x): # 前缀和思想,返回前缀和sum=a[1]+a[2]+...a[n]
ans = 0
while(x >0):
ans += tree[x]
x-=lowbit(x)
return ans
N=1000
tree =[0]*N
a=[0,4,5,6,7,8,9,10,11,12,13]
for i in range(1,11): # 计算tree数组,即初始化
add(i,a[i])
print("old:[5,8]=",sum(8)-sum(4)) # 区间和查询
add(5,100)
print("new:[5,8]",sum(8)-sum(4))
1.2.2区间修改,区间查询
# python3.6
# -*- coding: utf-8 -*-
# @Time : 2023/4/29 9:15
# @Author : Jin
# @File : 树状数组.py
# @Software: PyCharm
# python3.6
# -*- coding: utf-8 -*-
# @Time : 2023/4/29 9:15
# @Author : Jin
# @File : 树状数组.py
# @Software: PyCharm
def lowbit(x):
return x &-x
def add1 (x,d): # 修改元素a[x],a[x]=a[x]+d
while(x <= N) :
tree1[x] +=d
x+=lowbit(x)
def add2(x,d):
while(x <= N) :
tree2[x] +=d
x+=lowbit(x)
def sum1(x): # 前缀和思想,返回前缀和sum=a[1]+a[2]+...a[n]
ans = 0
while(x >0):
ans += tree1[x]
x-=lowbit(x)
return ans
def sum2(x):
ans = 0
while (x > 0):
ans += tree2[x]
x -= lowbit(x)
return ans
N=10010
tree1 =[0]*N
tree2 =[0]*N #2个树状数组
n,m = map(int,input().split())
old=0
a=[0]+[int(i) for i in input().split()] # 不用a[0]
for i in range(1,n+1): # 计算tree数组,即初始化
add1(i,a[i]-old) # 差分数组原理初始化
add2(i,(i-1)*(a[i]-old))
old=a[i]
for _ in range(m):
g = [int(i) for i in input().split()]
if (g[0]==1): # 区间修改
L,R,d = g[1],g[2],g[3]
add1(L,d) # 第一个树状数组,差分
add1(R+1,-d)
add2(L,d*(L-1)) # 第二个树状数组,前缀和
add2(R+1,-d*R) # d*R = d*(R+1-1)
else: # 区间询问
L,R = g[1],g[2]
print(R*sum1(R)-sum2(R)-(L-1)*sum1(L-1)+sum2(L-1))
1.2.3 逆序对问题(归并排序)
def merge_sort(L,R):
if L < R:
mid = (L+R)//2
merge_sort(L,mid)
merge_sort(mid+1,R)
merge(L,mid,R)
def merge(L,mid,R):
global res # 记录答案
i=L;j=mid+1;t=0
while(i<=mid and j<=R): #归并
a[i]
a[j]
if (a[i]>a[j]): #4 5 / 2 3 L=0 mid=1,R=3
b[t]=a[j];t+=1;j+=1;
res = res+(mid-i+1) # 记录逆序对数量
else:
b[t] = a[i];t += 1;i += 1
# 其中一个处理完了,把剩下的复制过来,直接整体复制
# 这里注意区间取值,采用的是整体复制的思想,b是辅助数组
if i<=mid: b[t:R-L+1]=a[i:mid+1] # 取不到mid+1
elif j<=R:b[t:R-L+1]=a[j:]
# 把排序好的b[]复制回去a[]
a[L:R+1]=b[:R-L+1]
n= int(input())
a = list(map(int,input().split()))
b = [0]*n
res = 0
merge_sort(0,n-1)
print(res)
1.2.4逆序对问题(树状数组)
def lowbit(x):
return x&-x
def update(x,d): # 更新为 +lowbit(x)
while(x<=n):
tree[x]+=d
x+=lowbit(x)
def sum(x): # 求和为 -lowbit(x)
ans=0
while(x>0):
ans+=tree[x]
x-=lowbit(x)
return ans
n=eval(input())
a = [0]+list(map(int,input().split())) #从a[1]开始
b=sorted(a) # 从小到大排序
for i in range(n+1): # 将a更新为排序后的索引元素+1 [0 1 4 2] -> [1 2 4 3],即转为树状数组下标
a[a.index(b[i])]=i+1
tree = [0]*(n+1) # 下标从1开始
res =0
for i in range(len(a)-1,0,-1): # 倒序处理求逆序对
update(a[i],1) # 更新a[i]+1
res+=sum(a[i]-1)
print(res)
1.2.5将元素离散化
def discretization(h):
b = list(set(h)) # 去重,使得离散化后一样
b.sort()
for i in range(len(h)):
h[i]=b.index(h[i])+1
a=[1,20543,19,376,5460007640,19]
print(a)
discretization(a)
print(a)
2.线段树
2.1线段树介绍
基于二叉树(通过数字模拟二叉树),二分法(mid=(left+right)//2),递归(sys.setrecursionlimit(300000))
应用背景:
- 解决区间查询
- 区间修改问题
- 多次区间查询最值和区间修改
线段树构建
2.2线段树的树结构
"""
定义根接点为 tree[1]
通过数组模拟存储,空间开为元素个数的四倍,即 4*N
tree[p]
tree[2p] #左儿子
tree[2p+1] #右儿子
"""
2.3 利用线段树求最大数例题模板,即单点修改和查询操作模板
"""
def build(p,pl,pr): #建树
if pl=pr:
tree[p]=
return
mid=(pl+pr)//2
build(2*p,pl,mid)
build(2*p+1,mid+1,pr)
tree[p]=max( tree[2*p],tree[2*p+1]) # 这步关键
def update(p,pl,pr,L,R,d):
if L<=pl and pr <=R: # 说明当前区间完全包含在要查询的区间中
tree[p]=d
return
# 没有完全包含
mid = (pl+pr)//2
if L<=mid: #查左边
update(2*p,pl,mid,L,R,d)
if R>mid: #查右边
update(2*p+1,mid+1,pr,L,R,d)
tree[p]=max(tree[2*p],tree[2*p+1])
return
def query(p,pl,pr,L,R): # 查 [L R] 区间最值
res =-inf
if L<=pl and pr <=R: # 说明当前区间完全包含在要查询的区间中
return tree[p]
# 没有完全包含
mid = (pl+pr)//2
if L<=mid: #查左边
return max(res , query(2*p,pl,mid,L,R))
if R>mid: #查右边
return max(res , query(2*p+1,mid+1,pr,L,R))
tree[p]=max(tree[2*p],tree[2*p+1])
return
"""
注意递归问题
L≤mid :递归[ pl , mid ]
R>mid :递归[ mid+1 , pr ]
import sys
import collections
import itertools
import heapq
sys.setrecursionlimit(300000)
N=100001
inf=2**50
tree = [0]*4*N # 初始化树的大小
def build(p,pl,pr):
if (pl==pr):
tree[p]=-inf
return
mid=(pl+pr)//2
build(2*p,pl,mid) #递归构建左孩子
build(2*p,mid+1,pr) #递归构造右孩子
tree[p]=max(tree[2*p],tree[2*p+1]) #即 push_up操作
def update(p,pl,pr,L,R,d):
if L<=pl and pr <=R: # 说明当前区间完全包含在要查询的区间中
tree[p]=d
return
# 没有完全包含
mid = (pl+pr)//2
if L<=mid: #查左边
update(2*p,pl,mid,L,R,d)
if R>mid: #查右边
update(2*p+1,mid+1,pr,L,R,d)
tree[p]=max(tree[2*p],tree[2*p+1])
return
def query(p,pl,pr,L,R):
res =-inf
if L<=pl and pr <=R: # 说明当前区间完全包含在要查询的区间中
return tree[p]
# 没有完全包含
mid = (pl+pr)//2
if L<=mid: #查左边
return max(res , query(2*p,pl,mid,L,R))
if R>mid: #查右边
return max(res , query(2*p+1,mid+1,pr,L,R))
tree[p]=max(tree[2*p],tree[2*p+1])
return
2.4 区间修改的Lazy-tag技术
内部思想
多次区间修改可能会有冲突,需要push_down()函数解决
2.4.1区间修改,区间查询例题
import sys
import collections
import itertools
import heapq
sys.setrecursionlimit(300000)
def build(p,pl,pr):
if (pl==pr):
tree[p]=a[pl]
return
mid=(pl+pr)//2
build(2*p,pl,mid) #递归构建左孩子
build(2*p+1,mid+1,pr) #递归构造右孩子
tree[p]=tree[2*p]+tree[2*p+1] #记录区间和, push_up操作
def update(p,pl,pr,L,R,d):
if L<=pl and pr <=R: # 说明当前区间完全包含在要查询的区间中
addtag(p,pl,pr,d)
return
# 没有完全包含
push_down(p,pl,pr) # 将懒惰标记传递下去,如果修改区间重叠会有问题
mid = (pl+pr)//2
if L<=mid: #查左边
update(2*p,pl,mid,L,R,d)
if R>mid: #查右边
update(2*p+1,mid+1,pr,L,R,d)
tree[p]=tree[2*p]+tree[2*p+1]
return
def addtag(p,pl,pr,d): # 给结点p打上标记同时更新tree[p]
tag[p]+=d
tree[p]+=d*(pr-pl+1)
def push_down(p,pl,pr):
if tag[p]>0: # 有tag标记,需要传递并清空
mid=(pl+pr)//2
addtag(2*p,pl,mid,tag[p]) # 传给左孩子
addtag(2*p+1,mid+1,pr,tag[p]) # 传给右孩子
tag[p]=0 # 清空当前tag标记
def query(p,pl,pr,L,R):
if L<=pl and pr <=R: # 说明当前区间完全包含在要查询的区间中
return tree[p]
# 没有完全包含
push_down(p,pl,pr) # 如果查询的是标记内部子区间就会有问题!!
mid = (pl+pr)//2
res=0
if L<=mid: #查左边
res +=query(2*p,pl,mid,L,R)
if R>mid: #查右边
res +=query(2*p+1,mid+1,pr,L,R)
return res
n,m = map(int,input().split())
a=[0]+list(map(int,input().split()))
tag=[0]*4*len(a)
tree=[0]*4*len(a)
# 建树
build(1,1,n)
for i in range(m):
w=list(map(int,input().split()))
if len(w)==3: # 区间查询,查询[L,R]区间和
q,L,R=w
print(query(1,1,n,L,R))
else: # 区间修改,把[L,R]的每个元素加上d
q,L,R,d=w
update(1,1,n,L,R,d)