线段树(Segment Tree)和树状数组
- 线段树的实现
- 链式:
- 数组实现
- 解题思路
- 树状数组
线段树是 二叉树结构 的衍生,用于高效解决区间查询和动态修改的问题,其中区间查询的时间复杂度为 O(logN),动态修改单个元素的时间复杂度为x O(logN)。
这棵二叉树的叶子节点是数组中的元素,非叶子节点就是索引区间(线段)的汇总信息.
线段树结构可以有多种变体及复杂的优化,我们这里只聚焦最核心的两个 API:
class SegmentTree {
// 构造函数,给定一个数组,初始化线段树,时间复杂度 O(N)
// merge 是一个函数,用于定义 query 方法的行为
// 通过修改这个函数,可以让 query 函数返回区间的元素和、最大值、最小值等
public SegmentTree(int[] nums, Function<Integer, Integer> merge) {}
// 查询闭区间 [i, j] 的元素和(也可能是最大最小值,取决于 merge 函数),时间复杂度 O(logN)
public int query(int i, int j) {}
// 更新 nums[i] = val,时间复杂度 O(logN)
public void update(int i, int val) {}
}
- suffixMin[i] 可以在 O(1) 时间内查询; nums[i…] 后缀的最小值线段树的 query 方法不仅可以查询后缀,还可以查询任意 [i, j] 区间,时间复杂度均为
O(logN)。 - 当底层 nums 数组中的任意元素变化时,需要重新计算 suffixMin 数组,时间复杂度为 O(N);而线段树的 update 方法可以在 O(logN) 时间内完成元素的修改。
- 线段树不仅仅支持计算区间的最小值,只要修改 merge 函数,就可以支持计算区间元素和、最大值、乘积等。
线段树的实现
线段树是一种二叉树,所以可以用二叉树的方式实现,包括链式实现和数组实现两种:
链式:
from typing import Callable
# 线段树节点
class SegmentNode:
# 该节点表示的区间范围 [l, r]
def __init__(self, merge_val: int, l: int, r: int):
# [l, r] 区间元素的聚合值(如区间和、区间最大值等)
self.l = l
self.r = r
self.merge_val = merge_val
self.left = None
self.right = None
class SegmentTree:
def __init__(self, nums: list, merger: Callable[[int, int], int]):
# 创建线段树
# 输入数组 nums 和一个聚合函数 merger,merger 用于计算区间的聚合值
self.merger = merger
self.root = self.build(nums, 0, len(nums) - 1)
# 定义:将 nums[l..r] 中的元素构建成线段树,返回根节点
def build(self, nums: list, l: int, r: int) -> SegmentNode:
# 区间内只有一个元素,直接返回
if l == r:
return SegmentNode(nums[l], l, r)
# 从中间切分,递归构建左右子树
mid = l + (r - l) // 2
left = self.build(nums, l, mid)
right = self.build(nums, mid + 1, r)
# 根据左右子树的聚合值,计算当前根节点的聚合值
node = SegmentNode(self.merger(left.merge_val, right.merge_val), l, r)
# 组装左右子树
node.left = left
node.right = right
return node
def update(self, index: int, value: int):
self._update(self.root, index, value)
def _update(self, node: SegmentNode, index: int, value: int):
if node.l == node.r:
# 找到了目标叶子节点,更新值
node.merge_val = value
return
mid = node.l + (node.r - node.l) // 2
if index <= mid:
# 若 index 较小,则去左子树更新
self._update(node.left, index, value)
else:
# 若 index 较大,则去右子树更新
self._update(node.right, index, value)
# 后序位置,左右子树已经更新完毕,更新当前节点的聚合值
node.merge_val = self.merger(node.left.merge_val, node.right.merge_val)
def query(self, qL: int, qR: int) -> int:
return self._query(self.root, qL, qR)
def _query(self, node: SegmentNode, qL: int, qR: int) -> int:
if qL > qR:
raise ValueError("Invalid query range")
if node.l == qL and node.r == qR:
# 命中了目标区间,直接返回
return node.merge_val
# 未直接命中区间,需要继续向下查找
mid = node.l + (node.r - node.l) // 2
if qR <= mid:
# node.l <= qL <= qR <= mid
# 目标区间完全在左子树中
return self._query(node.left, qL, qR)
elif qL > mid:
# mid < qL <= qR <= node.r
# 目标区间完全在右子树中
return self._query(node.right, qL, qR)
else:
# node.l <= qL <= mid < qR <= node.r
# 目标区间横跨左右子树
# 将查询区间拆分成 [qL, mid] 和 [mid + 1, qR] 两部分,分别向左右子树查询
# 最后将左右子树的查询结果合并
return self.merger(
self._query(node.left, qL, mid),
self._query(node.right, mid + 1, qR)
)
# Example usage
if __name__ == "__main__":
arr = [1, 3, 5, 7, 9]
# 示例,创建一棵求和线段树
st = SegmentTree(arr, lambda a, b: a + b)
print(st.query(1, 3)) # 3 + 5 + 7 = 15
st.update(2, 10)
print(st.query(1, 3)) # 3 + 10 + 7 = 20
数组实现
from typing import Callable
class ArraySegmentTree:
# 用数组存储线段树结构
def __init__(self, nums: list[int], merger: Callable[[int, int], int]):
# 元素个数
self.n = len(nums)
self.merger = merger
# 分配 4 倍数组长度的空间,存储线段树
self.tree = [0] * (4 * self.n)
self.build(nums, 0, self.n - 1, 0)
# 定义:对 nums[l..r] 区间的元素构建线段树,rootIndex 是根节点
def build(self, nums: list[int], l: int, r: int, rootIndex: int):
if l == r:
# 区间内只有一个元素,设置为叶子节点
self.tree[rootIndex] = nums[l]
return
# 从中间切分,递归构建左右子树
mid = l + (r - l) // 2
leftRootIndex = self.leftChild(rootIndex)
rightRootIndex = self.rightChild(rootIndex)
# 递归构建 nums[l..mid],根节点为 leftRootIndex
self.build(nums, l, mid, leftRootIndex)
# 递归构建 nums[mid+1..r],根节点为 rightRootIndex
self.build(nums, mid + 1, r, rightRootIndex)
# 后序位置,左右子树已经构建完毕,更新当前节点的聚合值
self.tree[rootIndex] = self.merger(self.tree[leftRootIndex], self.tree[rightRootIndex])
def update(self, index: int, value: int):
self._update(0, self.n - 1, 0, index, value)
# 当前节点为 rootIndex,对应的区间为 [l, r]
# 去子树更新 nums[index] 为 value
def _update(self, l: int, r: int, rootIndex: int, index: int, value: int):
if l == r:
# 找到了目标叶子节点,更新值
self.tree[rootIndex] = value
return
mid = l + (r - l) // 2
if index <= mid:
# 若 index 较小,则去左子树更新
self._update(l, mid, self.leftChild(rootIndex), index, value)
else:
# 若 index 较大,则去右子树更新
self._update(mid + 1, r, self.rightChild(rootIndex), index, value)
# 后序位置,左右子树已经更新完毕,更新当前节点的聚合值
self.tree[rootIndex] = self.merger(
self.tree[self.leftChild(rootIndex)],
self.tree[self.rightChild(rootIndex)]
)
def query(self, qL: int, qR: int) -> int:
if qL < 0 or qR >= self.n or qL > qR:
raise ValueError(f"Invalid range: [{qL}, {qR}]")
return self._query(0, self.n - 1, 0, qL, qR)
def _query(self, l: int, r: int, rootIndex: int, qL: int, qR: int) -> int:
if qL == l and r == qR:
# 命中了目标区间,直接返回
return self.tree[rootIndex]
mid = l + (r - l) // 2
leftRootIndex = self.leftChild(rootIndex)
rightRootIndex = self.rightChild(rootIndex)
if qR <= mid:
# node.l <= qL <= qR <= mid
# 目标区间完全在左子树中
return self._query(l, mid, leftRootIndex, qL, qR)
elif qL > mid:
# mid < qL <= qR <= node.r
# 目标区间完全在右子树中
return self._query(mid + 1, r, rightRootIndex, qL, qR)
else:
# node.l <= qL <= mid < qR <= node.r
# 目标区间横跨左右子树
# 将查询区间拆分成 [qL, mid] 和 [mid + 1, qR] 两部分,分别向左右子树查询
return self.merger(
self._query(l, mid, leftRootIndex, qL, mid),
self._query(mid + 1, r, rightRootIndex, mid + 1, qR)
)
def leftChild(self, pos: int) -> int:
return 2 * pos + 1
def rightChild(self, pos: int) -> int:
return 2 * pos + 2
if __name__ == "__main__":
arr = [1, 3, 5, 7, 9]
# 示例,创建一棵求和线段树
st = ArraySegmentTree(arr, lambda a, b: a + b)
print(st.query(1, 3)) # 3 + 5 + 7 = 15
st.update(2, 10)
print(st.query(1, 3)) # 3 + 10 + 7 = 20
307区域和检索
解题思路
针对不同的题目,我们有不同的方案可以选择(假设我们有一个数组):
1.数组不变,求区间和:「前缀和」、「树状数组」、「线段树」
2.多次修改某个数(单点),求区间和:「树状数组」、「线段树」
3.多次修改某个区间,输出最终结果:「差分」
4.多次修改某个区间,求区间和:「线段树」、「树状数组」(看修改区间范围大小)
5.多次将某个区间变成同一个数,求区间和:「线段树」、「树状数组」(看修改区间范围大小)
这样看来,「线段树」能解决的问题是最多的,那我们是不是无论什么情况都写「线段树」呢?
答案并不是,而且恰好相反,只有在我们遇到第 4 类问题,不得不写「线段树」的时候,我们才考虑线段树。
因为「线段树」代码很长,而且常数很大,实际表现不算很好。我们只有在不得不用的时候才考虑「线段树」。
总结一下,我们应该按这样的优先级进行考虑:
简单求区间和,用「前缀和」
多次将某个区间变成同一个数,用「线段树」
其他情况,用「树状数组」
作者:宫水三叶
链接:https://leetcode.cn/problems/range-sum-query-mutable/solutions/632515/guan-yu-ge-lei-qu-jian-he-wen-ti-ru-he-x-41hv/
来源:力扣(LeetCode)
著作权归作者所有。
树状数组
树状数组是用tree[i]维护从某个值开始到i的区间的元素和,下标从1开始
prefixsum[i]表示前i个数的前缀和,这样可以简单求[left,right]区间的元素和
sum[left,right] = prefixsum[right+1]-prefixsum[left]
prefixsum[i]被分解为几个小区间tree[i]的和,
#倒着分解,每次区间的长度为二进制的最低位bitlower 可以用`i&-i`计算
prefixsum[5] = [5,5]+[1,4] = tree[5]+tree[4]
prefixsum[11] = [11,11]+[9,10]+[1,8] = tree[11]+tree[10]+tree[8]
class NumArray:
def __init__(self, nums: List[int]):
n = len(nums)
self.nums = [0]*n
self.tree = [0]*(n+1) #初始化当成nums[i]每个元素从0更新到nums[i]
for i,x in enumerate(nums):
self.update(i,x)
def update(self, index: int, val: int) -> None:
delta = val-self.nums[index]
self.nums[index]=val
i = index+1
while i<len(self.tree):
self.tree[i]+=delta
i+=i&-i # 下一个包含nums[index]的数字
def sumRange(self, left: int, right: int) -> int:
return self.prefixsum(right+1)-self.prefixsum(left)
def prefixsum(self,i):
ans = 0
while i:
ans+=self.tree[i]
i-= i&-i #
return ans
# Your NumArray object will be instantiated and called as such:
# obj = NumArray(nums)
# obj.update(index,val)
# param_2 = obj.sumRange(left,right)