Skip to content

Commit 8e20998

Browse files
authored
Simulate infections (#557)
1 parent 66b2e6a commit 8e20998

10 files changed

+454
-44
lines changed

NAMESPACE

+1
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ importFrom(checkmate,assert_names)
117117
importFrom(checkmate,assert_numeric)
118118
importFrom(checkmate,assert_path_for_output)
119119
importFrom(checkmate,assert_string)
120+
importFrom(checkmate,assert_subset)
120121
importFrom(checkmate,test_data_frame)
121122
importFrom(checkmate,test_numeric)
122123
importFrom(data.table,":=")

NEWS.md

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
* The `fixed` argument to `dist_spec` has been deprecated and replaced by a `fix_dist()` function. By @sbfnk in #503 and reviewed by @seabbs.
88
* Updated `estimate_infections()` so that rather than imputing missing data, it now skips these data points in the likelihood. This is a breaking change as it alters the behaviour of the model when dates are missing from a time series but are known to be zero. We recommend that users check their results when updating to this version but expect this to in most cases improve performance. By @seabbs in #528 and reviewed by @sbfnk.
99
* `simulate_infections` has been renamed to `forecast_infections` in line with `simulate_secondary` and `forecast_secondary`. The terminology is: a forecast is done from a fit to existing data, a simulation from first principles. By @sbfnk in #544 and reviewed by @seabbs.
10+
* A new `simulate_infections` function has been added that can be used to simulate from the model from given initial conditions and parameters. By @sbfnk in #557 and reviewed by @jamesmbaazam.
1011

1112
## Documentation
1213

R/extract.R

+23-20
Original file line numberDiff line numberDiff line change
@@ -181,28 +181,31 @@ extract_parameter_samples <- function(stan_fit, data, reported_dates,
181181
samples,
182182
reported_dates
183183
)
184-
if (data$estimate_r == 1) {
185-
out$R <- extract_parameter(
186-
"R",
187-
samples,
188-
reported_dates
189-
)
190-
if (data$bp_n > 0) {
191-
out$breakpoints <- extract_parameter(
192-
"bp_effects",
184+
if ("estimate_r" %in% names(data)) {
185+
if (data$estimate_r == 1) {
186+
out$R <- extract_parameter(
187+
"R",
193188
samples,
194-
1:data$bp_n
189+
reported_dates
190+
)
191+
if (data$bp_n > 0) {
192+
out$breakpoints <- extract_parameter(
193+
"bp_effects",
194+
samples,
195+
1:data$bp_n
196+
)
197+
out$breakpoints <- out$breakpoints[
198+
,
199+
strat := date
200+
][, c("time", "date") := NULL]
201+
}
202+
} else {
203+
out$R <- extract_parameter(
204+
"gen_R",
205+
samples,
206+
reported_dates
195207
)
196-
out$breakpoints <- out$breakpoints[,
197-
strat := date][, c("time", "date") := NULL
198-
]
199208
}
200-
} else {
201-
out$R <- extract_parameter(
202-
"gen_R",
203-
samples,
204-
reported_dates
205-
)
206209
}
207210
out$growth_rate <- extract_parameter(
208211
"r",
@@ -243,7 +246,7 @@ extract_parameter_samples <- function(stan_fit, data, reported_dates,
243246
value.V1 := NULL
244247
]
245248
}
246-
if (data$obs_scale_sd > 0) {
249+
if ("obs_scale_sd" %in% names(data) && data$obs_scale_sd > 0) {
247250
out$fraction_observed <- extract_static_parameter("frac_obs", samples)
248251
out$fraction_observed <- out$fraction_observed[, value := value.V1][,
249252
value.V1 := NULL

R/simulate_infections.R

+199-10
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,208 @@
1-
#' Deprecated; use [forecast_infections()] instead
1+
#' Simulate infections using the renewal equation
22
#'
3-
#' Calling this function passes all arguments to [forecast_infections()]
4-
#' @description `r lifecycle::badge("deprecated")`
5-
#' @param ... Arguments to be passed to [forecast_infections()]
6-
#' @return the result of [forecast_infections()]
3+
#' Simulations are done from given initial infections and, potentially
4+
#' time-varying, reproduction numbers. Delays and parameters of the observation
5+
#' model can be specified using the same options as in [estimate_infections()].
6+
#'
7+
#' In order to simulate, all parameters that are specified such as the mean and
8+
#' standard deviation of delays or observation scaling, must be fixed.
9+
#' Uncertain parameters are not allowed.
10+
#'
11+
#' A previous function called [simulate_infections()] that simulates from a
12+
#' given model fit has been renamed [forecast_infections()]. Using
13+
#' [simulate_infections()] with existing estimates is now deprecated. This
14+
#' option will be removed in version 2.1.0.
15+
#' @param R a data frame of reproduction numbers (column `R`) by date (column
16+
#' `date`). Column `R` must be numeric and `date` must be in date format. If
17+
#' not all days between the first and last day in the `date` are present,
18+
#' it will be assumed that R stays the same until the next given date.
19+
#' @param initial_infections numeric; the initial number of infections.
20+
#' @param day_of_week_effect either `NULL` (no day of the week effect) or a
21+
#' numerical vector of length specified in [obs_opts()] as `week_length`
22+
#' (default: 7) if `week_effect` is set to TRUE. Each element of the vector
23+
#' gives the weight given to reporting on this day (normalised to 1).
24+
#' The default is `NULL`.
25+
#' @param estimates deprecated; use [forecast_infections()] instead
26+
#' @param ... deprecated; only included for backward compatibility
27+
#' @inheritParams estimate_infections
28+
#' @inheritParams rt_opts
29+
#' @inheritParams stan_opts
30+
#' @importFrom lifecycle deprecate_warn
31+
#' @importFrom checkmate assert_data_frame assert_date assert_numeric
32+
#' assert_subset
33+
#' @importFrom data.table data.table merge.data.table nafill rbindlist
34+
#' @return A data.table of simulated infections (variable `infections`) and
35+
#' reported cases (variable `reported_cases`) by date.
36+
#' @author Sebastian Funk
737
#' @export
8-
simulate_infections <- function(...) {
38+
#' @examples
39+
#' \donttest{
40+
#' R <- data.frame(
41+
#' date = seq.Date(as.Date("2023-01-01"), length.out = 14, by = "day"),
42+
#' R = c(rep(1.2, 7), rep(0.8, 7))
43+
#' )
44+
#' sim <- simulate_infections(
45+
#' R = R,
46+
#' initial_infections = 100,
47+
#' generation_time = generation_time_opts(
48+
#' fix_dist(example_generation_time)
49+
#' ),
50+
#' delays = delay_opts(fix_dist(example_reporting_delay)),
51+
#' obs = obs_opts(family = "poisson")
52+
#' )
53+
#' }
54+
simulate_infections <- function(estimates, R, initial_infections,
55+
day_of_week_effect = NULL,
56+
generation_time = generation_time_opts(),
57+
delays = delay_opts(),
58+
truncation = trunc_opts(),
59+
obs = obs_opts(),
60+
CrIs = c(0.2, 0.5, 0.9),
61+
backend = "rstan",
62+
pop = 0, ...) {
63+
## deprecated usage
64+
if (!missing(estimates)) {
965
deprecate_warn(
1066
"2.0.0",
11-
"simulate_infections()",
67+
"simulate_infections(estimates)",
1268
"forecast_infections()",
13-
"A new [simulate_infections()] function for simulating from given ",
14-
"parameters is planned for implementation in the future."
69+
details = paste0(
70+
"This `estimates` option will be removed from [simulate_infections()] ",
71+
"in version 2.1.0."
72+
)
73+
)
74+
return(forecast_infections(estimates = estimates, ...))
75+
}
76+
77+
## check inputs
78+
assert_data_frame(R, any.missing = FALSE)
79+
assert_subset(colnames(R), c("date", "R"))
80+
assert_date(R$date)
81+
assert_numeric(R$R, lower = 0)
82+
assert_numeric(initial_infections, lower = 0)
83+
assert_numeric(day_of_week_effect, lower = 0, null.ok = TRUE)
84+
assert_numeric(pop, lower = 0)
85+
assert_class(delays, "delay_opts")
86+
assert_class(obs, "obs_opts")
87+
assert_class(generation_time, "generation_time_opts")
88+
89+
## create R for all dates modelled
90+
all_dates <- data.table(date = seq.Date(min(R$date), max(R$date), by = "day"))
91+
R <- merge.data.table(all_dates, R, by = "date", all.x = TRUE)
92+
R <- R[, R := nafill(R, type = "locf")]
93+
## remove any initial NAs
94+
R <- R[!is.na(R)]
95+
96+
seeding_time <- get_seeding_time(delays, generation_time)
97+
if (seeding_time > 1) {
98+
## estimate initial growth from initial reproduction number if seeding time
99+
## is greater than 1
100+
initial_growth <- (R$R[1] - 1) / mean(generation_time)
101+
} else {
102+
initial_growth <- numeric(0)
103+
}
104+
105+
data <- list(
106+
n = 1,
107+
t = nrow(R) + seeding_time,
108+
seeding_time = seeding_time,
109+
future_time = 0,
110+
initial_infections = array(log(initial_infections), dim = c(1, 1)),
111+
initial_growth = array(initial_growth, dim = c(1, length(initial_growth))),
112+
R = array(R$R, dim = c(1, nrow(R))),
113+
pop = pop
114+
)
115+
116+
data <- c(data, create_stan_delays(
117+
gt = generation_time,
118+
delay = delays,
119+
trunc = truncation
120+
))
121+
122+
if ((length(data$delay_mean_sd) > 0 && any(data$delay_mean_sd > 0)) ||
123+
(length(data$delay_sd_sd) > 0 && any(data$delay_sd_sd > 0))) {
124+
stop(
125+
"Cannot simulate from uncertain parameters. Use the [fix_dist()] ",
126+
"function to set the parameters of uncertain distributions either the ",
127+
"mean or a randomly sampled value"
15128
)
16-
forecast_infections(...)
129+
}
130+
data$delay_mean <- array(
131+
data$delay_mean_mean, dim = c(1, length(data$delay_mean_mean))
132+
)
133+
data$delay_sd <- array(
134+
data$delay_sd_mean, dim = c(1, length(data$delay_sd_mean))
135+
)
136+
data$delay_mean_sd <- NULL
137+
data$delay_sd_sd <- NULL
138+
139+
data <- c(data, create_obs_model(
140+
obs, dates = R$date
141+
))
142+
143+
if (data$obs_scale_sd > 0) {
144+
stop(
145+
"Cannot simulate from uncertain observation scaling; use fixed scaling ",
146+
"instead."
147+
)
148+
}
149+
if (data$obs_scale) {
150+
data$frac_obs <- array(data$obs_scale_mean, dim = c(1, 1))
151+
} else {
152+
data$frac_obs <- array(dim = c(1, 0))
153+
}
154+
data$obs_scale_mean <- NULL
155+
data$obs_scale_sd <- NULL
156+
157+
if (obs$family == "negbin") {
158+
if (data$phi_sd > 0) {
159+
stop(
160+
"Cannot simulate from uncertain overdispersion; use fixed ",
161+
"overdispersion instead."
162+
)
163+
}
164+
data$rep_phi <- array(data$phi_mean, dim = c(1, 1))
165+
} else {
166+
data$rep_phi <- array(dim = c(1, 0))
167+
}
168+
data$phi_mean <- NULL
169+
data$phi_sd <- NULL
170+
171+
## day of week effect
172+
if (is.null(day_of_week_effect)) {
173+
day_of_week_effect <- rep(1, data$week_effect)
174+
}
175+
176+
day_of_week_effect <- day_of_week_effect / sum(day_of_week_effect)
177+
data$day_of_week_simplex <- array(
178+
day_of_week_effect, dim = c(1, data$week_effect)
179+
)
180+
181+
# Create stan arguments
182+
stan <- stan_opts(backend = backend, chains = 1, samples = 1, warmup = 1)
183+
args <- create_stan_args(
184+
stan, data = data, fixed_param = TRUE, model = "simulate_infections",
185+
verbose = FALSE
186+
)
187+
188+
## simulate
189+
sim <- fit_model(args, id = "simulate_infections")
190+
191+
## join batches
192+
dates <- c(
193+
seq(min(R$date) - seeding_time, min(R$date) - 1, by = "day"),
194+
R$date
195+
)
196+
out <- extract_parameter_samples(sim, data,
197+
reported_inf_dates = dates,
198+
reported_dates = dates[-(1:seeding_time)],
199+
drop_length_1 = TRUE
200+
)
201+
202+
out <- rbindlist(out[c("infections", "reported_cases")], idcol = "variable")
203+
out <- out[, c("sample", "parameter", "time") := NULL]
204+
205+
return(out[])
17206
}
18207

19208
#' Forecast infections from a given fit and trajectory of the time-varying

inst/stan/data/simulation_rt.stan

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
array[seeding_time ? n : 0, 1] real initial_infections; // initial logged infections
2-
array[seeding_time > 1 ? n : 0, 1] real initial_growth; //initial growth
1+
array[n, 1] real initial_infections; // initial logged infections
2+
array[n, seeding_time > 1 ? 1 : 0] real initial_growth; //initial growth
33

44
matrix[n, t - seeding_time] R; // reproduction number
55
int pop; // susceptible population

inst/stan/simulate_infections.stan

+17-3
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ generated quantities {
6363
to_vector(infections[i]), delay_rev_pmf, seeding_time)
6464
);
6565
} else {
66-
reports[i] = to_row_vector(infections[(seeding_time + 1):t]);
66+
reports[i] = to_row_vector(
67+
infections[i, (seeding_time + 1):t]
68+
);
6769
}
6870

6971
// weekly reporting effect
@@ -72,6 +74,18 @@ generated quantities {
7274
day_of_week_effect(to_vector(reports[i]), day_of_week,
7375
to_vector(day_of_week_simplex[i])));
7476
}
77+
// truncate near time cases to observed reports
78+
if (trunc_id) {
79+
vector[delay_type_max[trunc_id] + 1] trunc_rev_cmf = get_delay_rev_pmf(
80+
trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id,
81+
delay_types_groups, delay_max, delay_np_pmf,
82+
delay_np_pmf_groups, delay_mean[i], delay_sd[i], delay_dist,
83+
0, 1, 1
84+
);
85+
reports[i] = to_row_vector(truncate(
86+
to_vector(reports[i]), trunc_rev_cmf, 0)
87+
);
88+
}
7589
// scale observations
7690
if (obs_scale) {
7791
reports[i] = to_row_vector(scale_obs(to_vector(reports[i]), frac_obs[i, 1]));
@@ -81,8 +95,8 @@ generated quantities {
8195
to_vector(reports[i]), rep_phi[i], model_type
8296
);
8397
{
84-
real gt_mean = rev_pmf_mean(gt_rev_pmf, 1);
85-
real gt_var = rev_pmf_var(gt_rev_pmf, 1, gt_mean);
98+
real gt_mean = rev_pmf_mean(gt_rev_pmf, 0);
99+
real gt_var = rev_pmf_var(gt_rev_pmf, 0, gt_mean);
86100
r[i] = R_to_growth(to_vector(R[i]), gt_mean, gt_var);
87101
}
88102
}

0 commit comments

Comments
 (0)