树上背包问题

tech2026-01-26  10

树上背包问题

树上背包问题算法原理例题一:有依赖的背包问题题意数据范围思路代码 例题二:二叉苹果树题意数据范围思路代码 例题三:Factories(2018icpc银川网络赛)题意数据范围思路代码

树上背包问题

一些题目给定了树形结构,在这个树形结构中选取一定数量的点或边(也可能是其他属性),使得某种与点权或者边权相关的花费最大或者最小。解决这类问题,一般要考虑使用树上背包。

算法原理

树上背包,顾名思义,就是在树上做背包问题。一个节点的若干子树可以看作是若干组背包,也就是用树形dp的方式做分组背包问题。一般来说, f ( i , j ) f(i,j) f(i,j)表示以 i i i为根的子树中,在 j j j的容量范围内,最大或者最小可以获得多少收益。根据分组背包的思想,第一维枚举物品(在树上指的是子树),第二维枚举容量,第三维枚举决策(这里指的是给子树分配多少容量)。基本的代码框架如下:

void dfs(int u, int fa) { for(int i = h[u]; ~i; i = ne[i]) { int son = e[i]; if(son == fa) continue; dfs(son, u); for(int j = m; j >= 0; j --) for(int k = 0; k <= j; k ++) f[u][j] = max(f[u][j], f[u][j-k] + f[son][k] + val); } }

例题一:有依赖的背包问题

题意

n n n个物品和一个容量是 m m m的背包。物品之间具有依赖关系,且依赖关系组成一棵树的形状。如果选择一个物品,则必须选择它的父节点。 求解将哪些物品装入背包,可使物品总体积不超过背包容量,且总价值最大。输出最大价值。 每件物品的编号是 i i i,体积是 v i v_i vi,价值是 w i w_i wi,依赖的父节点编号是 p i p_i pi。物品的下标范围是 1 … N 1 \dots N 1N

数据范围

1 ≤ n , m ≤ 100 1 \leq n,m \leq 100 1n,m100 1 ≤ v i , w i ≤ 100 1 \leq v_i,w_i \leq 100 1vi,wi100

思路

f ( i , j ) f(i,j) f(i,j)表示选择以 i i i为子树的物品,在容量不超过 j j j时所获得的最大价值。 由于只有选择了根节点,才会继续往下遍历,所以在遍历到 i i i节点时,先考虑一定选上它。 在分组背包部分, j j j的范围为 [ m , v [ i ] ] [m,v[i]] [m,v[i]],否则没有意义,因为连根节点也放不下; k k k的范围 [ 0 , j − v [ i ] ] [0,j-v[i]] [0,jv[i]],当大于 j − v [ i ] j-v[i] jv[i]时分给该子树的容量过多,剩余的容量连根节点的物品都放不下了。 递推式为: f ( i , j ) = m a x ( f ( i , j ) , f ( i , j − k ) + f ( s o n , k ) ) f(i,j) = max(f(i,j), f(i,j - k) + f(son,k)) f(i,j)=max(f(i,j),f(i,jk)+f(son,k))

代码

void dfs(int u) { for(int i = v[u]; i <= m; i ++) f[u][i] = w[u]; for(int i = h[u]; ~i; i = ne[i]) { int son = e[i]; dfs(son); for(int j = m; j >= v[u]; j --) for(int k = 0; k <= j - v[u]; k ++) f[u][j] = max(f[u][j], f[u][j - k] + f[son][k]); } }

例题二:二叉苹果树

题意

给定一棵二叉树,每条边有边权,保留一定数量的边(其他边删除),使得保留下来的边的边权和最大。

数据范围

1 ≤ n < m ≤ 100 1 \leq n < m \leq 100 1n<m100 w i ≤ 30000 w_i \leq 30000 wi30000

思路

