Maximum spanning tree count

Posted by jmcall10 on Tue, 18 Jan 2022 01:58:26 +0100

Maximum spanning tree count

Recently, I did a topic about the maximum spanning tree count. Here is the method to solve the maximum number of spanning trees.

meaning of the title

Give you an undirected graph with \ (n \) points \ (m \) edges. Each edge has a weight. Find the maximum number of spanning trees in this graph.

The answer is to take the mold for \ (998244353 \).

Problem solution

We consider one thing, that is, for any maximum spanning tree of a given graph, for a fixed weight \ (w \), they have the same number of edges with the corresponding weight \ (w \).

So, for each weight, count it separately.

We consider that we have done some large weights, and now we want to count the edges of \ (w \) weights.

Firstly, for the previous edges greater than \ (w \), some connected blocks will be formed. For the edge of \ (w \), we only count the spanning tree formed by connecting the edges of two different connected blocks (here, the previous connected block is regarded as a point), and then use the matrix tree theorem to get the answer.

It should be noted that when we do \ (w \), the graph is not connected. Therefore, we need to use the matrix tree theorem to solve the connected block of each edge greater than or equal to \ (w \).

Here is an example.

[2022 provincial ten consecutive test Day 3] treecnt - topic - Zhengrui Online Judge (zhengruioi.com)

In this topic, we only need to connect an edge of \ (w(i,j) \) between two points \ (x,y \), where \ (w(i,j) \) is how many restrictions contain both \ (I \) and \ (j \) points. Then, a spanning tree that meets the requirements is a spanning tree with a weight of \ (\ sum S_i - K \), and at the same time, the weight of this spanning tree must be the largest, so we can count the maximum spanning tree.

#include <bits/stdc++.h>
const int MAXN = 505, MOD = 998244353;
using std::cin;
using std::cout;
using std::bitset;
using std::sort;
using std::vector;
using std::swap;
struct Edge {
	int x, y, w, key;
} e[MAXN * MAXN];
int N, K, ecnt, f1[MAXN], f2[MAXN], tot, id[MAXN], A[MAXN][MAXN];
char s[MAXN];
bitset<2010> bs[MAXN];
vector<Edge> v[MAXN * MAXN];
int find1(int x) {
	return f1[x] == x ? x : f1[x] = find1(f1[x]);
}
int find2(int x) {
	return f2[x] == x ? x : f2[x] = find2(f2[x]);
}
auto Mod = [] (int x) {
	if (x >= MOD) {
		return x - MOD;
	}
	else if (x < 0) {
		return x + MOD;
	}
	else {
		return x;
	}
};
auto Ksm = [] (int x, int y) -> int {
	int ret = 1;
	for (; y; y >>= 1, x = (long long) x * x % MOD) {
		if (y & 1) {
			ret = (long long) ret * x % MOD;
		}
	}
	return ret;
};
int det(int m) {
	int ret = 1;
	for (int i = 1; i <= m; ++i) {
		for (int j = i; j <= m; ++j) {
			if (A[j][i]) {
				for (int k = i; k <= m; ++k) {
					swap(A[i][k], A[j][k]);
				}
				if (j != i) {
					ret = Mod(-ret);
				}
				break;
			}
		}
		ret = (long long) ret * A[i][i] % MOD;
		int invl = Ksm(A[i][i], MOD - 2);
		for (int j = i + 1; j <= m; ++j) {
			if (A[j][i]) {
				int mul = (long long) invl * A[j][i] % MOD;
				for (int k = i; k <= m; ++k) {
					A[j][k] = Mod(A[j][k] - (long long) mul * A[i][k] % MOD);
				}
			}
		}
	}
	return ret;
}
int calc(int x) {
	int cnt = 0;
	for (auto &i: v[x]) {
		if (!id[find2(i.x)]) {
			id[f2[i.x]] = ++cnt;
		}
		if (!id[find2(i.y)]) {
			id[f2[i.y]] = ++cnt;
		}
		int x = id[f2[i.x]], y = id[f2[i.y]];
		A[x][x] = Mod(A[x][x] + i.w);
		A[y][y] = Mod(A[y][y] + i.w);
		A[x][y] = Mod(A[x][y] - i.w + MOD);
		A[y][x] = Mod(A[y][x] - i.w + MOD);
	}
	for (auto &i: v[x]) {
		id[f2[i.x]] = 0;
		id[f2[i.y]] = 0;
	}
	for (auto &i: v[x]) {
		if (find2(i.x) != find2(i.y)) {
			f2[f2[i.x]] = f2[i.y];
		}
	}
	int ret = det(cnt - 1);
	for (int i = 1; i <= cnt; ++i) {
		for (int j = 1; j <= cnt; ++j) {
			A[i][j] = 0;
		}
	}
	return ret;
}
int main() {
	std::ios::sync_with_stdio(0);
	cin.tie(0);
	cout.tie(0);
	cin >> N >> K;
	for (int i = 1, x; i <= N - 1; ++i) {
		for (int j = i + 1; j <= N; ++j) {
			cin >> x;
			e[++ecnt] = {i, j, x, 0};
		}
	}
	int S = 0;
	for (int i = 1; i <= K; ++i) {
		cin >> s + 1;
		for (int j = 1; j <= N; ++j) {
			bs[j].set(i, s[j] == '1');
			S += s[j] == '1';
		}
	}
	for (int i = 1; i <= ecnt; ++i) {
		e[i].key = (bs[e[i].x] & bs[e[i].y]).count();
	}
	for (int i = 1; i <= N; ++i) {
		f1[i] = i;
	}
	sort(e + 1, e + 1 + ecnt, [&] (const Edge &a, const Edge &b) -> int {
		return a.key > b.key;
	});
	int singercoder = 0;
	for (int i = 1; i <= ecnt; ++i) {
		if (find1(e[i].x) != find1(e[i].y)) {
			f1[f1[e[i].x]] = f1[e[i].y];
			singercoder += e[i].key;
		}
	}
	if (singercoder != S - K) {
		cout << 0 << '\n';
		return 0;
	}
	for (int i = 1; i <= N; ++i) {
		f1[i] = f2[i] = i;
	}
	int ANS = 1;
	for (int i = 1, p; i <= ecnt; i = p + 1) {
		p = i;
		while (p < ecnt && e[p + 1].key == e[i].key) {
			++p;
		}
		for (int j = i; j <= p; ++j) {
			if (find1(e[j].x) != find1(e[j].y)) {
				f1[f1[e[j].x]] = f1[e[j].y];
			}
		}
		for (int j = i; j <= p; ++j) {
			if (find2(e[j].x) == find2(e[j].y)) {
				continue;
			}
			if (!id[find1(e[j].x)]) {
				id[find1(e[j].x)] = ++tot;
			}
			v[id[f1[e[j].x]]].push_back(e[j]);
		}
		for (int j = i; j <= p; ++j) {
			id[f1[e[j].x]] = 0;
		}
		for (int j = 1; j <= tot; ++j) {
			ANS = (long long) ANS * calc(j) % MOD;
			v[j].clear();
		}
	}
	cout << ANS << '\n';
	return 0;
}