diff --git a/R/chat.R b/R/chat.R index a8f53e4f..754c9500 100644 --- a/R/chat.R +++ b/R/chat.R @@ -133,18 +133,18 @@ Chat <- R6::R6Class("Chat", #' @description Extract structured data #' @param ... The input to send to the chatbot. Will typically include #' the phrase "extract structured data". - #' @param spec A type specification for the extracted data. Should be + #' @param type A type specification for the extracted data. Should be #' created with a [`type_()`][type_boolean] function. #' @param echo Whether to emit the response to stdout as it is received. #' Set to "text" to stream JSON data as it's generated (not supported by #' all providers). - extract_data = function(..., spec, echo = "none") { + extract_data = function(..., type, echo = "none") { turn <- user_turn(...) echo <- check_echo(echo %||% private$echo) coro::collect(private$submit_turns( turn, - spec = spec, + type = type, stream = echo != "none", echo = echo )) @@ -164,18 +164,18 @@ Chat <- R6::R6Class("Chat", #' that resolves to an object matching the type specification. #' @param ... The input to send to the chatbot. Will typically include #' the phrase "extract structured data". - #' @param spec A type specification for the extracted data. Should be + #' @param type A type specification for the extracted data. Should be #' created with a [`type_()`][type_boolean] function. #' @param echo Whether to emit the response to stdout as it is received. #' Set to "text" to stream JSON data as it's generated (not supported by #' all providers). - extract_data_async = function(..., spec, echo = "none") { + extract_data_async = function(..., type, echo = "none") { turn <- user_turn(...) echo <- check_echo(echo %||% private$echo) done <- coro::async_collect(private$submit_turns_async( turn, - spec = spec, + type = type, stream = echo != "none", echo = echo )) @@ -302,7 +302,7 @@ Chat <- R6::R6Class("Chat", # If stream = TRUE, yields completion deltas. If stream = FALSE, yields # complete assistant turns. - submit_turns = generator_method(function(self, private, user_turn, stream, echo, spec = NULL) { + submit_turns = generator_method(function(self, private, user_turn, stream, echo, type = NULL) { if (echo == "all") { cat_line(format(user_turn), prefix = "> ") @@ -313,7 +313,7 @@ Chat <- R6::R6Class("Chat", mode = if (stream) "stream" else "value", turns = c(private$.turns, list(user_turn)), tools = private$tools, - spec = spec + type = type ) emit <- emitter(echo) @@ -331,7 +331,7 @@ Chat <- R6::R6Class("Chat", result <- stream_merge_chunks(private$provider, result, chunk) } - turn <- value_turn(private$provider, result, has_spec = !is.null(spec)) + turn <- value_turn(private$provider, result, has_type = !is.null(type)) # Ensure turns always end in a newline if (any_text) { @@ -345,7 +345,7 @@ Chat <- R6::R6Class("Chat", cat_line(formatted, prefix = "< ") } } else { - turn <- value_turn(private$provider, response, has_spec = !is.null(spec)) + turn <- value_turn(private$provider, response, has_type = !is.null(type)) text <- turn@text if (!is.null(text)) { text <- paste0(text, "\n") @@ -364,13 +364,13 @@ Chat <- R6::R6Class("Chat", # If stream = TRUE, yields completion deltas. If stream = FALSE, yields # complete assistant turns. - submit_turns_async = async_generator_method(function(self, private, user_turn, stream, echo, spec = NULL) { + submit_turns_async = async_generator_method(function(self, private, user_turn, stream, echo, type = NULL) { response <- chat_perform( provider = private$provider, mode = if (stream) "async-stream" else "async-value", turns = c(private$.turns, list(user_turn)), tools = private$tools, - spec = spec + type = type ) emit <- emitter(echo) @@ -387,7 +387,7 @@ Chat <- R6::R6Class("Chat", result <- stream_merge_chunks(private$provider, result, chunk) } - turn <- value_turn(private$provider, result, has_spec = !is.null(spec)) + turn <- value_turn(private$provider, result, has_type = !is.null(type)) # Ensure turns always end in a newline if (any_text) { @@ -397,7 +397,7 @@ Chat <- R6::R6Class("Chat", } else { result <- await(response) - turn <- value_turn(private$provider, result, has_spec = !is.null(spec)) + turn <- value_turn(private$provider, result, has_type = !is.null(type)) text <- turn@text if (!is.null(text)) { text <- paste0(text, "\n") diff --git a/R/httr2.R b/R/httr2.R index 20e52b0e..b28cee01 100644 --- a/R/httr2.R +++ b/R/httr2.R @@ -5,7 +5,7 @@ chat_perform <- function(provider, mode = c("value", "stream", "async-stream", "async-value"), turns, tools = list(), - spec = NULL, + type = NULL, extra_args = list()) { mode <- arg_match(mode) @@ -16,7 +16,7 @@ chat_perform <- function(provider, turns = turns, tools = tools, stream = stream, - spec = spec, + type = type, extra_args = extra_args ) diff --git a/R/provider-azure.R b/R/provider-azure.R index fa6c3b49..97bdc369 100644 --- a/R/provider-azure.R +++ b/R/provider-azure.R @@ -78,7 +78,7 @@ method(chat_request, ProviderAzure) <- function(provider, stream = TRUE, turns = list(), tools = list(), - spec = NULL, + type = NULL, extra_args = list()) { req <- request(provider@base_url) @@ -95,12 +95,12 @@ method(chat_request, ProviderAzure) <- function(provider, tools <- as_json(provider, unname(tools)) extra_args <- utils::modifyList(provider@extra_args, extra_args) - if (!is.null(spec)) { + if (!is.null(type)) { response_format <- list( type = "json_schema", json_schema = list( name = "structured_data", - schema = as_json(provider, spec), + schema = as_json(provider, type), strict = TRUE ) ) diff --git a/R/provider-bedrock.R b/R/provider-bedrock.R index acea4c7a..9a779c23 100644 --- a/R/provider-bedrock.R +++ b/R/provider-bedrock.R @@ -58,7 +58,7 @@ method(chat_request, ProviderBedrock) <- function(provider, stream = TRUE, turns = list(), tools = list(), - spec = NULL, + type = NULL, extra_args = list()) { req <- request(paste0( @@ -90,12 +90,12 @@ method(chat_request, ProviderBedrock) <- function(provider, messages <- compact(as_json(provider, turns)) - if (!is.null(spec)) { + if (!is.null(type)) { tool_def <- ToolDef( fun = function(...) {}, name = "structured_tool_call__", description = "Extract structured data", - arguments = type_object(data = spec) + arguments = type_object(data = type) ) tools[[tool_def@name]] <- tool_def tool_choice <- list(tool = list(name = tool_def@name)) @@ -188,12 +188,12 @@ method(stream_merge_chunks, ProviderBedrock) <- function(provider, result, chunk result } -method(value_turn, ProviderBedrock) <- function(provider, result, has_spec = FALSE) { +method(value_turn, ProviderBedrock) <- function(provider, result, has_type = FALSE) { contents <- lapply(result$output$message$content, function(content) { if (has_name(content, "text")) { ContentText(content$text) } else if (has_name(content, "toolUse")) { - if (has_spec) { + if (has_type) { ContentJson(content$toolUse$input$data) } else { ContentToolRequest( diff --git a/R/provider-claude.R b/R/provider-claude.R index 240f984b..a66b2726 100644 --- a/R/provider-claude.R +++ b/R/provider-claude.R @@ -69,7 +69,7 @@ method(chat_request, ProviderClaude) <- function(provider, stream = TRUE, turns = list(), tools = list(), - spec = NULL, + type = NULL, extra_args = list()) { req <- request(provider@base_url) @@ -106,12 +106,12 @@ method(chat_request, ProviderClaude) <- function(provider, messages <- compact(as_json(provider, turns)) - if (!is.null(spec)) { + if (!is.null(type)) { tool_def <- ToolDef( fun = function(...) {}, name = "_structured_tool_call", description = "Extract structured data", - arguments = type_object(data = spec) + arguments = type_object(data = type) ) tools[[tool_def@name]] <- tool_def tool_choice <- list(type = "tool", name = tool_def@name) @@ -185,12 +185,12 @@ method(stream_merge_chunks, ProviderClaude) <- function(provider, result, chunk) result } -method(value_turn, ProviderClaude) <- function(provider, result, has_spec = FALSE) { +method(value_turn, ProviderClaude) <- function(provider, result, has_type = FALSE) { contents <- lapply(result$content, function(content) { if (content$type == "text") { ContentText(content$text) } else if (content$type == "tool_use") { - if (has_spec) { + if (has_type) { ContentJson(content$input$data) } else { if (is_string(content$input)) { diff --git a/R/provider-cortex.R b/R/provider-cortex.R index ae9c4da5..367381e5 100644 --- a/R/provider-cortex.R +++ b/R/provider-cortex.R @@ -99,12 +99,12 @@ method(chat_request, ProviderCortex) <- function(provider, stream = TRUE, turns = list(), tools = list(), - spec = NULL, + type = NULL, extra_args = list()) { if (length(tools) != 0) { cli::cli_abort("Tools are not supported by Cortex.") } - if (!is.null(spec) != 0) { + if (!is.null(type) != 0) { cli::cli_abort("Structured data extraction is not supported by Cortex.") } @@ -234,7 +234,7 @@ cortex_chunk_to_message <- function(x) { } } -method(value_turn, ProviderCortex) <- function(provider, result, has_spec = FALSE) { +method(value_turn, ProviderCortex) <- function(provider, result, has_type = FALSE) { if (!is_named(result)) { # streaming role <- "assistant" content <- result diff --git a/R/provider-databricks.R b/R/provider-databricks.R index df602b16..9c5ed28b 100644 --- a/R/provider-databricks.R +++ b/R/provider-databricks.R @@ -72,7 +72,7 @@ method(chat_request, ProviderDatabricks) <- function(provider, stream = TRUE, turns = list(), tools = list(), - spec = NULL, + type = NULL, extra_args = list()) { req <- request(provider@base_url) # Note: this API endpoint is undocumented and seems to exist primarily for @@ -95,12 +95,12 @@ method(chat_request, ProviderDatabricks) <- function(provider, tools <- as_json(provider, unname(tools)) extra_args <- utils::modifyList(provider@extra_args, extra_args) - if (!is.null(spec)) { + if (!is.null(type)) { response_format <- list( type = "json_schema", json_schema = list( name = "structured_data", - schema = as_json(provider, spec), + schema = as_json(provider, type), strict = TRUE ) ) diff --git a/R/provider-gemini.R b/R/provider-gemini.R index 4de26cd0..bbafcf37 100644 --- a/R/provider-gemini.R +++ b/R/provider-gemini.R @@ -50,7 +50,7 @@ method(chat_request, ProviderGemini) <- function(provider, stream = TRUE, turns = list(), tools = list(), - spec = NULL, + type = NULL, extra_args = list()) { @@ -78,10 +78,10 @@ method(chat_request, ProviderGemini) <- function(provider, system <- list(parts = list(text = "")) } - if (!is.null(spec)) { + if (!is.null(type)) { generation_config <- list( response_mime_type = "application/json", - response_schema = as_json(provider, spec) + response_schema = as_json(provider, type) ) } else { generation_config <- NULL @@ -129,12 +129,12 @@ method(stream_merge_chunks, ProviderGemini) <- function(provider, result, chunk) merge_dicts(result, chunk) } } -method(value_turn, ProviderGemini) <- function(provider, result, has_spec = FALSE) { +method(value_turn, ProviderGemini) <- function(provider, result, has_type = FALSE) { message <- result$candidates[[1]]$content contents <- lapply(message$parts, function(content) { if (has_name(content, "text")) { - if (has_spec) { + if (has_type) { data <- jsonlite::parse_json(content$text) ContentJson(data) } else { diff --git a/R/provider-openai.R b/R/provider-openai.R index accecbab..7b275a96 100644 --- a/R/provider-openai.R +++ b/R/provider-openai.R @@ -96,7 +96,7 @@ method(chat_request, ProviderOpenAI) <- function(provider, stream = TRUE, turns = list(), tools = list(), - spec = NULL, + type = NULL, extra_args = list()) { req <- request(provider@base_url) @@ -113,12 +113,12 @@ method(chat_request, ProviderOpenAI) <- function(provider, tools <- as_json(provider, unname(tools)) extra_args <- utils::modifyList(provider@extra_args, extra_args) - if (!is.null(spec)) { + if (!is.null(type)) { response_format <- list( type = "json_schema", json_schema = list( name = "structured_data", - schema = as_json(provider, spec), + schema = as_json(provider, type), strict = TRUE ) ) @@ -169,14 +169,14 @@ method(stream_merge_chunks, ProviderOpenAI) <- function(provider, result, chunk) merge_dicts(result, chunk) } } -method(value_turn, ProviderOpenAI) <- function(provider, result, has_spec = FALSE) { +method(value_turn, ProviderOpenAI) <- function(provider, result, has_type = FALSE) { if (has_name(result$choices[[1]], "delta")) { # streaming message <- result$choices[[1]]$delta } else { message <- result$choices[[1]]$message } - if (has_spec) { + if (has_type) { json <- jsonlite::parse_json(message$content[[1]]) content <- list(ContentJson(json)) } else { diff --git a/R/provider.R b/R/provider.R index b247c4ad..62b3eaa9 100644 --- a/R/provider.R +++ b/R/provider.R @@ -26,7 +26,7 @@ Provider <- new_class( # Create a request------------------------------------ chat_request <- new_generic("chat_request", "provider", - function(provider, stream = TRUE, turns = list(), tools = list(), spec = NULL, extra_args = list()) { + function(provider, stream = TRUE, turns = list(), tools = list(), type = NULL, extra_args = list()) { S7_dispatch() } ) diff --git a/man/Chat.Rd b/man/Chat.Rd index e6538f31..7be08582 100644 --- a/man/Chat.Rd +++ b/man/Chat.Rd @@ -193,7 +193,7 @@ will be used.} \subsection{Method \code{extract_data()}}{ Extract structured data \subsection{Usage}{ -\if{html}{\out{