[python刷题模板] 树的直径/换根DP

news2025/1/18 4:31:55

[python刷题模板] 树的直径/换根DP

    • 一、 算法&数据结构
      • 1. 描述
      • 2. 复杂度分析
      • 3. 常见应用
      • 4. 常用优化
    • 二、 模板代码
      • 1. 单纯询问树的直径值
      • 2. 求出树的直径两端搞事情
      • 3. 换根DP求树的直径(大炮打蚊子,别这么做,只是用来帮助理解换根DP)
      • 4. 换根dp求特定值(另附小日子模板)
      • 5. 换根dp求去掉一个叶子的值。
    • 三、其他
    • 四、更多例题
    • 五、参考链接

一、 算法&数据结构

1. 描述

树的直径代表树上最远的两个点的距离。
在某些特定的时候会用到。
  • 求树的直径通常可以
    1. 树形DP:仅用来计算直径的值,一次dfs即可。由于直径一定是某个子树的两个连接根的最长简单路径拼起来(其中一个可能是空),因此令dfs(u)计算u的子树树高(最长路径),计算v时,同时维护v的兄弟节点最大那个,每次都加一下尝试更新答案即可。
    2. 两次遍历(bfs/dfs均可):可以同时计算直径的值且找到两个端点。第一次遍历从任意一点root找最远节点u,第二次从u出发,找最远节点v,u和v就是直径的两端。
    3. 换根DP(大炮打蚊子):进用来计算直径的值,假设直径两端是u,v,那么以u为树根,v一定是最远节点,求最大树高即可。

换根DP是一种树上DP,可以用O(n)的复杂度计算出:分别以每个节点作为整个树的树根,树的某个属性值(如:树高)。
之所以把换根dp和树的直径写在一起,是因为,求树的直径时不论是树形dp还是换根DP,都可以结合着思考更容易吃透。
  • 记录三个数组down1,down2,up,分别表示:
    • down1[u]: 以u为根的子树,向下最大的值(比如求树高,就是最长简单路径)
    • down2[u]: 以u为根的子树,向下次大的值(比如求树高,就是最长简单路径)
    • up[u]: 以u为根的子树,向上最大的值(比如求树高,就是最长简单路径)
  • 那么当整颗树以u为根时,最大值就是max(down1[u],up[u])。可以想象一下揪着u节点往上提,所有路径其实就是down和up。
  • 第一遍dfs,后根遍历求出down1和down2。
  • 第二遍dfs,先根遍历求出up,注意讨论:如果当前节点v在u的最大路径(down1)下,up[v]应该从应该从u的次大(down2[u])来,或up[u]来,见途中绿色轨迹;否则从最大来。
  • 在这里插入图片描述

2. 复杂度分析

  1. O(ln)

3. 常见应用

4. 常用优化

  1. 换根DP的三个数组,可以写成f = [[0,0,0] for _ in range(n)],但这是负优化,还是写三个好使。
  2. 换根DP可以用BFS先列出dp序,然后遍历,省去dfs爆栈的开销。

二、 模板代码

1. 单纯询问树的直径值

例题: 4799. 最远距离
这题是一道裸的树的直径。下边的代码大部分会用这题测试。

  • 如果没特殊需求,建议用两次bfs免去爆栈风险。
def bootstrap(f, stack=[]):
    def wrappedfunc(*args, **kwargs):
        if stack:
            return f(*args, **kwargs)
        else:
            to = f(*args, **kwargs)
            while True:
                if type(to) is GeneratorType:
                    stack.append(to)
                    to = next(to)
                else:
                    stack.pop()
                    if not stack:
                        break
                    to = stack[-1].send(to)
            return to

    return wrappedfunc


#    4994    ms
def solve2():
    n, m = RI()
    g = [[] for _ in range(n)]
    for _ in range(m):
        u, v = RI()
        g[u - 1].append(v - 1)
        g[v - 1].append(u - 1)
    start = 0
    ans = 0
    for _ in range(2):
        @bootstrap
        def dfs(u, fa, depth=0):
            nonlocal ans, start
            if depth > ans:
                start = u
                ans = depth
            for v in g[u]:
                if v != fa:
                    yield dfs(v, u, depth + 1)
            yield

        dfs(start, -1)

    print(ans)


#   3618    ms
def solve1():
    n, m = RI()
    g = [[] for _ in range(n)]
    for _ in range(m):
        u, v = RI()
        g[u - 1].append(v - 1)
        g[v - 1].append(u - 1)
    start = 0
    ans = 0
    for _ in range(2):
        q = deque([start])
        fas = [-1] * n
        step = 0
        while q:
            step += 1
            for _ in range(len(q)):
                u = q.popleft()
                start = u
                for v in g[u]:
                    fas[v] = u
                    if v != fas[u]:
                        q.append(v)
        ans = max(ans, step)
    print(ans - 1)

2. 求出树的直径两端搞事情

链接: abc267_f - Exactly K Steps
在这里插入图片描述

  • 这题找出一条直径,就是最大支持的距离,如果够就可以。
  • 因此还需要多一次dfs,记录当前路径。
  • 方便的写法就是三次dfs,前两次找直径,第三次算答案,每次都记录最远的端点作为下一次的起始点即可。
    bfs找直径模板

    def get_tree_diameter(g, root=0):  # bfs两次找直径的端点
        """
        求树的直径,g是0-indexed,默认第一次root是从0
        返回某条直径的两个端点u,v,以及直径值d(边数而不是点数)
        简述:求树的直径时,可以通过树形DP做,也可以通过两次遍历找最远的点(bfs或dfs都可以)
            第二次的起始点就是某条直径的端点
        """
        if not g[root]:
            return root, root, 0

        def bfs(start):
            q = deque([(start, -1)])
            step = -1
            while q:
                step += 1
                for _ in range(len(q)):
                    u, fa = q.popleft()
                    for v in g[u]:
                        if v != fa:
                            q.append((v, u))
            return u, step

        x, _ = bfs(root)
        y, d = bfs(x)
        return x, y, d

