Contents
JZOJ – 5813 计算
Time Limit : 1000ms
Memory Limit : 512MB
Description
Input
一行由空格隔开的两个整数,分别是 $ n $ 和 $ m $。
Output
一行表示答案。
Sample Input
Case 1
6 1
Case 2
6 3
Sample Output
Case 1
10
Case 2
2248
Data Constraint
解题思路
我们令 $ F(x) = \prod_{i=1}^{2m} x_i $
要求 $ F(x) \leq n^m $的组数,我们考虑将它拆为两个部分 $ s_1 = F(x) < n^m $,$ s_2 = F(x) = n^m $
显然 $ s_3 = s_1 = F(x) > n^m $
那么我们可以得到 $ s_1 = \frac{d(n)^{2m}+s_2}{2} $
现在问题转化为求 $ s_2 $ 就可以了,问题简单了很多
将 $ n $ 分解质因数,对于每一个质因数 $ p_i $,我们假设 $ n $ 含有 $ k_i $ 个质因数 $ p_i $
令 $ a_j $ 表示 $ x_j $ 中含有 $ a_j $ 个质因数 $ p_i $
问题又转化为求 $ \sum_{j=1}^{2m} a_j = k_i \times n $的方案数
可以通过计数DP实现
令 $ f[i][j] $ 表示处理到第 $ i $ 个数,和为 $ j $ 的方案数
$ f[i][j] = \sum_{t=1}^{k_i} f[i-1][j-t] $
最后依次对每一个 $ p_i $ 我们都跑一遍DP,将答案乘起来即可,就可以得到 $ s_2 $
最后再求出 $ s_1 $,$ s_1 + s_2 $得到最终答案
代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
using namespace std;
const int MAXN = 20 + 10;
inline int read(){
int x = 0;
char ch = getchar();
while(ch < '0' || ch > '9'){
ch = getchar();
}
while('0' <= ch && ch <= '9'){
x = x*10 + ch - '0';
ch = getchar();
}
return x;
}
inline long long exgcd(long long a, long long b, long long &x, long long &y){
if(b == 0){
x = 1;
y = 0;
return a;
}else{
long long d = exgcd(b, a%b, y, x);
y -= x * (a / b);
return d;
}
}
inline long long gcd(long long a, long long b){
if(b == 0)
return a;
return gcd(b, a%b);
}
inline long long lcm(long long a, long long b){
if(!a)
return b;
if(!b)
return a;
return a / gcd(a, b) * b;
}
int data[MAXN];
int n, m;
long long ans;
inline void solve(int x, long long a, long long b, int k){
if(a > n || b > n)
return;
if(x == m + 1){
if(!a || !b)
return;
long long x, y;
long long d = exgcd(a, b, x, y);
if(d != 1)
return;
x %= b;
while(x <= 0)
x += b;
if(k%2 == 0)
ans += (n/a - x + b) / b;
else
ans -= (n/a - x + b) / b;
return;
}
solve(x+1, lcm(a, data[x]), b, k+1);
solve(x+1, a, lcm(b, data[x]), k+1);
solve(x+1, a, b, k);
}
int main(){
freopen("sazetak.in", "r", stdin);
freopen("sazetak.out", "w", stdout);
n = read();
m = read();
for(register int i=1; i<=m; i++)
data[i] = read();
data[++m] = n;
solve(1, 0, 0, 0);
printf("%lld", ans);
return 0;
}