Time complexity:
Preprocess O (n + NlogN), modify and query O (logN)
Principle of tree chain segmentation:
- The first time, find the depth dep of all nodes, record the parent node fa of each node, and record the size of the subtree represented by each node [including the node itself]
- The second pass of dfs is to record the heavy chain of each node, record the chain head top[x] of the chain where the current node x is, the POS [x] of each node in the dfs order, and what Id [POS [x] = x is the corresponding node of POS [x]
- Update a segment of the chain between u and v, judge the depth of u and v's chain head, move up every time the logn moves, until the chain heads of the two nodes are the same, judge the current depth of u and v, and update
- The query operation is the same as above, which means using query function instead
Tree problems that can be solved by tree chain segmentation:
- Each time u and v are selected, all points or all edges on the shortest path are weighted
- Select a point x, and weight point X and all points on its subtree
- Calculate the sum of point weight and edge weight on the path between u and v points on the tree
- To be continued
Basic example:
Luogu P3386 (point right)
Operation type:
AC Code:
#include <cstdio> #include <cstdlib> #include <cstring> #include <string> #include <cmath> #include <iostream> #include <algorithm> #include <queue> #include <map> #include <stack> #include <vector> using namespace std; typedef long long LL; const int Maxn = 2e5 + 10; const int Inf = 1e9 + 7; int N , M , cnt , root; int Mod; int pos[Maxn] , top[Maxn] , sz[Maxn] , dep[Maxn] , fa[Maxn] , wson[Maxn] , id[Maxn]; vector <int> G[Maxn<<2]; //Section tree int A[Maxn]; struct edge{ int l , r; int sum , lazy; }tree[Maxn<<2]; void PushUp( int x ){ tree[x].sum = tree[x<<1].sum + tree[x<<1|1].sum; } void Build( int l , int r , int x ){ tree[x].lazy = 0; tree[x].l = l , tree[x].r = r; if( l == r ){ tree[x].sum = A[id[l]]; return; } int mid = ( l + r ) >> 1; Build( l , mid , x << 1 ); Build( mid + 1 , r , x<<1|1 ); tree[x].sum = tree[x<<1].sum + tree[x<<1|1].sum; //PushUp( x ); } void PushDown( int x ){ if(tree[x].lazy){ tree[x<<1].lazy += tree[x].lazy; tree[x<<1|1].lazy += tree[x].lazy; int mid = ( tree[x].r + tree[x].l ) / 2; tree[x<<1].sum += ( mid - tree[x].l + 1 ) * tree[x].lazy; tree[x<<1|1].sum += ( tree[x].r - mid ) * tree[x].lazy; tree[x<<1].sum %= Mod; tree[x<<1|1].sum %= Mod; tree[x].lazy = 0; } } void Update_line( int L , int R , int add , int x ){ if( L <= tree[x].l && tree[x].r <= R ){ tree[x].sum += add * ( tree[x].r - tree[x].l + 1 ); tree[x].lazy += add; tree[x].sum %= Mod; return; } PushDown( x ); int mid = ( tree[x].l + tree[x].r ) >> 1; if( L <= mid ) Update_line( L , R , add , x <<1 ); if( R > mid ) Update_line( L , R , add , x <<1|1 ); tree[x].sum = tree[x<<1].sum + tree[x<<1|1].sum; //PushUp( x ); } int Query( int L , int R , int x ){ if( L <= tree[x].l && tree[x].r <= R ){ tree[x].sum %= Mod; return tree[x].sum; } PushDown( x ); int mid = ( tree[x].l + tree[x].r ) >> 1; int res = 0; if( L <= mid ) res += Query( L , R , x<<1 ); if( R > mid ) res += Query( L , R , x<<1| 1 ); return res; } //Tree chain section /* * First time dfs1 * Calculate the number of nodes of each sz -- subtree (including itself) of each point * Calculate the depth dep of each point * Calculate the fa parent node of each point * Interface: x is the root node, fat is the parent node, dept is the depth of the root node (1) */ void dfs1(int x , int fat , int dept){ dep[x] = dept , fa[x] = fat , sz[x] = 1; int Size = G[x].size() , v; for(int i = 0 ; i < Size ; i++){ v = G[x][i]; if(v == fat) continue; dfs1(v , x , dept+1); sz[x] += sz[v]; if(wson[x] == -1 || sz[wson[x]] < sz[v]) wson[x] = v; } } /* * Second pass of dfs2 * Calculate the top[x] of the heavy chain (or light chain) where each point is located * Calculate the position of each edge in the sequence pos[x] (x to the position of its father's edge in the sequence, and all the edges form a sequence) * id[pos[x]] Record the corresponding points of each number in the edge sequence */ void dfs2(int x , int line_top){ top[x] = line_top, pos[x] = ++cnt, id[pos[x]] = x; if(wson[x] == -1) return; dfs2(wson[x] , line_top); int Size = G[x].size() , v; for(int i = 0 ; i < Size ; i++){ v = G[x][i]; if(v != fa[x] && v != wson[x]) dfs2(v , v); } } int updata0_query1(int u , int v , int c , int op){ int ans = 0; while(top[u] != top[v]){ if(dep[top[u]] > dep[top[v]]) swap(u,v); if(!op) Update_line(pos[top[v]] , pos[v] , c , 1); else ans += Query(pos[top[v]] , pos[v] , 1) , ans %= Mod; v = fa[top[v]]; } if(dep[u] > dep[v]) swap(u,v); if(op == 0) Update_line(pos[u] , pos[v] , c , 1); else ans += Query(pos[u] , pos[v] , 1) , ans %= Mod; return ans % Mod; } int main() { cnt = 0; memset(wson , -1 , sizeof(wson)); scanf(" %d %d %d %d",&N,&M,&root,&Mod); for(int i = 1 ; i <= N ; i++) scanf(" %d",&A[i]); int u , v; for(int i = 1 ; i < N ; i++){ scanf(" %d %d",&u,&v); G[u].push_back(v); G[v].push_back(u); } dfs1(root , -1 , 1); dfs2(root , root); Build(1 , N , 1); int op , x , y , z; for(int i = 1 ; i <= M ; i++){ scanf(" %d",&op); if(op == 1){ scanf(" %d %d %d",&x , &y , &z); updata0_query1(x , y , z , 0); } else if(op == 2){ scanf(" %d %d",&x , &y); printf("%d\n",updata0_query1(x,y,0,1)); } else if(op == 3){ scanf(" %d %d",&x,&z); Update_line(pos[x] , pos[x] + sz[x] - 1 , z , 1); } else if(op == 4){ scanf(" %d",&x); printf("%d\n",Query(pos[x] , pos[x] + sz[x] - 1 , 1) % Mod); } } }