Skip to content

Commit

Permalink
Merge commit '6b42fc78a8d355bd2ba53490eda8af833bde17cc'
Browse files Browse the repository at this point in the history
  • Loading branch information
hadley committed Nov 27, 2024
2 parents 4538ecb + 6b42fc7 commit d92f5f3
Show file tree
Hide file tree
Showing 15 changed files with 92 additions and 92 deletions.
28 changes: 14 additions & 14 deletions R/chat.R
Original file line number Diff line number Diff line change
Expand Up @@ -133,18 +133,18 @@ Chat <- R6::R6Class("Chat",
#' @description Extract structured data
#' @param ... The input to send to the chatbot. Will typically include
#' the phrase "extract structured data".
#' @param spec A type specification for the extracted data. Should be
#' @param type A type specification for the extracted data. Should be
#' created with a [`type_()`][type_boolean] function.
#' @param echo Whether to emit the response to stdout as it is received.
#' Set to "text" to stream JSON data as it's generated (not supported by
#' all providers).
extract_data = function(..., spec, echo = "none") {
extract_data = function(..., type, echo = "none") {
turn <- user_turn(...)
echo <- check_echo(echo %||% private$echo)

coro::collect(private$submit_turns(
turn,
spec = spec,
type = type,
stream = echo != "none",
echo = echo
))
Expand All @@ -164,18 +164,18 @@ Chat <- R6::R6Class("Chat",
#' that resolves to an object matching the type specification.
#' @param ... The input to send to the chatbot. Will typically include
#' the phrase "extract structured data".
#' @param spec A type specification for the extracted data. Should be
#' @param type A type specification for the extracted data. Should be
#' created with a [`type_()`][type_boolean] function.
#' @param echo Whether to emit the response to stdout as it is received.
#' Set to "text" to stream JSON data as it's generated (not supported by
#' all providers).
extract_data_async = function(..., spec, echo = "none") {
extract_data_async = function(..., type, echo = "none") {
turn <- user_turn(...)
echo <- check_echo(echo %||% private$echo)

done <- coro::async_collect(private$submit_turns_async(
turn,
spec = spec,
type = type,
stream = echo != "none",
echo = echo
))
Expand Down Expand Up @@ -302,7 +302,7 @@ Chat <- R6::R6Class("Chat",

# If stream = TRUE, yields completion deltas. If stream = FALSE, yields
# complete assistant turns.
submit_turns = generator_method(function(self, private, user_turn, stream, echo, spec = NULL) {
submit_turns = generator_method(function(self, private, user_turn, stream, echo, type = NULL) {

if (echo == "all") {
cat_line(format(user_turn), prefix = "> ")
Expand All @@ -313,7 +313,7 @@ Chat <- R6::R6Class("Chat",
mode = if (stream) "stream" else "value",
turns = c(private$.turns, list(user_turn)),
tools = private$tools,
spec = spec
type = type
)
emit <- emitter(echo)

Expand All @@ -331,7 +331,7 @@ Chat <- R6::R6Class("Chat",

result <- stream_merge_chunks(private$provider, result, chunk)
}
turn <- value_turn(private$provider, result, has_spec = !is.null(spec))
turn <- value_turn(private$provider, result, has_type = !is.null(type))

# Ensure turns always end in a newline
if (any_text) {
Expand All @@ -345,7 +345,7 @@ Chat <- R6::R6Class("Chat",
cat_line(formatted, prefix = "< ")
}
} else {
turn <- value_turn(private$provider, response, has_spec = !is.null(spec))
turn <- value_turn(private$provider, response, has_type = !is.null(type))
text <- turn@text
if (!is.null(text)) {
text <- paste0(text, "\n")
Expand All @@ -364,13 +364,13 @@ Chat <- R6::R6Class("Chat",

# If stream = TRUE, yields completion deltas. If stream = FALSE, yields
# complete assistant turns.
submit_turns_async = async_generator_method(function(self, private, user_turn, stream, echo, spec = NULL) {
submit_turns_async = async_generator_method(function(self, private, user_turn, stream, echo, type = NULL) {
response <- chat_perform(
provider = private$provider,
mode = if (stream) "async-stream" else "async-value",
turns = c(private$.turns, list(user_turn)),
tools = private$tools,
spec = spec
type = type
)
emit <- emitter(echo)

Expand All @@ -387,7 +387,7 @@ Chat <- R6::R6Class("Chat",

result <- stream_merge_chunks(private$provider, result, chunk)
}
turn <- value_turn(private$provider, result, has_spec = !is.null(spec))
turn <- value_turn(private$provider, result, has_type = !is.null(type))

# Ensure turns always end in a newline
if (any_text) {
Expand All @@ -397,7 +397,7 @@ Chat <- R6::R6Class("Chat",
} else {
result <- await(response)

turn <- value_turn(private$provider, result, has_spec = !is.null(spec))
turn <- value_turn(private$provider, result, has_type = !is.null(type))
text <- turn@text
if (!is.null(text)) {
text <- paste0(text, "\n")
Expand Down
4 changes: 2 additions & 2 deletions R/httr2.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ chat_perform <- function(provider,
mode = c("value", "stream", "async-stream", "async-value"),
turns,
tools = list(),
spec = NULL,
type = NULL,
extra_args = list()) {

mode <- arg_match(mode)
Expand All @@ -16,7 +16,7 @@ chat_perform <- function(provider,
turns = turns,
tools = tools,
stream = stream,
spec = spec,
type = type,
extra_args = extra_args
)

Expand Down
6 changes: 3 additions & 3 deletions R/provider-azure.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ method(chat_request, ProviderAzure) <- function(provider,
stream = TRUE,
turns = list(),
tools = list(),
spec = NULL,
type = NULL,
extra_args = list()) {

req <- request(provider@base_url)
Expand All @@ -95,12 +95,12 @@ method(chat_request, ProviderAzure) <- function(provider,
tools <- as_json(provider, unname(tools))
extra_args <- utils::modifyList(provider@extra_args, extra_args)

if (!is.null(spec)) {
if (!is.null(type)) {
response_format <- list(
type = "json_schema",
json_schema = list(
name = "structured_data",
schema = as_json(provider, spec),
schema = as_json(provider, type),
strict = TRUE
)
)
Expand Down
10 changes: 5 additions & 5 deletions R/provider-bedrock.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ method(chat_request, ProviderBedrock) <- function(provider,
stream = TRUE,
turns = list(),
tools = list(),
spec = NULL,
type = NULL,
extra_args = list()) {

req <- request(paste0(
Expand Down Expand Up @@ -90,12 +90,12 @@ method(chat_request, ProviderBedrock) <- function(provider,

messages <- compact(as_json(provider, turns))

if (!is.null(spec)) {
if (!is.null(type)) {
tool_def <- ToolDef(
fun = function(...) {},
name = "structured_tool_call__",
description = "Extract structured data",
arguments = type_object(data = spec)
arguments = type_object(data = type)
)
tools[[tool_def@name]] <- tool_def
tool_choice <- list(tool = list(name = tool_def@name))
Expand Down Expand Up @@ -188,12 +188,12 @@ method(stream_merge_chunks, ProviderBedrock) <- function(provider, result, chunk
result
}

method(value_turn, ProviderBedrock) <- function(provider, result, has_spec = FALSE) {
method(value_turn, ProviderBedrock) <- function(provider, result, has_type = FALSE) {
contents <- lapply(result$output$message$content, function(content) {
if (has_name(content, "text")) {
ContentText(content$text)
} else if (has_name(content, "toolUse")) {
if (has_spec) {
if (has_type) {
ContentJson(content$toolUse$input$data)
} else {
ContentToolRequest(
Expand Down
10 changes: 5 additions & 5 deletions R/provider-claude.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ method(chat_request, ProviderClaude) <- function(provider,
stream = TRUE,
turns = list(),
tools = list(),
spec = NULL,
type = NULL,
extra_args = list()) {

req <- request(provider@base_url)
Expand Down Expand Up @@ -108,12 +108,12 @@ method(chat_request, ProviderClaude) <- function(provider,

messages <- compact(as_json(provider, turns))

if (!is.null(spec)) {
if (!is.null(type)) {
tool_def <- ToolDef(
fun = function(...) {},
name = "_structured_tool_call",
description = "Extract structured data",
arguments = type_object(data = spec)
arguments = type_object(data = type)
)
tools[[tool_def@name]] <- tool_def
tool_choice <- list(type = "tool", name = tool_def@name)
Expand Down Expand Up @@ -187,12 +187,12 @@ method(stream_merge_chunks, ProviderClaude) <- function(provider, result, chunk)
result
}

method(value_turn, ProviderClaude) <- function(provider, result, has_spec = FALSE) {
method(value_turn, ProviderClaude) <- function(provider, result, has_type = FALSE) {
contents <- lapply(result$content, function(content) {
if (content$type == "text") {
ContentText(content$text)
} else if (content$type == "tool_use") {
if (has_spec) {
if (has_type) {
ContentJson(content$input$data)
} else {
if (is_string(content$input)) {
Expand Down
6 changes: 3 additions & 3 deletions R/provider-cortex.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ method(chat_request, ProviderCortex) <- function(provider,
stream = TRUE,
turns = list(),
tools = list(),
spec = NULL,
type = NULL,
extra_args = list()) {
if (length(tools) != 0) {
cli::cli_abort("Tools are not supported by Cortex.")
}
if (!is.null(spec) != 0) {
if (!is.null(type) != 0) {
cli::cli_abort("Structured data extraction is not supported by Cortex.")
}

Expand Down Expand Up @@ -234,7 +234,7 @@ cortex_chunk_to_message <- function(x) {
}
}

method(value_turn, ProviderCortex) <- function(provider, result, has_spec = FALSE) {
method(value_turn, ProviderCortex) <- function(provider, result, has_type = FALSE) {
if (!is_named(result)) { # streaming
role <- "assistant"
content <- result
Expand Down
6 changes: 3 additions & 3 deletions R/provider-databricks.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ method(chat_request, ProviderDatabricks) <- function(provider,
stream = TRUE,
turns = list(),
tools = list(),
spec = NULL,
type = NULL,
extra_args = list()) {
req <- request(provider@base_url)
# Note: this API endpoint is undocumented and seems to exist primarily for
Expand All @@ -95,12 +95,12 @@ method(chat_request, ProviderDatabricks) <- function(provider,
tools <- as_json(provider, unname(tools))
extra_args <- utils::modifyList(provider@extra_args, extra_args)

if (!is.null(spec)) {
if (!is.null(type)) {
response_format <- list(
type = "json_schema",
json_schema = list(
name = "structured_data",
schema = as_json(provider, spec),
schema = as_json(provider, type),
strict = TRUE
)
)
Expand Down
10 changes: 5 additions & 5 deletions R/provider-gemini.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ method(chat_request, ProviderGemini) <- function(provider,
stream = TRUE,
turns = list(),
tools = list(),
spec = NULL,
type = NULL,
extra_args = list()) {


Expand Down Expand Up @@ -83,10 +83,10 @@ method(chat_request, ProviderGemini) <- function(provider,
system <- list(parts = list(text = ""))
}

if (!is.null(spec)) {
if (!is.null(type)) {
generation_config <- list(
response_mime_type = "application/json",
response_schema = as_json(provider, spec)
response_schema = as_json(provider, type)
)
} else {
generation_config <- NULL
Expand Down Expand Up @@ -134,12 +134,12 @@ method(stream_merge_chunks, ProviderGemini) <- function(provider, result, chunk)
merge_dicts(result, chunk)
}
}
method(value_turn, ProviderGemini) <- function(provider, result, has_spec = FALSE) {
method(value_turn, ProviderGemini) <- function(provider, result, has_type = FALSE) {
message <- result$candidates[[1]]$content

contents <- lapply(message$parts, function(content) {
if (has_name(content, "text")) {
if (has_spec) {
if (has_type) {
data <- jsonlite::parse_json(content$text)
ContentJson(data)
} else {
Expand Down
10 changes: 5 additions & 5 deletions R/provider-openai.R
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ method(chat_request, ProviderOpenAI) <- function(provider,
stream = TRUE,
turns = list(),
tools = list(),
spec = NULL,
type = NULL,
extra_args = list()) {

req <- request(provider@base_url)
Expand All @@ -115,12 +115,12 @@ method(chat_request, ProviderOpenAI) <- function(provider,
tools <- as_json(provider, unname(tools))
extra_args <- utils::modifyList(provider@extra_args, extra_args)

if (!is.null(spec)) {
if (!is.null(type)) {
response_format <- list(
type = "json_schema",
json_schema = list(
name = "structured_data",
schema = as_json(provider, spec),
schema = as_json(provider, type),
strict = TRUE
)
)
Expand Down Expand Up @@ -171,14 +171,14 @@ method(stream_merge_chunks, ProviderOpenAI) <- function(provider, result, chunk)
merge_dicts(result, chunk)
}
}
method(value_turn, ProviderOpenAI) <- function(provider, result, has_spec = FALSE) {
method(value_turn, ProviderOpenAI) <- function(provider, result, has_type = FALSE) {
if (has_name(result$choices[[1]], "delta")) { # streaming
message <- result$choices[[1]]$delta
} else {
message <- result$choices[[1]]$message
}

if (has_spec) {
if (has_type) {
json <- jsonlite::parse_json(message$content[[1]])
content <- list(ContentJson(json))
} else {
Expand Down
2 changes: 1 addition & 1 deletion R/provider.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Provider <- new_class(
# Create a request------------------------------------

chat_request <- new_generic("chat_request", "provider",
function(provider, stream = TRUE, turns = list(), tools = list(), spec = NULL, extra_args = list()) {
function(provider, stream = TRUE, turns = list(), tools = list(), type = NULL, extra_args = list()) {
S7_dispatch()
}
)
Expand Down
Loading

0 comments on commit d92f5f3

Please sign in to comment.