目录
一、线段树算法的概念
二、为什么需要线段树
三、线段树算法的实现
(1)建树
(2)查询
(3)修改
(4)综合代码,求区间和
(5)综合代码,求区间最大值
四、Lazy标记
一、线段树算法的概念
线段树(Segment Tree)是一种基于二分思想的数据结构,常常用于处理区间查询和区间修改。线段树的常用操作包括建树、查询、修改。
线段树的建树过程可以使用递归实现,也可以使用非递归实现(通常使用栈来实现)。
线段树的查询和修改基本都是从根节点开始,往下遍历到叶子节点或者与查询区间(或修改区间)不相交的节点为止。线段树相关问题经常需要使用懒惰标记(Lazy Tag)来优化。
线段树常用于以下场景:区间最值查询、区间求和、区间修改等
二、为什么需要线段树
考虑这样两个场景:
- 对于一个长度为 n 的数组,现在给定 l,r 让你求 l 到 r 所有元素的和,有多个这样的询问.
- 对于一个长度为 n 的数组,现在对数组的第 k 个元素进行修改后,给定 l,r 让你求 l 到 r 所有元素的和,有多个这样的询问.
大家看到第一种情况的时候,这不就是前缀和,是的.第二种情况呢,前缀和还能不能用,显然每次修改之后,前缀和就不能使用了,所以又退化为 O(n) 的时间复杂度了.
此时我们就需要用到我们的线段树了.
对于线段树中的每一个非叶子节点[a,b],它的左儿子表示的区间为[a,(a+b)/2],右儿子表示的区间为[(a+b)/2+1,b]。最后的子节点数目为 N,即整个线段区间的长度。
我们看一下 1-10 的线段树是如何存储的.
三、线段树算法的实现
(1)建树
void build(int p, int l, int r) { t[p].l = l, t[p].r = r; // 节点p代表区间[l,r] if (l == r) { t[p].dat = a[l]; return; } // 叶节点 int mid = (l + r) / 2; // 折半 build(p*2, l, mid); // 左子节点[l,mid],编号p*2 build(p*2+1, mid+1, r); // 右子节点[mid+1,r],编号p*2+1 // t[p].dat = max(t[p*2].dat, t[p*2+1].dat); // 从下往上传递信息 t[p].dat = t[p*2].dat+t[p*2+1].dat; // 从下往上传递信息 }
(2)查询
int ask(int p, int l, int r) { if (l <= t[p].l && r >= t[p].r) return t[p].dat; // 完全包含,直接返回 int mid = (t[p].l + t[p].r) / 2; int val = 0; if (l <= mid) val = val+ ask(p*2, l, r); // 左子节点有重叠 if (r > mid) val = val+ask(p*2+1, l, r); // 右子节点有重叠 return val; }
(3)修改
void change(int p, int x, int v) { if (t[p].l == t[p].r) { t[p].dat = v; return; } // 找到叶节点 int mid = (t[p].l + t[p].r) / 2; if (x <= mid) change(p*2, x, v); // x属于左半区间 else change(p*2+1, x, v); // x属于右半区间 // t[p].dat = max(t[p*2].dat, t[p*2+1].dat); // 从下往上更新信息 t[p].dat = t[p*2].dat+t[p*2+1].dat; }
(4)综合代码,求区间和
#include<bits/stdc++.h> using namespace std; const int SIZE=11; int a[11]={0,1,2,3,4,5,6,7,8,9,10}; //原始数据 struct SegmentTree { int l, r; int dat; } t[SIZE * 4]; // struct数组存储线段树 void build(int p, int l, int r) { t[p].l = l, t[p].r = r; // 节点p代表区间[l,r] if (l == r) { t[p].dat = a[l]; return; } // 叶节点 int mid = (l + r) / 2; // 折半 build(p*2, l, mid); // 左子节点[l,mid],编号p*2 build(p*2+1, mid+1, r); // 右子节点[mid+1,r],编号p*2+1 // t[p].dat = max(t[p*2].dat, t[p*2+1].dat); // 从下往上传递信息 t[p].dat = t[p*2].dat+t[p*2+1].dat; // 从下往上传递信息 } void change(int p, int x, int v) { if (t[p].l == t[p].r) { t[p].dat = v; return; } // 找到叶节点 int mid = (t[p].l + t[p].r) / 2; if (x <= mid) change(p*2, x, v); // x属于左半区间 else change(p*2+1, x, v); // x属于右半区间 // t[p].dat = max(t[p*2].dat, t[p*2+1].dat); // 从下往上更新信息 t[p].dat = t[p*2].dat+t[p*2+1].dat; } int ask(int p, int l, int r) { if (l <= t[p].l && r >= t[p].r) return t[p].dat; // 完全包含,直接返回 int mid = (t[p].l + t[p].r) / 2; int val = 0; if (l <= mid) val = val+ ask(p*2, l, r); // 左子节点有重叠 if (r > mid) val = val+ask(p*2+1, l, r); // 右子节点有重叠 return val; } int main() { //建树从根节点一点一点往下建立,所以第一个参数就是1号编号 build(1,1,10); //查询区间[4,7]的和,第一个参数是1的原因是查询要从根节点开始递归 int ans=ask(1,4,7); cout<<ans; //修改位置4的值变为25,第一个参数是1的原因是修改也要从根节点开始一步一步往下进行修改 change(1,4,25); ans=ask(1,4,7); cout<<ans; return 0; }
(5)综合代码,求区间最大值
#include<bits/stdc++.h> using namespace std; const int SIZE=11; int a[11]={0,1,2,3,4,5,6,7,8,9,10}; //原始数据 struct SegmentTree { int l, r; int dat; } t[SIZE * 4]; // struct数组存储线段树 void build(int p, int l, int r) { t[p].l = l, t[p].r = r; // 节点p代表区间[l,r] if (l == r) { t[p].dat = a[l]; return; } // 叶节点 int mid = (l + r) / 2; // 折半 build(p*2, l, mid); // 左子节点[l,mid],编号p*2 build(p*2+1, mid+1, r); // 右子节点[mid+1,r],编号p*2+1 // t[p].dat = max(t[p*2].dat, t[p*2+1].dat); // 从下往上传递信息 t[p].dat = max( t[p * 2].dat , t[p * 2 + 1].dat); // 从下往上传递信息 } void change(int p, int x, int v) { if (t[p].l == t[p].r) { t[p].dat = v; return; } // 找到叶节点 int mid = (t[p].l + t[p].r) / 2; if (x <= mid) change(p*2, x, v); // x属于左半区间 else change(p*2+1, x, v); // x属于右半区间 // t[p].dat = max(t[p*2].dat, t[p*2+1].dat); // 从下往上更新信息 t[p].dat = max( t[p * 2].dat , t[p * 2 + 1].dat); } int ask(int p, int l, int r) { if (l <= t[p].l && r >= t[p].r) return t[p].dat; // 完全包含,直接返回 int mid = (t[p].l + t[p].r) / 2; int val = 0; if (l <= mid) val = max(val, ask(p * 2, l, r)); // 左子节点有重叠 if (r > mid) val = max(val, ask(p * 2 + 1, l, r)); // 右子节点有重叠 return val; } int main() { build(1,1,10); int ans=ask(1,4,7); cout<<ans; change(1,4,25); ans=ask(1,4,7); cout<<ans; return 0; }
四、Lazy标记
这种类型的题目,一般都是这样问的:如果每次是对一个区间进行修改,比如让 l,r 区间内的每个值都加 30.然后求和。
如果我们换成对于点的修改,那么时间复杂就太高了.那我们怎么办呢?
我们可以使用 Lazy 标记的方式,进行处理,什么是 Lazy 标记?
若在一次修改操作中发现节点 p 所代表的区间 [pl,pr] 被修改区间 [l,r] 完全覆盖,并且随后的查询操作没有利用到范围 [l,r] 的子区间作为候选答案,那么对节点 p 及其子树进行的更新操作将是没有实际效果的。此情况下,我们需要考虑优化算法,避免对整棵子树进行无意义的更新。
在执行修改指令时,如果发现存在 l < pl < pr < r 的情况,可以立即返回,并在回溯之前向节点 p 添加一个 Lazy 标记,用于表示 '该节点曾被修改,但其子节点尚未被更新'。
在后续的指令中,若需要向下递归至节点 p,应检查节点 p 是否带有标记。如果存在标记,应根据标记信息更新节点 p 的两个子节点,并为这两个子节点添加标记。然后清除节点 p 的标记。
除了在修改指令中直接划分成的 O(logN) 个节点外,对任意节点的修改都延迟到 '在后续操作中递归进入它的父节点时' 再执行。这样一来,每条查询或修改指令的时间复杂度都降低到了 O(logN)。我们将这些标记称为 '延迟标记',它们提供了线段树中从上往下传递信息的方式。通过延迟标记的设计,我们能够更加高效地处理线段树操作。
这种 '延迟' 的思想是设计算法与解决问题时的一个重要思路,它充分利用了操作的特性,避免了不必要的计算,并提升了算法的效率。延迟标记的应用为线段树操作提供了一种优化策略,使得算法的时间复杂度得以降低。
那我们该如何设计呢.
#include <bits/stdc++.h> using namespace std; const int SIZE = 11; int a[11] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; // 原始数据 struct SegmentTree { int l, r; long long sum, add; } tree[SIZE * 4]; // struct数组存储线段树 void build(int p, int l, int r) { tree[p].l = l, tree[p].r = r; if (l == r) { tree[p].sum = a[l]; return; } int mid = (l + r) / 2; build(p * 2, l, mid); // 构建左子树 build(p * 2 + 1, mid + 1, r); // 构建右子树 tree[p].sum = tree[p * 2].sum + tree[p * 2 + 1].sum; // 更新节点的区间和 } void spread(int p) { // 下传延迟标记 if (tree[p].add) { tree[p * 2].sum += tree[p].add * (tree[p * 2].r - tree[p * 2].l + 1); tree[p * 2 + 1].sum += tree[p].add * (tree[p * 2 + 1].r - tree[p * 2 + 1].l + 1); tree[p * 2].add += tree[p].add; // 左子树打延迟标记 tree[p * 2 + 1].add += tree[p].add; // 右子树打延迟标记 tree[p].add = 0; // 清除延迟标记 } } void change(int p, int l, int r, int d) { if (l <= tree[p].l && r >= tree[p].r) { // 完全覆盖节点的区间 tree[p].sum = (long long)d * (tree[p].r - tree[p].l + 1); // 更新节点的区间和 tree[p].add += d; // 打延迟标记 return; } spread(p); int mid = (tree[p].l + tree[p].r) / 2; if (l <= mid) change(p * 2, l, r, d); // 修改左子树 if (r > mid) change(p * 2 + 1, l, r, d); // 修改右子树 tree[p].sum = tree[p * 2].sum + tree[p * 2 + 1].sum; // 更新节点的区间和 } long long ask(int p, int l, int r) { if (l <= tree[p].l && r >= tree[p].r) return tree[p].sum; spread(p); int mid = (tree[p].l + tree[p].r) / 2; long long val = 0; if (l <= mid) val += ask(p * 2, l, r); // 查询左子树 if (r > mid) val += ask(p * 2 + 1, l, r); // 查询右子树 return val; } int main() { build(1, 1, 10); int ans = ask(1, 4, 7); cout << ans << endl; change(1, 4, 5, 10); ans = ask(1, 4, 7); cout << ans << endl; return 0; }