Skip to content

Commit

Permalink
Added common_xregs to fabletools from fable
Browse files Browse the repository at this point in the history
  • Loading branch information
mitchelloharawild committed Mar 16, 2021
1 parent 601e342 commit 9684c8b
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 1 deletion.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ export(box_cox)
export(combination_ensemble)
export(combination_model)
export(common_periods)
export(common_xregs)
export(components)
export(construct_fc)
export(dable)
Expand Down
6 changes: 5 additions & 1 deletion R/definitions.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ model_definition <- R6::R6Class(NULL,
specials = list(),
formula = NULL,
extra = NULL,
origin = NULL,
env = global_env(),
check = function(.data){
},
Expand All @@ -28,7 +29,6 @@ model_definition <- R6::R6Class(NULL,
xreg_env <- get_env(self$specials$xreg)
xreg_env$lag <- self$recall_lag


self$prepare(formula, ...)

self$extra <- list2(...)
Expand All @@ -49,6 +49,10 @@ model_definition <- R6::R6Class(NULL,
data = NULL,
add_data = function(.data){
self$check(.data)
# Add data origin if not yet known (fitting model)
if(is.null(self$origin)) {
self$origin <- .data[[index_var(.data)]][[1]]
}
self$data <- .data
},
remove_data = function(){
Expand Down
140 changes: 140 additions & 0 deletions R/xregs.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
fbl_trend <- function(x, knots = NULL, origin = NULL) {
idx_num <- as.double(x[[index_var(x)]])
knots_num <- if (is.null(knots)) {
NULL
} else {
as.double(knots)
}
index_interval <- default_time_units(interval(x))
idx_num <- idx_num / index_interval
knots_num <- knots_num / index_interval
if (!is.null(origin)) {
# trend should count from 1
origin <- as.double(origin) / index_interval - 1
idx_num <- idx_num - origin
knots_num <- knots_num - origin
}

knots_exprs <- map(knots_num, function(.x) pmax(0, idx_num - .x))
knots_exprs <- set_names(
knots_exprs,
map_chr(knots_num, function(.x) paste0("trend_", format(.x)))
)
tibble(
trend = idx_num,
!!!knots_exprs
)
}

fbl_season <- function(x, period) {
idx_num <- as.double(x[[index_var(x)]])
index_interval <- default_time_units(interval(x))
idx_num <- idx_num / index_interval
period <- get_frequencies(period, x, .auto = "smallest")
season_exprs <- map(period, function(.x) expr(factor(floor((idx_num %% (!!.x)) + 1), levels = seq_len(!!.x))))
season_exprs <- set_names(season_exprs, names(period) %||% paste0("season_", period))
tibble(!!!season_exprs)
}

fbl_fourier <- function(x, period, K, origin = NULL) {
idx_num <- as.double(x[[index_var(x)]])
index_interval <- default_time_units(interval(x))
idx_num <- idx_num / index_interval
if (!is.null(origin)) {
origin <- as.double(origin) / index_interval
}
period <- get_frequencies(period, x, .auto = "smallest")

if (length(period) != length(K)) {
abort("Number of periods does not match number of orders")
}
if (any(2 * K > period)) {
abort("K must be not be greater than period/2")
}

fourier_exprs <- map2(
as.numeric(period), K,
function(period, K) {
set_names(seq_len(K) / period, paste0(seq_len(K), "_", round(period)))
}
) %>%
invoke(c, .) %>%
.[!duplicated(.)] %>%
map2(., names(.), function(p, name) {
out <- exprs(C = cospi(2 * !!p * idx_num))
if (abs(2 * p - round(2 * p)) > .Machine$double.eps) {
out <- c(out, exprs(S = sinpi(2 * !!p * idx_num)))
}
names(out) <- paste0(names(out), name)
out
}) %>%
set_names(NULL) %>%
unlist(recursive = FALSE)

tibble(!!!fourier_exprs)
}

#' Common exogenous regressors
#'
#' These special functions provide interfaces to more complicated functions within
#' the model formulae interface.
#'
#' @section Specials:
#'
#' \subsection{trend}{
#' The `trend` special includes common linear trend regressors in the model. It also supports piecewise linear trend via the `knots` argument.
#' \preformatted{
#' trend(knots = NULL, origin = NULL)
#' }
#'
#' \tabular{ll}{
#' `knots` \tab A vector of times (same class as the data's time index) identifying the position of knots for a piecewise linear trend.\cr
#' `origin` \tab An optional time value to act as the starting time for the trend.
#' }
#' }
#'
#' \subsection{season}{
#' The `season` special includes seasonal dummy variables in the model.
#' \preformatted{
#' season(period = NULL)
#' }
#'
#' \tabular{ll}{
#' `period` \tab The periodic nature of the seasonality. This can be either a number indicating the number of observations in each seasonal period, or text to indicate the duration of the seasonal window (for example, annual seasonality would be "1 year").
#' }
#' }
#'
#' \subsection{fourier}{
#' The `fourier` special includes seasonal fourier terms in the model. The maximum order of the fourier terms must be specified using `K`.
#' \preformatted{
#' fourier(period = NULL, K, origin = NULL)
#' }
#'
#' \tabular{ll}{
#' `period` \tab The periodic nature of the seasonality. This can be either a number indicating the number of observations in each seasonal period, or text to indicate the duration of the seasonal window (for example, annual seasonality would be "1 year"). \cr
#' `K` \tab The maximum order of the fourier terms.\cr
#' `origin` \tab An optional time value to act as the starting time for the fourier series.
#' }
#' }
#'
#' @format NULL
#'
#' @export
common_xregs <- list(
trend = function(knots = NULL, origin = NULL) {
if (is.null(origin)) {
origin <- self$origin
}
as.matrix(fabletools:::fbl_trend(self$data, knots, origin))
},
season = function(period = NULL) {
out <- as_model_matrix(fabletools:::fbl_season(self$data, period))
stats::model.matrix(~., data = out)[, -1, drop = FALSE]
},
fourier = function(period = NULL, K, origin = NULL) {
if (is.null(origin)) {
origin <- self$origin
}
as.matrix(fabletools:::fbl_fourier(self$data, period, K, origin))
}
)
54 changes: 54 additions & 0 deletions man/common_xregs.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 9684c8b

Please sign in to comment.