Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sketch Data Improvement #8913

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 63 additions & 50 deletions R/sketching.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@ NULL
#'
#' @param object A Seurat object.
#' @param assay Assay name. Default is NULL, in which case the default assay of the object is used.
#' @param ncells A positive integer indicating the number of cells to sample for the sketching. Default is 5000.
#' @param cell.ratio Proportion of cells to sample from each layer.
#' @min.cells Minimum cells a layer must have in order to be subsampled.
#' @param sketched.assay Sketched assay name. A sketch assay is created or overwrite with the sketch data. Default is 'sketch'.
#' @param method Sketching method to use. Can be 'LeverageScore' or 'Uniform'.
#' Default is 'LeverageScore'.
#' @param var.name A metadata column name to store the leverage scores. Default is 'leverage.score'.
#' @param over.write whether to overwrite existing column in the metadata. Default is FALSE.
#' @leverage.already.calculated whether leverage scores have already been calculated.
#' @param seed A positive integer for the seed of the random number generator. Default is 123.
#' @param cast The type to cast the resulting assay to. Default is 'dgCMatrix'.
#' @param verbose Print progress and diagnostic messages
Expand All @@ -41,50 +43,51 @@ NULL
SketchData <- function(
object,
assay = NULL,
ncells = 5000L,
cell.ratio = 0.25,
min.cells = 2500,
sketched.assay = 'sketch',
method = c('LeverageScore', 'Uniform'),
var.name = "leverage.score",
over.write = FALSE,
over.write = F,
leverage.already.calculated = F,
seed = 123L,
cast = 'dgCMatrix',
verbose = TRUE,
verbose = T,
...
) {
assay <- assay[1L] %||% DefaultAssay(object = object)
assay <- match.arg(arg = assay, choices = Assays(object = object))
assay <- match.arg(arg = assay, choices = SeuratObject::Assays(object = object))
method <- match.arg(arg = method)
if (sketched.assay == assay) {
abort(message = "Cannot overwrite existing assays")
rlang::abort(message = "Cannot overwrite existing assays")
}
if (sketched.assay %in% Assays(object = object)) {
if (sketched.assay %in% SeuratObject::Assays(object = object)) {
if (sketched.assay == DefaultAssay(object = object)) {
DefaultAssay(object = object) <- assay
}
object[[sketched.assay]] <- NULL
}
if (!over.write) {
var.name <- CheckMetaVarName(object = object, var.name = var.name)
}

if (method == 'LeverageScore') {
if (verbose) {
message("Calcuating Leverage Score")
}
object <- LeverageScore(
object = object,
assay = assay,
var.name = var.name,
over.write = over.write,
seed = seed,
verbose = FALSE,
...
)
} else if (method == 'Uniform') {
if (verbose) {
message("Uniformly sampling")

if (over.write == T | leverage.already.calculated == F) {
if (method == 'LeverageScore') {
if (verbose) {
message("Calcuating Leverage Score")
}
object <- LeverageScore(
object = object,
assay = assay,
var.name = var.name,
over.write = over.write,
seed = seed,
verbose = verbose,
...
)
} else if (method == 'Uniform') {
if (verbose) {
message("Uniformly sampling")
}
object[[var.name]] <- 1
}
object[[var.name]] <- 1
}
leverage.score <- object[[var.name]]
layers.data <- Layers(object = object[[assay]], search = 'data')
Expand All @@ -93,12 +96,14 @@ SketchData <- function(
FUN = function(i, seed) {
set.seed(seed = seed)
lcells <- Cells(x = object[[assay]], layer = layers.data[i])
if (length(x = lcells) < ncells) {
return(lcells)
if (length(lcells) < min.cells) {
ncells_per_sample = length(lcells)
} else {
ncells_per_sample = max(round(length(lcells)*cell.ratio), min.cells)
}
return(sample(
x = lcells,
size = ncells,
size = ncells_per_sample,
prob = leverage.score[lcells,]
))
},
Expand All @@ -113,13 +118,13 @@ SketchData <- function(
try(
expr = VariableFeatures(object = sketched, method = "sketch", layer = lyr) <-
VariableFeatures(object = object[[assay]], layer = lyr),
silent = TRUE
silent = F
)
}
if (!is.null(x = cast) && inherits(x = sketched, what = 'Assay5')) {
sketched <- CastAssay(object = sketched, to = cast, ...)
}
Key(object = sketched) <- Key(object = sketched.assay, quiet = TRUE)
Key(object = sketched) <- Key(object = sketched.assay, quiet = F)
object[[sketched.assay]] <- sketched
DefaultAssay(object = object) <- sketched.assay
return(object)
Expand Down Expand Up @@ -369,6 +374,7 @@ TransferSketchLabels <- function(
#' @param seed A positive integer. The seed for the random number generator, defaults to 123.
#' @param verbose Print progress and diagnostic messages
#' @importFrom Matrix qrR t
#' @importFrom matrixcalc is.singular.matrix
#' @importFrom irlba irlba
#'
#' @rdname LeverageScore
Expand Down Expand Up @@ -448,23 +454,30 @@ LeverageScore.default <- function(
} else {
base::qr.R(qr = qr.sa)
}
R.inv <- as.sparse(x = backsolve(r = R, x = diag(x = ncol(x = R))))
if (isTRUE(x = verbose)) {
message("Performing random projection")
}
JL <- as.sparse(x = JLEmbed(
nrow = ncol(x = R.inv),
ncol = ndims,
eps = eps,
seed = seed
))
Z <- object %*% (R.inv %*% JL)
if (inherits(x = Z, what = 'IterableMatrix')) {
Z.score <- BPCells::matrix_stats(matrix = Z ^ 2, row_stats = 'mean'
)$row_stats['mean',]*ncol(x = Z)
} else {
Z.score <- rowSums(x = Z ^ 2)
}
A <- diag(x = R)
if (any(A == 0)) {
bad_elem <- which(A == 0)
message(paste0("Found 0 in diagonal of input matrix at ", bad_elem, ". Assigning all cells leverage score of 1"))
Z.score <- rep(1, nrow(x = object))
} else {
R.inv <- as.sparse(x = backsolve(r = R, x = diag(x = ncol(x = R))))
if (isTRUE(x = verbose)) {
message("Performing random projection")
}
JL <- as.sparse(x = JLEmbed(
nrow = ncol(x = R.inv),
ncol = ndims,
eps = eps,
seed = seed
))
Z <- object %*% (R.inv %*% JL)
if (inherits(x = Z, what = 'IterableMatrix')) {
Z.score <- BPCells::matrix_stats(matrix = Z ^ 2, row_stats = 'mean'
)$row_stats['mean',]*ncol(x = Z)
} else {
Z.score <- rowSums(x = Z ^ 2)
}
}
return(Z.score)
}

Expand Down