ac代码

import sys
import heapq
import bisect
import random
import io, os
from bisect import *
from collections import *
from contextlib import redirect_stdout
from itertools import *
from math import sqrt, gcd, inf
from array import *
from functools import lru_cache
from types import GeneratorType

RI = lambda: map(int, sys.stdin.buffer.readline().split())
RS = lambda: map(bytes.decode, sys.stdin.buffer.readline().strip().split())
RILST = lambda: list(RI())
DEBUG = lambda *x: sys.stderr.write(f'{str(x)}\n')

MOD = 10 ** 9 + 7
PROBLEM = """https://atcoder.jp/contests/abc267/tasks/abc267_f

输入 n(≤2e5) 和一棵树的 n-1 条边(节点编号从 1 开始)。
然后输入 q(≤2e5) 和 q 个询问,每个询问输入 u 和 k。
输出到 u 的距离为 k 的任意一个点。如果这个点不存在则输出 -1。
距离指两点最短路上的边的数目。
输入
5
1 2
2 3
3 4
3 5
3
2 2
5 3
3 3
输出
4
1
-1
"""
"""https://atcoder.jp/contests/abc267/submissions/37595672

求出树的任意一条直径,设直径端点为 x 和 y。

从 x 出发 dfs,同时记录 dfs 路径上的点。
如果点 u 的深度 d >= k,那么 dfs 路径上的第 d-k 个点就是答案。

一次 dfs 不一定能满足所有点,再从 y 出发 dfs 一次就能保证所有点都有答案(除了 k 非常大的)。
"""


def bootstrap(f, stack=[]):
    def wrappedfunc(*args, **kwargs):
        if stack:
            return f(*args, **kwargs)
        else:
            to = f(*args, **kwargs)
            while True:
                if type(to) is GeneratorType:
                    stack.append(to)
                    to = next(to)
                else:
                    stack.pop()
                    if not stack:
                        break
                    to = stack[-1].send(to)
            return to

    return wrappedfunc


# 1629 ms
if __name__ == '__main__':
    n, = RI()
    g = [[] for _ in range(n)]
    for _ in range(n - 1):
        u, v = RI()
        g[u - 1].append(v - 1)
        g[v - 1].append(u - 1)

    q, = RI()
    qs = defaultdict(list)
    for i in range(q):
        u, d = RI()
        qs[u - 1].append([i, d])
    ans = [-1] * q

    leaf = mx = 0
    path = [-1] * n


    @bootstrap
    def dfs(u, fa, d=0):
        path[d] = u
        global leaf, mx
        if d > mx:
            leaf = u
            mx = d
        for i, k in qs[u]:
            if d >= k:
                ans[i] = path[d - k] + 1
        for v in g[u]:
            if v != fa:
                yield dfs(v, u, d + 1)
        yield


    for _ in range(3):
        dfs(leaf, -1)

    print('\n'.join(map(str, ans)))

# # 1423 ms
# if __name__ == '__main__':
#     n, = RI()
#     g = [[] for _ in range(n)]
#     for _ in range(n - 1):
#         u, v = RI()
#         g[u - 1].append(v - 1)
#         g[v - 1].append(u - 1)
#
#
#     def get_tree_diameter(g, root=0):  # bfs两次找直径的端点
#         if not g[root]:
#             return root, root
#
#         def bfs(start):
#             q = deque([(start, -1)])
#             while q:
#                 u, fa = q.popleft()
#                 for v in g[u]:
#                     if v != fa:
#                         q.append((v, u))
#             return u
#
#         x = bfs(root)
#         y = bfs(x)
#         return x, y
#
#
#     x, y = get_tree_diameter(g)
#
#     q, = RI()
#     qs = defaultdict(list)
#     for i in range(q):
#         u, d = RI()
#         qs[u - 1].append([i, d])
#     # print(qs)
#     ans = [-1] * q
#     path = [0] * n  # 当前深度链接到根的路径
#
#
#     @bootstrap
#     def dfs(u, fa, d=0):
#         path[d] = u
#         for i, k in qs[u]:
#             if d >= k:
#                 ans[i] = path[d - k] + 1
#         for v in g[u]:
#             if v != fa:
#                 yield dfs(v, u, d + 1)
#         yield
#
#
#     dfs(x, -1)
#     dfs(y, -1)
#     print('\n'.join(map(str, ans)))

