Skip to content

Commit

Permalink
Ml graph fixes (#99)
Browse files Browse the repository at this point in the history
* Fixes to the ML graph and visualise inst body

* Clean up diff

* Do not reconnect on hidden proof-children
  • Loading branch information
JonasAlaif authored Dec 25, 2024
1 parent 09ac3cf commit 67bd225
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 42 deletions.
18 changes: 10 additions & 8 deletions axiom-profiler-GUI/src/screen/graphviz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,9 @@ impl
Some(false) => "output".to_owned(),
None => "const".to_owned(),
},
QI(ref sig, pattern) => pattern
QI(_, quantifier, term) | InstBody(quantifier, term) => term
.simp
.with_data(&ctxt, &mut Some(sig.qpat.quant))
.with_data(&ctxt, &mut Some(quantifier))
.to_string(),
FixedENode(matched_term) => matched_term.simp.with(&ctxt).to_string(),
RecurringENode(matched_term, input) => {
Expand Down Expand Up @@ -362,7 +362,9 @@ impl
"Fixed nodes which do not change but are used in each iteration.".to_owned()
}
},
QI(ref sig, _) => ctxt.parser[sig.qpat.quant].kind.with(&ctxt).to_string(),
QI(_, quantifier, _) | InstBody(quantifier, _) => {
ctxt.parser[*quantifier].kind.with(&ctxt).to_string()
}
FixedENode(matched_term) => matched_term.orig.with(&ctxt).to_string(),
RecurringENode(matched_term, input) => {
ctxt.config.input = input.rec_input();
Expand All @@ -382,7 +384,7 @@ impl
use MLGraphNode::*;
match self {
HiddenNode(..) => "",
QI(..) | RecurringENode(..) | RecurringEquality(..) => "filled",
QI(..) | InstBody(..) | RecurringENode(..) | RecurringEquality(..) => "filled",
FixedENode(..) | FixedEquality(..) => "filled,dashed",
}
}
Expand All @@ -398,8 +400,8 @@ impl
use MLGraphNode::*;
match self {
HiddenNode(..) => Default::default(),
QI(sig, _) => {
let hue = ctx.get_rbg_hue(Some(sig.qpat.quant)).unwrap() / 360.0;
QI(_, quantifier, _) | InstBody(quantifier, _) => {
let hue = ctx.get_rbg_hue(Some(*quantifier)).unwrap() / 360.0;
format!("{hue} {NODE_COLOUR_SATURATION} {NODE_COLOUR_VALUE}")
}
FixedENode(..) | RecurringENode(..) => ENODE_COLOUR.to_owned(),
Expand All @@ -421,7 +423,7 @@ impl
Some(false) => "output",
None => "fixed",
},
QI(..) => "middle",
QI(..) | InstBody(..) => "middle",
};
class.to_string()
}
Expand Down Expand Up @@ -741,7 +743,7 @@ impl DotEdgeProperties<bool, (), (), (), (), (), (), (), ()> for MLGraphEdge {
use MLGraphEdge::*;
match self {
HiddenEdge(..) => "none",
Blame(..) | Yield => "normal",
Instantiation | Blame(..) | Yield => "normal",
BlameEq(..) | YieldEq | CombineEq => "empty",
}
}
Expand Down
108 changes: 81 additions & 27 deletions smt-log-parser/src/analysis/graph/analysis/matching_loop/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,18 @@ impl MlExplainer {
self.error = true;
return Vec::new();
};
let prev_inst = self.graph.add_node(MLGraphNode::QI(sig, pat.into()));
let old = self.instantiations.insert(prev, prev_inst);
let quantifier = sig.qpat.quant;

let prev_inst = MLGraphNode::QI(sig, quantifier, pat.into());
let prev_inst = self.graph.add_node(prev_inst);

let body = parser.quantifier_body(quantifier).unwrap().into();
let inst_body = MLGraphNode::InstBody(quantifier, body);
let inst_body = self.graph.add_node(inst_body);

self.graph
.add_edge(prev_inst, inst_body, MLGraphEdge::Instantiation);
let old = self.instantiations.insert(prev, inst_body);
assert!(old.is_none());

