egg tutorial

01 background

Why any of this?

プログラミング言語のツールを作ろうとするとき, 項に対して

  • 等価なより良い項に変換 (optimization)
  • ある仕様に沿って項を作成 (synthesis)
  • 項の等価性の判定 (verification)

Term Rewriting は一般的な方法, matchしてrewriteする.
問題点として選択の問題がある. 書き換えは破壊的な変更.
やり直すのも大変だし, 正しい書き換えを見つけるのはとてもむずかしい
good locally でも wrong globally な時がある

ex. C optimizer は a * 2a << 1 にするが, (a * 2) / 2 の時, greedyにやると (a << 1) / 2 になってしまって, a に最適化出来ない.
軽減策(mitigate)として, backtrackingと global value numbering あるが,
全てを考慮しながらやるという意味で, 選択肢が指数的に爆発する.
E-graphはこれを解決する. 候補は似たような構造をしていることが多くそれらをコンパクトに格納できる. 指数爆発なしに多くの書き換えを同時に適用できる.

What’s an e-graph

式の等価な表現を保持するデータ構造. e-graphe-class の集合.
e-class には等価な e-node が入っている.
e-node は子に対する演算, 子は項の代わりに e-class.
egg では e-nodeLanguage 型.

下図で, e-class は点線. 辺は e-node -> e-class に繋がっていることに注意.

e-graphe-matching (Searcher API -> Pattern API)によってパターを問い合わせる. e-graph を変更する操作は

  • add: e-nodee-graph に追加する
  • union: 2つの e-class をマージする

e-graph 上で の書き換えを行うには

  1. を検索し, を得る
  2. e-graph に項 を追加し, マッチした e-class と union する

ex.

  1. 最初, (a * 2) / 2 を表している. 全ての e-class は1つの e-node を持っているのでASTとなっている(2は複製せずに共有している)
  2. 書き換え x * 2 -> x << 1 は等式 x * 2 == x << 1 として記録される. (= 新たに追加された a << 1a を使うし, * と同じ e-class に結合されている)
  3. 結合則に基づいた書き換え (x * y) / z -> z * (y / z) を適用.
    1. (x * y) / z(a * 2) / 2 にマッチし, a * (2 / 2) に相当する e-class を結合.
  4. 書き換え x / x -> 1, 1 * x -> x を適用.
    1. x / x -> 1 に対応する結合は /1 を加える
    2. 1 * x -> x では *: {/, 1} aa を追加 -> *: *{a, \, *} になる, 自分自身の e-class を指しているのは a, a * 1, a * 1 * 1… という無限の集合に対する書き換えを表していることに相当する

Invariants and Rebuilding

addunion は2つの不変性を保つ変換.

  1. Congruence(合同性): e-graph は等価関係で無く, 合同関係を管理するもの.(同値類みたいな). 例えば, 2つの e-class a + x, a + y があり, その後, xy の等価性が判明したら, 2つ e-class は合同なので統合する.
  2. Uniqueness of e-nodes: 合同な2つ別の e-class が存在しないこと.
    1. add の時は Hash consing
    2. union, rebuild の時はdedupが行われる

egg の実装では, 不変性を保つ処理は rebuild() でやっている.
Runner は書き換えの反復で rebuild() してくれる.

Equality Saturation

Ross Tateさんの論文 [1012.1802] Equality Saturation: A New Approach to Optimization.
プログラム最適化の手法で, Runner APIで実装されている.

fn equality_saturation(expr: Expression, rewrites: Vec<Rewrite>) -> Expression {
    let mut egraph = make_initial_egraph(expr);
 
    while !egraph.is_saturated_or_timeout() {
        let mut matches = vec![];
 
        // read-only phase, invariants are preserved
        for rw in rewrites {
            for (subst, eclass) in egraph.search(rw.lhs) {
                matches.push((rw, subst, eclass));
            }
        }
 
        // write-only phase, temporarily break invariants
        for (rw, subst, eclass) in matches {
            eclass2 = egraph.add(rw.rhs.subst(subst));
            egraph.union(eclass, eclass2);
        }
 
        // restore the invariants once per iteration
        egraph.rebuild();
    }
 
    return egraph.extract_best();
}

Saturation: 書き換えによって情報が追加されなくなった時に起こる ex. x + y -> y + x. 与えられた全ての書き換えから導出可能な全ての合同関係が符号化されていること.
Extraction: あるコスト関数に従って, e-class から1つの表現を選ぶ手続き, Extractor がこれをする.
Equality Saturation は可能な全ての表現を探索して, 最適な項を抽出すること.

tutorial 02 getting started

Language trait (e-node) を実装すれば使える.
SymbolLang は用意されている.

let my_expression: RecExpr<SymbolLang> = "(foo a b)".parse().unwrap();
println!("this is my expression {}", my_expression);
let my_enode = SymbolLang::new("bar", vec![]);

e-class(RecExpr) にも e-graph にも追加できる.
2回同じ構造を追加すると同一の ID が返る.

let mut expr = RecExpr::default();
let a = expr.add(SymbolLang::leaf("a"));
let b = expr.add(SymbolLang::leaf("b"));
let foo = expr.add(SymbolLang::new("foo", vec![a, b]));
 
// we can do the same thing with an EGraph
let mut egraph: EGraph<SymbolLang, ()> = Default::default();
let a = egraph.add(SymbolLang::leaf("a"));
let b = egraph.add(SymbolLang::leaf("b"));
let foo = egraph.add(SymbolLang::new("foo", vec![a, b]));
 
