Description
You are given an undirected weighted tree with n nodes, numbered from 0 to n - 1. It is represented by a 2D integer array edges of length n - 1, where edges[i] = [ui, vi, wi] indicates that there is an edge between nodes ui and vi with weight wi.β
Additionally, you are given a 2D integer array queries, where queries[j] = [src1j, src2j, destj].
Return an array answer of length equal to queries.length, where answer[j] is the minimum total weight of a subtree such that it is possible to reach destj from both src1j and src2j using edges in this subtree.
A subtree here is any connected subset of nodes and edges of the original tree forming a valid tree.
Β
Example 1:
Input: edges = [[0,1,2],[1,2,3],[1,3,5],[1,4,4],[2,5,6]], queries = [[2,3,4],[0,2,5]]
Output: [12,11]
Explanation:
The blue edges represent one of the subtrees that yield the optimal answer.

-
answer[0]: The total weight of the selected subtree that ensures a path fromsrc1 = 2andsrc2 = 3todest = 4is3 + 5 + 4 = 12. -
answer[1]: The total weight of the selected subtree that ensures a path fromsrc1 = 0andsrc2 = 2todest = 5is2 + 3 + 6 = 11.
Example 2:
Input: edges = [[1,0,8],[0,2,7]], queries = [[0,1,2]]
Output: [15]
Explanation:

answer[0]: The total weight of the selected subtree that ensures a path fromsrc1 = 0andsrc2 = 1todest = 2is8 + 7 = 15.
Β
Constraints:
3 <= n <= 105edges.length == n - 1edges[i].length == 30 <= ui, vi < n1 <= wi <= 1041 <= queries.length <= 105queries[j].length == 30 <= src1j, src2j, destj < nsrc1j,src2j, anddestjare pairwise distinct.- The input is generated such that
edgesrepresents a valid tree.
Solution
Python3
class Solution:
def minimumWeight(self, edges: List[List[int]], queries: List[List[int]]) -> List[int]:
N = len(edges) + 1
M = N.bit_length() + 1
graph = defaultdict(list)
dist = [0] * N
for a, b, w in edges:
graph[a].append((b, w))
graph[b].append((a, w))
parent = [[0] * M for _ in range(N)]
d = [0] * N
def dfs(node, prev, depth):
parent[node][0] = prev
d[node] = depth
for adj, w in graph[node]:
if adj != prev:
dist[adj] = dist[node] + w
dfs(adj, node, depth + 1)
dfs(0, -1, 0)
# binary lifting
for power in range(1, M):
for node in range(N):
parent[node][power] = parent[parent[node][power - 1]][power - 1]
def lca(a, b):
if d[a] > d[b]:
a, b = b, a
# let a and b jump to the same depth
diff = d[b] - d[a]
for p in range(M):
if diff & (1 << p):
b = parent[b][p]
if a == b: return a
for p in range(M - 1, -1, -1):
if parent[a][p] != parent[b][p]:
a = parent[a][p]
b = parent[b][p]
return parent[a][0]
def path_weight(a, b):
return dist[a] + dist[b] - 2 * dist[lca(a, b)]
res = []
for src1, src2, dest in queries:
a = path_weight(src1, dest)
b = path_weight(src2, dest)
c = path_weight(src1, src2)
# divide by 2 here as every path crossed is counted twice
res.append((a + b + c) // 2)
return res