[模板] 快速数论变换 (NTT)

简介

在之前我们已经介绍了快速傅里叶变换,利用单位复根,我们可以在 $O(n \ log\ n)$的时间内进行 $DFT$ 和 $IDFT$ 变换。

但是由于不可避免的精度问题,单位复根存在一定的局限性,因此对于正整数的卷积运算,我们通常使用快速数论变换来避免精度问题。

原根

原根的定义

对于一个正整数 $p$,若存在一个数 $g$ 满足 $(p, g) = 1$,并且 $\delta p(g)$ = $\varphi(p)$,其中 $\delta p(g)$ 为使 $g^d \equiv 1 \ (mod \ p)$ 的最小正整数 $d$,称为为 $g$ 是模 $p$ 的原根。

由欧拉定理可知,$\delta p(g)$ 一定小于等于 $\varphi(p)$

原根的计算

我们可以通过枚举一个数 $g$,并检验 $g$ 是不是 $p$ 的原根

在 $NTT$ 中,我们的模数 $p$ 需要是一个质数,因此 $\varphi(p) = p-1$,根据定义,$ \delta p (g) = \varphi(p) = p-1 $,因此对于原根 $g$

$$ \forall_{i \in [1, p-2]} g^i \not \equiv 1 $$

这样的复杂度是 $O(p)$ 的,我们考虑如何优化

对于一个数 $g$,$\delta p(g)$ 一定是 $p-1$ 的约数

假设存在最小的 $d$ 不是 $p-1$ 的约数,那么可以找到一个 $x$ 满足 $xd > p-1 > (x-1)d$

$$g^{d^x} \equiv g^{dx} \equiv g^{p-1} \equiv 1 (mod \ p)$$

那么就存在一个更小的 $dx-(p-1)$ 数 $d’$ ,与假设相反。

那么我们只需要枚举 $p-1$ 中除 $p-1$ 的所有约数 $q$

$$\forall g^q \not \equiv 1 (mod \ p)$$

实际上,我们还可以优化,将 $p-1$ 分解质因数

$$p – 1 = \prod_{i=1}^r p_{i}^{k_i} $$

我们只需要判断

$$\forall_{i \in [1, r]} g^{\frac{p-1}{p_i}} \not \equiv 1 (mod \ p)$$

因为对于一个更小的约数 $q$,如果它已经同余 $1$ 了,因为一定存在至少一个 $p_i$ 满足 $q \mid \frac{p-1}{p_i}$,那么它一定能通过一个幂运算使得 $g^{\frac{p-1}{p_i}}$ 也同余 $1$。

inline int root(int x){
    for(register int i=2; i<=x; i++){
        int tmp = x - 1;

        bool flag = true;

        for(register int k=2; k * k <= (x - 1); k++){
            if(tmp % k == 0){
                if(powx(i, (x - 1) / k, x) == 1){
                    flag = false;
                    break;
                }

                while(tmp % k == 0)
                    tmp /= k;
            }
        }

        if(flag && (tmp == 1 || powx(i, (x - 1) / tmp, x) != 1)){
            return i;
        }
    }
}

原根的性质

我们还需要原根拥有和单位复根一样的性质来进行 $DFT$ 和 $IDFT$ 变换。

令 $p = qn + 1$ 其中 $n$ 是 $2$ 的幂

性质一

令 $\omega_n = g^q$,那么 $1, g^q, g^{2q}, \cdots, g^{(n-1)q} $互不相同,满足单位复根的性质一

性质二

由 $\omega_n = p^q$,那么 $\omega_{2n} = p^{\frac{q}{2}} (p = \frac{q}{2} \times 2n + 1)$,所以 $\omega_{2n}^{2k} = \omega_n^k $,满足单位复根的性质二

性质三

因为

$$\omega_{n}^{n} \equiv g^{p-1} \equiv 1 (mod \ p)$$

所以

$$\omega_{n}^{\frac{n}{2}} \equiv \pm 1 (mod \ p)$$

又因为

$$\omega_{n}^{\frac{n}{2}} \not \equiv \omega_{n}^{0} (mod \ p) $$

所以

$$\omega_{n}^{\frac{n}{2}} \equiv – 1 (mod \ p)$$


$$\omega_{n}^{k + \frac{n}{2}} \equiv – \omega_{n}^{k} (mod \ p)$$

满足单位复根的性质三

性质四

$$
\begin{aligned}
S(\omega_n^k) & = 1 + \omega_n^k + (\omega_n^k)^2 + \cdots + (\omega_n^k)^{n-1} \\
& = \frac{1-(\omega_n^k)^n}{1-\omega_n^k} \\
& = \frac{(\omega_n^k)^n – 1}{\omega_n^k – 1} \\
\end{aligned}
$$

