Skip to content

Commit

Permalink
Merge pull request #472 from tidymodels/print-461
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilHvitfeldt authored Dec 19, 2023
2 parents 2f75b6c + 5de4385 commit 53e1f1c
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 0 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ S3method(finalize_estimator_internal,pr_curve)
S3method(finalize_estimator_internal,roc_auc)
S3method(finalize_estimator_internal,roc_curve)
S3method(format,metric)
S3method(format,metric_factory)
S3method(format,metric_set)
S3method(gain_capture,data.frame)
S3method(gain_curve,data.frame)
Expand Down Expand Up @@ -76,6 +77,7 @@ S3method(precision,matrix)
S3method(precision,table)
S3method(print,conf_mat)
S3method(print,metric)
S3method(print,metric_factory)
S3method(print,metric_set)
S3method(recall,data.frame)
S3method(recall,matrix)
Expand Down
14 changes: 14 additions & 0 deletions R/fair-aaa.R
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,20 @@ groupwise_metric_class <- function(fn) {
class(attr(fn, "metrics")[[1]])
}

#' @noRd
#' @export
print.metric_factory <- function(x, ...) {
cat(format(x), sep = "\n")
invisible(x)
}

#' @export
format.metric_factory <- function(x, ...) {
cli::cli_format_method(
cli::cli_text("A {.help [metric factory](yardstick::new_groupwise_metric)}")
)
}

diff_range <- function(x) {
estimates <- x$.estimate

Expand Down
7 changes: 7 additions & 0 deletions tests/testthat/_snaps/fair-aaa.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# metric factory print method works

Code
equal_opportunity
Output
A metric factory (`?yardstick::new_groupwise_metric()`)

# handles `direction` input

Code
Expand Down
4 changes: 4 additions & 0 deletions tests/testthat/test-fair-aaa.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ test_that("new_groupwise_metric() works with grouped input", {
expect_identical(grouped_res, split_res)
})

test_that("metric factory print method works", {
expect_snapshot(equal_opportunity)
})

test_that("can accommodate redundant sensitive features", {
data("hpc_cv")

Expand Down

0 comments on commit 53e1f1c

Please sign in to comment.