// we can also add RecExprs to an egraph
let foo2 = egraph.add_expr(&expr);
// note that if you add the same thing to an e-graph twice, you'll get back equivalent Ids
assert_eq!(foo, foo2);
// let's make an e-graph
let mut egraph: EGraph<SymbolLang, ()> = Default::default();
let a = egraph.add(SymbolLang::leaf("a"));
let b = egraph.add(SymbolLang::leaf("b"));
let foo = egraph.add(SymbolLang::new("foo", vec![a, b]));
 
// rebuild the e-graph since we modified it
egraph.rebuild();
 
// we can make Patterns by parsing, similar to RecExprs
// names preceded by ? are parsed as Pattern variables and will match anything
let pat: Pattern<SymbolLang> = "(foo ?x ?x)".parse().unwrap();
 
// since we use ?x twice, it must match the same thing,
// so this search will return nothing
let matches = pat.search(&egraph);
assert!(matches.is_empty());
 
egraph.union(a, b);
// recall that rebuild must be called to "see" the effects of adds or unions
egraph.rebuild();
 
// now we can find a match since a = b
let matches = pat.search(&egraph);
assert!(!matches.is_empty());

書き換えルールを簡単に記述するマクロが用意されている.

use egg::{*, rewrite as rw};
 
let rules: &[Rewrite<SymbolLang, ()>] = &[
    rw!("commute-add"; "(+ ?x ?y)" => "(+ ?y ?x)"),
    rw!("commute-mul"; "(* ?x ?y)" => "(* ?y ?x)"),
 
    rw!("add-0"; "(+ ?x 0)" => "?x"),
    rw!("mul-0"; "(* ?x 0)" => "0"),
    rw!("mul-1"; "(* ?x 1)" => "?x"),
];
 
// While it may look like we are working with numbers,
// SymbolLang stores everything as strings.
// We can make our own Language later to work with other types.
let start = "(+ 0 (* 1 a))".parse().unwrap();
 
// That's it! We can run equality saturation now.
let runner = Runner::default().with_expr(&start).run(rules);
 
// Extractors can take a user-defined cost function,
// we'll use the egg-provided AstSize for now
let extractor = Extractor::new(&runner.egraph, AstSize);
 
// We want to extract the best expression represented in the
// same e-class as our initial expression, not from the whole e-graph.
// Luckily the runner stores the eclass Id where we put the initial expression.
let (best_cost, best_expr) = extractor.find_best(runner.roots[0]);
 
// we found the best thing, which is just "a" in this case
assert_eq!(best_expr, "a".parse().unwrap());
assert_eq!(best_cost, 1);

デバッグの為の Explanation

ex. (/ (* (/ 2 3) (/ 3 2)) 1) がなぜ 1 に簡約されるか

FlatExplanation は読みやすい形で表示してくれる.

use egg::{*, rewrite as rw};
let rules: &[Rewrite<SymbolLang, ()>] = &[
    rw!("div-one"; "?x" => "(/ ?x 1)"),
    rw!("unsafe-invert-division"; "(/ ?a ?b)" => "(/ 1 (/ ?b ?a))"),
    rw!("simplify-frac"; "(/ ?a (/ ?b ?c))" => "(/ (* ?a ?c) (* (/ ?b ?c) ?c))"),
    rw!("cancel-denominator"; "(* (/ ?a ?b) ?b)" => "?a"),
    rw!("times-zero"; "(* ?a 0)" => "0"),
];
 
let start = "(/ (* (/ 2 3) (/ 3 2)) 1)".parse().unwrap();
let end = "1".parse().unwrap();
let mut runner = Runner::default().with_explanations_enabled().with_expr(&start).run(rules);
 
println!("{}", runner.explain_equivalence(&start, &end).get_flat_string());
(/ (* (/ 2 3) (/ 3 2)) 1)
(Rewrite<= div-one (* (/ 2 3) (/ 3 2)))
(* (Rewrite=> unsafe-invert-division (/ 1 (/ 3 2))) (/ 3 2))
(Rewrite=> cancel-denominator 1)

これで0 を入れると,

0
(Rewrite<= times-zero (* (/ 1 0) 0))
(Rewrite=> cancel-denominator 1)

なぜ, 0 が 1になるのかとして (/ 1 0) というzero divisionがあるからということがわかる.

TreeExplanation はテストとかしやすい. 値の共有は let で書いてくれる.

(+ 1 (- a (* (- 2 1) a)))
 (+
    1
    (Explanation
      (- a (* (- 2 1) a))
      (-
        a
        (Explanation
          (* (- 2 1) a)
          (* (Explanation (- 2 1) (Rewrite=> constant_fold 1)) a)
          (Rewrite=> comm-mul (* a 1))
          (Rewrite<= mul-one a)))
      (Rewrite=> cancel-sub 0)))
 (Rewrite=> constant_fold 1)
(let
  (v_0 (- 2 1))
  (let
    (v_1 (- 2 (Explanation v_0 (Rewrite=> constant_fold 1))))
    (Explanation
      (* (- 2 (- 2 1)) (- 2 (- 2 1)))
      (*
        (Explanation (- 2 (- 2 1)) v_1 (Rewrite=> constant_fold 1))
        (Explanation (- 2 (- 2 1)) v_1 (Rewrite=> constant_fold 1)))
      (Rewrite=> constant_fold 1))))

より具体的な例

参考文献