欢迎关注更多精彩
关注我,学习常用算法与数据结构,一题多解,降维打击。
本期话题:在树上查找2个结点的最近公共祖先
问题提出
最近公共祖先定义
最近公共祖先简称 LCA(Lowest Common Ancestor)。两个节点的最近公共祖先,就是这两个点的公共祖先里面,离根最远(深度最深)的那个。
问题
参考地址:https://www.luogu.com.cn/problem/P3379
给定一棵树,询问每两个结点的最近公共祖先,一般会询问多次。
朴素做法
- 利用dfs求出所有结点的深度和父亲结点。
- 查询时把深度大的结点往上移,直到两个结点深度一样。然后两个结点同时往上移,直到两结点相遇。
复杂度分析
第1步求深度和父亲结点,需要遍历所有结点,复杂度是O(n)。
第2步在极端情况下是O(n) , 在多次查询的情况下,效率很低。
空间换时间
试想一下我们给每1个结点分配1个空间来存储往上移n个位置到达的祖先结点。
当我们要查询两个公共祖先时,就可以使用二分查找的方法来加速。
以A, B为例,可以看到后面黄色部分是公共祖先,我们要找的是最左边的10号祖先。只要利用二分查找即可找到。
该方法可以把查询复杂度降低到log(n). 但同时空间复杂度是O(n^2)。
优化空间(倍增算法)
参考资料:https://oi-wiki.org//graph/lca/#%E5%80%8D%E5%A2%9E%E7%AE%97%E6%B3%95
上面的方法的问题是空间分配的太多了,而且仔细观察,空间是冗余的。
比如A往上1个的祖先分配的数组和A的数组是高度重合的,可以看出是有递归或继承关系的。而且我们每次都是取的数组的一半。
那么我们可以存储往上数2^n个的祖先。
即存储往上1个,2个,4个。。。的祖先分别是谁。
查询的时候,由于任意数字都可以用2进制进行组合而成,可以遍历到所有祖先。
具体算法可以类比二分算法。
代码模板
题目链接:https://www.luogu.com.cn/problem/P3379
#include<stdio.h>
#include<malloc.h>
#include<string.h>
#include<cmath>
#include<algorithm>
using namespace std;
const int M = 500000 + 10;
const int N = 500000 + 10;
const int bitL = 22;
int head[N];
int to[M * 2], nextEdge[M * 2];
int len;
int h[N];
int father[bitL][N];
void initPara(int n)
{
len = 0;
for (int i = 0; i < n; i++)
{
head[i] = -1;
}
}
void add(int a, int b)
{
to[len] = b;
nextEdge[len] = head[a];
head[a] = len++;
}
void dfs(int x, int fa)
{
if (fa == -1) h[x] = 0;
else {
h[x] = h[fa] + 1;
father[0][x] = fa;
// 利用倍增算法初始化father
for (int t = 1; t < bitL && (1<<t)<=h[x]; t++) {
father[t][x] = father[t-1][father[t - 1][x]];
}
}
int i;
for (i = head[x]; i != -1; i = nextEdge[i])
{
int j = to[i];
if (fa==j)continue;
dfs(j, x);
}
}
int lca(int a, int b) {
if (h[a] < h[b]) {
return lca(b, a);
}
// 先将两个结点跳到一样高度
int gap = h[a] - h[b];
for (int t = bitL-1; t>=0; t--) {
if (gap & (1 << t))a = father[t][a];
}
if (a == b)return a;
gap = h[a];
// 利用二分查找找到深度最低的且不一样的结点。
for (int t = bitL-1; t >= 0; t--) {
if (gap <=(1 << t))continue;
if (father[t][a] == father[t][b])continue;
a = father[t][a];
b = father[t][b];
gap -= 1 << t;
}
return father[0][a]; // 再往上1个既是公共祖先
}
void solve()
{
int t;
int n, m, s;
scanf("%d%d%d", &n, &m, &s);
s--;
initPara(n);
int a, b;
for (int i = 0; i < n - 1; ++i) {
scanf("%d%d", &a, &b);
a--, b--;
add(a, b);
add(b, a);
}
dfs(s, -1);
/*for (int i = 0; i < n; ++i) {
printf("%d: %d\n", i, h[i]);
}*/
while (m--) {
scanf("%d%d", &a, &b);
a--, b--;
printf("%d\n", 1+lca(a, b));
}
}
void test() {
int t;
int n=5000, m=500000, s=1;
//scanf("%d%d%d", &n, &m, &s);
s--;
initPara(n);
int a, b;
for (int i = 0; i < n - 1; ++i) {
a = i, b = i + 1;
add(a, b);
add(b, a);
}
dfs(s, -1);
//printf("%d\n", 1 + lca(10, 5000-1));
while (m--) {
a = (m+102)%n, b =( 3823+m*2)%n;
//printf("%d\n", m);
if(lca(a, b)!=min(a,b))
printf("%d %d %d\n", 1 + lca(a, b), a+1, b+1);
}
}
int main()
{
solve();
//test();
return 0;
}
/*
5 5 4
3 1
2 4
5 1
1 4
2 4
3 2
3 5
1 2
4 5
12 11 8
8 1
8 9
8 12
1 5
1 7
7 6
9 4
9 11
9 2
4 3
12 10
1 2
2 3
3 4
4 5
5 6
6 7
7 8
8 9
9 10
10 11
11 12
*/
练习一
链接:https://loj.ac/p/10135
注意点:需要对结点进行编号,无公共祖先时返回-1
#include <stdio.h>
#include <malloc.h>
#include <string.h>
#include <cmath>
#include <algorithm>
#include <map>
using namespace std;
const int M = 500000 + 10;
const int N = 500000 + 10;
map<int, int> num2Ind;
int indLen;
void initIndex() {
num2Ind.clear();
indLen = 0;
}
int getIndex(int n) {
if (num2Ind.count(n) == 0)
return -1;
return num2Ind[n];
}
int addIndex(int n) {
if (num2Ind.count(n) == 0)
num2Ind[n] = indLen++;
return num2Ind[n];
}
const int bitL = 22;
int head[N];
int to[M * 2], nextEdge[M * 2];
int len;
int h[N];
int father[bitL][N];
void initPara(int n) {
len = 0;
for (int i = 0; i < n; i++) {
head[i] = -1;
}
}
void add(int a, int b) {
to[len] = b;
nextEdge[len] = head[a];
head[a] = len++;
}
void dfs(int x, int fa) {
if (fa == -1)
h[x] = 0;
else {
h[x] = h[fa] + 1;
father[0][x] = fa;
for (int t = 1; t < bitL && (1 << t) <= h[x]; t++) {
father[t][x] = father[t - 1][father[t - 1][x]];
}
}
int i;
for (i = head[x]; i != -1; i = nextEdge[i]) {
int j = to[i];
if (fa == j)
continue;
dfs(j, x);
}
}
int lca(int a, int b) {
if (h[a] < h[b]) {
return lca(b, a);
}
int gap = h[a] - h[b];
for (int t = bitL - 1; t >= 0; t--) {
if (gap & (1 << t))
a = father[t][a];
}
if (a == b)
return a;
gap = h[a];
for (int t = bitL - 1; t >= 0; t--) {
if (gap <= (1 << t))
continue;
if (father[t][a] == father[t][b])
continue;
a = father[t][a];
b = father[t][b];
gap -= 1 << t;
}
return father[0][a];
}
void solve() {
int n, m;
int a, b, s;
scanf("%d", &n);
initPara(n);
for (int i = 0; i < n; ++i) {
scanf("%d%d", &a, &b);
if (b == -1) {
s = addIndex(a);
continue;
}
a = addIndex(a);
b = addIndex(b);
add(a, b);
add(b, a);
}
dfs(s, -1);
scanf("%d", &m);
/*for (int i = 0; i < n; ++i) {
printf("%d: %d\n", i, h[i]);
}*/
while (m--) {
scanf("%d%d", &a, &b);
a = getIndex(a);
b = getIndex(b);
if (a < 0 || b < 0 || a == b) {
puts("0");
continue;
}
int lcab = lca(a, b);
if (lcab == a)
puts("1");
else if (lcab == b)
puts("2");
else
puts("0");
}
}
int main() {
solve();
return 0;
}
/*
3
2 -1
1 2
3 1
2
1 2
2 3
2
1 -1
1 2
2
1 2
2 1
12
8 -1
8 1
8 9
8 12
1 5
1 7
7 6
9 4
9 11
9 2
4 3
12 10
11
1 2
2 3
3 4
4 5
5 6
6 7
7 8
8 9
9 10
10 11
11 12
10
234 -1
12 234
13 234
14 234
15 234
16 234
17 234
18 234
19 234
233 19
5
234 233
233 12
233 13
233 15
233 19
*/
练习二
链接:https://loj.ac/p/2610
算法思路:先用最小生成算法把所有大的边加入到树中。
利用倍增算法建立祖先关系,以及到祖先链路上的最小负载。
查询A,B最小负载为=min(A到公共祖先最小负载,B到公共祖先最小负载)。
具体实现分别从A,B查找最近公共祖先时记录链路上的最小值。
注意点:需要事先判断是否可达。题目中规定A!=B。
利用并查集点击前往判断是否在一棵树中。
#include<stdio.h>
#include<malloc.h>
#include<string.h>
#include<cmath>
#include<algorithm>
#include<vector>
using namespace std;
class UnionFindSet {
private:
vector<int> father; // 父结点定义,father[i]=i时,i为本集合的代表
vector<int> height; // 代表树高度,初始为1
int nodeNum; // 集合中的点数
public:
UnionFindSet(int n); // 初始化
bool Union(int x, int y); // 合并
int Find(int x);
bool UnionV2(int x, int y); // 合并
int FindV2(int x);
};
UnionFindSet::UnionFindSet(int n) : nodeNum(n + 1) {
father = vector<int>(nodeNum);
height = vector<int>(nodeNum);
for (int i = 0; i < nodeNum; ++i) father[i] = i, height[i] = 1; // 初始为自己
}
int UnionFindSet::Find(int x) {
while (father[x] != x) x = father[x];
return x;
}
bool UnionFindSet::Union(int x, int y) {
x = Find(x);
y = Find(y);
if (x == y)return false;
father[x] = y;
return true;
}
int UnionFindSet::FindV2(int x) {
int root = x; // 保存好路径上的头结点
while (father[root] != root) {
root = father[root];
}
/*
从头结点开始一直往根上遍历
把所有结点的father直接指向root。
*/
while (father[x] != x) {
// 一定要先保存好下一个结点,下一步是要对father[x]进行赋值
int temp = father[x];
father[x] = root;
x = temp;
}
return root;
}
/*
需要加入height[]属性,初始化为1.
*/
//合并结点
bool UnionFindSet::UnionV2(int x, int y) {
x = Find(x);
y = Find(y);
if (x == y) {
return false;
}
if (height[x] < height[y]) {
father[x] = y;
}
else if (height[x] > height[y]) {
father[y] = x;
}
else {
father[x] = y;
height[y]++;
}
return true;
}
const int M = 500000 + 10;
const int N = 500000 + 10;
const int bitL = 22;
int head[N];
int to[M * 2], nextEdge[M * 2], weight[M * 2];
int len;
int h[N];
int father[bitL][N];
int dis[bitL][N];
void initPara(int n)
{
len = 0;
for (int i = 0; i < n; i++)
{
head[i] = -1;
h[i] = -1;
}
}
void add(int a, int b, int w)
{
to[len] = b;
weight[len] = w;
nextEdge[len] = head[a];
head[a] = len++;
}
void dfs(int x, int fa, int w)
{
if (fa == -1) h[x] = 0;
else {
h[x] = h[fa] + 1;
father[0][x] = fa;
dis[0][x] = w;
for (int t = 1; t < bitL && (1 << t) <= h[x]; t++) {
father[t][x] = father[t - 1][father[t - 1][x]];
dis[t][x] = min(dis[t - 1][x], dis[t - 1][father[t - 1][x]]);
}
}
int i;
for (i = head[x]; i != -1; i = nextEdge[i])
{
int j = to[i];
if (fa == j)continue;
dfs(j, x, weight[i]);
}
}
int lca(int a, int b) {
if (h[a] < h[b]) {
return lca(b, a);
}
int gap = h[a] - h[b];
for (int t = bitL - 1; t >= 0; t--) {
if (gap & (1 << t))a = father[t][a];
}
if (a == b)return a;
gap = h[a];
for (int t = bitL - 1; t >= 0; t--) {
if (gap <= (1 << t))continue;
if (father[t][a] == father[t][b])continue;
a = father[t][a];
b = father[t][b];
gap -= 1 << t;
}
return father[0][a];
}
int optDis(int a, int b) {
if (h[a] < h[b]) {
return optDis(b, a);
}
int d = 1e6;
int gap = h[a] - h[b];
for (int t = bitL - 1; t >= 0; t--) {
if (gap & (1 << t)) {
d=min(d, dis[t][a]);
a = father[t][a];
}
}
if (a == b)return d;
gap = h[a];
for (int t = bitL - 1; t >= 0; t--) {
if (gap <= (1 << t))continue;
if (father[t][a] == father[t][b])continue;
d = min(d,dis[t][a]);
d = min(d,dis[t][b]);
a = father[t][a];
b = father[t][b];
gap -= 1 << t;
}
d = min(d, min(dis[0][a], dis[0][b]));
return d;
}
bool cmp(vector<int> &a, vector<int> &b) {
return a[2] > b[2];
}
void solve()
{
int n, m;
int a, b, w;
scanf("%d%d", &n, &m);
initPara(n);
auto us = UnionFindSet(n);
vector<vector<int>> eds;
for (int i = 0; i < m; ++i) {
scanf("%d%d%d", &a, &b, &w);
a--, b--;
eds.push_back({a,b,w});
}
sort(eds.begin(), eds.end(), cmp);
for (auto ed : eds) {
if (us.UnionV2(ed[0], ed[1])) {
add(ed[0], ed[1], ed[2]);
add(ed[1], ed[0], ed[2]);
}
}
for (int i = 0; i < n; ++i) {
if(h[i]<0)dfs(i, -1, 0);
}
scanf("%d", &m);
while (m--) {
scanf("%d%d", &a, &b);
a--, b--;
if (us.FindV2(a) != us.FindV2(b))puts("-1");
else printf("%d\n", optDis(a, b));
}
}
int main()
{
solve();
return 0;
}
/*
12 11
8 1 4
8 9 3
8 12 6
1 5 5
1 7 1
7 6 2
9 4 2
9 11 10
9 2 9
4 3 2
12 10 7
11
1 2
2 3
3 4
4 5
5 6
6 7
7 8
8 9
9 10
10 11
11 12
12 11
8 1 1
8 9 1
8 12 1
1 5 1
1 7 1
7 6 1
9 4 1
9 11 1
9 2 1
4 3 1
12 10 1
11
1 2
2 3
3 4
4 5
5 6
6 7
7 8
8 9
9 10
10 11
11 12
*/
练习三
https://loj.ac/p/10130
算法思路:
利用倍增算法建立祖先关系,以及到祖先链路上的距离。
查询A,B距离=A到公共祖先距离+B到公共祖先距离)。
具体实现与上一题类似。
#include<stdio.h>
#include<malloc.h>
#include<string.h>
#include<cmath>
#include<algorithm>
using namespace std;
const int M = 500000 + 10;
const int N = 500000 + 10;
const int bitL = 22;
int head[N];
int to[M * 2], nextEdge[M * 2], weight[M * 2];
int len;
int h[N];
int father[bitL][N];
int dis[bitL][N];
void initPara(int n)
{
len = 0;
for (int i = 0; i < n; i++)
{
head[i] = -1;
}
}
void add(int a, int b, int w)
{
to[len] = b;
weight[len] = w;
nextEdge[len] = head[a];
head[a] = len++;
}
void dfs(int x, int fa, int w)
{
if (fa == -1) h[x] = 0;
else {
h[x] = h[fa] + 1;
father[0][x] = fa;
dis[0][x] = w;
for (int t = 1; t < bitL && (1 << t) <= h[x]; t++) {
father[t][x] = father[t - 1][father[t - 1][x]];
dis[t][x] = dis[t - 1][x] + dis[t - 1][father[t - 1][x]];
}
}
int i;
for (i = head[x]; i != -1; i = nextEdge[i])
{
int j = to[i];
if (fa == j)continue;
dfs(j, x, weight[i]);
}
}
int lca(int a, int b) {
if (h[a] < h[b]) {
return lca(b, a);
}
int gap = h[a] - h[b];
for (int t = bitL - 1; t >= 0; t--) {
if (gap & (1 << t))a = father[t][a];
}
if (a == b)return a;
gap = h[a];
for (int t = bitL - 1; t >= 0; t--) {
if (gap <= (1 << t))continue;
if (father[t][a] == father[t][b])continue;
a = father[t][a];
b = father[t][b];
gap -= 1 << t;
}
return father[0][a];
}
int optDis(int a, int b) {
if (h[a] < h[b]) {
return optDis(b, a);
}
int d = 0;
int gap = h[a] - h[b];
for (int t = bitL - 1; t >= 0; t--) {
if (gap & (1 << t)) {
d += dis[t][a];
a = father[t][a];
}
}
if (a == b)return d;
gap = h[a];
for (int t = bitL - 1; t >= 0; t--) {
if (gap <= (1 << t))continue;
if (father[t][a] == father[t][b])continue;
d += dis[t][a];
d += dis[t][b];
a = father[t][a];
b = father[t][b];
gap -= 1 << t;
}
d += dis[0][a] + dis[0][b];
return d;
}
void solve()
{
int n, m;
int a, b;
scanf("%d", &n);
initPara(n);
for (int i = 0; i < n - 1; ++i) {
scanf("%d%d", &a, &b);
a--, b--;
add(a, b, 1);
add(b, a, 1);
}
dfs(0, -1, 0);
scanf("%d", &m);
while (m--) {
scanf("%d%d", &a, &b);
a--, b--;
printf("%d\n", optDis(a, b));
}
}
int main()
{
solve();
return 0;
}
/*
6
1 2
1 3
2 4
2 5
3 6
2
2 6
5 6
12
8 1
8 9
8 12
1 5
1 7
7 6
9 4
9 11
9 2
4 3
12 10
11
1 2
2 3
3 4
4 5
5 6
6 7
7 8
8 9
9 10
10 11
11 12
*/
练习四
https://acm.hdu.edu.cn/showproblem.php?pid=2586
与练习三类似
#include<stdio.h>
#include<malloc.h>
#include<string.h>
#include<cmath>
#include<algorithm>
using namespace std;
const int M = 500000 + 10;
const int N = 500000 + 10;
const int bitL = 22;
int head[N];
int to[M * 2], nextEdge[M * 2],weight[M*2];
int len;
int h[N];
int father[bitL][N];
int dis[bitL][N];
void initPara(int n)
{
len = 0;
for (int i = 0; i < n; i++)
{
head[i] = -1;
}
}
void add(int a, int b, int w)
{
to[len] = b;
weight[len] = w;
nextEdge[len] = head[a];
head[a] = len++;
}
void dfs(int x, int fa, int w)
{
if (fa == -1) h[x] = 0;
else {
h[x] = h[fa] + 1;
father[0][x] = fa;
dis[0][x] = w;
for (int t = 1; t < bitL && (1<<t)<=h[x]; t++) {
father[t][x] = father[t-1][father[t - 1][x]];
dis[t][x] = dis[t-1][x]+ dis[t - 1][father[t - 1][x]];
}
}
int i;
for (i = head[x]; i != -1; i = nextEdge[i])
{
int j = to[i];
if (fa==j)continue;
dfs(j, x, weight[i]);
}
}
int lca(int a, int b) {
if (h[a] < h[b]) {
return lca(b, a);
}
int gap = h[a] - h[b];
for (int t = bitL-1; t>=0; t--) {
if (gap & (1 << t))a = father[t][a];
}
if (a == b)return a;
gap = h[a];
for (int t = bitL-1; t >= 0; t--) {
if (gap <=(1 << t))continue;
if (father[t][a] == father[t][b])continue;
a = father[t][a];
b = father[t][b];
gap -= 1 << t;
}
return father[0][a];
}
int optDis(int a, int b) {
if (h[a] < h[b]) {
return optDis(b, a);
}
int d = 0;
int gap = h[a] - h[b];
for (int t = bitL - 1; t >= 0; t--) {
if (gap & (1 << t)) {
d += dis[t][a];
a = father[t][a];
}
}
if (a == b)return d;
gap = h[a];
for (int t = bitL - 1; t >= 0; t--) {
if (gap <= (1 << t))continue;
if (father[t][a] == father[t][b])continue;
d += dis[t][a];
d += dis[t][b];
a = father[t][a];
b = father[t][b];
gap -= 1 << t;
}
d += dis[0][a] + dis[0][b];
return d;
}
void solve()
{
int t;
int n, m;
int a, b, w;
scanf("%d", &t);
while (t--) {
scanf("%d%d", &n, &m);
initPara(n);
for (int i = 0; i < n - 1; ++i) {
scanf("%d%d%d", &a, &b, &w);
a--, b--;
add(a, b,w);
add(b, a,w);
}
dfs(0, -1, 0);
/*for (int i = 0; i < n; ++i) {
printf("%d: %d\n", i, h[i]);
}*/
while (m--) {
scanf("%d%d", &a, &b);
a--, b--;
printf("%d\n", optDis(a,b));
}
}
}
void test() {
int t;
int n = 5000, m = 500000;
//scanf("%d%d%d", &n, &m, &s);
initPara(n);
int a, b;
for (int i = 0; i < n - 1; ++i) {
a = i, b = i + 1;
add(a, b,1);
add(b, a,1);
}
dfs(0, -1,0);
//printf("%d\n", 1 + lca(10, 5000-1));
while (m--) {
a = (m+102)%n, b =( 3823+m*2)%n;
//printf("%d\n", m);
if(lca(a, b)!=min(a,b))
printf("%d %d %d\n", 1 + lca(a, b), a+1, b+1);
}
}
int main()
{
solve();
//test();
return 0;
}
/*
2
3 2
1 2 10
3 1 15
1 2
2 3
2 2
1 2 100
1 2
2 1
1
12 11
8 1 4
8 9 3
8 12 6
1 5 5
1 7 1
7 6 2
9 4 2
9 11 10
9 2 9
4 3 2
12 10 7
1 2
2 3
3 4
4 5
5 6
6 7
7 8
8 9
9 10
10 11
11 12
1
12 11
8 1 1
8 9 1
8 12 1
1 5 1
1 7 1
7 6 1
9 4 1
9 11 1
9 2 1
4 3 1
12 10 1
1 2
2 3
3 4
4 5
5 6
6 7
7 8
8 9
9 10
10 11
11 12
*/
本人码农,希望通过自己的分享,让大家更容易学懂计算机知识。创作不易,帮忙点击公众号的链接。