diff --git a/NAMESPACE b/NAMESPACE index f7d3347..200d85b 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -8,10 +8,10 @@ S3method(llm_sentiment,data.frame) S3method(llm_summarize,"tbl_Spark SQL") S3method(llm_summarize,data.frame) S3method(llm_translate,data.frame) -S3method(m_backend_prompt,mall_defaults) +S3method(m_backend_prompt,mall_session) S3method(m_backend_submit,mall_ollama) S3method(m_backend_submit,mall_simulate_llm) -S3method(print,mall_defaults) +S3method(print,mall_session) export(llm_classify) export(llm_custom) export(llm_extract) diff --git a/R/defaults.R b/R/defaults.R deleted file mode 100644 index b687c5d..0000000 --- a/R/defaults.R +++ /dev/null @@ -1,24 +0,0 @@ -defaults_get <- function() { - .env_llm$defaults -} - -defaults_set <- function(...) { - new_args <- list(...) - for (i in seq_along(new_args)) { - nm <- names(new_args[i]) - .env_llm$defaults[[nm]] <- new_args[[i]] - } - obj_class <- clean_names(c( - .env_llm$defaults[["model"]], - .env_llm$defaults[["backend"]], - "defaults" - )) - class(.env_llm$defaults) <- paste0("mall_", obj_class) - defaults_get() -} - -#' @export -print.mall_defaults <- function(x, ...) { - cli_inform(glue("{col_green('Provider:')} {x$backend}")) - cli_inform(glue("{col_green('Model:')} {x$model}")) -} diff --git a/R/llm-use.R b/R/llm-use.R index ad01c70..025879a 100644 --- a/R/llm-use.R +++ b/R/llm-use.R @@ -16,7 +16,7 @@ #' character: `""`. 'It defaults to '_mall_cache'. If this argument is left #' `NULL` when calling this function, no changes to the path will be made. #' -#' @returns A `mall_defaults` object +#' @returns A `mall_session` object #' #' @export llm_use <- function( @@ -28,7 +28,7 @@ llm_use <- function( .force = FALSE) { models <- list() supplied <- sum(!is.null(backend), !is.null(model)) - not_init <- inherits(defaults_get(), "list") + not_init <- inherits(m_defaults_get(), "list") if (supplied == 2) { not_init <- FALSE } @@ -56,21 +56,23 @@ llm_use <- function( } if (.force) { - .env_llm$cache <- .cache %||% "_mall_cache" - .env_llm$defaults <- list() + cache <- .cache %||% "_mall_cache" + m_defaults_reset() } else { - .env_llm$cache <- .cache %||% .env_llm$cache %||% "_mall_cache" + cache <- .cache %||% m_defaults_cache() %||% "_mall_cache" } - if (!is.null(backend) && !is.null(model)) { - defaults_set( - backend = backend, - model = model, - ... - ) - } + backend <- backend %||% m_defaults_backend() + model <- model %||% m_defaults_model() + + m_defaults_set( + backend = backend, + model = model, + .cache = cache, + ... + ) if (!.silent || not_init) { - print(defaults_get()) + print(m_defaults_get()) } - invisible(defaults_get()) + invisible(m_defaults_get()) } diff --git a/R/m-backend-prompt.R b/R/m-backend-prompt.R index 1a9599c..8d87fd6 100644 --- a/R/m-backend-prompt.R +++ b/R/m-backend-prompt.R @@ -5,7 +5,7 @@ m_backend_prompt <- function(backend, additional) { } #' @export -m_backend_prompt.mall_defaults <- function(backend, additional = "") { +m_backend_prompt.mall_session <- function(backend, additional = "") { list( sentiment = function(options) { options <- process_labels( diff --git a/R/m-backend-submit.R b/R/m-backend-submit.R index 7ce10cc..7d23565 100644 --- a/R/m-backend-submit.R +++ b/R/m-backend-submit.R @@ -1,6 +1,6 @@ #' Functions to integrate different back-ends #' -#' @param backend An `mall_defaults` object +#' @param backend An `mall_session` object #' @param x The body of the text to be submitted to the LLM #' @param prompt The additional information to add to the submission #' @param additional Additional text to insert to the `base_prompt` @@ -29,11 +29,10 @@ m_backend_submit.mall_ollama <- function(backend, x, prompt, preview = FALSE) { .args <- c( messages = list(map(prompt, \(i) map(i, \(j) glue(j, x = x)))), output = "text", - backend + m_defaults_args(backend) ) res <- NULL if (preview) { - .args$backend <- NULL res <- expr(ollamar::chat(!!!.args)) } if (m_cache_use() && is.null(res)) { @@ -41,7 +40,6 @@ m_backend_submit.mall_ollama <- function(backend, x, prompt, preview = FALSE) { res <- m_cache_check(hash_args) } if (is.null(res)) { - .args$backend <- NULL res <- exec("chat", !!!.args) m_cache_record(.args, res, hash_args) } @@ -56,8 +54,7 @@ m_backend_submit.mall_simulate_llm <- function(backend, prompt, preview = FALSE) { .args <- as.list(environment()) - args <- backend - class(args) <- "list" + args <- m_defaults_args(backend) if (args$model == "pipe") { out <- map_chr(x, \(x) trimws(strsplit(x, "\\|")[[1]][[2]])) } else if (args$model == "echo") { diff --git a/R/m-cache.R b/R/m-cache.R index 5564208..7a2c64f 100644 --- a/R/m-cache.R +++ b/R/m-cache.R @@ -2,7 +2,7 @@ m_cache_record <- function(.args, .response, hash_args) { if (!m_cache_use()) { return(invisible()) } - folder_root <- m_cache_folder() + folder_root <- m_defaults_cache() try(dir_create(folder_root)) content <- list( request = .args, @@ -15,7 +15,7 @@ m_cache_record <- function(.args, .response, hash_args) { } m_cache_check <- function(hash_args) { - folder_root <- m_cache_folder() + folder_root <- m_defaults_cache() resp <- suppressWarnings( try(read_json(m_cache_file(hash_args)), TRUE) ) @@ -28,17 +28,13 @@ m_cache_check <- function(hash_args) { } m_cache_file <- function(hash_args) { - folder_root <- m_cache_folder() + folder_root <- m_defaults_cache() folder_sub <- substr(hash_args, 1, 2) path(folder_root, folder_sub, hash_args, ext = "json") } -m_cache_folder <- function() { - .env_llm$cache -} - m_cache_use <- function() { - folder <- m_cache_folder() %||% "" + folder <- m_defaults_cache() %||% "" out <- FALSE if (folder != "") { out <- TRUE diff --git a/R/m-defaults.R b/R/m-defaults.R new file mode 100644 index 0000000..b0535de --- /dev/null +++ b/R/m-defaults.R @@ -0,0 +1,76 @@ +m_defaults_set <- function(...) { + new_args <- list2(...) + defaults <- .env_llm$defaults + for (i in seq_along(new_args)) { + nm <- names(new_args[i]) + defaults[[nm]] <- new_args[[i]] + } + obj_class <- clean_names(c( + defaults[["model"]], + defaults[["backend"]], + "session" + )) + .env_llm$defaults <- defaults + .env_llm$session <- structure( + list( + name = defaults[["backend"]], + args = defaults[names(defaults) != "backend" & names(defaults) != ".cache"], + session = list( + cache_folder = defaults[[".cache"]] + ) + ), + class = paste0("mall_", obj_class) + ) + m_defaults_get() +} + +m_defaults_get <- function() { + .env_llm$session +} + +m_defaults_backend <- function() { + .env_llm$session$name +} + +m_defaults_model <- function() { + .env_llm$session$args$model +} + +m_defaults_cache <- function() { + .env_llm$session$session$cache_folder +} + +m_defaults_reset <- function() { + .env_llm$defaults <- list() + .env_llm$session <- list() +} + +m_defaults_args <- function(x = m_defaults_get()) { + x$args +} + +#' @export +print.mall_session <- function(x, ...) { + cli_h3("{col_cyan('mall')} session object") + cli_inform(glue("{col_green('Backend:')} {x$name}")) + args <- imap(x$args, \(x, y) glue("{col_yellow({paste0(y, ':')})}{x}")) + label_argument <- "{col_green('LLM session:')}" + if (length(args) == 1) { + cli_inform(paste(label_argument, args[[1]])) + } else { + cli_inform(label_argument) + args <- as.character(args) + args <- set_names(args, " ") + cli_bullets(args) + } + session <- imap(x$session, \(x, y) glue("{col_yellow({paste0(y, ':')})}{x}")) + label_argument <- "{col_green('R session:')}" + if (length(session) == 1) { + cli_inform(paste(label_argument, session[[1]])) + } else { + cli_inform(label_argument) + session <- as.character(session) + session <- set_names(session, " ") + cli_bullets(session) + } +} diff --git a/R/mall.R b/R/mall.R index eb8b4f8..36e63e8 100644 --- a/R/mall.R +++ b/R/mall.R @@ -8,5 +8,4 @@ #' @import cli .env_llm <- new.env() -.env_llm$defaults <- list() -.env_llm$cache <- NULL +m_defaults_reset() diff --git a/man/llm_use.Rd b/man/llm_use.Rd index 0b067e5..e237508 100644 --- a/man/llm_use.Rd +++ b/man/llm_use.Rd @@ -34,7 +34,7 @@ character: \code{""}. 'It defaults to '_mall_cache'. If this argument is left R session} } \value{ -A \code{mall_defaults} object +A \code{mall_session} object } \description{ Allows us to specify the back-end provider, model to use during the current diff --git a/man/m_backend_submit.Rd b/man/m_backend_submit.Rd index d91fa93..f64bb69 100644 --- a/man/m_backend_submit.Rd +++ b/man/m_backend_submit.Rd @@ -10,7 +10,7 @@ m_backend_prompt(backend, additional) m_backend_submit(backend, x, prompt, preview = FALSE) } \arguments{ -\item{backend}{An \code{mall_defaults} object} +\item{backend}{An \code{mall_session} object} \item{additional}{Additional text to insert to the \code{base_prompt}} diff --git a/tests/testthat/_snaps/zzz-cache.md b/tests/testthat/_snaps/zzz-cache.md index 5cdc9a9..b6d85ca 100644 --- a/tests/testthat/_snaps/zzz-cache.md +++ b/tests/testthat/_snaps/zzz-cache.md @@ -3,30 +3,30 @@ Code fs::dir_ls("_mall_cache", recurse = TRUE) Output - _mall_cache/47 - _mall_cache/47/4719ddcccab04eee3f2355a53cc27219.json - _mall_cache/55 - _mall_cache/55/55616b50a571c2246950e9a2179e3ae0.json - _mall_cache/6f - _mall_cache/6f/6f5cffa79bde9ec4d56472204d4f49c2.json - _mall_cache/71 - _mall_cache/71/71690a76540d567abe0bf76b8be7b3f0.json - _mall_cache/8c - _mall_cache/8c/8cc0fb726002ccec0345938877129357.json - _mall_cache/8d - _mall_cache/8d/8db41f08bc77b48e4387babc74ad3fd3.json + _mall_cache/00 + _mall_cache/00/004088f786ed0f6a3abc08f2aa55ae2b.json + _mall_cache/14 + _mall_cache/14/14afc26cb4f76497b80b5552b2b1e217.json + _mall_cache/18 + _mall_cache/18/18560280fe5b5a85f2d66fa2dc89aa00.json + _mall_cache/29 + _mall_cache/29/296f3116c07dab7f3ecb4a71776e3b64.json + _mall_cache/2c + _mall_cache/2c/2cbb57fd4a7e7178c489d068db063433.json + _mall_cache/44 + _mall_cache/44/44fd00c39a9697e24e93943ef5f2ad1b.json + _mall_cache/57 + _mall_cache/57/5702ff773afb880c746037a5d8254019.json + _mall_cache/65 + _mall_cache/65/65c76a53ebea14a6695adf433fb2faa6.json + _mall_cache/98 + _mall_cache/98/98a43dc690b06455d6b0a5046db31d84.json _mall_cache/9c - _mall_cache/9c/9c2d6c1858e9e0e03cfee5411de6e1ca.json - _mall_cache/9d - _mall_cache/9d/9d859cbfb2855ce7b13c63037ff27351.json - _mall_cache/a0 - _mall_cache/a0/a00cf53bede87a270fb59b2199745037.json - _mall_cache/a3 - _mall_cache/a3/a3bd49c839ea6ca5d4013c1dbe77e65c.json - _mall_cache/a5 - _mall_cache/a5/a597588456064a215c11791442be6587.json - _mall_cache/d9 - _mall_cache/d9/d90e4685e9fd358bccf99751b3eee3e3.json - _mall_cache/df - _mall_cache/df/dfe963014849a2dbd09171246c45c511.json + _mall_cache/9c/9c4ed89921994aa00c712bada91ef941.json + _mall_cache/b0 + _mall_cache/b0/b02d0fab954e183a98787fa897b47d59.json + _mall_cache/b7 + _mall_cache/b7/b7c613386c94b2500b2b733632fedd1a.json + _mall_cache/c2 + _mall_cache/c2/c2e2ca95eaaa64b8926b185d6eeec18f.json diff --git a/tests/testthat/test-llm-use.R b/tests/testthat/test-llm-use.R index f630cac..445768c 100644 --- a/tests/testthat/test-llm-use.R +++ b/tests/testthat/test-llm-use.R @@ -6,7 +6,7 @@ test_that("Ollama not found error", { x } ) - .env_llm$defaults <- list() + m_defaults_reset() expect_error(llm_use()) }) @@ -20,6 +20,6 @@ test_that("Init code is covered", { list_models = function() data.frame(name = c("model1", "model2")), menu = function(...) 1 ) - .env_llm$defaults <- list() + m_defaults_reset() expect_message(llm_use()) })