题意
给定两颗二叉树 S S S 和 T T T,如果对于 S S S 的某个子树 S ′ S^\prime S′,删除若干个(或不删除)其子树后,可以和 T T T 相同(左子树与左子树匹配,右子树与右子树匹配),那么称 S ′ S^\prime S′ 与 T T T 匹配
统计
S
S
S 中有多少个子树
S
′
S^\prime
S′ 与
T
T
T 匹配
保证
T
T
T 的叶子节点最多只有
20
20
20 个
思路
首先我们可以将 T T T 看成是一颗 01 T r i e 01 Trie 01Trie,假设 T T T 有 k ( k ≤ 20 ) k(k \leq 20) k(k≤20) 个叶子,那么 T T T 就可以看成是插入了 k k k 个 01 01 01 串后的 T r i e Trie Trie,对于 S ′ S^\prime S′ 也看成是一颗 01 T r i e 01Trie 01Trie
那么问题转化为了:在每一个 S ′ S^\prime S′ 子树里,确定是否可以找到所有 k k k 个 01 01 01 串是从 S ′ S^\prime S′ 这个节点延伸下去的,即这 k k k 个串包含在 S ′ S^\prime S′ 这颗 T r i e Trie Trie 中
注意到
k
k
k 个串肯定两两不同,所以我们可以在
S
S
S 上
d
f
s
dfs
dfs 来统计,对于当前节点
u
u
u,我们维护当前深度,以及
d
f
s
dfs
dfs 下来的路径,那么我们枚举
k
k
k 个
01
01
01 串,看看这些串是不是
u
u
u 结尾的某个后缀,即对于
s
i
s_i
si,如果
s
i
s_i
si 是以
u
u
u 结尾的串,那么
u
u
u 的
∣
s
i
∣
|s_i|
∣si∣ 级祖先的
c
n
t
cnt
cnt 要
+
1
+1
+1
字符串匹配同样可以在
d
f
s
dfs
dfs 过程中维护
h
a
s
h
hash
hash 数组
d f s dfs dfs 完后, c n t = k cnt = k cnt=k 的那些点就是题目要求的点 S ′ S^\prime S′
时间复杂度: O ( m + n k ) O(m + nk) O(m+nk)
#include<bits/stdc++.h>
#define fore(i,l,r) for(int i=(int)(l);i<(int)(r);++i)
#define fi first
#define se second
#define endl '\n'
#define ull unsigned long long
#define ALL(v) v.begin(), v.end()
#define Debug(x, ed) std::cerr << #x << " = " << x << ed;
const int INF=0x3f3f3f3f;
const long long INFLL=1e18;
typedef long long ll;
const int N = 300005;
const int P = 1313;
struct Tree{
int n;
std::vector<int> l, r;
Tree(int n = 0){
this -> n = n;
l.assign(n + 1, 0);
r.assign(n + 1, 0);
}
void set(int n = 0){
this -> n = n;
l.assign(n + 1, 0);
r.assign(n + 1, 0);
}
};
struct Seq{
int len;
ull w; //哈希值
};
Tree S, T;
ull Pow[N];
std::vector<Seq> leaf; //不超过20个01串
ull hash[N]; //搜索过程中的前缀哈希
std::vector<int> ans;
int cnt[N];
std::vector<int> pt; //dfs过程中的访问路径
inline ull get_hash(int l, int r){
return hash[r] - hash[l - 1] * Pow[r - l + 1];
}
void dfs0(int u, ull val, int len){
if(!T.l[u] && !T.r[u]){
leaf.push_back({len, val});
return;
}
if(T.l[u]) dfs0(T.l[u], val * P + 1, len + 1);
if(T.r[u]) dfs0(T.r[u], val * P + 2, len + 1);
}
void dfs1(int u, int dep){
cnt[u] = 0;
pt.push_back(u);
for(auto [len, w] : leaf)
if(len <= dep){
ull val = get_hash(dep - len + 1, dep);
if(w == val) ++cnt[pt[dep - len]];
}
if(S.l[u]){
hash[dep + 1] = hash[dep] * P + 1;
dfs1(S.l[u], dep + 1);
}
if(S.r[u]){
hash[dep + 1] = hash[dep] * P + 2;
dfs1(S.r[u], dep + 1);
}
if(cnt[u] == leaf.size()) ans.push_back(u);
pt.pop_back();
}
int main(){
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cout.tie(nullptr);
Pow[0] = 1;
fore(i, 1, N) Pow[i] = Pow[i - 1] * P;
int t;
std::cin >> t;
while(t--){
int n, m;
std::cin >> n;
S.set(n);
fore(i, 0, n){
std::cin >> S.l[i] >> S.r[i];
}
std::cin >> m;
T.set(m);
fore(i, 0, m){
std::cin >> T.l[i] >> T.r[i];
}
leaf.clear();
dfs0(0, 0, 0);
ans.clear();
pt.clear();
dfs1(0, 0);
std::sort(ALL(ans));
std::cout << ans.size() << endl;
fore(i, 0, ans.size()) std::cout << ans[i] << " \n"[i == ans.size() - 1];
}
return 0;
}