LC 100112. 平衡子序列的最大和

https://leetcode.cn/problems/maximum-balanced-subsequence-sum/description/

这题最开始的想法是维护一个排序结构 `st = [(a, b)]`

其实这个结构挺好的,唯一的问题就在于更新,维护这个数据结构的成本比较高:需要将某些元素从st里面删除掉,这个成本有点高。 我花了比较长时间进行调试,但是还有有一些test cases出错了。 错误的思路,糟糕的数据结构,从一开始就会让整个调试成本增加。

后面换了一个思路,就是其实 `i-nums[i]` 这个值是可以枚举出来的。如果我们将这些值映射成为偏移量的话,那么其实我们想求解的就是某个区间内的最大值, 并且每次只是更新其中一个点,这个数据结构正好就是上周使用的线段树。

#!/usr/bin/env python
# coding:utf-8
# Copyright (C) dirlt

from typing import List


class Solution:
    def maxBalancedSubsequenceSum(self, nums: List[int]) -> int:
        # dp[i] = max(dp[j]) + nums[i] if (i - j) <= nums[i] - nums[j]
        # i - nums[i] <= j - nums[j]

        diff = [i - nums[i] for i in range(len(nums))]
        diff.sort(reverse=True)
        pos = {}
        for d in diff:
            if d not in pos:
                pos[d] = len(pos)
        N = len(pos)

        INF = 1 << 63
        SZ = 1
        while SZ < N:
            SZ = SZ * 2
        MAX = [-INF] * (2 * SZ)

        def update_max(p, v):
            k = p + SZ
            MAX[k] = max(MAX[k], v)
            while k != 1:
                p = k // 2
                MAX[p] = max(MAX[2 * p], MAX[2 * p + 1])
                k = p

        def query_max(p):
            def do(i, j, k, s, sz):
                if i <= s <= (s + sz - 1) <= j:
                    return MAX[k]
                mid = s + sz // 2
                res = -INF
                if i < mid:
                    a = do(i, j, 2 * k, s, sz // 2)
                    res = max(res, a)
                if j >= mid:
                    a = do(i, j, 2 * k + 1, mid, sz // 2)
                    res = max(res, a)
                return res

            return do(0, p, 1, 0, SZ)

        ans = -INF
        for i in range(len(nums)):
            d = i - nums[i]
            p = pos[d]
            value = nums[i]
            last = query_max(p)
            value += max(last, 0)
            update_max(p, value)
            ans = max(ans, value)

        return ans