木の分岐数

algorithm

頂点と親のインデックス , が与えられます. 根を1とした根付き木とした場合, 分岐数 を以下のように定めます.

  • 各頂点による部分木の最大深さが小さい方が分岐が1増える

このとき分岐数が0, 1の枝のみを出力し,

  • 分岐数が0のどれかをネストせずに出力し
  • それ以外の場合は入れ子のリスト
    にして出力してください.
    入力は 頂点数のあとに 行でインデックス, 親のインデックスが与えられます.

入力例1

// 1 - 2 - 4 - 7 - 8 (k = 0)
//   L 3 - 5 - 9 (k = 1)
//       L 6 (k = 2)
9
2 1
3 1
4 2
5 3
6 3
7 4
8 7
9 5

出力例1

[1, 2, [3, 5, 9], 4, 7, 8]

入力例2

// 1 - 2 - 4 - 5 - 7 (k = 0)
//   L 3 (1)     L 6 (0)
7
2 1
3 1
4 2
5 4
6 5
7 5

出力例2

同じ k が複数ある場合はどちらをメインにしても構いません

[1, 2, [3], 4, 5, 6, [7]]

解答例(Python)

from icecream import ic
from anytree import Node
 
def branches(leaf: Node):
    """ leaf の分岐数 """
    cnt = 0
    for i, anc in enumerate(leaf.iter_path_reverse()):
        if anc.height != i:
            cnt += 1
    return cnt
 
def prune(leaf: Node):
    """ leaf の枝をかる """
    n = leaf
    while len(n.siblings) == 0:
        n = n.parent
    n.parent = None
 
def flatten(root: Node):
    """ 兄弟がいないならまとめる """
    cs = root.children
    if len(cs) == 0:
        return [root]
    elif len(cs) == 1:
        return [root, *flatten(cs[0])]
    else:
        res = []
        l0, *l1s = sorted(cs, key=lambda n: branches(n), reverse=True)
        l0s = flatten(l0)
 
        # [root, l0[0], l1..., l0[1:]]
        res.append(root)
        res.append(l0s[0])
        for l1 in l1s:
            res.append(flatten(l1))
        res.extend(l0s[1:])
        return res
 
def flatten_by_lt_l1(root: Node):
    """ level1 以下を枝刈りしてまとめる """
    for leaf in root.leaves:
        if branches(leaf) > 1:
            prune(leaf)
    return flatten(root)
 
def construct_tree(n: int, links: list):
    """ 隣接情報から木を構築 """
    nodes = [Node(i + 1) for i in range(n)]
    for i, p in links:
        nodes[i - 1].parent = nodes[p - 1]
    return nodes
 
if __name__ == '__main__':
    # 入力
    n = int(input())
    links = []
    for i in range(n - 1):
        i, p = map(int, input().split())
        links.append((i, p))
 
    # 計算
    nodes = construct_tree(n, links)
    res = flatten_by_lt_l1(nodes[0])
 
    # 出力
    for e in res:
        if type(e) is list:
            print(list(map(lambda n: n.name, e)), end=' ')
        else:
            print(e.name, end=' ')