【点分治】[BJOI2017]树的难题

tech2024-06-18  72

题目

给你一棵 nn 个点的无根树。

树上的每条边具有颜色。一共有 mm 种颜色,编号为 11 到 mm,第 ii 种颜色的权值为 c_ic i ​ 。

对于一条树上的简单路径,路径上经过的所有边按顺序组成一个颜色序列,序列可以划分成若干个相同颜色段。定义路径权值为颜色序列上每个同颜色段的颜色权值之和。

请你计算,经过边数在 ll 到 rr 之间的所有简单路径中,路径权值的最大值。

思路

考虑分治重心为k的时候。由于求最大值,所以不能容斥,所以考虑给每个点染色(染他是哪个子树的),记录这个点到k的路径上最靠上的一条边的颜色。 考虑如何合并路径。将这个剩下的一堆半链按照最顶端颜色排序,再按照所属子树排序。这样所有子树相同的点在一个区间里,所有顶端颜色相同的点也在一个连续区间里。 考虑到顶端颜色相同的需要减去顶端颜色权值,顶端颜色不同的不需要。 考虑维护一个权值线段树,以长度为下标,以路径权值为值的线段树,维护区间最大值。这个线段树在一个颜色区间里的所有点都计算完之后,把这些点全都插入。这是为了计算跨颜色的路径合并。另一个线段树,在同一个颜色里初始化一次,在一个子树区间里全部计算完后把这些点都插入。以长度为下标的目的是,求长度在某一范围内的和的最大值。

代码

#include <bits/stdc++.h> using namespace std; const int N=2e5+77,inf=0x3f3f3f3f; int n,m,L,R,c[N],ls[N],to[N<<1],clr[N<<1],nx[N<<1],sum,top,rt,fiz[N],siz[N]; bool ban[N]; struct Seg { int dis,num,bel; }p[N]; struct Rmq { #define ls (x<<1) #define rs (x<<1|1) int val[N<<2]; bool tag[N<<2]; void clear() { val[1]=-inf; tag[1]=1; } void pushDown(int x) { val[ls]=-inf,tag[ls]=1; val[rs]=-inf,tag[rs]=1; tag[x]=0; } void modify(int x,int l,int r,int p,int w) { if(l==r) return void(val[x]=w); int mid=(l+r)>>1; val[x]=max(val[x],w); if(tag[x]) pushDown(x); if(p<=mid) modify(ls,l,mid,p,w); else modify(rs,mid+1,r,p,w); } int query(int x,int l,int r,int L,int R) { if(L<=l && r<=R) return val[x]; int mid=(l+r)>>1, ret=-inf; if(tag[x]) return ret; if(L<=mid) ret=max(ret,query(ls,l,mid,L,R)); if(mid<R) ret=max(ret,query(rs,mid+1,r,L,R)); return ret; } void modify(int p,int w) { modify(1,1,n+1,p+1,w); } int query(int L,int R) { if(R<L || R<0) return -inf; return query(1,1,n+1,L+1,R+1); } #undef ls #undef rs } A,B; void add(int x,int y,int c) { int cnt=0; to[++cnt]=y; clr[cnt]=c; nx[cnt]=ls[x]; ls[x]=cnt; } void getrt(int x,int pa) { fiz[x]=0,siz[x]=1; for(int i=ls[x]; i; i=nx[i]) { if(to[i]==pa||ban[to[i]]) continue; getrt(to[i],x); siz[x]+=siz[to[i]]; fiz[x]=max(fiz[x],siz[to[i]]); } fiz[x]=max(fiz[x],sum-siz[x]); if(fiz[x]<fiz[rt]) rt=x; } int ans=-2e9; void getDis(int x,int pa,int dis,int num,int pClr,int bel) { p[++top]=(Seg){dis,num,bel}; for(int i=ls[x]; i; i=nx[i]) { if(to[i]==pa||ban[to[i]]) continue; if(pClr==clr[i]) getDis(to[i],x,dis,num+1,clr[i],bel); else getDis(to[i],x,dis+c[clr[i]],num+1,clr[i],bel); } } void calc(int x) { p[top=1]=(Seg){0,0,0}; for(int i=ls[x]; i; i=nx[i]) { if(ban[to[i]]) continue; getDis(to[i],x,c[clr[i]],1,clr[i],i); } sort(p+1,p+top+1,[=](Seg x,Seg y) { if(clr[x.bel]!=clr[y.bel]) return clr[x.bel]<clr[y.bel]; return x.bel<y.bel; }); A.clear(); B.clear(); for(int l=1,r; l<=top; l=r+1) { for(r=l; r<top && clr[p[l].bel]==clr[p[r+1].bel]; ++r); for(int x=l,y; x<=r; x=y+1) { for(y=x; y<r && p[x].bel==p[y+1].bel; ++y); if(x!=l) for(int i=x; i<=y; ++i) ans=max(ans,p[i].dis+B.query(L-p[i].num,R-p[i].num)-c[clr[p[i].bel]]); for(int i=x; i<=y; ++i) B.modify(p[i].num,p[i].dis); } B.clear(); if(l!=1) for(int i=l; i<=r; ++i) ans=max(ans,p[i].dis+A.query(L-p[i].num,R-p[i].num)); for(int i=l; i<=r; ++i) A.modify(p[i].num,p[i].dis); } } void solve(int x) { ban[x]=true; calc(x); for(int i=ls[x]; i; i=nx[i]) { if(ban[to[i]]) continue; rt=0; sum=siz[to[i]]; getrt(to[i],x); solve(rt); } } int main() { scanf("%d%d%d%d",&n,&m,&L,&R); for(int i=1; i<=m; ++i) scanf("%d",&c[i]); for(int x,y,c,i=n; i>=1; i--) { scanf("%d%d%d",&x,&y,&c); add(x,y,c); add(y,x,c); } rt=0; sum=n; fiz[0]=2e9; getrt(1,0); solve(rt); printf("%d\n",ans); }
最新回复(0)