Skip to content

Commit

Permalink
fix fields IO bug
Browse files Browse the repository at this point in the history
  • Loading branch information
jinluchang committed Mar 21, 2024
1 parent 080d231 commit 42c7445
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 2 deletions.
41 changes: 41 additions & 0 deletions examples-py/fields-io.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,49 @@

prop = q.Prop(geo)
prop.set_rand(rs.split("prop-1"))

q.displayln_info("CHECK: prop", prop.crc32(), f"{prop.qnorm():.14E}")

s_prop = q.SelProp(fselc)
s_prop @= prop

sfw = q.open_fields("results/prop1.fields", "w", q.Coordinate([ 1, 1, 1, 1, ]))

prop.save_float_from_double(sfw, "prop", skip_if_exist=True)

s_prop.save_float_from_double(sfw, "s_prop", skip_if_exist=True)

sfw.close()

prop_1 = q.Prop()
s_prop_1 = q.SelProp(None)
s_prop_2 = q.SelProp(fselc)
s_prop_3 = q.SelProp(fselc)

sfr = q.open_fields("results/prop1.fields", "r")

prop_1.load_double_from_float(sfr, "prop")
s_prop_1.load_double_from_float(sfr, "s_prop")
s_prop_2.load_double_from_float(sfr, "s_prop")
s_prop_3.load_double_from_float(sfr, "s_prop")

sfr.close()

assert q.is_matching_fsel(s_prop.fsel, fselc)
assert q.is_matching_fsel(s_prop_1.fsel, fselc)
assert q.is_matching_fsel(s_prop_2.fsel, fselc)
assert q.is_matching_fsel(s_prop_3.fsel, fselc)

prop_1 -= prop
s_prop_1 -= s_prop
s_prop_2 -= s_prop
s_prop_3 -= s_prop

assert q.qnorm(prop_1) < 1e-10
assert q.qnorm(s_prop_1) < 1e-10
assert q.qnorm(s_prop_2) < 1e-10
assert q.qnorm(s_prop_3) < 1e-10

sfw = q.open_fields("results/prop.fields", "w", q.Coordinate([ 1, 1, 1, 8, ]))

sf_list = sorted(q.show_all_shuffled_fields_writer())
Expand Down
2 changes: 2 additions & 0 deletions qlat/qlat/everything.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ cdef extern from "qlat/mpi.h" namespace "qlat":
int glb_sum(ComplexD& ld) except +
int glb_sum(ComplexF& ld) except +
int glb_sum(LatData& ld) except +
bool glb_all(const bool b) except +
bool glb_any(const bool b) except +

cdef extern from "qlat/geometry.h" namespace "qlat":

Expand Down
10 changes: 10 additions & 0 deletions qlat/qlat/fields_io.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,16 @@ cdef class ShuffledFieldsWriter:
return x

def get_cache_sbs(self, FieldSelection fsel):
if fsel is None:
cache_fields_io.pop(id(self), None)
return
if id(self) in cache_fields_io:
c_fsel, c_sbs = cache_fields_io[id(self)]
if fsel is c_fsel:
return c_sbs
sbs = ShuffledBitSet(fsel, self.new_size_node())
assert sbs.xx.fsel.n_elems == fsel.xx.n_elems
assert cc.is_matching_fsel(sbs.xx.fsel, fsel.xx)
cache_fields_io[id(self)] = (fsel, sbs,)
return sbs

Expand Down Expand Up @@ -90,11 +95,16 @@ cdef class ShuffledFieldsReader:
return x

def get_cache_sbs(self, FieldSelection fsel):
if fsel is None:
cache_fields_io.pop(id(self), None)
return
if id(self) in cache_fields_io:
c_fsel, c_sbs = cache_fields_io[id(self)]
if fsel is c_fsel:
return c_sbs
sbs = ShuffledBitSet(fsel, self.new_size_node())
assert sbs.xx.fsel.n_elems == fsel.xx.n_elems
assert cc.is_matching_fsel(sbs.xx.fsel, fsel.xx)
cache_fields_io[id(self)] = (fsel, sbs,)
return sbs

