[题解] JZOJ – 5813 计算

JZOJ – 5813 计算

Time Limit : 1000ms
Memory Limit : 512MB

Description

Input

一行由空格隔开的两个整数,分别是 $ n $ 和 $ m $。

Output

一行表示答案。

Sample Input

Case 1

6 1

Case 2

6 3

Sample Output

Case 1

10

Case 2

2248

Data Constraint

解题思路

我们令 $ F(x) = \prod_{i=1}^{2m} x_i $

要求 $ F(x) \leq n^m $的组数,我们考虑将它拆为两个部分 $ s_1 = F(x) < n^m $,$ s_2 = F(x) = n^m $

显然 $ s_3 = s_1 = F(x) > n^m $

那么我们可以得到 $ s_1 = \frac{d(n)^{2m}+s_2}{2} $

现在问题转化为求 $ s_2 $ 就可以了,问题简单了很多

将 $ n $ 分解质因数,对于每一个质因数 $ p_i $,我们假设 $ n $ 含有 $ k_i $ 个质因数 $ p_i $

令 $ a_j $ 表示 $ x_j $ 中含有 $ a_j $ 个质因数 $ p_i $

问题又转化为求 $ \sum_{j=1}^{2m} a_j = k_i \times n $的方案数

可以通过计数DP实现

令 $ f[i][j] $ 表示处理到第 $ i $ 个数,和为 $ j $ 的方案数

$ f[i][j] = \sum_{t=1}^{k_i} f[i-1][j-t] $

最后依次对每一个 $ p_i $ 我们都跑一遍DP,将答案乘起来即可,就可以得到 $ s_2 $

最后再求出 $ s_1 $,$ s_1 + s_2 $得到最终答案

代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>

using namespace std;
const int MAXN = 20 + 10;

inline int read(){
    int x = 0;
    char ch = getchar();

    while(ch < '0' || ch > '9'){
        ch = getchar();
    }

    while('0' <= ch && ch <= '9'){
        x = x*10 + ch - '0';
        ch = getchar();
    }

    return x;
}

inline long long exgcd(long long a, long long b, long long &x, long long &y){
    if(b == 0){
        x = 1;
        y = 0;
        return a;
    }else{
        long long d = exgcd(b, a%b, y, x);
        y -= x * (a / b);
        return d;
    }
}

inline long long gcd(long long a, long long b){
    if(b == 0)
        return a;
    return gcd(b, a%b);
}

inline long long lcm(long long a, long long b){
    if(!a)
        return b;
    if(!b)
        return a;

    return a / gcd(a, b) * b;
}

int data[MAXN];
int n, m;
long long ans;

inline void solve(int x, long long a, long long b, int k){
    if(a > n || b > n)
        return;

    if(x == m + 1){
        if(!a || !b)
            return;

        long long x, y;
        long long d = exgcd(a, b, x, y);

        if(d != 1)
            return;

        x %= b;

        while(x <= 0)
            x += b;


        if(k%2 == 0)
            ans += (n/a - x + b) / b;
        else
            ans -= (n/a - x + b) / b;

        return;
    }

    solve(x+1, lcm(a, data[x]), b, k+1);
    solve(x+1, a, lcm(b, data[x]), k+1);
    solve(x+1, a, b, k);
}

int main(){

    freopen("sazetak.in", "r", stdin);
    freopen("sazetak.out", "w", stdout);

    n = read();
    m = read();

    for(register int i=1; i<=m; i++)
        data[i] = read();

    data[++m] = n;

    solve(1, 0, 0, 0);

    printf("%lld", ans);
    return 0;
}