树链剖分就是将树分割成多条链,然后利用数据结构(线段树、树状数组等)来维护这些链。
树链剖分可以用来解决两点间路径相关的查询,修改问题。
Contents
树链剖分基本概念
重结点:子树结点数目最多的结点
轻节点:父亲节点中除了重结点以外的结点
重边:父亲结点和重结点连成的边
轻边:父亲节点和轻节点连成的边
重链:由多条重边连接而成的路径
轻链:由多条轻边连接而成的路径
比如上面这幅图中,用黑线连接的结点都是重结点,其余均是轻结点,2-11、1-14就是重链,其他就是轻链,用红点标记的就是该结点所在链的起点。
定义一些全局数组
son[] : 记录重儿子信息
siz[] : 记录子树个数
top[] : 记录所在链的链顶
deep[] : 记录节点深度
fa[] : 记录节点父亲
id[] : 记录dfs序
Step 1 : 进行第一次dfs
第一次dfs的目的是处理出每一个节点的重儿子,深度,父亲,子树大小
void dfs1(int u, int f){
deep[u] = deep[f] + 1; //计算每个节点的深度
siz[u] = 1; //计算每个节点子树大小
fa[u] = f; //计算每个节点的父亲
for(int i=0; i<map[u].size(); i++){
int v = map[u][i];
if(v != f){
dfs1(v,u); //递归处理子节点
siz[u] += siz[v];
if(siz[v] > siz[son[u]]) //如果这个儿子子树节点数目更多,更新重儿子
son[u] = v;
}
}
}
Step 2: 进行第二次dfs
第二次dfs的目的是处理出链顶, dfs序, (同时也可以处理出线段树build所需要的数组)
void dfs2(int u,int topf){
id[u] = ++time_stamp; //处理dfs序
top[u] = topf; //处理链顶
if(son[u]){
dfs2(son[u],topf); //重儿子,重链链顶延续
for(int i=0; i<map[u].size(); i++){
int v = map[u][i];
if(v != son[u] && v!=fa[u])
dfs2(v,v); //重链链顶链顶重新设置
}
}
}
Step 3: 线段树操作
这里的线段树以 BZOJ 1036 为例
void pushup(int o){
int lc = o << 1;
int rc = o << 1 | 1;
num[o] = max(num[lc],num[rc]);
sum[o] = sum[lc] + sum[rc];
}
void update(int o, int l, int r, int p, int k){
int lc = o << 1, rc = o << 1 | 1, mid = (l+r) >>1;
if(l == r){
sum[o] = num[o] = k;
return;
}
if(p <= mid)
update(lc, l, mid, p, k);
if(mid < p)
update(rc, mid+1, r, p, k);
pushup(o);
}
int getMax(int o, int l, int r, int ql, int qr){
int lc = o << 1, rc = o <<1 | 1, mid=(l+r)>>1, ans=-0x3f3f3f3f;
if(ql <= l && r <= qr)
return num[o];
if(ql <= mid)
ans = max(ans,getMax(lc, l, mid, ql, qr));
if(mid < qr)
ans = max(ans,getMax(rc, mid+1, r, ql, qr));
return ans;
}
int getSum(int o, int l, int r, int ql, int qr){
int lc = o << 1, rc = o <<1 | 1, mid = (l+r) >> 1, ans = 0;
if(ql <=l && r <= qr)
return sum[o];
if(ql <= mid)
ans += getSum(lc, l, mid, ql, qr);
if(mid < qr)
ans += getSum(rc, mid+1, r, ql, qr);
return ans;
}
Step 4: 树链剖分操作
查询操作,非常类似于倍增求LCA,不过这里直接跳转到top的父亲节点,(但是轻链的top就是自己)。需要注意的是,每次循环只能跳一次,并且让top结点深的那个来跳到top的位置,避免两个一起跳从而错过。
本题只涉及到了路径查询,对于路径修改,和查询操作非常类似,只是将向线段树查询链信息改为向线段树修改链信息。
int findMax(int u, int v){
int f1 = top[u], f2= top[v];
int ans = -0x3f3f3f3f;
while(f1 != f2){
if(deep[f1] < deep[f2]){
swap(f1,f2);
swap(u,v);
}
ans = max(ans, getMax(1, 1, n, id[f1], id[u]));
u = fa[f1];
f1 = top[u];
}
if(deep[u] > deep[v])
swap(u,v);
ans = max(ans, getMax(1, 1, n, id[u], id[v]));
return ans;
}
int findSum(int u, int v){
int f1 = top[u], f2=top[v];
int ans = 0;
while(f1 != f2){
if(deep[f1] < deep[f2]){
swap(u,v);
swap(f1,f2);
}
ans += getSum(1, 1, n, id[f1], id[u]);
u = fa[f1];
f1 = top[u];
}
if(deep[u] > deep[v])
swap(u,v);
ans += getSum(1, 1, n, id[u], id[v]);
return ans;
}
Step 5:主函数
因为这里涉及到了一些函数的特殊调用,特将主函数也一同附上
int main(){
scanf("%d",&n);
for(int i=0; i<n-1; i++){
int u,v;
scanf("%d%d",&u,&v);
map[u].push_back(v);
map[v].push_back(u);
}
deep[1] = 1;
dfs1(1,0);
dfs2(1,1);
for(int i=1; i<=n; i++){
scanf("%d",&w[i]);
update(1, 1, n, id[i], w[i]);
}
scanf("%d",&q);
for(int i=1; i<=q; i++){
char ch[10];
scanf("%s",ch);
if(ch[0] == 'C'){
int u,k;
scanf("%d%d",&u,&k);
update(1, 1, n, id[u], k);
}else if(ch[1] == 'S'){
int u,v;
scanf("%d%d",&u,&v);
printf("%d\n",findSum(u,v));
}else if(ch[1] == 'M'){
int u,v;
scanf("%d%d",&u,&v);
printf("%d\n",findMax(u,v));
}
}
}