継続渡しとコンパイル

実装はまだ先だが、継続渡しに変換してからコンパイルする方法がなんとなく分かった。

継続渡しへの変換

継続渡しへの変換は、簡単な例で示すと以下のようになる。
次の式を評価するためには、まず(g)を評価して、その戻り値を使って(f # x)を評価する。

(f (g) x)

つまり以下の式と等価。

((lambda (v) (f v x)) (g))

ここで関数gの引数を一つ増やして、1引数の関数kを取れるようにしたg_cpsを考える。この関数g_cpsは、本来のgの計算を行った後で、戻り値を返す代わりに渡された関数を呼び出す(呼ばれた関数から戻ってこないので、呼び出すというよりジャンプすると言った方が正しい)。

(g_cps (lambda (v) (f v x)))

このときの増えた引数が継続。

(擬似)バイトコードへのコンパイル

具体例で考える。

(define (fact n)
  (if (> n 1)
      (* n (fact (1- n)))
      1))

この関数を継続渡しに変換すると、こうなる(途中省略)。

(define (fact_cps n k)
  (>_cps n 1 (lambda (v1)
               (if v1
                   (1-_cps n (lambda (v2)
                               (fact_cps v2 (lambda (v3)
                                              (*_cps n v3 k)))))
                   (k 1)))))

ところで、継続の中に自由変数(外の環境で定義された変数)があると後の処理で都合が悪いので、次のようにlambdaのパラメータに自由変数のリストを追加する(この処理をクロージャー変換というらしい)。

(define (fact_cps n k)
  (>_cps n 1 (lambda ([n k] v1)
               (if v1
                   (1-_cps n (lambda ([n k] v2)
                               (fact_cps v2 (lambda ([n k] v3)
                                              (*_cps n v3 k)))))
                   (k 1)))))

コンパイルする都合上、継続に名前を付ける。

(define (fact_cps n k)
  (>_cps n 1 [k1 n k]))

(define (k1 n k v1)
  (if v1
      (1-_cps n [k2 n k])
      (k 1)))

(define (k2 n k v2)
  (fact_cps v2 [k3 n k]))

(define (k3 n k v3)
  (*_cps n v3 k))

すると次のようなバイトコードになるだろう。

fact_cps:
        pop k
        pop n
        push n
        push 1
        push [k1 n k]
        jmp >_cps

k1:
        pop v1
        pop k
        pop n
        if v1 == #f jmp L1
        push n
        push [k2 n k]
        jmp 1-_cps

L1:
        push 1
        jmp k

k2:
        pop v2
        pop k
        pop n
        push v2
        push [k3 n k]
        jmp fact_cps

k3:
        pop v3
        pop k
        pop n
        push n
        push v3
        push k
        jmp *_cps

ただし、push vはvの内容をスタックに積む、pop vはスタックからポップした内容をvに入れるという意味。[〜]はヒープにクロージャを作成する擬似コード

これで例えば

(fact 3)

コンパイルすると、

toplevel:
        push 3
        push END
        jmp fact_cps

END:
        pop ret

こんな感じになる。

追記

別にスタックマシンでなくてもいい。レジスタが無限にあると仮定すれば同じことだ。ついでに四則演算ぐらいはプリミティブで処理することにすると、

; (r1: n, r2: k)
fact_cps:
        r0 := r1 - 1
        r0 := r0 > 0
        (r1, r2 r3) := (r1, r2, r0)
        jmp k1

; (r1: n, r2: k, r3: v1)
k1:
        if r3 == #f jmp L1
        r0 := r1 - 1
        (r1, r2, r3) := (r1, r2, r0)
        jump k2

L1:
        (r1) := (1)
        jmp r2

; (r1: n, r2: k, r3: v2)
k2:
        (r1, r2) := (r3, [k3 r1 r2])
        jmp fact_cps

; (r1: n, r2: k, r3: v3)
k3:
        r0 := r1 * r3
        (r1) := (r0)
        jmp r2

toplevel:
        (r1, r2) := (3, END)
        jmp fact_cps

END:
        ; r1を使って何かをする

だいぶ簡単になった。これはさらに最適化できて、

fact_cps:
        r3 := r1 - 1
        if (r3 <= 0) jmp L1
        r2 := [k3 r1 r2]
        r1 := r3
        jmp fact_cps

L1:
        r1 := 1
        jmp r2

k3:
        r1 := r1 * r3
        jmp r2

toplevel:
        r1 := 3
        r2 := END
        jmp fact_cps

END:

さすがにここまで自動的にやるのは厳しいかな。