Skip to content

Commit

Permalink
Fix unique_ptr<T, Del> constructor of UniquePtrWithLambda (#1552)
Browse files Browse the repository at this point in the history
* Add test with copyable custom deleter

* Remove wrong double application of `get_deleter()`

* Test custom non-copyable deleter type

* Consider that deleter may not be copy-constructible

We then need to put them behind a shared_ptr to create a lambda for
std::function

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add a comment and some cleanup

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
franzpoeschel and pre-commit-ci[bot] authored Nov 7, 2023
1 parent 31e927b commit 442481a
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 49 deletions.
39 changes: 21 additions & 18 deletions include/openPMD/auxiliary/UniquePtr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,14 @@ namespace auxiliary
public:
using deleter_type = std::function<void(T_decayed *)>;

deleter_type const &get_deleter() const
{
return *this;
}
deleter_type &get_deleter()
{
return *this;
}

/*
* Default constructor: Use std::default_delete<T>.
* This ensures correct destruction of arrays by using delete[].
*/
CustomDelete()
: deleter_type{[](T_decayed *ptr) {
: deleter_type{[]([[maybe_unused]] T_decayed *ptr) {
if constexpr (std::is_void_v<T_decayed>)
{
(void)ptr;
std::cerr << "[Warning] Cannot standard-delete a void-type "
"pointer. Please specify a custom destructor. "
"Will let the memory leak."
Expand Down Expand Up @@ -144,12 +134,25 @@ UniquePtrWithLambda<T>::UniquePtrWithLambda(std::unique_ptr<T> stdPtr)
template <typename T>
template <typename Del>
UniquePtrWithLambda<T>::UniquePtrWithLambda(std::unique_ptr<T, Del> ptr)
: BasePtr{
ptr.release(),
auxiliary::CustomDelete<T>{
[deleter = std::move(ptr.get_deleter())](T_decayed *del_ptr) {
deleter.get_deleter()(del_ptr);
}}}
: BasePtr{ptr.release(), auxiliary::CustomDelete<T>{[&]() {
if constexpr (std::is_copy_constructible_v<Del>)
{
return [deleter = std::move(ptr.get_deleter())](
T_decayed *del_ptr) { deleter(del_ptr); };
}
else
{
/*
* The constructor of std::function requires a copyable
* lambda. Since Del is not a copyable type, we cannot
* capture it directly, but need to put it into a
* shared_ptr to make it copyable.
*/
return [deleter = std::make_shared<Del>(
std::move(ptr.get_deleter()))](
T_decayed *del_ptr) { (*deleter)(del_ptr); };
}
}()}}
{}

template <typename T>
Expand All @@ -170,7 +173,7 @@ UniquePtrWithLambda<U> UniquePtrWithLambda<T>::static_cast_() &&
return UniquePtrWithLambda<U>{
static_cast<other_type *>(this->release()),
[deleter = std::move(this->get_deleter())](other_type *ptr) {
deleter.get_deleter()(static_cast<T_decayed *>(ptr));
deleter(static_cast<T_decayed *>(ptr));
}};
}
} // namespace openPMD
106 changes: 75 additions & 31 deletions test/SerialIOTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,29 @@ TEST_CASE("close_iteration_interleaved_test", "[serial]")
}
}

namespace detail
{
template <typename T>
struct CopyableDeleter : std::function<void(T *)>
{
CopyableDeleter()
: std::function<void(T *)>{[](T const *ptr) { delete[] ptr; }}
{}
};

template <typename T>
struct NonCopyableDeleter : std::function<void(T *)>
{
NonCopyableDeleter()
: std::function<void(T *)>{[](T const *ptr) { delete[] ptr; }}
{}
NonCopyableDeleter(NonCopyableDeleter const &) = delete;
NonCopyableDeleter &operator=(NonCopyableDeleter const &) = delete;
NonCopyableDeleter(NonCopyableDeleter &&) = default;
NonCopyableDeleter &operator=(NonCopyableDeleter &&) = default;
};
} // namespace detail

void close_and_copy_attributable_test(std::string file_ending)
{
using position_t = int;
Expand Down Expand Up @@ -685,6 +708,27 @@ void close_and_copy_attributable_test(std::string file_ending)
{0},
{global_extent});

// UniquePtrWithLambda from unique_ptr with custom delete type
auto pos_v = electronPositions["v"];
pos_v.resetDataset(dataset);
std::unique_ptr<int, ::detail::CopyableDeleter<int>>
ptr_v_copyable_deleter(new int[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
pos_v.storeChunk(
UniquePtrWithLambda<int>(std::move(ptr_v_copyable_deleter)),
{0},
{global_extent});

// UniquePtrWithLambda from unique_ptr with non-copyable custom delete
// type
auto posOff_v = electronPositionsOffset["v"];
posOff_v.resetDataset(dataset);
std::unique_ptr<int, ::detail::NonCopyableDeleter<int>>
ptr_v_noncopyable_deleter(new int[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
posOff_v.storeChunk(
UniquePtrWithLambda<int>(std::move(ptr_v_noncopyable_deleter)),
{0},
{global_extent});

iteration_ptr->close();
// force re-flush of previous iterations
series.flush();
Expand Down Expand Up @@ -940,7 +984,7 @@ inline void constant_scalar(std::string file_ending)
s.iterations[1]
.meshes["rho"][MeshRecordComponent::SCALAR]
.getAttribute("shape")
.get<std::vector<uint64_t> >() == Extent{1, 2, 3});
.get<std::vector<uint64_t>>() == Extent{1, 2, 3});
REQUIRE(s.iterations[1]
.meshes["rho"][MeshRecordComponent::SCALAR]
.containsAttribute("value"));
Expand All @@ -962,7 +1006,7 @@ inline void constant_scalar(std::string file_ending)
s.iterations[1]
.meshes["E"]["x"]
.getAttribute("shape")
.get<std::vector<uint64_t> >() == Extent{1, 2, 3});
.get<std::vector<uint64_t>>() == Extent{1, 2, 3});
REQUIRE(s.iterations[1].meshes["E"]["x"].containsAttribute("value"));
REQUIRE(
s.iterations[1]
Expand Down Expand Up @@ -990,7 +1034,7 @@ inline void constant_scalar(std::string file_ending)
s.iterations[1]
.particles["e"]["position"][RecordComponent::SCALAR]
.getAttribute("shape")
.get<std::vector<uint64_t> >() == Extent{3, 2, 1});
.get<std::vector<uint64_t>>() == Extent{3, 2, 1});
REQUIRE(s.iterations[1]
.particles["e"]["position"][RecordComponent::SCALAR]
.containsAttribute("value"));
Expand All @@ -1014,7 +1058,7 @@ inline void constant_scalar(std::string file_ending)
s.iterations[1]
.particles["e"]["positionOffset"][RecordComponent::SCALAR]
.getAttribute("shape")
.get<std::vector<uint64_t> >() == Extent{3, 2, 1});
.get<std::vector<uint64_t>>() == Extent{3, 2, 1});
REQUIRE(s.iterations[1]
.particles["e"]["positionOffset"][RecordComponent::SCALAR]
.containsAttribute("value"));
Expand All @@ -1036,7 +1080,7 @@ inline void constant_scalar(std::string file_ending)
s.iterations[1]
.particles["e"]["velocity"]["x"]
.getAttribute("shape")
.get<std::vector<uint64_t> >() == Extent{3, 2, 1});
.get<std::vector<uint64_t>>() == Extent{3, 2, 1});
REQUIRE(
s.iterations[1].particles["e"]["velocity"]["x"].containsAttribute(
"value"));
Expand Down Expand Up @@ -1388,55 +1432,55 @@ inline void dtype_test(const std::string &backend)
REQUIRE(s.getAttribute("emptyString").get<std::string>().empty());
}
REQUIRE(
s.getAttribute("vecChar").get<std::vector<char> >() ==
s.getAttribute("vecChar").get<std::vector<char>>() ==
std::vector<char>({'c', 'h', 'a', 'r'}));
REQUIRE(
s.getAttribute("vecInt16").get<std::vector<int16_t> >() ==
s.getAttribute("vecInt16").get<std::vector<int16_t>>() ==
std::vector<int16_t>({32766, 32767}));
REQUIRE(
s.getAttribute("vecInt32").get<std::vector<int32_t> >() ==
s.getAttribute("vecInt32").get<std::vector<int32_t>>() ==
std::vector<int32_t>({2147483646, 2147483647}));
REQUIRE(
s.getAttribute("vecInt64").get<std::vector<int64_t> >() ==
s.getAttribute("vecInt64").get<std::vector<int64_t>>() ==
std::vector<int64_t>({9223372036854775806, 9223372036854775807}));
REQUIRE(
s.getAttribute("vecUchar").get<std::vector<unsigned char> >() ==
s.getAttribute("vecUchar").get<std::vector<unsigned char>>() ==
std::vector<unsigned char>({'u', 'c', 'h', 'a', 'r'}));
REQUIRE(
s.getAttribute("vecSchar").get<std::vector<signed char> >() ==
s.getAttribute("vecSchar").get<std::vector<signed char>>() ==
std::vector<signed char>({'s', 'c', 'h', 'a', 'r'}));
REQUIRE(
s.getAttribute("vecUint16").get<std::vector<uint16_t> >() ==
s.getAttribute("vecUint16").get<std::vector<uint16_t>>() ==
std::vector<uint16_t>({65534u, 65535u}));
REQUIRE(
s.getAttribute("vecUint32").get<std::vector<uint32_t> >() ==
s.getAttribute("vecUint32").get<std::vector<uint32_t>>() ==
std::vector<uint32_t>({4294967294u, 4294967295u}));
REQUIRE(
s.getAttribute("vecUint64").get<std::vector<uint64_t> >() ==
s.getAttribute("vecUint64").get<std::vector<uint64_t>>() ==
std::vector<uint64_t>({18446744073709551614u, 18446744073709551615u}));
REQUIRE(
s.getAttribute("vecFloat").get<std::vector<float> >() ==
s.getAttribute("vecFloat").get<std::vector<float>>() ==
std::vector<float>({0.f, 3.40282e+38f}));
REQUIRE(
s.getAttribute("vecDouble").get<std::vector<double> >() ==
s.getAttribute("vecDouble").get<std::vector<double>>() ==
std::vector<double>({0., 1.79769e+308}));
if (test_long_double)
{
REQUIRE(
s.getAttribute("vecLongdouble").get<std::vector<long double> >() ==
s.getAttribute("vecLongdouble").get<std::vector<long double>>() ==
std::vector<long double>(
{0.L, std::numeric_limits<long double>::max()}));
}
REQUIRE(
s.getAttribute("vecString").get<std::vector<std::string> >() ==
s.getAttribute("vecString").get<std::vector<std::string>>() ==
std::vector<std::string>({"vector", "of", "strings"}));
if (!adios1)
{
REQUIRE(
s.getAttribute("vecEmptyString").get<std::vector<std::string> >() ==
s.getAttribute("vecEmptyString").get<std::vector<std::string>>() ==
std::vector<std::string>({"", "", ""}));
REQUIRE(
s.getAttribute("vecMixedString").get<std::vector<std::string> >() ==
s.getAttribute("vecMixedString").get<std::vector<std::string>>() ==
std::vector<std::string>({"hi", "", "ho"}));
}
REQUIRE(s.getAttribute("bool").get<bool>() == true);
Expand Down Expand Up @@ -1648,22 +1692,22 @@ void test_complex(const std::string &backend)
"longDoublesYouSay", std::complex<long double>(5.5, -4.55));

