Contents
JZOJ – 5455 拆网线
Time Limit : 1000ms
Memory Limit : 64MB
Description
企鹅国的网吧们之间由网线互相连接,形成一棵树的结构。现在由于冬天到了,供暖部门缺少燃料,于是他们决定去拆一些网线来做燃料。但是现在有K只企鹅要上网和别人联机游戏,所以他们需要把这K只企鹅安排到不同的机房(两只企鹅在同一个机房会吵架),然后拆掉一些网线,但是需要保证每只企鹅至少还能通过留下来的网线和至少另一只企鹅联机游戏。
所以他们想知道,最少需要保留多少根网线?
Input
第一行一个整数 $T$ ,表示数据组数;
每组数据第一行两个整数 $N$,$K$,表示总共的机房数目和企鹅数目。
第二行 $N-1$ 个整数,第 $i$ 个整数 $A_i$ 表示机房 $i+1$ 和机房 $A_i$ 有一根网线连接($1 \leq A_i \leq i$)。
Output
每组数据输出一个整数表示最少保留的网线数目。
Sample Input
2
4 4
1 2 3
4 3
1 1 1
Sample Output
2
2
Data Constraint
对于30%的数据:N≤15;
对于50%的数据:N≤300;
对于70%的数据:N≤2000;
对于100%的数据:2≤K≤N≤100000,T≤10。
解题思路
我们发现这是一棵树,题目要求一只企鹅只需要有另一只企鹅联机即可。
那么我们就尽可能期望有更多的边将两只不同的企鹅两两连接在一起。一根网线连接两只企鹅, $n$ 根网线连接 $2n$ 只企鹅
题目转变为计算在一棵树中最多有多少条这样的边,考虑树形DP
对于状态 $f[x][0]$ 表示为不包括节点 $x$ , 其子树最多有多少条这样的边
状态 $f[x][1]$ 表示包括节点 $x$,其子树最多有多少条这样的边
$f[x][0]$ 直接由子节点 $v$ 的 $f[v][1]$ 暴力累加即可
$f[x][1]$ 尝试以此对每一个子节点 $v$,连上一条从 $x$ 到 $v$ 的网线
$f[x][0] – f[v][1] + f[v][0] + 1$ 更新 $f[x][1]$
最后我们就可以计算出有 $k$ 条这样的边。
如果企鹅的数量小于等于 $2k$,就需要 $\frac{k+1}{2}$条网线
如果大于那么一根网线只能连接一只不同的企鹅,另一端接在已经连接的企鹅上。那么网线数量加上剩余的企鹅数量即可。
代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <vector>
using namespace std;
const int MAXN = 100000 + 10;
vector <int> tree[MAXN];
int T, n, k;
inline int read(){
int x = 0;
char ch = getchar();
while(ch < '0' || ch > '9')
ch = getchar();
while('0' <= ch && ch <= '9'){
x = x*10 + ch - '0';
ch = getchar();
}
return x;
}
int f[MAXN][2];
inline void dfs(int x, int fa){
for(register int i=0; i<tree[x].size(); i++){
int v = tree[x][i];
if(v != fa){
dfs(v, x);
f[x][0] += f[v][1];
}
}
for(register int i=0; i<tree[x].size(); i++){
int v = tree[x][i];
if(v != fa){
f[x][1] = max(f[x][1], f[x][0] - f[v][1] + f[v][0] + 1);
}
}
}
int main(){
freopen("tree.in", "r", stdin);
freopen("tree.out", "w", stdout);
T = read();
while(T--){
n = read();
k = read();
for(register int i=1; i<=n; i++)
tree[i].clear();
memset(f, 0, sizeof(f));
for(register int i=2; i<=n; i++){
int tmp = read();
tree[tmp].push_back(i);
tree[i].push_back(tmp);
}
dfs(1, 0);
int ans = max(f[1][0], f[1][1]);
if(ans * 2 >= k){
printf("%d\n", (k+1)/2);
}else{
printf("%d\n", ans + (k-ans*2));
}
}
return 0;
}