diff --git a/include/hal_core/netlist/decorators/netlist_abstraction_decorator.h b/include/hal_core/netlist/decorators/netlist_abstraction_decorator.h index d467e56ed15..d56c194a445 100644 --- a/include/hal_core/netlist/decorators/netlist_abstraction_decorator.h +++ b/include/hal_core/netlist/decorators/netlist_abstraction_decorator.h @@ -47,6 +47,8 @@ namespace hal struct NETLIST_API NetlistAbstraction { public: + NetlistAbstraction(NetlistAbstraction&& other) = default; + /** * @brief Creates a `NetlistAbstraction` from a set of gates. * @@ -56,11 +58,11 @@ namespace hal * @param[in] exit_endpoint_filter - Filter condition to stop traversal on a fan-in/out endpoint. * @param[in] entry_endpoint_filter - Filter condition to stop traversal on a successor/predecessor endpoint. */ - static Result create(const Netlist* netlist, - const std::vector& gates, - const bool include_all_netlist_gates = false, - const std::function& exit_endpoint_filter = nullptr, - const std::function& entry_endpoint_filter = nullptr); + static Result> create(const Netlist* netlist, + const std::vector& gates, + const bool include_all_netlist_gates = false, + const std::function& exit_endpoint_filter = nullptr, + const std::function& entry_endpoint_filter = nullptr); /** * @brief Gets the predecessors of a gate within the abstraction. diff --git a/plugins/machine_learning/include/machine_learning/types.h b/plugins/machine_learning/include/machine_learning/types.h index 1583961bc17..3c0376f34b1 100644 --- a/plugins/machine_learning/include/machine_learning/types.h +++ b/plugins/machine_learning/include/machine_learning/types.h @@ -3,6 +3,9 @@ #include "hal_core/defines.h" #include "hal_core/netlist/decorators/netlist_abstraction_decorator.h" +#include +#include +#include #include #include @@ -52,10 +55,16 @@ namespace hal const u32 num_threads; private: - std::optional m_sequential_abstraction; - std::optional m_original_abstraction; - std::optional> m_possible_gate_type_properties; - std::optional m_mbi; + std::shared_ptr m_mbi{nullptr}; + std::shared_ptr m_sequential_abstraction{nullptr}; + std::shared_ptr m_original_abstraction{nullptr}; + std::shared_ptr> m_possible_gate_type_properties{nullptr}; + + // Mutexes for thread-safe initialization + std::mutex m_mbi_mutex; + std::mutex m_sequential_abstraction_mutex; + std::mutex m_original_abstraction_mutex; + std::mutex m_possible_gate_type_properties_mutex; }; enum GraphDirection diff --git a/plugins/machine_learning/src/features/gate_feature.cpp b/plugins/machine_learning/src/features/gate_feature.cpp index 83891fe05b9..117ed681711 100644 --- a/plugins/machine_learning/src/features/gate_feature.cpp +++ b/plugins/machine_learning/src/features/gate_feature.cpp @@ -10,7 +10,7 @@ #include #define MAX_DISTANCE 255 -#define PROGRESS_BAR +// #define PROGRESS_BAR namespace hal { diff --git a/plugins/machine_learning/src/features/gate_pair_feature.cpp b/plugins/machine_learning/src/features/gate_pair_feature.cpp index 6d6dc72cc05..8c911cb269d 100644 --- a/plugins/machine_learning/src/features/gate_pair_feature.cpp +++ b/plugins/machine_learning/src/features/gate_pair_feature.cpp @@ -7,7 +7,7 @@ #include "machine_learning/features/gate_pair_feature.h" #define MAX_DISTANCE 255 -#define PROGRESS_BAR +// #define PROGRESS_BAR namespace hal { diff --git a/plugins/machine_learning/src/graph_neural_network.cpp b/plugins/machine_learning/src/graph_neural_network.cpp index 7298e5d04a4..82b6422489b 100644 --- a/plugins/machine_learning/src/graph_neural_network.cpp +++ b/plugins/machine_learning/src/graph_neural_network.cpp @@ -89,7 +89,7 @@ namespace hal { return ERR_APPEND(sequential_abstraction_res.get_error(), "cannot get sequential netlist abstraction for gate feature context: failed to build abstraction."); } - const auto sequential_abstraction = sequential_abstraction_res.get(); + const auto& sequential_abstraction = sequential_abstraction_res.get(); // edge list std::vector sources; @@ -100,7 +100,7 @@ namespace hal const u32 g_idx = gate_to_idx.at(g); if (dir == GraphDirection::directed) { - const auto unique_predecessors = sequential_abstraction.get_unique_predecessors(g); + const auto unique_predecessors = sequential_abstraction->get_unique_predecessors(g); if (unique_predecessors.is_error()) { return ERR_APPEND(unique_predecessors.get_error(), @@ -115,7 +115,7 @@ namespace hal if (dir == GraphDirection::undirected) { - const auto unique_successors = sequential_abstraction.get_unique_successors(g); + const auto unique_successors = sequential_abstraction->get_unique_successors(g); if (unique_successors.is_error()) { return ERR_APPEND(unique_successors.get_error(), diff --git a/plugins/machine_learning/src/types.cpp b/plugins/machine_learning/src/types.cpp index 64fbe4e422d..55c702ff301 100644 --- a/plugins/machine_learning/src/types.cpp +++ b/plugins/machine_learning/src/types.cpp @@ -272,96 +272,134 @@ namespace hal const MultiBitInformation& Context::get_multi_bit_information() { - if (!m_mbi.has_value()) + auto mbi = std::atomic_load_explicit(&m_mbi, std::memory_order_acquire); + if (mbi) { - const auto seq_gates = nl->get_gates([](const auto* g) { return g->get_type()->has_property(GateTypeProperty::sequential); }); - m_mbi = calculate_multi_bit_information(seq_gates); + return *mbi; } + else + { + std::lock_guard lock(m_mbi_mutex); + mbi = std::atomic_load_explicit(&m_mbi, std::memory_order_acquire); + if (mbi) + { + return *mbi; + } - return m_mbi.value(); + auto new_mbi = std::make_shared(); + const auto seq_gates = nl->get_gates([](const auto* g) { return g->get_type()->has_property(GateTypeProperty::sequential); }); + *new_mbi = calculate_multi_bit_information(seq_gates); + + std::atomic_store_explicit(&m_mbi, new_mbi, std::memory_order_release); + + return *new_mbi; + } } const Result Context::get_sequential_abstraction() { - if (!m_sequential_abstraction.has_value()) + auto abstraction = std::atomic_load_explicit(&m_sequential_abstraction, std::memory_order_acquire); + if (abstraction) { + return OK(abstraction.get()); + } + else + { + std::lock_guard lock(m_sequential_abstraction_mutex); + // Double-check after acquiring the lock + abstraction = std::atomic_load_explicit(&m_sequential_abstraction, std::memory_order_acquire); + if (abstraction) + { + return OK(abstraction.get()); + } + const auto seq_gates = nl->get_gates([](const auto* g) { return g->get_type()->has_property(GateTypeProperty::sequential); }); - const std::vector forbidden_pins = { - PinType::clock, /*PinType::done, PinType::error, PinType::error_detection,*/ /*PinType::none,*/ PinType::ground, PinType::power /*, PinType::status*/}; + const std::vector forbidden_pins = {PinType::clock, PinType::ground, PinType::power}; - const auto endpoint_filter = [forbidden_pins](const auto* ep, const auto& _d) { - UNUSED(_d); + const auto endpoint_filter = [forbidden_pins](const auto* ep, const auto&) { return std::find(forbidden_pins.begin(), forbidden_pins.end(), ep->get_pin()->get_type()) == forbidden_pins.end(); }; - const auto sequential_abstraction_res = NetlistAbstraction::create(nl, seq_gates, true, endpoint_filter, endpoint_filter); + auto sequential_abstraction_res = NetlistAbstraction::create(nl, seq_gates, true, endpoint_filter, endpoint_filter); if (sequential_abstraction_res.is_error()) { - return ERR_APPEND(sequential_abstraction_res.get_error(), "cannot get sequential netlist abstraction for gate feature context: failed to build abstraction."); + return ERR_APPEND(sequential_abstraction_res.get_error(), "Cannot get sequential netlist abstraction: failed to build abstraction."); } - m_sequential_abstraction = sequential_abstraction_res.get(); + auto new_abstraction = sequential_abstraction_res.get(); - // TODO remove debug print - // std::cout << "Built abstraction" << std::endl; - } + std::atomic_store_explicit(&m_sequential_abstraction, new_abstraction, std::memory_order_release); - return OK(&m_sequential_abstraction.value()); + return OK(m_sequential_abstraction.get()); + } } const Result Context::get_original_abstraction() { - if (!m_original_abstraction.has_value()) + auto abstraction = std::atomic_load_explicit(&m_original_abstraction, std::memory_order_acquire); + if (abstraction) { - // const std::vector forbidden_pins = { - // PinType::clock, /*PinType::done, PinType::error, PinType::error_detection,*/ /*PinType::none,*/ PinType::ground, PinType::power /*, PinType::status*/}; - - // const auto endpoint_filter = [forbidden_pins](const auto* ep, const auto& _d) { - // UNUSED(_d); - // return std::find(forbidden_pins.begin(), forbidden_pins.end(), ep->get_pin()->get_type()) == forbidden_pins.end(); - // }; + return OK(abstraction.get()); + } + else + { + std::lock_guard lock(m_original_abstraction_mutex); + // Double-check after acquiring the lock + abstraction = std::atomic_load_explicit(&m_original_abstraction, std::memory_order_acquire); + if (abstraction) + { + return OK(abstraction.get()); + } - const auto original_abstraction_res = NetlistAbstraction::create(nl, nl->get_gates(), true, nullptr, nullptr); + auto original_abstraction_res = NetlistAbstraction::create(nl, nl->get_gates(), true, nullptr, nullptr); if (original_abstraction_res.is_error()) { - return ERR_APPEND(original_abstraction_res.get_error(), "cannot get original netlist abstraction for gate feature context: failed to build abstraction."); + return ERR_APPEND(original_abstraction_res.get_error(), "Cannot get original netlist abstraction: failed to build abstraction."); } - m_original_abstraction = original_abstraction_res.get(); + auto new_abstraction = original_abstraction_res.get(); - // TODO remove debug print - // std::cout << "Built abstraction" << std::endl; - } + std::atomic_store_explicit(&m_original_abstraction, new_abstraction, std::memory_order_release); - return OK(&m_original_abstraction.value()); + return OK(m_original_abstraction.get()); + } } const std::vector& Context::get_possible_gate_type_properties() { - if (!m_possible_gate_type_properties.has_value()) + auto properties = std::atomic_load_explicit(&m_possible_gate_type_properties, std::memory_order_acquire); + if (properties) + { + return *properties; + } + else { - std::set properties; + std::lock_guard lock(m_possible_gate_type_properties_mutex); + // Double-check after acquiring the lock + properties = std::atomic_load_explicit(&m_possible_gate_type_properties, std::memory_order_acquire); + if (properties) + { + return *properties; + } + + std::set property_set; for (const auto& [_name, gt] : nl->get_gate_library()->get_gate_types()) { - const auto gt_properties = gt->get_properties(); - properties.insert(gt_properties.begin(), gt_properties.end()); + const auto& gt_properties = gt->get_properties(); + property_set.insert(gt_properties.begin(), gt_properties.end()); } - // for (auto& [gtp, _name] : EnumStrings::data) - // { - // UNUSED(_name); - // properties.insert(gtp); - // } + auto properties_vec = std::make_shared>(property_set.begin(), property_set.end()); - auto properties_vec = utils::to_vector(properties); - // sort alphabetically - std::sort(properties_vec.begin(), properties_vec.end(), [](const auto& a, const auto& b) { return enum_to_string(a) < enum_to_string(b); }); - m_possible_gate_type_properties = properties_vec; - } + // Sort alphabetically + std::sort(properties_vec->begin(), properties_vec->end(), [](const auto& a, const auto& b) { return enum_to_string(a) < enum_to_string(b); }); - return m_possible_gate_type_properties.value(); + std::atomic_store_explicit(&m_possible_gate_type_properties, properties_vec, std::memory_order_release); + + return *properties_vec; + } } } // namespace machine_learning } // namespace hal \ No newline at end of file diff --git a/src/netlist/decorators/netlist_abstraction_decorator.cpp b/src/netlist/decorators/netlist_abstraction_decorator.cpp index fbb996000ba..74a610659ce 100644 --- a/src/netlist/decorators/netlist_abstraction_decorator.cpp +++ b/src/netlist/decorators/netlist_abstraction_decorator.cpp @@ -7,20 +7,25 @@ namespace hal { - Result NetlistAbstraction::create(const Netlist* netlist, - const std::vector& gates, - const bool include_all_netlist_gates, - const std::function& exit_endpoint_filter, - const std::function& entry_endpoint_filter) + Result> NetlistAbstraction::create(const Netlist* netlist, + const std::vector& gates, + const bool include_all_netlist_gates, + const std::function& exit_endpoint_filter, + const std::function& entry_endpoint_filter) { const auto nl_trav_dec = NetlistTraversalDecorator(*netlist); // transform gates into set to check fast if a gate is part of abstraction - const auto gates_set = utils::to_unordered_set(gates); + const auto gates_set = utils::to_unordered_set(gates); + const auto& included_gates = include_all_netlist_gates ? netlist->get_gates() : gates; - auto new_abstraction = NetlistAbstraction(); + auto new_abstraction = std::shared_ptr(new NetlistAbstraction()); + const u32 approximated_endpoint_count = included_gates.size() * 8; + new_abstraction->m_successors.reserve(approximated_endpoint_count); + new_abstraction->m_predecessors.reserve(approximated_endpoint_count); + new_abstraction->m_global_output_successors.reserve(approximated_endpoint_count); + new_abstraction->m_global_input_predecessors.reserve(approximated_endpoint_count); - const auto& included_gates = include_all_netlist_gates ? netlist->get_gates() : gates; for (const Gate* gate : included_gates) { // TODO remove debug print @@ -29,7 +34,7 @@ namespace hal // gather all successors for (Endpoint* ep_out : gate->get_fan_out_endpoints()) { - new_abstraction.m_successors.insert({ep_out, {}}); + new_abstraction->m_successors.insert({ep_out, {}}); const auto successors = nl_trav_dec.get_next_matching_endpoints( ep_out, true, @@ -46,14 +51,14 @@ namespace hal for (Endpoint* ep : successors.get()) { - new_abstraction.m_successors.at(ep_out).push_back(ep); + new_abstraction->m_successors.at(ep_out).push_back(ep); } } // gather all global output succesors for (Endpoint* ep_out : gate->get_fan_out_endpoints()) { - new_abstraction.m_global_output_successors.insert({ep_out, {}}); + new_abstraction->m_global_output_successors.insert({ep_out, {}}); const auto destinations = nl_trav_dec.get_next_matching_endpoints( ep_out, true, [](const auto& ep) { return ep->is_source_pin() && ep->get_net()->is_global_output_net(); }, false, exit_endpoint_filter, entry_endpoint_filter); @@ -66,14 +71,14 @@ namespace hal for (const auto* ep : destinations.get()) { - new_abstraction.m_global_output_successors.at(ep_out).push_back({ep->get_net()}); + new_abstraction->m_global_output_successors.at(ep_out).push_back({ep->get_net()}); } } // gather all predecessors for (Endpoint* ep_in : gate->get_fan_in_endpoints()) { - new_abstraction.m_predecessors.insert({ep_in, {}}); + new_abstraction->m_predecessors.insert({ep_in, {}}); const auto predecessors = nl_trav_dec.get_next_matching_endpoints(ep_in, false, [gates_set](const auto& ep) { return ep->is_source_pin() && gates_set.find(ep->get_gate()) != gates_set.end(); }); @@ -86,14 +91,14 @@ namespace hal for (Endpoint* ep : predecessors.get()) { - new_abstraction.m_predecessors.at(ep_in).push_back(ep); + new_abstraction->m_predecessors.at(ep_in).push_back(ep); } } // gather all global input predecessors for (Endpoint* ep_in : gate->get_fan_in_endpoints()) { - new_abstraction.m_global_input_predecessors.insert({ep_in, {}}); + new_abstraction->m_global_input_predecessors.insert({ep_in, {}}); const auto predecessors = nl_trav_dec.get_next_matching_endpoints(ep_in, false, [](const auto& ep) { return ep->is_destination_pin() && ep->get_net()->is_global_input_net(); }); @@ -105,7 +110,7 @@ namespace hal for (const auto* ep : predecessors.get()) { - new_abstraction.m_global_input_predecessors.at(ep_in).push_back({ep->get_net()}); + new_abstraction->m_global_input_predecessors.at(ep_in).push_back({ep->get_net()}); } } } diff --git a/src/python_bindings/bindings/netlist_abstraction_decorator.cpp b/src/python_bindings/bindings/netlist_abstraction_decorator.cpp index d955443f296..1c0d8dfc2ec 100644 --- a/src/python_bindings/bindings/netlist_abstraction_decorator.cpp +++ b/src/python_bindings/bindings/netlist_abstraction_decorator.cpp @@ -25,7 +25,7 @@ namespace hal const std::vector& gates, bool include_all_netlist_gates, const std::function& exit_endpoint_filter, - const std::function& entry_endpoint_filter) -> std::optional { + const std::function& entry_endpoint_filter) -> std::optional> { auto res = NetlistAbstraction::create(netlist, gates, include_all_netlist_gates, exit_endpoint_filter, entry_endpoint_filter); if (res.is_ok()) {