bzoj 4293: [PA2015]Siano line tree

Posted by persia on Thu, 02 Apr 2020 02:19:07 +0200

meaning of the title

There are n grasses, the height is 0 at the beginning. Each grass has a growth height of a[i], and m times of harvesting, in which the I time of harvesting is on the d[i] day, and all parts with a height of b[i] or higher are cut off. Find the total height of the grass you get from each harvest.
n,m<=500000

Analysis

It is noted that no matter how the grass is cut, the height of the grass with high growth rate at any time will not be less than that of the grass with low growth rate.
Then we can sort the grass according to the growth speed, and the grass to be cut each time becomes an interval, which can be maintained by line tree.
Each node of the segment tree records the last modification time and the modified height of the rightmost node of the segment, and records the height growth rate * modification time of all nodes at the same time.

Code

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
using namespace std;

typedef long long LL;

const int N=500005;

int n,m;
LL a[N],sum[N];
struct tree{LL val,h,tim,tag_h,tag_tim;}t[N*4];

bool check(int d,LL tim,LL h,int id)
{
    return t[d].h+a[id]*(tim-t[d].tim)>h;
}

void mark(int d,LL tim,LL h,int l,int r)
{
    t[d].h=t[d].tag_h=h;t[d].tim=t[d].tag_tim=tim;
    t[d].val=(r-l+1)*h-tim*(sum[r]-sum[l-1]);
}

void pushdown(int d,int l,int r)
{
    if (!t[d].tag_tim) return;
    int mid=(l+r)/2;LL tim=t[d].tag_tim,h=t[d].tag_h;
    t[d].tag_tim=t[d].tag_h=0;
    mark(d*2,tim,h,l,mid);mark(d*2+1,tim,h,mid+1,r);
}

LL solve(int d,int l,int r,LL tim,LL h)
{
    if (l==r)
    {
        if (check(d,tim,h,r))
        {
            LL ans=t[d].h+a[l]*(tim-t[d].tim)-h;
            mark(d,tim,h,l,r);return ans;
        }
        return 0;
    }
    pushdown(d,l,r);
    int mid=(l+r)/2;LL ans=0;
    if (check(d*2,tim,h,mid))
    {
        ans+=t[d*2].val+tim*(sum[mid]-sum[l-1])-h*(mid-l+1);
        mark(d*2,tim,h,l,mid);
        ans+=solve(d*2+1,mid+1,r,tim,h);
    }
    else ans+=solve(d*2,l,mid,tim,h);
    t[d].val=t[d*2].val+t[d*2+1].val;
    t[d].tim=t[d*2+1].tim;t[d].h=t[d*2+1].h;
    return ans;
}

int main()
{
    scanf("%d%d",&n,&m);
    for (int i=1;i<=n;i++) scanf("%lld",&a[i]),a[i]*=-1;
    sort(a+1,a+n+1);
    for (int i=1;i<=n;i++) a[i]*=-1,sum[i]=sum[i-1]+a[i];
    while (m--)
    {
        LL tim,h;scanf("%lld%lld",&tim,&h);
        printf("%lld\n",solve(1,1,n,tim,h));
    }
    return 0;
}

Topics: less