diff --git a/plugins/machine_learning/CMakeLists.txt b/plugins/machine_learning/CMakeLists.txt index 13d83256421..959c5f8cb3c 100644 --- a/plugins/machine_learning/CMakeLists.txt +++ b/plugins/machine_learning/CMakeLists.txt @@ -9,6 +9,6 @@ if(PL_MACHINE_LEARNING OR BUILD_ALL_PLUGINS) SHARED HEADER ${MACHINE_LEARNING_INC} SOURCES ${MACHINE_LEARNING_SRC} ${MACHINE_LEARNING_PYTHON_SRC} - LINK_LIBRARIES nlohmann_json::nlohmann_json + LINK_LIBRARIES nlohmann_json::nlohmann_json netlist_preprocessing ) endif() diff --git a/plugins/machine_learning/include/machine_learning/labels/gate_pair_label.h b/plugins/machine_learning/include/machine_learning/labels/gate_pair_label.h index dc5a48d1ff9..68c8a148b88 100644 --- a/plugins/machine_learning/include/machine_learning/labels/gate_pair_label.h +++ b/plugins/machine_learning/include/machine_learning/labels/gate_pair_label.h @@ -17,6 +17,7 @@ namespace hal /* Forward declarations */ class Gate; class Netlist; + enum class PinDirection : int; namespace machine_learning { @@ -34,12 +35,12 @@ namespace hal /** * @brief Maps word pairs to corresponding gates. */ - std::map, std::vector> word_to_gates; + std::map, std::vector> word_to_gates; /** * @brief Maps gates to associated word pairs. */ - std::map>> gate_to_words; + std::map>> gate_to_words; }; /** diff --git a/plugins/machine_learning/src/labels/gate_pair_label.cpp b/plugins/machine_learning/src/labels/gate_pair_label.cpp index ada1be4a263..b3f2430c7db 100644 --- a/plugins/machine_learning/src/labels/gate_pair_label.cpp +++ b/plugins/machine_learning/src/labels/gate_pair_label.cpp @@ -4,6 +4,7 @@ #include "hal_core/netlist/netlist.h" #include "hal_core/utilities/log.h" #include "nlohmann/json.hpp" +#include "netlist_preprocessing/netlist_preprocessing.h" #include @@ -17,7 +18,7 @@ namespace hal { MultiBitInformation calculate_multi_bit_information(const std::vector& gates) { - std::map, std::set>> word_to_gates_unsorted; + std::map, std::set>> word_to_gates_unsorted; for (const auto g : gates) { @@ -33,38 +34,59 @@ namespace hal // std::cout << "Trying to parse string: " << json_string << std::endl; // TODO catch exceptions and return result - const nlohmann::json j = nlohmann::json::parse(json_string); - const std::vector> index_information = j; + nlohmann::json j = nlohmann::json::parse(json_string); + std::vector index_information = j.get>(); // TODO remove - // if (!index_information.empty()) - // { - // std::cout << "For gate " << g->get_id() << " found " << std::get<0>(index_information.front()) << " - " << std::get<1>(index_information.front()) << std::endl; - // } + if (!index_information.empty()) + { + std::cout << "For gate " << g->get_id() << " found " < pin_to_min_distance; + for (const auto& [_name, _index, _origin, pin, _direction, distance] : index_information) { - word_to_gates_unsorted[{name, direction}].insert({index, g}); + if (const auto it = pin_to_min_distance.find(pin); it == pin_to_min_distance.end()) + { + pin_to_min_distance.insert({pin, distance}); + } + else + { + pin_to_min_distance.at(pin) = std::min(it->second, distance); + } + } + + + for (const auto& [name, index, _origin, pin, direction, distance] : index_information) + { + if (pin_to_min_distance.at(pin) == distance) + { + word_to_gates_unsorted[{name, direction, pin}].insert({index, g}); + } } } // 1. Sort out words with the same name by checking whether they contain duplicate indices // 2. Dedupe all words by only keeping one word/name_direction for each multi_bit_signal/vector of gates. - std::map, std::pair> gates_to_word; + std::map, std::tuple> gates_to_word; for (const auto& [name_direction, word] : word_to_gates_unsorted) { std::set indices; - std::vector gates; + std::set unique_gates; + std::vector gates; // TODO remove // std::cout << "Order Word: " << std::endl; - for (const auto& [index, gate] : word) + for (auto& [index, gate] : word) { // TODO remove // std::cout << index << std::endl; indices.insert(index); + unique_gates.insert(gate); + gates.push_back(gate); } @@ -72,7 +94,7 @@ namespace hal if (indices.size() != word.size()) { // TODO return result - log_error("machine_learning", "Found index double in word {}-{}!", name_direction.first, name_direction.second); + log_error("machine_learning", "Found index double in word {}-{} - !", std::get<0>(name_direction), enum_to_string(std::get<1>(name_direction)), std::get<2>(name_direction)); // TODO remove std::cout << "Insane Word: " << std::endl; @@ -84,25 +106,82 @@ namespace hal continue; } + if (unique_gates.size() != word.size()) + { + continue; + } + + if (unique_gates.size() <= 1) + { + continue; + } + + std::cout << "Word [" << word.size() << "] " << std::get<0>(name_direction) << " - " << std::get<1>(name_direction) << " - " << std::get<2>(name_direction) << " : " << std::endl; + for (const auto& [index, gate] : word) + { + std::cout << index << ": " << gate->get_id() << std::endl; + } + if (const auto it = gates_to_word.find(gates); it == gates_to_word.end()) { gates_to_word.insert({gates, name_direction}); } // NOTE could think about a priorization of shorter names or something similar - // else } MultiBitInformation mbi; - for (const auto& [gates, name_direction] : gates_to_word) + for (auto& [word_gates, name_direction] : gates_to_word) { - mbi.word_to_gates[name_direction] = gates; - for (const auto g : gates) + mbi.word_to_gates[name_direction] = word_gates; + for (const auto g : word_gates) { mbi.gate_to_words[g].push_back(name_direction); } } + // filter words for each gate: + // 1) For each direction only take the biggest word + // 2) From all remaining only take the smallest word + // std::map>> filtered_gate_to_words; + // for (const auto g : gates) + // { + // const auto it = mbi.gate_to_words.find(g); + // if (it == mbi.gate_to_words.end()) + // { + // continue; + // } + + // std::set sizes; + // for (const auto& w : it->second) + // { + // sizes.insert(mbi.word_to_gates.at(w).size()); + // } + + // std::vector> filtered_words; + // for (const auto& w : it->second) + // { + // if (mbi.word_to_gates.at(w).size() == *(sizes.begin())) + // { + // filtered_words.push_back(w); + // } + // } + + // filtered_gate_to_words.insert({g, filtered_words}); + // } + + // std::map, std::vector> filtered_word_to_gates; + // for (const auto& [g, words] : filtered_gate_to_words) + // { + // for (const auto& w : words) + // { + // filtered_word_to_gates[w].push_back(g); + // } + // } + + // mbi.gate_to_words = filtered_gate_to_words; + // mbi.word_to_gates = filtered_word_to_gates; + return mbi; } } // namespace @@ -126,7 +205,7 @@ namespace hal for (const auto& g : gates) { // positive labels - std::unordered_set pos_gates; + std::unordered_set pos_gates; if (mbi.gate_to_words.find(g) == mbi.gate_to_words.end()) { // gate is only in a group with itself @@ -138,8 +217,8 @@ namespace hal // add all gates that are part of at least one other signal group as positive pair for (const auto& name_direction : mbi.gate_to_words.at(g)) { - const auto& gates = mbi.word_to_gates.at(name_direction); - for (const auto g_i : gates) + const auto& word_gates = mbi.word_to_gates.at(name_direction); + for (const auto* g_i : word_gates) { if (g == g_i) { @@ -147,7 +226,7 @@ namespace hal } pairs.push_back({g, g_i}); - pos_gates.insert(g); + pos_gates.insert(g_i); } } } @@ -156,13 +235,15 @@ namespace hal const u64 pos_count = pos_gates.size(); const u64 neg_count = std::min(gates.size() - pos_count, pos_count); + std::cout << "Gate ID: " << g->get_id() << " " << pos_count << " vs. " << neg_count << std::endl; + std::set chosen_gates; for (u32 i = 0; i < neg_count; i++) { - const u32 start = std::rand() % lc.nl->get_gates().size(); - for (u32 idx = start; idx < start + lc.nl->get_gates().size(); idx++) + const u32 start = std::rand() % gates.size(); + for (u32 idx = start; idx < start + gates.size(); idx = (idx + 1) % gates.size()) { - const auto g_i = lc.nl->get_gates().at(idx % lc.nl->get_gates().size()); + const auto g_i = gates.at(idx % gates.size()); if (pos_gates.find(g_i) == pos_gates.end() && chosen_gates.find(g_i) == chosen_gates.end()) { pairs.push_back({g, g_i}); @@ -179,8 +260,54 @@ namespace hal std::vector SharedSignalGroup::calculate_label(LabelContext& lc, const Gate* g_a, const Gate* g_b) const { const auto& mbi = lc.get_multi_bit_information(); - const auto& words_a = mbi.gate_to_words.at(g_a); - const auto& words_b = mbi.gate_to_words.at(g_b); + + const auto it_a = mbi.gate_to_words.find(g_a); + if (it_a == mbi.gate_to_words.end()) + { + return {0}; + } + + const auto it_b = mbi.gate_to_words.find(g_b); + if (it_b == mbi.gate_to_words.end()) + { + return {0}; + } + + const auto& words_a = it_a->second; + const auto& words_b = it_b->second; + + // // only consider the smallest words a gate is part of + // std::set sizes_a; + // std::set sizes_b; + + // for (const auto& w_a : words_a) + // { + // sizes_a.insert(mbi.word_to_gates.at(w_a).size()); + // } + + // for (const auto& w_b : words_b) + // { + // sizes_b.insert(mbi.word_to_gates.at(w_b).size()); + // } + + // std::vector> filtered_words_a; + // std::vector> filtered_words_b; + + // for (const auto& w_a : words_a) + // { + // if (mbi.word_to_gates.at(w_a).size() == *(sizes_a.begin())) + // { + // filtered_words_a.push_back(w_a); + // } + // } + + // for (const auto& w_b : words_b) + // { + // if (mbi.word_to_gates.at(w_b).size() == *(sizes_b.begin())) + // { + // filtered_words_b.push_back(w_b); + // } + // } for (const auto& wa : words_a) { diff --git a/plugins/machine_learning/src/plugin_machine_learning.cpp b/plugins/machine_learning/src/plugin_machine_learning.cpp index 1fdc5dc178e..e17c63d73b0 100644 --- a/plugins/machine_learning/src/plugin_machine_learning.cpp +++ b/plugins/machine_learning/src/plugin_machine_learning.cpp @@ -24,6 +24,6 @@ namespace hal std::set MachineLearningPlugin::get_dependencies() const { - return {}; + return {"netlist_preprocessing"}; } } // namespace hal diff --git a/plugins/netlist_preprocessing/include/netlist_preprocessing/netlist_preprocessing.h b/plugins/netlist_preprocessing/include/netlist_preprocessing/netlist_preprocessing.h index 8779448021a..18668cefb77 100644 --- a/plugins/netlist_preprocessing/include/netlist_preprocessing/netlist_preprocessing.h +++ b/plugins/netlist_preprocessing/include/netlist_preprocessing/netlist_preprocessing.h @@ -27,6 +27,7 @@ #include "hal_core/defines.h" #include "hal_core/utilities/result.h" +#include "nlohmann/json.hpp" #include #include @@ -40,6 +41,8 @@ namespace hal class Module; class Net; + enum class PinDirection; + namespace netlist_preprocessing { /** @@ -152,6 +155,46 @@ namespace hal */ Result simplify_lut_inits(Netlist* nl); + + /** + * Represents an identifier with an associated index and additional metadata, used for reconstructing and annotating names and indices + * for flip flops in synthesized netlists based on input and output net names as well as gate names. + * + * This struct is designed specifically for use with synthesized netlists. By analyzing net and gate names, we attempt to reconstruct a + * multi bit word and index for each flip flop. + * + * The reconstructed identifiers, stored as `indexed_identifier` instances, are added to the gate data container in the netlist. + * + * Members: + * - identifier: The reconstructed name of the flip flop. + * - index: The index number associated with the identifier, if part of a multi-bit signal. + * - origin: The original source or scope of the identifier. + * - pin: The specific pin associated with the identifier. + * - direction: The direction of the pin (e.g., INPUT, OUTPUT, INOUT). + * - distance: The distance or offset, representing additional structural information. + */ + struct indexed_identifier + { + indexed_identifier(); + indexed_identifier(const std::string& identifier, const u32 index, const std::string& origin, const std::string& pin, const PinDirection& direction, const u32 distance); + + std::string identifier; /**< The reconstructed name of the multi-bit words. */ + u32 index; /**< The index associated with the identifier, used for multi-bit signals. */ + std::string origin; /**< The origin or source of the identifier within the netlist (either "gate_name" or "net_name"). */ + std::string pin; /**< The pin name associated with this identifier. */ + PinDirection direction; /**< Direction of the pin. */ + u32 distance; /**< Distance to merged net which name this index was derived from. */ + + // Overload < operator for strict weak ordering + bool operator<(const indexed_identifier& other) const; + }; + + // Serialization function for indexed_identifier as a list of values + void to_json(nlohmann::json& j, const indexed_identifier& id); + + // Deserialization function for indexed_identifier from a list of values + void from_json(const nlohmann::json& j, indexed_identifier& id); + /** * Tries to reconstruct a name and index for each flip flop that was part of a multi-bit wire in the verilog code. * This is NOT a general netlist reverse engineering algorithm and ONLY works on synthesized netlists with names annotated by the synthesizer. diff --git a/plugins/netlist_preprocessing/src/netlist_preprocessing.cpp b/plugins/netlist_preprocessing/src/netlist_preprocessing.cpp index d6e2830b2f9..1deedf252b8 100644 --- a/plugins/netlist_preprocessing/src/netlist_preprocessing.cpp +++ b/plugins/netlist_preprocessing/src/netlist_preprocessing.cpp @@ -11,10 +11,9 @@ #include "hal_core/netlist/net.h" #include "hal_core/netlist/netlist.h" #include "hal_core/utilities/token_stream.h" -#include "nlohmann/json.hpp" -#include "rapidjson/document.h" #include "resynthesis/resynthesis.h" #include "z3_utils/netlist_comparison.h" +#include "z3_utils/subgraph_function_generation.h" #include #include @@ -924,6 +923,10 @@ namespace hal } std::vector> equality_classes; + z3::context ctx; + + std::map net_cache; + std::map, BooleanFunction> gate_cache; for (const auto& [_fingerprint, nets] : fingerprint_to_nets) { @@ -951,14 +954,36 @@ namespace hal for (const auto& m : current_candidate_nets) { - auto comp_res = z3_utils::compare_nets(nl, nl, n, m); - if (comp_res.is_error()) + + const auto bf_n = z3_utils::get_subgraph_z3_function(all_comb_gates_vec, n, ctx, net_cache, gate_cache); + const auto bf_m = z3_utils::get_subgraph_z3_function(all_comb_gates_vec, m, ctx, net_cache, gate_cache); + + if (bf_n.is_error()) { - return ERR_APPEND(comp_res.get_error(), - "Unable to remove redundant logic trees: failed to compare net " + n->get_name() + " with ID " + std::to_string(n->get_id()) + " with net " - + m->get_name() + " with ID " + std::to_string(m->get_id())); + return ERR_APPEND(bf_n.get_error(), + "Unable to remove redundant logic trees: failed to build Boolean function for net " + n->get_name() + " with ID " + std::to_string(n->get_id())); } - const auto are_equal = comp_res.get(); + + if (bf_m.is_error()) + { + return ERR_APPEND(bf_m.get_error(), + "Unable to remove redundant logic trees: failed to build Boolean function for net " + m->get_name() + " with ID " + std::to_string(m->get_id())); + } + + z3::solver s(ctx); + s.add(bf_n.get() != bf_m.get()); + const auto res = s.check(); + + const bool are_equal = (res == z3::unsat); + + // auto comp_res = z3_utils::compare_nets(nl, nl, n, m); + // if (comp_res.is_error()) + // { + // return ERR_APPEND(comp_res.get_error(), + // "Unable to remove redundant logic trees: failed to compare net " + n->get_name() + " with ID " + std::to_string(n->get_id()) + " with net " + // + m->get_name() + " with ID " + std::to_string(m->get_id())); + // } + // const auto are_equal = comp_res.get(); if (are_equal) { @@ -1951,20 +1976,41 @@ namespace hal return OK(num_inits); } - namespace + indexed_identifier::indexed_identifier() { - struct indexed_identifier - { - indexed_identifier(const std::string& identifier, const u32 index, const std::string& origin, const PinDirection& direction) : identifier{identifier}, index{index}, origin{origin}, direction{direction} - { - } + } - std::string identifier; - u32 index; - std::string origin; - PinDirection direction; - }; + indexed_identifier::indexed_identifier(const std::string& identifier, const u32 index, const std::string& origin, const std::string& pin, const PinDirection& direction, const u32 distance) : identifier{identifier}, index{index}, origin{origin}, pin{pin}, direction{direction}, distance{distance} + { + } + + // Overload < operator for strict weak ordering + bool indexed_identifier::operator<(const indexed_identifier& other) const + { + return std::tie(identifier, index, origin, pin, direction, distance) < + std::tie(other.identifier, other.index, other.origin, other.pin, other.direction, other.distance); + } + + // Serialization function for indexed_identifier as a list of values + void to_json(nlohmann::json& j, const indexed_identifier& id) + { + j = nlohmann::json{ id.identifier, id.index, id.origin, id.pin, enum_to_string(id.direction), id.distance }; + } + + // Deserialization function for indexed_identifier from a list of values + void from_json(const nlohmann::json& j, indexed_identifier& id) + { + j.at(0).get_to(id.identifier); + j.at(1).get_to(id.index); + j.at(2).get_to(id.origin); + j.at(3).get_to(id.pin); + const std::string direction_string = j.at(4).get(); + id.direction = enum_from_string(direction_string); + j.at(5).get_to(id.distance); + } + namespace + { // TODO when the verilog parser changes are merged into the master this will no longer be needed const std::string hal_instance_index_pattern = "__\\[(\\d+)\\]__"; const std::string hal_instance_index_pattern_reverse = "(\\d+)"; @@ -2005,7 +2051,7 @@ namespace hal const std::string gate_index_pattern = "\\[(\\d+)\\]"; // Extracts an index from a string by taking the last integer enclosed by parentheses - std::optional extract_index(const std::string& name, const std::string& index_pattern, const std::string& origin, const PinDirection& direction) + std::optional extract_index(const std::string& name, const std::string& index_pattern, const std::string& origin, const std::string& pin, const PinDirection& direction, const u32 distance) { std::regex re(index_pattern); @@ -2032,22 +2078,14 @@ namespace hal auto identifier_name = name; identifier_name = identifier_name.replace(name.rfind(found_match), found_match.size(), ""); - return std::optional{{identifier_name, last_index.value(), origin, direction}}; + return std::optional{{identifier_name, last_index.value(), origin, pin, direction, distance}}; } // annotate all found identifiers to a gate bool annotate_indexed_identifiers(Gate* gate, const std::vector& identifiers) { - std::string json_identifier_str = "[" - + utils::join(", ", - identifiers, - [](const auto& i) { - return std::string("[") + '"' + i.identifier + '"' + ", " + std::to_string(i.index) + ", " + '"' + i.origin + '"' + "," + '"' - + enum_to_string(i.direction) + '"' + "]"; - }) - + "]"; - - return gate->set_data("preprocessing_information", "multi_bit_indexed_identifiers", "string", json_identifier_str); + nlohmann::json j = identifiers; // Convert the vector to JSON + return gate->set_data("preprocessing_information", "multi_bit_indexed_identifiers", "string", j.dump()); } // search for a net that connects to the gate at a pin of a specific type and tries to reconstruct an indexed identifier from its name or form a name of its merged wires @@ -2071,44 +2109,47 @@ namespace hal continue; } - // 1) search the net name itself - const auto net_name_index = extract_index(typed_net->get_name(), net_index_pattern, "net_name", pin->get_direction()); - if (net_name_index.has_value()) - { - found_identfiers.push_back(net_name_index.value()); - } - - // 2) search all the names of the wires that where merged into this net - if (!typed_net->has_data("parser_annotation", "merged_nets")) - { - continue; - } - - const auto all_merged_nets_str = std::get<1>(typed_net->get_data("parser_annotation", "merged_nets")); - - if (all_merged_nets_str.empty()) + std::vector> merged_nets; + + // 1) search all the names of the wires that where merged into this net + if (typed_net->has_data("parser_annotation", "merged_nets")) { - continue; - } - - // parse json list of merged net names - rapidjson::Document doc; - doc.Parse(all_merged_nets_str.c_str()); + const auto all_merged_nets_str = std::get<1>(typed_net->get_data("parser_annotation", "merged_nets")); - for (u32 i = 0; i < doc.GetArray().Size(); i++) - { - const auto list = doc[i].GetArray(); - for (u32 j = 0; j < list.Size(); j++) + if (all_merged_nets_str.empty()) { - const auto merged_wire_name = list[j].GetString(); + nlohmann::json merged_nets_json = nlohmann::json::parse(all_merged_nets_str); + merged_nets = merged_nets_json.get>>(); - const auto merged_wire_name_index = extract_index(merged_wire_name, net_index_pattern, "net_name", pin->get_direction()); - if (merged_wire_name_index.has_value()) + // the order of the merged nets starts with nets closest to the destination of the net (which is connected to an input pin) + if (pin->get_direction() == PinDirection::output) { - found_identfiers.push_back(merged_wire_name_index.value()); + std::reverse(merged_nets.begin(), merged_nets.end()); + } + + for (u32 i = 0; i < merged_nets.size(); i++) + { + for (u32 j = 0; j < merged_nets.at(i).size(); j++) + { + const auto merged_wire_name = merged_nets.at(i).at(j); + + const auto merged_wire_name_index = extract_index(merged_wire_name, net_index_pattern, "net_name", pin->get_name(), pin->get_direction(), i + 1); + if (merged_wire_name_index.has_value()) + { + found_identfiers.push_back(merged_wire_name_index.value()); + } + } } } } + + // 2) search the net name itself + const u32 distance = (pin->get_direction() == PinDirection::output) ? merged_nets.size() + 1 : 0; + const auto net_name_index = extract_index(typed_net->get_name(), net_index_pattern, "net_name", pin->get_name(), pin->get_direction(), distance); + if (net_name_index.has_value()) + { + found_identfiers.push_back(net_name_index.value()); + } } return found_identfiers; @@ -2124,7 +2165,7 @@ namespace hal // 1) Check whether the ff gate already has an index annotated in its gate name const auto cleaned_gate_name = replace_hal_instance_index(ff->get_name()); - const auto gate_name_index = extract_index(cleaned_gate_name, gate_index_pattern, "gate_name", PinDirection::none); + const auto gate_name_index = extract_index(cleaned_gate_name, gate_index_pattern, "gate_name", "", PinDirection::none, 0); if (gate_name_index.has_value()) { @@ -2159,13 +2200,13 @@ namespace hal for (const auto& pin : nl->get_top_module()->get_pins()) { - auto reconstruct = extract_index(pin->get_name(), net_index_pattern, "", pin->get_direction()); + auto reconstruct = extract_index(pin->get_name(), net_index_pattern, "pin_name", pin->get_name(), pin->get_direction(), 0); if (!reconstruct.has_value()) { continue; } - auto [pg_name, index, _origin, _direction] = reconstruct.value(); + auto [pg_name, index, _origin, _pin, _direction, _distance] = reconstruct.value(); pg_name_to_indexed_pins[pg_name][index].push_back(pin); } diff --git a/plugins/z3_utils/include/z3_utils/subgraph_function_generation.h b/plugins/z3_utils/include/z3_utils/subgraph_function_generation.h index 02f20a9a8e3..92924146efc 100644 --- a/plugins/z3_utils/include/z3_utils/subgraph_function_generation.h +++ b/plugins/z3_utils/include/z3_utils/subgraph_function_generation.h @@ -28,11 +28,14 @@ #include "hal_core/utilities/result.h" #include "z3++.h" +#include #include namespace hal { + class BooleanFunction; class Gate; + class GatePin; class Net; namespace z3_utils @@ -48,6 +51,19 @@ namespace hal */ Result get_subgraph_z3_function(const std::vector& subgraph_gates, const Net* subgraph_output, z3::context& ctx); + + /** + * @brief Get the z3 expression representation of a combined Boolean function of a subgraph of combinational gates starting at the source of the provided subgraph output net. + * + * The variables of the resulting Boolean function are created from the subgraph input nets using `BooleanFunctionNetDecorator::get_boolean_variable`. + * + * @param[in] subgraph_gates - The gates making up the subgraph to consider. + * @param[in] subgraph_output - The subgraph oputput net for which to generate the Boolean function. + * @return The the z3 expression representation of combined Boolean function of the subgraph on success, an error otherwise. + */ + Result get_subgraph_z3_function(const std::vector& subgraph_gates, const Net* subgraph_output, z3::context& ctx, std::map& net_cache, + std::map, BooleanFunction>& gate_cache); + /** * @brief Get the z3 expression representations of combined Boolean functions of a subgraph of combinational gates starting at the sources of the provided subgraph output nets. * diff --git a/plugins/z3_utils/src/netlist_comparison.cpp b/plugins/z3_utils/src/netlist_comparison.cpp index 1a0aa47059d..412769ea35e 100644 --- a/plugins/z3_utils/src/netlist_comparison.cpp +++ b/plugins/z3_utils/src/netlist_comparison.cpp @@ -490,6 +490,7 @@ namespace hal } } } + for (const auto& gate_b : seq_gates_b) { gate_name_to_gate_b[gate_b->get_name()] = gate_b; @@ -516,6 +517,7 @@ namespace hal return OK(false); } } + for (const auto& [gate_b_name, gate_b] : gate_name_to_gate_b) { if (const auto gate_a_it = gate_name_to_gate_a.find(gate_b_name); gate_a_it == gate_name_to_gate_a.end()) @@ -548,24 +550,37 @@ namespace hal const auto out_pins_a = netlist_a->get_top_module()->get_output_pin_names(); const auto out_pins_b = netlist_b->get_top_module()->get_output_pin_names(); - auto all_out_pins = out_pins_a; - all_out_pins.insert(all_out_pins.end(), out_pins_b.begin(), out_pins_b.end()); + std::set all_out_pins = utils::to_set(out_pins_a); + all_out_pins.insert(out_pins_b.begin(), out_pins_b.end()); + + // TODO remove debug print + // std::cout << "PINS A: " << std::endl; + // for (const auto& o : out_pins_a) + // { + // std::cout << "\t" << o << std::endl; + // } + // std::cout << "PINS B: " << std::endl; + // for (const auto& o : out_pins_b) + // { + // std::cout << "\t" << o << std::endl; + // } + for (const auto& pin : all_out_pins) { - auto it_a = std::find(out_pins_a.begin(), out_pins_a.end(), pin); - if (it_a == out_pins_a.end()) + const auto pin_a = netlist_a->get_top_module()->get_pin_by_name(pin); + if (pin_a == nullptr) { log_warning("z3_utils", - "netlist a with ID {} and netlist b with ID {} might not be equal: netlist a has output pin {} that does not exist in netlist b!", + "netlist a with ID {} and netlist b with ID {} might not be equal: netlist b has output pin {} that does not exist in netlist a!", netlist_a->get_id(), netlist_b->get_id(), pin); continue; } - auto it_b = std::find(out_pins_b.begin(), out_pins_b.end(), pin); - if (it_b == out_pins_b.end()) + const auto pin_b = netlist_b->get_top_module()->get_pin_by_name(pin); + if (pin_b == nullptr) { log_warning("z3_utils", "netlist a with ID {} and netlist b with ID {} might not be equal: netlist a has output pin {} that does not exist in netlist b!", @@ -575,8 +590,9 @@ namespace hal continue; } - Net* net_a = netlist_a->get_top_module()->get_pin_by_name(pin)->get_net(); - Net* net_b = netlist_b->get_top_module()->get_pin_by_name(pin)->get_net(); + Net* net_a = pin_a->get_net(); + Net* net_b = pin_b->get_net(); + to_compare.insert({net_a, net_b}); } diff --git a/plugins/z3_utils/src/subgraph_function_generation.cpp b/plugins/z3_utils/src/subgraph_function_generation.cpp index 68a51375a9e..386088f482d 100644 --- a/plugins/z3_utils/src/subgraph_function_generation.cpp +++ b/plugins/z3_utils/src/subgraph_function_generation.cpp @@ -145,6 +145,12 @@ namespace hal } // namespace + Result get_subgraph_z3_function(const std::vector& subgraph_gates, const Net* subgraph_output, z3::context& ctx, std::map& net_cache, + std::map, BooleanFunction>& gate_cache) + { + return get_subgraph_z3_function_internal(subgraph_gates, subgraph_output, ctx, net_cache, gate_cache); + } + Result get_subgraph_z3_function(const std::vector& subgraph_gates, const Net* subgraph_output, z3::context& ctx) { std::map net_cache;