알고리즘/BOJ

[백준] 1717_집합의 표현 (disjoint set 이용)

엉아_ 2021. 10. 20. 21:21
728x90

📕 문제

https://www.acmicpc.net/problem/1717

 

1717번: 집합의 표현

첫째 줄에 n(1 ≤ n ≤ 1,000,000), m(1 ≤ m ≤ 100,000)이 주어진다. m은 입력으로 주어지는 연산의 개수이다. 다음 m개의 줄에는 각각의 연산이 주어진다. 합집합은 0 a b의 형태로 입력이 주어진다. 이는

www.acmicpc.net

 

💡 풀이법

: disjoint-set (분리 집합) 를 이용하는 문제이다.

 

1. find_set 함수는 현재 노드의 루트 노드를 찾아주는 함수이다.

2. union 함수는 두 노드를 합치는 함수이다.

- 두 노드중 높이가 큰 노드 아래로 높이가 낮은 노드가 담긴다.

3. parents 리스트는 각 노드의 부모 노드가 담긴다.

4. ranks 리스트는 각 노드의 높이가 담긴다.

 

import sys
sys.setrecursionlimit(10000000)
input = sys.stdin.readline

def find_set(x):
    """
    path compression 적용
    => 부모 노드를 루트 노드로 갱신
    """
    if x != parents[x]:
        parents[x] = find_set(parents[x])

    return parents[x]


def union(x, y):
    """
    union by rank 적용
    => rank 값이 더 큰 쪽에 붙이기
    """
    root_x = find_set(x)
    root_y = find_set(y)

    # root_x의 트리의 높이(rank)가 더 클 경우
    if ranks[root_x] > ranks[root_y]:
        parents[root_y] = root_x
    # root_y의 트리의 높이가 더 크거나, 혹은 둘이 같을 경우
    else:
        parents[root_x] = root_y
        # 만약에 높이가 같다면 rank 증가
        if ranks[root_x] == ranks[root_y]:
            ranks[root_y] += 1

N, M = map(int, input().split())
nodes = [i for i in range(N + 1)]
parents = [i for i in range(N + 1)]
ranks = [0] * (N + 1)

for i in range(M):
    c, a, b = map(int, input().split())
    if c :
        if find_set(a) == find_set(b):
            print('YES')
        else:
            print('NO')
    else:
        union(a, b)