문제 출처 - https://www.acmicpc.net/problem/2263

문제를 처음봤을 때 접근 방법이 도저히 떠오르지 않아서 다른 문제를 풀다가 다시 시도했다. 트리의 노드를 여러 개 그려서 inorder과 postorder을 각각 써보면 규칙을 발견할 수 있다.

각각의 순회는 재귀적으로 표현할 수 있다.

inorder - 왼쪽 서브 트리 / 루트 / 오른쪽 서브 트리 → 여기서 각각의 서브 트리도 다시 inorder이다. 즉 다음과 같이 표현할 수 있다.

왼쪽 서브트리의-왼쪽 서브 트리 / 왼쪽 서브 트리의 루트 / 왼쪽 서브트리의-오른쪽 서브 트리 // 루트 // 오른쪽 서브트리의-왼쪽 서브 트리 / 오른쪽 서브 트리의 루트 / 오른쪽 서브트리의-오른쪽 서브 트리

postorder- 왼쪽 서브 트리 / 오른쪽 서브 트리 / 루트 → 여기서 마찬가지이다.

왼쪽 서브트리의-왼쪽 서브 트리 / 왼쪽 서브트리의-오른쪽 서브 트리 / 왼쪽 서브 트리의 루트 / 오른쪽 서브트리의-왼쪽 서브 트리 / 오른쪽 서브트리의-오른쪽 서브 트리 / 오른쪽 서브 트리의 루트 / 루트

규칙이 보인다. 주어진 postorder에서 각각의 서브트리 가장 오른쪽에는 루트 노드의 번호가 적혀있다.

1. 풀이

  • postorder에서 루트의 번호를 알아낸다.
  • 루트의 번호를 이용하여 inorder에서 루트의 위치(인덱스)를 확인한다.
  • inorder에서 루트는 왼쪽 서브 트리와 오른쪽 서브 트리 중간에 위치하므로 인덱스를 이용해 왼쪽 및 오른쪽 서브 트리의 개수를 구할 수 있다.
  • 왼쪽 및 오른쪽 서브 트리의 개수를 이용하여 postorder에서 각각의 서브 트리의 마지막 노드의 위치를 알 수 있다.
  • postorder에서 서브 트리의 마지막 노드의 위치는 다시 서브 트리의 루트의 번호다.

정리하자면, 루트의 번호를 찾고(post) → 왼쪽 및 오른쪽 서브트리의 길이를 찾는다(in) → 왼쪽 및 오른쪽 서브트리의 루트의 번호를 찾는다. (post) → order의 길이가 1일 때까지 반복.

위 과정에서 보면 가장 위의 루트 → 왼쪽 서브트리의 루트 → 오른쪽 서브트리의 루트 순서이므로 루트를 찾는대로 출력하면 preorder이다.

2-1. 코드 (메모리 초과)

아래와 같이 루트 기준 왼쪽 오른쪽 서브트리의 inorder과 postorder를 재귀로 넘기며 풀었다. 하지만 이렇게 풀었더니 메모리 초과가 발생했다. 서브트리 전체를 넘기지 않고 inorder과 postorder의 인덱스를 넘기며 풀 수 있다.

N = int(input())
order1 = list(map(int, input().split()))
order2 = list(map(int, input().split()))

tmp=[]
def find(inorder, postorder):
    if len(inorder)==1:
        tmp.append(inorder[0])
        return

    root = postorder[-1]
    tmp.append(root)

    #inorder 배열에서 root의 idx
    in_root_idx = inorder.index(root)

    #왼쪽 서브트리 개수
    left_num = in_root_idx
    right_num = len(inorder)-in_root_idx-1    

    if left_num!=0:
        in_left_subtree = inorder[:in_root_idx]
        post_left_subtree=postorder[:left_num]
        find(in_left_subtree, post_left_subtree)
    
    if right_num!=0:
        post_right_subtree = postorder[left_num:-1]
        in_right_subtree = inorder[in_root_idx+1:]
        find(in_right_subtree, post_right_subtree)
        
    return

find(order1,order2)
print(*tmp)

2-2. 코드(정답)

마찬가지로 개수를 이용한다. 다만 inorder과 postorder의 시작점과 끝점을 재귀로 전달한다.

import sys
sys.setrecursionlimit(10**4)
N = int(input())
inorder = list(map(int, input().split()))
postorder = list(map(int, input().split()))

tmp=[]
def find(ins,ine,posts,poste):
    if ins==ine:
        tmp.append(inorder[ins])
        return
    elif ins>ine or posts>poste:
        return
    
    root = postorder[poste]
    tmp.append(root)
    #inorder 배열에서 root의 idx
    in_root_idx = inorder.index(root)
    l_num = in_root_idx-ins
    r_num = ine-in_root_idx
    
    #왼쪽은 같아.
    find(ins, in_root_idx-1, posts, posts+l_num-1)
    #오른쪽은
    find(in_root_idx+1,ine , posts+l_num , poste-1)
    return

find(0,N-1,0,N-1)
print(*tmp)

업데이트:

댓글남기기