# # 1892ms
# if __name__ == '__main__':
#     n, = RI()
#     g = [[] for _ in range(n)]
#     for _ in range(n - 1):
#         u, v = RI()
#         g[u - 1].append(v - 1)
#         g[v - 1].append(u - 1)
#
#
#     def get_tree_diameter(g, root=0):
#         ans = (0, root, root)
#         if not g[root]:
#             return ans
#
#         dp = {}
#
#         @bootstrap
#         def dfs(u, fa, depth=0):  # 返回树高以及最深的叶子
#             if len(g[u]) == 1 and u != root:  # 没有子节点了,它就是一个端点(叶子),高度1
#                 dp[u] = (1, u)
#                 yield
#
#             hs = []
#             for v in g[u]:
#                 if v != fa:
#                     yield dfs(v, u, depth + 1)
#                     h, o = dp[v]
#                     if len(hs) < 2:
#                         heapq.heappush(hs, (h, o))
#                     else:
#                         heapq.heappushpop(hs, (h, o))
#
#             if len(hs) == 2:
#                 l, r = max((depth, root), hs[0]), hs[1]
#             else:
#                 l, r = (depth, root), hs[0]
#             p = (l[0] + r[0], l[1], r[1])
#             # print(p)
#             nonlocal ans
#             if p > ans:
#                 ans = p
#             dp[u] = (hs[-1][0] + 1, hs[-1][1])
#             yield
#
#         dfs(root, -1)
#         return ans
#
#
#     d, x, y = get_tree_diameter(g)
#
#     q, = RI()
#     qs = defaultdict(list)
#     for i in range(q):
#         u, d = RI()
#         qs[u - 1].append([i, d])
#     # print(qs)
#     ans = [-1] * q
#     path = [0] * n  # 当前深度链接到根的路径
#
#
#     @bootstrap
#     def dfs(u, fa, d=0):
#         path[d] = u
#         for i, k in qs[u]:
#             if d >= k:
#                 ans[i] = path[d - k] + 1
#         for v in g[u]:
#             if v != fa:
#                 yield dfs(v, u, d + 1)
#         yield
#
#
#     dfs(x, -1)
#     dfs(y, -1)
#     print('\n'.join(map(str, ans)))

3. 换根DP求树的直径(大炮打蚊子,别这么做,只是用来帮助理解换根DP)


#  换根dp   4122    ms
def solve():
    n, m = RI()
    g = [[] for _ in range(n)]
    for _ in range(m):
        u, v = RI()
        g[u - 1].append(v - 1)
        g[v - 1].append(u - 1)

    def get_tree_diameter(g, root=0):  # bfs两次找直径的端点
        """ ms
        求树的直径,g是0-indexed,默认第一次root是从0
        返回直径值(边数而不是点数)
        简述:换根dp,假设直径两端是u,v,那么以u为树根,v一定是最远节点,求最大树高即可。
        """
        if not g[root]:
            return 0
        down1, down2, up = [0] * n, [0] * n, [0] * n  # 初始化向下最大/次大树高、向上树高(其实不是树高,是最远简单路径)
        order = []  # dp序
        fas = [-1] * n  # 记录父节点
        q = deque([root])  # bfs求order
        while q:
            u = q.popleft()
            order.append(u)
            for v in g[u]:
                if v != fas[u]:
                    fas[v] = u
                    q.append(v)

        for u in order[::-1]:  # 第一遍,自底向上求每个子树的最大/次大树高
            for v in g[u]:
                if v == fas[u]:
                    continue
                h = down1[v] + 1  # 高度
                if h > down1[u]:
                    down1[u], down2[u] = h, down1[u]
                elif h > down2[u]:
                    down2[u] = h
        for u in order:
            for v in g[u]:
                if v == fas[u]:
                    continue
                if down1[u] == down1[v] + 1:  # v在u的最大路径上,则往上的路径应该可能从次大走
                    up[v] = max(down2[u], up[u]) + 1
                else:  # 否则一定从最大走
                    up[v] = max(down1[u], up[u]) + 1

        return max(max(x, y) for x, y in zip(down1, up))

    print(get_tree_diameter(g))

4. 换根dp求特定值(另附小日子模板)

链接: abc222_f - Expensive Expense
在这里插入图片描述

  • 这题由于要附加一个路径上的端点节点值,因此更新答案时多了个,可能是d[u]+w。
  • solve2/3分别是用上述方法三个数组down1/2+up计算的过程。
  • 另附一个从周赛抄来的模板,我尽可能的按照自己的理解把注释写成了中文。
    • 模板需要调整三个地方 e/op/composition方法,具体可以看代码。
    • 注意e通常是0,我目前做的题都是。
    • op目前我做的题都是max。
    • composition是最关键的,知道子树和节点,如何拼接答案。
# Problem: F - Expensive Expense
# Contest: AtCoder - Exawizards Programming Contest 2021(AtCoder Beginner Contest 222)
# URL: https://atcoder.jp/contests/abc222/tasks/abc222_f
# Memory Limit: 1024 MB
# Time Limit: 4000 ms

import sys
import bisect
import random
import io, os
from bisect import *
from collections import *
from contextlib import redirect_stdout
from itertools import *
from math import sqrt, gcd, inf
from array import *
from functools import lru_cache
from types import GeneratorType
from heapq import *

RI = lambda: map(int, sys.stdin.buffer.readline().split())
RS = lambda: map(bytes.decode, sys.stdin.buffer.readline().strip().split())
RILST = lambda: list(RI())
DEBUG = lambda *x: sys.stderr.write(f'{str(x)}\n')

MOD = 10 ** 9 + 7
PROBLEM = """https://atcoder.jp/contests/abc222/tasks/abc222_f

输入 n(2≤n≤2e5) 和一棵树的 n-1 条边(节点编号从 1 开始),每条边输入两个端点和边权。
然后输入 n 个数 d,d[i] 表示点 i 的点权。

定义 f(x,y) = 从 x 到 y 的简单路径的边权之和,再加上 d[y]。
定义 g(x) = max{f(x,i)},这里 i 取遍 1~n 的所有不为 x 的点。
输出 g(1),g(2),...,g(n)。
输入
3
1 2 2
2 3 3
1 2 3
输出
8
6
6
"""

