2 분 소요

[Gold II] 중앙값 구하기 - 2696

문제 링크

분류 📌

자료 구조(data_structures), 우선순위 큐(priority_queue)

문제 설명 📌

어떤 수열을 읽고, 홀수번째 수를 읽을 때 마다, 지금까지 입력받은 값의 중앙값을 출력하는 프로그램을 작성하시오.

예를 들어, 수열이 1, 5, 4, 3, 2 이면, 홀수번째 수는 1번째 수, 3번째 수, 5번째 수이고, 1번째 수를 읽었을 때 중앙값은 1, 3번째 수를 읽었을 때는 4, 5번째 수를 읽었을 때는 3이다.

입력 📌

첫째 줄에 테스트 케이스의 개수 T(1 ≤ T ≤ 1,000)가 주어진다. 각 테스트 케이스의 첫째 줄에는 수열의 크기 M(1 ≤ M ≤ 9999, M은 홀수)이 주어지고, 그 다음 줄부터 이 수열의 원소가 차례대로 주어진다. 원소는 한 줄에 10개씩 나누어져있고, 32비트 부호있는 정수이다.

출력 📌

각 테스트 케이스에 대해 첫째 줄에 출력하는 중앙값의 개수를 출력하고, 둘째 줄에는 홀수 번째 수를 읽을 때 마다 구한 중앙값을 차례대로 공백으로 구분하여 출력한다. 이때, 한 줄에 10개씩 출력해야 한다.

풀이 🏆

  • 처음에는 직관적으로 비어 있는 배열 하나에 입력값을 2개씩 넣고 정렬하여 중앙값을 고르는 방법으로 풀었음

  • 이는 테스트 케이스 T만큼의 반복분에 또 중간에 정렬을 해야하기 때문에 MlogM 즉 1000*9999log9999 이 기본으로 붙고 부가적인 시간 복잡도가 붙는다. 따라서 불안하니 다른 방법

  • 우선순위 큐를 이용하여, 입력 받은 value를 중앙값보다 작다면 왼쪽 힙 max힙에 삽입, 중앙값보다 크다면 오른쪽 힙 min힙에 삽입 2개 씩 받을 때 마다 둘 중(min 힙, max 힙) 길이가 짧은 힙에 중앙값을 삽입하고 길이가 긴 힙에서 루트 부분을 꺼내면 새로운 중앙 값이 된다.

  • 힙에서 값을 넣고, 빼는 것은 O(logM)이고 이를 M번 반복하니깐(물론 한 쪽 힙에서 다른 쪽 이동 등 부가적인 거 빼고) 총 시간 복잡도는 O(MlogM)이다.

구현 🚀

  • heap의 사용을 위해 heapq 모듈을 import함
  • 중앙값 보다 작은 값을 저장 하기 위한 중앙값 왼쪽의 max heap과 큰 값을 저장 하기 위한 중앙값 오른쪽의 min heap을 각각 준비
  • 입력 받은 value 들을 중앙값과 비교하여 맞는 위치의 heap에 넣는다
  • 2번씩 입력 받을 때 마다 heap의 길이를 비교하여 더 긴 쪽에 중앙값이 존재하므로 기존에 알아 냈던 중앙값은 짧은 곳에 넣고 긴 곳에서 루트 값을 꺼내 중앙값이 됨

코드 📃



import heapq
import sys

input = sys.stdin.readline
t = int(input())


def check(data):
    mid = data[0]
    maxl = []
    minr = []
    result = []
    result.append(mid)
    # idx 는 1부터 시작
    # heapq는 기본 적으로 min heap이기 때문에
    # max heap을 구현 할려면 값을 부호를 바꿔서 넣고, 부호를 바꿔서 빼야함
    for idx, val in enumerate(data[1:], 1):
        if val < mid:
            heapq.heappush(maxl, -val)
        else:
            heapq.heappush(minr, val)
        if idx % 2 == 0:
            if len(maxl) > len(minr):
                heapq.heappush(minr, mid)
                mid = -heapq.heappop(maxl)
            elif len(maxl) < len(minr):
                heapq.heappush(maxl, -mid)
                mid = heapq.heappop(minr)
            result.append(mid)
    
    print(len(result))
    for i in range(len(result)):
        if (i+1) != 1 and (i+1) % 10 == 1:
            print()
        print(result[i], end=' ')
    print()


for _ in range(t):
    m = int(input())
    data = []
    if m % 10 == 0:
        for _ in range(m//10):
            data.extend(list(map(int, input().rstrip().split())))
    else:
        for _ in range(m//10 + 1):
            data.extend(list(map(int, input().rstrip().split())))
    check(data)

        

댓글남기기