原题传送门 注意 a < = 10 a<=10 a<=10 可以倍增的同时,暴力记录每个点往上跳 2 j 2^j 2j步范围内前10小的点集 倍增跳的时候暴力合并
Code:
#include <bits/stdc++.h> #define maxn 100010 using namespace std; struct Edge{ int to, next; }edge[maxn << 1]; struct data{ int num[15], tot; }ans, node[maxn][25]; int num, head[maxn], n, m, Q, d[maxn], fa[maxn][25]; inline int read(){ int s = 0, w = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') w = -1; for (; isdigit(c); c = getchar()) s = (s << 1) + (s << 3) + (c ^ 48); return s * w; } void addedge(int x, int y){ edge[++num] = (Edge){y, head[x]}, head[x] = num; } void merge(data &a, data b){ data c; int i = 1, j = 1; c.tot = 0; for (; c.tot <= 10; ++c.tot){ if (i <= a.tot && (j > b.tot || a.num[i] < b.num[j])) c.num[c.tot + 1] = a.num[i++]; else if (j <= b.tot) c.num[c.tot + 1] = b.num[j++]; else break; } a.tot = c.tot; for (int i = 1; i <= c.tot; ++i) a.num[i] = c.num[i]; } void build(int u, int pre){ d[u] = d[pre] + 1, fa[u][0] = pre; for (int i = 0; fa[u][i]; ++i) fa[u][i + 1] = fa[fa[u][i]][i], merge(node[u][i + 1], node[u][i]), merge(node[u][i + 1], node[fa[u][i]][i]); for (int i = head[u]; i; i = edge[i].next){ int v = edge[i].to; if (v != pre) build(v, u); } } int main(){ n = read(), m = read(), Q = read(); for (int i = 1; i < n; ++i){ int x = read(), y = read(); addedge(x, y), addedge(y, x); } for (int i = 1; i <= m; ++i){ int x = read(); data tmp; tmp.tot = 1, tmp.num[1] = i; merge(node[x][0], tmp); } /* for (int i = 1; i <= n; ++i){ printf("%d\n", node[i][0].tot); for (int j = 1; j <= node[i][0].tot; ++j) printf("%d ", node[i][0].num[j]); puts("m"); }*/ build(1, 0); while (Q--){ int u = read(), v = read(), a = read(); if (d[u] < d[v]) swap(u, v); ans.tot = 0; for (int i = 20; i >= 0; --i) if (d[u] - (1 << i) >= d[v]) merge(ans, node[u][i]), u = fa[u][i]; if (u != v){ for (int i = 20; i >= 0; --i) if (fa[u][i] != fa[v][i]) merge(ans, node[u][i]), merge(ans, node[v][i]), u = fa[u][i], v = fa[v][i]; merge(ans, node[u][0]), merge(ans, node[v][0]), u = fa[u][0], v = fa[v][0]; } merge(ans, node[u][0]); ans.tot = min(ans.tot, a); printf("%d ", ans.tot); for (int i = 1; i <= ans.tot; ++i) printf("%d ", ans.num[i]); puts(""); } return 0; }