Description
You are given a tree (i.e. a connected, undirected graph that has no cycles) rooted at node 0 consisting of n nodes numbered from 0 to n - 1. The tree is represented by a 0-indexed array parent of size n, where parent[i] is the parent of node i. Since node 0 is the root, parent[0] == -1.
You are also given a string s of length n, where s[i] is the character assigned to the edge between i and parent[i]. s[0] can be ignored.
Return the number of pairs of nodes (u, v) such that u < v and the characters assigned to edges on the path from u to v can be rearranged to form a palindrome.
A string is a palindrome when it reads the same backwards as forwards.
Β
Example 1:

Input: parent = [-1,0,0,1,1,2], s = "acaabc" Output: 8 Explanation: The valid pairs are: - All the pairs (0,1), (0,2), (1,3), (1,4) and (2,5) result in one character which is always a palindrome. - The pair (2,3) result in the string "aca" which is a palindrome. - The pair (1,5) result in the string "cac" which is a palindrome. - The pair (3,5) result in the string "acac" which can be rearranged into the palindrome "acca".
Example 2:
Input: parent = [-1,0,0,0,0], s = "aaaaa" Output: 10 Explanation: Any pair of nodes (u,v) where u < v is valid.
Β
Constraints:
n == parent.length == s.length1 <= n <= 1050 <= parent[i] <= n - 1for alli >= 1parent[0] == -1parentrepresents a valid tree.sconsists of only lowercase English letters.
Solution
Python3
class Solution:
def countPalindromePaths(self, parent: List[int], s: str) -> int:
N = len(parent)
@cache
def f(node):
return f(parent[node]) ^ (1 << (ord(s[node]) - ord('a'))) if node else 0
res = 0
mp = Counter()
for node in range(N):
v = f(node)
res += mp[v] + sum(mp[v ^ (1 << k)] for k in range(26))
mp[v] += 1
return res
C++
class Solution {
public:
long long countPalindromePaths(vector<int>& parent, string s) {
int N = parent.size();
vector<long long> mp(N);
unordered_map<long long, long long> counter;
vector<vector<int>> graph(N, vector<int>());
for (int node = 0; node < N; node++) {
if (parent[node] != -1)
graph[parent[node]].push_back(node);
}
function<void(int, long long)> dfs = [&](int node, long long mask) {
mp[node] = mask;
for (int adj: graph[node])
dfs(adj, mask ^ (1LL << (s[adj] - 'a')));
};
dfs(0, 0);
long long res = 0;
for (int node = 0; node < N; node++) {
long long mask = mp[node];
res += counter[mask];
for (int mid = 0; mid < 26; mid++) {
res += counter[mask ^ (1LL << mid)];
}
counter[mask]++;
}
return res;
}
};