[JSOI2011]柠檬

tech2024-11-17  33

Link!

题意

一次操作可以选择一段连续区间和一个颜色 s 0 s_0 s0,获得 s 0 t 2 s_0t^2 s0t2的收益,其中 t t t是区间内这个颜色出现的个数。

思路

首先观察到取的顺序不影响结果,不妨从左到右取。注意到每次取的区间端点颜色一定一样,否则将不一样的颜色单独出来,显然更优。 设 d p [ i ] dp[i] dp[i]表示取到前 i i i个的最大收益,有转移方程 d p [ i ] = m a x ( d p [ j ] + ( s [ i ] − s [ j ] + 1 ) 2 ) ( c o l [ i ] = = c o l [ j ] ) dp[i] = max(dp[j] + (s[i]-s[j]+1)^2) (col[i] == col[j]) dp[i]=max(dp[j]+(s[i]s[j]+1)2)(col[i]==col[j]) 对于每种颜色,具有决策单调性,即对于 j 1 < j 2 < i 1 < i 2 j1 < j2 < i1 < i2 j1<j2<i1<i2当一个 j 1 j1 j1转移到 i 1 i1 i1, j 2 j2 j2就不可能转移到 i 2 i2 i2,也就是说决策点是不增的,其原因是二次函数的增长是越来越快的。 于是我们可以用一个单调栈维护最优决策点。 记 f ( x , y ) f(x,y) f(x,y)表示 x x x这个决策点是什么时候优于 y y y的,这个可以二分求出。 如果 f ( S t k t o p − 1 , S t k t o p ) f(Stk_{top-1},Stk_{top}) f(Stktop1,Stktop)早于当前点 x x x,将 S t k t o p Stk_{top} Stktop弹出,这里会遇到一个问题,如果 S t k t o p − 2 Stk_{top-2} Stktop2优于 S t k t o p − 1 Stk_{top-1} Stktop1 S t k t o p Stk_{top} Stktop S t k t o p − 1 Stk_{top-1} Stktop1劣于 S t k t o p Stk_{top} Stktop,我们便不会弹栈,因此取不到最优决策点。 解决办法就是如果 f ( S t k t o p − 1 , S t k t o p ) < = f ( S t k t o p , x ) f(Stk_{top-1},Stk_{top}) <= f(Stk_{top},x) f(Stktop1,Stktop)<=f(Stktop,x),就弹栈,因为如果 S t k t o p Stk_{top} Stktop超过了 x x x,那么 S t k t o p − 1 Stk_{top-1} Stktop1早已超过了他们。 这道决策单调性貌似只能单调栈而不能分治。

代码

#include<bits/stdc++.h> #define int long long #define N 100015 #define rep(i,a,n) for (int i=a;i<=n;i++) #define per(i,a,n) for (int i=n;i>=a;i--) #define inf 0x3f3f3f3f #define pb push_back #define mp make_pair #define lowbit(i) ((i)&(-i)) #define VI vector<int> #define SZ(x) ((int)x.size()) using namespace std; int n,cnt[N],val[N],s[N],dp[N]; VI st[N]; int calc(int x,int y){ return dp[x-1]+val[x]*y*y; } int beyond(int x,int y){ int l = 0,r = n; while(l + 3 < r){ int mid = (l+r)>>1; if(calc(x,mid-s[x]+1) >= calc(y,mid-s[y]+1)){ r = mid; }else l = mid; } rep(mid,l,r) if(calc(x,mid-s[x]+1) >= calc(y,mid-s[y]+1)) return mid; return l; } signed main(){ //freopen(".in","r",stdin); //freopen(".out","w",stdout); scanf("%lld",&n); rep(i,1,n) scanf("%lld",&val[i]),s[i] = ++cnt[val[i]]; rep(i,1,n){ while(st[val[i]].size()>=2 && beyond(st[val[i]][SZ(st[val[i]])-2],st[val[i]][SZ(st[val[i]])-1]) <= beyond(st[val[i]][SZ(st[val[i]])-1],i)) st[val[i]].pop_back(); st[val[i]].pb(i); while(st[val[i]].size() >= 2 && beyond(st[val[i]][SZ(st[val[i]])-2],st[val[i]][SZ(st[val[i]])-1]) <= s[i]) st[val[i]].pop_back(); dp[i] = calc(st[val[i]][SZ(st[val[i]])-1],s[i]-s[st[val[i]][SZ(st[val[i]])-1]]+1); } printf("%lld\n", dp[n]); return 0; }
最新回复(0)