Maximum spanning tree count
Recently, I did a topic about the maximum spanning tree count. Here is the method to solve the maximum number of spanning trees.
meaning of the title
Give you an undirected graph with \ (n \) points \ (m \) edges. Each edge has a weight. Find the maximum number of spanning trees in this graph.
The answer is to take the mold for \ (998244353 \).
Problem solution
We consider one thing, that is, for any maximum spanning tree of a given graph, for a fixed weight \ (w \), they have the same number of edges with the corresponding weight \ (w \).
So, for each weight, count it separately.
We consider that we have done some large weights, and now we want to count the edges of \ (w \) weights.
Firstly, for the previous edges greater than \ (w \), some connected blocks will be formed. For the edge of \ (w \), we only count the spanning tree formed by connecting the edges of two different connected blocks (here, the previous connected block is regarded as a point), and then use the matrix tree theorem to get the answer.
It should be noted that when we do \ (w \), the graph is not connected. Therefore, we need to use the matrix tree theorem to solve the connected block of each edge greater than or equal to \ (w \).
Here is an example.
In this topic, we only need to connect an edge of \ (w(i,j) \) between two points \ (x,y \), where \ (w(i,j) \) is how many restrictions contain both \ (I \) and \ (j \) points. Then, a spanning tree that meets the requirements is a spanning tree with a weight of \ (\ sum S_i - K \), and at the same time, the weight of this spanning tree must be the largest, so we can count the maximum spanning tree.
#include <bits/stdc++.h> const int MAXN = 505, MOD = 998244353; using std::cin; using std::cout; using std::bitset; using std::sort; using std::vector; using std::swap; struct Edge { int x, y, w, key; } e[MAXN * MAXN]; int N, K, ecnt, f1[MAXN], f2[MAXN], tot, id[MAXN], A[MAXN][MAXN]; char s[MAXN]; bitset<2010> bs[MAXN]; vector<Edge> v[MAXN * MAXN]; int find1(int x) { return f1[x] == x ? x : f1[x] = find1(f1[x]); } int find2(int x) { return f2[x] == x ? x : f2[x] = find2(f2[x]); } auto Mod = [] (int x) { if (x >= MOD) { return x - MOD; } else if (x < 0) { return x + MOD; } else { return x; } }; auto Ksm = [] (int x, int y) -> int { int ret = 1; for (; y; y >>= 1, x = (long long) x * x % MOD) { if (y & 1) { ret = (long long) ret * x % MOD; } } return ret; }; int det(int m) { int ret = 1; for (int i = 1; i <= m; ++i) { for (int j = i; j <= m; ++j) { if (A[j][i]) { for (int k = i; k <= m; ++k) { swap(A[i][k], A[j][k]); } if (j != i) { ret = Mod(-ret); } break; } } ret = (long long) ret * A[i][i] % MOD; int invl = Ksm(A[i][i], MOD - 2); for (int j = i + 1; j <= m; ++j) { if (A[j][i]) { int mul = (long long) invl * A[j][i] % MOD; for (int k = i; k <= m; ++k) { A[j][k] = Mod(A[j][k] - (long long) mul * A[i][k] % MOD); } } } } return ret; } int calc(int x) { int cnt = 0; for (auto &i: v[x]) { if (!id[find2(i.x)]) { id[f2[i.x]] = ++cnt; } if (!id[find2(i.y)]) { id[f2[i.y]] = ++cnt; } int x = id[f2[i.x]], y = id[f2[i.y]]; A[x][x] = Mod(A[x][x] + i.w); A[y][y] = Mod(A[y][y] + i.w); A[x][y] = Mod(A[x][y] - i.w + MOD); A[y][x] = Mod(A[y][x] - i.w + MOD); } for (auto &i: v[x]) { id[f2[i.x]] = 0; id[f2[i.y]] = 0; } for (auto &i: v[x]) { if (find2(i.x) != find2(i.y)) { f2[f2[i.x]] = f2[i.y]; } } int ret = det(cnt - 1); for (int i = 1; i <= cnt; ++i) { for (int j = 1; j <= cnt; ++j) { A[i][j] = 0; } } return ret; } int main() { std::ios::sync_with_stdio(0); cin.tie(0); cout.tie(0); cin >> N >> K; for (int i = 1, x; i <= N - 1; ++i) { for (int j = i + 1; j <= N; ++j) { cin >> x; e[++ecnt] = {i, j, x, 0}; } } int S = 0; for (int i = 1; i <= K; ++i) { cin >> s + 1; for (int j = 1; j <= N; ++j) { bs[j].set(i, s[j] == '1'); S += s[j] == '1'; } } for (int i = 1; i <= ecnt; ++i) { e[i].key = (bs[e[i].x] & bs[e[i].y]).count(); } for (int i = 1; i <= N; ++i) { f1[i] = i; } sort(e + 1, e + 1 + ecnt, [&] (const Edge &a, const Edge &b) -> int { return a.key > b.key; }); int singercoder = 0; for (int i = 1; i <= ecnt; ++i) { if (find1(e[i].x) != find1(e[i].y)) { f1[f1[e[i].x]] = f1[e[i].y]; singercoder += e[i].key; } } if (singercoder != S - K) { cout << 0 << '\n'; return 0; } for (int i = 1; i <= N; ++i) { f1[i] = f2[i] = i; } int ANS = 1; for (int i = 1, p; i <= ecnt; i = p + 1) { p = i; while (p < ecnt && e[p + 1].key == e[i].key) { ++p; } for (int j = i; j <= p; ++j) { if (find1(e[j].x) != find1(e[j].y)) { f1[f1[e[j].x]] = f1[e[j].y]; } } for (int j = i; j <= p; ++j) { if (find2(e[j].x) == find2(e[j].y)) { continue; } if (!id[find1(e[j].x)]) { id[find1(e[j].x)] = ++tot; } v[id[f1[e[j].x]]].push_back(e[j]); } for (int j = i; j <= p; ++j) { id[f1[e[j].x]] = 0; } for (int j = 1; j <= tot; ++j) { ANS = (long long) ANS * calc(j) % MOD; v[j].clear(); } } cout << ANS << '\n'; return 0; }