diff --git a/.Rbuildignore b/.Rbuildignore index 8f8e47e..53bec18 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -8,4 +8,6 @@ utils/ ^pkgdown$ ^\.github$ ^codecov\.yml$ -^_mall_cache/ \ No newline at end of file +^_mall_cache/ +^_readme_cache/ +^_prompt_cache/ diff --git a/.github/workflows/R-CMD-check.yaml b/.github/workflows/R-CMD-check.yaml index c914ce4..960d062 100644 --- a/.github/workflows/R-CMD-check.yaml +++ b/.github/workflows/R-CMD-check.yaml @@ -43,9 +43,16 @@ jobs: - uses: r-lib/actions/setup-r-dependencies@v2 with: - extra-packages: any::rcmdcheck - needs: check - + extra-packages: | + any::rcmdcheck + needs: | + check + + - name: Installing dbplyr + run: | + pak::pak("dbplyr") + shell: Rscript {0} + - uses: r-lib/actions/check-r-package@v2 with: upload-snapshots: true diff --git a/.github/workflows/test-coverage.yaml b/.github/workflows/test-coverage.yaml index bf1c324..cf09f1b 100644 --- a/.github/workflows/test-coverage.yaml +++ b/.github/workflows/test-coverage.yaml @@ -25,7 +25,10 @@ jobs: - uses: r-lib/actions/setup-r-dependencies@v2 with: - extra-packages: any::covr, any::xml2 + extra-packages: | + any::covr + any::xml2 + any::dbplyr needs: coverage - name: Test coverage diff --git a/DESCRIPTION b/DESCRIPTION index 495e73b..c124321 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -21,6 +21,7 @@ Imports: ollamar, rlang Suggests: + dbplyr, testthat (>= 3.0.0) Config/testthat/edition: 3 URL: https://edgararuiz.github.io/mall/ diff --git a/R/llm-extract.R b/R/llm-extract.R index 2ca49d8..027374f 100644 --- a/R/llm-extract.R +++ b/R/llm-extract.R @@ -77,14 +77,10 @@ llm_extract.data.frame <- function(.data, llm_vec_extract <- function(x, labels = c(), additional_prompt = "") { - resp <- l_vec_prompt( + l_vec_prompt( x = x, prompt_label = "extract", labels = labels, additional_prompt = additional_prompt ) - map_chr( - resp, - \(x) paste0(as.character(fromJSON(x, flatten = TRUE)), collapse = "|") - ) } diff --git a/R/m-backend-prompt.R b/R/m-backend-prompt.R index 1eb642e..16424ec 100644 --- a/R/m-backend-prompt.R +++ b/R/m-backend-prompt.R @@ -57,11 +57,12 @@ m_backend_prompt.mall_defaults <- function(backend, additional = "") { json_labels <- paste0("\"", labels, "\":your answer", collapse = ",") json_labels <- paste0("{{", json_labels, "}}") plural <- ifelse(no_labels > 1, "s", "") + text_multi <- ifelse( + no_labels > 1, + "Return the response in a simple list, pipe separated, and no headers. ", + "" + ) list( - list( - role = "system", - content = "You only speak simple JSON. Do not write normal text." - ), list( role = "user", content = glue(paste( @@ -69,7 +70,7 @@ m_backend_prompt.mall_defaults <- function(backend, additional = "") { "Extract the {col_labels} being referred to on the text.", "I expect {no_labels} item{plural} exactly.", "No capitalization. No explanations.", - "You will use this JSON this format exclusively: {json_labels} .", + "{text_multi}", "{additional}", "The answer is based on the following text:\n{{x}}" )) @@ -100,7 +101,7 @@ l_vec_prompt <- function(x, prompt = NULL, ...) { # Initializes session LLM - backend <- llm_use(.silent = TRUE, force = FALSE) + backend <- llm_use(.silent = TRUE, .force = FALSE) # If there is no 'prompt', then assumes that we're looking for a # prompt label (sentiment, classify, etc) to set 'prompt' if (is.null(prompt)) { diff --git a/R/m-backend-submit.R b/R/m-backend-submit.R index 4f0f7b7..6d66417 100644 --- a/R/m-backend-submit.R +++ b/R/m-backend-submit.R @@ -47,6 +47,8 @@ m_backend_submit.mall_simulate_llm <- function(backend, x, prompt) { out <- map_chr(x, \(x) trimws(strsplit(x, "\\|")[[1]][[2]])) } else if (args$model == "echo") { out <- x + } else if (args$model == "prompt") { + out <- prompt } res <- NULL if (m_cache_use()) { @@ -77,9 +79,6 @@ m_cache_record <- function(.args, .response, hash_args) { } m_cache_check <- function(hash_args) { - if (!m_cache_use()) { - return(invisible()) - } folder_root <- m_cache_folder() resp <- suppressWarnings( try(read_json(m_cache_file(hash_args)), TRUE) diff --git a/R/utils.R b/R/utils.R index 842b724..68ccf15 100644 --- a/R/utils.R +++ b/R/utils.R @@ -1,13 +1,10 @@ -clean_names <- function(x, replace_periods = FALSE) { +clean_names <- function(x) { x <- tolower(x) map_chr( x, \(x) { out <- str_replace_clean(x, " ") out <- str_replace_clean(out, "\\:") - if (replace_periods) { - out <- str_replace_clean(out, "\\.") - } out } ) diff --git a/tests/testthat/_snaps/llm-extract.md b/tests/testthat/_snaps/llm-extract.md index 4da22b7..da82dcf 100644 --- a/tests/testthat/_snaps/llm-extract.md +++ b/tests/testthat/_snaps/llm-extract.md @@ -1,3 +1,18 @@ +# Extract works + + Code + llm_vec_extract("toaster", labels = "product") + Output + [[1]] + [[1]]$role + [1] "user" + + [[1]]$content + You are a helpful text extraction engine. Extract the product being referred to on the text. I expect 1 item exactly. No capitalization. No explanations. The answer is based on the following text: + {x} + + + # Extract on Ollama works Code @@ -12,3 +27,10 @@ 2 laptop 3 washing machine +--- + + Code + llm_vec_extract("bob smith, 105 2nd street", c("name", "address")) + Output + [1] "bob smith | 105 2nd street" + diff --git a/tests/testthat/_snaps/llm-sentiment.md b/tests/testthat/_snaps/llm-sentiment.md index 6076166..e0671b5 100644 --- a/tests/testthat/_snaps/llm-sentiment.md +++ b/tests/testthat/_snaps/llm-sentiment.md @@ -1,3 +1,12 @@ +# Sentiment translates expected Spark SQL + + Code + llm_sentiment(df_spark, x) + Output + + SELECT `df`.*, ai_analyze_sentiment(`x`) AS `.sentiment` + FROM `df` + # Sentiment on Ollama works Code diff --git a/tests/testthat/_snaps/llm-summarize.md b/tests/testthat/_snaps/llm-summarize.md index ad13916..3a7ae57 100644 --- a/tests/testthat/_snaps/llm-summarize.md +++ b/tests/testthat/_snaps/llm-summarize.md @@ -1,3 +1,21 @@ +# Summarize translates expected Spark SQL + + Code + llm_summarize(df_spark, x) + Output + + SELECT `df`.*, ai_summarize(`x`, CAST(10.0 AS INT)) AS `.summary` + FROM `df` + +--- + + Code + llm_summarize(df_spark, x, max_words = 50) + Output + + SELECT `df`.*, ai_summarize(`x`, CAST(50.0 AS INT)) AS `.summary` + FROM `df` + # Summarize on Ollama works Code diff --git a/tests/testthat/_snaps/zzz-cache.md b/tests/testthat/_snaps/zzz-cache.md index 001e5fa..269034e 100644 --- a/tests/testthat/_snaps/zzz-cache.md +++ b/tests/testthat/_snaps/zzz-cache.md @@ -5,23 +5,24 @@ Output _mall_cache/65 _mall_cache/65/654ffcb598cbbdeb0c1c7ebd05239ed4.json - _mall_cache/65/65fd4ea24b687f30e12881a3c5ee4acc.json _mall_cache/72 _mall_cache/72/72433745387c1c32b53ec05a77fa3b97.json - _mall_cache/73 - _mall_cache/73/73962724ef856c31dffa1dfcae15daf1.json _mall_cache/76 _mall_cache/76/768d2519acbc4cce84d8069f665ee0b2.json _mall_cache/7a _mall_cache/7a/7a500c559374fe68cf6cf57605d9de46.json + _mall_cache/7d + _mall_cache/7d/7d2911f2710df85c421e40c2991a6cac.json + _mall_cache/7e + _mall_cache/7e/7e20b6c382a19b95ea3e43d76af92b0f.json _mall_cache/83 _mall_cache/83/83dafa73c6187bf2156f4f07d6fdfc36.json - _mall_cache/87 - _mall_cache/87/8787f5fe54447691660ccb29bedbb420.json _mall_cache/ab _mall_cache/ab/abcd4e05a8a4cf61e0053a7a7bed1360.json _mall_cache/df _mall_cache/df/dfe963014849a2dbd09171246c45c511.json + _mall_cache/e6 + _mall_cache/e6/e624d61d6a9dab8052ff6053b5be2f62.json _mall_cache/f4 _mall_cache/f4/f47435d14c52e6a5c07400fb5f43db58.json _mall_cache/f7 diff --git a/tests/testthat/test-llm-extract.R b/tests/testthat/test-llm-extract.R index bee000f..46fdf35 100644 --- a/tests/testthat/test-llm-extract.R +++ b/tests/testthat/test-llm-extract.R @@ -1,45 +1,44 @@ test_that("Extract works", { - llm_use("simulate_llm", "echo", .silent = TRUE, .force = TRUE) + llm_use("simulate_llm", "prompt", .silent = TRUE, .force = TRUE) - expect_equal( - llm_vec_extract("{\"product\":\"toaster\"}", labels = "product"), - "toaster" + expect_snapshot( + llm_vec_extract("toaster", labels = "product") ) +}) + +test_that("Extract data frame works", { + llm_use("simulate_llm", "echo", .silent = TRUE, .force = TRUE) - entries2 <- "{\"product\":\"toaster\", \"product\":\"TV\"}" - entries2_result <- "toaster|TV" - expect_equal( - llm_vec_extract( - entries2, - labels = "product" - ), - entries2_result - ) expect_equal( - llm_extract(data.frame(x = entries2), x, labels = "product"), - data.frame(x = entries2, .extract = entries2_result) + llm_extract(data.frame(x = "test"), x, labels = "product"), + data.frame(x = "test", .extract = "test") ) + expect_equal( llm_extract( - .data = data.frame(x = entries2), + .data = data.frame(x = "test1|test2"), col = x, labels = c("product1", "product2"), expand_cols = TRUE ), - data.frame(x = entries2, product1 = "toaster", product2 = "TV") + data.frame(x = "test1|test2", product1 = "test1", product2 = "test2") ) + expect_equal( llm_extract( - .data = data.frame(x = entries2), + .data = data.frame(x = "test1|test2"), col = x, labels = c(y = "product1", z = "product2"), expand_cols = TRUE ), - data.frame(x = entries2, y = "toaster", z = "TV") + data.frame(x = "test1|test2", y = "test1", z = "test2") ) }) test_that("Extract on Ollama works", { skip_if_no_ollama() expect_snapshot(llm_extract(reviews_table(), review, "product")) + expect_snapshot( + llm_vec_extract("bob smith, 105 2nd street", c("name", "address")) + ) }) diff --git a/tests/testthat/test-llm-sentiment.R b/tests/testthat/test-llm-sentiment.R index 5318cae..30ce1d1 100644 --- a/tests/testthat/test-llm-sentiment.R +++ b/tests/testthat/test-llm-sentiment.R @@ -20,6 +20,13 @@ test_that("Sentiment works", { ) }) +test_that("Sentiment translates expected Spark SQL", { + suppressPackageStartupMessages(library(dbplyr)) + df <- data.frame(x = 1) + df_spark <- tbl_lazy(df, con = simulate_spark_sql()) + expect_snapshot(llm_sentiment(df_spark, x)) +}) + test_that("Sentiment on Ollama works", { skip_if_no_ollama() vec_reviews <- reviews_vec() diff --git a/tests/testthat/test-llm-summarize.R b/tests/testthat/test-llm-summarize.R index 9d78c6b..4b2d18c 100644 --- a/tests/testthat/test-llm-summarize.R +++ b/tests/testthat/test-llm-summarize.R @@ -21,6 +21,14 @@ test_that("Summarize works", { ) }) +test_that("Summarize translates expected Spark SQL", { + suppressPackageStartupMessages(library(dbplyr)) + df <- data.frame(x = 1) + df_spark <- tbl_lazy(df, con = simulate_spark_sql()) + expect_snapshot(llm_summarize(df_spark, x)) + expect_snapshot(llm_summarize(df_spark, x, max_words = 50)) +}) + test_that("Summarize on Ollama works", { skip_if_no_ollama() expect_snapshot(llm_summarize(reviews_table(), review, max_words = 5)) diff --git a/tests/testthat/test-llm-use.R b/tests/testthat/test-llm-use.R index d44be4a..f630cac 100644 --- a/tests/testthat/test-llm-use.R +++ b/tests/testthat/test-llm-use.R @@ -1,3 +1,15 @@ +test_that("Ollama not found error", { + local_mocked_bindings( + test_connection = function() { + x <- list() + x$status_code <- 400 + x + } + ) + .env_llm$defaults <- list() + expect_error(llm_use()) +}) + test_that("Init code is covered", { local_mocked_bindings( test_connection = function() { diff --git a/tests/testthat/test-m-backend-prompt.R b/tests/testthat/test-m-backend-prompt.R new file mode 100644 index 0000000..ccd497b --- /dev/null +++ b/tests/testthat/test-m-backend-prompt.R @@ -0,0 +1,14 @@ +test_that("Prompt handles list()", { + llm_use( + backend = "simulate_llm", + model = "prompt", + .silent = TRUE, + .force = TRUE, + .cache = "_prompt_cache" + ) + test_text <- "Custom:{prompt}\n{{x}}" + expect_equal( + llm_vec_custom(x = "new test", prompt = test_text), + list(list(role = "user", content = test_text)) + ) +}) diff --git a/tests/testthat/test-m-backend-submit.R b/tests/testthat/test-m-backend-submit.R index e4fd5f5..bff26ee 100644 --- a/tests/testthat/test-m-backend-submit.R +++ b/tests/testthat/test-m-backend-submit.R @@ -8,3 +8,11 @@ test_that("Ollama code is covered", { "positive" ) }) + +test_that("No cache is saved if turned off", { + llm_use("simulate_llm", "echo", .silent = TRUE, .force = TRUE, .cache = "") + expect_equal( + llm_vec_custom("test text", "nothing new: "), + "test text" + ) +}) diff --git a/tests/testthat/test-zzz-cache.R b/tests/testthat/test-zzz-cache.R index 1ca913d..29adf18 100644 --- a/tests/testthat/test-zzz-cache.R +++ b/tests/testthat/test-zzz-cache.R @@ -8,7 +8,7 @@ test_that("Ollama cache exists and delete", { skip_if_no_ollama() expect_equal( length(fs::dir_ls("_ollama_cache", recurse = TRUE)), - 53 + 55 ) fs::dir_delete("_ollama_cache") })