由性质三

$${(\omega_{n}^{k})}^{n} – 1 \equiv 0 (mod \ p) $$

那么 $S(\omega_n^k) = 0$,满足单位复根的性质四

快速数论变换

接下来的操作和快速傅里叶变换一样,只是每一步的操作都需要取模,在最后输出结果时,由除以 $n$ 变为乘以 $n^{-1}$

$NTT$ 常用的模数为 $998244353$ 和 $1004535809$ ,它们的原根为 $3$

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>

using namespace std;
const int MOD = 998244353;
const int MAXN = 4194304 + 10;

inline long long powx(long long a, long long b, long long mod){
    long long ans = 1;

    for(; b; b >>= 1){
        if(b & 1) ans *= a, ans %= mod;
        a *= a, a %= mod;
    }

    return ans;
}

inline long long exgcd(long long a, long long b, long long &x, long long &y){
    if(!b){
        x = 1, y = 0;
        return a;
    }

    long long d = exgcd(b, a%b, y, x);
    y -= x * (a / b);

    return d;
}

inline long long inv(long long a, long long mod){
    long long x, y;
    exgcd(a, mod, x, y);
    return (x + mod) % mod;
}

inline int root(int x){
    for(register int i=2; i<=x; i++){
        int tmp = x - 1;

        bool flag = true;

        for(register int k=2; k * k <= (x - 1); k++){
            if(tmp % k == 0){
                if(powx(i, (x - 1) / k, x) == 1){
                    flag = false;
                    break;
                }

                while(tmp % k == 0)
                    tmp /= k;
            }
        }

        if(flag && (tmp == 1 || powx(i, (x - 1) / tmp, x) != 1)){
            return i;
        }
    }
    throw;
}

namespace NTT{
    int place[MAXN];
    long long omega[MAXN];
    long long omegaInverse[MAXN];

    inline void init(int n){
        int k = 0;

        while((1 << k) < n)
            k++;

        for(register int i=0; i<n; i++){
            for(register int j=0; j<k; j++){
                if(i & (1 << j)){
                    place[i] |= 1 << (k - j - 1);
                }
            }
        }

        long long g = root(MOD);
        long long tmp = powx(g, (MOD - 1) / n, MOD);

        for(register int i=0; i<n; i++){
            omega[i] = (i == 0) ? 1 : omega[i - 1] * tmp % MOD;
            omegaInverse[i] = inv(omega[i], MOD);
        }
    }

    inline void transform(long long *a, int n, long long *omega){
        for(register int i=0; i<n; i++){
            if(i < place[i]){
                swap(a[i], a[place[i]]);
            }
        }

        for(register int range=2; range <= n; range <<= 1){
            int mid = range >> 1;

            for(register long long *p = a; p != a + n; p += range){
                int k = n / range;

                for(register int i=0; i<mid; i++){
                    int tmp = i + mid;

                    long long t = omega[k * i] * p[i + mid] % MOD;

                    p[tmp] = (p[i] - t + MOD) % MOD;
                    p[i] = (p[i] + t) % MOD;
                }
            }
        }
    }

    inline void dft(long long *a, int n){
        transform(a, n, omega);
    }

    inline void idft(long long *a, int n){
        transform(a, n, omegaInverse);

        long long tmp = inv(n, MOD);

        for(register int i=0; i<n; i++){
            a[i] = a[i] * tmp % MOD;
        }
    }
}

inline int read(){
    int x = 0;
    int p = 1;
    char ch = getchar();

    while(ch < '0' || ch > '9'){
        if(ch == '-')
            p = 0;
        ch = getchar();
    }

    while('0' <= ch && ch <= '9'){
        x = x*10 + ch - '0';
        ch = getchar();
    }

    return p ? x : (-x);
}

int n, m, tot;

long long a[MAXN], b[MAXN];

int main(){
    n = read();
    m = read();

    n++;
    m++;

    for(register int i=0; i<n; i++)
        a[i] = read();

    for(register int i=0; i<m; i++){
        b[i] = read();
    }

    tot = 1;
    while(tot < n + m){
        tot <<= 1;
    }

    NTT::init(tot);

    NTT::dft(a, tot);
    NTT::dft(b, tot);

    for(register int i=0; i<tot; i++){
        a[i] = a[i] * b[i] % MOD;
    }

    NTT::idft(a, tot);

    for(register int i=0; i<n+m-1; i++){
        printf("%lld ", a[i]);
    }
    return 0;
}

参考资料