《算法竞赛·快冲300题》将于2024年出版,是《算法竞赛》的辅助练习册。
所有题目放在自建的OJ New Online Judge。
用C/C++、Java、Python三种语言给出代码,以中低档题为主,适合入门、进阶。
文章目录
- 题目描述
- 题解
- C++代码
- Java代码
- Python代码
“ 树与排列” ,链接: http://oj.ecustacm.cn/problem.php?id=1834
题目描述
【题目描述】 给你一棵树和它的顶点的排列。
可以证明:对于任何树,任何一对源点/终点,存在排列第一个节点是源点,最后一个节点是终点,排列的相邻节点之间的距离小于或等于3。
你的任务是为该性质编写一个验证程序。
给定这样一个排列和树,验证排列中相邻节点之间的距离是否小于或等于3。
【输入格式】 第一行为正整数T,表示存在T组测试数据,T≤50000。
对于每组测试数据,第一行输入n,表示树的节点数量,节点编号为1-n,2≤n≤100000。
接下来n-1行,每行两个数字a和b,表示节点a和b之间存在边。
接下来n行,每行一个数字p,表示给定的排列。
输入保证n总和不超过100000。
【输出格式】 对于每组测试数据,表示满足题目条件输出1,否则输出0。
【输入样例】
2
5
1 2
2 3
3 4
4 5
1
3
2
5
4
5
1 2
2 3
3 4
4 5
1
5
2
3
4
【输出样例】
1
0
题解
本题要求检查排列中相邻两点的距离是否小于等于3,以下图的树为例。
设给定一个排列{1 4 2 6 3 7 5 8 9},1-4距离为2,4-2距离为1,2-6距离为3,等等。
本题显然是最近公共祖先(LCA)问题。例如求2-6之间的距离,先求得它们的公共祖先是1,那么2-6的距离等于2-1的距离加上1-6的距离。
如果用标准的LCA算法,例如倍增法或Tarjan,对两个点x、y求一次LCA(x,y),计算复杂度为O(logn)。共T个测试,总复杂度为O(Tlogn),能通过测试。
不过,本题的要求比较简单,只要判断两个点x、y之间的距离dis(x,y)≤3,并不用算x、y之间的距离,所以计算量很小,不需要用标准的LCA算法,用简单的LCA算法即可(《算法竞赛》清华大学出版社,罗勇军,郭卫斌著。234页,“4.8 LCA”)。
在进行以下步骤之前,先求出所有点在树上的深度depth,例如depth[2]=2,depth[6]=3。在树上做一次DFS即可求出所有点的深度,计算复杂度为O(n)。
下面2个步骤可以计算出di(x,y)的距离是否大于3。
(1)从x和y中较深的点往上走,直到和另一个点等高。以(2, 6)为例,6更深,从6出发走到点3的位置停下,此时和点2等高。每往上走一步,dis(x,y)加1,例如6走到点3的位置,得dis(2,6)=1。也可以理解为x、y之间的总距离减少了1。到x、y相遇时,减少的总数就是距离dis(x, y)。在x或y往上走的过程中,如果超过了3步还没有相遇,说明dis(x, y)大于3,不符合要求,停止。
(2)经过(1)之后,x,y等高,让x 和y 同步向上走。每走一步就判断是否相遇,如果相遇就停止。同时累加dis(x, y),如果大于3,也停止。
以上操作的计算量很小。每次检查x、y的距离是否大于3,只需让x或y一共走3次。共T个测试,总计算量只有3T。
【重点】 。
C++代码
#include<bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10;
int n;
vector<int>G[N]; //邻接表,存树
int a[N]; //记录排列
int depth[N]; //depth[i]:节点i的深度
int pre[N]; //pre[i]:节点i的父节点
void dfs(int u, int fa, int d){ //计算每个点的深度,结果记录在depth[]中
pre[u] = fa;
depth[u] = d;
for(auto v : G[u])
if(v != fa)
dfs(v, u, d + 1);
}
int main(){
int T; scanf("%d", &T);
while(T--){
scanf("%d", &n);
for(int i = 1; i <= n; i++) G[i].clear();
for(int i = 1; i < n; i++) { //用邻接表存树
int u, v; scanf("%d%d", &u, &v);
G[u].push_back(v);
G[v].push_back(u);
}
for(int i = 1; i <= n; i++) scanf("%d", &a[i]); //读排列
dfs(1, -1, 0); //DFS出所有点的深度
int ok = 1;
for(int i = 1; i <= n-1; i++) {
int x = a[i], y = a[i+1], dis = 0; //检查排列中相邻2个点
while(dis <= 3 && depth[x] > depth[y]) //x比y深,让x往上走,直到和y深度相同
x = pre[x], dis++; //x每往上走一步,dis(x,y)加1
while(dis <= 3 && depth[y] > depth[x]) //y比x深,让y往上走,直到和x深度相同
y = pre[y], dis++;
while(dis <= 3 && depth[x] && x != y)
//经过前面2个while,x、y已经走到同一层。如果它们不在同一个位置,那么就在2个子树的相同深度上
x = pre[x], y = pre[y], dis += 2; //x、y同时往上走,dis(x,y)加2
if(dis > 3) {ok = 0; break;} //不符合要求,停止
}
printf("%d\n",ok);
}
return 0;
}
Java代码
import java.util.*;
public class Main {
static final int N = 100010;
static int n;
static List<Integer>[] G = new ArrayList[N];
static int[] a = new int[N];
static int[] depth = new int[N];
static int[] pre = new int[N];
public static void main(String[] args) {
Scanner scanner = new Scanner(System.in);
int T = scanner.nextInt();
while (T-- > 0) {
n = scanner.nextInt();
for (int i = 1; i <= n; i++) G[i] = new ArrayList<Integer>();
for (int i = 1; i < n; i++) {
int u = scanner.nextInt(), v = scanner.nextInt();
G[u].add(v);
G[v].add(u);
}
for (int i = 1; i <= n; i++) a[i] = scanner.nextInt();
dfs(1, -1, 0);
int ok = 1;
for (int i = 1; i <= n - 1; i++) {
int x = a[i], y = a[i + 1], dis = 0;
while (dis <= 3 && depth[x] > depth[y]) { x = pre[x]; dis++;}
while (dis <= 3 && depth[y] > depth[x]) { y = pre[y]; dis++;}
while (dis <= 3 && depth[x] > 0 && x != y) {
x = pre[x];
y = pre[y];
dis += 2;
}
if (dis > 3) { ok = 0; break; }
}
System.out.println(ok);
}
scanner.close();
}
static void dfs(int u, int fa, int d) {
pre[u] = fa;
depth[u] = d;
for (int v : G[u])
if (v != fa)
dfs(v, u, d + 1);
}
}
Python代码
import sys
sys.setrecursionlimit(1000000)
N = 100010
G = [[] for _ in range(N)]
a = [0] * N
depth = [0] * N
pre = [0] * N
def dfs(u, fa, d):
pre[u] = fa
depth[u] = d
for v in G[u]:
if v != fa:
dfs(v, u, d + 1)
T = int(input())
for _ in range(T):
n = int(input())
for i in range(1, n + 1): G[i].clear()
for i in range(1, n):
u, v = map(int, input().split())
G[u].append(v)
G[v].append(u)
#a[1:n + 1] = map(int, input().split())
for i in range(1, n + 1): a[i]=int(input())
dfs(1, -1, 0)
ok = 1
for i in range(1, n):
x, y = a[i], a[i + 1]
dis = 0
while dis <= 3 and depth[x] > depth[y]:
x = pre[x]
dis += 1
while dis <= 3 and depth[y] > depth[x]:
y = pre[y]
dis += 1
while dis <= 3 and depth[x] and x != y:
x = pre[x]
y = pre[y]
dis += 2
if dis > 3:
ok = 0
break
print(ok)