Skip to content

Commit

Permalink
Cleaner Python API, Python tests
Browse files Browse the repository at this point in the history
  • Loading branch information
franzpoeschel committed Nov 4, 2024
1 parent 6353063 commit 6ae9e1c
Show file tree
Hide file tree
Showing 8 changed files with 131 additions and 20 deletions.
10 changes: 4 additions & 6 deletions include/openPMD/Mesh.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include "openPMD/backend/BaseRecord.hpp"
#include "openPMD/backend/MeshRecordComponent.hpp"

#include <array>
#include <ostream>
#include <string>
#include <type_traits>
Expand Down Expand Up @@ -196,7 +195,7 @@ class Mesh : public BaseRecord<MeshRecordComponent>
*/
Mesh &setGridUnitSI(double gridUnitSI);

/** Alias for `setGridUnitSI(std::vector<double>)`.
/** Alias for `setGridUnitSIPerDimension(std::vector<double>)`.
*
* Set the unit-conversion factor per dimension to multiply each value in
* Mesh::gridSpacing and Mesh::gridGlobalOffset, in order to convert from
Expand Down Expand Up @@ -245,8 +244,7 @@ class Mesh : public BaseRecord<MeshRecordComponent>
* that represent the power of the particular base.
* @return Reference to modified mesh.
*/
Mesh &
setUnitDimension(std::map<UnitDimension, double> const &unitDimension);
Mesh &setUnitDimension(unit_representations::AsMap const &unitDimension);

/**
* @brief Set the unitDimension for each axis of the current grid.
Expand All @@ -260,8 +258,8 @@ class Mesh : public BaseRecord<MeshRecordComponent>
*
* @return Reference to modified mesh.
*/
Mesh &setGridUnitDimension(
std::vector<std::map<UnitDimension, double>> const &gridUnitDimension);
Mesh &
setGridUnitDimension(unit_representations::AsMaps const &gridUnitDimension);

/**
* @brief Return the physical dimensions of the mesh axes.
Expand Down
4 changes: 2 additions & 2 deletions include/openPMD/Record.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
#pragma once

#include "openPMD/RecordComponent.hpp"
#include "openPMD/UnitDimension.hpp"
#include "openPMD/backend/BaseRecord.hpp"

#include <map>
#include <string>
#include <type_traits>

Expand All @@ -40,7 +40,7 @@ class Record : public BaseRecord<RecordComponent>
Record &operator=(Record const &) = default;
~Record() override = default;

Record &setUnitDimension(std::map<UnitDimension, double> const &);
Record &setUnitDimension(unit_representations::AsMap const &);

template <typename T>
T timeOffset() const;
Expand Down
2 changes: 2 additions & 0 deletions include/openPMD/backend/Attributable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,8 @@ class Attributable
*/
void touch();

[[nodiscard]] OpenpmdStandard openPMDStandard() const;

// clang-format off
OPENPMD_protected
// clang-format on
Expand Down
7 changes: 3 additions & 4 deletions src/Mesh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,6 @@ std::vector<double> Mesh::gridUnitSIPerDimension() const

Mesh &Mesh::setGridUnitSIPerDimension(std::vector<double> gridUnitSI)
{
setAttribute("gridUnitSI", std::move(gridUnitSI));
if (auto standard = IOHandler()->m_standard;
standard < OpenpmdStandard::v_2_0_0)
{
Expand All @@ -261,10 +260,11 @@ Mesh &Mesh::setGridUnitSIPerDimension(std::vector<double> gridUnitSI)
"openPMD 2.0. Either upgrade the file to openPMD >= 2.0 "
"or specify a scalar that applies to all axes.");
}
setAttribute("gridUnitSI", std::move(gridUnitSI));
return *this;
}

