Skip to content

Commit

Permalink
Merge pull request #11 from edgararuiz/updates
Browse files Browse the repository at this point in the history
Updates
  • Loading branch information
edgararuiz authored Sep 18, 2024
2 parents 85deefb + 67dbddd commit 7d036a9
Show file tree
Hide file tree
Showing 26 changed files with 296 additions and 185 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: mall
Title: Run multiple 'Large Language Model' predictions against a table, or
vectors
Version: 0.0.0.9005
Version: 0.0.0.9006
Authors@R:
person("Edgar", "Ruiz", , "[email protected]", role = c("aut", "cre"))
Description: Run multiple 'Large Language Model' predictions against a table. The
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,5 @@ importFrom(jsonlite,write_json)
importFrom(ollamar,chat)
importFrom(ollamar,list_models)
importFrom(ollamar,test_connection)
importFrom(utils,head)
importFrom(utils,menu)
18 changes: 12 additions & 6 deletions R/llm-classify.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,20 @@
#' Use a Large Language Model (LLM) to classify the provided text as one of the
#' options provided via the `labels` argument.
#'
#' @param .data A `data.frame` or `tbl` object that contains the text to be analyzed
#' @param .data A `data.frame` or `tbl` object that contains the text to be
#' analyzed
#' @param col The name of the field to analyze, supports `tidy-eval`
#' @param x A vector that contains the text to be analyzed
#' @param additional_prompt Inserts this text into the prompt sent to the LLM
#' @param pred_name A character vector with the name of the new column where the
#' prediction will be placed
#' @param labels A character vector with at least 2 labels to classify the text
#' as
#' @returns `llm_classify` returns a `data.frame` or `tbl` object. `llm_vec_classify`
#' returns a vector that is the same length as `x`.
#' @param preview It returns the R call that would have been used to run the
#' prediction. It only returns the first record in `x`. Defaults to `FALSE`
#' Applies to vector function only.
#' @returns `llm_classify` returns a `data.frame` or `tbl` object.
#' `llm_vec_classify` returns a vector that is the same length as `x`.
#' @export
llm_classify <- function(.data,
col,
Expand Down Expand Up @@ -43,12 +47,14 @@ llm_classify.data.frame <- function(.data,
#' @export
llm_vec_classify <- function(x,
labels,
additional_prompt = "") {
l_vec_prompt(
additional_prompt = "",
preview = FALSE) {
m_vec_prompt(
x = x,
prompt_label = "classify",
additional_prompt = additional_prompt,
labels = labels,
valid_resps = labels
valid_resps = labels,
preview = preview
)
}
12 changes: 6 additions & 6 deletions R/llm-custom.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
#'
#' @inheritParams llm_classify
#' @param prompt The prompt to append to each record sent to the LLM
#' @param valid_resps If the response from the LLM is not open, but deterministic,
#' provide the options in a vector. This function will set to `NA` any response
#' not in the options
#' @returns `llm_custom` returns a `data.frame` or `tbl` object. `llm_vec_custom`
#' returns a vector that is the same length as `x`.
#' @param valid_resps If the response from the LLM is not open, but
#' deterministic, provide the options in a vector. This function will set to
#' `NA` any response not in the options
#' @returns `llm_custom` returns a `data.frame` or `tbl` object.
#' `llm_vec_custom` returns a vector that is the same length as `x`.
#' @export
llm_custom <- function(
.data,
Expand Down Expand Up @@ -40,7 +40,7 @@ llm_custom.data.frame <- function(.data,
#' @rdname llm_custom
#' @export
llm_vec_custom <- function(x, prompt = "", valid_resps = NULL) {
l_vec_prompt(
m_vec_prompt(
x = x,
prompt = prompt,
valid_resps = valid_resps
Expand Down
12 changes: 7 additions & 5 deletions R/llm-extract.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
#' `labels` is a named vector, this function will use those names as the
#' new column names, if not, the function will use a sanitized version of
#' the content as the name.
#' @returns `llm_extract` returns a `data.frame` or `tbl` object. `llm_vec_extract`
#' returns a vector that is the same length as `x`.
#' @returns `llm_extract` returns a `data.frame` or `tbl` object.
#' `llm_vec_extract` returns a vector that is the same length as `x`.
#' @export
llm_extract <- function(.data,
col,
Expand Down Expand Up @@ -76,11 +76,13 @@ llm_extract.data.frame <- function(.data,
#' @export
llm_vec_extract <- function(x,
labels = c(),
additional_prompt = "") {
l_vec_prompt(
additional_prompt = "",
preview = FALSE) {
m_vec_prompt(
x = x,
prompt_label = "extract",
labels = labels,
additional_prompt = additional_prompt
additional_prompt = additional_prompt,
preview = preview
)
}
8 changes: 5 additions & 3 deletions R/llm-sentiment.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,14 @@ globalVariables("ai_analyze_sentiment")
#' @export
llm_vec_sentiment <- function(x,
options = c("positive", "negative", "neutral"),
additional_prompt = "") {
l_vec_prompt(
additional_prompt = "",
preview = FALSE) {
m_vec_prompt(
x = x,
prompt_label = "sentiment",
additional_prompt = additional_prompt,
valid_resps = options,
options = options
options = options,
preview = preview
)
}
8 changes: 5 additions & 3 deletions R/llm-summarize.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,13 @@ globalVariables("ai_summarize")
#' @export
llm_vec_summarize <- function(x,
max_words = 10,
additional_prompt = "") {
l_vec_prompt(
additional_prompt = "",
preview = FALSE) {
m_vec_prompt(
x = x,
prompt_label = "summarize",
additional_prompt = additional_prompt,
max_words = max_words
max_words = max_words,
preview = preview
)
}
8 changes: 5 additions & 3 deletions R/llm-translate.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ llm_translate.data.frame <- function(.data,
llm_vec_translate <- function(
x,
language,
additional_prompt = "") {
l_vec_prompt(
additional_prompt = "",
preview = FALSE) {
m_vec_prompt(
x = x,
prompt_label = "translate",
additional_prompt = additional_prompt,
language = language
language = language,
preview = preview
)
}
6 changes: 3 additions & 3 deletions R/llm-use.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
#' R session
#' @param .cache The path to save model results, so they can be re-used if
#' the same operation is ran again. To turn off, set this argument to an empty
#' character: `""`. 'It defaults to '_mall_cache'. If this argument is left `NULL`
#' when calling this function, no changes to the path will be made.
#' character: `""`. 'It defaults to '_mall_cache'. If this argument is left
#' `NULL` when calling this function, no changes to the path will be made.
#'
#' @returns A `mall_defaults` object
#'
Expand Down Expand Up @@ -69,7 +69,7 @@ llm_use <- function(
...
)
}
if (!.silent | not_init) {
if (!.silent || not_init) {
print(defaults_get())
}
invisible(defaults_get())
Expand Down
80 changes: 24 additions & 56 deletions R/m-backend-prompt.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@ m_backend_prompt <- function(backend, additional) {
m_backend_prompt.mall_defaults <- function(backend, additional = "") {
list(
sentiment = function(options) {
options <- paste0(options, collapse = ", ")
options <- process_labels(
x = options,
if_character = "Return only one of the following answers: {x}",
if_formula = "- If the text is {f_lhs(x)}, return {f_rhs(x)}"
)
list(
list(
role = "user",
content = glue(paste(
"You are a helpful sentiment engine.",
"Return only one of the following answers: {options}.",
"{options}.",
"No capitalization. No explanations.",
"{additional}",
"The answer is based on the following text:\n{{x}}"
Expand All @@ -37,13 +41,17 @@ m_backend_prompt.mall_defaults <- function(backend, additional = "") {
)
},
classify = function(labels) {
labels <- paste0(labels, collapse = ", ")
labels <- process_labels(
x = labels,
if_character = "Determine if the text refers to one of the following: {x}",
if_formula = "- For {f_lhs(x)}, return {f_rhs(x)}"
)
list(
list(
role = "user",
content = glue(paste(
"You are a helpful classification engine.",
"Determine if the text refers to one of the following: {labels}.",
"{labels}.",
"No capitalization. No explanations.",
"{additional}",
"The answer is based on the following text:\n{{x}}"
Expand All @@ -54,8 +62,6 @@ m_backend_prompt.mall_defaults <- function(backend, additional = "") {
extract = function(labels) {
no_labels <- length(labels)
col_labels <- paste0(labels, collapse = ", ")
json_labels <- paste0("\"", labels, "\":your answer", collapse = ",")
json_labels <- paste0("{{", json_labels, "}}")
plural <- ifelse(no_labels > 1, "s", "")
text_multi <- ifelse(
no_labels > 1,
Expand Down Expand Up @@ -94,55 +100,17 @@ m_backend_prompt.mall_defaults <- function(backend, additional = "") {
)
}

l_vec_prompt <- function(x,
prompt_label = "",
additional_prompt = "",
valid_resps = NULL,
prompt = NULL,
...) {
# Initializes session LLM
backend <- llm_use(.silent = TRUE, .force = FALSE)
# If there is no 'prompt', then assumes that we're looking for a
# prompt label (sentiment, classify, etc) to set 'prompt'
if (is.null(prompt)) {
defaults <- m_backend_prompt(
backend = backend,
additional = additional_prompt
)
fn <- defaults[[prompt_label]]
prompt <- fn(...)
}
# If the prompt is a character, it will convert it to
# a list so it can be processed
if (!inherits(prompt, "list")) {
p_split <- strsplit(prompt, "\\{\\{x\\}\\}")[[1]]
if (length(p_split) == 1 && p_split == prompt) {
content <- glue("{prompt}\n{{x}}")
} else {
content <- prompt
}
prompt <- list(
list(role = "user", content = content)
)
}
# Submits final prompt to the LLM
resp <- m_backend_submit(
backend = backend,
x = x,
prompt = prompt
)
# Checks for invalid output and marks them as NA
if (!is.null(valid_resps)) {
errors <- !resp %in% valid_resps
resp[errors] <- NA
if (any(errors)) {
cli_alert_warning(
c(
"There were {sum(errors)} predictions with ",
"invalid output, they were coerced to NA"
)
)
}
all_formula <- function(x) {
all(map_lgl(x, inherits, "formula"))
}

process_labels <- function(x, if_character = "", if_formula = "") {
if (all_formula(x)) {
labels_mapped <- map_chr(x, \(x) glue(if_formula))
out <- paste0(labels_mapped, collapse = ", ")
} else {
x <- paste0(x, collapse = ", ")
out <- glue(if_character)
}
resp
out
}
Loading

0 comments on commit 7d036a9

Please sign in to comment.