零、前言
关于前缀和:
前缀和详解,朴素前缀和,前缀和变形,二维前缀和_前缀积-CSDN博客
关于LCA:
LCA算法-倍增算法_lca倍增算法-CSDN博客
LCA算法-Tarjan算法_lca数组-CSDN博客
树链剖分——重链剖分,原理剖析,代码详解-CSDN博客
一、树上前缀和
1.1 问题引入
给定一棵 n 个节点的树,多次询问 x, y 路径上的节点和。
对于一次询问,我们可以一次dfs解决,如果多次,我们就要用树上前缀和了。
1.2 点前缀和
设 acc[i] 表示从根节点到节点 i 的点权和。
先自顶向下 dfs 计算出前缀和 acc[],然后用 前缀和 拼凑(x, y)的路径和。
acc[x] + acc[y] - acc[lca(x, y)] - acc[fa(lca)]
1.3 边前缀和
设 acc[i] 表示从根节点到节点i的边权和
先自顶向下 dfs 计算出前缀和 acc[],然后用 前缀和 拼凑(x, y)的路径和。
acc[x] + acc[y] - 2 · acc[lca(x, y)]
二、OJ练习
2.1 BJOI2018 求和
原题链接
[P4427 BJOI2018] 求和 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
思路分析
先建图,然后 dfs 预处理 acc[][], acc[j][i] 代表从根节点到 i 路径上节点的 深度的 j 次方 之和
对于每个查询,我们求 lca,输出结果即可
AC代码
#include <bits/stdc++.h>
// #define DEBUG
using u32 = unsigned;
using i64 = long long;
using u64 = unsigned long long;
constexpr int inf32 = 1E9 + 7;
constexpr i64 inf64 = 1E18 + 7;
constexpr int P = 998244353;
constexpr int B = 19;
constexpr int K = 50;
void solve() {
int n;
std::cin >> n;
std::vector<std::vector<int>> adj(n);
for (int i = 1, u, v; i < n; ++ i) {
std::cin >> u >> v;
-- u, -- v;
adj[u].push_back(v);
adj[v].push_back(u);
}
std::vector<int> d(n);
std::vector<std::vector<int>> acc(K + 1, std::vector<int>(n));
std::vector<std::array<int, B>> f(n, std::array<int, B>{});
auto dfs = [&](auto &&self, int u, int p) -> void {
if (u) {
f[u][0] = p;
for (int i = 1; i < B; ++ i) {
f[u][i] = f[f[u][i - 1]][i - 1];
}
}
for (int i = 0, val = 1; i <= K; ++ i) {
acc[i][u] = val;
if (u) {
acc[i][u] += acc[i][p];
if (acc[i][u] >= P)
acc[i][u] -= P;
}
val = 1LL * val * d[u] % P;
}
for (int v : adj[u]) {
if (v == p) continue;
d[v] = d[u] + 1;
self(self, v, u);
}
};
dfs(dfs, 0, -1);
auto LCA = [&](int u, int v) -> int {
if (d[u] < d[v]) std::swap(u, v);
for (int i = B - 1; ~i; -- i)
if (d[f[u][i]] >= d[v])
u = f[u][i];
if (u == v)
return u;
for (int i = B - 1; ~i; -- i)
if (f[u][i] != f[v][i]) {
u = f[u][i];
v = f[v][i];
}
return f[u][0];
};
int m;
std::cin >> m;
for (int i = 0, u, v, k; i < m; ++ i) {
std::cin >> u >> v >> k;
-- u, -- v;
int lca = LCA(u, v);
int ans = ((acc[k][u] + acc[k][v] - acc[k][lca]) % P + P) % P;
if (lca)
ans = ((ans - acc[k][f[lca][0]]) % P + P) % P;
std::cout << ans << '\n';
}
}
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
#ifdef DEBUG
int cur = clock();
freopen("in.txt", "r", stdin);
freopen("out.txt", "w", stdout);
#endif
int t = 1;
// std::cin >> t;
while (t--) {
solve();
}
#ifdef DEBUG
std::cerr << "run-time: " << clock() - cur << '\n';
#endif
return 0;
}