Mesh &Mesh::setUnitDimension(std::map<UnitDimension, double> const &udim)
Mesh &Mesh::setUnitDimension(unit_representations::AsMap const &udim)
{
if (!udim.empty())
{
Expand All @@ -276,8 +276,7 @@ Mesh &Mesh::setUnitDimension(std::map<UnitDimension, double> const &udim)
return *this;
}

Mesh &Mesh::setGridUnitDimension(
std::vector<std::map<UnitDimension, double>> const &udims)
Mesh &Mesh::setGridUnitDimension(unit_representations::AsMaps const &udims)
{
auto rawGridUnitDimension = [this]() {
try
Expand Down
3 changes: 2 additions & 1 deletion src/Record.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
*/
#include "openPMD/Record.hpp"
#include "openPMD/RecordComponent.hpp"
#include "openPMD/UnitDimension.hpp"
#include "openPMD/backend/BaseRecord.hpp"

#include <iostream>
Expand All @@ -31,7 +32,7 @@ Record::Record()
setTimeOffset(0.f);
}

Record &Record::setUnitDimension(std::map<UnitDimension, double> const &udim)
Record &Record::setUnitDimension(unit_representations::AsMap const &udim)
{
if (!udim.empty())
{
Expand Down
6 changes: 6 additions & 0 deletions src/backend/Attributable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
* If not, see <http://www.gnu.org/licenses/>.
*/
#include "openPMD/backend/Attributable.hpp"
#include "openPMD/IO/AbstractIOHandler.hpp"
#include "openPMD/Iteration.hpp"
#include "openPMD/ParticleSpecies.hpp"
#include "openPMD/RecordComponent.hpp"
Expand Down Expand Up @@ -250,6 +251,11 @@ void Attributable::touch()
setDirtyRecursive(true);
}

OpenpmdStandard Attributable::openPMDStandard() const
{
return IOHandler()->m_standard;
}

template <bool flush_entire_series>
void Attributable::seriesFlush_impl(internal::FlushParams const &flushParams)
{
Expand Down
30 changes: 23 additions & 7 deletions src/binding/python/Mesh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
* If not, see <http://www.gnu.org/licenses/>.
*/
#include "openPMD/Mesh.hpp"
#include "openPMD/Error.hpp"
#include "openPMD/IO/AbstractIOHandler.hpp"
#include "openPMD/backend/Attributable.hpp"
#include "openPMD/backend/BaseRecord.hpp"
#include "openPMD/backend/MeshRecordComponent.hpp"
Expand All @@ -29,14 +31,15 @@
#include "openPMD/binding/python/UnitDimension.hpp"

#include <string>
#include <variant>
#include <vector>

void init_Mesh(py::module &m)
{
auto py_m_cont =
declare_container<PyMeshContainer, Attributable>(m, "Mesh_Container");

py::class_<Mesh, BaseRecord<MeshRecordComponent> > cl(m, "Mesh");
py::class_<Mesh, BaseRecord<MeshRecordComponent>> cl(m, "Mesh");

py::enum_<Mesh::Geometry>(m, "Geometry") // TODO: m -> cl
.value("cartesian", Mesh::Geometry::cartesian)
Expand Down Expand Up @@ -102,12 +105,25 @@ void init_Mesh(py::module &m)
&Mesh::setGridGlobalOffset)
.def_property(
"grid_unit_SI",
&Mesh::gridUnitSI,
py::overload_cast<double>(&Mesh::setGridUnitSI))
.def_property(
"grid_unit_SI_per_dimension",
&Mesh::gridUnitSIPerDimension,
&Mesh::setGridUnitSIPerDimension)
[](Mesh &self) {
using return_t = std::variant<double, std::vector<double>>;
if (self.openPMDStandard() < OpenpmdStandard::v_2_0_0)
{
return return_t(self.gridUnitSI());
}
else
{
return return_t(self.gridUnitSIPerDimension());
}
},
[](Mesh &self, std::variant<double, std::vector<double>> arg) {
return std::visit(
[&](auto &&arg_resolved) {
return self.setGridUnitSI(
static_cast<decltype(arg_resolved)>(arg_resolved));
},
arg);
})
.def_property(
"time_offset",
&Mesh::timeOffset<double>,
Expand Down
89 changes: 89 additions & 0 deletions test/python/unittest/API/APITest.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,95 @@ def testAttributes(self):
for ext in tested_file_extensions:
self.attributeRoundTrip(ext)

def testOpenPMD_2_0(self):
write_2_0 = io.Series("../samples/openpmd_2_0.json", io.Access.create)
write_2_0.openPMD = "2.0.0"
meshes = write_2_0.write_iterations()[100].meshes

E = meshes["E"]
E.reset_dataset(io.Dataset(io.Datatype.DOUBLE, [10, 10, 10]))
E.grid_unit_SI = [1, 2, 3]
E.grid_unit_dimension = [
{io.Unit_Dimension.L: 1},
{io.Unit_Dimension.L: 1},
{io.Unit_Dimension.L: 1, io.Unit_Dimension.T: -1}]
E.make_constant(17)

B = meshes["B"]
B.reset_dataset(io.Dataset(io.Datatype.DOUBLE, [10, 10, 10]))
# This is deprecated for openPMD 2.0, a warning will be printed.
B.grid_unit_SI = 3
B.grid_unit_dimension = [{io.Unit_Dimension.L: 1} for _ in range(3)]
B.make_constant(18)

write_2_0.close()

read_2_0 = io.Series(
"../samples/openpmd_2_0.json", io.Access.read_only)
meshes = read_2_0.iterations[100].meshes

E = meshes["E"]
self.assertEqual(E.grid_unit_SI, [1, 2, 3])
self.assertEqual(E.grid_unit_dimension, io.Unit_Dimension.as_arrays([
{io.Unit_Dimension.L: 1},
{io.Unit_Dimension.L: 1},
{io.Unit_Dimension.L: 1, io.Unit_Dimension.T: -1}]))

B = meshes["B"]
# Will return a list due to openPMD standard being set to 2.0.0
self.assertEqual(B.grid_unit_SI, [3])
self.assertEqual(io.Unit_Dimension.as_maps(B.grid_unit_dimension), [
{io.Unit_Dimension.L: 1} for _ in range(3)])
read_2_0.close()

write_1_1 = io.Series("../samples/openpmd_1_1.json", io.Access.create)
write_1_1.openPMD = "1.1.0"
meshes = write_1_1.write_iterations()[100].meshes

E = meshes["E"]
E.reset_dataset(io.Dataset(io.Datatype.DOUBLE, [10, 10, 10]))

def unsupported_in_1_1():
E.grid_unit_SI = [1, 2, 3]
self.assertRaises(
io.ErrorIllegalInOpenPMDStandard, unsupported_in_1_1)
E.grid_unit_dimension = [
{io.Unit_Dimension.L: 1},
{io.Unit_Dimension.L: 1},
{io.Unit_Dimension.L: 1, io.Unit_Dimension.T: -1}]
E.make_constant(17)

B = meshes["B"]
B.reset_dataset(io.Dataset(io.Datatype.DOUBLE, [10, 10, 10]))
# This is deprecated for openPMD 2.0, a warning will be printed.
B.grid_unit_SI = 3
B.grid_unit_dimension = [{io.Unit_Dimension.L: 1} for _ in range(3)]
B.make_constant(18)

write_1_1.close()

read_1_1 = io.Series(
"../samples/openpmd_1_1.json", io.Access.read_only)
meshes = read_1_1.iterations[100].meshes

E = meshes["E"]
# Will return a default value due to the failed attempt at setting
# a list at write time
self.assertEqual(E.grid_unit_SI, 1)
self.assertEqual(E.grid_unit_dimension, io.Unit_Dimension.as_arrays([
{io.Unit_Dimension.L: 1},
{io.Unit_Dimension.L: 1},
{io.Unit_Dimension.L: 1, io.Unit_Dimension.T: -1}]))

B = meshes["B"]
# Will return a scalar due to openPMD standard being set to 2.0.0
self.assertEqual(B.grid_unit_SI, 3)
self.assertEqual(io.Unit_Dimension.as_maps(B.grid_unit_dimension), [
{io.Unit_Dimension.L: 1} for _ in range(3)])
read_1_1.close()



def makeConstantRoundTrip(self, file_ending):
# write
series = io.Series(
Expand Down

0 comments on commit 6ae9e1c

Please sign in to comment.