Expand Down
6 changes: 6 additions & 0 deletions qlat/qlat/include/qlat/fields-io.h
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,8 @@ Long read(ShuffledFieldsReader& sfr, const std::string& fn, Field<M>& field)
qassert(0 == total_bytes);
}
if (0 == total_bytes) {
qwarn(fname + ssprintf(": total_bytes=%ld (fn='%s', sfr.path='%s')",
total_bytes, fn.c_str(), sfr.path.c_str()));
return 0;
}
Coordinate total_site;
Expand Down Expand Up @@ -849,6 +851,8 @@ Long read(ShuffledFieldsReader& sfr, const std::string& fn,
qassert(0 == total_bytes);
}
if (0 == total_bytes) {
qwarn(fname + ssprintf(": total_bytes=%ld (fn='%s', sfr.path='%s')",
total_bytes, fn.c_str(), sfr.path.c_str()));
return 0;
}
Coordinate total_site;
Expand Down Expand Up @@ -893,6 +897,8 @@ Long read(ShuffledFieldsReader& sfr, const std::string& fn,
qassert(0 == total_bytes);
}
if (0 == total_bytes) {
qwarn(fname + ssprintf(": total_bytes=%ld (fn='%s', sfr.path='%s')",
total_bytes, fn.c_str(), sfr.path.c_str()));
return 0;
}
Coordinate total_site;
Expand Down
20 changes: 20 additions & 0 deletions qlat/qlat/include/qlat/mpi.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,26 @@ int glb_sum(Vector<M> xx)

int glb_sum(LatData& ld);

inline bool glb_any(const bool b)
{
Long ret = 0;
if (b) {
ret = 1;
}
glb_sum(ret);
return ret > 0;
}

inline bool glb_all(const bool b)
{
Long ret = 0;
if (not b) {
ret = 1;
}
glb_sum(ret);
return ret == 0;
}

int bcast(Vector<Char> recv, const int root = 0);

template <class T, QLAT_ENABLE_IF(is_data_vector_type<T>())>
Expand Down
4 changes: 2 additions & 2 deletions qlat/qlat/lib/selected-field.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ bool is_matching_fsel(const FieldSelection& fsel1, const FieldSelection& fsel2)
}

bool is_containing(const FieldSelection& fsel, const FieldSelection& fsel_small)
// local checking
{
TIMER("is_containing(fsel,fsel_small)");
Long n_missing_points = 0;
Expand All @@ -400,11 +401,11 @@ bool is_containing(const FieldSelection& fsel, const FieldSelection& fsel_small)
n_missing_points += 1;
}
});
glb_sum(n_missing_points);
return n_missing_points == 0;
}

bool is_containing(const FieldSelection& fsel, const PointsSelection& psel)
// local checking
{
TIMER("is_containing(fsel,psel)");
const Geometry& geo = fsel.f_rank.geo();
Expand All @@ -422,7 +423,6 @@ bool is_containing(const FieldSelection& fsel, const PointsSelection& psel)
}
}
});
glb_sum(n_missing_points);
return n_missing_points == 0;
}

Expand Down
10 changes: 10 additions & 0 deletions qlat/qlat/selected_field_types.pyx.in
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,15 @@ cdef class SelectedField{{name}}(SelectedFieldBase):
raise ValueError("can't re-init while being viewed")
self.fsel = FieldSelection()
total_bytes = cc.read(sfr.xx, fn, self.xx, self.fsel.xx)
if total_bytes == 0:
return 0
else:
sbs = sfr.get_cache_sbs(self.fsel)
total_bytes = cc.read(sfr.xx, fn, sbs.xx, self.xx)
if total_bytes == 0:
return 0
if sbs.xx.fsel.n_elems != self.fsel.xx.n_elems:
raise Exception(f"read_sfr_direct: sbs.xx.fsel.n_elems={sbs.xx.fsel.n_elems} ; self.fsel.xx.n_elems={self.fsel.xx.n_elems}")
if self.xx.n_elems != self.fsel.xx.n_elems:
raise Exception(f"read_sfr_direct: self.xx.n_elems={self.xx.n_elems} ; self.fsel.xx.n_elems={self.fsel.xx.n_elems}")
return total_bytes
Expand All @@ -198,6 +204,10 @@ cdef class SelectedField{{name}}(SelectedFieldBase):
assert self.fsel is not None
assert self.xx.n_elems == self.fsel.xx.n_elems
cdef ShuffledBitSet sbs = sfw.get_cache_sbs(self.fsel)
if sbs.xx.fsel.n_elems != self.fsel.xx.n_elems:
raise Exception(f"write_sfw_direct: sbs.xx.fsel.n_elems={sbs.xx.fsel.n_elems} ; self.fsel.xx.n_elems={self.fsel.xx.n_elems}")
if self.xx.n_elems != self.fsel.xx.n_elems:
raise Exception(f"write_sfw_direct: self.xx.n_elems={self.xx.n_elems} ; self.fsel.xx.n_elems={self.fsel.xx.n_elems}")
cdef cc.Long total_bytes = cc.write(sfw.xx, fn, sbs.xx, self.xx)
return total_bytes

Expand Down

0 comments on commit 42c7445

Please sign in to comment.