Skip to content

Commit

Permalink
Merge pull request #12 from edgararuiz/updates
Browse files Browse the repository at this point in the history
Updates
  • Loading branch information
edgararuiz authored Sep 19, 2024
2 parents 7d036a9 + 9de14f1 commit d99a4ea
Show file tree
Hide file tree
Showing 12 changed files with 132 additions and 86 deletions.
4 changes: 2 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ S3method(llm_sentiment,data.frame)
S3method(llm_summarize,"tbl_Spark SQL")
S3method(llm_summarize,data.frame)
S3method(llm_translate,data.frame)
S3method(m_backend_prompt,mall_defaults)
S3method(m_backend_prompt,mall_session)
S3method(m_backend_submit,mall_ollama)
S3method(m_backend_submit,mall_simulate_llm)
S3method(print,mall_defaults)
S3method(print,mall_session)
export(llm_classify)
export(llm_custom)
export(llm_extract)
Expand Down
24 changes: 0 additions & 24 deletions R/defaults.R

This file was deleted.

30 changes: 16 additions & 14 deletions R/llm-use.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#' character: `""`. 'It defaults to '_mall_cache'. If this argument is left
#' `NULL` when calling this function, no changes to the path will be made.
#'
#' @returns A `mall_defaults` object
#' @returns A `mall_session` object
#'
#' @export
llm_use <- function(
Expand All @@ -28,7 +28,7 @@ llm_use <- function(
.force = FALSE) {
models <- list()
supplied <- sum(!is.null(backend), !is.null(model))
not_init <- inherits(defaults_get(), "list")
not_init <- inherits(m_defaults_get(), "list")
if (supplied == 2) {
not_init <- FALSE
}
Expand Down Expand Up @@ -56,21 +56,23 @@ llm_use <- function(
}

if (.force) {
.env_llm$cache <- .cache %||% "_mall_cache"
.env_llm$defaults <- list()
cache <- .cache %||% "_mall_cache"
m_defaults_reset()
} else {
.env_llm$cache <- .cache %||% .env_llm$cache %||% "_mall_cache"
cache <- .cache %||% m_defaults_cache() %||% "_mall_cache"
}

if (!is.null(backend) && !is.null(model)) {
defaults_set(
backend = backend,
model = model,
...
)
}
backend <- backend %||% m_defaults_backend()
model <- model %||% m_defaults_model()

m_defaults_set(
backend = backend,
model = model,
.cache = cache,
...
)
if (!.silent || not_init) {
print(defaults_get())
print(m_defaults_get())
}
invisible(defaults_get())
invisible(m_defaults_get())
}
2 changes: 1 addition & 1 deletion R/m-backend-prompt.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ m_backend_prompt <- function(backend, additional) {
}

#' @export
m_backend_prompt.mall_defaults <- function(backend, additional = "") {
m_backend_prompt.mall_session <- function(backend, additional = "") {
list(
sentiment = function(options) {
options <- process_labels(
Expand Down
9 changes: 3 additions & 6 deletions R/m-backend-submit.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#' Functions to integrate different back-ends
#'
#' @param backend An `mall_defaults` object
#' @param backend An `mall_session` object
#' @param x The body of the text to be submitted to the LLM
#' @param prompt The additional information to add to the submission
#' @param additional Additional text to insert to the `base_prompt`
Expand Down Expand Up @@ -29,19 +29,17 @@ m_backend_submit.mall_ollama <- function(backend, x, prompt, preview = FALSE) {
.args <- c(
messages = list(map(prompt, \(i) map(i, \(j) glue(j, x = x)))),
output = "text",
backend
m_defaults_args(backend)
)
res <- NULL
if (preview) {
.args$backend <- NULL
res <- expr(ollamar::chat(!!!.args))
}
if (m_cache_use() && is.null(res)) {
hash_args <- hash(.args)
res <- m_cache_check(hash_args)
}
if (is.null(res)) {
.args$backend <- NULL
res <- exec("chat", !!!.args)
m_cache_record(.args, res, hash_args)
}
Expand All @@ -56,8 +54,7 @@ m_backend_submit.mall_simulate_llm <- function(backend,
prompt,
preview = FALSE) {
.args <- as.list(environment())
args <- backend
class(args) <- "list"
args <- m_defaults_args(backend)
if (args$model == "pipe") {
out <- map_chr(x, \(x) trimws(strsplit(x, "\\|")[[1]][[2]]))
} else if (args$model == "echo") {
Expand Down
12 changes: 4 additions & 8 deletions R/m-cache.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ m_cache_record <- function(.args, .response, hash_args) {
if (!m_cache_use()) {
return(invisible())
}
folder_root <- m_cache_folder()
folder_root <- m_defaults_cache()
try(dir_create(folder_root))
content <- list(
request = .args,
Expand All @@ -15,7 +15,7 @@ m_cache_record <- function(.args, .response, hash_args) {
}

m_cache_check <- function(hash_args) {
folder_root <- m_cache_folder()
folder_root <- m_defaults_cache()
resp <- suppressWarnings(
try(read_json(m_cache_file(hash_args)), TRUE)
)
Expand All @@ -28,17 +28,13 @@ m_cache_check <- function(hash_args) {
}

m_cache_file <- function(hash_args) {
folder_root <- m_cache_folder()
folder_root <- m_defaults_cache()
folder_sub <- substr(hash_args, 1, 2)
path(folder_root, folder_sub, hash_args, ext = "json")
}

m_cache_folder <- function() {
.env_llm$cache
}

m_cache_use <- function() {
folder <- m_cache_folder() %||% ""
folder <- m_defaults_cache() %||% ""
out <- FALSE
if (folder != "") {
out <- TRUE
Expand Down
76 changes: 76 additions & 0 deletions R/m-defaults.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
m_defaults_set <- function(...) {
new_args <- list2(...)
defaults <- .env_llm$defaults
for (i in seq_along(new_args)) {
nm <- names(new_args[i])
defaults[[nm]] <- new_args[[i]]
}
obj_class <- clean_names(c(
defaults[["model"]],
defaults[["backend"]],
"session"
))
.env_llm$defaults <- defaults
.env_llm$session <- structure(
list(
name = defaults[["backend"]],
args = defaults[names(defaults) != "backend" & names(defaults) != ".cache"],
session = list(
cache_folder = defaults[[".cache"]]
)
),
class = paste0("mall_", obj_class)
)
m_defaults_get()
}

m_defaults_get <- function() {
.env_llm$session
}

m_defaults_backend <- function() {
.env_llm$session$name
}

m_defaults_model <- function() {
.env_llm$session$args$model
}

m_defaults_cache <- function() {
.env_llm$session$session$cache_folder
}

m_defaults_reset <- function() {
.env_llm$defaults <- list()
.env_llm$session <- list()
}

m_defaults_args <- function(x = m_defaults_get()) {
x$args
}

#' @export
print.mall_session <- function(x, ...) {
cli_h3("{col_cyan('mall')} session object")
cli_inform(glue("{col_green('Backend:')} {x$name}"))
args <- imap(x$args, \(x, y) glue("{col_yellow({paste0(y, ':')})}{x}"))
label_argument <- "{col_green('LLM session:')}"
if (length(args) == 1) {
cli_inform(paste(label_argument, args[[1]]))
} else {
cli_inform(label_argument)
args <- as.character(args)
args <- set_names(args, " ")
cli_bullets(args)
}
session <- imap(x$session, \(x, y) glue("{col_yellow({paste0(y, ':')})}{x}"))
label_argument <- "{col_green('R session:')}"
if (length(session) == 1) {
cli_inform(paste(label_argument, session[[1]]))
} else {
cli_inform(label_argument)
session <- as.character(session)
session <- set_names(session, " ")
cli_bullets(session)
}
}
3 changes: 1 addition & 2 deletions R/mall.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,4 @@
#' @import cli

.env_llm <- new.env()
.env_llm$defaults <- list()
.env_llm$cache <- NULL
m_defaults_reset()
2 changes: 1 addition & 1 deletion man/llm_use.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/m_backend_submit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

50 changes: 25 additions & 25 deletions tests/testthat/_snaps/zzz-cache.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,30 @@
Code
fs::dir_ls("_mall_cache", recurse = TRUE)
Output
_mall_cache/47
_mall_cache/47/4719ddcccab04eee3f2355a53cc27219.json
_mall_cache/55
_mall_cache/55/55616b50a571c2246950e9a2179e3ae0.json
_mall_cache/6f
_mall_cache/6f/6f5cffa79bde9ec4d56472204d4f49c2.json
_mall_cache/71
_mall_cache/71/71690a76540d567abe0bf76b8be7b3f0.json
_mall_cache/8c
_mall_cache/8c/8cc0fb726002ccec0345938877129357.json
_mall_cache/8d
_mall_cache/8d/8db41f08bc77b48e4387babc74ad3fd3.json
_mall_cache/00
_mall_cache/00/004088f786ed0f6a3abc08f2aa55ae2b.json
_mall_cache/14
_mall_cache/14/14afc26cb4f76497b80b5552b2b1e217.json
_mall_cache/18
_mall_cache/18/18560280fe5b5a85f2d66fa2dc89aa00.json
_mall_cache/29
_mall_cache/29/296f3116c07dab7f3ecb4a71776e3b64.json
_mall_cache/2c
_mall_cache/2c/2cbb57fd4a7e7178c489d068db063433.json
_mall_cache/44
_mall_cache/44/44fd00c39a9697e24e93943ef5f2ad1b.json
_mall_cache/57
_mall_cache/57/5702ff773afb880c746037a5d8254019.json
_mall_cache/65
_mall_cache/65/65c76a53ebea14a6695adf433fb2faa6.json
_mall_cache/98
_mall_cache/98/98a43dc690b06455d6b0a5046db31d84.json
_mall_cache/9c
_mall_cache/9c/9c2d6c1858e9e0e03cfee5411de6e1ca.json
_mall_cache/9d
_mall_cache/9d/9d859cbfb2855ce7b13c63037ff27351.json
_mall_cache/a0
_mall_cache/a0/a00cf53bede87a270fb59b2199745037.json
_mall_cache/a3
_mall_cache/a3/a3bd49c839ea6ca5d4013c1dbe77e65c.json
_mall_cache/a5
_mall_cache/a5/a597588456064a215c11791442be6587.json
_mall_cache/d9
_mall_cache/d9/d90e4685e9fd358bccf99751b3eee3e3.json
_mall_cache/df
_mall_cache/df/dfe963014849a2dbd09171246c45c511.json
_mall_cache/9c/9c4ed89921994aa00c712bada91ef941.json
_mall_cache/b0
_mall_cache/b0/b02d0fab954e183a98787fa897b47d59.json
_mall_cache/b7
_mall_cache/b7/b7c613386c94b2500b2b733632fedd1a.json
_mall_cache/c2
_mall_cache/c2/c2e2ca95eaaa64b8926b185d6eeec18f.json

4 changes: 2 additions & 2 deletions tests/testthat/test-llm-use.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ test_that("Ollama not found error", {
x
}
)
.env_llm$defaults <- list()
m_defaults_reset()
expect_error(llm_use())
})

Expand All @@ -20,6 +20,6 @@ test_that("Init code is covered", {
list_models = function() data.frame(name = c("model1", "model2")),
menu = function(...) 1
)
.env_llm$defaults <- list()
m_defaults_reset()
expect_message(llm_use())
})

0 comments on commit d99a4ea

Please sign in to comment.