Skip to content

Latest commit

 

History

History
135 lines (89 loc) · 4.37 KB

5999.count-good-triplets-in-an-array.md

File metadata and controls

135 lines (89 loc) · 4.37 KB

题目地址(5999. 统计数组中好三元组数目)

https://leetcode-cn.com/problems/count-good-triplets-in-an-array/

题目描述

给你两个下标从 0 开始且长度为 n 的整数数组 nums1 和 nums2 ,两者都是 [0, 1, ..., n - 1] 的 排列 。

好三元组 指的是 3 个 互不相同 的值,且它们在数组 nums1 和 nums2 中出现顺序保持一致。换句话说,如果我们将 pos1v 记为值 v 在 nums1 中出现的位置,pos2v 为值 v 在 nums2 中的位置,那么一个好三元组定义为 0 <= x, y, z <= n - 1 ,且 pos1x < pos1y < pos1z 和 pos2x < pos2y < pos2z 都成立的 (x, y, z) 。

请你返回好三元组的 总数目 。

 

示例 1:

输入:nums1 = [2,0,1,3], nums2 = [0,1,2,3]
输出:1
解释:
总共有 4 个三元组 (x,y,z) 满足 pos1x < pos1y < pos1z ,分别是 (2,0,1) ,(2,0,3) ,(2,1,3) 和 (0,1,3) 。
这些三元组中,只有 (0,1,3) 满足 pos2x < pos2y < pos2z 。所以只有 1 个好三元组。


示例 2:

输入:nums1 = [4,0,1,3,2], nums2 = [4,1,0,2,3]
输出:4
解释:总共有 4 个好三元组 (4,0,3) ,(4,0,2) ,(4,1,3) 和 (4,1,2) 。


 

提示:

n == nums1.length == nums2.length
3 <= n <= 10^5
0 <= nums1[i], nums2[i] <= n - 1
nums1 和 nums2 是 [0, 1, ..., n - 1] 的排列。

前置知识

  • 平衡二叉树
  • 枚举

公司

  • 暂无

思路

本题的第一个关键点是:根据数组 A 的索引对应关系置换数组 B,得到新的数组 C,问题转化为堆 C 求递增三元组的个数

比如对于题目给的:nums1 = [2,0,1,3], nums2 = [0,1,2,3]

我们可以获取到 nums1 的索引对应关系,即 2->0, 0->1, 1->2, 3->3。

n = len(nums1)
for i in range(n):
    d[nums1[i]] = i

用这个对应关系更新 nums2,最终得到的数组为 [1,2,0,3],我们只需要求 [1,2,0,3] 的递增三元组的个数即可。

for i in range(n):
    nums.append(d[nums2[i]])

第二个关键单是枚举中间值 x,这样以 x 为中间值的递增三元组的个数就是 ycnt * zcnt(笛卡尔积),其中 ycnt 为 x 前面的比 x 小的,zcnt 为后面的比 x 大的。

比较容易想到的方法是使用两个平衡二叉树,代码:

sl1 = SortedList()
sl2 = SortedList(nums)
for num in nums:
    sl1.add(num)
    sl2.remove(num)
    ans += sl1.bisect_left(num) * (len(sl2) - sl2.bisect_left(num + 1))
return ans

实际上使用 sl1 就足够了。这是因为 zcnt 其实也等价于所有比 x 大的数的总数 - sl1 中比 x 大的数的个数,而这个信息通过 sl1 就足以求得。具体代码见下方代码区。我们可以省去一个 SortedList 的开销,因此不管是空间复杂度还是时间复杂度都可以获得常数级别的优化。

关键点

  • 根据数组 A 的索引对应关系置换数组 B,得到新的数组 C,问题转化为堆 C 求递增三元组的个数
  • 枚举三元组中中间的数

代码

  • 语言支持:Python3

Python3 Code:

from sortedcontainers import SortedList
class Solution:
    def goodTriplets(self, nums1: List[int], nums2: List[int]) -> int:
        d = {}
        nums = []
        ans = 0
        n = len(nums1)
        for i in range(n):
            d[nums1[i]] = i
        for i in range(n):
            nums.append(d[nums2[i]])
        sl1 = SortedList()
        for num in nums:
            sl1.add(num)
            ans += sl1.bisect_left(num) * ((n - num - (len(sl1) - sl1.bisect_left(num))))
        return ans

复杂度分析

令 n 为数组长度。

  • 时间复杂度:$O(nlogn)$
  • 空间复杂度:$O(n)$

此题解由 力扣刷题插件 自动生成。

力扣的小伙伴可以关注我,这样就会第一时间收到我的动态啦~

以上就是本文的全部内容了。大家对此有何看法,欢迎给我留言,我有时间都会一一查看回答。更多算法套路可以访问我的 LeetCode 题解仓库:https://github.com/azl397985856/leetcode 。 目前已经 40K star 啦。大家也可以关注我的公众号《力扣加加》带你啃下算法这块硬骨头。

关注公众号力扣加加,努力用清晰直白的语言还原解题思路,并且有大量图解,手把手教你识别套路,高效刷题。