题目链接
题面:
题意: 有一棵 n n n 个节点的树,边权均为 1 1 1,从上面选 m m m 个点的方案为 C n m C_n^m Cnm。 对于每一种方案,该方案的权重定义为这 m m m 个点到树上某一点的距离和的最小值。我们定义这一点为最优点。 求 C n m C_n^m Cnm 种方案的权重的和。
题解: 我没枚举每一条边,假设这条边两侧的节点数分别为 s s s, n − s n-s n−s。我们在这条边两侧选的节点数为 i i i, m − i m-i m−i,我们可以知道,最优点一定在选的点数较多的一侧。
那么对于某条边来说较为容易得到公式:
f ( s ) = ∑ i = 1 m − 1 C s i ∗ C n − s m − i ∗ m i n ( i , m − i ) f(s)=\sum\limits_{i=1}^{m-1}C_s^i*C_{n-s}^{m-i}*min(i,m-i) f(s)=i=1∑m−1Csi∗Cn−sm−i∗min(i,m−i)
显然,对于每一条边,计算该式子的时间复杂度是 O ( n 2 ) O(n^2) O(n2) 的。
考虑化简,我们令 p = ⌊ m − 1 2 ⌋ p=\left\lfloor\frac{m-1}{2}\right\rfloor p=⌊2m−1⌋,那么我们得到式子:
f ( s ) = ∑ i = 1 p C s i ∗ C n − s m − i ∗ i + ∑ i = 1 p C s m − i ∗ C n − s i ∗ i + [ m m o d 2 = 0 ] ∗ C s m 2 ∗ C n − s m 2 ∗ m 2 f(s)=\sum\limits_{i=1}^pC_s^i*C_{n-s}^{m-i}*i+\sum\limits_{i=1}^pC_s^{m-i}*C_{n-s}^i*i+[m\space mod\space 2=0]*C_s^{\frac{m}{2}}*C_{n-s}^{\frac{m}{2}}*\frac{m}{2} f(s)=i=1∑pCsi∗Cn−sm−i∗i+i=1∑pCsm−i∗Cn−si∗i+[m mod 2=0]∗Cs2m∗Cn−s2m∗2m
我们令 g ( s ) = ∑ i = 1 p C s i ∗ C n − s m − i ∗ i g(s)=\sum\limits_{i=1}^pC_s^i*C_{n-s}^{m-i}*i g(s)=i=1∑pCsi∗Cn−sm−i∗i , h ( s ) = ∑ i = 1 p C s m − i ∗ C n − s i ∗ i h(s)=\sum\limits_{i=1}^pC_s^{m-i}*C_{n-s}^i*i h(s)=i=1∑pCsm−i∗Cn−si∗i, k ( s ) = C s m 2 ∗ C n − s m 2 ∗ m 2 k(s)=C_s^{\frac{m}{2}}*C_{n-s}^{\frac{m}{2}}*\frac{m}{2} k(s)=Cs2m∗Cn−s2m∗2m
容易发现 h ( s ) = g ( n − s ) h(s)=g(n-s) h(s)=g(n−s),现在我们考虑怎么快速求出 g ( s ) g(s) g(s)。
g ( s ) = ∑ i = 1 p C s i ∗ C n − s m − i ∗ i g(s)=\sum\limits_{i=1}^pC_s^i*C_{n-s}^{m-i}*i g(s)=i=1∑pCsi∗Cn−sm−i∗i
C s i = s ! i ! ∗ ( s − i ) ! = ( s − 1 ) ! ∗ s ( i − 1 ) ! ∗ i ∗ ( s − i ) ! = C s − 1 i − 1 ∗ s i C_s^i=\dfrac{s!}{i!*(s-i)!}=\dfrac{(s-1)!*s}{(i-1)!*i*(s-i)!}=C_{s-1}^{i-1}*\dfrac{s}{i} Csi=i!∗(s−i)!s!=(i−1)!∗i∗(s−i)!(s−1)!∗s=Cs−1i−1∗is
g ( s ) = s ∗ ∑ i = 1 p C s − 1 i − 1 ∗ C n − s m − i = s ∗ t ( s ) g(s)=s*\sum\limits_{i=1}^pC_{s-1}^{i-1}*C_{n-s}^{m-i}=s*t(s) g(s)=s∗i=1∑pCs−1i−1∗Cn−sm−i=s∗t(s) ,其中 t ( s ) = ∑ i = 1 p C s − 1 i − 1 ∗ C n − s m − i t(s)=\sum\limits_{i=1}^pC_{s-1}^{i-1}*C_{n-s}^{m-i} t(s)=i=1∑pCs−1i−1∗Cn−sm−i。
考虑给定 t ( s ) t(s) t(s) 一个定义: n − 1 n-1 n−1 个位置,放置 m − 1 m-1 m−1 个球,每个球只能放在一个位置上,每个位置至多放置一个球。其中要求前 s − 1 s-1 s−1 个位置至多放置 p − 1 p-1 p−1 个球。 得到: t ( s ) = ∑ i = 1 p C s − 1 i − 1 ∗ C n − s m − i t(s)=\sum\limits_{i=1}^pC_{s-1}^{i-1}*C_{n-s}^{m-i} t(s)=i=1∑pCs−1i−1∗Cn−sm−i。
明显需要满足 p > = 1 p>=1 p>=1。且 s = 1 s=1 s=1时, t ( s ) = C n − 1 m − 1 t(s)=C_{n-1}^{m-1} t(s)=Cn−1m−1
考虑:怎么由 t ( s − 1 ) t(s-1) t(s−1) 得到 t ( s ) t(s) t(s)。
要求改变的地方为,从前 s − 2 s-2 s−2 个位置至多放置 p − 1 p-1 p−1 个球,转化为前 s − 1 s-1 s−1 个位置至多放置 p − 1 p-1 p−1 个球。
考虑哪些不合法。 那些在 t ( s − 1 ) t(s-1) t(s−1) 种合法且在 t ( s ) t(s) t(s) 种不合法的一定是,前 s − 2 s-2 s−2 个位置已经放置了 p − 1 p-1 p−1 个球,但是第 s − 1 s-1 s−1 的位置还有一个球。即 C s − 2 p − 1 ∗ C n − s m − 1 − p C_{s-2}^{p-1}*C_{n-s}^{m-1-p} Cs−2p−1∗Cn−sm−1−p。
这样,我们可以快速求出 t ( s ) t(s) t(s),从而快速得到 g ( s ) g(s) g(s),从而得到 h ( s ) h(s) h(s),最终得到 f ( s ) f(s) f(s)。
注意 m = 1 m=1 m=1 和 m = 2 m=2 m=2 这两种情况下, p = 0 p=0 p=0。
代码:
#include<iostream> #include<cstdio> #include<cstdlib> #include<algorithm> #include<cstring> #include<cmath> #include<string> #include<queue> #include<bitset> #include<map> #include<unordered_map> #include<unordered_set> #include<set> #include<ctime> #define ui unsigned int #define ll long long #define llu unsigned ll #define ld long double #define pr make_pair #define pb push_back //#define lc (cnt<<1) //#define rc (cnt<<1|1) #define len(x) (t[(x)].r-t[(x)].l+1) #define tmid ((l+r)>>1) #define fhead(x) for(int i=head[(x)];i;i=nt[i]) #define max(x,y) ((x)>(y)?(x):(y)) #define min(x,y) ((x)>(y)?(y):(x)) using namespace std; const int inf=0x3f3f3f3f; const ll lnf=0x3f3f3f3f3f3f3f3f; const double dnf=1e18; const double alpha=0.75; const int mod=1e9+7; const double eps=1e-8; const double pi=acos(-1.0); const int hp=13331; const int maxn=1000100; const int maxm=100100; const int maxp=100100; const int up=1100; ll fac[maxn],inv[maxn]; ll t[maxn],g[maxn],h[maxn],k[maxn],ans; int f[maxn],si[maxn],n,m; ll mypow(ll a,ll b) { ll ans=1; while(b) { if(b&1) ans=ans*a%mod; a=a*a%mod; b>>=1; } return ans; } void init(void) { fac[0]=1; for(int i=1;i<maxn;i++) fac[i]=fac[i-1]*i%mod; inv[maxn-1]=mypow(fac[maxn-1],mod-2); for(int i=maxn-2;i>=0;i--) inv[i]=inv[i+1]*(i+1)%mod; } ll C(ll n,ll m) { if(n<0||m<0||m>n) return 0; return fac[n]*inv[m]%mod*inv[n-m]%mod; } int main(void) { init(); int tt; scanf("%d",&tt); while(tt--) { scanf("%d%d",&n,&m); for(int i=2;i<=n;i++) scanf("%d",&f[i]),si[i]=1; ans=0; int p=(m-1)/2; t[1]=p?C(n-1,m-1):0; g[1]=t[1]*1; for(int s=2;s<=n;s++) { t[s]=((t[s-1]-C(s-2,p-1)*C(n-s,m-1-p)%mod)%mod+mod)%mod; g[s]=t[s]*s%mod; } for(int s=1;s<=n;s++) { h[s]=g[n-s]; k[s]=C(s,m/2)*C(n-s,m/2)%mod*(m/2)%mod; } int now=0; for(int i=n;i>=2;i--) { si[f[i]]+=si[i]; now=min(si[i],n-si[i]); ans=(ans+g[now]+h[now]+(m%2==0?k[now]:0))%mod; } printf("%lld\n",ans); } return 0; }