    • 给定一棵 (n) 个点的以 (1) 为根的有根树,在 (m) 条给出的带权祖-孙路径中选出若干条使之覆盖所有边,最小化权值之和。

    • (n, m leq 3 imes 10^5)


    看到题先想到了 [NOI2020]命运,就想了类似的做法。

    (O(n^2)) 的 dp 很好想:设 (f[x][d]) 表示覆盖了 (x) 的子树且选择的路径中端点深度最小值为 (d) 的最小权值之和,这个是个 min 卷积,像NOI题那样拿线段树合并优化即可。


    #include <bits/stdc++.h>
    #define perr(a...) fprintf(stderr, a)
    #define dbg(a...) 42 //perr("33[32;1m"), perr(a), perr("33[0m")
    template <class T, class U>
    inline bool smin(T &x, const U &y) {
      return y < x ? x = y, 1 : 0;
    template <class T, class U>
    inline bool smax(T &x, const U &y) {
      return x < y ? x = y, 1 : 0;
    using LL = long long;
    using PII = std::pair<int, int>;
    constexpr int N(3e5 + 5);
    int n, m;
    std::vector<int> g[N];
    std::vector<PII> p[N];
    struct Node {
      Node *ls, *rs;
      LL min, tag;
      Node() : ls(nullptr), rs(nullptr), min(LLONG_MAX), tag(0) {}
      void add(LL x) { min += x, tag += x; }
      void pushup() {
        min = LLONG_MAX;
        if (ls) smin(min, ls->min);
        if (rs) smin(min, rs->min);
      void pushdown() {
        if (tag) {
          if (ls) ls->add(tag);
          if (rs) rs->add(tag);
          tag = 0;
    } *root[N];
    void ins(Node *&o, int l, int r, int x, LL y) {
      if (!o) o = new Node;
      smin(o->min, y);
      if (l == r) return;
      int m = l + r >> 1;
      x <= m ? ins(o->ls, l, m, x, y) : ins(o->rs, m + 1, r, x, y);
    void trash(Node *o) {
      if (!o) return;
      trash(o->ls), trash(o->rs);
      delete o;
    void del(Node *&o, int l, int r, int x, int y) {
      if (!o || x > r || y < l) return;
      if (x <= l && r <= y) {
        o = nullptr;
      int m = l + r >> 1;
      del(o->ls, l, m, x, y);
      del(o->rs, m + 1, r, x, y);
      if (o->ls || o->rs) {
      } else {
        delete o;
        o = nullptr;
    Node *merge(Node *x, Node *y, int l, int r, LL xr = LLONG_MAX, LL yr = LLONG_MAX) {
      if (!x) {
        if (!y) return nullptr;
        if (xr == LLONG_MAX) return trash(y), nullptr;
        return y;
      if (!y) {
        if (yr == LLONG_MAX) return trash(x), nullptr;
        return x;
      if (l == r) {
        assert(l == r);
        x->min += std::min(y->min, yr);
        if (xr < LLONG_MAX) smin(x->min, y->min + xr);
        delete y;
        return x;
      int m = l + r >> 1;
      x->pushdown(), y->pushdown();
      LL nxr = xr, nyr = yr;
      if (x->rs) smin(nxr, x->rs->min);
      if (y->rs) smin(nyr, y->rs->min);
      x->ls = merge(x->ls, y->ls, l, m, nxr, nyr);  
      x->rs = merge(x->rs, y->rs, m + 1, r, xr, yr);
      delete y;
      return x;
    LL ask(Node *o, int l, int r, int x, int y) {
      if (!o || x > r || y < l) return LLONG_MAX;
      if (x <= l && r <= y) return o->min;
      int m = l + r >> 1;
      return std::min(ask(o->ls, l, m, x, y), ask(o->rs, m + 1, r, x, y));
    int dep[N], max_dep;
    void dfs0(int x, int fa) {
      dep[x] = dep[fa] + 1;
      smax(max_dep, dep[x]);
      for (int y : g[x]) {
        if (y == fa) continue;
        dfs0(y, x);
    void dfs(int x, int fa) {
      if (x > 1 && g[x].size() == 1) {
        for (auto &v : p[x]) {
          ins(root[x], 1, max_dep, dep[v.first], v.second);
        if (!root[x]) {
          std::cout << "-1
      for (int y : g[x]) {
        if (y == fa) continue;
        dfs(y, x);
        root[x] = root[x] ? merge(root[x], root[y], 1, max_dep) : root[y];
      for (auto &v : p[x]) {
        LL s = ask(root[x], 1, max_dep, dep[v.first], dep[x]);
        if (s < LLONG_MAX) ins(root[x], 1, max_dep, dep[v.first], s + v.second);
      del(root[x], 1, max_dep, dep[x] + (x == 1), max_dep);
      if (!root[x]) {
        std::cout << "-1
    int main() {
      std::cin >> n >> m;
      if (n == 1) return puts("0"), 0;
      for (int i = 1, x, y; i < n; i++) {
        std::cin >> x >> y;
        g[x].push_back(y), g[y].push_back(x);
      while (m--) {
        int x, y, z;
        std::cin >> x >> y >> z;
        if (x == y) continue;
        p[x].emplace_back(y, z); 
      dfs0(1, 0);
      dfs(1, 0);
      std::cout << root[1]->min << "
      return 0;



    用小根堆存一下覆盖 (x) 子树的权值和,合并两个堆分别整体加上另一个堆的最小值就行了。


    #include <bits/stdc++.h>
    template <class T, class U>
    inline bool smin(T &x, const U &y) {
      return y < x ? x = y, 1 : 0;
    template <class T, class U>
    inline bool smax(T &x, const U &y) {
      return x < y ? x = y, 1 : 0;
    using LL = long long;
    using PII = std::pair<int, int>;
    constexpr int N(3e5 + 5);
    int n;
    std::vector<int> g[N];
    std::vector<PII> p[N];
    struct Node {
      Node *ls, *rs;
      LL val, tag;
      int dep;
      Node(LL v, int d) : ls(nullptr), rs(nullptr), val(v), tag(0), dep(d) {}
      void add(LL x) {
        val += x;
        tag += x;
      void pushdown() {
        if (tag) {
          if (ls) ls->add(tag);
          if (rs) rs->add(tag);
          tag = 0;
    } *root[N];
    Node *merge(Node *x, Node *y) {
      if (!x) return y;
      if (!y) return x;
      static std::mt19937 rnd(std::chrono::high_resolution_clock::now().time_since_epoch().count());
      if (x->val > y->val) std::swap(x, y);
      x->pushdown(), y->pushdown();
      if (rnd() & 1) {
        x->rs = merge(x->rs, y);
      } else {
        x->ls = merge(x->ls, y);
      return x;
    int dep[N];
    void update(int x) {
      while (root[x] && root[x]->dep >= dep[x] + (x == 1)) {
        root[x] = merge(root[x]->ls, root[x]->rs);
      if (!root[x]) {
    void dfs(int x, int fa) {
      dep[x] = dep[fa] + 1;
      if (x > 1 && g[x].size() == 1) {
        for (auto [f, v] : p[x]) {
          root[x] = merge(root[x], new Node(v, dep[f]));
      for (int y : g[x]) {
        if (y == fa) continue;
        dfs(y, x);
        if (root[x]) {
          LL v = root[x]->val;
        root[x] = merge(root[x], root[y]);
      for (auto [f, v] : p[x]) {
        root[x] = merge(root[x], new Node(v + root[x]->val, dep[f]));
    int main() {
      int m;
      std::cin >> n >> m;
      if (n == 1) return puts("0"), 0;
      for (int i = 1, x, y; i < n; i++) {
        std::cin >> x >> y;
        g[x].push_back(y), g[y].push_back(x);
      while (m--) {
        int x, y, z;
        std::cin >> x >> y >> z;
        if (x == y) continue;
        p[x].emplace_back(y, z);
      dfs(1, 0);
      std::cout << root[1]->val << "
      return 0;
