diff --git a/crates/ast/src/attribute.rs b/crates/ast/src/attribute.rs index abe4c522..12abf739 100644 --- a/crates/ast/src/attribute.rs +++ b/crates/ast/src/attribute.rs @@ -9,12 +9,12 @@ pub enum NumAttr {} /// An flag attribute #[derive(Enum, Clone, Copy, PartialEq, EnumString)] pub enum BoolAttr { + /// This is a toplevel component + #[strum(serialize = "toplevel")] + TopLevel, /// Use a counter based FSM design #[strum(serialize = "counter_fsm")] CounterFSM, - /// Dummy attribute because the [Enum] trait derivation throws errors if there is only one varianat - #[strum(disabled)] - Dummy, } /// Represents a single attribute. This is a private enum that is used during @@ -50,6 +50,11 @@ impl Attributes { } } + /// Check if an attribute is set + pub fn has(&self, attr: impl Into) -> bool { + self.attrs[attr.into()].is_some() + } + /// Get the value of an attribute. pub fn get(&self, attr: impl Into) -> Option { self.attrs[attr.into()].map(|(v, _)| v) diff --git a/crates/ast/src/component.rs b/crates/ast/src/component.rs index eeaf42ee..1c5516e8 100644 --- a/crates/ast/src/component.rs +++ b/crates/ast/src/component.rs @@ -1,3 +1,5 @@ +use crate::BoolAttr; + use super::{Command, Id, Signature}; use fil_gen as gen; use gen::GenConfig; @@ -44,6 +46,7 @@ impl Component { } } +#[derive(Default)] pub struct Namespace { /// Imported files pub imports: Vec, @@ -51,23 +54,11 @@ pub struct Namespace { pub externs: Vec, /// Components defined in this file pub components: Vec, - /// Top-level component id - pub toplevel: String, /// Top level bindings pub bindings: Vec, } impl Namespace { - pub fn new(toplevel: String) -> Self { - Self { - imports: Vec::default(), - externs: Vec::default(), - components: Vec::default(), - bindings: Vec::default(), - toplevel, - } - } - /// Returns true if the namespace declares at least one generative module pub fn requires_gen(&self) -> bool { self.externs.iter().any(|Extern { gen, .. }| gen.is_some()) @@ -103,8 +94,14 @@ impl Namespace { /// Get the index to the top-level component. /// Currently, this is the distinguished "main" component pub fn main_idx(&self) -> Option { - self.components.iter().position(|c| { - c.sig.name.inner() == &Id::from(self.toplevel.clone()) - }) + self.components + .iter() + .position(|c| c.sig.attributes.has(BoolAttr::TopLevel)) + } + + /// Get the toplevel component name + pub fn toplevel(&self) -> Option<&str> { + self.main_idx() + .map(|idx| self.components[idx].sig.name.inner().as_ref()) } } diff --git a/crates/ast/src/parser.rs b/crates/ast/src/parser.rs index 313af481..1a9097c9 100644 --- a/crates/ast/src/parser.rs +++ b/crates/ast/src/parser.rs @@ -902,8 +902,11 @@ impl FilamentParser { Ok(match_nodes!( input.into_children(); [imports(imps), comp_or_ext(mixed).., _EOI] => { - let mut namespace = ast::Namespace::new("main".to_string()); - namespace.imports = imps; + let mut namespace = ast::Namespace { + imports: imps, + ..Default::default() + }; + for m in mixed { match m { BodyEl::Ext(sig) => namespace.externs.push(sig), diff --git a/crates/filament/src/ast_passes/mod.rs b/crates/filament/src/ast_passes/mod.rs new file mode 100644 index 00000000..f83ab4e4 --- /dev/null +++ b/crates/filament/src/ast_passes/mod.rs @@ -0,0 +1,3 @@ +mod toplevel; + +pub use toplevel::TopLevel; diff --git a/crates/filament/src/ast_passes/toplevel.rs b/crates/filament/src/ast_passes/toplevel.rs new file mode 100644 index 00000000..20c09265 --- /dev/null +++ b/crates/filament/src/ast_passes/toplevel.rs @@ -0,0 +1,84 @@ +use crate::ast_visitor::{Action, Visitor}; +use fil_ast as ast; +use fil_utils::{Diagnostics, Error, GPosIdx}; + +/// Sets the proper FSM Attributes for every component +#[derive(Default)] +pub struct TopLevel { + /// Set to true if we find a toplevel component + has_toplevel: Option, + /// Error reporting + diag: Diagnostics, +} + +impl Visitor for TopLevel { + fn name() -> &'static str { + "fsm-attributes" + } + + fn signature(&mut self, sig: &mut ast::Signature) -> Action { + if sig.attributes.get(ast::BoolAttr::TopLevel) == Some(1) { + if self.has_toplevel.is_some() { + let err = Error::malformed("Multiple top-level components") + .add_note(self.diag.add_info( + "first top-level component here", + self.has_toplevel.unwrap(), + )) + .add_note( + self.diag.add_info( + "second top-level component here", + sig.attributes + .get_loc(ast::BoolAttr::TopLevel) + .unwrap(), + ), + ); + + self.diag.add_error(err); + } else { + self.has_toplevel = Some( + sig.attributes.get_loc(ast::BoolAttr::TopLevel).unwrap(), + ); + } + } + + // Stop traversal into component body + Action::Stop + } + + fn external(&mut self, ext: &mut ast::Extern) { + for sig in &mut ext.comps { + if sig.attributes.get(ast::BoolAttr::TopLevel) == Some(1) { + let err = + Error::malformed("External components cannot be top-level") + .add_note( + self.diag.add_info( + "toplevel attribute here", + sig.attributes + .get_loc(ast::BoolAttr::TopLevel) + .unwrap(), + ), + ); + + self.diag.add_error(err); + } + } + } + + fn after_traversal(mut self) -> Option { + self.diag.report_all() + } + + fn finish(&mut self, ast: &mut ast::Namespace) { + // If no toplevel component was found, find the component with the name "main" + if self.has_toplevel.is_none() { + for comp in ast.components.iter_mut() { + if comp.sig.name.as_ref() == "main" { + // Add the toplevel attribute to the component + comp.sig.attributes.set(ast::BoolAttr::TopLevel, 1); + + return; + } + } + } + } +} diff --git a/crates/filament/src/ast_visitor/mod.rs b/crates/filament/src/ast_visitor/mod.rs new file mode 100644 index 00000000..3daca7fa --- /dev/null +++ b/crates/filament/src/ast_visitor/mod.rs @@ -0,0 +1,3 @@ +mod visitor; + +pub use visitor::{Action, Construct, Visitor}; diff --git a/crates/filament/src/ast_visitor/visitor.rs b/crates/filament/src/ast_visitor/visitor.rs new file mode 100644 index 00000000..6a4ad3be --- /dev/null +++ b/crates/filament/src/ast_visitor/visitor.rs @@ -0,0 +1,259 @@ +use crate::cmdline; +use fil_ast as ast; + +#[must_use] +/// Action performed by the visitor +pub enum Action { + /// Stop visiting the CFG + Stop, + /// Continue visiting the CFG + Continue, + /// Add commands after this command + AddBefore(Vec), + /// Change the current command with other commands + Change(Vec), +} + +impl Action { + /// Run the traversal specified by `next` if this traversal succeeds. + /// If the result of this traversal is not `Action::Continue`, do not + /// run `next()`. + pub fn and_then(self, mut next: F) -> Action + where + F: FnMut() -> Action, + { + match self { + Action::Continue => next(), + Action::Change(_) | Action::AddBefore(_) | Action::Stop => self, + } + } +} + +/// Construct a visitor +pub trait Construct { + fn from(opts: &cmdline::Opts, ast: &mut ast::Namespace) -> Self; + + /// Clear data before the next component has been visited + fn clear_data(&mut self); +} + +impl Construct for T { + fn from(_: &cmdline::Opts, _: &mut ast::Namespace) -> Self { + Self::default() + } + + fn clear_data(&mut self) { + *self = Self::default(); + } +} + +/// Visit and transform the given AST +pub trait Visitor +where + Self: Sized + Construct, +{ + /// The user visible name for the pass + fn name() -> &'static str; + + #[must_use] + /// Executed after the visitor has visited all the components. + /// If the return value is `Some`, the number is treated as an error code. + fn after_traversal(self) -> Option { + None + } + + fn bundle(&mut self, _: &mut ast::Bundle) -> Action { + Action::Continue + } + + fn connect(&mut self, _: &mut ast::Connect) -> Action { + Action::Continue + } + + fn exists(&mut self, _: &mut ast::Exists) -> Action { + Action::Continue + } + + fn fact(&mut self, _: &mut ast::Fact) -> Action { + Action::Continue + } + + fn start_loop(&mut self, _: &mut ast::ForLoop) -> Action { + Action::Continue + } + + fn end_loop(&mut self, _: &mut ast::ForLoop) -> Action { + Action::Continue + } + + fn do_loop(&mut self, l: &mut ast::ForLoop) -> Action { + self.start_loop(l) + .and_then(|| self.visit_cmds(&mut l.body)) + .and_then(|| self.end_loop(l)) + } + + fn start_if(&mut self, _: &mut ast::If) -> Action { + Action::Continue + } + + fn end_if(&mut self, _: &mut ast::If) -> Action { + Action::Continue + } + + fn do_if(&mut self, i: &mut ast::If) -> Action { + self.start_if(i) + .and_then(|| self.visit_cmds(&mut i.then)) + .and_then(|| self.visit_cmds(&mut i.alt)) + .and_then(|| self.end_if(i)) + } + + fn instance(&mut self, _: &mut ast::Instance) -> Action { + Action::Continue + } + + fn invoke(&mut self, _: &mut ast::Invoke) -> Action { + Action::Continue + } + + fn param_let(&mut self, _: &mut ast::ParamLet) -> Action { + Action::Continue + } + + fn visit_cmd(&mut self, cmd: &mut ast::Command) -> Action { + match cmd { + ast::Command::Bundle(bundle) => self.bundle(bundle), + ast::Command::Connect(connect) => self.connect(connect), + ast::Command::Exists(exists) => self.exists(exists), + ast::Command::Fact(fact) => self.fact(fact), + ast::Command::ForLoop(forloop) => self.do_loop(forloop), + ast::Command::If(i) => self.do_if(i), + ast::Command::Instance(inst) => self.instance(inst), + ast::Command::Invoke(inv) => self.invoke(inv), + ast::Command::ParamLet(pl) => self.param_let(pl), + } + } + + fn start_cmds(&mut self, _: &mut Vec) {} + + fn end_cmds(&mut self, _: &mut Vec) {} + + /// Visit a list of commands (a scope) + fn visit_cmds(&mut self, cmds: &mut Vec) -> Action { + self.start_cmds(cmds); + + let cs = std::mem::take(cmds); + let mut n_cmds = Vec::with_capacity(cs.len()); + let mut iter = cs.into_iter(); + + let mut stopped = false; + for mut cmd in iter.by_ref() { + match self.visit_cmd(&mut cmd) { + Action::Stop => { + stopped = true; + break; + } + Action::Continue => { + n_cmds.push(cmd); + } + Action::Change(cmds) => { + n_cmds.extend(cmds.into_iter()); + } + Action::AddBefore(cmds) => { + n_cmds.extend(cmds.into_iter()); + n_cmds.push(cmd); + } + } + } + n_cmds.extend(iter); + *cmds = n_cmds; + + if stopped { + Action::Stop + } else { + self.end_cmds(cmds); + Action::Continue + } + } + + /// Visit a component signature + fn signature(&mut self, _: &mut ast::Signature) -> Action { + Action::Continue + } + + fn after_component(&mut self, _: &mut ast::Component) {} + + /// Visit a component + fn component(&mut self, comp: &mut ast::Component) { + let pre_cmds = match self.signature(&mut comp.sig) { + Action::Stop => return, + Action::Continue => None, + Action::AddBefore(cmds) => Some(cmds), + Action::Change(_) => { + unreachable!( + "visit_cmds should not attempt to change AST nodes" + ) + } + }; + + // Traverse the commands + let mut cmds = std::mem::take(&mut comp.body); + match self.visit_cmds(&mut cmds) { + Action::Stop | Action::Continue => (), + Action::Change(_) | Action::AddBefore(_) => { + unreachable!( + "visit_cmds should not attempt to change AST nodes" + ) + } + } + + if let Some(pre_cmds) = pre_cmds { + comp.body = pre_cmds; + } + + comp.body.extend(cmds); + + self.after_component(comp); + } + + /// Visit an extern + fn external(&mut self, ext: &mut ast::Extern) { + for sig in &mut ext.comps { + match self.signature(sig) { + Action::Stop => break, + Action::Continue => (), + Action::AddBefore(_) | Action::Change(_) => { + unreachable!( + "Externs should not attempt to change AST nodes" + ) + } + } + } + } + + /// Run after the pass is done + fn finish(&mut self, _: &mut ast::Namespace) {} + + /// Perform the pass + fn do_pass( + opts: &cmdline::Opts, + ast: &mut ast::Namespace, + ) -> Result<(), u64> { + let mut visitor = Self::from(opts, ast); + for comp in ast.components.iter_mut() { + visitor.component(comp); + visitor.clear_data(); + } + + for ext in ast.externs.iter_mut() { + visitor.external(ext); + visitor.clear_data(); + } + + visitor.finish(ast); + + match visitor.after_traversal() { + Some(n) => Err(n), + None => Ok(()), + } + } +} diff --git a/crates/filament/src/cmdline.rs b/crates/filament/src/cmdline.rs index f87de769..76cf7747 100644 --- a/crates/filament/src/cmdline.rs +++ b/crates/filament/src/cmdline.rs @@ -83,10 +83,6 @@ pub struct Opts { #[argh(option, long = "log", default = "log::LevelFilter::Warn")] pub log_level: log::LevelFilter, - /// set toplevel - #[argh(option, long = "toplevel", default = "\"main\".into()")] - pub toplevel: String, - /// skip the discharge pass (unsafe) #[argh(switch, long = "unsafe-skip-discharge")] pub unsafe_skip_discharge: bool, @@ -104,10 +100,10 @@ pub struct Opts { /// backend to use (default: verilog): calyx, verilog #[argh(option, long = "backend", default = "Backend::Verilog")] pub backend: Backend, - /// disable generation of counter-based FSMs in the backend. The default (non-counter) FSM - /// is represented by a single bit Shift Register counting through the number of states. - /// However, for components with a large number of states or a large II, it may be more efficient to use a - /// counter-based FSM, where one counter loops every II states, at which point it increments the state counter. + /// disable generation of counter-based FSMs in the backend. + /// The default (non-counter) FSM is represented by a single bit Shift Register counting through the number of states. + /// However, for components with a large number of states or a large II, it may be more efficient to use a counter-based FSM, + /// where one counter loops every II states, at which point it increments the state counter. #[argh(switch, long = "no-counter-fsms")] pub no_counter_fsms: bool, /// preserves original port names during compilation. diff --git a/crates/filament/src/ir_passes/fsm_attributes.rs b/crates/filament/src/ir_passes/fsm_attributes.rs index 6f91d7bc..8a71c3e1 100644 --- a/crates/filament/src/ir_passes/fsm_attributes.rs +++ b/crates/filament/src/ir_passes/fsm_attributes.rs @@ -1,4 +1,4 @@ -use crate::ir_visitor::{Visitor, VisitorData}; +use crate::ir_visitor::{Action, Visitor, VisitorData}; use fil_ast as ast; use fil_ir::TimeSub; @@ -11,17 +11,17 @@ impl Visitor for FSMAttributes { "fsm-attributes" } - fn visit(&mut self, mut data: VisitorData) { + fn start(&mut self, data: &mut VisitorData) -> Action { let attrs = &data.comp.attrs; // Check if the component already has FSM attributes if attrs.get(ast::BoolAttr::CounterFSM).is_some() { - return; + return Action::Stop; } // If the component is external or generated, do not add any slow FSM attributes if data.comp.is_ext() || data.comp.is_gen() { - return; + return Action::Stop; } // Get the delay of the component if it is a single event component @@ -38,10 +38,12 @@ impl Visitor for FSMAttributes { // TODO(UnsignedByte): Find a better heuristic for slow FSMs if delay > 1 { data.comp.attrs.set(ast::BoolAttr::CounterFSM, 1); - return; + return Action::Stop; } } data.comp.attrs.set(ast::BoolAttr::CounterFSM, 0); + + Action::Stop } } diff --git a/crates/filament/src/lib.rs b/crates/filament/src/lib.rs index 4abe1171..9ae65b65 100644 --- a/crates/filament/src/lib.rs +++ b/crates/filament/src/lib.rs @@ -1,3 +1,5 @@ +pub mod ast_passes; +pub mod ast_visitor; pub mod cmdline; pub mod ir_passes; pub mod ir_visitor; diff --git a/crates/filament/src/macros.rs b/crates/filament/src/macros.rs index 16de979a..1fb9a436 100644 --- a/crates/filament/src/macros.rs +++ b/crates/filament/src/macros.rs @@ -47,18 +47,19 @@ macro_rules! log_pass { } #[macro_export] -/// A macro generate the pass pipeline. For each provided pass, it will: +/// A macro to generate the IR pass pipeline. For each provided pass, it will: /// 1. Record the amount of time it took to run the pass. -/// 2. Print out the state of the AST if the name of the pass was in the +/// 2. Print out the state of the IR if the name of the pass was in the /// print-after declaration. /// /// Usage: /// ``` -/// pass_pipeline! { opts, ir; +/// ir_pass_pipeline! { opts, ir; /// Pass1, -/// Pass2, ... +/// Pass2, // ... /// } -macro_rules! pass_pipeline { +/// ``` +macro_rules! ir_pass_pipeline { ($opts:ident, $ir:ident; $($pass:path),*) => { $( let name = <$pass as $crate::ir_visitor::Visitor>::name(); @@ -69,3 +70,25 @@ macro_rules! pass_pipeline { )* }; } + +#[macro_export] +/// A macro to generate the AST pass pipeline. For each provided pass, it will: +/// 1. Record the amount of time it took to run the pass. +/// 2. Print out the state of the AST if the name of the pass was in the +/// print-after declaration. +/// +/// Usage: +/// ``` +/// ast_pass_pipeline! { opts, ir; +/// Pass1, +/// Pass2, // ... +/// } +/// ``` +macro_rules! ast_pass_pipeline { + ($opts:ident, $ast:ident; $($pass:path),*) => { + $( + let name = <$pass as $crate::ast_visitor::Visitor>::name(); + $crate::log_time!(<$pass as $crate::ast_visitor::Visitor>::do_pass($opts, &mut $ast)?, name); + )* + }; +} diff --git a/crates/filament/src/main.rs b/crates/filament/src/main.rs index 97e2339e..c49d7116 100644 --- a/crates/filament/src/main.rs +++ b/crates/filament/src/main.rs @@ -3,8 +3,10 @@ use calyx_opt::pass_manager::PassManager; use fil_gen::GenConfig; use fil_ir as ir; use filament::ir_passes::BuildDomination; -use filament::{cmdline, ir_passes as ip, resolver::Resolver}; -use filament::{log_pass, log_time, pass_pipeline}; +use filament::{ast_pass_pipeline, ir_pass_pipeline, log_pass, log_time}; +use filament::{ + ast_passes as ap, cmdline, ir_passes as ip, resolver::Resolver, +}; use serde::Deserialize; use std::collections::HashMap; use std::fs; @@ -37,22 +39,25 @@ fn run(opts: &cmdline::Opts) -> Result<(), u64> { .map(|path| toml::from_str(&fs::read_to_string(path).unwrap()).unwrap()) .unwrap_or_default(); - let ns = match Resolver::from(opts).parse_namespace() { - Ok(mut ns) => { - ns.toplevel = opts.toplevel.clone(); - ns.bindings = provided_bindings - .params - .get(opts.toplevel.as_str()) - .cloned() - .unwrap_or_default(); - ns - } + let mut ns = match Resolver::from(opts).parse_namespace() { + Ok(ns) => ns, Err(e) => { eprintln!("Error: {e:?}"); return Err(1); } }; + ast_pass_pipeline! { opts, ns; ap::TopLevel }; + + // Set the parameter bindings for the top-level component + if let Some(main) = ns.toplevel() { + ns.bindings = provided_bindings + .params + .get(main) + .cloned() + .unwrap_or_default(); + } + // Initialize the generator let mut gen_exec = if ns.requires_gen() { if opts.out_dir.is_none() @@ -71,7 +76,7 @@ fn run(opts: &cmdline::Opts) -> Result<(), u64> { // Transform AST to IR let mut ir = log_pass! { opts; ir::transform(ns)?, "astconv" }; - pass_pipeline! {opts, ir; + ir_pass_pipeline! {opts, ir; ip::BuildDomination, ip::TypeCheck, ip::IntervalCheck, @@ -79,13 +84,13 @@ fn run(opts: &cmdline::Opts) -> Result<(), u64> { ip::Assume } if !opts.unsafe_skip_discharge { - pass_pipeline! {opts, ir; ip::Discharge } + ir_pass_pipeline! {opts, ir; ip::Discharge } } - pass_pipeline! { opts, ir; + ir_pass_pipeline! { opts, ir; BuildDomination }; ir = log_pass! { opts; ip::Monomorphize::transform(&ir, &mut gen_exec), "monomorphize"}; - pass_pipeline! { opts, ir; + ir_pass_pipeline! { opts, ir; ip::FSMAttributes, ip::Simplify, ip::AssignCheck,