[BOJ] Q11437 & Q11438 LCA

Question

Language: Python

Difficulty: Gold 3

트리 내 임의의 노드 2개에 대한 최단 거리 공통 조상을 찾는 문제이다.

공통 조상을 찾는 로직은 비교적 간단하다.

  1. bfs/dfs을 실행해서 각 노드에 대한 부모 노드, 깊이를 초기화한다.
  2. 공통 조상을 조사하고자할 노드에 대해 LCA 알고리즘을 수행한다.

LCA 알고리즘

Lowest Common Ancestor(최단 거리 공통 조상) 찾기를 위해 요구되는 알고리즘으로, 특정 노드의 깊이, 부모 노드를 알고 있을때 활용한다. 2개의 노드에 대해, 서로 같은 높이가 되도록 맞춰준다음, 같은 노드로 모일때까지 올라가는 작업을 반복해서 진행하면 된다.

#높이를 서로 맞춰주는 작업
while levels[v1] != levels[v2]:            
    if levels[v2] < levels[v1]:
        v1=parents[v1]
    else:
        v2=parents[v2]

#서로 같은 노드에서 모일때 까지 위로 올라가는 작업
while v1 != v2:
    v1=parents[v1]
    v2=parents[v2]

print(v1)

Solution 1

from collections import deque
def solution():
    parents=[0] * (n+1)
    levels=[0]*(n+1)
    
    queue=deque([(1)])
    visited=[False]*(n+1)
    visited[1]=True

    while queue:
        vertex=queue.popleft()

        for adj_vertex in graph[vertex]:
            if visited[adj_vertex]:
                continue
            visited[adj_vertex]=True

            parents[adj_vertex]=vertex
            levels[adj_vertex]=levels[vertex]+1
            queue.append(adj_vertex)
    
    for v1, v2 in queries:
        while levels[v1] != levels[v2]:            
            if levels[v2] < levels[v1]:
                v1=parents[v1]
            
            else:
                v2=parents[v2]
        while v1 != v2:
            v1=parents[v1]
            v2=parents[v2]

        print(v1)


if __name__ == "__main__":
    n=int(input())
    graph=[[] for _ in range(n+1)]
    for _ in range(n-1):
        v1,v2=map(int,input().split())
        graph[v1].append(v2)
        graph[v2].append(v1)
    m=int(input())
    queries=[list(map(int,input().split())) for _ in range(m)]
    
    solution()

Question

Language: Python

Difficulty: Platinum 5

위의 문제의 경우, 일반적인 LCA 알고리즘을 활용하여 문제를 풀이하는 것이 가능하다. 하지만, N(노드의 갯수) 값이 커지게 되면 시간 초과가 발생하게 된다.

11438 문제의 경우를 확인해보면, N의 최대값은 50000, LCA를 찾고자하는 순서쌍의 갯수는 최대 10000개인데, Tree의 최악의 시간 복잡도는 완전 편향트리로, O(n)이기 때문에, 최종적인 시간의 복잡도의 경우 O(NM)이 되면서 5*108이 되므로 시간 초과가 발생한다.

그래서 위와 같이 N이 매우 큰 경우에 대해서는 한칸씩 비교하는 것이 아닌, 2의 거듭제곱꼴로 부모를 저장하므로써 한번에 여러칸을 올라가는 작업을 진행해서, N을 logN으로 시간 복잡도를 낮춘다.

거듭제곱꼴의 부모 노드 배열 초기화

parents=[[0]*LENGTH for _ in range(n+1)]
for i in range(1,LENGTH):
    for j in range(1,n+1):
        parents[j][i]=parents[parents[j][i-1]][i-1]

현재 노드에 대해 각각 2의 거듭제곱 윗칸의 노드를 초기화하도록한다.

개선된 LCA 알고리즘

v1,v2=(v1,v2) if levels[v2]>levels[v1] else (v2,v1)
#깊이를 맞춰주는 작업
for i in range(LENGTH-1,-1,-1):
    if levels[v2]-levels[v1]>=2**i:
        v2=parents[v2][i]

if v1==v2:
    print(v1)
    continue

#깊이가 같을 때거슬러 올라가는 작업
for i in range(LENGTH-1,-1,-1):
    if parents[v1][i]!=parents[v2][i]:
        v1=parents[v1][i]
        v2=parents[v2][i]

print(parents[v1][0])

기존의 LCA 알고리즘의 경우 부모노드로의 접근 과정에서 한칸씩 올라갔지만, 현재 알고리즘의 경우 2의 거듭제곱꼴로 올라가기 때문에 시간 복잡도를 O(n) 에서 O(logn)으로 줄일 수 있다.

from collections import deque
def solution():
    parents=[[0]*LENGTH for _ in range(n+1)]
    levels=[0]*(n+1)
    
    queue=deque([(1)])
    visited=[False]*(n+1)
    visited[1]=True

    while queue:
        vertex=queue.popleft()

        for adj_vertex in graph[vertex]:
            if visited[adj_vertex]:
                continue
            visited[adj_vertex]=True

            parents[adj_vertex][0]=vertex
            levels[adj_vertex]=levels[vertex]+1
            queue.append(adj_vertex)
    
    for i in range(1,LENGTH):
        for j in range(1,n+1):
            parents[j][i]=parents[parents[j][i-1]][i-1]
    

    for v1, v2 in queries:
        v1,v2=(v1,v2) if levels[v2]>levels[v1] else (v2,v1)

        for i in range(LENGTH-1,-1,-1):
            if levels[v2]-levels[v1]>=2**i:
                v2=parents[v2][i]

        if v1==v2:
            print(v1)
            continue

        for i in range(LENGTH-1,-1,-1):
            if parents[v1][i]!=parents[v2][i]:
                v1=parents[v1][i]
                v2=parents[v2][i]
        
        print(parents[v1][0])
        
if __name__ == "__main__":
    n=int(input())
    graph=[[] for _ in range(n+1)]
    LENGTH=21
    for _ in range(n-1):
        v1,v2=map(int,input().split())
        graph[v1].append(v2)
        graph[v2].append(v1)
    m=int(input())
    queries=[list(map(int,input().split())) for _ in range(m)]
    
    solution()

댓글남기기