Skip to content

move prior weights to model arguments #450

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
rename and expand doc
  • Loading branch information
sbfnk committed Sep 8, 2023
commit 7c16faf33cb0d704dbf2cbcf32936abfb2566b10
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ This release is in development. For a stable release install 1.3.5 from CRAN.
* Added content to the vignette for the estimate_truncation model. By @sbfnk in #439 and reviewed by @seabbs.
* Added a feature to the `estimate_truncation` to allow it to be applied to time series that are shorter than the truncation max. By @sbfnk in #438 and reviewed by @seabbs.
* Changed the `estimate_truncation` to use the `dist_spec` interface, deprecating existing options `max_trunc` and `trunc_dist`. By @sbfnk in #448 reviewed by @seabbs.
* Added a `weigh_prior_delays` argument to the main functions, allowing the users to choose whether to weigh delay priors by the number of data points or not. By @sbfnk in #450 and reviewed by @seabbs.
* Added a `weigh_delay_priors` argument to the main functions, allowing the users to choose whether to weigh delay priors by the number of data points or not. By @sbfnk in #450 and reviewed by @seabbs.

## Documentation

Expand Down
9 changes: 5 additions & 4 deletions R/estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@
#' @param horizon Numeric, defaults to 7. Number of days into the future to
#' forecast.
#'
#' @param weigh_prior_delays Logical. If TRUE (default), all delay distribution
#' priors will be weighted by the number of observation data points, usually
#' @param weigh_delay_priors Logical. If TRUE (default), all delay distribution
#' priors will be weighted by the number of observation data points, in doing so
#' approximately placing an independent prior at each time step and usually
#' preventing the posteriors from shifting. If FALSE, no weight will be applied,
#' i.e. delay distributions will be treated as a single parameters.
#'
Expand Down Expand Up @@ -251,7 +252,7 @@ estimate_infections <- function(reported_cases,
CrIs = c(0.2, 0.5, 0.9),
filter_leading_zeros = TRUE,
zero_threshold = Inf,
weigh_prior_delays = TRUE,
weigh_delay_priors = TRUE,
id = "estimate_infections",
verbose = interactive()) {
set_dt_single_thread()
Expand Down Expand Up @@ -319,7 +320,7 @@ estimate_infections <- function(reported_cases,
delay = delays,
trunc = truncation,
weight = ifelse(
weigh_prior_delays, data$t - data$seeding_time - data$horizon, 1
weigh_delay_priors, data$t - data$seeding_time - data$horizon, 1
)
))

Expand Down
12 changes: 7 additions & 5 deletions R/estimate_secondary.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,12 @@
#' use for estimation but not to fit to at the beginning of the time series.
#' This must be less than the number of observations.
#'
#' @param weigh_prior_delays Logical. If TRUE, all delay distribution
#' priors will be weighted by the number of observation data points, usually
#' @param weigh_delay_priors Logical. If TRUE, all delay distribution priors
#' will be weighted by the number of observation data points, in doing so
#' approximately placing an independent prior at each time step and usually
#' preventing the posteriors from shifting. If FALSE (default), no weight will
#' be applied, i.e. delay distributions will be treated as a single parameters.
#' be applied, i.e. delay distributions will be treated as a single
#' parameters.
#'
#' @param verbose Logical, should model fitting progress be returned. Defaults
#' to `interactive()`.
Expand Down Expand Up @@ -143,7 +145,7 @@ estimate_secondary <- function(reports,
CrIs = c(0.2, 0.5, 0.9),
priors = NULL,
model = NULL,
weigh_prior_delays = FALSE,
weigh_delay_priors = FALSE,
verbose = interactive(),
...) {
reports <- data.table::as.data.table(reports)
Expand All @@ -166,7 +168,7 @@ estimate_secondary <- function(reports,
data <- c(data, create_stan_delays(
delay = delays,
trunc = truncation,
weight = ifelse(weigh_prior_delays, data$t, 1)
weight = ifelse(weigh_delay_priors, data$t, 1)
))

# observation model data
Expand Down
14 changes: 8 additions & 6 deletions R/estimate_truncation.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,13 @@
#' @param model A compiled stan model to override the default model. May be
#' useful for package developers or those developing extensions.
#'
#' @param weigh_prior_delays Logical. If TRUE, all delay distribution
#' priors will be weighted by the number of observation data points, usually
#' @param weigh_delay_priors Logical. If TRUE, all delay distribution priors
#' will be weighted by the number of observation data points, in doing so
#' approximately placing an independent prior at each time step and usually
#' preventing the posteriors from shifting. If FALSE (default), no weight will
#' be applied, i.e. delay distributions will be treated as a single parameters.
#
#' be applied, i.e. delay distributions will be treated as a single
#' parameters.
#'
#' @param verbose Logical, should model fitting progress be returned.
#'
#' @param ... Additional parameters to pass to `rstan::sampling`.
Expand Down Expand Up @@ -143,7 +145,7 @@ estimate_truncation <- function(obs, max_truncation, trunc_max = 10,
),
model = NULL,
CrIs = c(0.2, 0.5, 0.9),
weigh_prior_delays = FALSE,
weigh_delay_priors = FALSE,
verbose = TRUE,
...) {

Expand Down Expand Up @@ -224,7 +226,7 @@ estimate_truncation <- function(obs, max_truncation, trunc_max = 10,

data <- c(data, create_stan_delays(
trunc = truncation,
weight = ifelse(weigh_prior_delays, data$t, 1)
weight = ifelse(weigh_delay_priors, data$t, 1)
))

## convert to integer
Expand Down
7 changes: 4 additions & 3 deletions man/estimate_infections.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 6 additions & 4 deletions man/estimate_secondary.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 6 additions & 4 deletions man/estimate_truncation.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions tests/testthat/test-estimate_secondary.R
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,14 @@ test_that("forecast_secondary can return values from simulated data and plot
expect_error(plot(inc_preds, new_obs = cases, from = "2020-05-01"), NA)
})

test_that("estimate_secondary works with weigh_prior_delays = TRUE", {
test_that("estimate_secondary works with weigh_delay_priors = TRUE", {
delays <- dist_spec(
mean = 2.5, mean_sd = 0.5, sd = 0.47, sd_sd = 0.25, max = 30
)
inc_weigh <- estimate_secondary(
cases[1:60], delays = delays,
obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), week_effect = FALSE),
weigh_prior_delays = TRUE, verbose = FALSE
weigh_delay_priors = TRUE, verbose = FALSE
)
expect_s3_class(inc_weigh, "estimate_secondary")
})