Tree chain segmentation (to be continued)

Posted by broomstick on Sat, 09 Nov 2019 18:32:14 +0100

Time complexity:

Preprocess O (n + NlogN), modify and query O (logN)

Principle of tree chain segmentation:

  1. The first time, find the depth dep of all nodes, record the parent node fa of each node, and record the size of the subtree represented by each node [including the node itself]
  2. The second pass of dfs is to record the heavy chain of each node, record the chain head top[x] of the chain where the current node x is, the POS [x] of each node in the dfs order, and what Id [POS [x] = x is the corresponding node of POS [x]
  3. Update a segment of the chain between u and v, judge the depth of u and v's chain head, move up every time the logn moves, until the chain heads of the two nodes are the same, judge the current depth of u and v, and update
  4. The query operation is the same as above, which means using query function instead

 

 

Tree problems that can be solved by tree chain segmentation:

  • Each time u and v are selected, all points or all edges on the shortest path are weighted
  • Select a point x, and weight point X and all points on its subtree
  • Calculate the sum of point weight and edge weight on the path between u and v points on the tree
  • To be continued

 

Basic example:

Luogu P3386 (point right)

Operation type:

 

AC Code:

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <string>
#include <cmath>
#include <iostream>
#include <algorithm>
#include <queue>
#include <map>
#include <stack>
#include <vector>
using namespace std;
typedef long long LL;
const int Maxn = 2e5 + 10;
const int Inf = 1e9 + 7;
int N , M , cnt , root;
int Mod;
int pos[Maxn] , top[Maxn] , sz[Maxn] , dep[Maxn] , fa[Maxn] , wson[Maxn] , id[Maxn];
vector <int> G[Maxn<<2];

//Section tree
int A[Maxn];
struct edge{
    int l , r;
    int sum , lazy;
}tree[Maxn<<2];

void PushUp( int x ){
    tree[x].sum = tree[x<<1].sum + tree[x<<1|1].sum;
}

void Build( int l , int r , int x ){
    tree[x].lazy = 0;
    tree[x].l = l , tree[x].r = r;
    if( l == r ){
        tree[x].sum = A[id[l]];
        return;
    }
    int mid = ( l + r ) >> 1;
    Build( l , mid , x << 1 );
    Build( mid + 1 , r , x<<1|1 );
    tree[x].sum = tree[x<<1].sum + tree[x<<1|1].sum;
    //PushUp( x );
}

void PushDown( int x ){
    if(tree[x].lazy){
        tree[x<<1].lazy += tree[x].lazy;
        tree[x<<1|1].lazy += tree[x].lazy;
        int mid = ( tree[x].r + tree[x].l ) / 2;
        tree[x<<1].sum += ( mid - tree[x].l + 1 ) * tree[x].lazy;
        tree[x<<1|1].sum += ( tree[x].r - mid ) * tree[x].lazy;
        tree[x<<1].sum %= Mod;
        tree[x<<1|1].sum %= Mod;
        tree[x].lazy = 0;
    }
}

void Update_line( int L , int R , int add , int x ){
    if( L <= tree[x].l && tree[x].r <= R ){
        tree[x].sum += add * ( tree[x].r - tree[x].l + 1 );
        tree[x].lazy += add;
        tree[x].sum %= Mod;
        return;
    }
    PushDown( x );
    int mid = ( tree[x].l + tree[x].r ) >> 1;
    if( L <= mid )	Update_line( L , R , add , x <<1 );
    if( R > mid )	Update_line( L , R , add , x <<1|1 );
    tree[x].sum = tree[x<<1].sum + tree[x<<1|1].sum;
    //PushUp( x );
}

int Query( int L , int R , int x ){
    if( L <= tree[x].l && tree[x].r <= R ){
        tree[x].sum %= Mod;
        return tree[x].sum;
    }
    PushDown( x );
    int mid = ( tree[x].l + tree[x].r ) >> 1;
    int res = 0;
    if( L <= mid )	res += Query( L , R , x<<1 );
    if( R > mid )	res += Query( L , R , x<<1| 1 );
    return res;
}


//Tree chain section
/*
 * First time dfs1
 * Calculate the number of nodes of each sz -- subtree (including itself) of each point
 * Calculate the depth dep of each point
 * Calculate the fa parent node of each point
 * Interface: x is the root node, fat is the parent node, dept is the depth of the root node (1)
 */
void dfs1(int x , int fat , int dept){
    dep[x] = dept , fa[x] = fat , sz[x] = 1;
    int Size = G[x].size() , v;
    for(int i = 0 ; i < Size ; i++){
        v = G[x][i];
        if(v == fat)	continue;
        dfs1(v , x , dept+1);
        sz[x] += sz[v];
        if(wson[x] == -1 || sz[wson[x]] < sz[v])	wson[x] = v;
    }
}
/*
 * Second pass of dfs2
 * Calculate the top[x] of the heavy chain (or light chain) where each point is located
 * Calculate the position of each edge in the sequence pos[x] (x to the position of its father's edge in the sequence, and all the edges form a sequence)
 * id[pos[x]] Record the corresponding points of each number in the edge sequence
 */
void dfs2(int x , int line_top){
    top[x] = line_top, pos[x] = ++cnt, id[pos[x]] = x;
    if(wson[x] == -1)	return;
    dfs2(wson[x] , line_top);
    int Size = G[x].size() , v;
    for(int i = 0 ; i < Size ; i++){
        v = G[x][i];
        if(v != fa[x] && v != wson[x])	dfs2(v , v);
    }
}

int updata0_query1(int u , int v , int c , int op){
    int ans = 0;
    while(top[u] != top[v]){
        if(dep[top[u]] > dep[top[v]])	swap(u,v);
        if(!op)	Update_line(pos[top[v]] , pos[v] , c , 1);
        else ans += Query(pos[top[v]] , pos[v] , 1) , ans %= Mod;
        v = fa[top[v]];
    }
    if(dep[u] > dep[v])	swap(u,v);
    if(op == 0)	Update_line(pos[u] , pos[v] , c , 1);
    else	ans += Query(pos[u] , pos[v] , 1) , ans %= Mod;
    return ans % Mod;
}

int main()
{
    cnt = 0;
    memset(wson , -1 , sizeof(wson));
    scanf(" %d %d %d %d",&N,&M,&root,&Mod);
    for(int i = 1 ; i <= N ; i++)	scanf(" %d",&A[i]);
    int u , v;
    for(int i = 1 ; i < N ; i++){
        scanf(" %d %d",&u,&v);
        G[u].push_back(v);
        G[v].push_back(u);
    }
    dfs1(root , -1 , 1);
    dfs2(root , root);
    Build(1 , N , 1);
    int op , x , y , z;
    for(int i = 1 ; i <= M ; i++){
        scanf(" %d",&op);
        if(op == 1){
            scanf(" %d %d %d",&x , &y , &z);
            updata0_query1(x , y , z , 0);
        } else if(op == 2){
            scanf(" %d %d",&x , &y);
            printf("%d\n",updata0_query1(x,y,0,1));
        } else if(op == 3){
            scanf(" %d %d",&x,&z);
            Update_line(pos[x] , pos[x] + sz[x] - 1 , z , 1);
        } else if(op == 4){
            scanf(" %d",&x);
            printf("%d\n",Query(pos[x] , pos[x] + sz[x] - 1 , 1) % Mod);
        }
    }
}