Skip to content

Commit

Permalink
all pickle more objects
Browse files Browse the repository at this point in the history
  • Loading branch information
jinluchang committed Apr 7, 2024
1 parent 6509bcb commit b7bd0db
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 3 deletions.
8 changes: 8 additions & 0 deletions examples-py/selected-field.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@

q.displayln_info(f"CHECK: prop.crc32() = {prop.crc32()} ; prop.qnorm() = {prop.qnorm():.12E}")

q.save_pickle_obj(prop, f"results/prop-{q.get_id_node()}.pickle", is_sync_node=False)
prop_load = q.load_pickle_obj(f"results/prop-{q.get_id_node()}.pickle", is_sync_node=False)
assert np.all(prop[:] == prop_load[:])

psel = q.PointsSelection([
[ 0, 0, 0, 0, ],
[ 0, 1, 2, 0, ],
Expand All @@ -67,6 +71,10 @@
s_prop = q.SelProp(fselc)
s_prop @= prop

q.save_pickle_obj(s_prop, f"results/s_prop-{q.get_id_node()}.pickle", is_sync_node=False)
s_prop_load = q.load_pickle_obj(f"results/s_prop-{q.get_id_node()}.pickle", is_sync_node=False)
assert np.all(s_prop[:] == s_prop_load[:])

n1 = s_prop.n_elems()
n2 = q.glb_sum(n1)

Expand Down
39 changes: 39 additions & 0 deletions qlat/qlat/field_base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,25 @@ cdef class FieldBase:
def __getnewargs__(self):
return ()

def __getstate__(self):
"""
Only work when single node (or if all nodes has the same data).
"""
geo = self.geo()
data_arr = self[:]
return [ data_arr, geo, ]

def __setstate__(self, state):
"""
Only work when single node (or if all nodes has the same data).
"""
if self.view_count > 0:
raise ValueError("can't load while being viewed")
self.__init__()
[ data_arr, geo, ] = state
self.init_from_geo(geo)
self[:] = data_arr

### -------------------------------------------------------------------

def split_fields(fs, f):
Expand Down Expand Up @@ -615,6 +634,26 @@ cdef class SelectedFieldBase:
def __getnewargs__(self):
return ()

def __getstate__(self):
"""
Only work when single node (or if all nodes has the same data).
"""
fsel = self.fsel
multiplicity = self.multiplicity()
data_arr = self[:]
return [ data_arr, multiplicity, fsel, ]

def __setstate__(self, state):
"""
Only work when single node (or if all nodes has the same data).
"""
if self.view_count > 0:
raise ValueError("can't load while being viewed")
self.__init__()
[ data_arr, multiplicity, fsel, ] = state
self.init_from_fsel(fsel, multiplicity)
self[:] = data_arr

### -------------------------------------------------------------------

cdef class SelectedPointsBase:
Expand Down
17 changes: 16 additions & 1 deletion qlat/qlat/field_types.pyx.in
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ cdef class Field{{name}}(FieldBase):
self.cdata = <cc.Long>&(self.xx)
self.view_count = 0

def __init__(self, Geometry geo=None, int multiplicity=0):
def __init__(self, *args):
self.init_from_geo(*args)

def init_from_geo(self, Geometry geo=None, int multiplicity=0):
if geo is None:
self.xx.init()
return
Expand Down Expand Up @@ -226,6 +229,18 @@ cdef class Field{{name}}(FieldBase):
return 0
return cc.write(sfw.xx, fn, self.xx)

def __getstate__(self):
"""
Only work when single node (or if all nodes has the same data).
"""
return super().__getstate__()

def __setstate__(self, state):
"""
Only work when single node (or if all nodes has the same data).
"""
super().__setstate__(state)

field_type_dict[ElemType{{name}}] = Field{{name}}

{{endfor}}
Expand Down
26 changes: 25 additions & 1 deletion qlat/qlat/propagator.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,23 @@ cdef class Prop(FieldWilsonMatrix):
np.asarray(wm)[:] = self[index, m]
return wm

def __getstate__(self):
"""
Only work when single node (or if all nodes has the same data).
"""
return super().__getstate__()

def __setstate__(self, state):
"""
Only work when single node (or if all nodes has the same data).
"""
super().__setstate__(state)

###

cdef class SelProp(SelectedFieldWilsonMatrix):

def __init__(self, FieldSelection fsel):
def __init__(self, FieldSelection fsel=None):
super().__init__(fsel, 1)

cdef cc.Handle[cc.SelProp] xxx(self):
Expand All @@ -50,6 +62,18 @@ cdef class SelProp(SelectedFieldWilsonMatrix):
np.asarray(wm)[:] = self[idx, m]
return wm

def __getstate__(self):
"""
Only work when single node (or if all nodes has the same data).
"""
return super().__getstate__()

def __setstate__(self, state):
"""
Only work when single node (or if all nodes has the same data).
"""
super().__setstate__(state)

###

cdef class PselProp(SelectedPointsWilsonMatrix):
Expand Down
17 changes: 16 additions & 1 deletion qlat/qlat/selected_field_types.pyx.in
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ cdef class SelectedField{{name}}(SelectedFieldBase):
self.cdata = <cc.Long>&(self.xx)
self.view_count = 0

def __init__(self, FieldSelection fsel, int multiplicity=0):
def __init__(self, *args):
self.init_from_fsel(*args)

def init_from_fsel(self, FieldSelection fsel=None, cc.Int multiplicity=0):
self.fsel = fsel
if multiplicity > 0 and self.fsel is not None:
if self.view_count > 0:
Expand Down Expand Up @@ -211,6 +214,18 @@ cdef class SelectedField{{name}}(SelectedFieldBase):
cdef cc.Long total_bytes = cc.write(sfw.xx, fn, sbs.xx, self.xx)
return total_bytes

def __getstate__(self):
"""
Only work when single node (or if all nodes has the same data).
"""
return super().__getstate__()

def __setstate__(self, state):
"""
Only work when single node (or if all nodes has the same data).
"""
super().__setstate__(state)

selected_field_type_dict[ElemType{{name}}] = SelectedField{{name}}

{{endfor}}
Expand Down

0 comments on commit b7bd0db

Please sign in to comment.