Problem Link

Description


You are given an undirected weighted tree with n nodes, numbered from 0 to n - 1. It is represented by a 2D integer array edges of length n - 1, where edges[i] = [ui, vi, wi] indicates that there is an edge between nodes ui and vi with weight wi.​

Additionally, you are given a 2D integer array queries, where queries[j] = [src1j, src2j, destj].

Return an array answer of length equal to queries.length, where answer[j] is the minimum total weight of a subtree such that it is possible to reach destj from both src1j and src2j using edges in this subtree.

A subtree here is any connected subset of nodes and edges of the original tree forming a valid tree.

Β 

Example 1:

Input: edges = [[0,1,2],[1,2,3],[1,3,5],[1,4,4],[2,5,6]], queries = [[2,3,4],[0,2,5]]

Output: [12,11]

Explanation:

The blue edges represent one of the subtrees that yield the optimal answer.

  • answer[0]: The total weight of the selected subtree that ensures a path from src1 = 2 and src2 = 3 to dest = 4 is 3 + 5 + 4 = 12.

  • answer[1]: The total weight of the selected subtree that ensures a path from src1 = 0 and src2 = 2 to dest = 5 is 2 + 3 + 6 = 11.

Example 2:

Input: edges = [[1,0,8],[0,2,7]], queries = [[0,1,2]]

Output: [15]

Explanation:

  • answer[0]: The total weight of the selected subtree that ensures a path from src1 = 0 and src2 = 1 to dest = 2 is 8 + 7 = 15.

Β 

Constraints:

  • 3 <= n <= 105
  • edges.length == n - 1
  • edges[i].length == 3
  • 0 <= ui, vi < n
  • 1 <= wi <= 104
  • 1 <= queries.length <= 105
  • queries[j].length == 3
  • 0 <= src1j, src2j, destj < n
  • src1j, src2j, and destj are pairwise distinct.
  • The input is generated such that edges represents a valid tree.

Solution


Python3

class Solution:
    def minimumWeight(self, edges: List[List[int]], queries: List[List[int]]) -> List[int]:
        N = len(edges) + 1
        M = N.bit_length() + 1
        graph = defaultdict(list)
        dist = [0] * N
         
        for a, b, w in edges:
            graph[a].append((b, w))
            graph[b].append((a, w))
         
        parent = [[0] * M for _ in range(N)]
        d = [0] * N
         
        def dfs(node, prev, depth):
            parent[node][0] = prev
            d[node] = depth
         
            for adj, w in graph[node]:
                if adj != prev:
                    dist[adj] = dist[node] + w
                    dfs(adj, node, depth + 1)
         
        dfs(0, -1, 0)
         
        # binary lifting
        for power in range(1, M):
            for node in range(N):
                parent[node][power] = parent[parent[node][power - 1]][power - 1]
         
        def lca(a, b):
            if d[a] > d[b]:
                a, b = b, a
         
            # let a and b jump to the same depth
            diff = d[b] - d[a]
            for p in range(M):
                if diff & (1 << p):
                    b = parent[b][p]
         
            if a == b: return a
         
            for p in range(M - 1, -1, -1):
                if parent[a][p] != parent[b][p]:
                    a = parent[a][p]
                    b = parent[b][p]
         
            return parent[a][0]
 
        def path_weight(a, b):
            return dist[a] + dist[b] - 2 * dist[lca(a, b)]
 
        res = []
        for src1, src2, dest in queries:
            a = path_weight(src1, dest)
            b = path_weight(src2, dest)
            c = path_weight(src1, src2)
            # divide by 2 here as every path crossed is counted twice
            res.append((a + b + c) // 2)
 
        return res