树形dp(HDU-6820 Tree)

tech2023-11-11  70

题目连接 题意:给一颗带边权的树,你要求一个最大权联通快,要求里面超过k度的节点只能有一个。 题解: 这个题解不是官方思路。而是一位大佬的思路。 首先,我们dp[i][2]表示第i个节点为根的子树。有没有与fa向连的最大权值。我们不妨先考虑,没有大于k度的节点,然后再对于每一个节点,令它的度数大于k,从而求出最大权。 显然对于dp[i][0],他最多可以连接k个孩子,对于dp[i][1],他最多可以连接k-1个孩子。我们首先要择优选取。我们首先递归孩子,得出孩子节点的最大权值,再以(孩子的最大权值+这条边的权值)从大到小排序。dp[i][0] 等于前k个孩子和边权之和。dp[i][1]即为前k-1个。然后我们思考枚举其他的边。这里从上往下枚举。累计更换某个边后的权值之和。变更边时,应该从已选边里找一个最少的,然后加新边。递归结束后,我们让这个节点为度数大于k的点,所以我们直接吧他所有的孩子再累计上即可。

#include<cstdio> #include<iostream> #include<cstring> #include <map> #include <queue> #include <set> #include <cstdlib> #include <cmath> #include <algorithm> #include <vector> #include <string> #include <list> #include <bitset> #include <array> #include <cctype> #include <time.h> #pragma GCC optimize(2) void read_f() { freopen("1.in", "r", stdin); freopen("1.out", "w", stdout); } void fast_cin() { std::ios::sync_with_stdio(false); std::cin.tie(); } void run_time() { std::cout << "ESC in : " << clock() * 1000.0 / CLOCKS_PER_SEC << "ms" << std::endl; } template <typename T> bool bacmp(const T & a, const T & b) { return a > b; } template <typename T> bool pecmp(const T & a, const T & b) { return a < b; } #define ll long long #define ull unsigned ll #define _min(x, y) ((x)>(y)?(y):(x)) #define _max(x, y) ((x)>(y)?(x):(y)) #define max3(x, y, z) ( max( (x), max( (y), (z) ) ) ) #define min3(x, y, z) ( min( (x), min( (y), (z) ) ) ) #define pr(x, y) (make_pair((x), (y))) #define pb(x) push_back(x); using namespace std; const int N = 5e5+5; ll dp[N][2], ans; int n, k; vector< pair<ll, int> > g[N], v[N]; void init(int n) { for (int i = 1; i <= n; i++) { g[i].clear(); v[i].clear(); dp[i][0] = dp[i][1] = 0; } ans = 0; } void dfs(int x, int fa) { int sz = g[x].size(); for (auto i : g[x]) { int y = i.second; ll w = i.first; if (y == fa) continue; dfs(y, x); v[x].pb( pr(dp[y][0] + w, y) ); } sort(v[x].begin(), v[x].end(), bacmp<pair<ll, int> >); for (int i = 0; i < min((int)v[x].size(), k-1 + (x==1) ); i++) dp[x][0] += v[x][i].first; for (int i = 0; i < min((int)v[x].size(), k); i++) dp[x][1] += v[x][i].first; } void dfs2(int x, int fa, ll s) { for (int i = 0; i < v[x].size(); i++) { int y = v[x][i].second; ll w = v[x][i].first; if (x == 1) { if (i >= k) dfs2(y, x, s+dp[x][0] - v[x][k-1].first + w - dp[y][0]); else dfs2(y, x, s+dp[x][0] - dp[y][0]); } else { if (i >= k) dfs2(y, x, max(s+dp[x][0]-v[x][k-2].first + w -dp[y][0], dp[x][1] - v[x][k-1].first + w - dp[y][0])); else if (i == k - 1) dfs2(y,x, max(s + dp[x][0] - v[x][k-2].first + w - dp[y][0], dp[x][1]-dp[y][0])); else dfs2(y,x,max(s + dp[x][0]-dp[y][0], dp[x][1] - dp[y][0])); } } ll sum = 0; for (int i = 0; i < v[x].size(); i++) sum += v[x][i].first; ans = max(ans, sum+s); } int main() { int t; cin >> t; while(t--) { scanf("%d%d", &n, &k); init(n); for (int i = 1; i < n; i++) { int x, y; ll w; scanf("%d%d%lld", &x, &y, &w); g[x].pb(pr(w, y)); g[y].pb(pr(w, x)); } if (k == 0) { puts("0"); continue; } else if (k == 1) { for (int i = 1; i <= n; i++) { ll sum = 0; for (auto j : g[i]) sum += j.first; ans = max(ans, sum); } printf("%lld\n", ans); continue; } dfs(1, 0); dfs2(1, 0, 0); printf("%lld\n", ans); } }
最新回复(0)