样例输入:
5 1
1 4 2 8 5
样例输出:
4
分析:看到这种对其中连续k个数进行修改的我们就应该想到答案是由三部分组成,因为求的是最长不下降子序列,那么我们可以找到一个最合适的断点i,使得答案是由区间[1,i],[i+1,i+k],[i+k+1,n]三部分组成,其中区间[i+1,i+k]里面的数是可以任意变化的,那么我们只要在区间[1,i]和区间[i+k+1,n]中找到一个最长不下降子序列b1,b2,……,bm,那么我们就可以将区间[i+1,i+k]中的所有数变为某个bj,使得最长不下降子序列的长度为m+k,所以现在我们的关键问题就是为了求取m。
一般这种问题就是要设置一个前缀和一个后缀,表示含义如下:
f1[i]表示a[1~i]中以a[i]结尾的最长不下降子序列的长度
f2[i]表示a[i~n]中以a[i]开头的最长不下降子序列的长度
这两个数组显然可以用权值线段树预处理出来:
f1[i]:就是每次在加入a[i]之前,先看一下线段树中以小于等于a[i]的值结尾的最长不下降子序列的长度的最大值,然后在这个基础上+1即可得到
f2[i]:这个要从后往前遍历,这个是在每次加入a[i]之前,先看一下线段树中以大于等于a[i]的值开头的最长不下降子序列的长度的最大值,然后在这个基础上+1即可得到
注意当求出这个值后要用f数组对权值线段树进行更新
那么我们枚举前半段区间的最长不下降子序列端点i,那么也就代表最长不下降子序列是由a[1~i]中的一部分和[i+1~i+k]中的全部以及a[i+k+1,n]中的一部分组成,由于我们枚举的前半段区间的最长不下降子序列的末尾,那么我们就要在区间[i+k+1,n]中找到以大于等于a[i]的值开头的最长不下降子序列的长度最大值,这个直接在求解f2[]过程中刚好可以利用权值线段树得到。
答案还有可能就是只有两段区间,这个要分两种情况,一种是只有a[1~i]中的一部分和[i+1~i+k]中的全部,或者是只有[i+1~i+k]中的全部以及a[i+k+1,n]中的一部分组成,这两种情况直接用for循环遍历一遍即可得到,无非就是一种只用到f1[],另一种只用到f2[]。
细节见代码:
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<map>
#include<queue>
#include<vector>
#include<cmath>
using namespace std;
const int N=5e5+10;
int l[N],r[N],mx[N];
int a[N];
int f1[N],f2[N];
/*
f1[i]表示a[1~i]中以a[i]结尾的最长不下降子序列的长度
f2[i]表示a[i~n]中以a[i]开头的最长不下降子序列的长度
*/
vector<int>alls;
int find(int x)
{
return lower_bound(alls.begin(),alls.end(),x)-alls.begin()+1;
}
void pushup(int id)
{
mx[id]=max(mx[id<<1],mx[id<<1|1]);
}
void build(int id,int L,int R)
{
l[id]=L;r[id]=R;mx[id]=0;
if(L==R) return ;
int mid=L+R>>1;
build(id<<1,L,mid);
build(id<<1|1,mid+1,R);
pushup(id);
}
void update_point(int id,int pos,int val)
{
if(l[id]==r[id])
{
mx[id]=val;
return ;
}
int mid=l[id]+r[id]>>1;
if(pos<=mid) update_point(id<<1,pos,val);
else update_point(id<<1|1,pos,val);
pushup(id);
}
int query_interval(int id,int L,int R)
{
if(l[id]>=L&&r[id]<=R) return mx[id];
int mid=l[id]+r[id]>>1;
int ans=0;
if(L<=mid) ans=max(ans,query_interval(id<<1,L,R));
if(mid+1<=R) ans=max(ans,query_interval(id<<1|1,L,R));
return ans;
}
int main()
{
int n,k;
cin>>n>>k;
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
alls.push_back(a[i]);
}
sort(alls.begin(),alls.end());
alls.erase(unique(alls.begin(),alls.end()),alls.end());
for(int i=1;i<=n;i++)
a[i]=find(a[i]);
build(1,1,alls.size());
for(int i=1;i<=n;i++)
{
f1[i]=query_interval(1,1,a[i])+1;
update_point(1,a[i],f1[i]);
}
int ans=0;
build(1,1,alls.size());
for(int i=n;i>=1;i--)
{
f2[i]=query_interval(1,a[i],alls.size())+1;
update_point(1,a[i],f2[i]);
if(i>k)
{
ans=max(ans,f1[i-k-1]+k+query_interval(1,a[i-k-1],alls.size()));
ans=max(ans,k+f2[i]);
}
if(i+k<=n) ans=max(ans,k+f1[i]);
}
printf("%d\n",ans);
return 0;
}