[题解] BZOJ – 4665 小w的喜糖

题目描述

废话不多说,反正小w要发喜糖啦!!
小w一共买了 $n$ 块喜糖,发给了 $n$ 个人,每个喜糖有一个种类。这时,小w突发奇想,如果这 $n$ 个人相互交换手中的糖,那会有多少种方案使得每个人手中的糖的种类都与原来不同。
两个方案不同当且仅当,存在一个人,他手中的糖的种类在两个方案中不一样。

输入格式

第一行,一个整数 $n$
接下来 $n$ 行,每行一个整数,第 $i$ 个整数 $A_i$ 表示开始时第 $i$ 个人手中的糖的种类
对于所有数据,$1 \leq A_i \leq k, k \leq N, N \leq 2000$

输出格式

一行,一个整数 $\mathrm{Ans}$,表示方案数模 $1000000009$

样例输入

6
1
1
2
2
3
3

样例输出

10

解题思路

首先我们可以将喜糖颜色一样的人进行合并,他们都是等效的

我们通过 DP 求得 $f_{i, j}$ 表示对于前 $i$ 种糖果,至少有 $j$ 人交换后糖果与之前相同

然后只要通过枚举 $j$ 容斥即可求出答案

下面我们考虑如何进行 DP,对于第 $i$ 种糖果,我们枚举有 $k$ 人交换后不变

$$
f_{i, j+k} \leftarrow f_{i-1, j}
$$

同时我们要乘上一个组合数 ${a_i \choose k}$,$a_i$ 表示有多少人拥有第 $i$ 种糖果

最后其他人($n – j$ 人) 即可直接排列

$$
\frac{(n – j)!}{\prod_{i=1}^n (a_i-k_i)!}
$$

其中

$$
\sum_{i=1}^n k_i = j
$$

可以将分母在 DP 过程中直接求出

代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>

using namespace std;
const int MAXN = 2000 + 10;
const int MOD = 1e9 + 9;

inline int read(){
    int x = 0;
    char ch = getchar();

    while(ch < '0' || ch > '9')
        ch = getchar();

    while('0' <= ch && ch <= '9'){
        x = x*10 + ch - '0';
        ch = getchar();
    }

    return x;
}

int n;
int sugar[MAXN];
long long f[MAXN][MAXN];

long long ans;
long long fac[MAXN], inv[MAXN];

inline long long powx(long long a, long long b){
    long long ans = 1;

    for(; b; b >>= 1){
        if(b & 1){
            ans = (ans * a) % MOD;
        }

        a = (a * a) % MOD;
    }

    return ans;
}  

inline void init(){
    fac[0] = 1;

    for(register int i=1; i<=n; i++){
        fac[i] = fac[i-1] * i % MOD;
    }

    inv[n] = powx(fac[n], MOD - 2);

    for(register int i=n-1; i>=0; i--){
        inv[i] = inv[i+1] * (i+1) % MOD;
    }
}

inline long long C(long long n, long long m){
    if(m > n)
        return 0;

    return fac[n] * inv[m] % MOD * inv[n-m] % MOD;
}

int main(){
    n = read();

    for(register int i=1; i<=n; i++){
        sugar[read()]++;
    }

    init();

    f[0][0] = 1;

    int tot = 0;
    for(register int i=1; i<=n; i++){
        for(register int j=0; j<=tot; j++){
            for(register int k=0; k<=sugar[i]; k++){
                f[i][j+k] = (f[i][j+k] + ((f[i-1][j] * C(sugar[i], k)) % MOD * inv[sugar[i] - k]) % MOD) % MOD;
            }
        }
        tot += sugar[i];
    }

    for(register int i=0; i<=n; i++){
        ans = (ans + ((i & 1 ? -f[n][i] : f[n][i]) * fac[tot - i]) % MOD + MOD) % MOD;
    }

    printf("%lld\n", ans);
    return 0;
}