Description
You are given an integer n and an undirected, weighted tree rooted at node 1 with n nodes numbered from 1 to n. This is represented by a 2D array edges of length n - 1, where edges[i] = [ui, vi, wi] indicates an undirected edge from node ui to vi with weight wi.
You are also given a 2D integer array queries of length q, where each queries[i] is either:
- [1, u, v, w']β Update the weight of the edge between nodes- uand- vto- w', where- (u, v)is guaranteed to be an edge present in- edges.
- [2, x]β Compute the shortest path distance from the root node 1 to node- x.
Return an integer array answer, where answer[i] is the shortest path distance from node 1 to x for the ith query of [2, x].
Β
Example 1:
Input: n = 2, edges = [[1,2,7]], queries = [[2,2],[1,1,2,4],[2,2]]
Output: [7,4]
Explanation:

- Query [2,2]: The shortest path from root node 1 to node 2 is 7.
- Query [1,1,2,4]: The weight of edge(1,2)changes from 7 to 4.
- Query [2,2]: The shortest path from root node 1 to node 2 is 4.
Example 2:
Input: n = 3, edges = [[1,2,2],[1,3,4]], queries = [[2,1],[2,3],[1,1,3,7],[2,2],[2,3]]
Output: [0,4,2,7]
Explanation:

- Query [2,1]: The shortest path from root node 1 to node 1 is 0.
- Query [2,3]: The shortest path from root node 1 to node 3 is 4.
- Query [1,1,3,7]: The weight of edge(1,3)changes from 4 to 7.
- Query [2,2]: The shortest path from root node 1 to node 2 is 2.
- Query [2,3]: The shortest path from root node 1 to node 3 is 7.
Example 3:
Input: n = 4, edges = [[1,2,2],[2,3,1],[3,4,5]], queries = [[2,4],[2,3],[1,2,3,3],[2,2],[2,3]]
Output: [8,3,2,5]
Explanation:

- Query [2,4]: The shortest path from root node 1 to node 4 consists of edges(1,2),(2,3), and(3,4)with weights2 + 1 + 5 = 8.
- Query [2,3]: The shortest path from root node 1 to node 3 consists of edges(1,2)and(2,3)with weights2 + 1 = 3.
- Query [1,2,3,3]: The weight of edge(2,3)changes from 1 to 3.
- Query [2,2]: The shortest path from root node 1 to node 2 is 2.
- Query [2,3]: The shortest path from root node 1 to node 3 consists of edges(1,2)and(2,3)with updated weights2 + 3 = 5.
Β
Constraints:
- 1 <= n <= 105
- edges.length == n - 1
- edges[i] == [ui, vi, wi]
- 1 <= ui, vi <= n
- 1 <= wi <= 104
- The input is generated such that edgesrepresents a valid tree.
- 1 <= queries.length == q <= 105
- queries[i].length == 2or- 4- queries[i] == [1, u, v, w']or,
- queries[i] == [2, x]
- 1 <= u, v, x <= n
- (u, v)is always an edge from- edges.
- 1 <= w' <= 104
 
Solution
Python3
class SegmentTree:
    def __init__(self, arr):
        self.n = len(arr)
        self.tree = [0] * (4 * self.n)
        self.build(1, 0, self.n - 1, arr)
 
    def build(self, v, tl, tr, arr):
        if tl == tr:
            self.tree[v] = arr[tl]
        else:
            tm = tl + (tr - tl) // 2
            self.build(v * 2, tl, tm, arr)
            self.build(v * 2 + 1, tm + 1, tr, arr)
            self.tree[v] = self.tree[v * 2] + self.tree[v * 2 + 1]
 
    def queryHelper(self, l, r):
        return self.query(1, 0, self.n - 1, l, r)
 
    def query(self, v, tl, tr, l, r):
        if l > r: return 0
 
        if tl == l and tr == r:
            return self.tree[v]
        else:
            tm = tl + (tr - tl) // 2
 
        return self.query(v * 2, tl, tm, l, min(tm, r)) + self.query(v * 2 + 1, tm + 1, tr, max(tm + 1, l), r)
    
    def updateHelper(self, pos, value):
        self.update(1, 0, self.n - 1, pos, value)
 
    def update(self, v, tl, tr, pos, value):
        if tl == tr:
            self.tree[v] = value
        else:
            tm = tl + (tr - tl) // 2
 
            if pos <= tm:
                self.update(v * 2, tl, tm, pos, value)
            else:
                self.update(v * 2 + 1, tm + 1, tr, pos, value)
 
            self.tree[v] = self.tree[v * 2] + self.tree[v * 2 + 1]
 
class Solution:
    def treeQueries(self, n: int, edges: List[List[int]], queries: List[List[int]]) -> List[int]:
        graph = defaultdict(list)
        for a, b, w in edges:
            graph[a].append((b, w))
            graph[b].append((a, w))
        
        flat_tree = []
        tin = [0] * (n + 1)
        tout = [0] * (n + 1)
 
        def dfs(node, prev, weight):
            flat_tree.append(weight)
            tin[node] = len(flat_tree) - 1
 
            for adj, w2 in graph[node]:
                if adj != prev:
                    dfs(adj, node, w2)
            
            flat_tree.append(-weight)
            tout[node] = len(flat_tree) - 1
        
        dfs(1, -1, 0)
 
        res = []
        st = SegmentTree(flat_tree)
 
        for q in queries:
            if q[0] == 1:
                _, a, b, w = q
                if a > b:
                    a, b = b, a
                node_in, node_out = tin[b], tout[b]
                st.updateHelper(node_in, w)
                st.updateHelper(node_out, -w)
            else:
                node = q[1]
                node_out = tin[node]
                res.append(st.queryHelper(0, node_out))
 
        return res