P3384 [Template] Tree Chain Partitioning - (Introduction to Tree Chain Partitioning Template)

Posted by ovisopa on Wed, 04 Sep 2019 11:31:57 +0200

Topic Links

Understand

Chain partitioning is similar to DFS sequence, which converts tree shape into line shape and then maintains it with segment tree. The only difference between tree chain partitioning and DFS sequence is the direction in which they go: DFS sequence takes the sub-nodes in order, and tree chain partitioning takes the son first. The others are basically the same.

There are several problems: finding all the values of a node's subtree. DFS order is continuous, so there is no advantage of tree chain partition.

Find the sum on the path of two nodes. Here we can use LCA to solve this problem. By recording the distance from the node to the root with dis array, we can get the distance between dis [x] + dis [y] - 2 * dis [LCA]. But if you add a modification operation, you need to recalculate the dis array every time, and the complexity explodes. Chain splitting is used here!

DFS sequence is in fact block by block in the online segment tree. In the continuous chain on the tree, the online segment tree is meaningful. Or we can divide the tree into chains, a chain of small maintenance trees on the online segment tree. But the chain of DFS order is usually very short, or the length of the chain is determined by the DFS order. We want to make the chain as long as possible, so that the corresponding label of the long chain on the segment tree is continuous.

Doing so will result in the least number of chains (logN), which will save time (think of N chains, no optimization at all). Because in querying the path of two points, if two points are not on a heavy chain, ans adds x to the top of the chain where x is located; if two points are on a heavy chain, add the interval sum of the two points at this time.

Understand two DFS & Graphics. https://www.cnblogs.com/ivanovcraft/p/9019090.html

The code has detailed comments: https://www.cnblogs.com/chinhhh/p/7965433.html

#include <bits/stdc++.h>
#define ll long long
using namespace std;

int mod;
const int N=2e5+10;
struct Edge{
    int next,to;
}edge[2*N];
int head[N],tot;
void addEdge(int from,int to)
{
    edge[tot].to = to; edge[tot].next = head[from];
    head[from] = tot++;
}

int cnt,f[N],d[N],siz[N],son[N],id[N],rk[N],top[N];
void dfs1(int u,int fa,int dep)
{
    f[u] = fa;
    d[u] = dep;
    siz[u] = 1;
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v = edge[i].to;
        if(v==fa) continue;
        dfs1(v,u,dep+1);
        siz[u] += siz[v];
        if(siz[v]>siz[son[u]]) son[u] = v;
    }
}
void dfs2(int u,int t)
{
    top[u] = t;
    id[u] = ++cnt;
    rk[cnt] = u;
    if(!son[u]) return;
    dfs2(son[u],t);
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v = edge[i].to;
        if(v!=son[u]&&v!=f[u]) dfs2(v,v);
    }
}

struct{
    int l,r;
    ll w,lazy;
}tr[N*4];
int a[N];
void build(int k,int l,int r)
{
    tr[k].l=l; tr[k].r=r; tr[k].lazy=0;
    if(l==r){
        tr[k].w = a[rk[l]]%mod;
        return;
    }
    int mid = (l+r)/2;
    build(k*2,l,mid);
    build(k*2+1,mid+1,r);
    tr[k].w = (tr[k*2].w + tr[k*2+1].w)%mod;
}

void pushdown(int k)
{
    if(tr[k].lazy){
        tr[k*2].lazy += tr[k].lazy;
        tr[k*2+1].lazy += tr[k].lazy;
        tr[k*2].lazy %= mod;
        tr[k*2].lazy %= mod;
        tr[k*2].w += (tr[k*2].r-tr[k*2].l+1)*tr[k].lazy;
        tr[k*2+1].w += (tr[k*2+1].r-tr[k*2+1].l+1)*tr[k].lazy;
        tr[k*2].w %= mod;
        tr[k*2+1].w %= mod;
        tr[k].lazy = 0;
    }
}

void update(int k,int l,int r,int v)
{
    v %= mod;
    if(l>tr[k].r||r<tr[k].l) return;
    if(l<=tr[k].l&&tr[k].r<=r){
        tr[k].lazy += v;
        tr[k].lazy %= mod;
        tr[k].w +=(tr[k].r-tr[k].l+1)*v%mod;
        return;
    }
    pushdown(k);
    update(k*2,l,r,v);
    update(k*2+1,l,r,v);
    tr[k].w = (tr[k*2].w + tr[k*2+1].w)%mod;
}

int query(int k,int l,int r)
{
    if(l>tr[k].r||r<tr[k].l) return 0;
    pushdown(k);
    if(l<=tr[k].l&&tr[k].r<=r) return tr[k].w%mod;
    return (query(k*2,l,r)+query(k*2+1,l,r))%mod;
}

void uprange(int x,int y,int v)
{
    v %= mod;
    while (top[x]!=top[y])
    {
        if(d[top[x]]<d[top[y]]) swap(x, y);
        update(1,id[top[x]],id[x],v);
        x = f[top[x]];
    }
    if(d[x]>d[y]) swap(x, y);
    update(1,id[x],id[y],v);
}

int qrange(int x,int y)
{
    int ans = 0;
    while(top[x]!=top[y])
    {
        if(d[top[x]]<d[top[y]]) swap(x, y);
        ans += query(1,id[top[x]],id[x]);
        ans %= mod;
        x = f[top[x]];
    }
    if (d[x]>d[y]) swap(x, y);
    ans += query(1,id[x],id[y]);
    return ans%mod;
}

void upson(int k,int v)
{
    update(1,id[k],id[k]+siz[k]-1,v);
}
int qson(int k)
{
    int res = query(1,id[k],id[k]+siz[k]-1);
    return res%mod;
}

int main()
{
    tot = 0;
    memset(head,-1,sizeof(head));

    int n,m,root,p;
    scanf("%d%d%d%d",&n,&m,&root,&mod);

    for(int i=1;i<=n;i++) scanf("%d",&a[i]);
    for(int i=1;i<n;i++){
        int x,y; scanf("%d%d",&x,&y);
        addEdge(x,y);
        addEdge(y,x);
    }

    dfs1(root,0,1);
    dfs2(root,root);
    build(1,1,n);

    int op,x,y,z;
    while(m--){
        scanf("%d",&op);
        if(op==1){
            scanf("%d%d%d",&x,&y,&z);
            uprange(x,y,z);
        }
        else if(op==2){
            scanf("%d%d",&x,&y);
            printf("%d\n",qrange(x,y));
        }
        else if(op==3){
            scanf("%d%d",&x,&z);
            upson(x,z);
        }
        else if(op==4){
            scanf("%d",&x);
            printf("%d\n",qson(x));
        }
    }
	return 0;
}