출처 - https://www.acmicpc.net/problem/2213

트리에서 DP를 사용하는 문제였다. 독립집합의 크기를 구하는 것은 쉽게 생각할 수 있었다. 백준 트리DP 알고리즘 분류에서 가장 위에 있는 SNS 문제와 거의 유사했다. 경로를 구하기 위해서 따로 추적을 하는 함수를 작성했는데 간단하게 만드는 방법도 존재할 것 같다.

1. 풀이

  1. 최대값 구하기

트리의 독립집합에 포함될 노드를 가중치를 비교하며 선택해가면 된다. 가장 먼저 트리의 독립집합을 구하기 위해서 트리의 부모-자식 관계를 설정해야 한다.

임의의 노드를 루트 노드로 설정해도 트리의 특성은 유지된다. 즉 인접 리스트로 트리를 저장한 뒤 방문 표시를 하며 루트부터 순회하면 자식 노드를 차례대로(DFS) 모두 방문할 수 있다.

  • 부모 노드를 선택했다면 자식 노드는 선택할 수 없다. 부모 노드를 선택하지 않았다면 자식 노드는 선택’할 수’ 있다. (반드시 자식 노드를 선택한 경우가 큰 것은 아니다)
  • 즉 dp 테이블을 2차원으로 만들고 해당 노드를 선택한 경우에 누적된 독립집합의 크기와 선택하지 않은 경우의 크기를 모두 저장한다.
  1. 경로 구하기
    • dp 테이블에 2가지 경우가 저장되어 있으므로 이를 이용해서 추적한다.
    • 재귀함수를 이용하여 방문할 노드와 방문할 노드의 부모 노드가 선택되었는 지 여부를 전달한다. 해당 노드가 선택되었다면 check 배열을 통해 표시한다.
    • 만약 현재 노드를 선택한 경우에 독립집합의 크기가 크다면 방문하지 않은 자식 노드와 함께 재귀함수에 1을 전달하여 자식 노드는 선택할 수 없음을 표시한다. (ps=1일 때 무조건 check=0)
    • 만약 현재 노드를 선택하지 않은 경우에 독립집합의 크기가 크다면 재귀함수에 0을 전달한다.
    • 이 때 자식 노드가 입력값으로 함수가 실행될 때는 노드가 선택된 경우 혹은 선택 되지 않은 경우 비교가 가능하다. (ps=0일 때 두 가지 경우로 분기처리)

2. 코드

#트리의 독립집합
N = int(input())
W = [0]+ list(map(int, input().split()))
tree = [[] for _ in range(N+1)]
for _ in range(N-1):
    u,v = map(int,input().split())
    tree[u].append(v)
    tree[v].append(u)
    

dp=[ [0,0] for _ in range(N+1) ]
visited = [0 for _ in range(N+1)]
def dynamic(node):
    
    if visited[node] == 1:    
        return max(dp[node][0], dp[node][1])

    
    visited[node]=1
    dp[node][1]=W[node]
    
    for child in tree[node]:
        if visited[child]==0:
            dynamic(child)
            
            dp[node][0] += max(dp[child][1], dp[child][0])
            
            dp[node][1] += dp[child][0] 

    
    
    return max(dp[node][0], dp[node][1])

result = dynamic(1)

check = [0 for _ in range(N+1)]
visited = [0 for _ in range(N+1)]
def findPath(node,ps):
    if visited[node]==1:
        return    
    visited[node]=1
    
    if ps==1:
        
        check[node]=0
        for child in tree[node]:
            if visited[child]==0:
                findPath(child, 0)
    
    else:
        if dp[node][0]> dp[node][1]:
            check[node]=0

            for child in tree[node]:
                if visited[child]==0:
                    findPath(child,0)    
            
        else:
            check[node]=1

            for child in tree[node]:
                if visited[child]==0:
                    findPath(child,1)    

findPath(1, 0)

print(result)
for i in range(N+1):
    if check[i]==1:
        print(i, end=' ')

업데이트:

댓글남기기