트리의 지름 - 1167
[Gold II] 트리의 지름 - 1167
트리의 지름이란 💫
- 트리에서 임의의 두 점 사이의 거리 중 가장 긴 것
구하는 방법 (?)
-
임의의 점 A에서 가장 먼 지점 B를 찾음
-
B에서 가장 먼 지점 C를 찾음
-
B ~ C의 거리가 트리의 지름
증명 🔓
어떤 임의의 점 A에서 가장 먼 곳에 있는 것은 트리의 지름 끝 점 중 하나이다.
귀류법
A에서 가장 먼 지점을 B, 트리의 지름 (D, E), B가 D나 E가 아니라 가정
- X는 (A, B)위의 노드, (A, X) >= (B, X)라고 가정
- Y는 (D, E)위의 노드, (D, Y) >= (E, Y)라고 가정
트리의 지름 : (D, E) 따라서 (D, E) >= (A, B)
- (A, X) >= (A, B)/2 >= (B, X)
- (D, Y) >= (D, E)/2 >= (E, Y)
따라서, (D, Y) >= (D, E)/2 >= (A, B)/2 >= (B, X) 따라서, (A, D) = (A, X) + (X, Y) + (Y, D) >= (A, X) + (X, Y) + (B, X) = (A, B) + (X, Y)
A에서 가장 먼 지점을 B라고 가정했지만, (A, B)보다 더 먼 (A, D)가 존재
따라서, 귀류법에 의하여 A에서 가장 먼 점 B가 D, E가 아니라는 것이 거짓이므로 B는 D, E중 하나 결국 A에서 가장 먼 지점은 트리의 지름 양 끝점 중 하나
분류 1️⃣
깊이 우선 탐색(dfs), 그래프 이론(graphs), 그래프 탐색(graph_traversal), 트리(trees)
문제 설명 2️⃣
트리의 지름이란, 트리에서 임의의 두 점 사이의 거리 중 가장 긴 것을 말한다. 트리의 지름을 구하는 프로그램을 작성하시오.
입력 3️⃣
트리가 입력으로 주어진다. 먼저 첫 번째 줄에서는 트리의 정점의 개수 V가 주어지고 (2 ≤ V ≤ 100,000)둘째 줄부터 V개의 줄에 걸쳐 간선의 정보가 다음과 같이 주어진다. 정점 번호는 1부터 V까지 매겨져 있다.
먼저 정점 번호가 주어지고, 이어서 연결된 간선의 정보를 의미하는 정수가 두 개씩 주어지는데, 하나는 정점번호, 다른 하나는 그 정점까지의 거리이다. 예를 들어 네 번째 줄의 경우 정점 3은 정점 1과 거리가 2인 간선으로 연결되어 있고, 정점 4와는 거리가 3인 간선으로 연결되어 있는 것을 보여준다. 각 줄의 마지막에는 -1이 입력으로 주어진다. 주어지는 거리는 모두 10,000 이하의 자연수이다.
출력 4️⃣
첫째 줄에 트리의 지름을 출력한다.
풀이 🔓
위의 트리의 지름 증명에 따르면 결국 임의의 노드에서 가장 긴 경로로 연결돼 있는 노드는 트리의 지름에 해당하는 두 노드 중 하나다. 따라서 임의의 노드 A에서 가장 긴 경로의 노드 B를 구하고, B는 트리의 지름의 양 끝 점중 하나 니깐 B에서 가장 긴 경로의 노드 C를 구하면 트리의 지름 (B, C)를 얻게 된다✌
구현 ✨
- 그래프를 인접 리스트로 저장하는데 이때 [노드, 가중치]로 인접리스트를 구성한다.
- 방문 체크 배열인 visited 배열을 선언하는데 이 배열은 DFS 또는 BFS를 수행하고 탐색할 때 각 노드의 거리를 배열에 저장한다.
- 2 과정에서 얻은 배열에서 가장 긴 거리를 가지는 노드를 뽑아서 다시 DFS 또는 BFS를 수행한다
- 트리의 지름을 얻을 수 있다.
코드(DFS) 📃
import sys
input = sys.stdin.readline
v = int(input())
graph = [[] for _ in range(v+1)]
for _ in range(v):
tmp = list(map(int, input().split()))
for i in range(1, len(tmp)-2, 2):
graph[tmp[0]].append([tmp[i], tmp[i+1]])
visited = [-1]*(v+1)
def dfs(a, b):
for e, d in graph[a]:
if visited[e] == -1:
visited[e] = d + b
dfs(e, visited[e])
visited[1] = 0
dfs(1, visited[1])
start = visited.index(max(visited))
visited = [-1]*(v+1)
visited[start] = 0
dfs(start, visited[start])
print(max(visited))
코드(BFS) 📃
import sys
from collections import deque
input = sys.stdin.readline
v = int(input())
graph = [[] for _ in range(v+1)]
for _ in range(v):
tmp = list(map(int, input().split()))
for i in range(1, len(tmp)-2, 2):
graph[tmp[0]].append([tmp[i], tmp[i+1]])
def bfs(start):
visited = [-1]*(v+1)
q = deque()
q.append(start)
visited[start] = 0
_max = [0, 0] # distance, node
while q:
x = q.popleft()
for e, d in graph[x]:
if visited[e] == -1:
visited[e] = d + visited[x]
q.append(e)
if visited[e] > _max[0]:
_max = [visited[e], e]
return _max
distance, vertex = bfs(1)
distance, vertex = bfs(vertex)
print(distance)
댓글남기기