Skip to content

fixed generation times or distributions #276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 8, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fixed generation times or distributions
  • Loading branch information
sbfnk committed Oct 8, 2021
commit 516317a42de5a3c198755010bedb7cc4ca45f087
47 changes: 39 additions & 8 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,32 @@ create_stan_data <- function(reported_cases, generation_time,
rt, gp, obs, delays, horizon,
backcalc, shifted_cases,
truncation) {
## complete generation time parameters if not all are given
if (is.null(generation_time)) {
generation_time <- list(mean = 1)
}
for (param in c("mean_sd", "sd", "sd_sd")) {
if (!(param %in% names(generation_time))) generation_time[[param]] <- 0
}
## check if generation time is fixed
if (generation_time$sd == 0 && generation_time$sd_sd == 0) {
if ("max_gt" %in% names(generation_time)) {
if (generation_time$max_gt != generation_time$mean) {
stop("Error in generation time defintion: if max_gt(",
generation_time$max_gt,
") is given it must be equal to the mean (",
generation_time$mean,
")")
}
} else {
generation_time$max_gt <- generation_time$mean
}
if (any(generation_time$mean_sd > 0, generation_time$sd_sd > 0)) {
stop("Error in generation time definition: if sd_mean is 0 and ",
"sd_sd is 0 then mean_sd must be 0, too.")
}
}

cases <- reported_cases[(delays$seeding_time + 1):(.N - horizon)]$confirm

data <- list(
Expand Down Expand Up @@ -516,14 +542,19 @@ create_initial_conditions <- function(data) {
n = 1, mean = convert_to_logmean(data$r_mean, data$r_sd),
sd = convert_to_logsd(data$r_mean, data$r_sd) * 0.1
))
out$gt_mean <- array(truncnorm::rtruncnorm(1,
a = 0, mean = data$gt_mean_mean,
sd = data$gt_mean_sd * 0.1
))
out$gt_sd <- array(truncnorm::rtruncnorm(1,
a = 0, mean = data$gt_sd_mean,
sd = data$gt_sd_sd * 0.1
))
if (data$gt_mean_sd > 0) {
out$gt_mean <- array(truncnorm::rtruncnorm(1,
a = 0, mean = data$gt_mean_mean,
sd = data$gt_mean_sd * 0.1
))
}
if (data$gt_sd_sd > 0) {
out$gt_sd <- array(truncnorm::rtruncnorm(1,
a = 0, mean = data$gt_sd_mean,
sd = data$gt_sd_sd * 0.1
))
}

if (data$bp_n > 0) {
out$bp_sd <- array(truncnorm::rtruncnorm(1, a = 0, mean = 0, sd = 0.1))
out$bp_effects <- array(rnorm(data$bp_n, 0, 0.1))
Expand Down
2 changes: 1 addition & 1 deletion R/epinow.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
#' summary(out, type = "parameters", params = "R")
#' }
epinow <- function(reported_cases,
generation_time,
generation_time = NULL,
delays = delay_opts(),
truncation = trunc_opts(),
rt = rt_opts(),
Expand Down
4 changes: 3 additions & 1 deletion R/extract.R
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,11 @@ extract_parameter_samples <- function(stan_fit, data, reported_dates, reported_i
out$truncation_sd <-
out$truncation_sd[, strat := as.character(time)][, time := NULL][, date := NULL]
}
if (data$estimate_r == 1) {
if (data$estimate_r && data$gt_mean_sd > 0) {
out$gt_mean <- extract_static_parameter("gt_mean", samples)
out$gt_mean <- out$gt_mean[, value := value.V1][, value.V1 := NULL]
}
if (data$estimate_r && data$gt_sd_sd > 0) {
out$gt_sd <- extract_static_parameter("gt_sd", samples)
out$gt_sd <- out$gt_sd[, value := value.V1][, value.V1 := NULL]
}
Expand Down
2 changes: 1 addition & 1 deletion inst/stan/data/generation_time.stan
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
real gt_mean_sd; // prior sd of mean generation time
real gt_mean_mean; // prior mean of mean generation time
real gt_sd_mean; // prior sd of sd of generation time
real gt_sd_mean; // prior mean of sd of generation time
real gt_sd_sd; // prior sd of sd of generation time
int max_gt; // maximum generation time
18 changes: 11 additions & 7 deletions inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ parameters{
vector[estimate_r] log_R; // baseline reproduction number estimate (log)
real initial_infections[estimate_r] ; // seed infections
real initial_growth[estimate_r && seeding_time > 1 ? 1 : 0]; // seed growth rate
real<lower = 0, upper = max_gt> gt_mean[estimate_r]; // mean of generation time
real<lower = 0> gt_sd[estimate_r]; // sd of generation time
real<lower = 0, upper = max_gt> gt_mean[estimate_r && gt_mean_sd > 0]; // mean of generation time (if uncertain)
real<lower = 0> gt_sd[estimate_r && gt_sd_sd > 0]; // sd of generation time (if uncertain)
real<lower = 0> bp_sd[bp_n > 0 ? 1 : 0]; // standard deviation of breakpoint effect
real bp_effects[bp_n]; // Rt breakpoint effects
// observation model
Expand All @@ -67,11 +67,13 @@ transformed parameters {
// Estimate latent infections
if (estimate_r) {
// via Rt
real set_gt_mean = (gt_mean_sd > 0 ? gt_mean[1] : gt_mean_mean);
real set_gt_sd = (gt_sd_sd > 0 ? gt_sd[1] : gt_sd_mean);
R = update_Rt(R, log_R[estimate_r], noise, breakpoints, bp_effects, stationary);
infections = generate_infections(R, seeding_time, gt_mean, gt_sd, max_gt,
infections = generate_infections(R, seeding_time, set_gt_mean, set_gt_sd, max_gt,
initial_infections, initial_growth,
pop, future_time);
}else{
} else {
// via deconvolution
infections = deconvolve_infections(shifted_cases, noise, fixed, backcalc_prior);
}
Expand Down Expand Up @@ -126,11 +128,13 @@ generated quantities {
vector[return_likelihood > 1 ? ot : 0] log_lik;
if (estimate_r){
// estimate growth from estimated Rt
r = R_to_growth(R, gt_mean[1], gt_sd[1]);
real set_gt_mean = (gt_mean_sd > 0 ? gt_mean[1] : gt_mean_mean);
real set_gt_sd = (gt_sd_sd > 0 ? gt_sd[1] : gt_sd_mean);
r = R_to_growth(R, set_gt_mean, set_gt_sd);
}else{
// sample generation time
real gt_mean_sample = normal_rng(gt_mean_mean, gt_mean_sd);
real gt_sd_sample = normal_rng(gt_sd_mean, gt_sd_sd);
real gt_mean_sample = (gt_mean_sd > 0 ? normal_rng(gt_mean_mean, gt_mean_sd) : gt_mean_mean);
real gt_sd_sample = (gt_sd_sd > 0 ? normal_rng(gt_sd_mean, gt_sd_sd) : gt_sd_mean);
// calculate Rt using infections and generation time
gen_R = calculate_Rt(infections, seeding_time, gt_mean_sample, gt_sd_sample,
max_gt, rt_half_window);
Expand Down
19 changes: 15 additions & 4 deletions inst/stan/functions/generated_quantities.stan
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ vector calculate_Rt(vector infections, int seeding_time,
for (i in 1:(max_gt)) {
gt_indexes[i] = max_gt - i + 1;
}
gt_pmf = discretised_gamma_pmf(gt_indexes, gt_mean, gt_sd, max_gt);
if (gt_sd > 0) {
gt_pmf = discretised_gamma_pmf(gt_indexes, gt_mean, gt_sd, max_gt);
} else {
gt_pmf = discretised_delta_pmf(gt_indexes);
}
// calculate Rt using Cori et al. approach
for (s in 1:ot) {
infectiousness[s] += update_infectiousness(infections, gt_pmf, seeding_time, max_gt, s);
Expand All @@ -36,11 +40,18 @@ vector calculate_Rt(vector infections, int seeding_time,
}
// Convert an estimate of Rt to growth
real[] R_to_growth(vector R, real gt_mean, real gt_sd) {
real k = pow(gt_sd / gt_mean, 2);
int t = num_elements(R);
real r[t];
for (s in 1:t) {
r[s] = (pow(R[s], k) - 1) / (k * gt_mean);
if (gt_sd > 0) {
real k = pow(gt_sd / gt_mean, 2);
for (s in 1:t) {
r[s] = (pow(R[s], k) - 1) / (k * gt_mean);
}
} else {
// limit as gt_sd -> 0
for (s in 1:t) {
r[s] = log(R[s]) / gt_mean;
}
}
return(r);
}
21 changes: 19 additions & 2 deletions inst/stan/functions/infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,21 @@
// for a single time point
real update_infectiousness(vector infections, vector gt_pmf,
int seeding_time, int max_gt, int index){
// work out where to start the convolution of past infections with the
// generation time distribution: (current_time - maximal generation time) if
// that is >= 1, otherwise 1
int inf_start = max(1, (index + seeding_time - max_gt));
// work out where to end the convolution: (current_time - 1)
int inf_end = (index + seeding_time - 1);
// number of indices of the generation time to sum over (inf_end - inf_start + 1)
int pmf_accessed = min(max_gt, index + seeding_time - 1);
// calculate the elements of the convolution
real new_inf = dot_product(infections[inf_start:inf_end], tail(gt_pmf, pmf_accessed));
return(new_inf);
}
// generate infections by using Rt = Rt-1 * sum(reversed generation time pmf * infections)
vector generate_infections(vector oR, int uot,
real[] gt_mean, real[] gt_sd, int max_gt,
real gt_mean, real gt_sd, int max_gt,
real[] initial_infections, real[] initial_growth,
int pop, int ht) {
// time indices and storage
Expand All @@ -24,11 +30,18 @@ vector generate_infections(vector oR, int uot,
vector[ot] infectiousness = rep_vector(1e-5, ot);
// generation time pmf
vector[max_gt] gt_pmf = rep_vector(1e-5, max_gt);
// revert indices (this is for later doing the convolution with recent cases)
int gt_indexes[max_gt];
for (i in 1:(max_gt)) {
gt_indexes[i] = max_gt - i + 1;
}
gt_pmf = gt_pmf + discretised_gamma_pmf(gt_indexes, gt_mean[1], gt_sd[1], max_gt);
if (gt_sd > 0) {
// SD > 0: use discretised gamma
gt_pmf = gt_pmf + discretised_gamma_pmf(gt_indexes, gt_mean, gt_sd, max_gt);
} else {
// SD == 0: use discretised delta
gt_pmf = gt_pmf + discretised_delta_pmf(gt_indexes);
}
// Initialise infections using daily growth
infections[1] = exp(initial_infections[1]);
if (uot > 1) {
Expand Down Expand Up @@ -81,7 +94,11 @@ vector deconvolve_infections(vector shifted_cases, vector noise, int fixed,
// Update the log density for the generation time distribution mean and sd
void generation_time_lp(real[] gt_mean, real gt_mean_mean, real gt_mean_sd,
real[] gt_sd, real gt_sd_mean, real gt_sd_sd, int weight) {
if (gt_mean_sd > 0) {
target += normal_lpdf(gt_mean[1] | gt_mean_mean, gt_mean_sd) * weight;
}
if (gt_sd_sd > 0) {
target += normal_lpdf(gt_sd[1] | gt_sd_mean, gt_sd_sd) * weight;
}
}

13 changes: 13 additions & 0 deletions inst/stan/functions/pmfs.stan
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,16 @@ vector reverse_mf(vector pmf, int max_pmf) {
}
return rev_pmf;
}

// discretised truncated gamma pmf
vector discretised_delta_pmf(int[] y) {
int n = num_elements(y);
vector[n] pmf;
pmf[y[1]] = 1;
if (n > 1) {
for (i in 2:n) {
pmf[y[i]] = 0;
}
}
return(pmf);
}
2 changes: 1 addition & 1 deletion inst/stan/simulate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ generated quantities {
for (i in 1:n) {
// generate infections from Rt trace
infections[i] = to_row_vector(generate_infections(to_vector(R[i]), seeding_time,
gt_mean[i], gt_sd[i], max_gt,
gt_mean[i, 1], gt_sd[i, 1], max_gt,
initial_infections[i], initial_growth[i],
pop, future_time));
// convolve from latent infections to mean of observations
Expand Down
5 changes: 5 additions & 0 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@

using namespace Rcpp;

#ifdef RCPP_USE_GLOBAL_ROSTREAM
Rcpp::Rostream<true>& Rcpp::Rcout = Rcpp::Rcpp_cout_get();
Rcpp::Rostream<false>& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get();
#endif


RcppExport SEXP _rcpp_module_boot_stan_fit4estimate_infections_mod();
RcppExport SEXP _rcpp_module_boot_stan_fit4estimate_secondary_mod();
Expand Down