Skip to content

Commit

Permalink
predict: Fix confidence intervals and standard errors
Browse files Browse the repository at this point in the history
  • Loading branch information
hsbadr committed Sep 29, 2021
1 parent 8187e41 commit 3170b19
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 16 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# additive 0.0.3

- Fixed class predictions for binary classification
- Fixed confidence intervals and standard errors
- Added threshold probability option for class predictions
- Replaced deprecated `pull_workflow_fit()`

Expand Down
28 changes: 12 additions & 16 deletions R/additive_make.R
Original file line number Diff line number Diff line change
Expand Up @@ -458,11 +458,10 @@ additive_make <- function(modes = c("classification", "regression")) {
hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level) / 2
const <-
stats::qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE)
trans <- object$fit$family$linkinv
res_2 <-
tibble::tibble(
lo = trans(results$fit - const * results$se.fit),
hi = trans(results$fit + const * results$se.fit)
lo = results$fit - const * results$se.fit,
hi = results$fit + const * results$se.fit
)
res_1 <- res_2
res_1$lo <- 1 - res_2$hi
Expand All @@ -482,7 +481,7 @@ additive_make <- function(modes = c("classification", "regression")) {
args = list(
object = rlang::expr(object$fit),
newdata = rlang::expr(new_data),
type = "link",
type = "response",
se.fit = TRUE
)
)
Expand All @@ -499,11 +498,10 @@ additive_make <- function(modes = c("classification", "regression")) {
hf_lvl <- (1 - object$spec$method$pred$pred_int$extras$level) / 2
const <-
stats::qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE)
trans <- object$fit$family$linkinv
res_2 <-
tibble::tibble(
lo = trans(results$fit - const * results$se.fit),
hi = trans(results$fit + const * results$se.fit)
lo = results$fit - const * results$se.fit,
hi = results$fit + const * results$se.fit
)
res_1 <- res_2
res_1$lo <- 1 - res_2$hi
Expand All @@ -523,7 +521,7 @@ additive_make <- function(modes = c("classification", "regression")) {
args = list(
object = rlang::expr(object$fit),
newdata = rlang::expr(new_data),
type = "link",
type = "response",
se.fit = TRUE
)
)
Expand Down Expand Up @@ -559,11 +557,10 @@ additive_make <- function(modes = c("classification", "regression")) {
hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level) / 2
const <-
stats::qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE)
trans <- object$fit$family$linkinv
res <-
tibble::tibble(
.pred_lower = trans(results$fit - const * results$se.fit),
.pred_upper = trans(results$fit + const * results$se.fit)
.pred_lower = results$fit - const * results$se.fit,
.pred_upper = results$fit + const * results$se.fit
)
# In case of inverse or other links
if (any(res$.pred_upper < res$.pred_lower)) {
Expand All @@ -581,7 +578,7 @@ additive_make <- function(modes = c("classification", "regression")) {
args = list(
object = rlang::expr(object$fit),
newdata = rlang::expr(new_data),
type = "link",
type = "response",
se.fit = TRUE
)
)
Expand All @@ -598,11 +595,10 @@ additive_make <- function(modes = c("classification", "regression")) {
hf_lvl <- (1 - object$spec$method$pred$pred_int$extras$level) / 2
const <-
stats::qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE)
trans <- object$fit$family$linkinv
res <-
tibble::tibble(
.pred_lower = trans(results$fit - const * results$se.fit),
.pred_upper = trans(results$fit + const * results$se.fit)
.pred_lower = results$fit - const * results$se.fit,
.pred_upper = results$fit + const * results$se.fit
)
# In case of inverse or other links
if (any(res$.pred_upper < res$.pred_lower)) {
Expand All @@ -620,7 +616,7 @@ additive_make <- function(modes = c("classification", "regression")) {
args = list(
object = rlang::expr(object$fit),
newdata = rlang::expr(new_data),
type = "link",
type = "response",
se.fit = TRUE
)
)
Expand Down

0 comments on commit 3170b19

Please sign in to comment.