f ( i , j ) f(i,j) f(i,j)表示以 i i i为根的子树中,恰好保留 j j j条边的最大边权和。 若需要选择该子树中的边,则根结点到子树的边一定要选,因此能用上的总边数一定减 1 1 1,总共可以选择 j j j条边时,当前子树son分配的最大边数是 j − 1 j - 1 j1。 递推式为, f ( i , j ) = m a x ( f ( i , j ) , f ( i , j − k − 1 ) + f ( s o n , k ) + w [ i ] ) f(i,j) = max(f(i,j), f(i,j-k-1) + f(son, k) + w[i]) f(i,j)=max(f(i,j),f(i,jk1)+f(son,k)+w[i])

代码

void dfs(int u, int fa) { for(int i = h[u]; ~i; i = ne[i]) { int son = e[i]; if(son == fa) continue; dfs(son, u); for(int j = m; j >= 1; j -- ) for(int k = 0; k <= j - 1; k ++ ) f[u][j] = max(f[u][j], f[u][j - k - 1] + f[son][k] + w[i]); } }

例题三:Factories(2018icpc银川网络赛)

题意

给定一棵树,边有边权。每个叶子节点上最多可以布置一个工厂,总共要布置 k k k个工厂。问怎样布置工厂,使得工厂之间的距离和最小。

数据范围

10 s 10s 10s 2 ≤ n ≤ 1 0 5 2 \leq n \leq 10^5 2n105, 1 ≤ m ≤ 100 1 \leq m \leq 100 1m100 1 ≤ w i ≤ 1 0 5 1 \leq w_i \leq 10^5 1wi105 多组测试数据, n n n总数不超过 1 0 6 10^6 106

思路

直接考虑距离之和非常困难,所以可以考虑每条边被计算了几次(距离和等类似问题很多都是这么考虑的)。不妨设一条边为 i i i,与 i i i相连的子树中有 j j j个工厂,则这条边被计算的次数为 j ∗ ( m − j ) j*(m - j) j(mj) f ( i , j ) f(i,j) f(i,j)表示以 i i i为根节点的子树中,选择恰好 j j j个叶子节点的距离总和。 递推式为, f ( i , j ) = m i n ( f ( i , j ) , f ( i , j − k ) + f ( s o n , k ) + w [ i ] ∗ j ∗ ( m − j ) ) f(i,j) = min(f(i,j), f(i,j - k) + f(son, k) + w[i] * j * (m - j)) f(i,j)=min(f(i,j),f(i,jk)+f(son,k)+w[i]j(mj))。 因为只能分布在叶子节点,因此初始化的时候要注意,如果点 i i i为叶子节点,那么 f ( i , 1 ) = 0 f(i,1) = 0 f(i,1)=0。 同时这道题要卡常数,所以要对状态做一个优化,即把无效状态去掉。

代码

#include <iostream> #include <cstdio> #include <cstring> #include <algorithm> using namespace std; typedef long long ll; const int N = 100003, M = 103; const ll inf = 1e18; int n, m; int h[N], e[2*N], ne[2*N], w[2*N], idx; int s[N], deg[N]; ll f[N][M]; void add(int a,int b,int c) { e[idx] = b, ne[idx] = h[a], w[idx] = c, h[a] = idx ++; } void dfs(int u,int fa) { for(int i = h[u]; ~i; i = ne[i]) { int son = e[i]; if(son == fa) continue; dfs(son, u); s[u] += s[son]; for(int j = min(m, s[u]); j >= 1; j --) for(int k = 1; k <= min(j, s[son]); k ++) f[u][j] = min(f[u][j], f[u][j-k] + f[son][k] + (ll)w[i] * k * (m - k)); } } int main() { int T; scanf("%d", &T); int cas = 0; while(T --) { scanf("%d%d", &n,&m); for(int i = 1; i <= n; i ++) h[i] = -1, deg[i] = 0; idx = 0; for(int i = 0; i < n - 1; i ++) { int a,b,c; scanf("%d%d%d", &a,&b,&c); add(a,b,c), add(b,a,c); deg[a] ++, deg[b] ++; } for(int i = 1; i <= n; i ++) { s[i] = 0; for(int j = 1; j <= m; j ++) f[i][j] = inf; if(deg[i]==1) f[i][1] = 0, s[i] = 1; } dfs(1, -1); printf("Case #%d: %lld\n",++cas,f[1][m]); } return 0; }
最新回复(0)