Contents
JZOJ – 5788 餐馆
Time Limit : 2000ms
Memory Limit : 64MB
Description
K妹的胡椒粉大卖,这辣味让食客们感到刺激,许多餐馆也买这位K妹的账。有 $N$ 家餐馆,有 $N-1$ 条道路,这 $N$ 家餐馆能相互到达。K妹从 $1$ 号餐馆开始。每一个单位时间,K妹可以在所在餐馆卖完尽量多的胡椒粉,或者移动到有道路直接相连的隔壁餐馆。第 $i$ 家餐馆最多需要 $A_i$ 瓶胡椒粉。K妹有 $M$ 个单位的时间,问她最多能卖多少胡椒粉。
Input
第一行有两个正整数 $N$,$M$。
第二行描述餐馆对胡椒粉的最大需求量,有 $N$ 个正整数,表示 $A_i$。
接下来有 $N-1$ 行描述道路的情况,每行两个正整数 $u$,$v$,描述这条道路连接的两个餐馆。
Output
一个整数,表示她最多能卖的胡椒粉瓶数。
Sample Input
Case 1
3 5
9 2 5
1 2
1 3
Case 2
4 5
1 1 1 2
1 2
2 3
3 4
Case 3
5 10
1 3 5 2 4
5 2
3 1
2 3
4 2
Sample Output
Case 1
14
Case 2
3
Case 3
15
Data Constraint
对于10%的数据,N≤20。
对于50%的数据,N≤110。
对于100%的数据1 ≤ N, M ≤ 500,1 ≤ A[i]≤ 10^6
解题思路
题目中给出的图是一棵树,要求最多能卖得胡椒粉数量,考虑树形DP
我们注意到,K妹可以从一棵子树再走到另一棵子树贩卖胡椒粉,因此不能简单地使用树形背包
我们设 $f[i][j][0/1]$ 表示K妹现在正在节点 $i$,还有 $j$ 时间剩余,最后是否回到本节点
在子树中有三种转移方法
– 从之前的子树中回到根,走向当前子树不回到根
– 从之前的子树中回到根,走向当前子树回到根
– 走向当前子树回到根,再走向之前的子树不回到根
之前的子树就是已经处理过的 $i$ 的儿子
注意到走向当前子树再回到根需要2个单位的时间,不回到根只需要1个单位的时间
在树形DP向上回溯的时候需要考虑是否在当前节点花费1单位时间贩卖胡椒粉,这一部分和背包问题非常相似
代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <vector>
using namespace std;
const int MAXN = 500 + 10;
const int MAXM = 1000 + 10;
int n, m;
int Head[MAXN], to[MAXM], Next[MAXM], tot = 1;
int cnt[MAXN];
int dp[MAXN][MAXN][2];
inline void add(int a, int b){
to[tot] = b;
Next[tot] = Head[a];
Head[a] = tot++;
}
inline void dfs(int x, int fa){
for(register int i=Head[x]; i; i=Next[i]){
int v = to[i];
if(v != fa){
dfs(v, x);
for(register int j=m; j>=1; j--){
for(register int k=1; k<=j-1; k++){
dp[x][j][0] = max(dp[x][j][0], dp[v][k][0] + dp[x][j-k-1][1]);
}
for(register int k=1; k<=j-2; k++){
dp[x][j][0] = max(dp[x][j][0], dp[v][k][1] + dp[x][j-k-2][0]);
dp[x][j][1] = max(dp[x][j][1], dp[v][k][1] + dp[x][j-k-2][1]);
}
}
}
}
for(register int j=m; j>=1; j--){
dp[x][j][0] = max(dp[x][j][0], dp[x][j-1][0] + cnt[x]);
dp[x][j][1] = max(dp[x][j][1], dp[x][j-1][1] + cnt[x]);
}
}
int main(){
freopen("dostavljac.in", "r", stdin);
freopen("dostavljac.out", "w", stdout);
scanf("%d%d", &n, &m);
for(register int i=1; i<=n; i++)
scanf("%d", &cnt[i]);
for(register int i=1; i<=n-1; i++){
int a, b;
scanf("%d%d", &a, &b);
add(a, b);
add(b, a);
}
dfs(1, 0);
printf("%d\n", max(dp[1][m][0], dp[1][m][1]));
return 0;
}