少年易酔學難成

IT/技術的な話題について書きます

二分木の畳み込みを末尾再帰化したい人生だった

※ この記事はアルゴリズムに詳しくない人が自分の思考の記録のために書いた記事です。 なので書いた内容が合っているかは確証が持てません。 「大間違いだよ、馬鹿野郎!」というツッコミは大歓迎します。

突然どうした

現在、自分はプロジェクトでScalaを使っている。 そのプロジェクトでは毎週木曜日に1時間FPinScalaの勉強会をやっているのだけど、現在読んでいる3章は前にも解いたことがあり、ただ同じように解いてもつまらないのでなるべく末尾再帰の形にするようにしている。

昨日はexercise3.25を解いていた。 この問題は二分木のノードの数をカウントするsize関数を実装せよ、という問題だ。 ここで用いる二分木の定義は以下の通り。

sealed trait Tree[+A]
case class Leaf[A](value: A) extends Tree[A]
case class Branch[A](left: Tree[A], right: Tree[A]) extends Tree[A]

で、size関数は普通に書けば↓のように書ける。

def size[A](tree: Tree[A]): Int = tree match {
  case Leaf(_)      => 1
  case Branch(l, r) => size(l) + 1 + size(r)
}

これの末尾再帰化は二分木だし結構時間かかるのでは…と思ったら、意外にもすぐに末尾再帰化できた。

def size[A](tree: Tree[A]): Int = {
  // ListはFPinScalaで定義した自前のものを使っている
  @annotation.tailrec
  def go[A](tree: Tree[A], acc: Int, notCalculated: List[Tree[A]]): Int = tree match {
    case Leaf(_) => notCalculated match {
      case Nil        => acc + 1
      case Cons(h, t) => go(h, acc + 1, t)
    }
    case Branch(l, r) => go(l, acc + 1, Cons(r, notCalculated))
  }
  go(tree, 0, Nil)
}

二分木を扱う関数で末尾再帰化するイメージは自分にはなかったので、とても感動した。 この例では、未計算のツリーのリストを再帰で渡してやることで難なく末尾再帰化できるのだ。

もう少し突き詰めると、スタックの代わりとなるデータ型をうまく考えてやれば色々なものを末尾再帰化できるようになるのではないか。 例えば、二分木のfoldやmap関数だ。

そして試行錯誤が始まった

元々書いていたfold関数の回答は↓の通り。

def fold[A, B](a: Tree[A])(f: A => B)(g: (B, B) => B): B = a match {
  case Branch(l, r) => g(fold(l)(f)(g), fold(r)(f)(g))
  case Leaf(v)      => f(v)
}

これをまずはこのように修正した。

def fold[A, B](tree: Tree[A])(f: A => B)(g: (B, B) => B): B = {
  @annotation.tailrec
  def go(tree: Tree[A], acc: Option[B], notCalculated: List[Tree[A]])(f: A => B)(g: (B, B) => B): B = tree match {
    case Leaf(v) => notCalculated match {
      case Nil        => acc.map(g(_, f(v))).getOrElse(f(v))
      case Cons(h, t) => go(h, acc.map(g(_, f(v))).orElse(Some(f(v))), t)(f)(g)
    }
    case Branch(l, r) => go(l, acc, Cons(r, notCalculated))(f)(g)
  }
  go(tree, None, Nil)(f)(g)
}

が、これは実装の途中で悪手だと気づいた。 これは木の左側からLeafを見つけ次第たたみ込んでいるので元々のfold関数と評価の順番が変わってしまっている。 畳み込み関数gの実装によっては出力される答えが変わってしまうだろう。

ここで元々のfold関数の評価順を守ろうとすると別の問題が出てくる。 再帰呼び出しで渡すデータは未計算のツリーだけではなく左側のLeafで計算した値も必要になるのだ。

この問題の解決方法はなかなか思い浮かばなかった。

途中、notCalculated: List[Tree[A]]の他に、別の引数として左側のリーフを計算した値leftAcc: Option[B]も一緒に渡す事も検討したが、 処理するLeafが左側であるのか右側であるのかわからないのでうまく行かない。

