leetcode : Number of Good Paths
https://leetcode.com/problems/number-of-good-paths/
말도 안되게 어려워...
절대 내 혼자 힘으론 못풀고
가장 인상깊었던 discussion 솔루션의 코드를 해석해보고자 한다.
class Solution:
def numberOfGoodPaths(self, vals, edges):
res = n = len(vals)
f = list(range(n))
count = [Counter({vals[i]: 1}) for i in range(n)]
edges = sorted([max(vals[i], vals[j]),i,j] for i,j in edges)
def find(x):
if f[x] != x:
f[x] = find(f[x])
return f[x]
for v,i,j in edges:
fi, fj = find(i), find(j)
cj, ci = count[fi][v], count[fj][v]
res += ci * cj
f[fj] = fi
count[fi] = Counter({v: ci + cj})
return res
이걸 내 방식대로 해석해서 어려운 문법 최대한 빼고 작성해봤다.
class Solution:
def numberOfGoodPaths(self, vals, edges):
answer = len(vals)
group = [i for i in range(len(vals))]
group_counter = [defaultdict(int) for _ in range(len(vals))]
for i in range(len(vals)):
group_counter[i][vals[i]] = 1
def find(x):
if group[x] != x:
group[x] = find(group[x])
return group[x]
# i < j
def union(i, j):
nonlocal answer
if i > j:
i,j = j,i
val = max(vals[i], vals[j])
root_i = find(i)
root_j = find(j)
group[j] = i
answer += group_counter[root_i][val] * group_counter[root_j][val]
group_counter[root_i][val] += group_counter[root_j][val]
edges.sort(key = lambda x : max(vals[x[0]], vals[x[1]]))
for i, j in edges:
union(i, j)
return answer
먼저 intuition을 설명하면서 시작해야겠다.
위와 같이 7이 여러개 들어있는 트리 두 개가 있다고 치자. 현재는 good path가 2개 있다.
이때 이렇게 중간에 3이라는, 7보다 작은값으로 두 트리를 이어주면 goodpath는 몇 개가 늘어날까?
왼쪽에 있는 7 2개, 오른쪽에 있는 7 2개가 end to end로 이어지는 경우의 수를 세면 되니까 2*2=4개의 good path가 새로 생긴다.
이 원리를 이용해서 정답을 찾을거다!! 주의할점은 우리가 7을 기준으로 계산할 때 7이 max값이어야 한다는거다. 중간에 3이 아니라 8같은게 있었으면 good path는 만들어지지 못했을거다.
그래서 값이 작은 노드부터 트리를 만들어나갈거다.
각 single node들이 스스로 goodpath를 형성하고(5개) 0번노드-1번노드 와 2번노드-3번노드 의 2개의 good path가 추가되어 정답은 7이다. 이거 어떻게 찾는지 보자.
값이 제일 작은 1부터 본다. 0번 노드와 1번 노드를 이어준다. 각각을 트리로 본다면 0번 트리에는 1이 1개 있고 1번 트리에도 1이 1개 있으므로 1*1=1. 즉 양 끝이 1인 good path가 1개 추가되었다.
1 다음으로 작은값은 2다. 2번 트리에는 1개의 2가 있고 0번 노드와 1번 노드가 이어진 트리에는 2가 0개 있으므로 양끝이 2인 good path는 1 * 0 = 0개 추가된다.
3번 트리에는 2가 1개 있고 0-1-2번 노드가 이어진 트리에는 2가 1개 있으니 1*1=1개의 새로운 good path가 추가된다.
이제 모든 2를 다 이었으니 3을 볼 차례다.
마찬가지로 1 * 0 = 0개의 새로운 good path가 추가된다.
다시 내 코드를 보자.
class Solution:
def numberOfGoodPaths(self, vals, edges):
answer = len(vals)
group = [i for i in range(len(vals))]
group_counter = [defaultdict(int) for _ in range(len(vals))]
edges.sort(key = lambda x : max(vals[x[0]], vals[x[1]]))
for i in range(len(vals)):
group_counter[i][vals[i]] = 1
def find(x):
if group[x] != x:
group[x] = find(group[x])
return group[x]
# i < j
def union(i, j):
nonlocal answer
if i > j:
i,j = j,i
val = max(vals[i], vals[j])
root_i = find(i)
root_j = find(j)
group[j] = i
answer += group_counter[root_i][val] * group_counter[root_j][val]
group_counter[root_i][val] += group_counter[root_j][val]
for i, j in edges:
union(i, j)
return answer
class Solution:
def numberOfGoodPaths(self, vals, edges):
answer = len(vals)
group = [i for i in range(len(vals))]
group_counter = [defaultdict(int) for _ in range(len(vals))]
edges.sort(key = lambda x : max(vals[x[0]], vals[x[1]]))
for i in range(len(vals)):
group_counter[i][vals[i]] = 1
먼저 이부분을 보자.
answer : 정답값을 담는 변수다. len(vals) 그러니까 노드 개수로 초기화한다. single node는 그자체로 하나의 good path를 형성하니까 미리 추가해두고 시작하는거다.
group : union-find를 구현하기 위한 배열이다. union-find를 잘 모른다면 이부분은 따로 보고 오길 바란다..! ( https://keep-your-pace.tistory.com/112?category=1000047 )
group counter는 각 그룹마다 dict를 하나씩 둔 것이다. 왜뒀냐면 특정 그룹에 특정 값이 몇 개나 있는지 한 번에 확인하기 위함이다. 아까 위에서 두 트리를 합칠때 양측에 있는 특정 값의 개수를 곱해서 새로운 good path의 개수를 구한다는걸 떠올려보면 이해하기 쉬울것이다.
edges.sort : 엣지를 정렬했다. 엣지 양쪽에 있는 값중 큰값을 기준으로 정렬했다. 큰값을 기준으로 정렬해야 한다 무조건. 작은값을 기준으로 정렬했다간 한쪽은 엄청 작은데 한쪽은 엄청 큰 엣지를 먼저 보게되는 불상사가 생긴다. 우리는 큰 값은 최대한 나중에 추가하는 방향으로 트리를 구성할거니까 안전하게 엣지 양측값중 큰값을 기준으로 오름차순 정렬해야한다.
for i in range(len(vals))부분은 group_counter를 초기화해주는 부분이다. 현재는 각 노드들이 스스로 하나의 트리를 이루는 상태이므로 값을 하나씩밖에 가지고 있지 않고 그래서 자기자신의 값(vals[i])은 1개다 라고 초기화해주는 모습.
def find(x):
if group[x] != x:
group[x] = find(group[x])
return group[x]
# union부분 생략
for i, j in edges:
union(i, j)
return answer
일단 union함수는 핵심부분이니까 생략하고 나머지 먼저 보자.
find는 그냥 인터넷에 union-find 검색하면 나오는 find 알고리즘 그대로 쓴거다. path compression이 적용되었다.
그리고 아까 오름차순 정렬된 edges를 차례로 union하면 정답이 계산된다.
그럼 이제 union을보자.
# i < j
def union(i, j):
nonlocal answer
if i > j:
i,j = j,i
val = max(vals[i], vals[j])
root_i = find(i)
root_j = find(j)
group[j] = i
answer += group_counter[root_i][val] * group_counter[root_j][val]
group_counter[root_i][val] += group_counter[root_j][val]
nonlocal answer는 함수 내 함수에서 바깥 스코프의 변수를 참조하기 때문에 선언해둔것이다.
if i > j : i ,j. = j, i는 i가 j보다 작음을 보장하기 위함이다. 이건 union-find를 위해서 보장하는거다. 둘을 합칠 때 더 작은 값인 i를 parent로 둘거다.
다른 부분은 그냥 union find 알고리즘과 완벽히 동일하고 이부분만 다르다.
#...
val = max(vals[i], vals[j])
#...
answer += group_counter[root_i][val] * group_counter[root_j][val]
group_counter[root_i][val] += group_counter[root_j][val]
val = max(..를 통해서 내가 보고 있는 기준값을 val에 저장했다. edge가 잇는 양측의 값중에 큰 값을 기준으로 봐야한다. good path의 사이에는 나보다 작거나 같은놈들만 있어야하므로 내가 보고있는 기준값은 항상 현재 형성한 tree에서 가장 max값임이 보장되어야 한다.
answer += .. 이부분에서 아까 말한 그거를 수행한다! 양측의 트리에서 값의 개수를 세서 곱하는거. 그걸 정답 변수에 더해준다.
group_counter[root_i][val] +=... 이부분은 이제 두 트리가 합쳐졌으니 값의 개수도 합쳐주는 것이다. 다만 val보다 작은 값들에 대해서는 합쳐주지 않는데 걔네들은 어차피 이제 쓸일이 없기 때문이다.
상당히 어려운 문제였다.