from typing import Callable, Generic, List, TypeVar

T = TypeVar("T")
E = Callable[[int], T]
"""identify element of op, and answer of leaf"""
Op = Callable[[T, T], T]
"""merge value of child node"""
Composition = Callable[[T, int, int, int], T]
"""return value from child node to parent node"""


#    1187    ms
def solve():
    class Rerooting(Generic[T]):
        __slots__ = ("g", "_n", "_decrement", "_root", "_parent", "_order")

        def __init__(self, n: int, decrement: int = 0, edges=None):
            """
            n: 节点个数
            decrement: 节点id可能需要偏移 (1-indexed则-1, 0-indexed则0)
            """
            self.g = g = [[] for _ in range(n)]
            self._n = n
            self._decrement = decrement
            self._root = None  # 一开始的根
            if edges:
                for u, v in edges:
                    u -= decrement
                    v -= decrement
                    g[u].append(v)
                    g[v].append(u)

        def add_edge(self, u: int, v: int):
            """
            无向树加边
            """
            u -= self._decrement
            v -= self._decrement
            self.g[u].append(v)
            self.g[v].append(u)

        def rerooting(
                self, e: E["T"], op: Op["T"], composition: Composition["T"], root=0
        ) -> List["T"]:
            """
            - e: 初始化每个节点的价值
              (root) -> res
              mergeの単位元
              例:求最长路径 e=0

            - op: 两个子树答案如何组合或取舍
              (childRes1,childRes2) -> newRes
              例:求最长路径 return max(childRes1,childRes2)

            - composition: 知道子子树答案和节点值,如何更新子树答案
              (from_res,fa,u,use_fa) -> new_res
              use_fa: 0表示用u更新fa的dp1,1表示用fa更新u的dp2
              例:最长路径return from_res+1

            - root: 可能要设置初始根,默认是0
            <概要> 换根DP模板,用线性时间获取以每个节点为根整颗树的情况。
            注意最终返回的dp[u]代表以u为根时,u的所有子树的最优情况(不包括u节点本身),因此如果要整颗子树情况,还要再额外计算。
            1. 记录dp1,dp2。其中:
                dp1[u] 代表 以u为根的子树,它的孩子子树的最优值,即u节点本身不参与计算。注意,和我们一般定义的f[u]代表以u为根的子树2情况不同。
                dp2[v] 代表 除了v以外,它的兄弟子树的最优值。依然注意,v不参与,同时u也不参与(u是v的父节点)。
                建议画图理解。
            2. dp2[v]的含义后边将进行一次变动,变更为v的兄弟、u的父过来的路径,merge上u节点本身最后得出来的值。即v以父亲为邻居向外延伸的最优值(不含v,但含父)。
            3. 同时dp1[u]的含义更新为目标的含义:以u为根,u的子节点们所在子树的最优情况。
            4. 这样dp1,dp2将分别代表u的向下子树的最优,u除了向下子树以外的最优(一定从父节点来,但父节点可能从兄弟来或祖宗来)
            <步骤>
            1. 先从任意root出发(一般是0),获取bfs层序。这里是为了方便dp,或者直接dfs树形DP其实也是可以的,但可能会爆栈。
            2. 自底向上dp,用自身子树情况更新dp1,除自己外的兄弟子树情况更新dp2。
            3. 自顶向下dp,变更dp2和dp1的含义。这时对于u来说存在三种子树(强烈建议画图观察):
                ① u本身的子树,它们的最优解已经存在于之前的dp1[u]。
                ② u的兄弟子树+fa,它们的最优解=composition(dp2[u],fa,u,use_fa=1)。
                ③ 连接到fa的最优子树+fa,最优解=composition(dp2[fa],fa,u,use_fa=1)。
                    注意这里的dp2含义已变更,由于我们是自顶向下计算,因此dp2[fa]已更新。
                    ②和③可以写一起来更新dp2[u]

            計算量 O(|V|) (Vは頂点数)
            参照 https://qiita.com/keymoon/items/2a52f1b0fb7ef67fb89e
            """
            # step1
            root -= self._decrement
            assert 0 <= root < self._n
            self._root = root
            g = self.g
            _fas = self._parent = [-1] * self._n  # 记录每个节点的父节点
            _order = self._order = [root]  # bfs记录遍历层序,便于后续dp
            q = deque([root])
            while q:
                u = q.popleft()
                for v in g[u]:
                    if v == _fas[u]:
                        continue
                    _fas[v] = u
                    _order.append(v)
                    q.append(v)

            # step2
            dp1 = [e(i) for i in range(self._n)]  # !子树部分的dp值,假设u是当前子树的根,vs是第一层儿子(它的非父邻居),则dp1[u]=op(dp1(vs))
            dp2 = [e(i) for i in
                   range(
                       self._n)]  # !非子树部分的dp值,假设u是当前子树的根,vs={v1,v2..vi..}是第一层儿子(它的非父邻居),则dp2[vi]=op(dp1(vs-vi)),即他的兄弟们

            for u in _order[::-1]:  # 从下往上拓扑序dp
                res = e(u)
                for v in g[u]:
                    if _fas[u] == v:
                        continue
                    dp2[v] = res
                    res = op(res, composition(dp1[v], u, v, 0))  # op从下往上更新dp1
                # 由于最大可能在后边,因此还得倒序来一遍
                res = e(u)
                for v in g[u][::-1]:
                    if _fas[u] == v:
                        continue
                    dp2[v] = op(res, dp2[v])
                    res = op(res, composition(dp1[v], u, v, 0))
                dp1[u] = res

            # step3 自顶向下计算每个节点作为根时的dp1,dp2的含义变更为:dp2[u]为u的兄弟+父。这样对v来说dp1[u] = op(dp1[fa],dp1[u])

            for u in _order[1:]:
                fa = _fas[u]
                dp2[u] = composition(
                    op(dp2[u], dp2[fa]), fa, u, 1
                )  # op从上往下更新dp2
                dp1[u] = op(dp1[u], dp2[u])

            return dp1

    n, = RI()
    r = Rerooting(n)
    ws = {}
    for _ in range(n - 1):
        u, v, w = RI()
        u -= 1
        v -= 1
        ws[u, v] = w
        ws[v, u] = w
        r.add_edge(u, v)
    d = RILST()

    def e(root: int) -> int:
        # 转移时单个点不管相邻子树的贡献
        # 例:最も遠い点までの距離を求める場合 e=0
        return 0

    def op(child_res1: int, child_res2: int) -> int:
        # 如何组合/取舍两个子树的答案
        # 例:求最长路径 return max(childRes1,childRes2)
        return max(child_res1, child_res2)

    def composition(from_res: int, fa: int, u: int, use_fa: int = 0) -> int:
        # 知道子树的每个子树和节点值,如何更新子树答案;
        # 例子:求最长路径 return from_res+1
        if use_fa == 0:  # cur -> parent
            return max(from_res, d[u]) + ws[u, fa]
        return max(from_res, d[fa]) + ws[fa, u]

    res = r.rerooting(e, op, composition)
    print(*res, sep='\n')



