Description
You are given an undirected weighted tree with n
nodes, numbered from 0
to n - 1
. It is represented by a 2D integer array edges
of length n - 1
, where edges[i] = [ui, vi, wi]
indicates that there is an edge between nodes ui
and vi
with weight wi
.β
Additionally, you are given a 2D integer array queries
, where queries[j] = [src1j, src2j, destj]
.
Return an array answer
of length equal to queries.length
, where answer[j]
is the minimum total weight of a subtree such that it is possible to reach destj
from both src1j
and src2j
using edges in this subtree.
A subtree here is any connected subset of nodes and edges of the original tree forming a valid tree.
Β
Example 1:
Input: edges = [[0,1,2],[1,2,3],[1,3,5],[1,4,4],[2,5,6]], queries = [[2,3,4],[0,2,5]]
Output: [12,11]
Explanation:
The blue edges represent one of the subtrees that yield the optimal answer.
-
answer[0]
: The total weight of the selected subtree that ensures a path fromsrc1 = 2
andsrc2 = 3
todest = 4
is3 + 5 + 4 = 12
. -
answer[1]
: The total weight of the selected subtree that ensures a path fromsrc1 = 0
andsrc2 = 2
todest = 5
is2 + 3 + 6 = 11
.
Example 2:
Input: edges = [[1,0,8],[0,2,7]], queries = [[0,1,2]]
Output: [15]
Explanation:
answer[0]
: The total weight of the selected subtree that ensures a path fromsrc1 = 0
andsrc2 = 1
todest = 2
is8 + 7 = 15
.
Β
Constraints:
3 <= n <= 105
edges.length == n - 1
edges[i].length == 3
0 <= ui, vi < n
1 <= wi <= 104
1 <= queries.length <= 105
queries[j].length == 3
0 <= src1j, src2j, destj < n
src1j
,src2j
, anddestj
are pairwise distinct.- The input is generated such that
edges
represents a valid tree.
Solution
Python3
class Solution:
def minimumWeight(self, edges: List[List[int]], queries: List[List[int]]) -> List[int]:
N = len(edges) + 1
M = N.bit_length() + 1
graph = defaultdict(list)
dist = [0] * N
for a, b, w in edges:
graph[a].append((b, w))
graph[b].append((a, w))
parent = [[0] * M for _ in range(N)]
d = [0] * N
def dfs(node, prev, depth):
parent[node][0] = prev
d[node] = depth
for adj, w in graph[node]:
if adj != prev:
dist[adj] = dist[node] + w
dfs(adj, node, depth + 1)
dfs(0, -1, 0)
# binary lifting
for power in range(1, M):
for node in range(N):
parent[node][power] = parent[parent[node][power - 1]][power - 1]
def lca(a, b):
if d[a] > d[b]:
a, b = b, a
# let a and b jump to the same depth
diff = d[b] - d[a]
for p in range(M):
if diff & (1 << p):
b = parent[b][p]
if a == b: return a
for p in range(M - 1, -1, -1):
if parent[a][p] != parent[b][p]:
a = parent[a][p]
b = parent[b][p]
return parent[a][0]
def path_weight(a, b):
return dist[a] + dist[b] - 2 * dist[lca(a, b)]
res = []
for src1, src2, dest in queries:
a = path_weight(src1, dest)
b = path_weight(src2, dest)
c = path_weight(src1, src2)
# divide by 2 here as every path crossed is counted twice
res.append((a + b + c) // 2)
return res