题目描述
您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:
- 查询 $k$ 在区间内的排名
- 查询区间内排名为 $k$ 的值
- 修改某一位值上的数值
- 查询 $k$ 在区间内的前驱(前驱定义为严格小于 $x$,且最大的数,若不存在输出 $-2147483647$ )
- 查询 $k$ 在区间内的后继(后继定义为严格大于 $x$,且最小的数,若不存在输出 $2147483647$ )
输入格式
第一行两个数 $n,m$ 表示长度为 $n$ 的有序序列和 $m$ 个操作
第二行有 $n$ 个数,表示有序序列
下面有 $m$ 行,$opt$ 表示操作标号
- 若 $opt=1$ 则为操作 $1$,之后有三个数 $l,r,k$ 表示查询 $k$ 在区间 $[l,r]$ 的排名
- 若 $opt=2$ 则为操作 $2$,之后有三个数 $l,r,k$ 表示查询区间 $[l,r]$ 内排名为k的数
- 若 $opt=3$ 则为操作 $3$,之后有两个数 $pos,k$ 表示将 $pos$ 位置的数修改为 $k$
- 若 $opt=4$ 则为操作 $4$,之后有三个数 $l,r,k$ 表示查询区间 $[l,r]$ 内 $k$ 的前驱
- 若 $opt=5$ 则为操作 $5$,之后有三个数 $l,r,k$ 表示查询区间 $[l,r]$ 内 $k$ 的后继
输出格式
对于操作 $1, 2, 4, 5$ 各输出一行,表示查询结果
样例输入
9 6
4 2 2 1 9 4 0 1 1
2 1 4 3
3 4 10
2 1 4 3
1 2 5 9
4 3 9 5
5 2 8 5
样例输出
2
4
3
4
9
数据范围及提示
$n, m \leq 5 \times 10^4$,保证有序序列所有值在任何时刻满足 $[0, 10^8]$
解题思路
线段树套平衡树模板题,在这里平衡树选用了 $Splay$
对于在线段树中的每一个区间 $[l, r]$,都用一棵平衡树来维护区间信息
对于操作 $1$,在线段树内查询 $[l, r]$ 对应的 $Splay$ 中比 $k$ 小的数的个数,相加即可。(输出答案时需 $+1$)
对于操作 $3$,在线段树内对应修改即可(先在 $Splay$ 中删除旧值,再插入新值)
对于操作 $4$,在线段树递归时不断取最大值即可
对于操作 $5$,在线段树递归时不断取最小值即可
上面的操作全都是 $O(log^2 n)$复杂度的,对于操作 $2$,它不满足区间累加性,无法通过简单操作得到结果
我们考虑二分一个值 $x$,判断 $x$ 在区间内的排名是多少,即可得到答案
代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
using namespace std;
const int MAXN = 50000;
const int SIZ = 50000 * 25;
const int INF = 2147483647;
int n, m;
int data[MAXN];
inline int read(){
int x = 0;
char ch = getchar();
while(ch < '0' || ch > '9')
ch = getchar();
while('0' <= ch && ch <= '9'){
x = x*10 + ch - '0';
ch = getchar();
}
return x;
}
namespace Splay{
int tree[SIZ][2], fa[SIZ], tot;
int num[SIZ], cnt[SIZ], siz[SIZ];
inline void clear(int x){
tree[x][0] = tree[x][1] = 0;
fa[x] = num[x] = cnt[x] = siz[x] = 0;
}
inline int getpath(int x){
return tree[fa[x]][1] == x;
}
inline void update(int x){
siz[x] = siz[tree[x][0]] + siz[tree[x][1]] + cnt[x];
}
inline void rotate(int x){
int father = fa[x];
int grandfather = fa[father];
int path = getpath(x);
tree[father][path] = tree[x][path^1];
fa[tree[father][path]] = father;
tree[x][path^1] = father;
fa[father] = x;
fa[x] = grandfather;
if(grandfather){
tree[grandfather][tree[grandfather][1] == father] = x;
}
update(father);
update(x);
}
inline void splay(int &root, int x){
for(register int f; (f = fa[x]); rotate(x)){
if(fa[f]){
rotate(getpath(f) == getpath(x) ? f : x);
}
}
root = x;
}
inline void insert(int &root, int x){
if(!root){
root = ++tot;
num[tot] = x;
cnt[tot] = siz[tot] = 1;
return;
}
int pos = root;
int f;
while(true){
if(num[pos] == x){
cnt[pos]++;
siz[pos]++;
splay(root, pos);
return;
}
f = pos;
pos = tree[pos][num[pos] < x];
if(!pos){
pos = ++tot;
cnt[pos] = siz[pos] = 1;
num[pos] = x;
fa[pos] = f;
tree[f][num[f] < x] = tot;
splay(root, pos);
return;
}
}
}
inline int pre(int &root){
int pos = tree[root][0];
while(tree[pos][1])
pos = tree[pos][1];
return pos;
}
inline int pre(int &root, int x){
int pos = root;
int ans = -INF;
while(pos){
if(x > num[pos]){
ans = max(ans, num[pos]);
pos = tree[pos][1];
continue;
}
pos = tree[pos][0];
}
return ans;
}
inline int next(int &root){
int pos = tree[pos][1];
while(tree[pos][0])
pos = tree[pos][0];
return pos;
}
inline int next(int &root, int x){
int pos = root;
int ans = INF;
while(pos){
if(x < num[pos]){
ans = min(ans, num[pos]);
pos = tree[pos][0];
continue;
}
pos = tree[pos][1];
}
return ans;
}
inline int find(int &root, int x){
int ans = 0;
int pos = root;
while(true){
if(x < num[pos]){
pos = tree[pos][0];
continue;
}
ans += siz[tree[pos][0]];
if(x == num[pos]){
splay(root, pos);
return ans + 1;
}
ans += cnt[pos];
pos = tree[pos][1];
}
}
inline int rank(int &root, int x){
int pos = root;
int ans = 0;
while(pos){
if(x < num[pos]){
pos = tree[pos][0];
continue;
}
ans += siz[tree[pos][0]];
if(x == num[pos]){
splay(root, pos);
return ans;
}
if(x > num[pos]){
ans += cnt[pos];
pos = tree[pos][1];
}
}
return ans;
}
inline void Delete(int &root, int x){
find(root, x);
if(cnt[root] > 1){
cnt[root]--;
siz[root]--;
return;
}
if(!tree[root][0] && !tree[root][1]){
clear(root);
root = 0;
return;
}
if(!tree[root][0] && tree[root][1]){
int oldroot = root;
root = tree[root][1];
fa[root] = 0;
clear(oldroot);
return;
}
if(tree[root][0] && !tree[root][1]){
int oldroot = root;
root = tree[root][0];
fa[root] = 0;
clear(oldroot);
return;
}
int oldroot = root;
int Pre = pre(root);
splay(root, Pre);
tree[root][1] = tree[oldroot][1];
fa[tree[root][1]] = root;
fa[root] = 0;
clear(oldroot);
update(root);
return;
}
}
namespace SEG{
int rt[MAXN << 2];
inline void build(int root, int left, int right){
int lc = root << 1, rc = root << 1 | 1, mid = (left + right) >> 1;
for(register int i=left; i<=right; i++){
Splay::insert(rt[root], data[i]);
}
if(left == right)
return;
if(left <= mid)
build(lc, left, mid);
if(mid < right)
build(rc, mid+1, right);
}
inline void update(int root, int left, int right, int last, int now, int pos){
int lc = root << 1, rc = root << 1 | 1, mid = (left + right) >> 1;
Splay::insert(rt[root], now);
Splay::Delete(rt[root], last);
if(left == right)
return;
if(pos <= mid)
update(lc, left, mid, last, now, pos);
if(mid < pos)
update(rc, mid+1, right, last, now, pos);
}
inline int queryRank(int root, int left, int right, int qleft, int qright, int num){
int lc = root << 1, rc = root << 1 | 1, mid = (left + right) >> 1;
if(qleft <= left && right <= qright){
return Splay::rank(rt[root], num);
}
int ans = 0;
if(qleft <= mid)
ans += queryRank(lc, left, mid, qleft, qright, num);
if(mid < qright)
ans += queryRank(rc, mid+1, right, qleft, qright, num);
return ans;
}
inline int queryPre(int root, int left, int right, int qleft, int qright, int num){
int lc = root << 1, rc = root << 1 | 1, mid = (left + right) >> 1;
if(qleft <= left && right <= qright){
return Splay::pre(rt[root], num);
}
int ans = -INF;
if(qleft <= mid)
ans = max(ans, queryPre(lc, left, mid, qleft, qright, num));
if(mid < qright)
ans = max(ans, queryPre(rc, mid+1, right, qleft, qright, num));
return ans;
}
inline int queryNext(int root, int left, int right, int qleft, int qright, int num){
int lc = root << 1, rc = root << 1 | 1, mid = (left + right) >> 1;
if(qleft <= left && right <= qright){
return Splay::next(rt[root], num);
}
int ans = INF;
if(qleft <= mid)
ans = min(ans, queryNext(lc, left, mid, qleft, qright, num));
if(mid < qright)
ans = min(ans, queryNext(rc, mid+1, right, qleft, qright, num));
return ans;
}
inline int queryKth(int root, int left, int right, int num){
int l = 0;
int r = 1e8 + 10;
while(l < r){
int mid = ((l + r) >> 1) + 1;
if(queryRank(1, 1, n, left, right, mid) < num)
l = mid;
else
r = mid - 1;
}
return l;
}
}
int main(){
n = read();
m = read();
for(register int i=1; i<=n; i++)
data[i] = read();
SEG::build(1, 1, n);
for(register int i=1; i<=m; i++){
int op = read();
int a = read();
int b = read();
if(op == 3){
SEG::update(1, 1, n, data[a], b, a);
data[a] = b;
}else{
int c = read();
if(op == 1){
printf("%d\n", SEG::queryRank(1, 1, n, a, b, c) + 1);
}else if(op == 2){
printf("%d\n", SEG::queryKth(1, a, b, c));
}else if(op == 4){
printf("%d\n", SEG::queryPre(1, 1, n, a, b, c));
}else{
printf("%d\n", SEG::queryNext(1, 1, n, a, b, c));
}
}
}
}