def bootstrap(f, stack=[]):
    def wrappedfunc(*args, **kwargs):
        if stack:
            return f(*args, **kwargs)
        else:
            to = f(*args, **kwargs)
            while True:
                if type(to) is GeneratorType:
                    stack.append(to)
                    to = next(to)
                else:
                    stack.pop()
                    if not stack:
                        break
                    to = stack[-1].send(to)
            return to

    return wrappedfunc


#    927     ms
def solve2():
    n, = RI()
    g = [[] for _ in range(n)]

    for _ in range(n - 1):
        u, v, w = RI()
        u -= 1
        v -= 1
        g[u].append((v, w))
        g[v].append((u, w))
    d = RILST()
    down1, down2, up = [0] * n, [0] * n, [0] * n

    @bootstrap
    def dfs(u, fa):
        for v, w in g[u]:
            if v == fa:
                continue
            yield dfs(v, u)
            s = max(down1[v], d[v]) + w
            if s > down1[u]:
                down2[u] = down1[u]
                down1[u] = s
            elif s > down2[u]:
                down2[u] = s
        yield

    @bootstrap
    def reroot(u, fa):
        for v, w in g[u]:
            if v != fa:
                if down1[u] == down1[v] + w or down1[u] == d[v] + w:
                    up[v] = max(down2[u] + w, up[u] + w, d[u] + w)
                else:
                    up[v] = max(down1[u] + w, up[u] + w, d[u] + w)
                yield reroot(v, u)
        yield

    dfs(0, -1)
    reroot(0, -1)
    print(*[max(a, b) for a, b in zip(up, down1)], sep='\n')


#     715    ms
def solve3():
    n, = RI()
    g = [[] for _ in range(n)]

    for _ in range(n - 1):
        u, v, w = RI()
        u -= 1
        v -= 1
        g[u].append((v, w))
        g[v].append((u, w))
    d = RILST()
    down1, down2, up = [0] * n, [0] * n, [0] * n
    order = []
    q = deque([0])
    fas = [-1] * n
    while q:
        u = q.popleft()
        order.append(u)
        for v, w in g[u]:
            if v != fas[u]:
                fas[v] = u
                q.append(v)

    for u in order[::-1]:
        for v, w in g[u]:
            if v == fas[u]:
                continue
            s = max(down1[v], d[v]) + w
            if s > down1[u]:
                down2[u] = down1[u]
                down1[u] = s
            elif s > down2[u]:
                down2[u] = s
    for u in order:
        for v, w in g[u]:
            if v != fas[u]:
                if down1[u] == down1[v] + w or down1[u] == d[v] + w:
                    up[v] = max(down2[u] + w, up[u] + w, d[u] + w)
                else:
                    up[v] = max(down1[u] + w, up[u] + w, d[u] + w)

    print(*[max(a, b) for a, b in zip(up, down1)], sep='\n')


if __name__ == '__main__':
    solve()

5. 换根dp求去掉一个叶子的值。

链接: 2538. 最大价值和与最小价值和的差值

  • 这题是周赛T4,我当时用树形DP写了一大堆两次dfs做出来了。
    手写换根
class Solution:
    def maxOutput(self, n: int, edges: List[List[int]], price: List[int]) -> int:
        g = [[] for _ in range(n)]
        for u,v in edges:
            g[u].append(v)
            g[v].append(u)
            # print(u,v)
        ans = 0

        f = [[0,0,0] for _ in range(n)]  # f[i][0/1/2]代表:i向下走最大路径和,向下走次大路径和,向上走最大路径和;答案一定在向下或向上走的路径中
        def dfs1(u,fa):  # 更新向下走的最大/次大路径和
            f[u][0] = p = price[u]
            for v in g[u]:
                if v != fa:
                    dfs1(v,u)
                    x = f[v][0]+p
                    if f[u][0]<x:
                        f[u][1] = f[u][0]
                        f[u][0] = x
                    elif f[u][1] < x:
                        f[u][1] = x 
        
        def dfs2(u,fa):
            for v in g[u]:
                if v != fa:
                    p = price[v]
                    if f[u][0] == f[v][0] + price[u]:
                        f[v][2] = max(f[u][2],f[u][1]) + p
                    else:
                        f[v][2] = max(f[u][2],f[u][0]) + p 
                    dfs2(v,u)
        dfs1(0,-1)
        dfs2(0,-1)

        return max(max(a-price[i],c-price[i]) for i,(a,_,c) in enumerate(f))