let mut recurring = Vec::new();
Expand Down Expand Up @@ -427,13 +437,6 @@ impl MlExplainer {
self.ancestor_is_recurring = false;
}
self.ancestor_is_recurring = ancestor_is_recurring;
if self.ancestor_is_recurring && !self.add_mode {
self.burned_eqs.insert(eq);
self.add_mode = true;
self.super_walk_trans(eq, forward)?;
self.add_mode = false;
assert!(self.ancestor_is_recurring);
}
Ok(())
}
fn walk_trans(
Expand Down Expand Up @@ -471,7 +474,24 @@ impl MlExplainer {
return Ok(());
}

self.super_walk_trans(eq, forward)
self.super_walk_and_rewalk(eq, forward)
}
}
impl TransEqGraphWalker<'_> {
fn super_walk_and_rewalk(
&mut self,
eq: EqTransIdx,
forward: bool,
) -> core::result::Result<(), Never> {
self.super_walk_trans(eq, forward)?;
if self.ancestor_is_recurring && !self.add_mode {
self.burned_eqs.insert(eq);
self.add_mode = true;
self.super_walk_trans(eq, forward)?;
self.add_mode = false;
assert!(self.ancestor_is_recurring);
}
Ok(())
}
}
let mut walker = TransEqGraphWalker {
Expand All @@ -482,22 +502,18 @@ impl MlExplainer {
add_mode: false,
burned_eqs: FxHashSet::default(),
};
walker.super_walk_trans(eqidx, true).unwrap();
walker.super_walk_and_rewalk(eqidx, true).unwrap();
walker.ancestor_is_recurring
}

pub fn simplify_terms(mut self, parser: &mut Z3Parser) -> Result<MlExplanation> {
let mut collector = QVarParentCollector::new(parser);
for &i in self.instantiations.values() {
let MLGraphNode::QI(_, pattern) = &self.graph[i] else {
unreachable!();
};
let pattern = parser.synth_terms.as_tidx(pattern.orig).unwrap();
let has_qvars = collector.collect_term(pattern);
debug_assert!(has_qvars);
self.collect_inst(&mut collector, i);
}
let mut simplifier = TermSimplifier {
forbidden_apps: collector.forbidden_apps,
terms_with_qvars: collector.terms_with_qvars,
parser,
simplifications: FxHashMap::default(),
stack: Vec::new(),
Expand All @@ -507,44 +523,77 @@ impl MlExplainer {
}
Ok(self.graph)
}

fn collect_inst(&self, collector: &mut QVarParentCollector<'_>, i: NodeIndex) {
let MLGraphNode::InstBody(_, body) = &self.graph[i] else {
unreachable!();
};
collector.stack = None;
let body = collector.parser.synth_terms.as_tidx(body.orig).unwrap();
let _has_qvars = collector.collect_term(body);

let dir = petgraph::Direction::Incoming;
let mut parents = self.graph.neighbors_directed(i, dir);
let i = parents.next().unwrap();
assert_eq!(parents.next(), None);
let MLGraphNode::QI(_, _, pattern) = &self.graph[i] else {
unreachable!();
};
collector.stack = Some(Vec::new());
let pattern = collector.parser.synth_terms.as_tidx(pattern.orig).unwrap();
let has_qvars = collector.collect_term(pattern);
debug_assert!(has_qvars);
}
}

struct QVarParentCollector<'a> {
parser: &'a Z3Parser,
stack: Vec<IString>,
stack: Option<Vec<IString>>,
forbidden_apps: FxHashSet<BoxSlice<IString>>,
terms_with_qvars: FxHashSet<TermIdx>,
}

