1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
| #include <bits/stdc++.h> #define rep(i, a, b) for(int i = (a); i <= (b); ++i) #define per(i, a, b) for(int i = (a); i >= (b); --i) #define debug(x) cerr << #x << ' ' << x << endl; using namespace std;
typedef long long ll; const int mod = 1e9+7; const int MAXN = 5e5 + 7;
vector<pair<int, int> > G[MAXN]; ll dis[MAXN], msd[MAXN], maxdis, ans; void dfs(int x, int fa) { for(auto p: G[x]) { int v = p.first, w = p.second; if(v == fa) continue; dis[v] = dis[x] + w; dfs(v, x); msd[x] = max(msd[x], msd[v] + w); } }
void dfs2(int x, int fa) { for (auto p: G[x]) { int v = p.first, w = p.second; if(v == fa) continue; ans += msd[x] - msd[v] - w; dfs2(v, x); } } int main(int argc, char const *argv[]) { int n, rt; scanf("%d %d", &n, &rt); int u, v, w; rep(i, 1, n-1){ scanf("%d %d %d", &u, &v, &w); G[u].push_back({v, w}); G[v].push_back({u, w}); } dfs(rt, 0);
dfs2(rt, 0); printf("%lld\n", ans); return 0; }
|