套模板

from typing import List, Tuple, Optional
from collections import defaultdict, Counter, deque

MOD = int(1e9 + 7)
INF = int(1e20)

# 给你一个 n 个节点的无向无根图,节点编号为 0 到 n - 1 。给你一个整数 n 和一个长度为 n - 1 的二维整数数组 edges ,其中 edges[i] = [ai, bi] 表示树中节点 ai 和 bi 之间有一条边。

# 每个节点都有一个价值。给你一个整数数组 price ,其中 price[i] 是第 i 个节点的价值。

# 一条路径的 价值和 是这条路径上所有节点的价值之和。

# 你可以选择树中任意一个节点作为根节点 root 。选择 root 为根的 开销 是以 root 为起点的所有路径中,价值和 最大的一条路径与最小的一条路径的差值。

# 请你返回所有节点作为根节点的选择中,最大 的 开销 为多少。


from typing import Callable, Generic, List, TypeVar

T = TypeVar("T")

E = Callable[[int], T]
"""identify element of op, and answer of leaf"""

Op = Callable[[T, T], T]
"""merge value of child node"""

Composition = Callable[[T, int, int, int], T]
"""return value from child node to parent node"""


class Rerooting(Generic[T]):
    __slots__ = ("g", "_n", "_decrement", "_root", "_parent", "_order")

    def __init__(self, n: int, decrement: int = 0, edges=None):
        """
        n: 节点个数
        decrement: 节点id可能需要偏移 (1-indexed则-1, 0-indexed则0)
        """
        self.g = g = [[] for _ in range(n)]
        self._n = n
        self._decrement = decrement
        self._root = None  # 一开始的根
        if edges:
            for u, v in edges:
                u -= decrement
                v -= decrement
                g[u].append(v)
                g[v].append(u)

    def add_edge(self, u: int, v: int):
        """
        无向树加边
        """
        u -= self._decrement
        v -= self._decrement
        self.g[u].append(v)
        self.g[v].append(u)

    def rerooting(
            self, e: E["T"], op: Op["T"], composition: Composition["T"], root=0
    ) -> List["T"]:
        """
        - e: 初始化每个节点的价值
          (root) -> res
          mergeの単位元
          例:求最长路径 e=0

        - op: 两个子树答案如何组合或取舍
          (childRes1,childRes2) -> newRes
          例:求最长路径 return max(childRes1,childRes2)

        - composition: 知道子子树答案和节点值,如何更新子树答案
          (from_res,fa,u,use_fa) -> new_res
          use_fa: 0表示用u更新fa的dp1,1表示用fa更新u的dp2
          例:最长路径return from_res+1

        - root: 可能要设置初始根,默认是0
        <概要> 换根DP,用线性时间获取以每个节点为根整颗树的情况。
        注意最终返回的dp[u]代表以u为根时,u的所有子树的最优情况(不包括u节点本身),因此如果要整颗子树情况,还要再额外计算。
        1. 记录dp1,dp2。其中:
            dp1[u] 代表 以u为根的子树,它的孩子子树的最优值,即u节点本身不参与计算。注意,和我们一般定义的f[u]代表以u为根的子树2情况不同。
            dp2[v] 代表 除了v以外,它的兄弟子树的最优值。依然注意,v不参与,同时u也不参与(u是v的父节点)。
            建议画图理解。
        2. dp2[v]的含义后边将进行一次变动,变更为v的兄弟、u的父过来的路径,merge上u节点本身最后得出来的值。即v以父亲为邻居向外延伸的最优值(不含v,但含父)。
        3. 同时dp1[u]的含义更新为目标的含义:以u为根,u的子节点们所在子树的最优情况。
        4. 这样dp1,dp2将分别代表u的向下子树的最优,u除了向下子树以外的最优(一定从父节点来,但父节点可能从兄弟来或祖宗来)
        <步骤>
        1. 先从任意root出发(一般是0),获取bfs层序。这里是为了方便dp,或者直接dfs树形DP其实也是可以的,但可能会爆栈。
        2. 自底向上dp,用自身子树情况更新dp1,除自己外的兄弟子树情况更新dp2。
        3. 自顶向下dp,变更dp2和dp1的含义。这时对于u来说存在三种子树(强烈建议画图观察):
            ① u本身的子树,它们的最优解已经存在于之前的dp1[u]。
            ② u的兄弟子树+fa,它们的最优解=composition(dp2[u],fa,u,use_fa=1)。
            ③ 连接到fa的最优子树+fa,最优解=composition(dp2[fa],fa,u,use_fa=1)。
                注意这里的dp2含义已变更,由于我们是自顶向下计算,因此dp2[fa]已更新。
                ②和③可以写一起来更新dp2[u]

        計算量 O(|V|) (Vは頂点数)
        参照 https://qiita.com/keymoon/items/2a52f1b0fb7ef67fb89e
        """
        # step1
        root -= self._decrement
        assert 0 <= root < self._n
        self._root = root
        g = self.g
        _fas = self._parent = [-1] * self._n  # 记录每个节点的父节点
        _order = self._order = [root]  # bfs记录遍历层序,便于后续dp
        q = deque([root])
        while q:
            u = q.popleft()
            for v in g[u]:
                if v == _fas[u]:
                    continue
                _fas[v] = u
                _order.append(v)
                q.append(v)

        # step2
        dp1 = [e(i) for i in range(self._n)]  # !子树部分的dp值,假设u是当前子树的根,vs是第一层儿子(它的非父邻居),则dp1[u]=op(dp1(vs))
        dp2 = [e(i) for i in
               range(self._n)]  # !非子树部分的dp值,假设u是当前子树的根,vs={v1,v2..vi..}是第一层儿子(它的非父邻居),则dp2[vi]=op(dp1(vs-vi)),即他的兄弟们
        for u in _order[::-1]:  # 从下往上拓扑序dp
            res = e(u)
            for v in g[u]:
                if _fas[u] == v:
                    continue
                dp2[v] = res
                res = op(res, composition(dp1[v], u, v, 0))  # op从下往上更新dp1
            # 由于最大可能在后边,因此还得倒序来一遍
            res = e(u)
            for v in g[u][::-1]:
                if _fas[u] == v:
                    continue
                dp2[v] = op(res, dp2[v])
                res = op(res, composition(dp1[v], u, v, 0))
            dp1[u] = res

        # step3 自顶向下计算每个节点作为根时的dp1,dp2的含义变更为:dp2[u]为u的兄弟+父。这样对v来说dp1[u] = op(dp1[fa],dp1[u])
        for u in _order[1:]:  #
            fa = _fas[u]
            dp2[u] = composition(
                op(dp2[u], dp2[fa]), fa, u, 1
            )  # op从上往下更新dp2
            dp1[u] = op(dp1[u], dp2[u])
        return dp1


