Skip to content

Commit 0a8569e

Browse files
seabbssbfnkactions-user
authored
Add support for missing NAs in estimate_infection() model (#528)
* add a lookup to estimate_infections * add R side support * don't internally impute missing as zero * update news * fix data preprocessing order * correction data ingestion * clean up filtering of leading zeros * error check create_clean_reported_cases and add unit tests to cover function * correct handling of missing data in data preprocessing: * refine data preprocessing * update news and tests * update global variables * Update NEWS.md Co-authored-by: Sebastian Funk <sebastian.funk@lshtm.ac.uk> * Update R/create.R Co-authored-by: Sebastian Funk <sebastian.funk@lshtm.ac.uk> * Update R/create.R Co-authored-by: Sebastian Funk <sebastian.funk@lshtm.ac.uk> * Update R/create.R Co-authored-by: Sebastian Funk <sebastian.funk@lshtm.ac.uk> * fix line length linting * Document --------- Co-authored-by: Sebastian Funk <sebastian.funk@lshtm.ac.uk> Co-authored-by: GitHub Actions <actions@github.com>
1 parent 0cbbefb commit 0a8569e

13 files changed

+126
-47
lines changed

NEWS.md

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
* The functions `get_dist`, `get_generation_time`, `get_incubation_period` have been deprecated and replaced with examples. By @sbfnk in #481 and reviewed by @seabbs.
66
* The utility function `update_list()` has been deprecated in favour of `utils::modifyList()` because it comes with an installation of R. By @jamesmbaazam in #491 and reviewed by @seabbs.
77
* The `fixed` argument to `dist_spec` has been deprecated and replaced by a `fix_dist()` function. By @sbfnk in #503 and reviewed by @seabbs.
8+
* Updated `estimate_infections()` so that rather than imputing missing data, it now skips these data points in the likelihood. This is a breaking change as it alters the behaviour of the model when dates are missing from a time series but are known to be zero. We recommend that users check their results when updating to this version but expect this to in most cases improve performance. By @seabbs in #528 and reviewed by @sbfnk.
89

910
## Documentation
1011

R/create.R

+48-28
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,35 @@
11
#' Create Clean Reported Cases
22
#' @description `r lifecycle::badge("stable")`
3-
#' Cleans a data frame of reported cases by replacing missing dates with 0
4-
#' cases and applies an optional threshold at which point 0 cases are replaced
5-
#' with a moving average of observed cases. See `zero_threshold` for details.
3+
#' Filters leading zeros, completes dates, and applies an optional threshold at
4+
#' which point 0 cases are replaced with a user supplied value (defaults to
5+
#' `NA`).
66
#'
77
#' @param filter_leading_zeros Logical, defaults to TRUE. Should zeros at the
88
#' start of the time series be filtered out.
99
#'
1010
#' @param zero_threshold `r lifecycle::badge("experimental")` Numeric defaults
1111
#' to Inf. Indicates if detected zero cases are meaningful by using a threshold
12-
#' number of cases based on the 7 day average. If the average is above this
13-
#' threshold then the zero is replaced with the backwards looking rolling
14-
#' average. If set to infinity then no changes are made.
12+
#' number of cases based on the 7-day average. If the average is above this
13+
#' threshold then the zero is replaced using `fill`.
14+
#'
15+
#' @param fill Numeric, defaults to NA. Value to use to replace NA values or
16+
#' zeroes that are flagged because the 7-day average is above the
17+
#' `zero_threshold`. If the default NA is used then dates with NA values or with
18+
#' 7-day averages above the `zero_threshold` will be skipped in model fitting.
19+
#' If this is set to 0 then the only effect is to replace NA values with 0.
1520
#'
1621
#' @inheritParams estimate_infections
1722
#' @importFrom data.table copy merge.data.table setorder setDT frollsum
1823
#' @return A cleaned data frame of reported cases
1924
#' @author Sam Abbott
2025
#' @author Lloyd Chapman
2126
#' @export
27+
#' @examples
28+
#' create_clean_reported_cases(example_confirmed, 7)
2229
create_clean_reported_cases <- function(reported_cases, horizon,
2330
filter_leading_zeros = TRUE,
24-
zero_threshold = Inf) {
31+
zero_threshold = Inf,
32+
fill = NA_integer_) {
2533
reported_cases <- data.table::setDT(reported_cases)
2634
reported_cases_grid <- data.table::copy(reported_cases)[,
2735
.(date = seq(min(date), max(date) + horizon, by = "days"))
@@ -35,35 +43,35 @@ create_clean_reported_cases <- function(reported_cases, horizon,
3543
if (is.null(reported_cases$breakpoint)) {
3644
reported_cases$breakpoint <- 0
3745
}
38-
reported_cases <- reported_cases[
39-
is.na(confirm), confirm := 0][, .(date = date, confirm, breakpoint)
40-
]
41-
reported_cases <- reported_cases[is.na(breakpoint), breakpoint := 0]
46+
reported_cases[is.na(breakpoint), breakpoint := 0]
4247
reported_cases <- data.table::setorder(reported_cases, date)
4348
## Filter out 0 reported cases from the beginning of the data
4449
if (filter_leading_zeros) {
4550
reported_cases <- reported_cases[order(date)][
46-
,
47-
cum_cases := cumsum(confirm)
48-
][cum_cases > 0][, cum_cases := NULL]
51+
date >= min(date[confirm[!is.na(confirm)] > 0])
52+
]
4953
}
50-
54+
# Calculate `average_7_day` which for rows with `confirm == 0`
55+
# (the only instance where this is being used) equates to the 7-day
56+
# right-aligned moving average at the previous data point.
57+
reported_cases <-
58+
reported_cases[
59+
,
60+
`:=`(average_7_day = (
61+
data.table::frollsum(confirm, n = 8, na.rm = TRUE)
62+
) / 7
63+
)
64+
]
5165
# Check case counts preceding zero case counts and set to 7 day average if
5266
# average over last 7 days is greater than a threshold
5367
if (!is.infinite(zero_threshold)) {
54-
reported_cases <-
55-
reported_cases[
56-
,
57-
`:=`(average_7 = (data.table::frollsum(confirm, n = 8)) / 7)
58-
]
5968
reported_cases <- reported_cases[
60-
confirm == 0 & average_7 > zero_threshold,
61-
confirm := as.integer(average_7)
62-
][
63-
,
64-
"average_7" := NULL
69+
confirm == 0 & average_7_day > zero_threshold,
70+
confirm := NA_integer_
6571
]
6672
}
73+
reported_cases[is.na(confirm), confirm := fill]
74+
reported_cases[, "average_7_day" := NULL]
6775
return(reported_cases)
6876
}
6977

@@ -429,14 +437,26 @@ create_obs_model <- function(obs = obs_opts(), dates) {
429437
#' @author Sam Abbott
430438
#' @author Sebastian Funk
431439
#' @export
440+
#' @examples
441+
#' create_stan_data(
442+
#' example_confirmed, 7, rt_opts(), gp_opts(), obs_opts(), 7,
443+
#' backcalc_opts(), create_shifted_cases(example_confirmed, 7, 14, 7)
444+
#' )
432445
create_stan_data <- function(reported_cases, seeding_time,
433446
rt, gp, obs, horizon,
434447
backcalc, shifted_cases) {
435448

436-
cases <- reported_cases[(seeding_time + 1):(.N - horizon)]$confirm
449+
cases <- reported_cases[(seeding_time + 1):(.N - horizon)]
450+
cases[, lookup := seq_len(.N)]
451+
complete_cases <- cases[!is.na(cases$confirm)]
452+
cases_time <- complete_cases$lookup
453+
complete_cases <- complete_cases$confirm
454+
cases <- cases$confirm
437455

438456
data <- list(
439-
cases = cases,
457+
cases = complete_cases,
458+
cases_time = cases_time,
459+
lt = length(cases_time),
440460
shifted_cases = shifted_cases,
441461
t = length(reported_cases$date),
442462
horizon = horizon,
@@ -455,7 +475,7 @@ create_stan_data <- function(reported_cases, seeding_time,
455475
first_week <- data.table::data.table(
456476
confirm = cases[seq_len(min(7, length(cases)))],
457477
t = seq_len(min(7, length(cases)))
458-
)
478+
)[!is.na(confirm)]
459479
data$prior_infections <- log(mean(first_week$confirm, na.rm = TRUE))
460480
data$prior_infections <- ifelse(
461481
is.na(data$prior_infections) || is.null(data$prior_infections),

R/estimate_infections.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ estimate_infections <- function(reported_cases,
165165
name = "EpiNow2.epinow.estimate_infections"
166166
)
167167
}
168-
# Make sure there are no missing dates and order cases
168+
# Order cases
169169
reported_cases <- create_clean_reported_cases(
170170
reported_cases, horizon,
171171
filter_leading_zeros = filter_leading_zeros,

R/utilities.R

+2-2
Original file line numberDiff line numberDiff line change
@@ -444,9 +444,9 @@ globalVariables(
444444
"New confirmed cases by infection date", "Data", "R", "reference",
445445
".SD", "day_of_week", "forecast_type", "measure", "numeric_estimate",
446446
"point", "strat", "estimate", "breakpoint", "variable", "value.V1",
447-
"central_lower", "central_upper", "mean_sd", "sd_sd", "average_7",
447+
"central_lower", "central_upper", "mean_sd", "sd_sd", "average_7_day",
448448
"..lowers", "..upper_CrI", "..uppers", "timing", "dataset", "last_confirm",
449449
"report_date", "secondary", "id", "conv", "meanlog", "primary", "scaled",
450-
"scaling", "sdlog"
450+
"scaling", "sdlog", "lookup"
451451
)
452452
)

inst/stan/data/observations.stan

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
int t; // unobserved time
2+
int lt; // timepoints in the likelihood
23
int seeding_time; // time period used for seeding and not observed
34
int horizon; // forecast horizon
45
int future_time; // time in future for Rt
5-
array[t - horizon - seeding_time] int<lower = 0> cases; // observed cases
6+
array[lt] int<lower = 0> cases; // observed cases
7+
array[lt] int cases_time; // time of observed cases
68
vector<lower = 0>[t] shifted_cases; // prior infections (for backcalculation)

inst/stan/estimate_infections.stan

+3-2
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,8 @@ model {
148148
// observed reports from mean of reports (update likelihood)
149149
if (likelihood) {
150150
report_lp(
151-
cases, obs_reports, rep_phi, phi_mean, phi_sd, model_type, obs_weight
151+
cases, obs_reports[cases_time], rep_phi, phi_mean, phi_sd, model_type,
152+
obs_weight
152153
);
153154
}
154155
}
@@ -191,7 +192,7 @@ generated quantities {
191192
// log likelihood of model
192193
if (return_likelihood) {
193194
log_lik = report_log_lik(
194-
cases, obs_reports, rep_phi, model_type, obs_weight
195+
cases, obs_reports[cases_time], rep_phi, model_type, obs_weight
195196
);
196197
}
197198
}

man/create_clean_reported_cases.Rd

+16-7
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/create_stan_data.Rd

+6
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/epinow.Rd

+2-3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/estimate_infections.Rd

+2-3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
2+
test_that("create_clean_reported_cases runs without errors", {
3+
expect_no_error(create_clean_reported_cases(example_confirmed, 7))
4+
})
5+
6+
test_that("create_clean_reported_cases returns a data table", {
7+
result <- create_clean_reported_cases(example_confirmed, 7)
8+
expect_s3_class(result, "data.table")
9+
})
10+
11+
test_that("create_clean_reported_cases filters leading zeros correctly", {
12+
# Modify example_confirmed to have leading zeros
13+
modified_data <- example_confirmed
14+
modified_data[1:3, "confirm"] <- 0
15+
16+
result <- create_clean_reported_cases(modified_data, 7)
17+
# Check if the first row with non-zero cases is retained
18+
expect_equal(
19+
result$date[1], min(modified_data$date[modified_data$confirm > 0])
20+
)
21+
})
22+
23+
test_that("create_clean_reported_cases replaces zero cases correctly", {
24+
# Modify example_confirmed to have zero cases that should be replaced
25+
modified_data <- example_confirmed
26+
modified_data$confirm[10:16] <- 0
27+
threshold <- 10
28+
29+
result <- create_clean_reported_cases(
30+
modified_data, 0, zero_threshold = threshold
31+
)
32+
# Check if zero cases within the threshold are replaced
33+
expect_equal(sum(result$confirm == 0, na.rm = TRUE), 0)
34+
})

tests/testthat/test-create_stan_data.R

Whitespace-only changes.

tests/testthat/test-estimate_infections.R

+8
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,14 @@ test_that("estimate_infections successfully returns estimates using default sett
3636
test_estimate_infections(reported_cases)
3737
})
3838

39+
test_that("estimate_infections successfully returns estimates when passed NA values", {
40+
skip_on_cran()
41+
reported_cases_na <- data.table::copy(reported_cases)
42+
reported_cases_na[sample(1:30, 5), confirm := NA]
43+
test_estimate_infections(reported_cases_na)
44+
})
45+
46+
3947
test_that("estimate_infections successfully returns estimates using no delays", {
4048
skip_on_cran()
4149
test_estimate_infections(reported_cases, delay = FALSE)

0 commit comments

Comments
 (0)