题目描述
给定一个含 $N$ 个元素的数组 $A$,下标从 $1$ 开始。请找出下面式子的最大值
$$(A[l_1] \oplus A[l_1 + 1] \oplus \cdots A[r_1]) + (A[l_2] \oplus A[l_2+1] \oplus \cdots A[r_2]) $$
其中 $1 \leq l_1 \leq r_1 < l_2 \leq r_2 \leq N$,$x \oplus y$ 表示 $x$ 和 $y$ 的按位异或。
输入格式
输入数据的第一行包含一个整数 $N$,表示数组中的元素个数。
第二行包括 $N$ 个整数 $A_1, A_2, \cdots A_N$。
输出格式
输出一行包含给定表达式可能最大值
样例输入1
5
1 2 3 4 5
样例输出1
6
数据范围及提示
对于 $40 \%$ 的数据,$2 \leq N \leq 10^4$
对于 $100 \%$ 的数据,$2 \leq N \leq 4 \times 10^5$
解题思路
异或运算满足 相同的数字异或两次等于 $0$ 这个性质,因此区间 $[l, r]$ 的值可以用 $sum[r] \oplus sum[l-1]$ 来实现
我们定义两个数组 $f[x]$ 和 $g[x]$ 分别表示区间 $[1, x]$ 异或最大值和 $[x, N]$ 区间异或最大值
转移方程很好想
$$f[i] = max(f[l-1], sum[i], sum[i] \oplus sum[j]) $$
这个转移方程在取 $sum[i]$ 区间时表示取整段区间 $[1, i]$ ,否则表示取其中一段区间 $[j, i]$
$$g[i] = max(g[l+1], sum[i], sum[i] \oplus sum[j]) $$
这里的 $sum[i]$ 分别表示前缀异或和和后缀异或和
暴力求解 $sum[i] \oplus sum[j]$ 显然是不可行的,我们可以用 $Trie$ 树帮助求解
代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#define r read()
#define Insert insert
#define Query query
#define Max max
using namespace std;
const int MAXN = 400000 + 10;
const int SIZ = 12400000 + 10;
inline int read(){
int x = 0;
char ch = getchar();
while(ch < '0' || ch > '9')
ch = getchar();
while('0' <= ch && ch <= '9'){
x = x*10 + ch - '0';
ch = getchar();
}
return x;
}
int tree[SIZ][2], cnt;
int data[MAXN], f[MAXN], g[MAXN];
int n, sum;
inline void insert(int num){
int root = 0;
for(register int i=0; i<32; i++){
int x = (num >> (32 - i - 1)) & 1;
if(!tree[root][x]){
tree[root][x] = ++cnt;
}
root = tree[root][x];
}
}
inline int query(int num){
int root = 0;
int k = 0;
for(register int i=0; i<32; i++){
int x = (num >> (32 - i - 1)) & 1;
if(tree[root][!x]){
k |= (!x) << (32 - i - 1);
root = tree[root][!x];
}else{
k |= x << (32 - i - 1);
root = tree[root][x];
}
}
return k;
}
int main(){
n = read();
for(register int i=1; i<=n; i++){
data[i] = read();
}
f[1] = sum = data[1];
insert(sum);
for(register int i=2; i<=n; i++){
sum = sum^data[i];
f[i] = max(f[i-1], max(sum, sum^query(sum)));
insert(sum);
}
memset(tree, 0, sizeof(tree));
cnt = 0;
sum = 0;
g[n] = sum = data[n];
insert(sum);
for(register int i=n-1; i>=1; i--){
sum = sum^data[i];
g[i] = max(g[i+1], max(sum, sum^query(sum)));
insert(sum);
}
int ans = 0;
for(register int i=1; i<n; i++){
ans = max(ans, f[i] + g[i+1]);
}
printf("%d", ans);
return 0;
}