Skip to content

Commit c5c1338

Browse files
sbfnkseabbs
andauthored
distribution interface to dist_spec (#504)
* add distribution functions * deprecate "empty" distribution * make sd S3 * only generate samples if any params aren't natural * update stan model with new dist interface * update lognormal parameters * return mean function to previous functionality * update data * deprecate dist_def functions * use natural parametrisations in dist_def functions * deprecate dist_spec * extract_single_dist function * update fix_dist to work with compsosite dists * extract squash * update parameters to extract * specify lower bounds in function * pass lower bounds to stan model * update sample/report functions * max squash adjust report * update dist functions to new syntax * re-create data * update get_dist to new syntax * fully deprecate get fnuctions * create delay inits separately * max squash again * return correct dist in estimate_truncation * few more examples/docs * fix tests * add documentation to dist interface * add input checks * sd function to work with composite dists * warn when not using natural parameters * ensure bounds are respected in stan * add empty distribution for legacy reasons * add checks to dist_skel * use lapply for parameters * don't calculate sd if length 1 * use uncertain reporting in example * don't add one to sd * return correct parameters * dist_skel: calculate rate everywhere * update dist_skel examples * add missing man file * don't run internal examples * demote warning to message * update syntax everywhere * add news item * turn sd into an internal function * fix distribution documentation * remove obselete default * spell checking * use correct sd function * linting * remove obsolete tests * loop over all parameters * update touchstone arguments * linting * fix regex search/replace gone wrong * remove obsolete space * update strategy for estimating uncertainty * update uncertain parameter transformations * add missing sd to parameter sampling * update / recompile vignettes * update var names * rename argument in docs * update man pages * update test result * add reviewer Co-authored-by: Sam Abbott <s.e.abbott12@gmail.com> * base scaling on variance, not sd * re-render vignettes * full text capitalisation of distributions * separate dist_spec from stan model * adjust tests/code for new dist_spec set up * re-create examples * re-doc * update tests * new dist_spec in estimate_truncation example * update get_seeding_time with updated dist_spec * estimate_truncation and seeding time tests * update truncation dist in estimate_truncation * remove more uses of old dist_spec * SD explicitly to zero for fixed * give names * fix typo * fix indent * fix another typo * squash bugs highlighted by tests * remove missing variable * linting * add missing docs * import transpose * ensure sd is positive * fix estimate_truncation example * make tolerance user-settable * use purrr::map instead of lapply * fix stan dist test * fix plotting * Apply suggestions from code review Co-authored-by: Sam Abbott <s.e.abbott12@gmail.com> * rate and scale examples for Gamma Co-authored-by: Sam Abbott <s.e.abbott12@gmail.com> * capitalise gamma and lognormal Co-authored-by: Sam Abbott <s.e.abbott12@gmail.com> * change to single hash * use bar in normal_cdf * remove estraneous backticks * remove space before left parenthesis * split up dist.R * move deprecated `dist_spec` function * add examples * initial design sketch * make parameter conversion more flexible * add test for alternative gama params * update syntax in simulate_infections * add missing tag * update man pages * update estimate_secondary tests * update simulate_infections for new interface * udpate snapshots * get_dist deprecation test with natural params * update phi syntax * hide internal example * update deprecations * use toString * pmf -> NonParametric * add american spelling * fix gamma deprecation * add new functions to pkgdown * update vignette * recompile vignettes --------- Co-authored-by: Sam Abbott <s.e.abbott12@gmail.com>
1 parent 50e0b60 commit c5c1338

File tree

122 files changed

+3774
-2590
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

122 files changed

+3774
-2590
lines changed

NAMESPACE

+14
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Generated by roxygen2: do not edit by hand
22

33
S3method("+",dist_spec)
4+
S3method(c,dist_spec)
5+
S3method(max,dist_spec)
46
S3method(mean,dist_spec)
57
S3method(plot,dist_spec)
68
S3method(plot,epinow)
@@ -10,16 +12,23 @@ S3method(plot,estimate_truncation)
1012
S3method(print,dist_spec)
1113
S3method(summary,epinow)
1214
S3method(summary,estimate_infections)
15+
export(Fixed)
16+
export(Gamma)
17+
export(LogNormal)
18+
export(NonParametric)
19+
export(Normal)
1320
export(R_to_growth)
1421
export(add_day_of_week)
1522
export(adjust_infection_to_report)
23+
export(apply_tolerance)
1624
export(backcalc_opts)
1725
export(bootstrapped_dist_fit)
1826
export(calc_CrI)
1927
export(calc_CrIs)
2028
export(calc_summary_measures)
2129
export(calc_summary_stats)
2230
export(clean_nowcasts)
31+
export(collapse)
2332
export(construct_output)
2433
export(convert_to_logmean)
2534
export(convert_to_logsd)
@@ -35,6 +44,8 @@ export(create_shifted_cases)
3544
export(create_stan_args)
3645
export(create_stan_data)
3746
export(delay_opts)
47+
export(discretise)
48+
export(discretize)
3849
export(dist_fit)
3950
export(dist_skel)
4051
export(dist_spec)
@@ -195,16 +206,19 @@ importFrom(posterior,mcse_mean)
195206
importFrom(progressr,progressor)
196207
importFrom(progressr,with_progress)
197208
importFrom(purrr,compact)
209+
importFrom(purrr,flatten)
198210
importFrom(purrr,keep)
199211
importFrom(purrr,list_transpose)
200212
importFrom(purrr,map)
201213
importFrom(purrr,map2_dbl)
202214
importFrom(purrr,map_chr)
203215
importFrom(purrr,map_dbl)
216+
importFrom(purrr,map_dfc)
204217
importFrom(purrr,pmap_dbl)
205218
importFrom(purrr,quietly)
206219
importFrom(purrr,reduce)
207220
importFrom(purrr,safely)
221+
importFrom(purrr,transpose)
208222
importFrom(purrr,walk)
209223
importFrom(rlang,abort)
210224
importFrom(rlang,arg_match)

NEWS.md

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
* `simulate_infections` has been renamed to `forecast_infections` in line with `simulate_secondary` and `forecast_secondary`. The terminology is: a forecast is done from a fit to existing data, a simulation from first principles. By @sbfnk in #544 and reviewed by @seabbs.
1010
* A new `simulate_infections` function has been added that can be used to simulate from the model from given initial conditions and parameters. By @sbfnk in #557 and reviewed by @jamesmbaazam.
1111
* The function `init_cumulative_fit()` has been deprecated. By @jamesmbaazam in #541 and reviewed by @sbfnk.
12+
* The interface to generating delay distributions has been completely overhauled. Instead of calling `dist_spec()` users now specify distributions using functions that represent the available distributions, i.e. `LogNormal()`, `Gamma()` and `Fixed()`. Uncertainty is specified using calls of the same nature, to `Normal()`. More information on the underlying design can be found in `inst/dev/design_dist.md` By @sbfnk in #504 and reviewed by @seabbs.
1213

1314
## Documentation
1415

R/adjust.R

+227-5
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ adjust_infection_to_report <- function(infections, delay_defs,
8585
# Reset DT Defaults on Exit
8686
set_dt_single_thread()
8787

88+
## deprecated
8889
sample_single_dist <- function(input, delay_def) {
8990
## Define sample delay fn
9091
sample_delay_fn <- function(n, ...) {
@@ -111,14 +112,50 @@ adjust_infection_to_report <- function(infections, delay_defs,
111112
return(out)
112113
}
113114

114-
report <- sample_single_dist(infections, delay_defs[[1]])
115-
116-
if (length(delay_defs) > 1) {
117-
for (def in 2:length(delay_defs)) {
118-
report <- sample_single_dist(report, delay_defs[[def]])
115+
sample_dist_spec <- function(input, delay_def) {
116+
## Define sample delay fn
117+
sample_delay_fn <- function(n, dist, cum, ...) {
118+
fixed_dist <- discretise(fix_dist(delay_def, strategy = "sample"))
119+
if (dist) {
120+
fixed_dist[[1]]$pmf[n + 1]
121+
} else {
122+
sample(seq_along(fixed_dist[[1]]$pmf) - 1, size = n, replace = TRUE)
123+
}
119124
}
125+
126+
## Infection to onset
127+
out <- EpiNow2::sample_approx_dist(
128+
cases = input,
129+
dist_fn = sample_delay_fn,
130+
max_value = max(delay_def),
131+
direction = "forwards",
132+
type = type,
133+
truncate_future = FALSE
134+
)
135+
136+
return(out)
120137
}
121138

139+
if (is(delay_defs, "dist_spec")) {
140+
report <- sample_dist_spec(infections, extract_single_dist(delay_defs, 1))
141+
if (length(delay_defs) > 1) {
142+
for (def in seq(2, length(delay_defs))) {
143+
report <- sample_dist_spec(report, extract_single_dist(delay_defs, def))
144+
}
145+
}
146+
} else {
147+
deprecate_warn(
148+
"1.5.0",
149+
"adjust_infection_to_report(delay_defs = 'should be a dist_spec')",
150+
details = "Specifying this as a list of data tables is deprecated."
151+
)
152+
report <- sample_single_dist(infections, delay_defs[[1]])
153+
if (length(delay_defs) > 1) {
154+
for (def in 2:length(delay_defs)) {
155+
report <- sample_single_dist(report, delay_defs[[def]])
156+
}
157+
}
158+
}
122159
## Add a weekly reporting effect if present
123160
if (!missing(reporting_effect)) {
124161
reporting_effect <- data.table::data.table(
@@ -146,3 +183,188 @@ adjust_infection_to_report <- function(infections, delay_defs,
146183
}
147184
return(report)
148185
}
186+
187+
#' Approximate Sampling a Distribution using Counts
188+
#'
189+
#' @description `r lifecycle::badge("soft-deprecated")`
190+
#' Convolves cases by a PMF function. This function will soon be removed or
191+
#' replaced with a more robust stan implementation.
192+
#'
193+
#' @param cases A `<data.frame>` of cases (in date order) with the following
194+
#' variables: `date` and `cases`.
195+
#'
196+
#' @param max_value Numeric, maximum value to allow. Defaults to 120 days
197+
#'
198+
#' @param direction Character string, defato "backwards". Direction in which to
199+
#' map cases. Supports either "backwards" or "forwards".
200+
#'
201+
#' @param dist_fn Function that takes two arguments with the first being
202+
#' numeric and the second being logical (and defined as `dist`). Should return
203+
#' the probability density or a sample from the defined distribution. See
204+
#' the examples for more.
205+
#'
206+
#' @param earliest_allowed_mapped A character string representing a date
207+
#' ("2020-01-01"). Indicates the earliest allowed mapped value.
208+
#'
209+
#' @param type Character string indicating the method to use to transform
210+
#' counts. Supports either "sample" which approximates sampling or "median"
211+
#' would shift by the median of the distribution.
212+
#'
213+
#' @param truncate_future Logical, should cases be truncated if they occur
214+
#' after the first date reported in the data. Defaults to `TRUE`.
215+
#'
216+
#' @return A `<data.table>` of cases by date of onset
217+
#' @export
218+
#' @importFrom purrr map_dfc
219+
#' @importFrom data.table data.table setorder
220+
#' @importFrom lubridate days
221+
#' @examples
222+
#' \donttest{
223+
#' cases <- example_confirmed
224+
#' cases <- cases[, cases := as.integer(confirm)]
225+
#' print(cases)
226+
#'
227+
#' # total cases
228+
#' sum(cases$cases)
229+
#'
230+
#' delay_fn <- function(n, dist, cum) {
231+
#' if (dist) {
232+
#' pgamma(n + 0.9999, 2, 1) - pgamma(n - 1e-5, 2, 1)
233+
#' } else {
234+
#' as.integer(rgamma(n, 2, 1))
235+
#' }
236+
#' }
237+
#'
238+
#' onsets <- sample_approx_dist(
239+
#' cases = cases,
240+
#' dist_fn = delay_fn
241+
#' )
242+
#'
243+
#' # estimated onset distribution
244+
#' print(onsets)
245+
#'
246+
#' # check that sum is equal to reported cases
247+
#' total_onsets <- median(
248+
#' purrr::map_dbl(
249+
#' 1:10,
250+
#' ~ sum(sample_approx_dist(
251+
#' cases = cases,
252+
#' dist_fn = delay_fn
253+
#' )$cases)
254+
#' )
255+
#' )
256+
#' total_onsets
257+
#'
258+
#'
259+
#' # map from onset cases to reported
260+
#' reports <- sample_approx_dist(
261+
#' cases = cases,
262+
#' dist_fn = delay_fn,
263+
#' direction = "forwards"
264+
#' )
265+
#'
266+
#'
267+
#' # map from onset cases to reported using a mean shift
268+
#' reports <- sample_approx_dist(
269+
#' cases = cases,
270+
#' dist_fn = delay_fn,
271+
#' direction = "forwards",
272+
#' type = "median"
273+
#' )
274+
#' }
275+
sample_approx_dist <- function(cases = NULL,
276+
dist_fn = NULL,
277+
max_value = 120,
278+
earliest_allowed_mapped = NULL,
279+
direction = "backwards",
280+
type = "sample",
281+
truncate_future = TRUE) {
282+
if (type == "sample") {
283+
if (direction == "backwards") {
284+
direction_fn <- rev
285+
} else if (direction == "forwards") {
286+
direction_fn <- function(x) {
287+
x
288+
}
289+
}
290+
# reverse cases so starts with current first
291+
reversed_cases <- direction_fn(cases$cases)
292+
reversed_cases[is.na(reversed_cases)] <- 0
293+
# draw from the density fn of the dist
294+
draw <- dist_fn(0:max_value, dist = TRUE, cum = FALSE)
295+
296+
# approximate cases
297+
mapped_cases <- do.call(cbind, purrr::map(
298+
seq_along(reversed_cases),
299+
~ c(
300+
rep(0, . - 1),
301+
stats::rbinom(
302+
length(draw),
303+
rep(reversed_cases[.], length(draw)),
304+
draw
305+
),
306+
rep(0, length(reversed_cases) - .)
307+
)
308+
))
309+
310+
311+
# set dates order based on direction mapping
312+
if (direction == "backwards") {
313+
dates <- seq(min(cases$date) - lubridate::days(length(draw) - 1),
314+
max(cases$date),
315+
by = "days"
316+
)
317+
} else if (direction == "forwards") {
318+
dates <- seq(min(cases$date),
319+
max(cases$date) + lubridate::days(length(draw) - 1),
320+
by = "days"
321+
)
322+
}
323+
324+
# summarises movements and sample for placement of non-integer cases
325+
case_sum <- direction_fn(rowSums(mapped_cases))
326+
floor_case_sum <- floor(case_sum)
327+
sample_cases <- floor_case_sum +
328+
as.numeric((runif(seq_along(case_sum)) < (case_sum - floor_case_sum)))
329+
330+
# summarise imputed onsets and build output data.table
331+
mapped_cases <- data.table::data.table(
332+
date = dates,
333+
cases = sample_cases
334+
)
335+
336+
# filter out all zero cases until first recorded case
337+
mapped_cases <- data.table::setorder(mapped_cases, date)
338+
mapped_cases <- mapped_cases[
339+
,
340+
cum_cases := cumsum(cases)
341+
][cum_cases != 0][, cum_cases := NULL]
342+
} else if (type == "median") {
343+
shift <- as.integer(
344+
median(as.integer(dist_fn(1000, dist = FALSE)), na.rm = TRUE)
345+
)
346+
347+
if (direction == "backwards") {
348+
mapped_cases <- data.table::copy(cases)[
349+
,
350+
date := date - lubridate::days(shift)
351+
]
352+
} else if (direction == "forwards") {
353+
mapped_cases <- data.table::copy(cases)[
354+
,
355+
date := date + lubridate::days(shift)
356+
]
357+
}
358+
}
359+
360+
if (!is.null(earliest_allowed_mapped)) {
361+
mapped_cases <- mapped_cases[date >= as.Date(earliest_allowed_mapped)]
362+
}
363+
364+
# filter out future cases
365+
if (direction == "forwards" && truncate_future) {
366+
max_date <- max(cases$date)
367+
mapped_cases <- mapped_cases[date <= max_date]
368+
}
369+
return(mapped_cases)
370+
}

R/checks.R

+50
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,53 @@ check_reports_valid <- function(reports, model) {
5353
assert_numeric(reports$confirm, lower = 0)
5454
}
5555
}
56+
57+
#' Validate probability distribution for passing to stan
58+
#'
59+
#' @description
60+
#' `check_stan_delay()` checks that the supplied data is a `<dist_spec>`,
61+
#' that it is a supported distribution, and that is has a finite maximum.
62+
#'
63+
#' @param dist A `dist_spec` object.`
64+
#' @importFrom checkmate assert_class
65+
#' @importFrom rlang arg_match
66+
#' @return Called for its side effects.
67+
#' @keywords internal
68+
check_stan_delay <- function(dist) {
69+
# Check that `dist` is a `dist_spec`
70+
assert_class(dist, "dist_spec")
71+
# Check that `dist` is lognormal or gamma or nonparametric
72+
distributions <- vapply(dist, function(x) x$distribution, character(1))
73+
if (
74+
!all(distributions %in% c("lognormal", "gamma", "fixed", "nonparametric"))
75+
) {
76+
stop(
77+
"Distributions passed to the model need to be lognormal, gamma, fixed ",
78+
"or nonparametric."
79+
)
80+
}
81+
# Check that `dist` has parameters that are either numeric or normal
82+
# distributions with numeric parameters and infinite maximum
83+
numeric_parameters <- vapply(dist$parameters, is.numeric, logical(1))
84+
normal_parameters <- vapply(
85+
dist$parameters,
86+
function(x) {
87+
is(x, "dist_spec") &&
88+
x$distribution == "normal" &&
89+
all(vapply(x$parameters, is.numeric, logical(1))) &&
90+
is.infinite(x$max)
91+
},
92+
logical(1)
93+
)
94+
if (!all(numeric_parameters | normal_parameters)) {
95+
stop(
96+
"Delay distributions passed to the model need to have parameters that ",
97+
"are either numeric or normally distributed with numeric parameters ",
98+
"and infinite maximum."
99+
)
100+
}
101+
# Check that `dist` has a finite maximum
102+
if (any(is.infinite(max(dist)))) {
103+
stop("All distribution passed to the model need to have a finite maximum")
104+
}
105+
}

0 commit comments

Comments
 (0)