LOJ #6733. Artificial emotion
First consider how to solve \ (W(S) \). If \ (f_ \) is set as the \ (W \) value of the path set in the sub tree \ (u \), there is a transition
Where \ ((x,y,w) \) is in the given route set \ (U \) and meets the condition \ (\ operatorname{lca}(x,y)=u \). If you do it directly, you will die suddenly. You can think about the nature of transfer.
It is observed that this transfer involves the summation of "points hanging below the path", so we can think of the difference on the tree. Let \ (f '_ = f_ - \ sum \ limits {V \ in \ operatorname {ch}_} f_v \), replace \ (f \) with \ (f' \) in the above formula, so the transfer becomes:
Where \ ((x,y,w) \) is in the given route set \ (U \) and meets the condition \ (\ operatorname{lca}(x,y)=u \).
So now you only need to realize the single point chain addition and summation. You can use the tree array to maintain the \ (\ operatorname{dfn} \) order, and the time complexity is \ (\ mathcal O(n\log n) \).
Next, consider how to calculate \ (f(x,y) \). Note that \ (f(x,y)=f_\mathrm{root}-h_{\operatorname{lca}(x,y)}-\sum\limits_{u\notin\operatorname{path}(x,y)\and \operatorname{fa}_u\in\operatorname{path}(x,y)}f_{u} \), where \ (h_ \) is the \ (W \) value of the path set outside the considered subtree \ (U \). Later, the summation of "points hanging below the path" can be directly eliminated by the difference on the tree, with \ (f(x,y)=f_{\mathrm{root}}+\sum\limits_{u\in\operatorname{path}(x,y)}f'_u-f_{\operatorname{lca}(u,v)}-h_{\operatorname{lca}(u,v)}\). The previous value of \ (f,f '\) can be easily obtained. The key is how to calculate \ (h_{u} \).
Similar to the method of handling \ (f \), let's set \ (H '_uu = h_-h {\ operatorname {FA}_} - \ sum \ limits {V \ in \ operatorname {ch}_ {\ operatorname {FA}_} \ and V \ NEQ u} f _ {V} \), then \ (H' \) has a transfer type similar to \ (f '\). We consider enumerating all paths \ ((x,y,w) \) that pass through \ (\ operatorname {fa}_ \) but do not pass through \ (U \), and set \ (z=\operatorname{lca}(x,y) \). For the currently enumerated \ (U \), set \ (g'_v=\begin{cases}f'_v,\qquad v is not the ancestor of u \ \ h'_v,\qquad v is the ancestor of u \ end{cases} \), and it is easy to find that the contribution of path \ ((x,y,w) \) to \ (H' _ \) is \ (W - \ sum \ limits {V \ in \ operatorname {path} (x, y) \ and V \ NEQ Z} g '_ v\). So we get a simple method: first enumerate the nodes \ (U \), then enumerate the paths \ ((x,y,w) \) passing through \ (U \) in the path set \ (U \), and calculate the contribution while updating \ (g '\). The time complexity of this method is \ (\ mathcal O(n^2) \).
Consider optimization. Note that the path \ ((x,y,w) \) in \ (U \) will only contribute to all points hanging below it; Further, its contribution to all sons hanging under the same node is the same. Let node \ (U \) be on the path \ ((x,y) \), and node \ (v \) be a son of node \ (U \), and satisfy that \ (v \) is not on the path \ ((x,y) \). Let \ (z=\operatorname{lca}(x,y) \), then the contribution of the path to \ (v \) is
Where \ (F,H \) are the ancestors and of \ (f',h' \) on the tree, i.e. \ (f_x = \ sum \ limits {y \ in \ operatorname {path} (\ mathrm {root}, x)} f '_ y\). The second and third categories in the above formula can be discussed together, so only the cases of \ (u=z \) and \ (u\neq z \) need to be discussed.
- \(u=z\) At this time, the path \ ((x,y,w) \) contributes to the child node \ (v \) of \ (Z \), if and only if \ (v \) is not on the path \ ((x,y) \). If all the child nodes of \ (Z \) are regarded as a sequence, the influence of the path on the sequence can be divided into \ (1 / 2 / 3 \) intervals. The operation to be implemented on the sequence of child nodes is the single point query of interval taking \ (\ max \). Therefore, you only need to use the segment tree to maintain the sequence of child nodes of \ (Z \).
- \(u\neq z\) Set \ (u\in\operatorname{path}(x,z) \). At this time, the path \ ((x,y,w) \) contributes to the child node \ (V \) of \ (U \), if and only if \ (V \) is not on the path \ ((x,y) \). When considering enumerating nodes \ (U \), calculate \ (h'_v \), then the path \ ((x,y,w) \) that contributes to \ (h'_v \) must have an endpoint \ (x \) in the \ (U \) subtree rather than in the \ (V \) subtree. We can save the contribution of path \ ((x,y,w) \) on node \ (x,y \). When querying, turning to the order \ (\ operatorname{dfn} \) is an interval query. Therefore, we can use the segment tree to maintain the \ (\ operatorname{dfn} \) order.
The time complexity is \ (\ mathcal O((n+m)\log n) \).
Reference code#include <bits/stdc++.h> using namespace std; template<typename _Tp> _Tp &min_eq(_Tp &x, const _Tp &y) { return x = min(x, y); } template<typename _Tp> _Tp &max_eq(_Tp &x, const _Tp &y) { return x = max(x, y); } static constexpr int mod = 998244353; static constexpr int Maxn = 3e5 + 5; static constexpr int64_t inf = 0x3f3f3f3f3f3f3f3f; int n, m; int64_t ans; vector<int> g[Maxn]; namespace hld { int par[Maxn], sz[Maxn], son[Maxn], dep[Maxn]; int top[Maxn], dfn[Maxn], idfn[Maxn], ed[Maxn], dn; void predfs1(int u, int fa, int depth) { par[u] = fa, dep[u] = depth; sz[u] = 1, son[u] = 0; for (const int &v: g[u]) if (v != par[u]) { predfs1(v, u, depth + 1), sz[u] += sz[v]; if (son[u] == 0 || sz[v] > sz[son[u]]) son[u] = v; } } // hld::predfs1 void predfs2(int u, int topv) { top[u] = topv, idfn[dfn[u] = ++dn] = u; if (son[u] != 0) predfs2(son[u], topv); for (const int &v: g[u]) if (v != par[u]) if (v != son[u]) predfs2(v, v); } // hld::predfs2 inline int get_lca(int u, int v) { for (; top[u] != top[v]; v = par[top[v]]) if (dep[top[u]] > dep[top[v]]) swap(u, v); return dep[u] < dep[v] ? u : v; } // hld::get_lca inline int get_anc(int u, int k) { if (k < 0 || k >= dep[u]) return 0; for (; dep[u] - dep[par[top[u]]] <= k; u = par[top[u]]) k -= (dep[u] - dep[par[top[u]]]); return idfn[dfn[u] - k]; } // hld::get_anc inline void initialize(int root) { dn = 0, predfs1(root, 0, 1), predfs2(root, root); for (int i = 1; i <= n; ++i) ed[i] = dfn[i] + sz[i] - 1; } // hld::initialize } // namespace hld using namespace hld; namespace fen { int64_t b[Maxn]; inline void clr(void) { memset(b, 0, sizeof(b)); } inline void upd(int x, int64_t v) { for (; x <= n; x += x & -x) b[x] += v; } inline int64_t ask(int x) { int64_t r = 0; for (; x; x -= x & -x) r += b[x]; return r; } } // namespace fen struct path { int x, y; int64_t w; } pa[Maxn]; vector<path> plca[Maxn]; int64_t f[Maxn], F[Maxn], f1[Maxn]; void dfs1(int u, int fa) { for (const int &v: g[u]) if (v != fa) dfs1(v, u); for (const auto &[x, y, w]: plca[u]) max_eq(f[u], w - fen::ask(dfn[x]) - fen::ask(dfn[y])); fen::upd(dfn[u], f[u]), fen::upd(ed[u] + 1, -f[u]); } // dfs1 void dfs11(int u, int fa) { F[u] = f[u] + F[fa], f1[u] = f[u]; for (const int &v: g[u]) if (v != fa) dfs11(v, u), f1[u] += f1[v]; } // dfs11 int64_t h[Maxn], H[Maxn], h1[Maxn]; namespace sgt1 { int64_t tr[Maxn * 4]; void update(int p, int l, int r, int x, int64_t v) { max_eq(tr[p], v); if (l == r) return ; int mid = (l + r) / 2; if (x <= mid) update(p * 2 + 0, l, mid, x, v); else update(p * 2 + 1, mid + 1, r, x, v); } // sgt1::update int64_t query(int p, int l, int r, int L, int R) { if (L > r || l > R) return -inf; if (L <= l && r <= R) return tr[p]; int mid = (l + r) / 2; return max(query(p * 2 + 0, l, mid, L, R), query(p * 2 + 1, mid + 1, r, L, R)); } // sgt1::query inline void upd(int x, int64_t v) { return update(1, 1, n, x, v); } inline int64_t ask(int l, int r) { return query(1, 1, n, l, r); } } // namespace sgt1 namespace sgt2 { int N; int64_t tr[Maxn * 4]; void build(int n) { N = n, memset(tr, -63, (n + 1) * 4 * sizeof(*tr)); } void update(int p, int l, int r, int L, int R, int64_t v) { if (L > r || l > R) return ; if (L <= l && r <= R) return max_eq(tr[p], v), void(); int mid = (l + r) / 2; update(p * 2 + 0, l, mid, L, R, v); update(p * 2 + 1, mid + 1, r, L, R, v); } // sgt2::update int64_t query(int p, int l, int r, int x) { if (l == r) return tr[p]; int mid = (l + r) / 2; int64_t t = tr[p]; if (x <= mid) max_eq(t, query(p * 2 + 0, l, mid, x)); else max_eq(t, query(p * 2 + 1, mid + 1, r, x)); return t; } // sgt2::query inline void upd(int l, int r, int64_t v) { return update(1, 1, N, l, r, v); } inline int64_t ask(int x) { return query(1, 1, N, x); } } // namespace sgt2 void dfs2(int u, int fa) { H[u] = H[fa] + h[u]; static int label[Maxn]; int N = 0; for (const int &v: g[u]) if (v != fa) label[v] = ++N; if (N != 0) { for (const int &v: g[u]) if (v != fa) max_eq(h[v], max(sgt1::ask(dfn[u], dfn[v] - 1), sgt1::ask(ed[v] + 1, ed[u])) + F[u] - H[u]); sgt2::build(N); for (auto [x, y, w]: plca[u]) { if (dep[x] > dep[y]) swap(x, y); int xk = get_anc(x, dep[x] - dep[u] - 1); int yk = get_anc(y, dep[y] - dep[u] - 1); if (xk == 0 && yk == 0) { sgt2::upd(1, N, w); } else if (xk == 0) { int64_t v = w - F[y] + F[u]; if (1 < label[yk]) sgt2::upd(1, label[yk] - 1, v); if (label[yk] < N) sgt2::upd(label[yk] + 1, N, v); } else { int64_t v = w - F[x] + F[u] - F[y] + F[u]; if (label[xk] > label[yk]) swap(x, y), swap(xk, yk); if (1 < label[xk]) sgt2::upd(1, label[xk] - 1, v); if (label[yk] < N) sgt2::upd(label[yk] + 1, N, v); if (label[xk] + 1 <= label[yk] - 1) sgt2::upd(label[xk] + 1, label[yk] - 1, v); } } for (const int &v: g[u]) if (v != fa) max_eq(h[v], sgt2::ask(label[v])); } for (auto [x, y, w]: plca[u]) { if (dep[x] > dep[y]) swap(x, y); int xk = get_anc(x, dep[x] - dep[u] - 1); int yk = get_anc(y, dep[y] - dep[u] - 1); if (xk == 0 && yk == 0) { } else if (xk == 0) { int64_t v = w - F[y] + H[u]; sgt1::upd(dfn[y], v); } else { int64_t v = w - F[x] - F[y] + H[u] + F[u]; sgt1::upd(dfn[x], v); sgt1::upd(dfn[y], v); } } for (const int &v: g[u]) if (v != fa) dfs2(v, u); } // dfs2 void dfs3(int u, int fa) { for (const int &v: g[u]) if (v != fa) h1[v] = h1[u] + f1[u] - f1[v] - f[u] + h[v], dfs3(v, u); } // dfs3 int64_t sf[Maxn]; void dfs4(int u, int fa) { for (const int &v: g[u]) if (v != fa) dfs4(v, u), (sf[u] += sf[v]) %= mod; int64_t z = (int64_t)sz[u] * sz[u]; for (const int &v: g[u]) if (v != fa) z -= (int64_t)sz[v] * sz[v]; ((ans -= (z % mod) * ((f1[u] + h1[u]) % mod) % mod) += mod) %= mod; for (const int &v: g[u]) if (v != fa) (ans += 2 * sf[v] * (sz[u] - sz[v]) % mod) %= mod; (ans += (f[u] % mod) * (z % mod) % mod) %= mod; (sf[u] += sz[u] * f[u] % mod) %= mod; } // dfs4 int main(void) { scanf("%d%d", &n, &m); for (int i = 2; i <= n; ++i) { int u, v; scanf("%d%d", &u, &v); g[u].push_back(v); g[v].push_back(u); } hld::initialize(1); for (int i = 1; i <= m; ++i) { scanf("%d%d%lld", &pa[i].x, &pa[i].y, &pa[i].w); int z = get_lca(pa[i].x, pa[i].y); plca[z].push_back(pa[i]); } fen::clr(), dfs1(1, 0), dfs11(1, 0); memset(sgt1::tr, -63, sizeof(sgt1::tr)); dfs2(1, 0), dfs3(1, 0); ans = 0, dfs4(1, 0); (ans += (int64_t)n * n % mod * (f1[1] % mod) % mod) %= mod; printf("%lld\n", ans); exit(EXIT_SUCCESS); } // main