class Solution:
    def maxOutput(self, n: int, edges: List[List[int]], price: List[int]) -> int:
        def e(root: int) -> int:
            # mergeの単位元
            # 例:最も遠い点までの距離を求める場合 e=0
            return 0

        def op(child_res1: int, child_res2: int) -> int:
            # 如何组合/取舍两个子树的答案
            # 例:求最长路径 return max(childRes1,childRes2)
            return max(child_res1, child_res2)

        def composition(from_res: int, fa: int, u: int, use_fa: int = 0) -> int:
            # 知道子树的每个子树和节点值,如何更新子树答案;
            # 例子:求最长路径 return from_res+1
            if use_fa == 0:  # cur -> parent
                return from_res + price[u]
            return from_res + price[fa]

        R = Rerooting(n, edges=edges)
        # for u, v in edges:
        #     R.add_edge(u, v)
        res = R.rerooting(e, max, composition)
        return max(res)

三、其他

四、更多例题

五、参考链接

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/175947.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

UDS诊断系列介绍14-2F服务

本文框架1. 系列介绍1.1 2F服务概述2. 2F服务请求与应答2.1 2F服务请求2.2 2F服务正响应2.3 2F服务否定响应3. 2F诊断使用示例4. Autosar系列文章快速链接1. 系列介绍 UDS&#xff08;Unified Diagnostic Services&#xff09;协议&#xff0c;即统一的诊断服务&#xff0c;是…

学习记录667@项目管理之项目人力资源管理

什么是项目人力资源管理 项目人力资源管理包括编制人力资源管理计划、组建项目团队、建设项目团队与管理项目团队的各个过程&#xff0c;不但要求充分发挥参与项目的个人的作用&#xff0c;还包括充分发挥所有与项目有关的人员-----项目负责人、客户、为项目做出贡献的个人及其…

[QMT]04-在QMT之外调用xtquant直接编写策略

背景希望不用在QMT软件里面憋屈地写代码&#xff0c;想使用pychar、vscode、notepad等IDE编写python代码&#xff0c;因为有代码提示、补全。这完全没问题&#xff01;QMT简直是为个人量化交易者量身打造的神器&#xff0c;它支持以上想法。QMT这个东东基本是由两部分组成的&am…

1. 深度学习简介|计算机视觉简介|得分函数|损失函数作用|前向传播整体流程

文章目录深度学习简介计算机视觉简介k近邻算法得分函数损失函数作用前向传播整体流程机器学习是一个大块&#xff0c;其中就包含着深度学习&#xff0c;计算机视觉等 机器学习的流程&#xff1a; 数据获取特征工程建立模型评估与应用 深度学习简介 深度学习通过特征学习进行…

力扣sql简单篇练习(四)

力扣sql简单篇练习(四) 1 超过五名学生的课 1.1 题目内容 1.1.1 基本题目信息 1.1.2 示例输入输出 1.2 示例sql语句 SELECT class FROM Courses GROUP BY class HAVING count(student)>51.3 运行截图 2 超过经理收入的员工 2.1 题目内容 2.1.1 基本题目信息 2.1.2 示例…

零信任-发展历程及概念(1)

零信任发展历程 2010 Forrester约翰金德维格正式提出零信任概念 2013 CSA成立软件定义边界SDP工作组,次年发布SDP标准规范1.0 2017 Gartner正式提出“CARTA”零信任模型 2018 Forrester发表零信任扩展模型ZTX 2019 Gartner发布零信任网络&#xff08;ZTNA&#xff09;云安…

String 与 StringBuffer 与 StringBuilder 各自的妙用

