Skip to content

Commit

Permalink
Merge pull request #109 from mayer79/faster-perm-imp
Browse files Browse the repository at this point in the history
Slight speed-up of perm_importance() for data.frames
  • Loading branch information
mayer79 authored Nov 26, 2023
2 parents bfca690 + 09af03c commit fab07aa
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 8 deletions.
5 changes: 3 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

## Other changes

- In multivariate cases, it was possible that normalized H-statistics could equal `0/0 (= NaN)`. Such values are now replaced by 0 ([Issue #107](https://github.com/mayer79/hstats/issues/107)).
- Removed an unnecessary special case when calculating column means ([PR #106](https://github.com/mayer79/hstats/pull/106)).
- In multivariate cases, it was possible that normalized H-statistics could equal `0/0 (= NaN)`. Such values are now replaced by 0 ([#107](https://github.com/mayer79/hstats/issues/107)).
- Removed an unnecessary special case when calculating column means ([#106](https://github.com/mayer79/hstats/pull/106)).
- Slight speed-up of permutation importance for non-matrix `X` ([#109](https://github.com/mayer79/hstats/pull/109)).

# hstats 1.1.0

Expand Down
6 changes: 5 additions & 1 deletion R/perm_importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,11 @@ perm_importance.default <- function(object, X, y, v = NULL,

shuffle_perf <- function(z, XX) {
ind <- c(replicate(m_rep, sample(seq_len(n)))) # shuffle within n rows
XX[, z] <- XX[ind, z]
if (is.matrix(XX) || length(z) > 1L) {
XX[, z] <- XX[ind, z]
} else {
XX[[z]] <- XX[[z]][ind]
}
pred <- prepare_pred(pred_fun(object, XX, ...))
t(wrowmean(loss(y, pred), ngroups = m_rep, w = w))
}
Expand Down
18 changes: 13 additions & 5 deletions tests/testthat/test_perm_importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ test_that("results are positive for modeled features and zero otherwise (univari

test_that("perm_importance() raises some errors (univariate)", {
expect_error(perm_importance(fit, X = iris[-1L], y = 1:10, verbose = FALSE))
expect_error(perm_importance(fit, X = iris[-1], y = "Hello", verbose = FALSE))
expect_error(perm_importance(fit, X = iris[-1L], y = "Hi", verbose = FALSE))
})

test_that("constant weights is same as unweighted (univariate)", {
Expand Down Expand Up @@ -139,12 +139,20 @@ test_that("groups of variables work as well", {
})

test_that("matrix case works as well", {
X <- cbind(i = 1, data.matrix(iris[2:4]))
v <- c("Petal.Length", "Petal.Width", "Sepal.Width")
X <- cbind(intercept = 1, data.matrix(iris[v]))
fit <- lm.fit(x = X, y = y)
pred_fun <- function(m, X) X %*% m$coefficients
expect_no_error(
perm_importance(fit, X = X, y = y, pred_fun = pred_fun, verbose = FALSE)
)

set.seed(1L)
s2 <- perm_importance(fit, X = X, y = y, v = v, pred_fun = pred_fun, verbose = FALSE)
expect_equal(dim(s2), c(3L, 1L))

v2 <- list(Petal.Length = "Petal.Length", Petal = c("Petal.Length", "Petal.Width"))
set.seed(1L)
s3 <- perm_importance(fit, X = X, y = y, v = v2, pred_fun = pred_fun, verbose = FALSE)
expect_equal(dim(s3), c(2L, 1L))
expect_equal(s2$M["Petal.Length", ], s3$M["Petal.Length", ])
})

test_that("non-numeric predictions can work as well (classification error)", {
Expand Down

0 comments on commit fab07aa

Please sign in to comment.