LC 1703. 得到连续 K 个 1 的最少相邻交换次数

https://leetcode-cn.com/problems/minimum-adjacent-swaps-for-k-consecutive-ones/

这题首先要证明,将1往中间靠是最优解,然后就是如何优化计算。

假设有k个不连续的1,我们假设在某个点上,左边有m个,右边则有k-m-1个。首先我们想象,要将这些1全部聚到这个点附近,需要X步骤。如果这个点右移一个单位,那么左边距离会增加m, 而右边距离则会减少(k-m-1),就是 `X+2m-k+1`. 为了确保距离会继续增加,就需要假设 `X=(X+2m-k+1)`, 所以 `m=(k-1)/2`. 而这个中间点的左边就是 `(k-1)/2` (以0为下标,在python函数里面就是 `k//2` )

一旦确定要往中点靠近,接着就是确定每个中点的移动代价了。这个代价其实可以分为两个部分:

  1. 每个点到这个点的距离之和
  2. 每个点因为顺序原因少移动的距离。

关于(2)可以举个例子, 假设 010101 -> 000111, 最左边的1到最右边的1距离是4,不过因为它是最外面的一个1,所以只需要移动4-2=2。将(2)单独分离出来计算可以减少逻辑复杂度,并且(2)这个距离是不变。下面是计算(2)这个值得代码

saved = 0
for i in range(mid):
    saved += (mid - i)
for i in range(mid+1, k):
    saved += (i - mid)

然后每次移动中点,有4个部分会变化:

  1. 中点所有左边的点需要增加 `(a[mid+1]-a[mid])`
  2. 中点所有右边的点需要减少 `(a[mid+1]-a[mid])`
  3. 减去最左边的点 `(a[mid+1]-a[i-k])`
  4. 增加最右边的点 `(a[i]-a[mid+1])`

下面是完整代码

class Solution:
    def minMoves(self, nums: List[int], k: int) -> int:
        arr = []
        for i in range(len(nums)):
            if nums[i] == 1:
                arr.append(i)

        if k == 1:
            return 0

        # 这题首先要证明往中间靠是最优解
        # 之后采用类似滑动窗口办法
        # mid = (k-1) / 2 是最优解

        # initialize cost.
        half = mid = k // 2
        cost = 0
        for i in range(k):
            p0 = arr[mid]
            p1 = arr[i]
            cost += abs(p0 - p1)
        # note: move all around 1 to mid.
        saved = 0
        for i in range(mid):
            saved += (mid - i)
        for i in range(mid+1, k):
            saved += (i - mid)

        ans = cost
        # print(cost)
        for i in range(k, len(arr)):
            # mid -> mid + 1
            it = arr[mid+1] - arr[mid]
            a = (half + 1) * it
            b = (k - half - 1) * it
            # remove (i-k-1) item.
            c = arr[mid+1] - arr[i-k]
            # add (i) item
            d = arr[i] - arr[mid+1]
            cost += (a - b - c + d)
            # print(it, a, b, c, d, cost)
            ans = min(ans, cost)
            mid = mid + 1
        # adjust final cost.
        ans -= saved
        return ans