[BOJ] Q1162 도로포장

Question

Language: Python

Difficulty: Gold 1

도로 k 개 이하를 포장해서 얻을 수 있는 최단 경로의 길이를 구하는 문제이다.

처음 생각으로는 재귀문을 이용해서 도로를 지우고, k개를 지워졌을때 다이직스트라를 이용해서 최단경로를 구하려고 했다.

Fail Code

import heapq
import sys
from math import inf

def dijkstra():
    distance=[inf] * (v+1)
    distance[1]=0

    heap=[(0,1)]

    while heap:
        cost,vertex=heapq.heappop(heap)

        if cost>distance[vertex]:
            continue
            
        for v1,v2,weight in edges:
            if v1 != vertex:
                continue

            cost=distance[v1]+weight
            if distance[v2] > cost:
                distance[v2]=cost
                heapq.heappush(heap,(cost,v2))
    
    return distance[-1]

def solution(cnt):
    global result
    if cnt==k:
        result=min(result,dijkstra())

    for i in range(m):
        if deleted[i]==0:
            deleted[i]=edges[i][2]
            edges[i][2]=0
            
            solution(cnt+1)
            
            edges[i][2]=deleted[i]
            deleted[i]=0

        
if __name__ == "__main__":
    sys.setrecursionlimit(10**6)
    v,m,k=map(int,input().split())
    edges=[list(map(int,input().split())) for _ in range(m)]
    deleted=[0]*(m)
    result=inf
    solution(0)
    print(result)

하지만 이렇게 풀다 보니 시간 초과가 났는데, 어쩌면 당연하게도 시간 초과가 날수 밖에 없는 게 문제의 조건이다. 도로의 개수는 최대 50000개에 포장가능한 도로는 최대 20개로 이를 재귀문으로 돌리게 되면 시간이 많이 소요된다.

그러면 이 문제를 다이직스트라로 해결해야한다는 소리인데… 다이직스트라에서는 distance 배열이 1차원 배열이다. 하지만 우리는 도로포장의 경우도 고려해야된다. 그렇다 distance 배열을 2차원 배열로 확장시킨다 pave 변수 까지 고려하도록 한다.

기존 distance 최신화

if distance[adj_vertex][pave] > cost:
    distance[adj_vertex][pave]=cost
    heapq.heappush(heap,(cost,adj_vertex,pave))

도로 포장 시 거리 최신화

if pave+1<=k and distance[adj_vertex][pave+1]>distance[vertex][pave]:
    distance[adj_vertex][pave+1]=distance[vertex][pave]
    heapq.heappush(heap,(distance[vertex][pave],adj_vertex,pave+1))

Solution

import heapq
from math import inf

def solution():
    global result
    distance=[[inf] * (k+1) for _ in range(v+1)]
    for i in range(k+1):
        distance[1][i]=0

    heap=[(0,1,0)]

    while heap:
        weight,vertex,pave=heapq.heappop(heap)

        if distance[vertex][pave]<weight:
            continue
        if vertex==v:
            result=min(result,weight)
            continue
            
        for adj_vertex,weight in graph[vertex]:
            cost=distance[vertex][pave]+weight
            if distance[adj_vertex][pave] > cost:
                distance[adj_vertex][pave]=cost
                heapq.heappush(heap,(cost,adj_vertex,pave))

            if pave+1<=k and distance[adj_vertex][pave+1]>distance[vertex][pave]:
                distance[adj_vertex][pave+1]=distance[vertex][pave]
                heapq.heappush(heap,(distance[vertex][pave],adj_vertex,pave+1))

    

        
if __name__ == "__main__":
    v,m,k=map(int,input().split())
    graph=[[] for _ in range(v+1)]
    result=inf
    for _ in range(m):
        v1,v2,weight=map(int,input().split())
        graph[v1].append((v2,weight))
        graph[v2].append((v1,weight))

    solution()
    print(result)

댓글남기기