[BOJ] Q2887 행성 터널

Question

Language: Python

Difficulty: Gold 1

각각의 행성들에 대해 모든 인접한 노드들에 대한 거리를 구해서 이에 대한 간선을 모두 추가 시켜주고 난후 kruskal algorithm으로 MST을 구한다.

Fail Code

from math import inf
from pyexpat.errors import codes
def find_parents_collapsed(parents,x):
    if x!= parents[x]:
        parents[x]=find_parents_collapsed(parents,parents[x])
    return parents[x]

def union_parents(parents,x,y):
    pre_x=find_parents_collapsed(parents,x)
    pre_y=find_parents_collapsed(parents,y)

    if pre_x <pre_y:
        parents[pre_x]=pre_y
    else:
        parents[pre_y]=pre_x

def solution():
    parents=[i for i in range(num)]
    edges=[]
    edge_count=0
    result=0
    for i in range(num):
        for j in range(num):
            if i==j:
                continue

            min_value=min(abs(points[i][0]-points[j][0]),abs(points[i][1]-points[j][1]),abs(points[i][2]-points[j][2]))

            edges.append((min_value,i,j))
    
    edges.sort()

    while edge_count < num-1:
        cost,v1,v2=edges.pop(0)

        if find_parents_collapsed(parents,v1) != find_parents_collapsed(parents,v2):
            union_parents(parents,v1,v2)
            result+=cost
            edge_count+=1
    
    return result
if __name__ == "__main__":
    num=int(input())
    points=[list(map(int,input().split())) for _ in range(num)]
    
    print(solution())

하지만 이렇게 풀게 되면 메모리 초과가 발생하게 된다. input이 최대 100,000가 올 수 있는데, 간선은 최대 5,000,000,000 개가 발생 할 수 있다. 따라서 모든 축에 대해서 생각을 하고 이를 MST로 돌리면 안된다.

대신, 각 축별로 분리해서 생각을 해보자.그리고 나서 각각의 축의 좌표들을 정렬을 한 다음 이들을 인접한 좌표끼리 연결하자 이렇게 하면 각 축별로 MST가 생기게 된다. 이렇게 생각을 하게 되면 고려해야 될 간선의 수는 최대 3*(N-1)로 모든 좌표축을 생각했을 때보다 과정을 단출 시킬 수 있다.

Solution

from heapq import heappush,heappop
from math import inf

def find_parent_compressed(parent,x):
    if parent[x] != x:
        parent[x]=find_parent_compressed(parent,parent[x])
    return parent[x]

def union_parent(parent,x,y):
    pre_x=find_parent_compressed(parent,x)
    pre_y=find_parent_compressed(parent,y)

    if pre_x < pre_y:
        parent[pre_y]=pre_x
    else:
        parent[pre_x]=pre_y

num=int(input())
x=[]
y=[]
z=[]

for i in range(1,num+1):
    data=list(map(int,input().split()))
    x.append((data[0],i))
    y.append((data[1],i))
    z.append((data[2],i))

x.sort()
y.sort()
z.sort()

heap=[]
parents=[i for i in range(num+1)]
result=0
edge_count=0

#이렇게 각 축에 대해 정렬을 한 후 인접한 좌표 끼리 이어주게 되면, 고려해야될 간선의 개수를 획기적으로 줄일 수 있다.
for i in range(num-1):
    heappush(heap,(x[i+1][0]-x[i][0],x[i][1],x[i+1][1]))
    heappush(heap,(y[i+1][0]-y[i][0],y[i][1],y[i+1][1]))
    heappush(heap,(z[i+1][0]-z[i][0],z[i][1],z[i+1][1]))

while edge_count<num-1:
    cost,vertex1,vertex2=heappop(heap)

    if find_parent_compressed(parents,vertex1) != find_parent_compressed(parents,vertex2):
        union_parent(parents,vertex1,vertex2)
        result+=cost
        edge_count+=1

print(result)

    
    

댓글남기기