From b7bd0db558391ec0b8e5431c10355e5b24dcca45 Mon Sep 17 00:00:00 2001 From: Luchang Jin Date: Sun, 7 Apr 2024 06:17:24 -0400 Subject: [PATCH] all pickle more objects --- examples-py/selected-field.py | 8 ++++++ qlat/qlat/field_base.pyx | 39 +++++++++++++++++++++++++++ qlat/qlat/field_types.pyx.in | 17 +++++++++++- qlat/qlat/propagator.pyx | 26 +++++++++++++++++- qlat/qlat/selected_field_types.pyx.in | 17 +++++++++++- 5 files changed, 104 insertions(+), 3 deletions(-) diff --git a/examples-py/selected-field.py b/examples-py/selected-field.py index 82f9d6ce2..b18636135 100755 --- a/examples-py/selected-field.py +++ b/examples-py/selected-field.py @@ -42,6 +42,10 @@ q.displayln_info(f"CHECK: prop.crc32() = {prop.crc32()} ; prop.qnorm() = {prop.qnorm():.12E}") +q.save_pickle_obj(prop, f"results/prop-{q.get_id_node()}.pickle", is_sync_node=False) +prop_load = q.load_pickle_obj(f"results/prop-{q.get_id_node()}.pickle", is_sync_node=False) +assert np.all(prop[:] == prop_load[:]) + psel = q.PointsSelection([ [ 0, 0, 0, 0, ], [ 0, 1, 2, 0, ], @@ -67,6 +71,10 @@ s_prop = q.SelProp(fselc) s_prop @= prop +q.save_pickle_obj(s_prop, f"results/s_prop-{q.get_id_node()}.pickle", is_sync_node=False) +s_prop_load = q.load_pickle_obj(f"results/s_prop-{q.get_id_node()}.pickle", is_sync_node=False) +assert np.all(s_prop[:] == s_prop_load[:]) + n1 = s_prop.n_elems() n2 = q.glb_sum(n1) diff --git a/qlat/qlat/field_base.pyx b/qlat/qlat/field_base.pyx index 6de57df4c..39bd7de86 100644 --- a/qlat/qlat/field_base.pyx +++ b/qlat/qlat/field_base.pyx @@ -360,6 +360,25 @@ cdef class FieldBase: def __getnewargs__(self): return () + def __getstate__(self): + """ + Only work when single node (or if all nodes has the same data). + """ + geo = self.geo() + data_arr = self[:] + return [ data_arr, geo, ] + + def __setstate__(self, state): + """ + Only work when single node (or if all nodes has the same data). + """ + if self.view_count > 0: + raise ValueError("can't load while being viewed") + self.__init__() + [ data_arr, geo, ] = state + self.init_from_geo(geo) + self[:] = data_arr + ### ------------------------------------------------------------------- def split_fields(fs, f): @@ -615,6 +634,26 @@ cdef class SelectedFieldBase: def __getnewargs__(self): return () + def __getstate__(self): + """ + Only work when single node (or if all nodes has the same data). + """ + fsel = self.fsel + multiplicity = self.multiplicity() + data_arr = self[:] + return [ data_arr, multiplicity, fsel, ] + + def __setstate__(self, state): + """ + Only work when single node (or if all nodes has the same data). + """ + if self.view_count > 0: + raise ValueError("can't load while being viewed") + self.__init__() + [ data_arr, multiplicity, fsel, ] = state + self.init_from_fsel(fsel, multiplicity) + self[:] = data_arr + ### ------------------------------------------------------------------- cdef class SelectedPointsBase: diff --git a/qlat/qlat/field_types.pyx.in b/qlat/qlat/field_types.pyx.in index 54648747b..af79a87e1 100644 --- a/qlat/qlat/field_types.pyx.in +++ b/qlat/qlat/field_types.pyx.in @@ -58,7 +58,10 @@ cdef class Field{{name}}(FieldBase): self.cdata = &(self.xx) self.view_count = 0 - def __init__(self, Geometry geo=None, int multiplicity=0): + def __init__(self, *args): + self.init_from_geo(*args) + + def init_from_geo(self, Geometry geo=None, int multiplicity=0): if geo is None: self.xx.init() return @@ -226,6 +229,18 @@ cdef class Field{{name}}(FieldBase): return 0 return cc.write(sfw.xx, fn, self.xx) + def __getstate__(self): + """ + Only work when single node (or if all nodes has the same data). + """ + return super().__getstate__() + + def __setstate__(self, state): + """ + Only work when single node (or if all nodes has the same data). + """ + super().__setstate__(state) + field_type_dict[ElemType{{name}}] = Field{{name}} {{endfor}} diff --git a/qlat/qlat/propagator.pyx b/qlat/qlat/propagator.pyx index a6f51d3bd..ba9795720 100644 --- a/qlat/qlat/propagator.pyx +++ b/qlat/qlat/propagator.pyx @@ -34,11 +34,23 @@ cdef class Prop(FieldWilsonMatrix): np.asarray(wm)[:] = self[index, m] return wm + def __getstate__(self): + """ + Only work when single node (or if all nodes has the same data). + """ + return super().__getstate__() + + def __setstate__(self, state): + """ + Only work when single node (or if all nodes has the same data). + """ + super().__setstate__(state) + ### cdef class SelProp(SelectedFieldWilsonMatrix): - def __init__(self, FieldSelection fsel): + def __init__(self, FieldSelection fsel=None): super().__init__(fsel, 1) cdef cc.Handle[cc.SelProp] xxx(self): @@ -50,6 +62,18 @@ cdef class SelProp(SelectedFieldWilsonMatrix): np.asarray(wm)[:] = self[idx, m] return wm + def __getstate__(self): + """ + Only work when single node (or if all nodes has the same data). + """ + return super().__getstate__() + + def __setstate__(self, state): + """ + Only work when single node (or if all nodes has the same data). + """ + super().__setstate__(state) + ### cdef class PselProp(SelectedPointsWilsonMatrix): diff --git a/qlat/qlat/selected_field_types.pyx.in b/qlat/qlat/selected_field_types.pyx.in index 4b66e2562..e53023509 100644 --- a/qlat/qlat/selected_field_types.pyx.in +++ b/qlat/qlat/selected_field_types.pyx.in @@ -59,7 +59,10 @@ cdef class SelectedField{{name}}(SelectedFieldBase): self.cdata = &(self.xx) self.view_count = 0 - def __init__(self, FieldSelection fsel, int multiplicity=0): + def __init__(self, *args): + self.init_from_fsel(*args) + + def init_from_fsel(self, FieldSelection fsel=None, cc.Int multiplicity=0): self.fsel = fsel if multiplicity > 0 and self.fsel is not None: if self.view_count > 0: @@ -211,6 +214,18 @@ cdef class SelectedField{{name}}(SelectedFieldBase): cdef cc.Long total_bytes = cc.write(sfw.xx, fn, sbs.xx, self.xx) return total_bytes + def __getstate__(self): + """ + Only work when single node (or if all nodes has the same data). + """ + return super().__getstate__() + + def __setstate__(self, state): + """ + Only work when single node (or if all nodes has the same data). + """ + super().__setstate__(state) + selected_field_type_dict[ElemType{{name}}] = SelectedField{{name}} {{endfor}}