Problem Link

Description


You are given two integers n and k.

For any positive integer x, define the following sequence:

  • p0 = x
  • pi+1 = popcount(pi) for all i >= 0, where popcount(y) is the number of set bits (1's) in the binary representation of y.

This sequence will eventually reach the value 1.

The popcount-depth of x is defined as the smallest integer d >= 0 such that pd = 1.

For example, if x = 7 (binary representation "111"). Then, the sequence is: 7 β†’ 3 β†’ 2 β†’ 1, so the popcount-depth of 7 is 3.

Your task is to determine the number of integers in the range [1, n] whose popcount-depth is exactly equal to k.

Return the number of such integers.

Β 

Example 1:

Input: n = 4, k = 1

Output: 2

Explanation:

The following integers in the range [1, 4] have popcount-depth exactly equal to 1:

xBinarySequence
2"10"2 β†’ 1
4"100"4 β†’ 1

Thus, the answer is 2.

Example 2:

Input: n = 7, k = 2

Output: 3

Explanation:

The following integers in the range [1, 7] have popcount-depth exactly equal to 2:

xBinarySequence
3"11"3 β†’ 2 β†’ 1
5"101"5 β†’ 2 β†’ 1
6"110"6 β†’ 2 β†’ 1

Thus, the answer is 3.

Β 

Constraints:

  • 1 <= n <= 1015
  • 0 <= k <= 5

Solution


Python3

class Solution:
    def popcountDepth(self, n: int, k: int) -> int:
        if k == 0: return 1
        
        def compute(x):
            depth = 0
 
            while x > 1:
                x = x.bit_count()
                depth += 1
            
            return depth
        
        seen = set()
        for i in range(1, 65):
            if compute(i) + 1 == k:
                seen.add(i)
        
        if not seen: return 0
 
        nums = list(map(int, bin(n)[2:]))
 
        @cache
        def dp(index, tight, ones):
            if index == len(nums):
                return 1 if ones in seen else 0
            
            digit = nums[index] if tight else 1
            count = 0
 
            for d in range(digit + 1):
                newTight = tight and d == digit
                count += dp(index + 1, newTight, ones + d)
 
            return count
        
        ans = dp(0, True, 0)
 
        if k == 1 and 1 in seen:
            return ans - 1
        
        return ans