Description
You are given a string s and an integer k.
First, you are allowed to change at most one index in s to another lowercase English letter.
After that, do the following partitioning operation until s is empty:
- Choose the longest prefix of scontaining at mostkdistinct characters.
- Delete the prefix from sand increase the number of partitions by one. The remaining characters (if any) insmaintain their initial order.
Return an integer denoting the maximum number of resulting partitions after the operations by optimally choosing at most one index to change.
Example 1:
Input: s = "accca", k = 2
Output: 3
Explanation:
The optimal way is to change s[2] to something other than a and c, for example, b. then it becomes "acbca".
Then we perform the operations:
- The longest prefix containing at most 2 distinct characters is "ac", we remove it andsbecomes"bca".
- Now The longest prefix containing at most 2 distinct characters is "bc", so we remove it andsbecomes"a".
- Finally, we remove "a"andsbecomes empty, so the procedure ends.
Doing the operations, the string is divided into 3 partitions, so the answer is 3.
Example 2:
Input: s = "aabaab", k = 3
Output: 1
Explanation:
Initially s contains 2 distinct characters, so whichever character we change, it will contain at most 3 distinct characters, so the longest prefix with at most 3 distinct characters would always be all of it, therefore the answer is 1.
Example 3:
Input: s = "xxyz", k = 1
Output: 4
Explanation:
The optimal way is to change s[0] or s[1] to something other than characters in s, for example, to change s[0] to w.
Then s becomes "wxyz", which consists of 4 distinct characters, so as k is 1, it will divide into 4 partitions.
Constraints:
- 1 <= s.length <= 104
- sconsists only of lowercase English letters.
- 1 <= k <= 26
Solution
Python3
class Solution:
    def maxPartitionsAfterOperations(self, s: str, k: int) -> int:
        N = len(s)
        s = [ord(x) - ord('a') for x in s]
        
        @cache
        def go(index, left, mask):
            if index == N: return 0
            
            newMask = mask | (1 << s[index])
            bitCount = newMask.bit_count()
            
            if bitCount > k:
                res = 1 + go(index + 1, left, 1 << s[index])
            else:
                res = go(index + 1, left, newMask)
            
            if left > 0:
                for c in range(26):
                    newMask = mask | (1 << c)
                    bitCount = newMask.bit_count()
                    
                    if bitCount > k:
                        res = max(res, 1 + go(index + 1, left - 1, 1 << c))
                    else:
                        res = max(res, go(index + 1, left - 1, newMask))
                    
            return res
        
        return go(0, 1, 0) + 1