10989번: 수 정렬하기 3

첫째 줄에 수의 개수 N(1 ≤ N ≤ 10,000,000)이 주어진다. 둘째 줄부터 N개의 줄에는 수가 주어진다. 이 수는 10,000보다 작거나 같은 자연수이다.

www.acmicpc.net


카운팅 정렬에 대해서는 다음 블로그 참고

https://elrion018.tistory.com/37

 

카운팅 정렬은 non-comparison sort 기법으로 두 수를 반복적으로 비교하는 정렬 기법이 아니다.

 

counting sort

1. input 배열 요소들의 빈도 값을 세어서 counting array에 저장

counting array의 각 요소 값은 인덱스에 해당하는 요소가 input 배열에 얼마나 존재하는지를 의미한다.

 

2. counting array의 각 요소에 직전 요소를 더해서 업데이트, input 배열을 정렬해서 담을 output 배열 생성

 

3. input 배열의 요소를 역순으로 output배열에 채워 넣음

counting 배열을 참조하되 참조 후에는 값을 하나씩 빼준다.

 

import sys
n = int(input())
input = [int(sys.stdin.readline()) for _ in range(n)]

cnt = [0] * (max(input) + 1)
for n in input:
  if cnt[n] == 0:
    cnt[n] = input.count(n)

for i in range(1, len(cnt)):
  cnt[i] += cnt[i - 1]

output = [-1] * len(input)

for n in reversed(input):
  output[cnt[n] - 1] = n
  cnt[n] -= 1

for n in output: print(n)

위 블로그를 참고하여 짠 카운팅 정렬 코드

근데 이대로 제출하면 메모리 초과 오류가 난다. 

 

이유는 아래 링크 참고

 

글 읽기 - ★☆★☆★ [필독] 수 정렬하기 3 FAQ ★☆★☆★

댓글을 작성하려면 로그인해야 합니다.

www.acmicpc.net

간단하게 요약하자면 문제의 메모리 제한은 겨우 8MB인데 모든 입력을 배열에 저장하면 당연히 메모리 초과라는 이야기다. 심지어 작성한 코드에서는 입력만 저장한 게 아니라 정렬 후 결과를 저장한 배열까지 있으니 메모리 차지를 많이 할 수 밖에 없다. 

 

그래서 메모리 초과없이 짜기 위해서 입력 배열과 출력 배열을 만들지 않기로 했다.

 

1. counting array에 바로 값 저장

주어지는 입력 값들은 모두 10,000보다 작거나 같은 자연수라고 명시되어 있기 때문에 cnt의 최대 크기는 10001이므로 다음과 같이 크기를 설정한다. 

import sys
n = int(input())

cnt = [0] * (10001)
for i in range(n):
  cnt[int(sys.stdin.readline())] += 1
더보기

정렬된 배열을 만들지 않을 것이기 때문에 사실 counting array의 각 요소에 직전 요소를 더해서 업데이트를 하는 과정은 필요가 없다. 그것도 모르고 한참 삽질을 해서 다음과 같은 코드를 짰는데 출력을 하다가 뭔가 이상하다는 것을 깨달았다.

for i in range(1, 10001):
  if cnt[i - 1] != 0 and cnt[i] != 0:
    cnt[i] += cnt[i - 1]
    tmp = cnt[i]
  elif cnt[i] != 0:
    cnt[i] += tmp
    tmp = cnt[i]

예제가 다음과 같다면

20
2
0
1
4

10
5
4
3
2
0

10
1
1
0

20
5
4
3
10
20

이런식으로 cnt에 저장되게 하는 코드다.

 

하지만 전혀 필요가 없는 코드였다는 사실~

 

2. 그냥 cnt에 저장되어 있는 수만큼 인덱스를 출력해주면 된다. 

for i in range(10001):
  if cnt[i] != 0:
    for _ in range(cnt[i]):
      print(i)

 

전체 코드

import sys
n = int(input())

cnt = [0] * (10001)
for i in range(n):
  cnt[int(sys.stdin.readline())] += 1
  
for i in range(10001):
  if cnt[i] != 0:
    for _ in range(cnt[i]):
      print(i)

 

복사했습니다!