Skip to content

Commit

Permalink
Add marshalling for 'torch'/'luz' models; TODO: add example and tests [
Browse files Browse the repository at this point in the history
  • Loading branch information
HenrikBengtsson committed Sep 30, 2023
1 parent 15dcb0a commit 168e355
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 1 deletion.
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Package: marshal
Version: 0.0.0-9029
Version: 0.0.0-9030
Title: Framework to Marshal Objects to be Used in Another R Process
Suggests:
bundle,
Expand All @@ -20,6 +20,8 @@ Suggests:
stats,
terra,
tools,
torch,
luz,
xgboost,
XML,
xml2
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ S3method(marshal,XMLAbstractNode)
S3method(marshal,connection)
S3method(marshal,data.table)
S3method(marshal,keras.engine.base_layer.Layer)
S3method(marshal,luz_module_fitted)
S3method(marshal,model_fit)
S3method(marshal,ncdf4)
S3method(marshal,stanfit)
Expand All @@ -35,6 +36,7 @@ S3method(marshallable,connection)
S3method(marshallable,data.table)
S3method(marshallable,default)
S3method(marshallable,keras.engine.base_layer.Layer)
S3method(marshallable,luz_module_fitted)
S3method(marshallable,marshalled)
S3method(marshallable,model_fit)
S3method(marshallable,ncdf4)
Expand Down
60 changes: 60 additions & 0 deletions R/marshal.torch.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#' Marshalling of 'torch' objects
#'
#' @param model
#' A `luz_module_fitted` object.
#'
#' @param \dots Not used.
#'
#' @return
#' A `marshalled` object as described in [marshal()].
#'
#' @details
#' [luz::luz_save()] is used to produce a marshalled version
#' of the original object.
#' [luz::luz_load()] is used to reconstruct a version of the
#' original object from the marshalled object.
#'
#' @rdname marshal.torch
#' @aliases marshal.luz_module_fitted
#' @export
marshal.luz_module_fitted <- function(model, ...) {
raw <- suppressWarnings(local({
con <- rawConnection(raw(), open = "wb")
on.exit(close(con))
luz::luz_save(model, con)
rawConnectionValue(con)
}))

res <- list(
marshalled = raw
)
class(res) <- marshal_class(model)

## IMPORTANT: We don't want any of the input arguments
## to be part of the unmarshal() environment
rm(list = c("model", names(list(...))))

res[["unmarshal"]] <- unmarshal_luz_module_fitted
assert_no_references(res)
res
}

unmarshal_luz_module_fitted <- function(model, ...) {
object <- model[["marshalled"]]

res <- local({
con <- rawConnection(object)
on.exit(close(con))
luz::luz_load(con)
})
stopifnot(all.equal(class(res), marshal_unclass(model), check.attributes = FALSE))
res
}


#' @rdname marshal.torch
#' @aliases marshallable.luz_module_fitted
#' @export
marshallable.luz_module_fitted <- function(...) {
TRUE
}
28 changes: 28 additions & 0 deletions man/marshal.torch.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 168e355

Please sign in to comment.