impl<'a> QVarParentCollector<'a> {
fn new(parser: &'a Z3Parser) -> Self {
Self {
parser,
stack: Default::default(),
stack: Some(Default::default()),
forbidden_apps: Default::default(),
terms_with_qvars: Default::default(),
}
}
fn collect_term(&mut self, tidx: TermIdx) -> bool {
let term = &self.parser[tidx];
match term.kind {
TermKind::Var(_) => true,
TermKind::Var(_) => {
self.terms_with_qvars.insert(tidx);
true
}
TermKind::App(app) => {
self.stack.push(app);
if let Some(stack) = &mut self.stack {
stack.push(app);
}

let mut has_qvar = false;
for &child in term.child_ids.iter() {
has_qvar |= self.collect_term(child);
}

if has_qvar {
self.terms_with_qvars.insert(tidx);
// I have qvars so prevent sequences of terms like me being
// replaced.
for i in 0..self.stack.len() {
let forbidden_app = self.stack[i..].iter().copied().collect();
self.forbidden_apps.insert(forbidden_app);
if let Some(stack) = &self.stack {
for i in 0..stack.len() {
let forbidden_app = stack[i..].iter().copied().collect();
self.forbidden_apps.insert(forbidden_app);
}
}
}

self.stack.pop();
if let Some(stack) = &mut self.stack {
stack.pop();
}
has_qvar
}
TermKind::Quant(..) => unreachable!(),
Expand All @@ -558,13 +607,15 @@ struct TermSimplifier<'a> {

stack: Vec<IString>,
forbidden_apps: FxHashSet<BoxSlice<IString>>,
terms_with_qvars: FxHashSet<TermIdx>,
}

impl TermSimplifier<'_> {
fn simplify_node(&mut self, node: &mut MLGraphNode) -> Result<()> {
match node {
MLGraphNode::HiddenNode(..) => (),
MLGraphNode::QI(_, idx)
MLGraphNode::QI(_, _, idx)
| MLGraphNode::InstBody(_, idx)
| MLGraphNode::FixedENode(idx)
| MLGraphNode::RecurringENode(idx, _) => {
idx.simp = self.simplify_term(idx.simp)?;
Expand All @@ -579,7 +630,9 @@ impl TermSimplifier<'_> {

fn simplify_term(&mut self, idx: SynthIdx) -> Result<SynthIdx> {
let tidx = self.parser.synth_terms.as_tidx(idx);
let has_qvars = tidx.is_some_and(|tidx| self.terms_with_qvars.contains(&tidx));
if let Some(tidx) = tidx.and_then(|tidx| self.simplifications.get(&tidx)) {
assert!(!has_qvars);
return Ok(*tidx);
}
let term = &self.parser[idx];
Expand All @@ -597,7 +650,8 @@ impl TermSimplifier<'_> {
}) if {
self.stack.push(*name);
*must_pop = true;
!self.forbidden_apps.contains(self.stack.as_slice())
let is_forbidden = self.forbidden_apps.contains(self.stack.as_slice());
!is_forbidden && !has_qvars
} =>
{
self.stack.pop();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::parsers::z3::synthetic::SynthIdx;
use crate::{items::QuantIdx, parsers::z3::synthetic::SynthIdx};

use super::MlSignature;

Expand All @@ -9,7 +9,8 @@ use mem_dbg::{MemDbg, MemSize};
#[derive(Debug, Clone, Eq, Hash, PartialEq)]
pub enum MLGraphNode {
HiddenNode(Option<bool>),
QI(MlSignature, SimpIdx),
QI(MlSignature, QuantIdx, SimpIdx),
InstBody(QuantIdx, SimpIdx),
FixedENode(SimpIdx),
RecurringENode(SimpIdx, RecurrenceKind),
FixedEquality(SimpIdx, SimpIdx),
Expand All @@ -22,6 +23,7 @@ pub enum MLGraphEdge {
HiddenEdge(bool, u32),
Blame(usize),
BlameEq(usize),
Instantiation,
Yield,
YieldEq,
CombineEq,
Expand Down
6 changes: 3 additions & 3 deletions smt-log-parser/src/analysis/graph/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -504,9 +504,9 @@ pub enum NodeKind {
TransEquality(EqTransIdx),
/// Corresponds to `ProofIdx`.
///
/// **Parents:** (small) arbitrary count, will always be `Proof` or
/// **Parents:** (large) arbitrary count, will always be `Proof` or
/// `Instantiation`.
/// **Children:** (small) arbitrary count, will always be `Proof`.
/// **Children:** (small?) arbitrary count, will always be `Proof`.
Proof(ProofIdx),
/// Corresponds to `CdclIdx`. Only connected to other `Cdcl` nodes.
///
Expand Down Expand Up @@ -582,7 +582,7 @@ impl NodeKind {
(
Self::ENode(..) | Self::TransEquality(..),
Self::Instantiation(..)
)
) | (Self::Proof(..), Self::Proof(..))
)
}
}
Expand Down
4 changes: 2 additions & 2 deletions smt-log-parser/src/display_with.rs
Original file line number Diff line number Diff line change
Expand Up @@ -598,8 +598,8 @@ impl<'a, 'b> DisplayWithCtxt<DisplayCtxt<'b>, DisplayData<'b>> for &'a SynthTerm
SynthTermKind::Variable(var) => write!(f, "${var}"),
SynthTermKind::Input(offset) => match offset {
Some(offset) => {
write!(f, "⭐ + ")?;
offset.fmt_with(f, ctxt, &mut None)
offset.fmt_with(f, ctxt, &mut None)?;
write!(f, " + ⭐")
}
None => write!(f, "⭐"),
},
Expand Down
5 changes: 5 additions & 0 deletions smt-log-parser/src/parsers/z3/z3parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,11 @@ impl Z3Parser {
self.cdcl.cdcls()
}

pub fn quantifier_body(&self, qidx: QuantIdx) -> Option<TermIdx> {
let children = &self[self[qidx].term].child_ids;
children.last().copied()
}

pub fn patterns(&self, q: QuantIdx) -> Option<&TiSlice<PatternIdx, TermIdx>> {
let child_ids = &self[self[q].term].child_ids;
child_ids
Expand Down

0 comments on commit 67bd225

Please sign in to comment.