egg tutorial
01 background
Why any of this?
プログラミング言語のツールを作ろうとするとき, 項に対して
- 等価なより良い項に変換 (optimization)
- ある仕様に沿って項を作成 (synthesis)
- 項の等価性の判定 (verification)
Term Rewriting は一般的な方法, matchしてrewriteする.
問題点として選択の問題がある. 書き換えは破壊的な変更.
やり直すのも大変だし, 正しい書き換えを見つけるのはとてもむずかしい
good locally でも wrong globally な時がある
ex. C optimizer は a * 2 を a << 1 にするが, (a * 2) / 2 の時, greedyにやると (a << 1) / 2 になってしまって, a に最適化出来ない.
軽減策(mitigate)として, backtrackingと global value numbering あるが,
全てを考慮しながらやるという意味で, 選択肢が指数的に爆発する.
E-graphはこれを解決する. 候補は似たような構造をしていることが多くそれらをコンパクトに格納できる. 指数爆発なしに多くの書き換えを同時に適用できる.
What’s an e-graph
式の等価な表現を保持するデータ構造. e-graph は e-class の集合.
e-class には等価な e-node が入っている.
e-node は子に対する演算, 子は項の代わりに e-class.
egg では e-node は Language 型.
下図で, e-class は点線. 辺は e-node -> e-class に繋がっていることに注意.

