题目
给你一棵 nn 个点的无根树。
树上的每条边具有颜色。一共有 mm 种颜色,编号为 11 到 mm,第 ii 种颜色的权值为 c_ic i 。
对于一条树上的简单路径,路径上经过的所有边按顺序组成一个颜色序列,序列可以划分成若干个相同颜色段。定义路径权值为颜色序列上每个同颜色段的颜色权值之和。
请你计算,经过边数在 ll 到 rr 之间的所有简单路径中,路径权值的最大值。
思路
考虑分治重心为k的时候。由于求最大值,所以不能容斥,所以考虑给每个点染色(染他是哪个子树的),记录这个点到k的路径上最靠上的一条边的颜色。 考虑如何合并路径。将这个剩下的一堆半链按照最顶端颜色排序,再按照所属子树排序。这样所有子树相同的点在一个区间里,所有顶端颜色相同的点也在一个连续区间里。 考虑到顶端颜色相同的需要减去顶端颜色权值,顶端颜色不同的不需要。 考虑维护一个权值线段树,以长度为下标,以路径权值为值的线段树,维护区间最大值。这个线段树在一个颜色区间里的所有点都计算完之后,把这些点全都插入。这是为了计算跨颜色的路径合并。另一个线段树,在同一个颜色里初始化一次,在一个子树区间里全部计算完后把这些点都插入。以长度为下标的目的是,求长度在某一范围内的和的最大值。
代码
#include <bits/stdc++.h>
using namespace std
;
const int N
=2e5+77,inf
=0x3f3f3f3f;
int n
,m
,L
,R
,c
[N
],ls
[N
],to
[N
<<1],clr
[N
<<1],nx
[N
<<1],sum
,top
,rt
,fiz
[N
],siz
[N
];
bool ban
[N
];
struct Seg
{
int dis
,num
,bel
;
}p
[N
];
struct Rmq
{
#define ls (x<<1)
#define rs (x<<1|1)
int val
[N
<<2];
bool tag
[N
<<2];
void clear() {
val
[1]=-inf
;
tag
[1]=1;
}
void pushDown(int x
) {
val
[ls
]=-inf
,tag
[ls
]=1;
val
[rs
]=-inf
,tag
[rs
]=1;
tag
[x
]=0;
}
void modify(int x
,int l
,int r
,int p
,int w
) {
if(l
==r
) return void(val
[x
]=w
);
int mid
=(l
+r
)>>1;
val
[x
]=max(val
[x
],w
);
if(tag
[x
]) pushDown(x
);
if(p
<=mid
) modify(ls
,l
,mid
,p
,w
);
else modify(rs
,mid
+1,r
,p
,w
);
}
int query(int x
,int l
,int r
,int L
,int R
) {
if(L
<=l
&& r
<=R
) return val
[x
];
int mid
=(l
+r
)>>1, ret
=-inf
;
if(tag
[x
]) return ret
;
if(L
<=mid
) ret
=max(ret
,query(ls
,l
,mid
,L
,R
));
if(mid
<R
) ret
=max(ret
,query(rs
,mid
+1,r
,L
,R
));
return ret
;
}
void modify(int p
,int w
) {
modify(1,1,n
+1,p
+1,w
);
}
int query(int L
,int R
) {
if(R
<L
|| R
<0) return -inf
;
return query(1,1,n
+1,L
+1,R
+1);
}
#undef ls
#undef rs
} A
,B
;
void add(int x
,int y
,int c
) {
int cnt
=0;
to
[++cnt
]=y
;
clr
[cnt
]=c
;
nx
[cnt
]=ls
[x
];
ls
[x
]=cnt
;
}
void getrt(int x
,int pa
) {
fiz
[x
]=0,siz
[x
]=1;
for(int i
=ls
[x
]; i
; i
=nx
[i
]) {
if(to
[i
]==pa
||ban
[to
[i
]]) continue;
getrt(to
[i
],x
);
siz
[x
]+=siz
[to
[i
]];
fiz
[x
]=max(fiz
[x
],siz
[to
[i
]]);
}
fiz
[x
]=max(fiz
[x
],sum
-siz
[x
]);
if(fiz
[x
]<fiz
[rt
]) rt
=x
;
}
int ans
=-2e9;
void getDis(int x
,int pa
,int dis
,int num
,int pClr
,int bel
) {
p
[++top
]=(Seg
){dis
,num
,bel
};
for(int i
=ls
[x
]; i
; i
=nx
[i
]) {
if(to
[i
]==pa
||ban
[to
[i
]]) continue;
if(pClr
==clr
[i
]) getDis(to
[i
],x
,dis
,num
+1,clr
[i
],bel
);
else getDis(to
[i
],x
,dis
+c
[clr
[i
]],num
+1,clr
[i
],bel
);
}
}
void calc(int x
)
{
p
[top
=1]=(Seg
){0,0,0};
for(int i
=ls
[x
]; i
; i
=nx
[i
])
{
if(ban
[to
[i
]]) continue;
getDis(to
[i
],x
,c
[clr
[i
]],1,clr
[i
],i
);
}
sort(p
+1,p
+top
+1,[=](Seg x
,Seg y
)
{
if(clr
[x
.bel
]!=clr
[y
.bel
]) return clr
[x
.bel
]<clr
[y
.bel
];
return x
.bel
<y
.bel
;
});
A
.clear();
B
.clear();
for(int l
=1,r
; l
<=top
; l
=r
+1)
{
for(r
=l
; r
<top
&& clr
[p
[l
].bel
]==clr
[p
[r
+1].bel
]; ++r
);
for(int x
=l
,y
; x
<=r
; x
=y
+1)
{
for(y
=x
; y
<r
&& p
[x
].bel
==p
[y
+1].bel
; ++y
);
if(x
!=l
) for(int i
=x
; i
<=y
; ++i
)
ans
=max(ans
,p
[i
].dis
+B
.query(L
-p
[i
].num
,R
-p
[i
].num
)-c
[clr
[p
[i
].bel
]]);
for(int i
=x
; i
<=y
; ++i
) B
.modify(p
[i
].num
,p
[i
].dis
);
}
B
.clear();
if(l
!=1) for(int i
=l
; i
<=r
; ++i
)
ans
=max(ans
,p
[i
].dis
+A
.query(L
-p
[i
].num
,R
-p
[i
].num
));
for(int i
=l
; i
<=r
; ++i
) A
.modify(p
[i
].num
,p
[i
].dis
);
}
}
void solve(int x
)
{
ban
[x
]=true;
calc(x
);
for(int i
=ls
[x
]; i
; i
=nx
[i
])
{
if(ban
[to
[i
]]) continue;
rt
=0;
sum
=siz
[to
[i
]];
getrt(to
[i
],x
);
solve(rt
);
}
}
int main()
{
scanf("%d%d%d%d",&n
,&m
,&L
,&R
);
for(int i
=1; i
<=m
; ++i
) scanf("%d",&c
[i
]);
for(int x
,y
,c
,i
=n
; i
>=1; i
--)
{
scanf("%d%d%d",&x
,&y
,&c
);
add(x
,y
,c
);
add(y
,x
,c
);
}
rt
=0;
sum
=n
;
fiz
[0]=2e9;
getrt(1,0);
solve(rt
);
printf("%d\n",ans
);
}