• [ BZOJ 3451 ] Normal


    Description

    题目链接

    定义一次点分治的复杂度是所有分治中心分治时的子树大小之和。

    给定一棵树,问所有点等概率被选做重心,点分治的期望复杂度。

    Solution

    根据期望的线性性,答案等价于每个点在点分树上的深度期望之和。

    思路是从点对的角度考虑某一个点是否会产生贡献。

    [E(depth[x])=sum_{y=1}^n P(xin subtree[y]) ]

    也就是 (x) 在点分树上在 (1dots n) 的子树中的概率和。

    考虑点分树上 (y)(x) 的祖先的条件,要求 (x)(y) 构成的这条链上第一个在点分治过程中被删除的点是 (y) ,由于链上被选中的概率相等,因此这个概率为 (frac{1}{dist(x,y) + 1})

    所以答案为

    [sum_{x=1}^nsum_{j=1}^n frac{1}{dis(i,j) + 1}=sum_{len = 0}^n frac{cnt[i]}{i + 1} ]

    因此需要点分治求长度为 (i) 的路径条数 (cnt[i]) ,注意到合并的时候是卷积的形式。

    容斥做法

    不考虑重复路径,把子树 dfs 一遍,直接自己进行卷积,再去掉子树内重复计数的路径即可。

    每一层最差以自己的 (size) 作为长度进行卷积,因此复杂度为 (mathcal O(nlog^2 n))

    #include <cmath>
    #include <cstdio>
    #include <cctype>
    #include <cstdlib>
    #include <cstring>
    #include <iostream>
    #include <algorithm>
    #define N 65537
    #define mod 998244353
    using namespace std;
    typedef long long ll;
     
    inline int rd() {
      int x = 0;
      char c = getchar();
      while (!isdigit(c)) c = getchar();
      while (isdigit(c)) {
        x = x * 10 + (c ^ 48); c = getchar();
      }
      return x;
    }
     
    inline void print(ll x) {
      int y = 10, len = 1;
      while(y <= x) {y *= 10; ++len;}
      while(len--) {y /= 10; putchar(x / y + 48); x %= y;}
      putchar('
    ');
    }
     
    inline int fpow(int x, int t = mod - 2) {
      int res = 1;
      while (t) {
        if (t & 1) res = 1ll * res * x % mod;
        x = 1ll * x * x % mod; t >>= 1;
      }
      return res;
    }
     
    int mxlen = (1 << 16), w[2][N], rev[N];
     
    inline int mo(int x) {
      return x >= mod ? x - mod : x;
    }
     
    inline void init() {
      int per = fpow(3, (mod - 1) / mxlen);
      int invper = fpow(per);
      w[0][0] = w[1][0] = 1;
      for (int i = 1; i < mxlen; ++i) {
        w[0][i] = 1ll * w[0][i - 1] * per % mod;
        w[1][i] = 1ll * w[1][i - 1] * invper % mod;
      }
    }
     
    inline int Rev(int n) {
      int len = 1, bit = 0;
      while (len <= n) len <<= 1, ++bit;
      for (int i = 0; i < len; ++i)
        rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (bit - 1)));
      return len;
    }
     
    inline void NTT(int *f, int len, int o) {
      for (int i = 0; i < len; ++i)
        if (i > rev[i]) swap(f[i], f[rev[i]]);
      for (int i = 1; i < len; i <<= 1) {
        int wn = mxlen / (i << 1);
        for (int j = 0; j < len; j += (i << 1)) {
          int nw = 0, x, y;
          for (int k = 0; k < i; ++k, nw += wn) {
            x = f[j + k];
            y = 1ll * w[o][nw] * f[i + j + k] % mod;
            f[j + k] = mo(x + y);
            f[i + j + k] = mo(x - y + mod);
          }
        }
      }
      if (o == 1) {
        int invl = fpow(len);
        for (int i = 0; i < len; ++i) f[i] = 1ll * f[i] * invl % mod;
      }
    }
     
    bool vis[N];
     
    int n, m, tot, totn, mx, rt, mxd;
     
    int bkt[N], cnt[N], sz[N], hd[N];
     
    struct edge{int to, nxt;} e[N << 1];
     
    inline void add(int u, int v) {
      e[++tot].to = v; e[tot].nxt = hd[u]; hd[u] = tot;
      e[++tot].to = u; e[tot].nxt = hd[v]; hd[v] = tot;
    }
     
    void getrt(int u, int fa) {
      sz[u] = 1;
      int mxs = 0;
      for (int i = hd[u], v; i; i = e[i].nxt)
        if ((v = e[i].to) != fa && !vis[v]) {
          getrt(v, u);
          sz[u] += sz[v];
          mxs = max(mxs, sz[v]);
        }
      mxs = max(mxs, totn - sz[u]);
      if (mxs < mx) {mx = mxs; rt = u;}
    }
     
    void getsz(int u, int fa) {
      sz[u] =  1;
      for (int i = hd[u], v; i; i = e[i].nxt)
        if ((v = e[i].to) != fa && !vis[v]) {
          getsz(v, u); sz[u] += sz[v];
        }
    }
     
    void dfs(int u, int fa, int dep) {
      ++bkt[dep]; mxd = max(mxd, dep);
      for (int i = hd[u], v; i; i = e[i].nxt)
        if ((v = e[i].to) != fa && !vis[v]) dfs(v, u, dep + 1);
    }
     
    inline void mul(int *a, int len, int o) {
      len = Rev(len << 1);
      NTT(a, len, 0);
      for (int i = 0; i < len; ++i) a[i] = 1ll * a[i] * a[i] % mod;
      NTT(a, len, 1);
      if (o > 0) for (int i = 0; i < len; ++i) cnt[i + 1] += a[i];
      else for (int i = 0; i < len; ++i) cnt[i + 3] -= a[i];
      for (int i = 0; i < len; ++i) a[i] = 0;
    }
     
    inline void calc(int u, int o) {
      mxd = 0;
      dfs(u, 0, 0);
      mul(bkt, mxd, o);
    }
     
    void divide(int u) {
      vis[u] = 1;
      calc(u, 1);
      for (int i = hd[u], v; i; i = e[i].nxt)
        if (!vis[v = e[i].to]) {
          calc(v, -1);
          getsz(v, u);
          totn = mx = sz[v]; rt = v;
          getrt(v, 0); divide(rt);
        }
    }
     
    int main() {
      init();
      n = rd();
      for (int i = 1; i < n; ++i) add(rd() + 1, rd() + 1);
      mx = totn = n;
      getrt(1, 0); divide(rt);
      double ans = 0.0;
      for (int i = 1; i <= n + 1; ++i) ans += (double) cnt[i] / i;
      printf("%.4lf", ans);
      return 0;
    }
    

    子树按秩合并做法

    在点分治求路径条数时,我们尝试用按秩合并的思路去搞,也就是将子树按照最深深度排序,然后逐个合并计算答案。

    开始的时候只有 (bkt[0]=1),然后按顺序卷每一个子树求出来的计数数组 (bktson)

    把贡献直接计算,然后再将 (bktson) 按位加到 (bkt) 上。

    考虑复杂度,将子树按照深度从小到大排序后,每次卷积得到的新的链长不会超过新合并的子树深度的二倍,所以每次卷积的数组长度为 (mxdep[v]) 的,且每个位置只会和其父节点卷积一次,因此总复杂度为 (mathcal O(nlog^2 n))

    #include <cmath>
    #include <cstdio>
    #include <cctype>
    #include <cstdlib>
    #include <cstring>
    #include <iostream>
    #include <algorithm>
    #define N 65537
    #define mod 998244353
    using namespace std;
    typedef long long ll;
     
    inline int rd() {
      int x = 0;
      char c = getchar();
      while (!isdigit(c)) c = getchar();
      while (isdigit(c)) {
        x = x * 10 + (c ^ 48); c = getchar();
      }
      return x;
    }
     
    inline void print(ll x) {
      int y = 10, len = 1;
      while(y <= x) {y *= 10; ++len;}
      while(len--) {y /= 10; putchar(x / y + 48); x %= y;}
      putchar('
    ');
    }
     
    inline int fpow(int x, int t = mod - 2) {
      int res = 1;
      while (t) {
        if (t & 1) res = 1ll * res * x % mod;
        x = 1ll * x * x % mod; t >>= 1;
      }
      return res;
    }
     
    int mxlen = (1 << 16), w[2][N], rev[N];
     
    inline int mo(int x) {
      return x >= mod ? x - mod : x;
    }
     
    inline void init() {
      int per = fpow(3, (mod - 1) / mxlen);
      int invper = fpow(per);
      w[0][0] = w[1][0] = 1;
      for (int i = 1; i < mxlen; ++i) {
        w[0][i] = 1ll * w[0][i - 1] * per % mod;
        w[1][i] = 1ll * w[1][i - 1] * invper % mod;
      }
    }
     
    inline int Rev(int n) {
      int len = 1, bit = 0;
      while (len <= n) len <<= 1, ++bit;
      for (int i = 0; i < len; ++i)
        rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (bit - 1)));
      return len;
    }
     
    inline void NTT(int *f, int len, int o) {
      for (int i = 0; i < len; ++i)
        if (i > rev[i]) swap(f[i], f[rev[i]]);
      for (int i = 1; i < len; i <<= 1) {
        int wn = mxlen / (i << 1);
        for (int j = 0; j < len; j += (i << 1)) {
          int nw = 0, x, y;
          for (int k = 0; k < i; ++k, nw += wn) {
            x = f[j + k];
            y = 1ll * w[o][nw] * f[i + j + k] % mod;
            f[j + k] = mo(x + y);
            f[i + j + k] = mo(x - y + mod);
          }
        }
      }
      if (o == 1) {
        int invl = fpow(len);
        for (int i = 0; i < len; ++i) f[i] = 1ll * f[i] * invl % mod;
      }
    }
     
    bool vis[N];
     
    double ans = 0.0;
     
    int n, m, tot, totn, mx, rt;
     
    int bkt[N], sz[N], hd[N];
     
    struct edge{int to, nxt;} e[N << 1];
     
    inline void add(int u, int v) {
      e[++tot].to = v; e[tot].nxt = hd[u]; hd[u] = tot;
      e[++tot].to = u; e[tot].nxt = hd[v]; hd[v] = tot;
    }
     
    void getrt(int u, int fa) {
      sz[u] = 1;
      int mxs = 0;
      for (int i = hd[u], v; i; i = e[i].nxt)
        if ((v = e[i].to) != fa && !vis[v]) {
          getrt(v, u);
          sz[u] += sz[v];
          mxs = max(mxs, sz[v]);
        }
      mxs = max(mxs, totn - sz[u]);
      if (mxs < mx) {mx = mxs; rt = u;}
    }
     
    void getsz(int u, int fa) {
      sz[u] =  1;
      for (int i = hd[u], v; i; i = e[i].nxt)
        if ((v = e[i].to) != fa && !vis[v]) {
          getsz(v, u); sz[u] += sz[v];
        }
    }
     
    int res[N], tmp[N];
     
    inline int mul(int *a, int *b, int lena, int lenb) {
      int len = Rev(lenb << 1);
      for (int i = 0; i < lena; ++i) res[i] = a[i];
      for (int i = lena; i < len; ++i) res[i] = 0;
      for (int i = 0; i < lenb; ++i) tmp[i] = b[i];
      for (int i = lenb; i < len; ++i) tmp[i] = 0;
      NTT(res, len, 0); NTT(tmp, len, 0);
      for (int i = 0; i < len; ++i) res[i] = 1ll * res[i] * tmp[i] % mod;
      NTT(res, len, 1);
      for (int i = 0; i < len; ++i) ans += 2.0 * res[i] / (i + 1);
      return len;
    }
     
    int mxd[N], s[N], bkts[N];
     
    inline bool cmp(int x, int y) {return mxd[x] < mxd[y];}
     
    int dfs(int u, int fa, int dep) {
      int resd = dep;
      for (int i = hd[u], v; i; i = e[i].nxt)
        if ((v = e[i].to) != fa && !vis[v]) resd = max(resd, dfs(v, u, dep + 1));
      return resd;
    }
     
    void dfs2(int u, int fa, int dep) {
      ++bkts[dep];
      for (int i = hd[u], v; i; i = e[i].nxt)
        if ((v = e[i].to) != fa && !vis[v]) dfs2(v, u, dep + 1);
    }
     
    void divide(int u) {
      vis[u] = 1;
      s[0] = 0;
      for (int i = hd[u], v; i; i = e[i].nxt)
        if (!vis[v = e[i].to]) {
          s[++s[0]] = v;
          mxd[v] = dfs(v, u, 1);
        }
      sort(s + 1, s + 1 + s[0], cmp);
      bkt[0] = 1;
      int nowlen = 1;
      for (int i = 1, v; i <= s[0]; ++i) {
        dfs2(v = s[i], 0, 1);
        nowlen = mul(bkt, bkts, nowlen, mxd[v] + 1);
        for (int i = 0; i <= mxd[v]; ++i) {
          bkt[i] += bkts[i]; bkts[i] = 0;
        }
      }
      for (int i = 0; i <= nowlen; ++i) bkt[i] = 0;
      for (int i = hd[u], v; i; i = e[i].nxt)
        if (!vis[v = e[i].to]) {
          getsz(v, u);
          totn = mx = sz[v]; rt = v;
          getrt(v, 0); divide(rt);
        }
    }
     
    int main() {
      init();
      n = rd();
      for (int i = 1; i < n; ++i) add(rd() + 1, rd() + 1);
      mx = totn = n;
      getrt(1, 0); divide(rt);
      printf("%.4lf", ans + n);
      return 0;
    }
    
  • 相关阅读:
    harbor 报错,注意harbor.yml文件格式。
    nginx中rewrite文件
    改善github网络连接
    代码层实现六种质量属性战术
    读漫谈架构有感
    “淘宝网”的六个质量属性的六个常见属性
    寒假学习进度16:tensorflow2.0 波士顿房价预测
    寒假学习进度15:tensorflow2.0 优化器
    寒假学习进度14:对mnist数据集实现逻辑回归
    寒假学习进度13:tensorflow2.0 mnist简介
  • 原文地址:https://www.cnblogs.com/SGCollin/p/10597925.html
Copyright © 2020-2023  润新知