e-graph は e-matching (Searcher API -> Pattern API)によってパターを問い合わせる. e-graph を変更する操作は
- add:
e-nodeをe-graphに追加する - union: 2つの
e-classをマージする
e-graph 上で の書き換えを行うには
- を検索し, を得る
e-graphに項 を追加し, マッチしたe-classと union する
ex.
- 最初,
(a * 2) / 2を表している. 全てのe-classは1つのe-nodeを持っているのでASTとなっている(2は複製せずに共有している) - 書き換え
x * 2 -> x << 1は等式x * 2 == x << 1として記録される. (= 新たに追加されたa << 1はaを使うし,*と同じe-classに結合されている) - 結合則に基づいた書き換え
(x * y) / z -> z * (y / z)を適用.(x * y) / zは(a * 2) / 2にマッチし,a * (2 / 2)に相当するe-classを結合.
- 書き換え
x / x -> 1,1 * x -> xを適用.x / x -> 1に対応する結合は/に1を加える1 * x -> xでは*: {/, 1} aにaを追加 ->*: *{a, \, *}になる, 自分自身のe-classを指しているのはa,a * 1,a * 1 * 1… という無限の集合に対する書き換えを表していることに相当する
Invariants and Rebuilding
add と union は2つの不変性を保つ変換.
Congruence(合同性):e-graphは等価関係で無く, 合同関係を管理するもの.(同値類みたいな). 例えば, 2つのe-classa + x,a + yがあり, その後,xとyの等価性が判明したら, 2つe-classは合同なので統合する.Uniqueness of e-nodes: 合同な2つ別のe-classが存在しないこと.addの時は Hash consingunion,rebuildの時はdedupが行われる
egg の実装では, 不変性を保つ処理は rebuild() でやっている.
Runner は書き換えの反復で rebuild() してくれる.
Equality Saturation
Ross Tateさんの論文 [1012.1802] Equality Saturation: A New Approach to Optimization.
プログラム最適化の手法で, Runner APIで実装されている.
- 代数学でよく出る (ex. ホッジ・テイト加群) の人と同じ人?
- 違った
fn equality_saturation(expr: Expression, rewrites: Vec<Rewrite>) -> Expression {
let mut egraph = make_initial_egraph(expr);
while !egraph.is_saturated_or_timeout() {
let mut matches = vec![];
// read-only phase, invariants are preserved
for rw in rewrites {
for (subst, eclass) in egraph.search(rw.lhs) {
matches.push((rw, subst, eclass));
}
}
// write-only phase, temporarily break invariants
for (rw, subst, eclass) in matches {
eclass2 = egraph.add(rw.rhs.subst(subst));
egraph.union(eclass, eclass2);
}
// restore the invariants once per iteration
egraph.rebuild();
}
return egraph.extract_best();
}Saturation: 書き換えによって情報が追加されなくなった時に起こる ex. x + y -> y + x. 与えられた全ての書き換えから導出可能な全ての合同関係が符号化されていること.
Extraction: あるコスト関数に従って, e-class から1つの表現を選ぶ手続き, Extractor がこれをする.
Equality Saturation は可能な全ての表現を探索して, 最適な項を抽出すること.
tutorial 02 getting started
Language trait (e-node) を実装すれば使える.
SymbolLang は用意されている.
let my_expression: RecExpr<SymbolLang> = "(foo a b)".parse().unwrap();
println!("this is my expression {}", my_expression);
let my_enode = SymbolLang::new("bar", vec![]);e-class(RecExpr) にも e-graph にも追加できる.
2回同じ構造を追加すると同一の ID が返る.
let mut expr = RecExpr::default();
let a = expr.add(SymbolLang::leaf("a"));
let b = expr.add(SymbolLang::leaf("b"));
let foo = expr.add(SymbolLang::new("foo", vec![a, b]));
// we can do the same thing with an EGraph
let mut egraph: EGraph<SymbolLang, ()> = Default::default();
let a = egraph.add(SymbolLang::leaf("a"));
let b = egraph.add(SymbolLang::leaf("b"));
let foo = egraph.add(SymbolLang::new("foo", vec![a, b]));
// we can also add RecExprs to an egraph
let foo2 = egraph.add_expr(&expr);
// note that if you add the same thing to an e-graph twice, you'll get back equivalent Ids
assert_eq!(foo, foo2);// let's make an e-graph
let mut egraph: EGraph<SymbolLang, ()> = Default::default();
let a = egraph.add(SymbolLang::leaf("a"));
let b = egraph.add(SymbolLang::leaf("b"));
let foo = egraph.add(SymbolLang::new("foo", vec![a, b]));
// rebuild the e-graph since we modified it
egraph.rebuild();
// we can make Patterns by parsing, similar to RecExprs
// names preceded by ? are parsed as Pattern variables and will match anything
let pat: Pattern<SymbolLang> = "(foo ?x ?x)".parse().unwrap();
// since we use ?x twice, it must match the same thing,
// so this search will return nothing
let matches = pat.search(&egraph);
assert!(matches.is_empty());
egraph.union(a, b);
// recall that rebuild must be called to "see" the effects of adds or unions
egraph.rebuild();
// now we can find a match since a = b
let matches = pat.search(&egraph);
assert!(!matches.is_empty());書き換えルールを簡単に記述するマクロが用意されている.
use egg::{*, rewrite as rw};
let rules: &[Rewrite<SymbolLang, ()>] = &[
rw!("commute-add"; "(+ ?x ?y)" => "(+ ?y ?x)"),
rw!("commute-mul"; "(* ?x ?y)" => "(* ?y ?x)"),
rw!("add-0"; "(+ ?x 0)" => "?x"),
rw!("mul-0"; "(* ?x 0)" => "0"),
rw!("mul-1"; "(* ?x 1)" => "?x"),
];
// While it may look like we are working with numbers,
// SymbolLang stores everything as strings.
// We can make our own Language later to work with other types.
let start = "(+ 0 (* 1 a))".parse().unwrap();
// That's it! We can run equality saturation now.
let runner = Runner::default().with_expr(&start).run(rules);
// Extractors can take a user-defined cost function,
// we'll use the egg-provided AstSize for now
let extractor = Extractor::new(&runner.egraph, AstSize);
// We want to extract the best expression represented in the
// same e-class as our initial expression, not from the whole e-graph.
// Luckily the runner stores the eclass Id where we put the initial expression.
let (best_cost, best_expr) = extractor.find_best(runner.roots[0]);
// we found the best thing, which is just "a" in this case
assert_eq!(best_expr, "a".parse().unwrap());
assert_eq!(best_cost, 1);デバッグの為の Explanation
ex. (/ (* (/ 2 3) (/ 3 2)) 1) がなぜ 1 に簡約されるか
FlatExplanation は読みやすい形で表示してくれる.
use egg::{*, rewrite as rw};
let rules: &[Rewrite<SymbolLang, ()>] = &[
rw!("div-one"; "?x" => "(/ ?x 1)"),
rw!("unsafe-invert-division"; "(/ ?a ?b)" => "(/ 1 (/ ?b ?a))"),
rw!("simplify-frac"; "(/ ?a (/ ?b ?c))" => "(/ (* ?a ?c) (* (/ ?b ?c) ?c))"),
rw!("cancel-denominator"; "(* (/ ?a ?b) ?b)" => "?a"),
rw!("times-zero"; "(* ?a 0)" => "0"),
];
let start = "(/ (* (/ 2 3) (/ 3 2)) 1)".parse().unwrap();
let end = "1".parse().unwrap();
let mut runner = Runner::default().with_explanations_enabled().with_expr(&start).run(rules);
println!("{}", runner.explain_equivalence(&start, &end).get_flat_string());(/ (* (/ 2 3) (/ 3 2)) 1)
(Rewrite<= div-one (* (/ 2 3) (/ 3 2)))
(* (Rewrite=> unsafe-invert-division (/ 1 (/ 3 2))) (/ 3 2))
(Rewrite=> cancel-denominator 1)
これで0 を入れると,
0
(Rewrite<= times-zero (* (/ 1 0) 0))
(Rewrite=> cancel-denominator 1)なぜ, 0 が 1になるのかとして (/ 1 0) というzero divisionがあるからということがわかる.
TreeExplanation はテストとかしやすい. 値の共有は let で書いてくれる.
(+ 1 (- a (* (- 2 1) a)))
(+
1
(Explanation
(- a (* (- 2 1) a))
(-
a
(Explanation
(* (- 2 1) a)
(* (Explanation (- 2 1) (Rewrite=> constant_fold 1)) a)
(Rewrite=> comm-mul (* a 1))
(Rewrite<= mul-one a)))
(Rewrite=> cancel-sub 0)))
(Rewrite=> constant_fold 1)
(let
(v_0 (- 2 1))
(let
(v_1 (- 2 (Explanation v_0 (Rewrite=> constant_fold 1))))
(Explanation
(* (- 2 (- 2 1)) (- 2 (- 2 1)))
(*
(Explanation (- 2 (- 2 1)) v_1 (Rewrite=> constant_fold 1))
(Explanation (- 2 (- 2 1)) v_1 (Rewrite=> constant_fold 1)))
(Rewrite=> constant_fold 1))))