想要精通算法和SQL的成长之路 - 最长递增子序列 II(线段树的运用)
- 前言
- 一. 最长递增子序列 II
- 1.1 向下递推
- 1.2 向上递推
- 1.3 更新操作
- 1.4 查询操作
- 1.5 完整代码:
前言
想要精通算法和SQL的成长之路 - 系列导航
一. 最长递增子序列 II
原题链接
在做这个题目之前,先看一下:数据结构 - 线段树的运用 。
在线段树的基础上,思路如下:
- 首先,题目要求了子序列中,相邻的元素差不能超过
k
值。我们假设线段树的val
值,存储的就是最长递增子序列的长度。 - 我们定义
query
函数的返回就是范围区间内的最长递增子序列长度。
那么伪代码就是:
public int lengthOfLIS(int[] nums, int k) {
int ans = 0;
for (int i = 0; i < nums.length; i++) {
int tmp = query(nums[i]);
ans = Math.max(ans, tmp);
}
return ans;
}
但是有一个问题:假设我们以num[i]
作为最后一个元素,但是我并不知道它的前一个元素是谁。那咋办?
结合线段树的一个区间求值性质,我们只要求得区间 [num[i] - k, num[i] - 1]
之间的最长子序列长度,再加上1(当前子序列的最后一个元素num[i]
),那么就可以求得以num[i]
为结尾的最长子序列长度了。
同时我们还要更新各个子区间对应的最长长度,即伪代码:
for (int i = 0; i < nums.length; i++) {
int tmp = query(nums[i]);
update(tmp)
ans = Math.max(ans, tmp);
}
1.1 向下递推
我们做更新操作的时候,求得不再是 数据结构 - 线段树的运用 里面的区间和,而是最大值。因此我们不能在原本值的基础上做加减法运算。而是做覆盖运算。
class Node {
Node left, right;
int val, add;
}
private void pushDown(Node node) {
if (node.left == null) {
node.left = new Node();
}
if (node.right == null) {
node.right = new Node();
}
if (node.add == 0) {
return;
}
node.left.val = node.add; // 替换
node.right.val = node.add; // 替换
node.left.add = node.add; // 替换
node.right.add = node.add; // 替换
node.add = 0;
}
1.2 向上递推
求以当前节点作为最长子序列的最后一个元素时的序列长度时,我们可以拿到:
- 左子序列的最长递增长度。
- 右子序列的最长递增长度。
两者取最大,那么代码就是:
private void pushUp(Node node) {
node.val = Math.max(node.left.val, node.right.val);
}
1.3 更新操作
public void update(Node node, int start, int end, int left, int right, int val) {
// 如果线段树的区间完全在查询区间内,那么直接更新当前节点的 val 值即可
if (start >= left && end <= right) {
// 覆盖旧值
node.val = val;
// 覆盖需要传递的节点值
node.add = val;
return;
}
// 如果不在查询区间内,那么我们需要递归更新左右子树
int mid = (start + end) >> 1;
// 向下传递标记
pushDown(node);
if (left <= mid) {
update(node.left, start, mid, left, right, val);
}
// [mid + 1, end] 和 [l, r] 可能有交集,遍历右孩子区间
if (right > mid) {
update(node.right, mid + 1, end, left, right, val);
}
// 计算当前节点的val值
pushUp(node);
}
1.4 查询操作
public int query(Node node, int start, int end, int left, int right) {
// 若当前区间完全在查询区间内,直接返回当前区间的最值
if (left <= start && end <= right) {
return node.val;
}
// 把当前区间 [start, end] 均分得到左右孩子的区间范围
int mid = (start + end) >> 1, ans = 0;
// 下推标记
pushDown(node);
// [start, mid] 和 [l, r] 可能有交集,遍历左孩子区间
if (left <= mid) {
ans = query(node.left, start, mid, left, right);
}
// [mid + 1, end] 和 [l, r] 可能有交集,遍历右孩子区间
if (right > mid) {
ans = Math.max(ans, query(node.right, mid + 1, end, left, right));
}
return ans;
}
1.5 完整代码:
有个问题就是:我们在遍历数组的每个元素num[i]的时候,我们的线段树区间应该设置为多少?
因为我们是以每个元素的 [num[i] - k, num[i] - 1]
区间来做计算的,因此线段树的范围和num[i]
的范围有关系。
题目有个提示:
那么确定好了线段树的区间范围,我们可以编写代码如下:
class Solution {
public int lengthOfLIS(int[] nums, int k) {
int ans = 0;
Node root = new Node();
for (int i = 0; i < nums.length; i++) {
// 查询区间 [nums[i] - k, nums[i] - 1] 区间范围内的,以每个元素为末尾元素时的最长递增子序列长度。
int cnt = query(root, 0, N, Math.max(0, nums[i] - k), nums[i] - 1) + 1;
// 更新,注意这里是覆盖更新,对应的模版中覆盖更新不需要累加,已在下方代码中标注
update(root, 0, N, nums[i], nums[i], cnt);
ans = Math.max(ans, cnt);
}
return ans;
}
class Node {
Node left, right;
int val, add;
}
private int N = (int) 1e5;
private Node root = new Node();
public void update(Node node, int start, int end, int left, int right, int val) {
// 如果线段树的区间完全在查询区间内,那么直接更新当前节点的 val 值即可
if (start >= left && end <= right) {
// 覆盖旧值
node.val = val;
// 覆盖需要传递的节点值
node.add = val;
return;
}
// 如果不在查询区间内,那么我们需要递归更新左右子树
int mid = (start + end) >> 1;
// 向下传递标记
pushDown(node);
if (left <= mid) {
update(node.left, start, mid, left, right, val);
}
// [mid + 1, end] 和 [l, r] 可能有交集,遍历右孩子区间
if (right > mid) {
update(node.right, mid + 1, end, left, right, val);
}
// 计算当前节点的val值
pushUp(node);
}
public int query(Node node, int start, int end, int left, int right) {
// 若当前区间完全在查询区间内,直接返回当前区间的最值
if (left <= start && end <= right) {
return node.val;
}
// 把当前区间 [start, end] 均分得到左右孩子的区间范围
int mid = (start + end) >> 1, ans = 0;
// 下推标记
pushDown(node);
// [start, mid] 和 [l, r] 可能有交集,遍历左孩子区间
if (left <= mid) {
ans = query(node.left, start, mid, left, right);
}
// [mid + 1, end] 和 [l, r] 可能有交集,遍历右孩子区间
if (right > mid) {
ans = Math.max(ans, query(node.right, mid + 1, end, left, right));
}
return ans;
}
private void pushUp(Node node) {
node.val = Math.max(node.left.val, node.right.val);
}
private void pushDown(Node node) {
if (node.left == null) {
node.left = new Node();
}
if (node.right == null) {
node.right = new Node();
}
if (node.add == 0) {
return;
}
node.left.add = node.add; // 不需要累加
node.right.add = node.add; // 不需要累加
node.left.val = node.add; // 不需要累加
node.right.val = node.add; // 不需要累加
node.add = 0;
}
}