auto Cflt = o.iterations[0].meshes["Cflt"][RecordComponent::SCALAR];
std::vector<std::complex<float> > cfloats(3);
std::vector<std::complex<float>> cfloats(3);
cfloats.at(0) = {1., 2.};
cfloats.at(1) = {-3., 4.};
cfloats.at(2) = {5., -6.};
Cflt.resetDataset(Dataset(Datatype::CFLOAT, {cfloats.size()}));
Cflt.storeChunk(cfloats, {0});

auto Cdbl = o.iterations[0].meshes["Cdbl"][RecordComponent::SCALAR];
std::vector<std::complex<double> > cdoubles(3);
std::vector<std::complex<double>> cdoubles(3);
cdoubles.at(0) = {2., 1.};
cdoubles.at(1) = {-4., 3.};
cdoubles.at(2) = {6., -5.};
Cdbl.resetDataset(Dataset(Datatype::CDOUBLE, {cdoubles.size()}));
Cdbl.storeChunk(cdoubles, {0});

std::vector<std::complex<long double> > cldoubles(3);
std::vector<std::complex<long double>> cldoubles(3);
if (o.backend() != "ADIOS2" && o.backend() != "ADIOS1" &&
o.backend() != "MPI_ADIOS1")
{
Expand All @@ -1684,26 +1728,26 @@ void test_complex(const std::string &backend)
Series i = Series(
"../samples/serial_write_complex." + backend, Access::READ_ONLY);
REQUIRE(
i.getAttribute("lifeIsComplex").get<std::complex<double> >() ==
i.getAttribute("lifeIsComplex").get<std::complex<double>>() ==
std::complex<double>(4.56, 7.89));
REQUIRE(
i.getAttribute("butComplexFloats").get<std::complex<float> >() ==
i.getAttribute("butComplexFloats").get<std::complex<float>>() ==
std::complex<float>(42.3, -99.3));
if (i.backend() != "ADIOS2" && i.backend() != "ADIOS1" &&
i.backend() != "MPI_ADIOS1")
{
REQUIRE(
i.getAttribute("longDoublesYouSay")
.get<std::complex<long double> >() ==
.get<std::complex<long double>>() ==
std::complex<long double>(5.5, -4.55));
}

auto rcflt = i.iterations[0]
.meshes["Cflt"][RecordComponent::SCALAR]
.loadChunk<std::complex<float> >();
.loadChunk<std::complex<float>>();
auto rcdbl = i.iterations[0]
.meshes["Cdbl"][RecordComponent::SCALAR]
.loadChunk<std::complex<double> >();
.loadChunk<std::complex<double>>();
i.flush();

REQUIRE(rcflt.get()[1] == std::complex<float>(-3., 4.));
Expand All @@ -1714,7 +1758,7 @@ void test_complex(const std::string &backend)
{
auto rcldbl = i.iterations[0]
.meshes["Cldbl"][RecordComponent::SCALAR]
.loadChunk<std::complex<long double> >();
.loadChunk<std::complex<long double>>();
i.flush();
REQUIRE(rcldbl.get()[2] == std::complex<long double>(7., -6.));
}
Expand Down Expand Up @@ -4957,7 +5001,7 @@ void bp4_steps(
auto E_x = E["x"];
REQUIRE(
E.getAttribute("vector_of_string")
.get<std::vector<std::string> >() ==
.get<std::vector<std::string>>() ==
std::vector<std::string>{"vector", "of", "string"});
REQUIRE(E_x.getDimensionality() == 1);
REQUIRE(E_x.getExtent()[0] == 10);
Expand Down Expand Up @@ -5191,7 +5235,7 @@ struct AreEqual
};

template <typename T>
struct AreEqual<std::vector<T> >
struct AreEqual<std::vector<T>>
{
static bool areEqual(std::vector<T> v1, std::vector<T> v2)
{
Expand Down

0 comments on commit 442481a

Please sign in to comment.