Skip to content

Commit

Permalink
Merge pull request #10 from edgararuiz/updates
Browse files Browse the repository at this point in the history
Updates
  • Loading branch information
edgararuiz authored Sep 18, 2024
2 parents 189de14 + 56b71cd commit 85deefb
Show file tree
Hide file tree
Showing 19 changed files with 152 additions and 48 deletions.
4 changes: 3 additions & 1 deletion .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ utils/
^pkgdown$
^\.github$
^codecov\.yml$
^_mall_cache/
^_mall_cache/
^_readme_cache/
^_prompt_cache/
13 changes: 10 additions & 3 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,16 @@ jobs:

- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::rcmdcheck
needs: check

extra-packages: |
any::rcmdcheck
needs: |
check
- name: Installing dbplyr
run: |
pak::pak("dbplyr")
shell: Rscript {0}

- uses: r-lib/actions/check-r-package@v2
with:
upload-snapshots: true
Expand Down
5 changes: 4 additions & 1 deletion .github/workflows/test-coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ jobs:

- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::covr, any::xml2
extra-packages: |
any::covr
any::xml2
any::dbplyr
needs: coverage

- name: Test coverage
Expand Down
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Imports:
ollamar,
rlang
Suggests:
dbplyr,
testthat (>= 3.0.0)
Config/testthat/edition: 3
URL: https://edgararuiz.github.io/mall/
6 changes: 1 addition & 5 deletions R/llm-extract.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,10 @@ llm_extract.data.frame <- function(.data,
llm_vec_extract <- function(x,
labels = c(),
additional_prompt = "") {
resp <- l_vec_prompt(
l_vec_prompt(
x = x,
prompt_label = "extract",
labels = labels,
additional_prompt = additional_prompt
)
map_chr(
resp,
\(x) paste0(as.character(fromJSON(x, flatten = TRUE)), collapse = "|")
)
}
13 changes: 7 additions & 6 deletions R/m-backend-prompt.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,20 @@ m_backend_prompt.mall_defaults <- function(backend, additional = "") {
json_labels <- paste0("\"", labels, "\":your answer", collapse = ",")
json_labels <- paste0("{{", json_labels, "}}")
plural <- ifelse(no_labels > 1, "s", "")
text_multi <- ifelse(
no_labels > 1,
"Return the response in a simple list, pipe separated, and no headers. ",
""
)
list(
list(
role = "system",
content = "You only speak simple JSON. Do not write normal text."
),
list(
role = "user",
content = glue(paste(
"You are a helpful text extraction engine.",
"Extract the {col_labels} being referred to on the text.",
"I expect {no_labels} item{plural} exactly.",
"No capitalization. No explanations.",
"You will use this JSON this format exclusively: {json_labels} .",
"{text_multi}",
"{additional}",
"The answer is based on the following text:\n{{x}}"
))
Expand Down Expand Up @@ -100,7 +101,7 @@ l_vec_prompt <- function(x,
prompt = NULL,
...) {
# Initializes session LLM
backend <- llm_use(.silent = TRUE, force = FALSE)
backend <- llm_use(.silent = TRUE, .force = FALSE)
# If there is no 'prompt', then assumes that we're looking for a
# prompt label (sentiment, classify, etc) to set 'prompt'
if (is.null(prompt)) {
Expand Down
5 changes: 2 additions & 3 deletions R/m-backend-submit.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ m_backend_submit.mall_simulate_llm <- function(backend, x, prompt) {
out <- map_chr(x, \(x) trimws(strsplit(x, "\\|")[[1]][[2]]))
} else if (args$model == "echo") {
out <- x
} else if (args$model == "prompt") {
out <- prompt
}
res <- NULL
if (m_cache_use()) {
Expand Down Expand Up @@ -77,9 +79,6 @@ m_cache_record <- function(.args, .response, hash_args) {
}

m_cache_check <- function(hash_args) {
if (!m_cache_use()) {
return(invisible())
}
folder_root <- m_cache_folder()
resp <- suppressWarnings(
try(read_json(m_cache_file(hash_args)), TRUE)
Expand Down
5 changes: 1 addition & 4 deletions R/utils.R
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
clean_names <- function(x, replace_periods = FALSE) {
clean_names <- function(x) {
x <- tolower(x)
map_chr(
x,
\(x) {
out <- str_replace_clean(x, " ")
out <- str_replace_clean(out, "\\:")
if (replace_periods) {
out <- str_replace_clean(out, "\\.")
}
out
}
)
Expand Down
22 changes: 22 additions & 0 deletions tests/testthat/_snaps/llm-extract.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Extract works

Code
llm_vec_extract("toaster", labels = "product")
Output
[[1]]
[[1]]$role
[1] "user"
[[1]]$content
You are a helpful text extraction engine. Extract the product being referred to on the text. I expect 1 item exactly. No capitalization. No explanations. The answer is based on the following text:
{x}

# Extract on Ollama works

Code
Expand All @@ -12,3 +27,10 @@
2 laptop
3 washing machine

---

Code
llm_vec_extract("bob smith, 105 2nd street", c("name", "address"))
Output
[1] "bob smith | 105 2nd street"

9 changes: 9 additions & 0 deletions tests/testthat/_snaps/llm-sentiment.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
# Sentiment translates expected Spark SQL

Code
llm_sentiment(df_spark, x)
Output
<SQL>
SELECT `df`.*, ai_analyze_sentiment(`x`) AS `.sentiment`
FROM `df`

# Sentiment on Ollama works

Code
Expand Down
18 changes: 18 additions & 0 deletions tests/testthat/_snaps/llm-summarize.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
# Summarize translates expected Spark SQL

Code
llm_summarize(df_spark, x)
Output
<SQL>
SELECT `df`.*, ai_summarize(`x`, CAST(10.0 AS INT)) AS `.summary`
FROM `df`

---

Code
llm_summarize(df_spark, x, max_words = 50)
Output
<SQL>
SELECT `df`.*, ai_summarize(`x`, CAST(50.0 AS INT)) AS `.summary`
FROM `df`

# Summarize on Ollama works

Code
Expand Down
11 changes: 6 additions & 5 deletions tests/testthat/_snaps/zzz-cache.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,24 @@
Output
_mall_cache/65
_mall_cache/65/654ffcb598cbbdeb0c1c7ebd05239ed4.json
_mall_cache/65/65fd4ea24b687f30e12881a3c5ee4acc.json
_mall_cache/72
_mall_cache/72/72433745387c1c32b53ec05a77fa3b97.json
_mall_cache/73
_mall_cache/73/73962724ef856c31dffa1dfcae15daf1.json
_mall_cache/76
_mall_cache/76/768d2519acbc4cce84d8069f665ee0b2.json
_mall_cache/7a
_mall_cache/7a/7a500c559374fe68cf6cf57605d9de46.json
_mall_cache/7d
_mall_cache/7d/7d2911f2710df85c421e40c2991a6cac.json
_mall_cache/7e
_mall_cache/7e/7e20b6c382a19b95ea3e43d76af92b0f.json
_mall_cache/83
_mall_cache/83/83dafa73c6187bf2156f4f07d6fdfc36.json
_mall_cache/87
_mall_cache/87/8787f5fe54447691660ccb29bedbb420.json
_mall_cache/ab
_mall_cache/ab/abcd4e05a8a4cf61e0053a7a7bed1360.json
_mall_cache/df
_mall_cache/df/dfe963014849a2dbd09171246c45c511.json
_mall_cache/e6
_mall_cache/e6/e624d61d6a9dab8052ff6053b5be2f62.json
_mall_cache/f4
_mall_cache/f4/f47435d14c52e6a5c07400fb5f43db58.json
_mall_cache/f7
Expand Down
37 changes: 18 additions & 19 deletions tests/testthat/test-llm-extract.R
Original file line number Diff line number Diff line change
@@ -1,45 +1,44 @@
test_that("Extract works", {
llm_use("simulate_llm", "echo", .silent = TRUE, .force = TRUE)
llm_use("simulate_llm", "prompt", .silent = TRUE, .force = TRUE)

expect_equal(
llm_vec_extract("{\"product\":\"toaster\"}", labels = "product"),
"toaster"
expect_snapshot(
llm_vec_extract("toaster", labels = "product")
)
})

test_that("Extract data frame works", {
llm_use("simulate_llm", "echo", .silent = TRUE, .force = TRUE)

entries2 <- "{\"product\":\"toaster\", \"product\":\"TV\"}"
entries2_result <- "toaster|TV"
expect_equal(
llm_vec_extract(
entries2,
labels = "product"
),
entries2_result
)
expect_equal(
llm_extract(data.frame(x = entries2), x, labels = "product"),
data.frame(x = entries2, .extract = entries2_result)
llm_extract(data.frame(x = "test"), x, labels = "product"),
data.frame(x = "test", .extract = "test")
)

expect_equal(
llm_extract(
.data = data.frame(x = entries2),
.data = data.frame(x = "test1|test2"),
col = x,
labels = c("product1", "product2"),
expand_cols = TRUE
),
data.frame(x = entries2, product1 = "toaster", product2 = "TV")
data.frame(x = "test1|test2", product1 = "test1", product2 = "test2")
)

expect_equal(
llm_extract(
.data = data.frame(x = entries2),
.data = data.frame(x = "test1|test2"),
col = x,
labels = c(y = "product1", z = "product2"),
expand_cols = TRUE
),
data.frame(x = entries2, y = "toaster", z = "TV")
data.frame(x = "test1|test2", y = "test1", z = "test2")
)
})

test_that("Extract on Ollama works", {
skip_if_no_ollama()
expect_snapshot(llm_extract(reviews_table(), review, "product"))
expect_snapshot(
llm_vec_extract("bob smith, 105 2nd street", c("name", "address"))
)
})
7 changes: 7 additions & 0 deletions tests/testthat/test-llm-sentiment.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ test_that("Sentiment works", {
)
})

test_that("Sentiment translates expected Spark SQL", {
suppressPackageStartupMessages(library(dbplyr))
df <- data.frame(x = 1)
df_spark <- tbl_lazy(df, con = simulate_spark_sql())
expect_snapshot(llm_sentiment(df_spark, x))
})

test_that("Sentiment on Ollama works", {
skip_if_no_ollama()
vec_reviews <- reviews_vec()
Expand Down
8 changes: 8 additions & 0 deletions tests/testthat/test-llm-summarize.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ test_that("Summarize works", {
)
})

test_that("Summarize translates expected Spark SQL", {
suppressPackageStartupMessages(library(dbplyr))
df <- data.frame(x = 1)
df_spark <- tbl_lazy(df, con = simulate_spark_sql())
expect_snapshot(llm_summarize(df_spark, x))
expect_snapshot(llm_summarize(df_spark, x, max_words = 50))
})

test_that("Summarize on Ollama works", {
skip_if_no_ollama()
expect_snapshot(llm_summarize(reviews_table(), review, max_words = 5))
Expand Down
12 changes: 12 additions & 0 deletions tests/testthat/test-llm-use.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
test_that("Ollama not found error", {
local_mocked_bindings(
test_connection = function() {
x <- list()
x$status_code <- 400
x
}
)
.env_llm$defaults <- list()
expect_error(llm_use())
})

test_that("Init code is covered", {
local_mocked_bindings(
test_connection = function() {
Expand Down
14 changes: 14 additions & 0 deletions tests/testthat/test-m-backend-prompt.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
test_that("Prompt handles list()", {
llm_use(
backend = "simulate_llm",
model = "prompt",
.silent = TRUE,
.force = TRUE,
.cache = "_prompt_cache"
)
test_text <- "Custom:{prompt}\n{{x}}"
expect_equal(
llm_vec_custom(x = "new test", prompt = test_text),
list(list(role = "user", content = test_text))
)
})
8 changes: 8 additions & 0 deletions tests/testthat/test-m-backend-submit.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,11 @@ test_that("Ollama code is covered", {
"positive"
)
})

test_that("No cache is saved if turned off", {
llm_use("simulate_llm", "echo", .silent = TRUE, .force = TRUE, .cache = "")
expect_equal(
llm_vec_custom("test text", "nothing new: "),
"test text"
)
})
2 changes: 1 addition & 1 deletion tests/testthat/test-zzz-cache.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ test_that("Ollama cache exists and delete", {
skip_if_no_ollama()
expect_equal(
length(fs::dir_ls("_ollama_cache", recurse = TRUE)),
53
55
)
fs::dir_delete("_ollama_cache")
})

0 comments on commit 85deefb

Please sign in to comment.