题目背景
对于一棵树,我们定义 $dis(i, j)$ 为节点 $i$ 和 $j$ 之间最短路径上的边数。
对于一个长度为 $n$ 的序列 $a$,我们定义 $w(l, r)$ 为 $\sum_{i=l}^r \sum_{j=i}^r dis(a_i, a_j)$
题目描述
给你一棵 $n$ 个点的树以及一个 $1-n$ 的排列 $a$,有 $q$ 次询问,每次给出 $k$,求 $\sum_{i=1}^k \sum_{j=i}^k w(i, j)$ 对 $998244353$ 取模的值
输入格式
第一行两个正整数 $n$ 和 $q$
接下来 $n-1$ 行,每行两个正整数 $u$ 和 $v$,表示 $u$ 和 $v$ 之间有一条树边
接下来一行 $n$ 个数字描述排列 $a$
接下来 $q$ 行每行一个正整数 $k_i$ 表示询问
输出格式
输出 $q$ 行,第 $i$ 行表示询问 $k_i$ 的答案
样例输入
4 4
1 2
2 3
2 4
3 2 1 4
1
2
3
4
样例输出
0
1
6
21
数据范围及提示
对于前 $30 \%$ 的数据,$n \leq 1000$
对于另 $20 \%$ 的数据,$q_i = n$
对于另 $20 \%$ 的数据,第 $i$ 条边连接点 $i$ 与点 $i+1$
对于 $100 \%$ 的数据,$n, q \leq 10^5, u, v, k_i \leq n$
解题思路
题目要求
$$
\begin{aligned}
f_k & = \sum_{l=1}^k \sum_{r=l}^k w(l, r) \\
& = \sum_{l=1}^k \sum_{r=l}^k \sum_{i=l}^r \sum_{j=i}^r dis(a_i, a_j)
\end{aligned}
$$
考虑 $k$ 从 $k-1$ 转移到 $k$ 时答案的增加量
$$
\begin{aligned}
f_k – f_{k-1} & = \sum_{i=1}^k w(i, k) \\
& = \sum_{i=1}^k \left[ w(i, k-1) + \sum_{j=i}^k dis(a_j,a_k) \right] \\
&= \sum_{i=1}^k w(i, k-1) + \sum_{i=1}^k \sum_{j=i}^k dis(a_j, a_k)
\end{aligned}
$$
因为
$$
w(k, k-1) = 0
$$
所以式子的前半部分就是上一次从 $k-2$ 转移到 $k-1$ 的增加量
考虑将和式展开,发现对于每一个 $dis(a_j, a_k)$ , 都计算了 $j$ 次
$$
\sum_{i=1}^k \sum_{j=i}^k dis(a_j, a_k) = \sum_{i=1}^k i \times dis(a_i, a_k)
$$
这样的算法是 $O(n^2)$ 的, 再考虑将 $dis(a_i, a_k)$ 展开
$$
dis(a_i, a_k) = deep_{a_i} + deep_{a_k} – 2 \times deep_{lca(a_i, a_k)}
$$
那么式子就被划分成了三部分
$$
\sum_{i=1}^k i \times deep_{a_i}
$$
可以通过预处理前缀和,$O(n)$ 预处理,$O(1)$ 查询
$$
\sum_{i=1}^k i \times deep_{a_k}
$$
一共出现了 $\frac{k \times (k-1)}{2}$ 次,可以 $O(1)$ 计算
问题转化为如何快速求解下面这个式子
$$
\sum_{i=1}^k i \times deep_{lca(a_i, a_k)}
$$
它可以通过树链剖分在 $O(\log^2 n)$ 的时间内求解,也可以使用 $\mathrm{lct}$ 在 $O(\log n)$ 的时间内求解
对于每一个 $a_i$,它与 $a_x$ 的 $\mathrm{lca}$ 一定在它到根节点的路径上,因此将 $a_i$ 到根节点上的每一个节点权值加上 $i$,统计 $a_k$ 到根节点的权值和即可
代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
using namespace std;
const long long MAXN = 200000 + 10;
const long long MAXM = 400000 + 10;
const long long MOD = 998244353;
long long Head[MAXN], to[MAXM], Next[MAXM], tot = 1;
long long sum[MAXN << 2], add[MAXN << 2];
inline void _add(long long a, long long b){
to[tot] = b;
Next[tot] = Head[a];
Head[a] = tot++;
}
inline void pushdown(long long root, long long left, long long right){
if(add[root]){
long long lson = root << 1, rson = root << 1 | 1, mid = (left + right) >> 1;
add[lson] = (add[lson] + add[root]) % MOD;
add[rson] = (add[rson] + add[root]) % MOD;
sum[lson] = (sum[lson] + (mid - left + 1) * add[root]) % MOD;
sum[rson] = (sum[rson] + (right - mid) * add[root]) % MOD;
add[root] = 0;
}
}
inline void update(long long root, long long left, long long right, long long qleft, long long qright, long long k){
long long lson = root << 1, rson = root << 1 | 1, mid = (left + right) >> 1;
if(qleft <= left && right <= qright){
add[root] = (add[root] + k) % MOD;
sum[root] = (sum[root] + (right - left + 1) * k) % MOD;
return;
}
pushdown(root, left, right);
if(qleft <= mid)
update(lson, left, mid, qleft, qright, k);
if(mid < qright)
update(rson, mid+1, right, qleft, qright, k);
sum[root] = (sum[lson] + sum[rson]) % MOD;
}
inline long long query(long long root, long long left, long long right, long long qleft, long long qright){
long long lson = root << 1, rson = root << 1 | 1, mid = (left + right) >> 1;
if(qleft <= left && right <= qright){
return sum[root];
}
pushdown(root, left, right);
long long ans = 0;
if(qleft <= mid)
ans += query(lson, left, mid, qleft, qright);
if(mid < qright)
ans += query(rson, mid+1, right, qleft, qright);
return ans % MOD;
}
long long n, q;
long long deep[MAXN], topf[MAXN], son[MAXN], fa[MAXN], siz[MAXN];
long long id[MAXN], data[MAXN], top[MAXN], time_stamp;
inline void dfs1(long long x, long long f){
siz[x] = 1;
deep[x] = deep[f] + 1;
fa[x] = f;
for(register long long i=Head[x]; i; i=Next[i]){
long long v = to[i];
if(v != f){
dfs1(v, x);
siz[x] += siz[v];
if(siz[v] > siz[son[x]])
son[x] = v;
}
}
}
inline void dfs2(long long x, long long topf){
id[x] = ++time_stamp;
data[time_stamp] = 0;
top[x] = topf;
if(son[x]){
dfs2(son[x], topf);
for(register long long i=Head[x]; i; i=Next[i]){
long long v = to[i];
if(v != son[x] && v != fa[x]){
dfs2(v, v);
}
}
}
}
inline long long treeQuery(long long x, long long y){
long long ans = 0;
long long f1 = top[x];
long long f2 = top[y];
while(f1 != f2){
if(deep[f1] < deep[f2]){
swap(x, y);
swap(f1, f2);
}
ans = (ans + query(1, 1, n, id[f1], id[x])) % MOD;
x = fa[f1];
f1 = top[x];
}
if(deep[x] > deep[y])
swap(x, y);
ans = (ans + query(1, 1, n, id[x], id[y])) % MOD;
return ans;
}
inline long long treeUpdate(long long x, long long y, long long k){
long long ans = 0;
long long f1 = top[x];
long long f2 = top[y];
while(f1 != f2){
if(deep[f1] < deep[f2]){
swap(x, y);
swap(f1, f2);
}
update(1, 1, n, id[f1], id[x], k);
x = fa[f1];
f1 = top[x];
}
if(deep[x] > deep[y])
swap(x, y);
update(1, 1, n, id[x], id[y], k);
return ans;
}
inline long long read(){
long long x = 0;
char ch = getchar();
while(ch < '0' || ch > '9')
ch = getchar();
while('0' <= ch && ch <= '9'){
x = x*10 + ch - '0';
ch = getchar();
}
return x;
}
long long a[MAXN];
long long result[MAXN];
long long prefix[MAXN];
signed main(){
n = read();
q = read();
for(register long long i=1; i<n; i++){
long long a = read();
long long b = read();
_add(a, b);
_add(b, a);
}
dfs1(1, 0);
dfs2(1, 1);
for(register long long i=1; i<=n; i++)
a[i] = read();
long long lastans = 0;
for(register long long i=1; i<=n; i++){
prefix[i] = (prefix[i-1] + (i * deep[a[i]] % MOD)) % MOD;
}
for(register long long i=1; i<=n; i++){
long long valadd = ((lastans + (((i * (i-1)) / 2) % MOD * deep[a[i]]) % MOD) % MOD + prefix[i-1]) % MOD;
valadd -= 2 * treeQuery(1, a[i]);
valadd = (valadd % MOD + MOD) % MOD;
lastans = valadd;
result[i] = (result[i-1] + valadd) % MOD;
treeUpdate(1, a[i], i);
}
for(register long long i=1; i<=q; i++){
long long k = read();
printf("%lld\n", result[k]);
}
}