题目描述
九条可怜是一个热爱思考的女孩子。
九条可怜最近正在研究各种排序的性质,她发现了一种很有趣的排序方法: $\operatorname{Gobo sort}$
$\operatorname{Gobo sort}$ 的算法描述大致如下:
- 假设我们要对一个大小为 $n$ 的数列 $a$ 排序。
- 等概率随机生成一个大小为 $n$ 的排列 $p$ 。
- 构造一个大小为 $n$ 的数列 $b$ 满足 $b_i = a_{p_i}$ ,检查 $b$ 是否有序,如果 $b$ 已经有序了就结束算法,并返回 $b$ ,不然返回步骤 $2$。
显然这个算法的期望时间复杂度是 $O(n \times n!)$ 的,但是九条可怜惊奇的发现,利用量子的神奇性质,在量子系统中,可以把这个算法的时间复杂度优化到线性。
九条可怜对这个排序算法进行了进一步研究,她发现如果一个序列满足一些性质,那么 $\operatorname{Gobo sort}$ 会很快计算出正确的结果。为了量化这个速度,她定义 $\operatorname{Gobo sort}$ 的执行轮数是步骤 2 的执行次数。
于是她就想到了这么一个问题:
现在有一个长度为 $n$ 的序列 $x$,九条可怜会在这个序列后面加入 $m$ 个元素,每个元素是 $[l,r]$内的正整数。 她希望新的长度为 $n+m$ 的序列执行 $\operatorname{Gobo sort}$ 的期望执行轮数尽量的多。她希望得到这个最多的期望轮数。
九条可怜很聪明,她很快就算出了答案,她希望和你核对一下,由于这个期望轮数实在是太大了,于是她只要求你输出对 $998244353$ 取模的结果。
输入格式
第一行输入一个整数 $T$,表示数据组数。
接下来 $2 \times T$ 行描述了 $T$ 组数据。
每组数据分成两行,第 $1$ 行有四个正整数 $n,m,l,r$ 表示数列的长度和加入数字的个数和加入数字的范围。 第 $2$ 行有 $n$ 个正整数,第 $i$ 个表示 $x_i$ 。
输出格式
输出 $T$ 个整数,表示答案。
样例输入
2
3 3 1 2
1 3 4
3 3 5 7
1 3 4
样例输出
180
720
数据范围及提示
对于第一组数据,我们可以添加${1,2,2}$ 到序列的最末尾,使得这个序列变成 1 3 4 1 2 2
,那么进行一轮的成功概率是 $\frac{1}{180}$ ,因此期望需要 $180$ 轮。
对于第二组数据,我们可以添加 ${5,6,7}$ 到序列的最末尾,使得这个序列变成 1 3 4 5 6 7
,那么进行一轮的成功概率是 $\frac{1}{720}$ ,因此期望需要 $720$ 轮。
对于 $30\%$ 的数据, $T\leq 10 , n,m,l,r$
对于 $50\%$ 的数据, $T\leq 300,n,m,l,r,a_i\leq 300$ 。
对于 $60\%$ 的数据,$\sum{r-l+1}\leq 10^7$ 。
对于 $70\%$ 的数据, $\sum{n} \leq 2\times 10^5$ 。
对于 $90\%$ 的数据, $m\leq 2\times 10^5*$。
对于 $100\%$ 的数据, $T\leq 10^5,n\leq 2\times 10^5,m\leq 10^7,1\leq l\leq r\leq 10^9$,$1\leq a_i\leq 10^9,\sum{n}\leq 2\times 10^6$ 。
解题思路
总的期望轮数等于一次成功的概率的倒数
考虑如何计算一次成功的概率,假设序列中有 $n$ 个数,其中总共有 $m$ 种,每种出现次数为 $a_i$
$$
\frac{\prod_{i=1}^n a_i}{n!}
$$
就是答案,我们要使期望轮数尽可能多,就是要使分子尽可能小,对于在 $l$ 到 $r$ 区间内的 $a_i$ ,要尽可能平均
也就是每一次加入的数是加入之前出现次数最少的数
考虑将 $l$ 到 $r$ 区间内的 $a_i$ 从小到大排序,将区间内不存在的数看做 $a_i$ 为 $0$
从前往后将第 $1$ 到 第 $i$ 个数一直增加,直至和 $i+1$ 数出现次数一样多,如果不足则只增加部分
注意题目有多组输入,数据清空时只清空已使用的部分数组,避免超时
代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
using namespace std;
const int MAXN = 200000 + 10;
const int MAXM= 10000000 + 200000 + 10;
const int MOD = 998244353;
inline int read(){
int x = 0;
char ch = getchar();
while(ch < '0' || ch > '9')
ch = getchar();
while('0' <= ch && ch <= '9'){
x = x*10 + ch - '0';
ch = getchar();
}
return x;
}
int T;
int cnt[MAXN];
int num[MAXN], que[MAXN];
int n, m, l, r;
long long fac[MAXM], inv[MAXM];
inline long long powx(long long a, long long b){
long long ans = 1;
for(; b; b >>= 1){
if(b & 1) ans = (ans * a) % MOD;
a = (a * a) % MOD;
}
return ans;
}
inline void init(){
fac[0] = 1;
for(register int i=1; i<=10200000; i++){
fac[i] = fac[i-1] * i % MOD;
}
inv[10200000] = powx(fac[10200000], MOD - 2);
for(register int i=10199999; i>=0; i--){
inv[i] = inv[i+1] * (i + 1) % MOD;
}
}
inline void write(int x){
if(x < 10){
putchar('0' + x);
return;
}
write(x / 10);
write(x % 10);
}
int main(){
init();
T = read();
while(T--){
n = read();
m = read();
l = read();
r = read();
int all = n + m;
for(register int i=1; i<=n; i++){
num[i] = que[i] = read();
}
sort(que+1, que+n+1);
int tot = unique(que+1, que+n+1) - que - 1;
memset(cnt, 0, sizeof(int) * (n + 10));
long long ans = 1;
int start = 0x3f3f3f3f;
int end = 0;
for(register int i=1; i<=n; i++){
int tmp = num[i];
int id = lower_bound(que+1, que+tot+1, tmp) - que;
if(l <= tmp && tmp <= r){
start = min(start, id);
end = max(end, id);
}
cnt[id]++;
ans = (ans * cnt[id]) % MOD;
}
if(start == 0x3f3f3f3f){
start = 5;
end = 4;
}
int non = (r - l + 1) - (end - start + 1);
sort(cnt+start, cnt+end+1);
cnt[start - 1] = 0;
cnt[end + 1] = 0x3f3f3f3f;
int len = non;
for(register int i=start; i<=end+1; i++){
if(cnt[i] == cnt[i-1]){
len++;
continue;
}
if(1LL * len * (cnt[i] - cnt[i-1]) <= m){
m -= len * (cnt[i] - cnt[i-1]);
long long tmp = (inv[cnt[i-1]] * fac[cnt[i]]) % MOD;
ans = ans * powx(tmp, len) % MOD;
}else{
int times = m / len;
int rest = m % len;
long long tmp = (inv[cnt[i-1]] * fac[cnt[i-1] + times]) % MOD;
ans = ans * powx(tmp, len) % MOD;
ans = ans * powx(cnt[i-1] + times + 1, rest) % MOD;
m = 0;
break;
}
len++;
}
write(fac[all] * powx(ans, MOD - 2) % MOD);
putchar('\n');
}
return 0;
}