Given an array of integers arr
We want to select three indices i
, j
and k
where (0 <= i < j <= k < arr.length)
Let's define a
and b
as follows:
a = arr[i] ^ arr[i + 1] ^ ... ^ arr[j - 1]
b = arr[j] ^ arr[j + 1] ^ ... ^ arr[k]
Note that ^ denotes the bitwise-xor operation.
Return the number of triplets (i
, j
and k
) Where a == b
Example 1:
Input: arr = [2,3,1,6,7] Output: 4 Explanation: The triplets are (0,1,2), (0,2,2), (2,3,4) and (2,4,4)
Example 2:
Input: arr = [1,1,1,1,1] Output: 10
1 <= arr.length <= 300
1 <= arr[i] <= 108
class Solution:
def countTriplets(self, arr: List[int]) -> int:
N = len(arr)
prefix = [0] * (N + 1)
res = 0
# since a == b
# a ^ a = b ^ a
# 0 = b ^ a
# arr[i] ^ arr[i + 1] ^ ... ^ arr[k] == 0
# prefix[k + 1] == prefix[i]
for i in range(1, N + 1):
prefix[i] = arr[i - 1] ^ prefix[i - 1]
count = Counter()
total = Counter()
for i in range(N + 1):
res += count[prefix[i]] * (i - 1) - total[prefix[i]]
count[prefix[i]] += 1
total[prefix[i]] += i
return res