[题解] Crash的数字表格 / JZPTAB

题目描述

今天的数学课上,Crash 小朋友学习了最小公倍数 (Least Common Multiple)。对于两个正整数 $a$ 和 $b$, $\mathrm{lcm} (a, b)$ 表示能同时整除 $a$ 和 $b$ 的最小正整数。例如,$\mathrm{lcm}(6, 8) = 24$。

回到家后,Crash 还在想着课上学的东西,为了研究最小公倍数,他画了一张 $n \times m$的表格。每个格子里写了一个数字,其中第 $i$ 行第 $j$ 列的那个格子里写着数为 $\mathrm{lcm}(i, j)$ 。一个 $4 \times 5$ 的表格如下:

1 2 3 4 5
1 1 2 3 4 5
2 2 2 6 4 10
3 3 6 3 12 15
4 4 4 12 4 20

看着这个表格,Crash想到了很多可以思考的问题。不过他最想解决的问题却是一个十分简单的问题:这个表格中所有数的和是多少。当 $n$ 和 $m$ 很大时,Crash 就束手无策了,因此他找到了聪明的你用程序帮他解决这个问题。由于最终结果可能会很大,Crash 只想知道表格里所有数的和 $\bmod 20101009$ 的值。

输入格式

输入的第一行包含两个正整数,分别表示 $n$ 和 $m$

输出格式

输出一个正整数,表示表格中所有数的和 $\bmod 20101009$ 的值

样例输入

4 5

样例输出

122

数据范围及提示

$30 \%$ 的数据满足 $n, m \leq 10^3$

$70\%$ 的数据满足 $n, m \leq 10^5$

$100 \%$ 的数据满足 $n, m \leq 10^7$

解题思路

$$
\begin{aligned}
& \sum_{i=1}^n \sum_{j=1}^m \frac{ij}{\gcd(i, j)} \\
= & \sum_{i=1}^n\sum_{j=1}^m \sum_{d=1}^{\min(n, m)} \frac{ij}{d} [gcd(i, j) = d] \\
= & \sum_{i=1}^n\sum_{j=1}^m \sum_{d=1}^{\min(n, m)} \frac{ij}{d} [gcd(i, j) = d] \\
= & \sum_{d=1}^{\min(n, m)} d^2 \sum_{i=1}^{\lfloor \frac{n}{d} \rfloor} \sum_{j=1}^{\lfloor \frac{m}{d} \rfloor} \frac{ij}{d} [gcd(i, j) = 1] \\
= & \sum_{d=1}^{\min(n, m)} d \sum_{i=1}^{\lfloor \frac{n}{d} \rfloor} \sum_{j=1}^{\lfloor \frac{m}{d} \rfloor} [gcd(i, j) = 1] ij \\
= & \sum_{d=1}^{\min(n, m)} d \sum_{i=1}^{\lfloor \frac{n}{d} \rfloor} \sum_{j=1}^{\lfloor \frac{m}{d} \rfloor} \sum_{k \mid \gcd(i, j)} \mu(k) ij \\
= & \sum_{d=1}^{\min(n, m)} d \sum_{i=1}^{\lfloor \frac{n}{d} \rfloor} \sum_{j=1}^{\lfloor \frac{m}{d} \rfloor} \sum_{k \mid i} \sum_{k \mid j} \mu(k) ij \\
\end{aligned}
$$

这个式子可以被分成两部分解决,我们令
$$
f(n, m) = \sum_{i=1}^n \sum_{j=1}^m \sum_{k \mid i} \sum_{k \mid j} \mu(k) ij
$$
那么原式就等于
$$
\sum_{d=1}^{\min(n, m)} d \cdot f(\lfloor \frac{n}{d} \rfloor, \lfloor \frac{m}{d} \rfloor)
$$
可以通过整除分块计算,而 $f(n, m)$ 的求解则是经典问题
$$
\begin{aligned}
f(n, m) & = \sum_{i=1}^n \sum_{j=1}^m \sum_{k \mid i} \sum_{k \mid j} \mu(k) ij \\
& = \sum_{k=1}^{\min(n, m)} \mu(k) \sum_{i=1}^n \sum_{i=1}^m ij [k \mid i] [k \mid j] \\
& = \sum_{k=1}^{\min(n, m)} \mu(k) k^2 \sum_{i=1}^{\lfloor \frac{n}{k} \rfloor} \sum_{j=1}^{\lfloor \frac{m}{k} \rfloor} ij
\end{aligned}
$$

$$
\begin{aligned}
g(n, m) & = \sum_{i=1}^n \sum_{j=1} ^m ij \\
&= \frac{(n + 1) \times n}{2} \times \frac{(m + 1) \times m}{2}
\end{aligned}
$$
那么式子可以写作
$$
f(n, m) = \sum_{k=1}^{\min(n, m)} \mu(k) k^2 g(\lfloor \frac{n}{k} \rfloor, \lfloor \frac{m}{k} \rfloor)
$$
可以整除分块计算

代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>

using namespace std;
const int MAXN = 10000000 + 10;
const int MOD = 20101009;

inline int read() {
    int x = 0;
    char ch = getchar();

    while (ch < '0' || ch > '9') {
        ch = getchar();
    }

    while ('0' <= ch && ch <= '9') {
        x = x * 10 + ch - '0';
        ch = getchar();
    }

    return x;
}

int prime[MAXN], mu[MAXN], cnt;
bool isNotPrime[MAXN];

long long sum[MAXN];

inline void init (int n) {
    mu[1] = 1;

    for (register int i=2; i<=n; i++) {
        if (!isNotPrime[i]) {
            prime[++cnt] = i;
            mu[i] = -1;
        }

        for (register int j=1; j<=cnt; j++) {
            int v = i * prime[j];
            if (v > n) break;

            isNotPrime[v] = true;

            if (i % prime[j] == 0) {
                mu[v] = 0;
                break;
            } else {
                mu[v] = -mu[i];
            }
        }
    }

    for (register int i=1; i<=n; i++) {
        sum[i] = (sum[i-1] + ((1LL * i * i) % MOD * (mu[i] + MOD) % MOD)) % MOD;
    }
}

inline long long calc2(int n, int m) {
    return (1LL * n * (n + 1) / 2 % MOD) * (1LL * m * (m + 1) / 2 % MOD) % MOD;
}

inline long long calc(int n, int m) {
    long long ans = 0;
    int r = 0;

    for (register int l=1; l<=min(n, m); l = r + 1) {
        r = min(n / (n / l), m / (m / l));
        ans = (ans + (sum[r] - sum[l-1] + MOD) % MOD * calc2(n / l, m / l) % MOD) % MOD;
    }

    return ans;
}

inline long long solve(int n, int m) {
    long long ans = 0;
    int r = 0;

    for (register int l=1; l<=min(n, m); l = r + 1) {
        r = min(n/(n/l), m/(m/l));
        ans = (ans + ((1LL * (l + r) * (r - l + 1) / 2)) % MOD * calc(n / l, m / l)) % MOD;
    }

    return ans;
}

int main(){
    int n = read();
    int m = read();

    init(max(n, m));

    printf("%lld\n", solve(n, m));
    return 0;
}