가장 큰 정사각형 - 1915
[Gold II] 중앙값 구하기 - 2696
분류 📌
자료 구조(data_structures), 우선순위 큐(priority_queue)
문제 설명 📌
어떤 수열을 읽고, 홀수번째 수를 읽을 때 마다, 지금까지 입력받은 값의 중앙값을 출력하는 프로그램을 작성하시오.
예를 들어, 수열이 1, 5, 4, 3, 2 이면, 홀수번째 수는 1번째 수, 3번째 수, 5번째 수이고, 1번째 수를 읽었을 때 중앙값은 1, 3번째 수를 읽었을 때는 4, 5번째 수를 읽었을 때는 3이다.
입력 📌
첫째 줄에 테스트 케이스의 개수 T(1 ≤ T ≤ 1,000)가 주어진다. 각 테스트 케이스의 첫째 줄에는 수열의 크기 M(1 ≤ M ≤ 9999, M은 홀수)이 주어지고, 그 다음 줄부터 이 수열의 원소가 차례대로 주어진다. 원소는 한 줄에 10개씩 나누어져있고, 32비트 부호있는 정수이다.
출력 📌
각 테스트 케이스에 대해 첫째 줄에 출력하는 중앙값의 개수를 출력하고, 둘째 줄에는 홀수 번째 수를 읽을 때 마다 구한 중앙값을 차례대로 공백으로 구분하여 출력한다. 이때, 한 줄에 10개씩 출력해야 한다.
풀이 🏆
-
처음에는 직관적으로 비어 있는 배열 하나에 입력값을 2개씩 넣고 정렬하여 중앙값을 고르는 방법으로 풀었음
-
이는 테스트 케이스 T만큼의 반복분에 또 중간에 정렬을 해야하기 때문에 MlogM 즉 1000*9999log9999 이 기본으로 붙고 부가적인 시간 복잡도가 붙는다. 따라서 불안하니 다른 방법
-
우선순위 큐를 이용하여,
입력 받은 value를 중앙값보다 작다면 왼쪽 힙 max힙에 삽입, 중앙값보다 크다면 오른쪽 힙 min힙에 삽입
2개 씩 받을 때 마다 둘 중(min 힙, max 힙) 길이가 짧은 힙에 중앙값을 삽입하고 길이가 긴 힙에서 루트 부분을 꺼내면 새로운 중앙 값이 된다. -
힙에서 값을 넣고, 빼는 것은 O(logM)이고 이를 M번 반복하니깐(물론 한 쪽 힙에서 다른 쪽 이동 등 부가적인 거 빼고)
총 시간 복잡도는 O(MlogM)이다.
구현 🚀
- heap의 사용을 위해 heapq 모듈을 import함
- 중앙값 보다 작은 값을 저장 하기 위한 중앙값 왼쪽의 max heap과 큰 값을 저장 하기 위한 중앙값 오른쪽의 min heap을 각각 준비
- 입력 받은 value 들을 중앙값과 비교하여 맞는 위치의 heap에 넣는다
- 2번씩 입력 받을 때 마다 heap의 길이를 비교하여 더 긴 쪽에 중앙값이 존재하므로 기존에 알아 냈던 중앙값은 짧은 곳에 넣고 긴 곳에서 루트 값을 꺼내 중앙값이 됨
코드 📃
import heapq
import sys
input = sys.stdin.readline
t = int(input())
def check(data):
mid = data[0]
maxl = []
minr = []
result = []
result.append(mid)
# idx 는 1부터 시작
# heapq는 기본 적으로 min heap이기 때문에
# max heap을 구현 할려면 값을 부호를 바꿔서 넣고, 부호를 바꿔서 빼야함
for idx, val in enumerate(data[1:], 1):
if val < mid:
heapq.heappush(maxl, -val)
else:
heapq.heappush(minr, val)
if idx % 2 == 0:
if len(maxl) > len(minr):
heapq.heappush(minr, mid)
mid = -heapq.heappop(maxl)
elif len(maxl) < len(minr):
heapq.heappush(maxl, -mid)
mid = heapq.heappop(minr)
result.append(mid)
print(len(result))
for i in range(len(result)):
if (i+1) != 1 and (i+1) % 10 == 1:
print()
print(result[i], end=' ')
print()
for _ in range(t):
m = int(input())
data = []
if m % 10 == 0:
for _ in range(m//10):
data.extend(list(map(int, input().rstrip().split())))
else:
for _ in range(m//10 + 1):
data.extend(list(map(int, input().rstrip().split())))
check(data)
댓글남기기