Data structure point divide and conquer

Posted by DangerousDave86 on Fri, 31 Dec 2021 16:13:08 +0100

Original title
Given a tree with N points (No. 0,1,..., N − 1), each edge has a weight (no more than 1000).

The path length between two nodes x and y in the tree is the sum of the weights of each edge on the path.

Find the number of paths whose length does not exceed K.

Input format
The input contains multiple sets of test cases.

The first line of each set of test cases contains two integers N and K.

The next N − 1 line contains three integers u, v and l, indicating that there is an edge between nodes u and v, and the weight of the edge is L.

When the input case N=0 and K=0, it means that the input is terminated and the case does not need to be processed.

Output format
Each test case outputs a result.

One line for each result.

Data range
1≤N≤104,
1≤K≤5×106,
1≤l≤103
Input example:
5 4
0 1 3
0 2 1
0 3 2
2 4 1
0 0
Output example:
8

First of all, we should know several basic concepts and properties:
Center of gravity of the tree: delete a point of the tree to generate many connected blocks, so that the point with the maximum and minimum number of points in all connected blocks is the center of gravity of the tree
Nature: among all connected blocks generated after deleting the center of gravity of the tree, the number of connected blocks with the largest number of points is < = n / 2, and N is the node tree of the tree

===========================
It is proved that by using the counter proof method, it is assumed that the number of points of the generated most common block is n and N > n / 2
Then, if we delete the point connected to the current center of gravity and belonging to the largest pass block as the center of gravity
And restore the original center of gravity
Then the number of points of the most Dalian pass block generated by deleting the new center of gravity is N-1, which can prove that the point of the most Dalian pass block < = n / 2

============================

For the number of paths < = K in the question, we can divide the paths into three categories
The first type: the two endpoints of the path are in the same subtree (solved by recursion)
The second type: the two endpoints of the path are not in the same subtree
First find the distance from each point in the subtree to the center of gravity, and then add the distances of any two points from all points to judge whether they are less than K. however, if so, some illegal situations will occur (the two selected points are in the same subtree) Therefore, we can use the inclusion exclusion principle to subtract the number of two points in the same tree (if it is troublesome to directly find the satisfaction of selecting points from different trees)
The third category: one point is at the center of gravity (very easy to handle, just traverse from the center of gravity to the subtree and find the distance of each point)

The last problem to be solved: give a pile of numbers, and choose the total number of schemes whose sum is less than K
We can sort these numbers, and then enumerate one of them and divide the other by two This will reduce O(n^2) to O(nlogn)
You can also use the double pointer algorithm after sorting

The following is the code. It's better to read it again

#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 10010, M = 2 * N;
int n, m;
int h[N], e[M], w[M], ne[M], idx;
bool st[N];//Indicates whether a point has been deleted

int p[N];//The distance from the current center of gravity to all child nodes
int q[N];//Distance of current child node

void add(int a, int b, int c) {
	e[idx] = b, ne[idx] = h[a], w[idx] = c, h[a] = idx++;
}


int get_size(int u, int fa) { //Subtree size
	if (st[u])
		return 0;
	int res = 1;
	for (int i = h[u]; ~i; i = ne[i]) {
		if (e[i] != fa) {
			res += get_size(e[i], u);
		}
	}
	return res;
}

//                      Subtree size
int get_wc(int u, int fa, int tot, int &wc) { //Find the center of gravity
	//As long as the point is deleted, the maximum pass block is less than n/2, which can meet the time complexity, and there is no need to find the real center of gravity
	if (st[u])
		return 0;
	int sum = 1; //Current total number of nodes in the entire tree
	int ms = 0; //The number of points of the most common block after deleting the u point
	for (int i = h[u]; ~i; i = ne[i]) {
		int j = e[i];
		if (j == fa)
			continue;
		int t = get_wc(j, u, tot, wc);
		ms = max(ms, t);
		sum += t;
	}
	ms = max(ms, tot - sum); //Also consider the size of the connected block where its parent node is located
	if (ms <= tot / 2)
		wc = u;
	return sum;
}

void get_dist(int u, int fa, int dist, int &qt) {
	if (st[u])
		return ;
	q[qt++] = dist;
	for (int i = h[u]; ~i; i = ne[i]) {
		if (e[i] != fa) {
			get_dist(e[i], u, dist + w[i], qt);
		}
	}
}

int get(int a[], int k) { //Array, size
	sort(a, a + k);
	int res = 0;
	for (int i = k - 1, j = -1; i >= 0; i--) { //Double pointer algorithm for the number of schemes
		while (j + 1 < i && a[j + 1] + a[i] <= m)
			j++;
		j = min(j, i - 1); //j cannot be greater than i
		res += j + 1; //The number of schemes is from 0 to j
	}
	return res;
}

int calc(int u) { //The tree where the processing u is located
	if (st[u])
		return 0;
	int res = 0; //Indicates how many pairs meet the requirements in the current subtree
	get_wc(u, -1, get_size(u, -1), u); //Find the center of gravity
	st[u] = true;

	//Merging part
	int pt = 0; //Size of p array
	for (int i = h[u]; ~i; i = ne[i]) {
		int j = e[i], qt = 0; //The size of the q array is similar to idx
		get_dist(j, -1, w[i], qt); //Traverse the distance from all points to the center of gravity in the subtree where j is located
		//            ↑ distance of the first edge
		res -= get(q, qt); //Subtracting illegal cases, get is used to find a pile of numbers and find the number of schemes in which any two numbers are less than K
		for (int k = 0; k < qt; k++) { //Integrate parts into the overall
			if (q[k] <= m)
				res++;//For the third case, one point is the center of gravity
			p[pt++] = q[k];
		}

	}
	res += get(p, pt); //The number of schemes (including illegal cases), but the illegal schemes have been deducted in advance
	for (int i = h[u]; ~i; i = ne[i]) {
		res += calc(e[i]);

	}

	return res;
}

int main() {
	while (cin >> n >> m) {
		if (n == 0 && m == 0)
			break;
		memset(h, -1, sizeof h);
		memset(st, false, sizeof st);
		idx = 0;
		for (int i = 0; i < n - 1; i++) {
			int a, b, c;
			cin >> a >> b >> c;
			add(a, b, c), add(b, a, c);
		}
		cout << calc(0) << endl;
	}
	return 0;


}

Topics: data structure