3 분 소요

[Gold II] 트리의 지름 - 1167

문제 링크

트리의 지름이란 💫

  • 트리에서 임의의 두 점 사이의 거리 중 가장 긴 것

구하는 방법 (?)

  1. 임의의 점 A에서 가장 먼 지점 B를 찾음

  2. B에서 가장 먼 지점 C를 찾음

  3. B ~ C의 거리가 트리의 지름

증명 🔓

어떤 임의의 점 A에서 가장 먼 곳에 있는 것은 트리의 지름 끝 점 중 하나이다.

귀류법

A에서 가장 먼 지점을 B, 트리의 지름 (D, E), B가 D나 E가 아니라 가정

image

  1. X는 (A, B)위의 노드, (A, X) >= (B, X)라고 가정
  2. Y는 (D, E)위의 노드, (D, Y) >= (E, Y)라고 가정

트리의 지름 : (D, E) 따라서 (D, E) >= (A, B)

  • (A, X) >= (A, B)/2 >= (B, X)
  • (D, Y) >= (D, E)/2 >= (E, Y)

따라서, (D, Y) >= (D, E)/2 >= (A, B)/2 >= (B, X) 따라서, (A, D) = (A, X) + (X, Y) + (Y, D) >= (A, X) + (X, Y) + (B, X) = (A, B) + (X, Y)

A에서 가장 먼 지점을 B라고 가정했지만, (A, B)보다 더 먼 (A, D)가 존재

따라서, 귀류법에 의하여 A에서 가장 먼 점 B가 D, E가 아니라는 것이 거짓이므로 B는 D, E중 하나 결국 A에서 가장 먼 지점은 트리의 지름 양 끝점 중 하나

분류 1️⃣

깊이 우선 탐색(dfs), 그래프 이론(graphs), 그래프 탐색(graph_traversal), 트리(trees)

문제 설명 2️⃣

트리의 지름이란, 트리에서 임의의 두 점 사이의 거리 중 가장 긴 것을 말한다. 트리의 지름을 구하는 프로그램을 작성하시오.

입력 3️⃣

트리가 입력으로 주어진다. 먼저 첫 번째 줄에서는 트리의 정점의 개수 V가 주어지고 (2 ≤ V ≤ 100,000)둘째 줄부터 V개의 줄에 걸쳐 간선의 정보가 다음과 같이 주어진다. 정점 번호는 1부터 V까지 매겨져 있다.

먼저 정점 번호가 주어지고, 이어서 연결된 간선의 정보를 의미하는 정수가 두 개씩 주어지는데, 하나는 정점번호, 다른 하나는 그 정점까지의 거리이다. 예를 들어 네 번째 줄의 경우 정점 3은 정점 1과 거리가 2인 간선으로 연결되어 있고, 정점 4와는 거리가 3인 간선으로 연결되어 있는 것을 보여준다. 각 줄의 마지막에는 -1이 입력으로 주어진다. 주어지는 거리는 모두 10,000 이하의 자연수이다.

출력 4️⃣

첫째 줄에 트리의 지름을 출력한다.



풀이 🔓

위의 트리의 지름 증명에 따르면 결국 임의의 노드에서 가장 긴 경로로 연결돼 있는 노드는 트리의 지름에 해당하는 두 노드 중 하나다. 따라서 임의의 노드 A에서 가장 긴 경로의 노드 B를 구하고, B는 트리의 지름의 양 끝 점중 하나 니깐 B에서 가장 긴 경로의 노드 C를 구하면 트리의 지름 (B, C)를 얻게 된다✌

구현 ✨

  1. 그래프를 인접 리스트로 저장하는데 이때 [노드, 가중치]로 인접리스트를 구성한다.
  2. 방문 체크 배열인 visited 배열을 선언하는데 이 배열은 DFS 또는 BFS를 수행하고 탐색할 때 각 노드의 거리를 배열에 저장한다.
  3. 2 과정에서 얻은 배열에서 가장 긴 거리를 가지는 노드를 뽑아서 다시 DFS 또는 BFS를 수행한다
  4. 트리의 지름을 얻을 수 있다.

코드(DFS) 📃

import sys

input = sys.stdin.readline

v = int(input())
graph = [[] for _ in range(v+1)]

for _ in range(v):
  tmp = list(map(int, input().split()))
  for i in range(1, len(tmp)-2, 2):
    graph[tmp[0]].append([tmp[i], tmp[i+1]])

visited = [-1]*(v+1)

def dfs(a, b):
  for e, d in graph[a]:
    if visited[e] == -1:
      visited[e] = d + b
      dfs(e, visited[e])


visited[1] = 0
dfs(1, visited[1])
start = visited.index(max(visited))
visited = [-1]*(v+1)
visited[start] = 0
dfs(start, visited[start])
print(max(visited))

코드(BFS) 📃


import sys
from collections import deque

input = sys.stdin.readline
v = int(input())
graph = [[] for _ in range(v+1)]

for _ in range(v):
  tmp = list(map(int, input().split()))
  for i in range(1, len(tmp)-2, 2):
    graph[tmp[0]].append([tmp[i], tmp[i+1]])

def bfs(start):
  visited = [-1]*(v+1)
  q = deque()
  q.append(start)
  visited[start] = 0
  _max = [0, 0] # distance, node

  while q:
    x = q.popleft()

    for e, d in graph[x]:
      if visited[e] == -1:
        visited[e] = d + visited[x]
        q.append(e)
        if visited[e] > _max[0]:
          _max = [visited[e], e]
  
  return _max

distance, vertex = bfs(1)
distance, vertex = bfs(vertex)
print(distance)

댓글남기기