diff --git a/DESCRIPTION b/DESCRIPTION index 3cd7c59..81ff8ff 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,5 +1,5 @@ Package: marshal -Version: 0.0.0-9029 +Version: 0.0.0-9030 Title: Framework to Marshal Objects to be Used in Another R Process Suggests: bundle, @@ -20,6 +20,8 @@ Suggests: stats, terra, tools, + torch, + luz, xgboost, XML, xml2 diff --git a/NAMESPACE b/NAMESPACE index a50087b..9fecfa0 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -13,6 +13,7 @@ S3method(marshal,XMLAbstractNode) S3method(marshal,connection) S3method(marshal,data.table) S3method(marshal,keras.engine.base_layer.Layer) +S3method(marshal,luz_module_fitted) S3method(marshal,model_fit) S3method(marshal,ncdf4) S3method(marshal,stanfit) @@ -35,6 +36,7 @@ S3method(marshallable,connection) S3method(marshallable,data.table) S3method(marshallable,default) S3method(marshallable,keras.engine.base_layer.Layer) +S3method(marshallable,luz_module_fitted) S3method(marshallable,marshalled) S3method(marshallable,model_fit) S3method(marshallable,ncdf4) diff --git a/R/marshal.torch.R b/R/marshal.torch.R new file mode 100644 index 0000000..7968ba7 --- /dev/null +++ b/R/marshal.torch.R @@ -0,0 +1,60 @@ +#' Marshalling of 'torch' objects +#' +#' @param model +#' A `luz_module_fitted` object. +#' +#' @param \dots Not used. +#' +#' @return +#' A `marshalled` object as described in [marshal()]. +#' +#' @details +#' [luz::luz_save()] is used to produce a marshalled version +#' of the original object. +#' [luz::luz_load()] is used to reconstruct a version of the +#' original object from the marshalled object. +#' +#' @rdname marshal.torch +#' @aliases marshal.luz_module_fitted +#' @export +marshal.luz_module_fitted <- function(model, ...) { + raw <- suppressWarnings(local({ + con <- rawConnection(raw(), open = "wb") + on.exit(close(con)) + luz::luz_save(model, con) + rawConnectionValue(con) + })) + + res <- list( + marshalled = raw + ) + class(res) <- marshal_class(model) + + ## IMPORTANT: We don't want any of the input arguments + ## to be part of the unmarshal() environment + rm(list = c("model", names(list(...)))) + + res[["unmarshal"]] <- unmarshal_luz_module_fitted + assert_no_references(res) + res +} + +unmarshal_luz_module_fitted <- function(model, ...) { + object <- model[["marshalled"]] + + res <- local({ + con <- rawConnection(object) + on.exit(close(con)) + luz::luz_load(con) + }) + stopifnot(all.equal(class(res), marshal_unclass(model), check.attributes = FALSE)) + res +} + + +#' @rdname marshal.torch +#' @aliases marshallable.luz_module_fitted +#' @export +marshallable.luz_module_fitted <- function(...) { + TRUE +} diff --git a/man/marshal.torch.Rd b/man/marshal.torch.Rd new file mode 100644 index 0000000..9835b15 --- /dev/null +++ b/man/marshal.torch.Rd @@ -0,0 +1,28 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/marshal.torch.R +\name{marshal.luz_module_fitted} +\alias{marshal.luz_module_fitted} +\alias{marshallable.luz_module_fitted} +\title{Marshalling of 'torch' objects} +\usage{ +\method{marshal}{luz_module_fitted}(model, ...) + +\method{marshallable}{luz_module_fitted}(...) +} +\arguments{ +\item{model}{A \code{luz_module_fitted} object.} + +\item{\dots}{Not used.} +} +\value{ +A \code{marshalled} object as described in \code{\link[=marshal]{marshal()}}. +} +\description{ +Marshalling of 'torch' objects +} +\details{ +\code{\link[luz:luz_save]{luz::luz_save()}} is used to produce a marshalled version +of the original object. +\code{\link[luz:luz_load]{luz::luz_load()}} is used to reconstruct a version of the +original object from the marshalled object. +}