ここで初めてfoldを実装するには関数がツリーの構造を知る必要があるのだと気づいた。

そこでnotCalculated: Tree[A]ではなく(Option[B], List[Tree[A]])のようなデータを渡したらどうかとも考えたが、データ構造がやや複雑に思えたので断念し、他に良いアイデアも思い浮かばなかったのでこの日は諦めた。

試行錯誤二日目

昨晩考えても答えが出なかったので、ちょっと一旦PCを離れて紙と鉛筆で考えてみることにした。

その結果、次の再帰に渡したいデータは、右のLeafの場合も左のLeafの場合も計算の小計だがBranchを処理する場合は未計算のツリーであることに気づいた。 そこで次の再帰に渡したいデータをstack: List[Either[B, Tree[A]]]に格納し、stackに格納されたデータ型によって関数が元のツリーの構造を判断するようにした。

それがこちらの実装。

def fold[A, B](tree: Tree[A])(f: A => B)(g: (B, B) => B): B = {
  @annotation.tailrec
  def go(tree: Tree[A], stack: List[Either[B, Tree[A]]])(f: A => B)(g: (B, B) => B): B = tree match {
    case Leaf(v) => stack match {
      case Cons(Right(next), rest) => go(next, Cons(Left(f(v)), rest))(f)(g) // 左枝の場合
      case Cons(Left(leftAcc), Cons(Right(next), rest)) => go(next, Cons(Left(g(leftAcc, f(v))), rest))(f)(g) // 右枝の場合
      case Cons(Left(leftAcc), Cons(Left(acc), _)) => g(acc, g(leftAcc, f(v))) // 右端のLeaf到達パターン1
      case Cons(Left(leftAcc), Nil) => g(leftAcc, f(v)) // 右端のLeaf到達パターン2
      case Nil => f(v) // この場合はLeafのみのTree
    }
    case Branch(l, r) => go(l, Cons(Right(r), stack))(f)(g)
  }
  go(tree, Nil)(f)(g)
}

自分はこれを書いた後、とても後悔することになった。 よく確かめもせず周囲に完成したと吹聴してしまったからだ。

これは全く正しく動作しない。

右側に偏った木をfoldしようとすると、未計算の木がstack奥深くに格納されてしまうのでパターンマッチで拾えず、未計算の木を残したまま計算が終了してしまうからだ。

これを解消するために、常に未計算の木が先頭に来るまで計算を進めることにした。

def fold[A, B](tree: Tree[A])(f: A => B)(g: (B, B) => B): B = {
  @annotation.tailrec
  def eval(stack: List[Either[B, Tree[A]]]): List[Either[B, Tree[A]]] = stack match {
    case Cons(Left(e1), Cons(Left(e2), rest))   => eval(Cons(Left(g(e2, e1)), rest))
    case Cons(Left(e1), Cons(r@Right(_), rest)) => Cons(r, Cons(Left(e1), rest)) // 呼出元で木を取り出しやすいよう入れ替え
    case _                                      => stack
  }
  @annotation.tailrec
  def go(tree: Tree[A], stack: List[Either[B, Tree[A]]]): B = tree match {
    case Leaf(v) =>
      // 計算済みの値がstackに連続して溜まると未計算の木が取り出せないので出来るだけ計算を進める
      eval(Cons(Left(f(v)), stack)) match {
        case Cons(Right(next), rest) => go(next, rest)
        case Cons(Left(e1), _)       => e1 // 最後まで評価が終わっているはず
        case Nil                     => f(v)
      }
    case Branch(l, r) => go(l, Cons(Right(r), stack))
  }
  go(tree, Nil)
}

これで、正しく動くようになった。(たぶん)

さてfoldさえできれば、mapを実装することは簡単だ。

def map[A, B](tree: Tree[A])(f: A => B): Tree[B] = {
  fold(tree)(v => Leaf(f(v)): Tree[B])(Branch(_, _))
}

パフォーマンスなどの議論は詳しくないから抜きにして、とりあえず二分木の畳み込みを末尾再帰化するという目的は達成できたように思える。