継続渡しとコンパイル
実装はまだ先だが、継続渡しに変換してからコンパイルする方法がなんとなく分かった。
継続渡しへの変換
継続渡しへの変換は、簡単な例で示すと以下のようになる。
次の式を評価するためには、まず(g)を評価して、その戻り値を使って(f # x)を評価する。
(f (g) x)
つまり以下の式と等価。
((lambda (v) (f v x)) (g))
ここで関数gの引数を一つ増やして、1引数の関数kを取れるようにしたg_cpsを考える。この関数g_cpsは、本来のgの計算を行った後で、戻り値を返す代わりに渡された関数を呼び出す(呼ばれた関数から戻ってこないので、呼び出すというよりジャンプすると言った方が正しい)。
(g_cps (lambda (v) (f v x)))
このときの増えた引数が継続。
(擬似)バイトコードへのコンパイル
具体例で考える。
(define (fact n) (if (> n 1) (* n (fact (1- n))) 1))
この関数を継続渡しに変換すると、こうなる(途中省略)。
(define (fact_cps n k) (>_cps n 1 (lambda (v1) (if v1 (1-_cps n (lambda (v2) (fact_cps v2 (lambda (v3) (*_cps n v3 k))))) (k 1)))))
ところで、継続の中に自由変数(外の環境で定義された変数)があると後の処理で都合が悪いので、次のようにlambdaのパラメータに自由変数のリストを追加する(この処理をクロージャー変換というらしい)。
(define (fact_cps n k) (>_cps n 1 (lambda ([n k] v1) (if v1 (1-_cps n (lambda ([n k] v2) (fact_cps v2 (lambda ([n k] v3) (*_cps n v3 k))))) (k 1)))))
コンパイルする都合上、継続に名前を付ける。
(define (fact_cps n k) (>_cps n 1 [k1 n k])) (define (k1 n k v1) (if v1 (1-_cps n [k2 n k]) (k 1))) (define (k2 n k v2) (fact_cps v2 [k3 n k])) (define (k3 n k v3) (*_cps n v3 k))
すると次のようなバイトコードになるだろう。
fact_cps: pop k pop n push n push 1 push [k1 n k] jmp >_cps k1: pop v1 pop k pop n if v1 == #f jmp L1 push n push [k2 n k] jmp 1-_cps L1: push 1 jmp k k2: pop v2 pop k pop n push v2 push [k3 n k] jmp fact_cps k3: pop v3 pop k pop n push n push v3 push k jmp *_cps
ただし、push vはvの内容をスタックに積む、pop vはスタックからポップした内容をvに入れるという意味。[〜]はヒープにクロージャを作成する擬似コード。
これで例えば
(fact 3)
をコンパイルすると、
toplevel: push 3 push END jmp fact_cps END: pop ret
こんな感じになる。
追記
別にスタックマシンでなくてもいい。レジスタが無限にあると仮定すれば同じことだ。ついでに四則演算ぐらいはプリミティブで処理することにすると、
; (r1: n, r2: k) fact_cps: r0 := r1 - 1 r0 := r0 > 0 (r1, r2 r3) := (r1, r2, r0) jmp k1 ; (r1: n, r2: k, r3: v1) k1: if r3 == #f jmp L1 r0 := r1 - 1 (r1, r2, r3) := (r1, r2, r0) jump k2 L1: (r1) := (1) jmp r2 ; (r1: n, r2: k, r3: v2) k2: (r1, r2) := (r3, [k3 r1 r2]) jmp fact_cps ; (r1: n, r2: k, r3: v3) k3: r0 := r1 * r3 (r1) := (r0) jmp r2 toplevel: (r1, r2) := (3, END) jmp fact_cps END: ; r1を使って何かをする
だいぶ簡単になった。これはさらに最適化できて、
fact_cps: r3 := r1 - 1 if (r3 <= 0) jmp L1 r2 := [k3 r1 r2] r1 := r3 jmp fact_cps L1: r1 := 1 jmp r2 k3: r1 := r1 * r3 jmp r2 toplevel: r1 := 3 r2 := END jmp fact_cps END:
さすがにここまで自動的にやるのは厳しいかな。