From 3170b1910ca42dfe6b7514586b8183df3b0e5277 Mon Sep 17 00:00:00 2001 From: "Hamada S. Badr" Date: Wed, 29 Sep 2021 04:30:04 -0400 Subject: [PATCH] predict: Fix confidence intervals and standard errors --- NEWS.md | 1 + R/additive_make.R | 28 ++++++++++++---------------- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/NEWS.md b/NEWS.md index cdf43f3..9d4c84d 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,6 +1,7 @@ # additive 0.0.3 - Fixed class predictions for binary classification +- Fixed confidence intervals and standard errors - Added threshold probability option for class predictions - Replaced deprecated `pull_workflow_fit()` diff --git a/R/additive_make.R b/R/additive_make.R index 932fb0a..444a2f6 100644 --- a/R/additive_make.R +++ b/R/additive_make.R @@ -458,11 +458,10 @@ additive_make <- function(modes = c("classification", "regression")) { hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level) / 2 const <- stats::qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE) - trans <- object$fit$family$linkinv res_2 <- tibble::tibble( - lo = trans(results$fit - const * results$se.fit), - hi = trans(results$fit + const * results$se.fit) + lo = results$fit - const * results$se.fit, + hi = results$fit + const * results$se.fit ) res_1 <- res_2 res_1$lo <- 1 - res_2$hi @@ -482,7 +481,7 @@ additive_make <- function(modes = c("classification", "regression")) { args = list( object = rlang::expr(object$fit), newdata = rlang::expr(new_data), - type = "link", + type = "response", se.fit = TRUE ) ) @@ -499,11 +498,10 @@ additive_make <- function(modes = c("classification", "regression")) { hf_lvl <- (1 - object$spec$method$pred$pred_int$extras$level) / 2 const <- stats::qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE) - trans <- object$fit$family$linkinv res_2 <- tibble::tibble( - lo = trans(results$fit - const * results$se.fit), - hi = trans(results$fit + const * results$se.fit) + lo = results$fit - const * results$se.fit, + hi = results$fit + const * results$se.fit ) res_1 <- res_2 res_1$lo <- 1 - res_2$hi @@ -523,7 +521,7 @@ additive_make <- function(modes = c("classification", "regression")) { args = list( object = rlang::expr(object$fit), newdata = rlang::expr(new_data), - type = "link", + type = "response", se.fit = TRUE ) ) @@ -559,11 +557,10 @@ additive_make <- function(modes = c("classification", "regression")) { hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level) / 2 const <- stats::qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE) - trans <- object$fit$family$linkinv res <- tibble::tibble( - .pred_lower = trans(results$fit - const * results$se.fit), - .pred_upper = trans(results$fit + const * results$se.fit) + .pred_lower = results$fit - const * results$se.fit, + .pred_upper = results$fit + const * results$se.fit ) # In case of inverse or other links if (any(res$.pred_upper < res$.pred_lower)) { @@ -581,7 +578,7 @@ additive_make <- function(modes = c("classification", "regression")) { args = list( object = rlang::expr(object$fit), newdata = rlang::expr(new_data), - type = "link", + type = "response", se.fit = TRUE ) ) @@ -598,11 +595,10 @@ additive_make <- function(modes = c("classification", "regression")) { hf_lvl <- (1 - object$spec$method$pred$pred_int$extras$level) / 2 const <- stats::qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE) - trans <- object$fit$family$linkinv res <- tibble::tibble( - .pred_lower = trans(results$fit - const * results$se.fit), - .pred_upper = trans(results$fit + const * results$se.fit) + .pred_lower = results$fit - const * results$se.fit, + .pred_upper = results$fit + const * results$se.fit ) # In case of inverse or other links if (any(res$.pred_upper < res$.pred_lower)) { @@ -620,7 +616,7 @@ additive_make <- function(modes = c("classification", "regression")) { args = list( object = rlang::expr(object$fit), newdata = rlang::expr(new_data), - type = "link", + type = "response", se.fit = TRUE ) )