题目描述
输入一个整数 $n$ 和一个整数 $p$ ,你需要求出 $\sum_{i=1}^n \sum_{j=1}^n ij \gcd(i,j) \bmod p$,其中 $\gcd(a, b)$ 表示 $a$ 与 $b$ 的最大公约数
输入格式
一行两个整数 $p$ 和 $n$
输出格式
一行一个整数,表示答案
样例输入
998244353 2000
样例输出
883968974
数据范围及提示
对于 $20\%$的数据,$n \leq 1000$。
对于 $30\%$ 的数据,$n \leq 5000$。
对于 $60\%$ 的数据,$n \leq 10^6$,时限1s。
对于另外 $20\%$ 的数据,$n \leq 10^9$,时限3s。
对于最后 $20\%$ 的数据,$n \leq 10^{10}$,时限6s。
对于 $100\%$ 的数据,$5 \times 10^8 \leq p \leq 1.1 \times 10^9$ 且 $p$ 为质数。
解题思路
$$
\begin{aligned}
& \sum_{i=1}^n \sum_{j=1}^n ij \gcd(i,j)\
= & \sum_{i=1}^n \sum_{j=1}^n ij \sum_{d \mid i} \sum_{d \mid j} \boldsymbol{\varphi} (d) \
= & \sum_{d=1}^n \boldsymbol{\varphi}(d) \sum_{d \mid i} \sum_{d \mid j} ij \
= & \sum_{d=1}^n \boldsymbol{\varphi}(d) d^2 \left( \sum_{i=1}^{n/k} i \right)^2 \
= & \sum_{d=1}^n \boldsymbol{\varphi}(d) d^2 \sum_{i=1}^{n/k} i^3
\end{aligned}
$$
其中最后一步是由
$$
\left( \sum_{i=1}^n i \right)^2 = \sum_{i=1}^n i^3
$$
得到的,同时
$$
\sum_{i=1}^n i^2 = \frac{n(n+1)(2n+1)}{6}
$$
所以我们可以通过杜教筛筛出 $\sum_{d=1}^n \boldsymbol{\varphi}(d) d^2$ 的值,然后使用整除分块,在 $O(n^{\frac{5}{6}})$ 的时间内算出答案
代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
using namespace std;
const int MAXN = 10000000 + 10;
const int MAXM = 2500 + 10;
const int SIZ = 10000000;
long long n, MOD, INV;
long long pri[MAXN], cnt;
long long phi[MAXN], sum1[MAXN];
bool isNotPrime[MAXN];
inline long long mul(long long a, long long b) {return a * b % MOD;}
inline long long add(long long a, long long b) {return (a + b) % MOD;}
inline long long powx(long long a, long long b) {
long long ans = 1;
for (; b; b >>= 1) {
if (b & 1) ans = mul(ans, a);
a = mul(a, a);
}
return ans;
}
inline void init() {
INV = powx(6, MOD - 2);
phi[1] = 1;
for (register int i=2; i<=SIZ; i++) {
if (!isNotPrime[i]) {
pri[++cnt] = i;
phi[i] = i - 1;
}
for (register int j=1; j<=cnt; j++) {
long long m = i * pri[j];
if (m > SIZ) break;
isNotPrime[m] = true;
if (i % pri[j] == 0) {
phi[m] = phi[i] * pri[j];
break;
} else {
phi[m] = phi[i] * (pri[j] - 1);
}
}
}
for (register int i=1; i<=SIZ; i++) {
sum1[i] = add(sum1[i-1], (mul(phi[i], mul(i, i))));
}
}
inline long long calc2(long long x) {
x %= MOD;
return mul(mul(mul(x, x+1), 2*x+1), INV);
}
inline long long calc3(long long x) {
x %= MOD;
return mul((x * (x + 1) / 2) % MOD, (x * (x + 1) / 2) % MOD);
}
long long sum2[MAXM];
bool vis[MAXM];
inline long long calc(long long x) {
if (x < SIZ) return sum1[x];
int k = n / x;
if (vis[k]) return sum2[k];
long long &ans = sum2[k];
long long r = 0;
vis[k] = true;
ans = calc3(x);
for (register long long l=2; l<=x; l=r+1) {
r = x / (x / l);
ans = (ans - ((((calc2(r) - calc2(l-1)) % MOD + MOD) % MOD ) * calc(x / l) % MOD)) % MOD;
if (ans < 0) ans += MOD;
}
return ans;
}
inline long long work () {
long long ans = 0;
long long r = 0;
for (register long long l=1; l<=n; l=r+1){
r = n / (n / l);
ans = (ans + ((calc(r) - calc(l-1) + MOD) % MOD * calc3(n / l)) % MOD) % MOD;
}
return ans;
}
int main(){
scanf("%lld%lld", &MOD, &n);
init();
printf("%lld", work());
return 0;
}