• [模板] 多项式全家桶


    包含: 多项式乘法, 多项式求逆, 多项式 ln, 多项式 exp, 多项式快速幂.

    exp 的 (O(n log n)) 做法常数太大, 实际表现还不如 (O(n log^2 n)) (当然也有可能是我写丑了), 所以就放了 (O(n log^2 n)) 的做法.

    学习笔记什么的会找时间补上 (尽量不咕).

    #include <cstdio>
    #include <cstring>
    #include <iostream>
    
    using namespace std;
    
    typedef long long ll;
    
    const int _ = (1 << 18) + 7;
    const int mod = 998244353, rt = 3;
    
    int n, K1, K2, g[_], f[_];
    bool flag;
    
    int Pw(int a, int p) {
      int res = 1;
      while (p) {
        if (p & 1) res = (ll)res * a % mod;
        a = (ll)a * a % mod;
        p >>= 1;
      }
      return res;
    }
    
    namespace POLY {
      int tot, num[_], pwrt[2][_], inv[_], tmp[5][_];
    
      void Clear(int *f, int L) { memset(f, 0, L << 2); }
      void Cpy(int *h, int *f, int L) { memcpy(h, f, L << 2); }
    
      void Init() {
        tot = 1; while (tot <= n + n) tot <<= 1;
        inv[1] = 1;
        for (int i = 2; i <= tot; ++i) inv[i] = (ll)inv[mod % i] * (mod - mod / i) % mod;
        pwrt[0][tot] = Pw(rt, (mod - 1) / tot);
        pwrt[1][tot] = Pw(pwrt[0][tot], mod - 2);
        for (int len = (tot >> 1); len; len >>= 1) {
          pwrt[0][len] = (ll)pwrt[0][len << 1] * pwrt[0][len << 1] % mod;
          pwrt[1][len] = (ll)pwrt[1][len << 1] * pwrt[1][len << 1] % mod;
        }
      }
    
      void NTT(int *f, int t, bool ty) {
        for (int i = 1; i < t; ++i) {
          num[i] = (num[i >> 1] >> 1) | ((i & 1) ? t >> 1 : 0);
          if (i < num[i]) swap(f[i], f[num[i]]);
        }
        for (int len = 2; len <= t; len <<= 1) {
          int gap = len >> 1, w1 = pwrt[ty][len];
          for (int i = 0, w = 1, tmp; i < t; i += len, w = 1)
            for (int j = i; j < i + gap; ++j) {
              tmp = (ll)w * f[j + gap] % mod;
              f[j + gap] = (f[j] - tmp + mod) % mod;
              f[j] = (f[j] + tmp) % mod;
              w = (ll)w * w1 % mod;
            }
        }
        if (ty) for (int i = 0; i < t; ++i) f[i] = (ll)f[i] * inv[t] % mod;
      }
    
      void Inv(int *h, int *f, int L) {
        int a[_];
        Clear(h, L >> 1), Clear(a, L >> 1);
        h[0] = Pw(f[0], mod - 2), a[0] = f[0], a[1] = f[1];
        for (int len = 2, t = 4; len <= L; len <<= 1, t <<= 1) {
          NTT(h, t, 0), NTT(a, t, 0);
          for (int i = 0; i < t; ++i) h[i] = (ll)h[i] * (2 - (ll)a[i] * h[i] % mod + mod) % mod;
          NTT(h, t, 1), NTT(a, t, 1);
          for (int i = len; i < t; i++) a[i] = f[i], h[i] = 0;
        }
      }
    
      void Deriv(int *h, int *f, int L) { for (int i = 0; i < L - 1; ++i) h[i] = (ll)f[i + 1] * (i + 1) % mod; }
      void Integ(int *h, int *f, int L) { for (int i = L - 1; i; --i) h[i] = (ll)f[i - 1] * inv[i] % mod; h[0] = 0; }
    
      void Ln(int *h, int *f, int L) {
        Clear(h, L << 1);
        int a[(L << 1) + 7], b[(L << 1) + 7];
        Clear(a, L << 1), Clear(b, L << 1);
        Deriv(a, f, L), Inv(b, f, L);
        NTT(a, L << 1, 0), NTT(b, L << 1, 0);
        for (int i = 0; i < (L << 1); ++i) h[i] = (ll)a[i] * b[i] % mod;
        NTT(h, L << 1, 1);
        Integ(h, h, L);
      }
    
      void dcExp(int *f, int *g, int t, int l, int r) {
        if (t == 1) { f[0] = l ? (ll)f[0] * inv[l] % mod : f[0]; return; }
        dcExp(f, g, t >> 1, l, (l + r) >> 1);
        int a[t + 7], b[t + 7];
        Clear(a, t), Clear(b, t);
        Cpy(a, f, t >> 1), Cpy(b, g, t); 
        NTT(a, t, 0), NTT(b, t, 0);
        for (int i = 0; i < t; ++i) a[i] = (ll)a[i] * b[i] % mod;
        NTT(a, t, 1);
        for (int i = (t >> 1); i < t; ++i) f[i] = (f[i] + a[i - 1]) % mod;
        dcExp(f + (t >> 1), g, t >> 1, (l + r) >> 1, r);
    
      }
    
      void Exp(int *f, int *g, int L) {
        Deriv(g, g, L), f[0] = 1;
        dcExp(f, g, L, 1, L);
      }
    
      void Pow(int *h, int *f, int K1, int K2, int L) {
        int st = 0; while (st < n and !f[st]) ++st;
        if ((flag and st) || (ll)st * K1 >= (ll)n) return;
        int inv = Pw(f[st], mod - 2), tmp = f[st];
        for (int i = 0; i < n; ++i) f[i] = (ll)f[i + st] * inv % mod;
        int a[(L << 1) + 7]; Clear(a, L << 1);
        Ln(a, f, L);
        for (int i = 0; i < n; ++i) a[i] = (ll)a[i] * K1 % mod;
        Exp(g, a, L);
        st *= K1, tmp = Pw(tmp, K2);
        for (int i = n - 1; i >= st; --i) g[i] = (ll)g[i - st] * tmp % mod;
        for (int i = 0; i < st; ++i) g[i] = 0;
      }
    }
    
    void Gi(int &K1, int &K2) {
      ll t1 = 0, t2 = 0;
      char c = getchar();
      while (!isdigit(c)) c = getchar();
      while (isdigit(c)) {
        t1 = t1 * 10 + c - '0', t2 = t2 * 10 + c - '0';
        if (t1 >= mod) flag = 1, t1 %= mod;
        t2 %= (mod - 1);
        c = getchar();
      }
      K1 = t1, K2 = t2;
    }
    
  • 相关阅读:
    [LeetCode] Sqrt(x)
    [LeetCode] Rotate Array
    【经典算法】贪心算法
    【经典算法——查找】二分查找
    ARP(Adress Resolution Protocol): 地址解析协议
    【经典算法】分治策略
    [LeetCode] Recover Binary Search Tree
    [LeetCode] Convert Sorted Array to Binary Search Tree
    python数据类型之字典(dict)和其常用方法
    简单了解hash
  • 原文地址:https://www.cnblogs.com/BruceW/p/13892271.html
Copyright © 2020-2023  润新知