diff --git a/pyzx/graph/multigraph.py b/pyzx/graph/multigraph.py index 37e14817..b8e17179 100644 --- a/pyzx/graph/multigraph.py +++ b/pyzx/graph/multigraph.py @@ -69,7 +69,7 @@ def __init__(self) -> None: self._auto_simplify: bool = True self._vindex: int = 0 self.nedges: int = 0 - self.ty: Dict[int,VertexType] = dict() + self.ty: Dict[int,VertexType] = dict() self._phase: Dict[int, FractionLike] = dict() self._qindex: Dict[int, FloatInt] = dict() self._maxq: FloatInt = -1 @@ -108,6 +108,9 @@ def set_auto_simplify(self, s: bool): """Automatically remove parallel edges as edges are added""" self._auto_simplify = s + def get_auto_simplify(self): + return self._auto_simplify + def multigraph(self): return False diff --git a/pyzx/simplify.py b/pyzx/simplify.py index dd7e7978..aaeac4e5 100644 --- a/pyzx/simplify.py +++ b/pyzx/simplify.py @@ -27,12 +27,14 @@ 'full_reduce', 'teleport_reduce', 'reduce_scalar', 'supplementarity_simp', 'to_clifford_normal_form_graph', 'to_graph_like', 'is_graph_like'] +from ast import Mult from optparse import Option from typing import List, Callable, Optional, Union, Generic, Tuple, Dict, Iterator, cast from .utils import EdgeType, VertexType, toggle_edge, vertex_is_zx, toggle_vertex from .rules import * from .graph.base import BaseGraph, VT, ET +from .graph.multigraph import Multigraph from .circuit import Circuit class Stats(object): @@ -58,6 +60,7 @@ def simp( name: str, match: Callable[..., List[MatchObject]], rewrite: Callable[[BaseGraph[VT,ET],List[MatchObject]],RewriteOutputType[VT,ET]], + auto_simplify_parallel_edges: bool = False, matchf:Optional[Union[Callable[[ET],bool], Callable[[VT],bool]]]=None, quiet:bool=False, stats:Optional[Stats]=None) -> int: @@ -73,6 +76,7 @@ def simp( str name: The name to display if ``quiet`` is set to False. match: One of the ``match_*`` functions of rules_. rewrite: One of the rewrite functions of rules_. + auto_simplify_parallel_edges: whether to automatically combine parallel edges between vertices if the graph is a Multigraph matchf: An optional filtering function on candidate vertices or edges, which is passed as the second argument to the match function. quiet: Suppress output on numbers of matches found during simplification. @@ -80,6 +84,9 @@ def simp( Returns: Number of iterations of ``rewrite`` that had to be applied before no more matches were found.""" + if auto_simplify_parallel_edges and isinstance(g, Multigraph): + auto_simp_value = g.get_auto_simplify() + g.set_auto_simplify(True) i = 0 new_matches = True while new_matches: @@ -103,19 +110,25 @@ def simp( new_matches = True if stats is not None: stats.count_rewrites(name, len(m)) if not quiet and i>0: print(' {!s} iterations'.format(i)) + if auto_simplify_parallel_edges and isinstance(g, Multigraph): + g.set_auto_simplify(auto_simp_value) return i def pivot_simp(g: BaseGraph[VT,ET], matchf:Optional[Callable[[ET],bool]]=None, quiet:bool=False, stats:Optional[Stats]=None) -> int: - return simp(g, 'pivot_simp', match_pivot_parallel, pivot, matchf=matchf, quiet=quiet, stats=stats) + return simp(g, 'pivot_simp', match_pivot_parallel, pivot, + auto_simplify_parallel_edges=True, matchf=matchf, quiet=quiet, stats=stats) def pivot_gadget_simp(g: BaseGraph[VT,ET], matchf:Optional[Callable[[ET],bool]]=None, quiet:bool=False, stats:Optional[Stats]=None) -> int: - return simp(g, 'pivot_gadget_simp', match_pivot_gadget, pivot, matchf=matchf, quiet=quiet, stats=stats) + return simp(g, 'pivot_gadget_simp', match_pivot_gadget, pivot, + auto_simplify_parallel_edges=True, matchf=matchf, quiet=quiet, stats=stats) def pivot_boundary_simp(g: BaseGraph[VT,ET], matchf:Optional[Callable[[ET],bool]]=None, quiet:bool=False, stats:Optional[Stats]=None) -> int: - return simp(g, 'pivot_boundary_simp', match_pivot_boundary, pivot, matchf=matchf, quiet=quiet, stats=stats) + return simp(g, 'pivot_boundary_simp', match_pivot_boundary, pivot, + auto_simplify_parallel_edges=True, matchf=matchf, quiet=quiet, stats=stats) def lcomp_simp(g: BaseGraph[VT,ET], matchf:Optional[Callable[[VT],bool]]=None, quiet:bool=False, stats:Optional[Stats]=None) -> int: - return simp(g, 'lcomp_simp', match_lcomp_parallel, lcomp, matchf=matchf, quiet=quiet, stats=stats) + return simp(g, 'lcomp_simp', match_lcomp_parallel, lcomp, + auto_simplify_parallel_edges=True, matchf=matchf, quiet=quiet, stats=stats) def bialg_simp(g: BaseGraph[VT,ET], quiet:bool=False, stats: Optional[Stats]=None) -> int: return simp(g, 'bialg_simp', match_bialg_parallel, bialg, quiet=quiet, stats=stats) @@ -127,10 +140,12 @@ def id_simp(g: BaseGraph[VT,ET], matchf:Optional[Callable[[VT],bool]]=None, quie return simp(g, 'id_simp', match_ids_parallel, remove_ids, matchf=matchf, quiet=quiet, stats=stats) def gadget_simp(g: BaseGraph[VT,ET], matchf: Optional[Callable[[VT],bool]]=None, quiet:bool=False, stats:Optional[Stats]=None) -> int: - return simp(g, 'gadget_simp', match_phase_gadgets, merge_phase_gadgets, matchf=matchf, quiet=quiet, stats=stats) + return simp(g, 'gadget_simp', match_phase_gadgets, merge_phase_gadgets, + auto_simplify_parallel_edges=True, matchf=matchf, quiet=quiet, stats=stats) def supplementarity_simp(g: BaseGraph[VT,ET], quiet:bool=False, stats:Optional[Stats]=None) -> int: - return simp(g, 'supplementarity_simp', match_supplementarity, apply_supplementarity, quiet=quiet, stats=stats) + return simp(g, 'supplementarity_simp', match_supplementarity, apply_supplementarity, + auto_simplify_parallel_edges=True, quiet=quiet, stats=stats) def copy_simp(g: BaseGraph[VT,ET], quiet:bool=False, stats:Optional[Stats]=None) -> int: """Copies 1-ary spiders with 0/pi phase through neighbors.