코딩 테스트 및 알고리즘/leetcode for google

leetcode medium : Range Sum Query Mutable

띠리링구 2022. 11. 15. 04:05

https://leetcode.com/problems/range-sum-query-mutable/

 

Range Sum Query - Mutable - LeetCode

Level up your coding skills and quickly land a job. This is the best place to expand your knowledge and get prepared for your next interview.

leetcode.com

 

나열된 값 중에서 특정 범위의 값들의 합을 빠르게 구하려면 Segment Tree를 생각해봐야 한다.

(정확히 말하면 '구간에 대한 쿼리'를 효과적으로 처리할 수 있는 자료구조)

"""
    The idea here is to build a segment tree. Each node stores the left and right
    endpoint of an interval and the sum of that interval. All of the leaves will store
    elements of the array and each internal node will store sum of leaves under it.
    Creating the tree takes O(n) time. Query and updates are both O(log n).
"""

#Segment tree node
class Node(object):
    def __init__(self, start, end):
        self.start = start
        self.end = end
        self.total = 0
        self.left = None
        self.right = None
        

class NumArray(object):
    def __init__(self, nums):
        """
        initialize your data structure here.
        :type nums: List[int]
        """
        #helper function to create the tree from input array
        def createTree(nums, l, r):
            
            #base case
            if l > r:
                return None
                
            #leaf node
            if l == r:
                n = Node(l, r)
                n.total = nums[l]
                return n
            
            mid = (l + r) // 2
            
            root = Node(l, r)
            
            #recursively build the Segment tree
            root.left = createTree(nums, l, mid)
            root.right = createTree(nums, mid+1, r)
            
            #Total stores the sum of all leaves under root
            #i.e. those elements lying between (start, end)
            root.total = root.left.total + root.right.total
                
            return root
        
        self.root = createTree(nums, 0, len(nums)-1)
            
    def update(self, i, val):
        """
        :type i: int
        :type val: int
        :rtype: int
        """
        #Helper function to update a value
        def updateVal(root, i, val):
            
            #Base case. The actual value will be updated in a leaf.
            #The total is then propogated upwards
            if root.start == root.end:
                root.total = val
                return val
        
            mid = (root.start + root.end) // 2
            
            #If the index is less than the mid, that leaf must be in the left subtree
            if i <= mid:
                updateVal(root.left, i, val)
                
            #Otherwise, the right subtree
            else:
                updateVal(root.right, i, val)
            
            #Propogate the changes after recursive call returns
            root.total = root.left.total + root.right.total
            
            return root.total
        
        return updateVal(self.root, i, val)

    def sumRange(self, i, j):
        """
        sum of elements nums[i..j], inclusive.
        :type i: int
        :type j: int
        :rtype: int
        """
        #Helper function to calculate range sum
        def rangeSum(root, i, j):
            
            #If the range exactly matches the root, we already have the sum
            if root.start == i and root.end == j:
                return root.total
            
            mid = (root.start + root.end) // 2
            
            #If end of the range is less than the mid, the entire interval lies
            #in the left subtree
            if j <= mid:
                return rangeSum(root.left, i, j)
            
            #If start of the interval is greater than mid, the entire inteval lies
            #in the right subtree
            elif i >= mid + 1:
                return rangeSum(root.right, i, j)
            
            #Otherwise, the interval is split. So we calculate the sum recursively,
            #by splitting the interval
            else:
                return rangeSum(root.left, i, mid) + rangeSum(root.right, mid+1, j)
        
        return rangeSum(self.root, i, j)

내가 푼건 아니고 리트코드 디스커션에서 읽기 쉽게 너무 잘 풀어놔서 복붙했다. (https://leetcode.com/problems/range-sum-query-mutable/discuss/75784/Python%3A-Well-commented-solution-using-Segment-Trees)

 

이 문제를 본 이유는 최근 본 면접에서 Segment Tree 관련 문제를 받았고 제대로 못 풀었기 때문이다 ㅠ

Segment Tree가 뭐냐면... (참고 https://www.youtube.com/watch?v=075fcq7oCC8 )

 

세그먼트 트리는 구간 연산을 빠르게 할 수 있는 자료구조라고 한다. 특히 값이 update 될 수 있는 경우에 유용하며 값이 고정되어있는 경우엔 prefix sum 혹은 sparse table 방식을 이용할 수 있다고 영상에서 설명하고 있다. 그림에서 보듯이 구간별로 연산결과를 미리 계산해놓고 각 연산 결과를 이용해서 구간 연산 결과를 계산하는 방식이다. 예를 들어 인덱스 6~9의 합을 구한다고 하면 root에서부터 분기해서 6~7의 연산결과랑 8~9의 연산결과를 찾고(O(log N)) 그걸 합해서 리턴하는 식이 되겠다.

이처럼 부분(segment)를 나눠서 미리 pre-compute하는 방식이라서 segment tree라고 불리나보다. 리프노드는 각각의 요소 하나하나인데 구간 연산 결과를 구할 때 최대한 큰 세그먼트들을 모아서 연산 결과를 취합해야한다.