String 与 StringBuffer 与 StringBuilder 各自的妙用 每博一文案 我从未见过&#xff0c;一个早起&#xff0c;勤奋&#xff0c;谨慎&#xff0c;诚实的人&#xff0c;抱怨命运不好的。 最完美的状态&#xff0c;不是你从不失误&#xff0c;而是你从没放弃成长。没人能把你变…

【SpringMVC】看完这篇简单理解并入门SpringMVC:通过入门案例举例子的方式快速理解

SpringMVC简介1.什么是MVC2.什么是SpringMVC3.SpringMVC的特点4.入门案例1.准备工作2.配置web.xml3.创建请求控制器4.创建SpringMVC的配置文件5.测试HelloWorld6.总结1.什么是MVC MVC是一种软件架构的思想&#xff0c;将软件按照模型、视图、控制器来划分 M&#xff1a;Model&…

【Java】比较器 Comparator Comparable

一、背景 我们在使用 Collections.sort() 对链表进行排序&#xff08;或者使用 Arrays.sort() 对数组进行排序&#xff09;时&#xff0c;常常需要根据不同情况自定义排序规则。比如&#xff1a;当我们存储学生对象时&#xff0c;我们需要按照学生年龄进行排序&#xff0c;这时…

树状数组(Binary Indexed Tree (B.I.T))

树状数组 树状数组 (Binary Indexed Tree(B.I.T), Fenwick Tree) 是一个查询和修改复杂度都为 log(n) 的数据结构。 「前缀和查询」与「单点更新」 直接前驱&#xff1a;c[i] 的直接前驱为 c[i - lowbid(i)]&#xff0c;即 c[i] 左侧紧邻的子树的根。 直接后继&#xff1a;c[i…

财务精度:BigInteger 与 BigDecimal

财务精度&#xff1a;BigInteger 与 BigDecimal 每博一文案 师父说: 人这一辈子&#xff0c;真地好难。 有些人&#xff0c;好着好着&#xff0c;忽然就变陌生了&#xff0c;有些手&#xff0c;牵着牵着&#xff0c;瞬间就放开了&#xff0c;有些路&#xff0c;走着走着&#…

算法练习笔记——栈的常用方法以及算法练习

栈学习常用方法介绍力扣练习力扣 20. 有效的括号力扣 32. 最长有效括号常用方法介绍 Stack<Character> characters new Stack<>();//判断栈是否为空boolean empty characters.empty();//将a压入栈底&#xff0c;同时也返回aCharacter push characters.push(a);/…

MYSQL中的常见知识问题(一)

1、MYSQL中redolog、binlog 、undolog的区别与作用。redolog&#xff1a;即重做日志&#xff0c;用来实现事物的一个持久性&#xff0c;由radiobuff和radiolog两部分组成。其中 radiobuff是一个缓冲&#xff0c;存放在内存里面&#xff1b;radiolog是文件&#xff0c;存放在磁盘…

基于粒子群优化和引力搜索混合优化算法改进的前馈神经网络(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

【3-神经网络八股】北京大学TensorFlow2.0

课程地址&#xff1a;【北京大学】Tensorflow2.0_哔哩哔哩_bilibiliPython3.7和TensorFlow2.1六讲&#xff1a;神经网络计算&#xff1a;神经网络的计算过程&#xff0c;搭建第一个神经网络模型神经网络优化&#xff1a;神经网络的优化方法&#xff0c;掌握学习率、激活函数、损…

走进 HTML

文章目录01 什么是HTML&#xff1f;02 HTML的基本结构03 网页基本标签04 图像标签05 链接标签06 块元素和行内元素07 列表07 表格08 视频和音频09 页面结构10 iframe内联框架11 表单语法&#x1f449; 表单元素格式&#x1f449; 表单的应用&#x1f449; 表单初级验证01 什么是…

【Mysql】 数据库用户管理

【Mysql】 数据库用户管理 DCL:英文全称是Data Control Language(数据控制语言)&#xff0c;用来管理数据库用户、控制数据库的访问权限。 1. 管理用户 想要对数据库用户进行操作&#xff0c;我们首先得进入 mysql 数据库 use mysql1.1 查询用户 select * from user;该条命…

每日一问-ChapGPT-20230122-关于春节习俗

文章目录每日一问-ChapGPT系列起因每日一问-ChapGPT-20230116-关于春节习俗世界有哪些国家过春节中国各个地区过春节都有哪些习俗台湾的春节习俗有哪些新加坡过春节有哪些习俗初一到初七的传统习俗有哪些在热闹的节日里&#xff0c;自己无法融入氛围&#xff0c;是什么原因不喜…

【MySQL】第六部分 单行函数

【MySQL】第六部分 单行函数 文章目录【MySQL】第六部分 单行函数6. 单行函数6.1 常用的函数6.2 角度与弧度转换函数6.3 三角函数6.4 指数与对数6.5 进制转换6.6 字符串函数6.7 日期时间函数6.7.1 获取时间和日期6.7.2 日期与时间戳的转换6.7.3 获取月份、星期、星期数、天数等…

筑基二层 —— 图解函数递归、数组详解

目录 一.修炼必备 二.图解递归的执行过程 三.数组 3.1 一维数组 3.2 二维数组 3.3 数组的共同问题 一.修炼必备 1.入门必备&#xff1a;VS2019社区版&#xff0c;下载地址&#xff1a;Visual Studio 较旧的下载 - 2019、2017、2015 和以前的版本 (microsoft.com) 2.趁手武…