From d586a26b8f8e42f64477c161c1ef1e15d6e6a4f7 Mon Sep 17 00:00:00 2001 From: Stainless Bot Date: Wed, 31 Jul 2024 04:08:26 +0000 Subject: [PATCH] feat: publish --- .devcontainer/Dockerfile | 23 + .devcontainer/devcontainer.json | 20 + .github/workflows/ci.yml | 41 + .github/workflows/create-releases.yml | 29 + .../handle-release-pr-title-edit.yml | 25 + .gitignore | 4 + .release-please-manifest.json | 3 + .stats.yml | 2 + CONTRIBUTING.md | 59 + LICENSE | 8 + README.md | 366 ++++ SECURITY.md | 27 + aliases.go | 9 + api.md | 34 + client.go | 112 ++ client_test.go | 211 ++ completion.go | 213 ++ completion_test.go | 47 + examples/.keep | 4 + field.go | 50 + go.mod | 11 + go.sum | 12 + internal/apierror/apierror.go | 53 + internal/apiform/encoder.go | 381 ++++ internal/apiform/form.go | 5 + internal/apiform/form_test.go | 440 ++++ internal/apiform/tag.go | 48 + internal/apijson/decoder.go | 668 ++++++ internal/apijson/encoder.go | 391 ++++ internal/apijson/field.go | 41 + internal/apijson/field_test.go | 66 + internal/apijson/json_test.go | 554 +++++ internal/apijson/port.go | 107 + internal/apijson/port_test.go | 178 ++ internal/apijson/registry.go | 27 + internal/apijson/tag.go | 47 + internal/apiquery/encoder.go | 341 ++++ internal/apiquery/query.go | 50 + internal/apiquery/query_test.go | 335 +++ internal/apiquery/tag.go | 41 + internal/param/field.go | 29 + internal/requestconfig/requestconfig.go | 486 +++++ internal/testutil/testutil.go | 27 + internal/version.go | 5 + lib/.keep | 4 + message.go | 1788 +++++++++++++++++ message_test.go | 106 + option/requestoption.go | 245 +++ packages/ssestream/streaming.go | 172 ++ release-please-config.json | 70 + scripts/bootstrap | 16 + scripts/format | 8 + scripts/lint | 8 + scripts/mock | 41 + scripts/test | 56 + shared/union.go | 8 + usage_test.go | 39 + 57 files changed, 8191 insertions(+) create mode 100644 .devcontainer/Dockerfile create mode 100644 .devcontainer/devcontainer.json create mode 100644 .github/workflows/ci.yml create mode 100644 .github/workflows/create-releases.yml create mode 100644 .github/workflows/handle-release-pr-title-edit.yml create mode 100644 .gitignore create mode 100644 .release-please-manifest.json create mode 100644 .stats.yml create mode 100644 CONTRIBUTING.md create mode 100644 LICENSE create mode 100644 README.md create mode 100644 SECURITY.md create mode 100644 aliases.go create mode 100644 api.md create mode 100644 client.go create mode 100644 client_test.go create mode 100644 completion.go create mode 100644 completion_test.go create mode 100644 examples/.keep create mode 100644 field.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 internal/apierror/apierror.go create mode 100644 internal/apiform/encoder.go create mode 100644 internal/apiform/form.go create mode 100644 internal/apiform/form_test.go create mode 100644 internal/apiform/tag.go create mode 100644 internal/apijson/decoder.go create mode 100644 internal/apijson/encoder.go create mode 100644 internal/apijson/field.go create mode 100644 internal/apijson/field_test.go create mode 100644 internal/apijson/json_test.go create mode 100644 internal/apijson/port.go create mode 100644 internal/apijson/port_test.go create mode 100644 internal/apijson/registry.go create mode 100644 internal/apijson/tag.go create mode 100644 internal/apiquery/encoder.go create mode 100644 internal/apiquery/query.go create mode 100644 internal/apiquery/query_test.go create mode 100644 internal/apiquery/tag.go create mode 100644 internal/param/field.go create mode 100644 internal/requestconfig/requestconfig.go create mode 100644 internal/testutil/testutil.go create mode 100644 internal/version.go create mode 100644 lib/.keep create mode 100644 message.go create mode 100644 message_test.go create mode 100644 option/requestoption.go create mode 100644 packages/ssestream/streaming.go create mode 100644 release-please-config.json create mode 100755 scripts/bootstrap create mode 100755 scripts/format create mode 100755 scripts/lint create mode 100755 scripts/mock create mode 100755 scripts/test create mode 100644 shared/union.go create mode 100644 usage_test.go diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 0000000..1aa883d --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,23 @@ +# syntax=docker/dockerfile:1 +FROM debian:bookworm-slim + +RUN apt-get update && apt-get install -y \ + libxkbcommon0 \ + ca-certificates \ + git \ + golang \ + unzip \ + libc++1 \ + vim \ + && apt-get clean autoclean + +# Ensure UTF-8 encoding +ENV LANG=C.UTF-8 +ENV LC_ALL=C.UTF-8 + +ENV GOPATH=/go +ENV PATH=$GOPATH/bin:$PATH + +WORKDIR /workspace + +COPY . /workspace diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000..d55fc4d --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,20 @@ +// For format details, see https://aka.ms/devcontainer.json. For config options, see the +// README at: https://github.com/devcontainers/templates/tree/main/src/debian +{ + "name": "Debian", + "build": { + "dockerfile": "Dockerfile" + } + + // Features to add to the dev container. More info: https://containers.dev/features. + // "features": {}, + + // Use 'forwardPorts' to make a list of ports inside the container available locally. + // "forwardPorts": [], + + // Configure tool-specific properties. + // "customizations": {}, + + // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. + // "remoteUser": "root" +} diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..5725bbd --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,41 @@ +name: CI +on: + push: + branches: + - main + pull_request: + branches: + - main + - next + +jobs: + lint: + name: lint + runs-on: ubuntu-latest + if: github.repository == 'anthropics/anthropic-sdk-go' + + steps: + - uses: actions/checkout@v4 + + - name: Setup go + uses: actions/setup-go@v5 + + - name: Run lints + run: ./scripts/lint + test: + name: test + runs-on: ubuntu-latest + if: github.repository == 'anthropics/anthropic-sdk-go' + + steps: + - uses: actions/checkout@v4 + + - name: Setup go + uses: actions/setup-go@v5 + + - name: Bootstrap + run: ./scripts/bootstrap + + - name: Run tests + run: ./scripts/test + diff --git a/.github/workflows/create-releases.yml b/.github/workflows/create-releases.yml new file mode 100644 index 0000000..1ada6e1 --- /dev/null +++ b/.github/workflows/create-releases.yml @@ -0,0 +1,29 @@ +name: Create releases +on: + schedule: + - cron: '0 5 * * *' # every day at 5am UTC + push: + branches: + - main + +jobs: + release: + name: release + if: github.ref == 'refs/heads/main' && github.repository == 'anthropics/anthropic-sdk-go' + runs-on: ubuntu-latest + environment: production-release + + steps: + - uses: actions/checkout@v4 + + - uses: stainless-api/trigger-release-please@v1 + id: release + with: + repo: ${{ github.event.repository.full_name }} + stainless-api-key: ${{ secrets.STAINLESS_API_KEY }} + + - name: Generate godocs + if: ${{ steps.release.outputs.releases_created }} + run: | + version=$(jq -r '. | to_entries[0] | .value' .release-please-manifest.json) + curl -X POST https://pkg.go.dev/fetch/github.com/anthropics/anthropic-sdk-go@v${version} diff --git a/.github/workflows/handle-release-pr-title-edit.yml b/.github/workflows/handle-release-pr-title-edit.yml new file mode 100644 index 0000000..d3c6cd5 --- /dev/null +++ b/.github/workflows/handle-release-pr-title-edit.yml @@ -0,0 +1,25 @@ +name: Handle release PR title edits +on: + pull_request: + types: + - edited + - unlabeled + +jobs: + update_pr_content: + name: Update pull request content + if: | + ((github.event.action == 'edited' && github.event.changes.title.from != github.event.pull_request.title) || + (github.event.action == 'unlabeled' && github.event.label.name == 'autorelease: custom version')) && + startsWith(github.event.pull_request.head.ref, 'release-please--') && + github.event.pull_request.state == 'open' && + github.event.sender.login != 'stainless-bot' && + github.event.sender.login != 'stainless-app' && + github.repository == 'anthropics/anthropic-sdk-go' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: stainless-api/trigger-release-please@v1 + with: + repo: ${{ github.event.repository.full_name }} + stainless-api-key: ${{ secrets.STAINLESS_API_KEY }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c6d0501 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +.prism.log +codegen.log +Brewfile.lock.json +.idea/ diff --git a/.release-please-manifest.json b/.release-please-manifest.json new file mode 100644 index 0000000..c476280 --- /dev/null +++ b/.release-please-manifest.json @@ -0,0 +1,3 @@ +{ + ".": "0.0.1-alpha.0" +} \ No newline at end of file diff --git a/.stats.yml b/.stats.yml new file mode 100644 index 0000000..42a8d19 --- /dev/null +++ b/.stats.yml @@ -0,0 +1,2 @@ +configured_endpoints: 2 +openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/anthropic-15089862b682046b13deff5bf5f09d786d9ec4aecc40b9f7ef40b84ef17d3348.yml diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..2214ed6 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,59 @@ +## Setting up the environment + +### Install Go 1.18+ + +Install go by following relevant directions [here](https://go.dev/doc/install). + +## Modifying/Adding code + +Most of the SDK is generated code, and any modified code will be overridden on the next generation. The +`examples/` directory is an exception and will never be overridden. + +## Adding and running examples + +All files in the `examples/` directory are not modified by the Stainless generator and can be freely edited or +added to. + +```bash +# add an example to examples//main.go + +package main + +func main() { + // ... +} +``` + +```bash +go run ./examples/ +``` + +## Using the repository from source + +To use a local version of this library from source in another project, edit the `go.mod` with a replace +directive. This can be done through the CLI with the following: + +```bash +go mod edit -replace github.com/anthropics/anthropic-sdk-go=/path/to/anthropic-sdk-go +``` + +## Running tests + +Most tests require you to [set up a mock server](https://github.com/stoplightio/prism) against the OpenAPI spec to run the tests. + +```bash +# you will need npm installed +npx prism mock path/to/your/openapi.yml +``` + +```bash +go test ./... +``` + +## Formatting + +This library uses the standard gofmt code formatter: + +```bash +gofmt -s -w . +``` diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..ac71a66 --- /dev/null +++ b/LICENSE @@ -0,0 +1,8 @@ +Copyright 2023 Anthropic, PBC. + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + diff --git a/README.md b/README.md new file mode 100644 index 0000000..6bef330 --- /dev/null +++ b/README.md @@ -0,0 +1,366 @@ +# Anthropic Go API Library + +Go Reference + +The Anthropic Go library provides convenient access to [the Anthropic REST +API](https://docs.anthropic.com/claude/reference/) from applications written in Go. The full API of this library can be found in [api.md](api.md). + +## Installation + + + +```go +import ( + "github.com/anthropics/anthropic-sdk-go" // imported as anthropic +) +``` + + + +Or to pin the version: + + + +```sh +go get -u 'github.com/anthropics/anthropic-sdk-go@v0.0.1-alpha.0' +``` + + + +## Requirements + +This library requires Go 1.18+. + +## Usage + +The full API of this library can be found in [api.md](api.md). + +```go +package main + +import ( + "context" + "fmt" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/option" +) + +func main() { + client := anthropic.NewClient( + option.WithAPIKey("my-anthropic-api-key"), // defaults to os.LookupEnv("ANTHROPIC_API_KEY") + ) + message, err := client.Messages.New(context.TODO(), anthropic.MessageNewParams{ + MaxTokens: anthropic.F(int64(1024)), + Messages: anthropic.F([]anthropic.MessageParam{{ + Role: anthropic.F(anthropic.MessageParamRoleUser), + Content: anthropic.F([]anthropic.MessageParamContentUnion{anthropic.TextBlockParam{Type: anthropic.F(anthropic.TextBlockParamTypeText), Text: anthropic.F("What is a quaternion?")}}), + }}), + Model: anthropic.F(anthropic.ModelClaude_3_5_Sonnet_20240620), + }) + if err != nil { + panic(err.Error()) + } + fmt.Printf("%+v\n", message.Content) +} + +``` + +### Request fields + +All request parameters are wrapped in a generic `Field` type, +which we use to distinguish zero values from null or omitted fields. + +This prevents accidentally sending a zero value if you forget a required parameter, +and enables explicitly sending `null`, `false`, `''`, or `0` on optional parameters. +Any field not specified is not sent. + +To construct fields with values, use the helpers `String()`, `Int()`, `Float()`, or most commonly, the generic `F[T]()`. +To send a null, use `Null[T]()`, and to send a nonconforming value, use `Raw[T](any)`. For example: + +```go +params := FooParams{ + Name: anthropic.F("hello"), + + // Explicitly send `"description": null` + Description: anthropic.Null[string](), + + Point: anthropic.F(anthropic.Point{ + X: anthropic.Int(0), + Y: anthropic.Int(1), + + // In cases where the API specifies a given type, + // but you want to send something else, use `Raw`: + Z: anthropic.Raw[int64](0.01), // sends a float + }), +} +``` + +### Response objects + +All fields in response structs are value types (not pointers or wrappers). + +If a given field is `null`, not present, or invalid, the corresponding field +will simply be its zero value. + +All response structs also include a special `JSON` field, containing more detailed +information about each property, which you can use like so: + +```go +if res.Name == "" { + // true if `"name"` is either not present or explicitly null + res.JSON.Name.IsNull() + + // true if the `"name"` key was not present in the repsonse JSON at all + res.JSON.Name.IsMissing() + + // When the API returns data that cannot be coerced to the expected type: + if res.JSON.Name.IsInvalid() { + raw := res.JSON.Name.Raw() + + legacyName := struct{ + First string `json:"first"` + Last string `json:"last"` + }{} + json.Unmarshal([]byte(raw), &legacyName) + name = legacyName.First + " " + legacyName.Last + } +} +``` + +These `.JSON` structs also include an `Extras` map containing +any properties in the json response that were not specified +in the struct. This can be useful for API features not yet +present in the SDK. + +```go +body := res.JSON.ExtraFields["my_unexpected_field"].Raw() +``` + +### RequestOptions + +This library uses the functional options pattern. Functions defined in the +`option` package return a `RequestOption`, which is a closure that mutates a +`RequestConfig`. These options can be supplied to the client or at individual +requests. For example: + +```go +client := anthropic.NewClient( + // Adds a header to every request made by the client + option.WithHeader("X-Some-Header", "custom_header_info"), +) + +client.Messages.New(context.TODO(), ..., + // Override the header + option.WithHeader("X-Some-Header", "some_other_custom_header_info"), + // Add an undocumented field to the request body, using sjson syntax + option.WithJSONSet("some.json.path", map[string]string{"my": "object"}), +) +``` + +See the [full list of request options](https://pkg.go.dev/github.com/anthropics/anthropic-sdk-go/option). + +### Pagination + +This library provides some conveniences for working with paginated list endpoints. + +You can use `.ListAutoPaging()` methods to iterate through items across all pages: + +Or you can use simple `.List()` methods to fetch a single page and receive a standard response object +with additional helper methods like `.GetNextPage()`, e.g.: + +### Errors + +When the API returns a non-success status code, we return an error with type +`*anthropic.Error`. This contains the `StatusCode`, `*http.Request`, and +`*http.Response` values of the request, as well as the JSON of the error body +(much like other response objects in the SDK). + +To handle errors, we recommend that you use the `errors.As` pattern: + +```go +_, err := client.Messages.New(context.TODO(), anthropic.MessageNewParams{ + MaxTokens: anthropic.F(int64(1024)), + Messages: anthropic.F([]anthropic.MessageParam{{ + Role: anthropic.F(anthropic.MessageParamRoleUser), + Content: anthropic.F([]anthropic.MessageParamContentUnion{anthropic.TextBlockParam{Type: anthropic.F(anthropic.TextBlockParamTypeText), Text: anthropic.F("What is a quaternion?")}}), + }}), + Model: anthropic.F(anthropic.ModelClaude_3_5_Sonnet_20240620), +}) +if err != nil { + var apierr *anthropic.Error + if errors.As(err, &apierr) { + println(string(apierr.DumpRequest(true))) // Prints the serialized HTTP request + println(string(apierr.DumpResponse(true))) // Prints the serialized HTTP response + } + panic(err.Error()) // GET "/v1/messages": 400 Bad Request { ... } +} +``` + +When other errors occur, they are returned unwrapped; for example, +if HTTP transport fails, you might receive `*url.Error` wrapping `*net.OpError`. + +### Timeouts + +Requests do not time out by default; use context to configure a timeout for a request lifecycle. + +Note that if a request is [retried](#retries), the context timeout does not start over. +To set a per-retry timeout, use `option.WithRequestTimeout()`. + +```go +// This sets the timeout for the request, including all the retries. +ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) +defer cancel() +client.Messages.New( + ctx, + anthropic.MessageNewParams{ + MaxTokens: anthropic.F(int64(1024)), + Messages: anthropic.F([]anthropic.MessageParam{{ + Role: anthropic.F(anthropic.MessageParamRoleUser), + Content: anthropic.F([]anthropic.MessageParamContentUnion{anthropic.TextBlockParam{Type: anthropic.F(anthropic.TextBlockParamTypeText), Text: anthropic.F("What is a quaternion?")}}), + }}), + Model: anthropic.F(anthropic.ModelClaude_3_5_Sonnet_20240620), + }, + // This sets the per-retry timeout + option.WithRequestTimeout(20*time.Second), +) +``` + +### File uploads + +Request parameters that correspond to file uploads in multipart requests are typed as +`param.Field[io.Reader]`. The contents of the `io.Reader` will by default be sent as a multipart form +part with the file name of "anonymous_file" and content-type of "application/octet-stream". + +The file name and content-type can be customized by implementing `Name() string` or `ContentType() +string` on the run-time type of `io.Reader`. Note that `os.File` implements `Name() string`, so a +file returned by `os.Open` will be sent with the file name on disk. + +We also provide a helper `anthropic.FileParam(reader io.Reader, filename string, contentType string)` +which can be used to wrap any `io.Reader` with the appropriate file name and content type. + +### Retries + +Certain errors will be automatically retried 2 times by default, with a short exponential backoff. +We retry by default all connection errors, 408 Request Timeout, 409 Conflict, 429 Rate Limit, +and >=500 Internal errors. + +You can use the `WithMaxRetries` option to configure or disable this: + +```go +// Configure the default for all requests: +client := anthropic.NewClient( + option.WithMaxRetries(0), // default is 2 +) + +// Override per-request: +client.Messages.New( + context.TODO(), + anthropic.MessageNewParams{ + MaxTokens: anthropic.F(int64(1024)), + Messages: anthropic.F([]anthropic.MessageParam{{ + Role: anthropic.F(anthropic.MessageParamRoleUser), + Content: anthropic.F([]anthropic.MessageParamContentUnion{anthropic.TextBlockParam{Type: anthropic.F(anthropic.TextBlockParamTypeText), Text: anthropic.F("What is a quaternion?")}}), + }}), + Model: anthropic.F(anthropic.ModelClaude_3_5_Sonnet_20240620), + }, + option.WithMaxRetries(5), +) +``` + +### Making custom/undocumented requests + +This library is typed for convenient access to the documented API. If you need to access undocumented +endpoints, params, or response properties, the library can still be used. + +#### Undocumented endpoints + +To make requests to undocumented endpoints, you can use `client.Get`, `client.Post`, and other HTTP verbs. +`RequestOptions` on the client, such as retries, will be respected when making these requests. + +```go +var ( + // params can be an io.Reader, a []byte, an encoding/json serializable object, + // or a "…Params" struct defined in this library. + params map[string]interface{} + + // result can be an []byte, *http.Response, a encoding/json deserializable object, + // or a model defined in this library. + result *http.Response +) +err := client.Post(context.Background(), "/unspecified", params, &result) +if err != nil { + … +} +``` + +#### Undocumented request params + +To make requests using undocumented parameters, you may use either the `option.WithQuerySet()` +or the `option.WithJSONSet()` methods. + +```go +params := FooNewParams{ + ID: anthropic.F("id_xxxx"), + Data: anthropic.F(FooNewParamsData{ + FirstName: anthropic.F("John"), + }), +} +client.Foo.New(context.Background(), params, option.WithJSONSet("data.last_name", "Doe")) +``` + +#### Undocumented response properties + +To access undocumented response properties, you may either access the raw JSON of the response as a string +with `result.JSON.RawJSON()`, or get the raw JSON of a particular field on the result with +`result.JSON.Foo.Raw()`. + +Any fields that are not present on the response struct will be saved and can be accessed by `result.JSON.ExtraFields()` which returns the extra fields as a `map[string]Field`. + +### Middleware + +We provide `option.WithMiddleware` which applies the given +middleware to requests. + +```go +func Logger(req *http.Request, next option.MiddlewareNext) (res *http.Response, err error) { + // Before the request + start := time.Now() + LogReq(req) + + // Forward the request to the next handler + res, err = next(req) + + // Handle stuff after the request + end := time.Now() + LogRes(res, err, start - end) + + return res, err +} + +client := anthropic.NewClient( + option.WithMiddleware(Logger), +) +``` + +When multiple middlewares are provided as variadic arguments, the middlewares +are applied left to right. If `option.WithMiddleware` is given +multiple times, for example first in the client then the method, the +middleware in the client will run first and the middleware given in the method +will run next. + +You may also replace the default `http.Client` with +`option.WithHTTPClient(client)`. Only one http client is +accepted (this overwrites any previous client) and receives requests after any +middleware has been applied. + +## Semantic versioning + +This package generally follows [SemVer](https://semver.org/spec/v2.0.0.html) conventions, though certain backwards-incompatible changes may be released as minor versions: + +1. Changes to library internals which are technically public but not intended or documented for external use. _(Please open a GitHub issue to let us know if you are relying on such internals)_. +2. Changes that we do not expect to impact the vast majority of users in practice. + +We take backwards-compatibility seriously and work hard to ensure you can rely on a smooth upgrade experience. + +We are keen for your feedback; please open an [issue](https://www.github.com/anthropics/anthropic-sdk-go/issues) with questions, bugs, or suggestions. diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000..2281532 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,27 @@ +# Security Policy + +## Reporting Security Issues + +This SDK is generated by [Stainless Software Inc](http://stainlessapi.com). Stainless takes security seriously, and encourages you to report any security vulnerability promptly so that appropriate action can be taken. + +To report a security issue, please contact the Stainless team at security@stainlessapi.com. + +## Responsible Disclosure + +We appreciate the efforts of security researchers and individuals who help us maintain the security of +SDKs we generate. If you believe you have found a security vulnerability, please adhere to responsible +disclosure practices by allowing us a reasonable amount of time to investigate and address the issue +before making any information public. + +## Reporting Non-SDK Related Security Issues + +If you encounter security issues that are not directly related to SDKs but pertain to the services +or products provided by Anthropic please follow the respective company's security reporting guidelines. + +### Anthropic Terms and Policies + +Please contact support@anthropic.com for any questions or concerns regarding security of our services. + +--- + +Thank you for helping us keep the SDKs and systems they interact with secure. diff --git a/aliases.go b/aliases.go new file mode 100644 index 0000000..1439f7d --- /dev/null +++ b/aliases.go @@ -0,0 +1,9 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package anthropic + +import ( + "github.com/anthropics/anthropic-sdk-go/internal/apierror" +) + +type Error = apierror.Error diff --git a/api.md b/api.md new file mode 100644 index 0000000..7a3238a --- /dev/null +++ b/api.md @@ -0,0 +1,34 @@ +# Messages + +Params Types: + +- anthropic.ImageBlockParam +- anthropic.MessageParam +- anthropic.Model +- anthropic.TextBlockParam +- anthropic.ToolParam +- anthropic.ToolResultBlockParam +- anthropic.ToolUseBlockParam + +Response Types: + +- anthropic.ContentBlock +- anthropic.InputJSONDelta +- anthropic.Message +- anthropic.MessageDeltaUsage +- anthropic.Model +- anthropic.ContentBlockDeltaEvent +- anthropic.ContentBlockStartEvent +- anthropic.ContentBlockStopEvent +- anthropic.MessageDeltaEvent +- anthropic.MessageStartEvent +- anthropic.MessageStopEvent +- anthropic.MessageStreamEvent +- anthropic.TextBlock +- anthropic.TextDelta +- anthropic.ToolUseBlock +- anthropic.Usage + +Methods: + +- client.Messages.New(ctx context.Context, body anthropic.MessageNewParams) (anthropic.Message, error) diff --git a/client.go b/client.go new file mode 100644 index 0000000..0cc61ed --- /dev/null +++ b/client.go @@ -0,0 +1,112 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package anthropic + +import ( + "context" + "net/http" + "os" + + "github.com/anthropics/anthropic-sdk-go/internal/requestconfig" + "github.com/anthropics/anthropic-sdk-go/option" +) + +// Client creates a struct with services and top level methods that help with +// interacting with the anthropic API. You should not instantiate this client +// directly, and instead use the [NewClient] method instead. +type Client struct { + Options []option.RequestOption + Completions *CompletionService + Messages *MessageService +} + +// NewClient generates a new client with the default option read from the +// environment (ANTHROPIC_API_KEY, ANTHROPIC_AUTH_TOKEN). The option passed in as +// arguments are applied after these default arguments, and all option will be +// passed down to the services and requests that this client makes. +func NewClient(opts ...option.RequestOption) (r *Client) { + defaults := []option.RequestOption{option.WithEnvironmentProduction()} + if o, ok := os.LookupEnv("ANTHROPIC_API_KEY"); ok { + defaults = append(defaults, option.WithAPIKey(o)) + } + if o, ok := os.LookupEnv("ANTHROPIC_AUTH_TOKEN"); ok { + defaults = append(defaults, option.WithAuthToken(o)) + } + opts = append(defaults, opts...) + + r = &Client{Options: opts} + + r.Completions = NewCompletionService(opts...) + r.Messages = NewMessageService(opts...) + + return +} + +// Execute makes a request with the given context, method, URL, request params, +// response, and request options. This is useful for hitting undocumented endpoints +// while retaining the base URL, auth, retries, and other options from the client. +// +// If a byte slice or an [io.Reader] is supplied to params, it will be used as-is +// for the request body. +// +// The params is by default serialized into the body using [encoding/json]. If your +// type implements a MarshalJSON function, it will be used instead to serialize the +// request. If a URLQuery method is implemented, the returned [url.Values] will be +// used as query strings to the url. +// +// If your params struct uses [param.Field], you must provide either [MarshalJSON], +// [URLQuery], and/or [MarshalForm] functions. It is undefined behavior to use a +// struct uses [param.Field] without specifying how it is serialized. +// +// Any "…Params" object defined in this library can be used as the request +// argument. Note that 'path' arguments will not be forwarded into the url. +// +// The response body will be deserialized into the res variable, depending on its +// type: +// +// - A pointer to a [*http.Response] is populated by the raw response. +// - A pointer to a byte array will be populated with the contents of the request +// body. +// - A pointer to any other type uses this library's default JSON decoding, which +// respects UnmarshalJSON if it is defined on the type. +// - A nil value will not read the response body. +// +// For even greater flexibility, see [option.WithResponseInto] and +// [option.WithResponseBodyInto]. +func (r *Client) Execute(ctx context.Context, method string, path string, params interface{}, res interface{}, opts ...option.RequestOption) error { + opts = append(r.Options, opts...) + return requestconfig.ExecuteNewRequest(ctx, method, path, params, res, opts...) +} + +// Get makes a GET request with the given URL, params, and optionally deserializes +// to a response. See [Execute] documentation on the params and response. +func (r *Client) Get(ctx context.Context, path string, params interface{}, res interface{}, opts ...option.RequestOption) error { + return r.Execute(ctx, http.MethodGet, path, params, res, opts...) +} + +// Post makes a POST request with the given URL, params, and optionally +// deserializes to a response. See [Execute] documentation on the params and +// response. +func (r *Client) Post(ctx context.Context, path string, params interface{}, res interface{}, opts ...option.RequestOption) error { + return r.Execute(ctx, http.MethodPost, path, params, res, opts...) +} + +// Put makes a PUT request with the given URL, params, and optionally deserializes +// to a response. See [Execute] documentation on the params and response. +func (r *Client) Put(ctx context.Context, path string, params interface{}, res interface{}, opts ...option.RequestOption) error { + return r.Execute(ctx, http.MethodPut, path, params, res, opts...) +} + +// Patch makes a PATCH request with the given URL, params, and optionally +// deserializes to a response. See [Execute] documentation on the params and +// response. +func (r *Client) Patch(ctx context.Context, path string, params interface{}, res interface{}, opts ...option.RequestOption) error { + return r.Execute(ctx, http.MethodPatch, path, params, res, opts...) +} + +// Delete makes a DELETE request with the given URL, params, and optionally +// deserializes to a response. See [Execute] documentation on the params and +// response. +func (r *Client) Delete(ctx context.Context, path string, params interface{}, res interface{}, opts ...option.RequestOption) error { + return r.Execute(ctx, http.MethodDelete, path, params, res, opts...) +} diff --git a/client_test.go b/client_test.go new file mode 100644 index 0000000..c4c85e0 --- /dev/null +++ b/client_test.go @@ -0,0 +1,211 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package anthropic_test + +import ( + "context" + "fmt" + "net/http" + "testing" + "time" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/internal" + "github.com/anthropics/anthropic-sdk-go/option" +) + +type closureTransport struct { + fn func(req *http.Request) (*http.Response, error) +} + +func (t *closureTransport) RoundTrip(req *http.Request) (*http.Response, error) { + return t.fn(req) +} + +func TestUserAgentHeader(t *testing.T) { + var userAgent string + client := anthropic.NewClient( + option.WithHTTPClient(&http.Client{ + Transport: &closureTransport{ + fn: func(req *http.Request) (*http.Response, error) { + userAgent = req.Header.Get("User-Agent") + return &http.Response{ + StatusCode: http.StatusOK, + }, nil + }, + }, + }), + ) + client.Messages.New(context.Background(), anthropic.MessageNewParams{ + MaxTokens: anthropic.F(int64(1024)), + Messages: anthropic.F([]anthropic.MessageParam{{ + Role: anthropic.F(anthropic.MessageParamRoleUser), + Content: anthropic.F([]anthropic.MessageParamContentUnion{anthropic.TextBlockParam{Type: anthropic.F(anthropic.TextBlockParamTypeText), Text: anthropic.F("What is a quaternion?")}}), + }}), + Model: anthropic.F(anthropic.ModelClaude_3_5_Sonnet_20240620), + }) + if userAgent != fmt.Sprintf("Anthropic/Go %s", internal.PackageVersion) { + t.Errorf("Expected User-Agent to be correct, but got: %#v", userAgent) + } +} + +func TestRetryAfter(t *testing.T) { + attempts := 0 + client := anthropic.NewClient( + option.WithHTTPClient(&http.Client{ + Transport: &closureTransport{ + fn: func(req *http.Request) (*http.Response, error) { + attempts++ + return &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{ + http.CanonicalHeaderKey("Retry-After"): []string{"0.1"}, + }, + }, nil + }, + }, + }), + ) + res, err := client.Messages.New(context.Background(), anthropic.MessageNewParams{ + MaxTokens: anthropic.F(int64(1024)), + Messages: anthropic.F([]anthropic.MessageParam{{ + Role: anthropic.F(anthropic.MessageParamRoleUser), + Content: anthropic.F([]anthropic.MessageParamContentUnion{anthropic.TextBlockParam{Type: anthropic.F(anthropic.TextBlockParamTypeText), Text: anthropic.F("What is a quaternion?")}}), + }}), + Model: anthropic.F(anthropic.ModelClaude_3_5_Sonnet_20240620), + }) + if err == nil || res != nil { + t.Error("Expected there to be a cancel error and for the response to be nil") + } + if want := 3; attempts != want { + t.Errorf("Expected %d attempts, got %d", want, attempts) + } +} + +func TestRetryAfterMs(t *testing.T) { + attempts := 0 + client := anthropic.NewClient( + option.WithHTTPClient(&http.Client{ + Transport: &closureTransport{ + fn: func(req *http.Request) (*http.Response, error) { + attempts++ + return &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{ + http.CanonicalHeaderKey("Retry-After-Ms"): []string{"100"}, + }, + }, nil + }, + }, + }), + ) + res, err := client.Messages.New(context.Background(), anthropic.MessageNewParams{ + MaxTokens: anthropic.F(int64(1024)), + Messages: anthropic.F([]anthropic.MessageParam{{ + Role: anthropic.F(anthropic.MessageParamRoleUser), + Content: anthropic.F([]anthropic.MessageParamContentUnion{anthropic.TextBlockParam{Type: anthropic.F(anthropic.TextBlockParamTypeText), Text: anthropic.F("What is a quaternion?")}}), + }}), + Model: anthropic.F(anthropic.ModelClaude_3_5_Sonnet_20240620), + }) + if err == nil || res != nil { + t.Error("Expected there to be a cancel error and for the response to be nil") + } + if want := 3; attempts != want { + t.Errorf("Expected %d attempts, got %d", want, attempts) + } +} + +func TestContextCancel(t *testing.T) { + client := anthropic.NewClient( + option.WithHTTPClient(&http.Client{ + Transport: &closureTransport{ + fn: func(req *http.Request) (*http.Response, error) { + <-req.Context().Done() + return nil, req.Context().Err() + }, + }, + }), + ) + cancelCtx, cancel := context.WithCancel(context.Background()) + cancel() + res, err := client.Messages.New(cancelCtx, anthropic.MessageNewParams{ + MaxTokens: anthropic.F(int64(1024)), + Messages: anthropic.F([]anthropic.MessageParam{{ + Role: anthropic.F(anthropic.MessageParamRoleUser), + Content: anthropic.F([]anthropic.MessageParamContentUnion{anthropic.TextBlockParam{Type: anthropic.F(anthropic.TextBlockParamTypeText), Text: anthropic.F("What is a quaternion?")}}), + }}), + Model: anthropic.F(anthropic.ModelClaude_3_5_Sonnet_20240620), + }) + if err == nil || res != nil { + t.Error("Expected there to be a cancel error and for the response to be nil") + } +} + +func TestContextCancelDelay(t *testing.T) { + client := anthropic.NewClient( + option.WithHTTPClient(&http.Client{ + Transport: &closureTransport{ + fn: func(req *http.Request) (*http.Response, error) { + <-req.Context().Done() + return nil, req.Context().Err() + }, + }, + }), + ) + cancelCtx, cancel := context.WithTimeout(context.Background(), 2*time.Millisecond) + defer cancel() + res, err := client.Messages.New(cancelCtx, anthropic.MessageNewParams{ + MaxTokens: anthropic.F(int64(1024)), + Messages: anthropic.F([]anthropic.MessageParam{{ + Role: anthropic.F(anthropic.MessageParamRoleUser), + Content: anthropic.F([]anthropic.MessageParamContentUnion{anthropic.TextBlockParam{Type: anthropic.F(anthropic.TextBlockParamTypeText), Text: anthropic.F("What is a quaternion?")}}), + }}), + Model: anthropic.F(anthropic.ModelClaude_3_5_Sonnet_20240620), + }) + if err == nil || res != nil { + t.Error("expected there to be a cancel error and for the response to be nil") + } +} + +func TestContextDeadline(t *testing.T) { + testTimeout := time.After(3 * time.Second) + testDone := make(chan struct{}) + + deadline := time.Now().Add(100 * time.Millisecond) + deadlineCtx, cancel := context.WithDeadline(context.Background(), deadline) + defer cancel() + + go func() { + client := anthropic.NewClient( + option.WithHTTPClient(&http.Client{ + Transport: &closureTransport{ + fn: func(req *http.Request) (*http.Response, error) { + <-req.Context().Done() + return nil, req.Context().Err() + }, + }, + }), + ) + res, err := client.Messages.New(deadlineCtx, anthropic.MessageNewParams{ + MaxTokens: anthropic.F(int64(1024)), + Messages: anthropic.F([]anthropic.MessageParam{{ + Role: anthropic.F(anthropic.MessageParamRoleUser), + Content: anthropic.F([]anthropic.MessageParamContentUnion{anthropic.TextBlockParam{Type: anthropic.F(anthropic.TextBlockParamTypeText), Text: anthropic.F("What is a quaternion?")}}), + }}), + Model: anthropic.F(anthropic.ModelClaude_3_5_Sonnet_20240620), + }) + if err == nil || res != nil { + t.Error("expected there to be a deadline error and for the response to be nil") + } + close(testDone) + }() + + select { + case <-testTimeout: + t.Fatal("client didn't finish in time") + case <-testDone: + if diff := time.Since(deadline); diff < -30*time.Millisecond || 30*time.Millisecond < diff { + t.Fatalf("client did not return within 30ms of context deadline, got %s", diff) + } + } +} diff --git a/completion.go b/completion.go new file mode 100644 index 0000000..42c9668 --- /dev/null +++ b/completion.go @@ -0,0 +1,213 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package anthropic + +import ( + "context" + "net/http" + + "github.com/anthropics/anthropic-sdk-go/internal/apijson" + "github.com/anthropics/anthropic-sdk-go/internal/param" + "github.com/anthropics/anthropic-sdk-go/internal/requestconfig" + "github.com/anthropics/anthropic-sdk-go/option" + "github.com/anthropics/anthropic-sdk-go/packages/ssestream" +) + +// CompletionService contains methods and other services that help with interacting +// with the anthropic API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewCompletionService] method instead. +type CompletionService struct { + Options []option.RequestOption +} + +// NewCompletionService generates a new service that applies the given options to +// each request. These options are applied after the parent client's options (if +// there is one), and before any request-specific options. +func NewCompletionService(opts ...option.RequestOption) (r *CompletionService) { + r = &CompletionService{} + r.Options = opts + return +} + +// [Legacy] Create a Text Completion. +// +// The Text Completions API is a legacy API. We recommend using the +// [Messages API](https://docs.anthropic.com/en/api/messages) going forward. +// +// Future models and features will not be compatible with Text Completions. See our +// [migration guide](https://docs.anthropic.com/en/api/migrating-from-text-completions-to-messages) +// for guidance in migrating from Text Completions to Messages. +// +// Note: If you choose to set a timeout for this request, we recommend 10 minutes. +func (r *CompletionService) New(ctx context.Context, body CompletionNewParams, opts ...option.RequestOption) (res *Completion, err error) { + opts = append(r.Options[:], opts...) + path := "v1/complete" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + return +} + +// [Legacy] Create a Text Completion. +// +// The Text Completions API is a legacy API. We recommend using the +// [Messages API](https://docs.anthropic.com/en/api/messages) going forward. +// +// Future models and features will not be compatible with Text Completions. See our +// [migration guide](https://docs.anthropic.com/en/api/migrating-from-text-completions-to-messages) +// for guidance in migrating from Text Completions to Messages. +// +// Note: If you choose to set a timeout for this request, we recommend 10 minutes. +func (r *CompletionService) NewStreaming(ctx context.Context, body CompletionNewParams, opts ...option.RequestOption) (stream *ssestream.Stream[Completion]) { + var ( + raw *http.Response + err error + ) + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithJSONSet("stream", true)}, opts...) + path := "v1/complete" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &raw, opts...) + return ssestream.NewStream[Completion](ssestream.NewDecoder(raw), err) +} + +type Completion struct { + // Unique object identifier. + // + // The format and length of IDs may change over time. + ID string `json:"id,required"` + // The resulting completion up to and excluding the stop sequences. + Completion string `json:"completion,required"` + // The model that will complete your prompt.\n\nSee + // [models](https://docs.anthropic.com/en/docs/models-overview) for additional + // details and options. + Model Model `json:"model,required"` + // The reason that we stopped. + // + // This may be one the following values: + // + // - `"stop_sequence"`: we reached a stop sequence — either provided by you via the + // `stop_sequences` parameter, or a stop sequence built into the model + // - `"max_tokens"`: we exceeded `max_tokens_to_sample` or the model's maximum + StopReason string `json:"stop_reason,required,nullable"` + // Object type. + // + // For Text Completions, this is always `"completion"`. + Type CompletionType `json:"type,required"` + JSON completionJSON `json:"-"` +} + +// completionJSON contains the JSON metadata for the struct [Completion] +type completionJSON struct { + ID apijson.Field + Completion apijson.Field + Model apijson.Field + StopReason apijson.Field + Type apijson.Field + raw string + ExtraFields map[string]apijson.Field +} + +func (r *Completion) UnmarshalJSON(data []byte) (err error) { + return apijson.UnmarshalRoot(data, r) +} + +func (r completionJSON) RawJSON() string { + return r.raw +} + +// Object type. +// +// For Text Completions, this is always `"completion"`. +type CompletionType string + +const ( + CompletionTypeCompletion CompletionType = "completion" +) + +func (r CompletionType) IsKnown() bool { + switch r { + case CompletionTypeCompletion: + return true + } + return false +} + +type CompletionNewParams struct { + // The maximum number of tokens to generate before stopping. + // + // Note that our models may stop _before_ reaching this maximum. This parameter + // only specifies the absolute maximum number of tokens to generate. + MaxTokensToSample param.Field[int64] `json:"max_tokens_to_sample,required"` + // The model that will complete your prompt.\n\nSee + // [models](https://docs.anthropic.com/en/docs/models-overview) for additional + // details and options. + Model param.Field[Model] `json:"model,required"` + // The prompt that you want Claude to complete. + // + // For proper response generation you will need to format your prompt using + // alternating `\n\nHuman:` and `\n\nAssistant:` conversational turns. For example: + // + // ``` + // "\n\nHuman: {userQuestion}\n\nAssistant:" + // ``` + // + // See [prompt validation](https://docs.anthropic.com/en/api/prompt-validation) and + // our guide to + // [prompt design](https://docs.anthropic.com/en/docs/intro-to-prompting) for more + // details. + Prompt param.Field[string] `json:"prompt,required"` + // An object describing metadata about the request. + Metadata param.Field[CompletionNewParamsMetadata] `json:"metadata"` + // Sequences that will cause the model to stop generating. + // + // Our models stop on `"\n\nHuman:"`, and may include additional built-in stop + // sequences in the future. By providing the stop_sequences parameter, you may + // include additional strings that will cause the model to stop generating. + StopSequences param.Field[[]string] `json:"stop_sequences"` + // Amount of randomness injected into the response. + // + // Defaults to `1.0`. Ranges from `0.0` to `1.0`. Use `temperature` closer to `0.0` + // for analytical / multiple choice, and closer to `1.0` for creative and + // generative tasks. + // + // Note that even with `temperature` of `0.0`, the results will not be fully + // deterministic. + Temperature param.Field[float64] `json:"temperature"` + // Only sample from the top K options for each subsequent token. + // + // Used to remove "long tail" low probability responses. + // [Learn more technical details here](https://towardsdatascience.com/how-to-sample-from-language-models-682bceb97277). + // + // Recommended for advanced use cases only. You usually only need to use + // `temperature`. + TopK param.Field[int64] `json:"top_k"` + // Use nucleus sampling. + // + // In nucleus sampling, we compute the cumulative distribution over all the options + // for each subsequent token in decreasing probability order and cut it off once it + // reaches a particular probability specified by `top_p`. You should either alter + // `temperature` or `top_p`, but not both. + // + // Recommended for advanced use cases only. You usually only need to use + // `temperature`. + TopP param.Field[float64] `json:"top_p"` +} + +func (r CompletionNewParams) MarshalJSON() (data []byte, err error) { + return apijson.MarshalRoot(r) +} + +// An object describing metadata about the request. +type CompletionNewParamsMetadata struct { + // An external identifier for the user who is associated with the request. + // + // This should be a uuid, hash value, or other opaque identifier. Anthropic may use + // this id to help detect abuse. Do not include any identifying information such as + // name, email address, or phone number. + UserID param.Field[string] `json:"user_id"` +} + +func (r CompletionNewParamsMetadata) MarshalJSON() (data []byte, err error) { + return apijson.MarshalRoot(r) +} diff --git a/completion_test.go b/completion_test.go new file mode 100644 index 0000000..2b79ac9 --- /dev/null +++ b/completion_test.go @@ -0,0 +1,47 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package anthropic_test + +import ( + "context" + "errors" + "os" + "testing" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/internal/testutil" + "github.com/anthropics/anthropic-sdk-go/option" +) + +func TestCompletionNewWithOptionalParams(t *testing.T) { + baseURL := "http://localhost:4010" + if envURL, ok := os.LookupEnv("TEST_API_BASE_URL"); ok { + baseURL = envURL + } + if !testutil.CheckTestServer(t, baseURL) { + return + } + client := anthropic.NewClient( + option.WithBaseURL(baseURL), + option.WithAPIKey("my-anthropic-api-key"), + ) + _, err := client.Completions.New(context.TODO(), anthropic.CompletionNewParams{ + MaxTokensToSample: anthropic.F(int64(256)), + Model: anthropic.F(anthropic.ModelClaude_3_5_Sonnet_20240620), + Prompt: anthropic.F("\n\nHuman: Hello, world!\n\nAssistant:"), + Metadata: anthropic.F(anthropic.CompletionNewParamsMetadata{ + UserID: anthropic.F("13803d75-b4b5-4c3e-b2a2-6f21399b021b"), + }), + StopSequences: anthropic.F([]string{"string", "string", "string"}), + Temperature: anthropic.F(1.000000), + TopK: anthropic.F(int64(5)), + TopP: anthropic.F(0.700000), + }) + if err != nil { + var apierr *anthropic.Error + if errors.As(err, &apierr) { + t.Log(string(apierr.DumpRequest(true))) + } + t.Fatalf("err should be nil: %s", err.Error()) + } +} diff --git a/examples/.keep b/examples/.keep new file mode 100644 index 0000000..d8c73e9 --- /dev/null +++ b/examples/.keep @@ -0,0 +1,4 @@ +File generated from our OpenAPI spec by Stainless. + +This directory can be used to store example files demonstrating usage of this SDK. +It is ignored by Stainless code generation and its content (other than this keep file) won't be touched. \ No newline at end of file diff --git a/field.go b/field.go new file mode 100644 index 0000000..c79199f --- /dev/null +++ b/field.go @@ -0,0 +1,50 @@ +package anthropic + +import ( + "github.com/anthropics/anthropic-sdk-go/internal/param" + "io" +) + +// F is a param field helper used to initialize a [param.Field] generic struct. +// This helps specify null, zero values, and overrides, as well as normal values. +// You can read more about this in our [README]. +// +// [README]: https://pkg.go.dev/github.com/anthropics/anthropic-sdk-go#readme-request-fields +func F[T any](value T) param.Field[T] { return param.Field[T]{Value: value, Present: true} } + +// Null is a param field helper which explicitly sends null to the API. +func Null[T any]() param.Field[T] { return param.Field[T]{Null: true, Present: true} } + +// Raw is a param field helper for specifying values for fields when the +// type you are looking to send is different from the type that is specified in +// the SDK. For example, if the type of the field is an integer, but you want +// to send a float, you could do that by setting the corresponding field with +// Raw[int](0.5). +func Raw[T any](value any) param.Field[T] { return param.Field[T]{Raw: value, Present: true} } + +// Int is a param field helper which helps specify integers. This is +// particularly helpful when specifying integer constants for fields. +func Int(value int64) param.Field[int64] { return F(value) } + +// String is a param field helper which helps specify strings. +func String(value string) param.Field[string] { return F(value) } + +// Float is a param field helper which helps specify floats. +func Float(value float64) param.Field[float64] { return F(value) } + +// Bool is a param field helper which helps specify bools. +func Bool(value bool) param.Field[bool] { return F(value) } + +// FileParam is a param field helper which helps files with a mime content-type. +func FileParam(reader io.Reader, filename string, contentType string) param.Field[io.Reader] { + return F[io.Reader](&file{reader, filename, contentType}) +} + +type file struct { + io.Reader + name string + contentType string +} + +func (f *file) Name() string { return f.name } +func (f *file) ContentType() string { return f.contentType } diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..5b49773 --- /dev/null +++ b/go.mod @@ -0,0 +1,11 @@ +module github.com/anthropics/anthropic-sdk-go + +go 1.19 + +require ( + github.com/google/uuid v1.3.0 // indirect + github.com/tidwall/gjson v1.14.4 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..569e555 --- /dev/null +++ b/go.sum @@ -0,0 +1,12 @@ +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= +github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= diff --git a/internal/apierror/apierror.go b/internal/apierror/apierror.go new file mode 100644 index 0000000..c17aaf3 --- /dev/null +++ b/internal/apierror/apierror.go @@ -0,0 +1,53 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package apierror + +import ( + "fmt" + "net/http" + "net/http/httputil" + + "github.com/anthropics/anthropic-sdk-go/internal/apijson" +) + +// Error represents an error that originates from the API, i.e. when a request is +// made and the API returns a response with a HTTP status code. Other errors are +// not wrapped by this SDK. +type Error struct { + JSON errorJSON `json:"-"` + StatusCode int + Request *http.Request + Response *http.Response +} + +// errorJSON contains the JSON metadata for the struct [Error] +type errorJSON struct { + raw string + ExtraFields map[string]apijson.Field +} + +func (r *Error) UnmarshalJSON(data []byte) (err error) { + return apijson.UnmarshalRoot(data, r) +} + +func (r errorJSON) RawJSON() string { + return r.raw +} + +func (r *Error) Error() string { + // Attempt to re-populate the response body + return fmt.Sprintf("%s \"%s\": %d %s %s", r.Request.Method, r.Request.URL, r.Response.StatusCode, http.StatusText(r.Response.StatusCode), r.JSON.RawJSON()) +} + +func (r *Error) DumpRequest(body bool) []byte { + if r.Request.GetBody != nil { + r.Request.Body, _ = r.Request.GetBody() + } + out, _ := httputil.DumpRequestOut(r.Request, body) + return out +} + +func (r *Error) DumpResponse(body bool) []byte { + out, _ := httputil.DumpResponse(r.Response, body) + return out +} diff --git a/internal/apiform/encoder.go b/internal/apiform/encoder.go new file mode 100644 index 0000000..de985d3 --- /dev/null +++ b/internal/apiform/encoder.go @@ -0,0 +1,381 @@ +package apiform + +import ( + "fmt" + "io" + "mime/multipart" + "net/textproto" + "path" + "reflect" + "sort" + "strconv" + "strings" + "sync" + "time" + + "github.com/anthropics/anthropic-sdk-go/internal/param" +) + +var encoders sync.Map // map[encoderEntry]encoderFunc + +func Marshal(value interface{}, writer *multipart.Writer) error { + e := &encoder{dateFormat: time.RFC3339} + return e.marshal(value, writer) +} + +func MarshalRoot(value interface{}, writer *multipart.Writer) error { + e := &encoder{root: true, dateFormat: time.RFC3339} + return e.marshal(value, writer) +} + +type encoder struct { + dateFormat string + root bool +} + +type encoderFunc func(key string, value reflect.Value, writer *multipart.Writer) error + +type encoderField struct { + tag parsedStructTag + fn encoderFunc + idx []int +} + +type encoderEntry struct { + reflect.Type + dateFormat string + root bool +} + +func (e *encoder) marshal(value interface{}, writer *multipart.Writer) error { + val := reflect.ValueOf(value) + if !val.IsValid() { + return nil + } + typ := val.Type() + enc := e.typeEncoder(typ) + return enc("", val, writer) +} + +func (e *encoder) typeEncoder(t reflect.Type) encoderFunc { + entry := encoderEntry{ + Type: t, + dateFormat: e.dateFormat, + root: e.root, + } + + if fi, ok := encoders.Load(entry); ok { + return fi.(encoderFunc) + } + + // To deal with recursive types, populate the map with an + // indirect func before we build it. This type waits on the + // real func (f) to be ready and then calls it. This indirect + // func is only used for recursive types. + var ( + wg sync.WaitGroup + f encoderFunc + ) + wg.Add(1) + fi, loaded := encoders.LoadOrStore(entry, encoderFunc(func(key string, v reflect.Value, writer *multipart.Writer) error { + wg.Wait() + return f(key, v, writer) + })) + if loaded { + return fi.(encoderFunc) + } + + // Compute the real encoder and replace the indirect func with it. + f = e.newTypeEncoder(t) + wg.Done() + encoders.Store(entry, f) + return f +} + +func (e *encoder) newTypeEncoder(t reflect.Type) encoderFunc { + if t.ConvertibleTo(reflect.TypeOf(time.Time{})) { + return e.newTimeTypeEncoder() + } + if t.ConvertibleTo(reflect.TypeOf((*io.Reader)(nil)).Elem()) { + return e.newReaderTypeEncoder() + } + e.root = false + switch t.Kind() { + case reflect.Pointer: + inner := t.Elem() + + innerEncoder := e.typeEncoder(inner) + return func(key string, v reflect.Value, writer *multipart.Writer) error { + if !v.IsValid() || v.IsNil() { + return nil + } + return innerEncoder(key, v.Elem(), writer) + } + case reflect.Struct: + return e.newStructTypeEncoder(t) + case reflect.Slice, reflect.Array: + return e.newArrayTypeEncoder(t) + case reflect.Map: + return e.newMapEncoder(t) + case reflect.Interface: + return e.newInterfaceEncoder() + default: + return e.newPrimitiveTypeEncoder(t) + } +} + +func (e *encoder) newPrimitiveTypeEncoder(t reflect.Type) encoderFunc { + switch t.Kind() { + // Note that we could use `gjson` to encode these types but it would complicate our + // code more and this current code shouldn't cause any issues + case reflect.String: + return func(key string, v reflect.Value, writer *multipart.Writer) error { + return writer.WriteField(key, v.String()) + } + case reflect.Bool: + return func(key string, v reflect.Value, writer *multipart.Writer) error { + if v.Bool() { + return writer.WriteField(key, "true") + } + return writer.WriteField(key, "false") + } + case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: + return func(key string, v reflect.Value, writer *multipart.Writer) error { + return writer.WriteField(key, strconv.FormatInt(v.Int(), 10)) + } + case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return func(key string, v reflect.Value, writer *multipart.Writer) error { + return writer.WriteField(key, strconv.FormatUint(v.Uint(), 10)) + } + case reflect.Float32: + return func(key string, v reflect.Value, writer *multipart.Writer) error { + return writer.WriteField(key, strconv.FormatFloat(v.Float(), 'f', -1, 32)) + } + case reflect.Float64: + return func(key string, v reflect.Value, writer *multipart.Writer) error { + return writer.WriteField(key, strconv.FormatFloat(v.Float(), 'f', -1, 64)) + } + default: + return func(key string, v reflect.Value, writer *multipart.Writer) error { + return fmt.Errorf("unknown type received at primitive encoder: %s", t.String()) + } + } +} + +func (e *encoder) newArrayTypeEncoder(t reflect.Type) encoderFunc { + itemEncoder := e.typeEncoder(t.Elem()) + + return func(key string, v reflect.Value, writer *multipart.Writer) error { + if key != "" { + key = key + "." + } + for i := 0; i < v.Len(); i++ { + err := itemEncoder(key+strconv.Itoa(i), v.Index(i), writer) + if err != nil { + return err + } + } + return nil + } +} + +func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc { + if t.Implements(reflect.TypeOf((*param.FieldLike)(nil)).Elem()) { + return e.newFieldTypeEncoder(t) + } + + encoderFields := []encoderField{} + extraEncoder := (*encoderField)(nil) + + // This helper allows us to recursively collect field encoders into a flat + // array. The parameter `index` keeps track of the access patterns necessary + // to get to some field. + var collectEncoderFields func(r reflect.Type, index []int) + collectEncoderFields = func(r reflect.Type, index []int) { + for i := 0; i < r.NumField(); i++ { + idx := append(index, i) + field := t.FieldByIndex(idx) + if !field.IsExported() { + continue + } + // If this is an embedded struct, traverse one level deeper to extract + // the field and get their encoders as well. + if field.Anonymous { + collectEncoderFields(field.Type, idx) + continue + } + // If json tag is not present, then we skip, which is intentionally + // different behavior from the stdlib. + ptag, ok := parseFormStructTag(field) + if !ok { + continue + } + // We only want to support unexported field if they're tagged with + // `extras` because that field shouldn't be part of the public API. We + // also want to only keep the top level extras + if ptag.extras && len(index) == 0 { + extraEncoder = &encoderField{ptag, e.typeEncoder(field.Type.Elem()), idx} + continue + } + if ptag.name == "-" { + continue + } + + dateFormat, ok := parseFormatStructTag(field) + oldFormat := e.dateFormat + if ok { + switch dateFormat { + case "date-time": + e.dateFormat = time.RFC3339 + case "date": + e.dateFormat = "2006-01-02" + } + } + encoderFields = append(encoderFields, encoderField{ptag, e.typeEncoder(field.Type), idx}) + e.dateFormat = oldFormat + } + } + collectEncoderFields(t, []int{}) + + // Ensure deterministic output by sorting by lexicographic order + sort.Slice(encoderFields, func(i, j int) bool { + return encoderFields[i].tag.name < encoderFields[j].tag.name + }) + + return func(key string, value reflect.Value, writer *multipart.Writer) error { + if key != "" { + key = key + "." + } + + for _, ef := range encoderFields { + field := value.FieldByIndex(ef.idx) + err := ef.fn(key+ef.tag.name, field, writer) + if err != nil { + return err + } + } + + if extraEncoder != nil { + err := e.encodeMapEntries(key, value.FieldByIndex(extraEncoder.idx), writer) + if err != nil { + return err + } + } + + return nil + } +} + +func (e *encoder) newFieldTypeEncoder(t reflect.Type) encoderFunc { + f, _ := t.FieldByName("Value") + enc := e.typeEncoder(f.Type) + + return func(key string, value reflect.Value, writer *multipart.Writer) error { + present := value.FieldByName("Present") + if !present.Bool() { + return nil + } + null := value.FieldByName("Null") + if null.Bool() { + return nil + } + raw := value.FieldByName("Raw") + if !raw.IsNil() { + return e.typeEncoder(raw.Type())(key, raw, writer) + } + return enc(key, value.FieldByName("Value"), writer) + } +} + +func (e *encoder) newTimeTypeEncoder() encoderFunc { + format := e.dateFormat + return func(key string, value reflect.Value, writer *multipart.Writer) error { + return writer.WriteField(key, value.Convert(reflect.TypeOf(time.Time{})).Interface().(time.Time).Format(format)) + } +} + +func (e encoder) newInterfaceEncoder() encoderFunc { + return func(key string, value reflect.Value, writer *multipart.Writer) error { + value = value.Elem() + if !value.IsValid() { + return nil + } + return e.typeEncoder(value.Type())(key, value, writer) + } +} + +var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"") + +func escapeQuotes(s string) string { + return quoteEscaper.Replace(s) +} + +func (e *encoder) newReaderTypeEncoder() encoderFunc { + return func(key string, value reflect.Value, writer *multipart.Writer) error { + reader := value.Convert(reflect.TypeOf((*io.Reader)(nil)).Elem()).Interface().(io.Reader) + filename := "anonymous_file" + contentType := "application/octet-stream" + if named, ok := reader.(interface{ Name() string }); ok { + filename = path.Base(named.Name()) + } + if typed, ok := reader.(interface{ ContentType() string }); ok { + contentType = path.Base(typed.ContentType()) + } + + // Below is taken almost 1-for-1 from [multipart.CreateFormFile] + h := make(textproto.MIMEHeader) + h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, escapeQuotes(key), escapeQuotes(filename))) + h.Set("Content-Type", contentType) + filewriter, err := writer.CreatePart(h) + if err != nil { + return err + } + _, err = io.Copy(filewriter, reader) + return err + } +} + +// Given a []byte of json (may either be an empty object or an object that already contains entries) +// encode all of the entries in the map to the json byte array. +func (e *encoder) encodeMapEntries(key string, v reflect.Value, writer *multipart.Writer) error { + type mapPair struct { + key string + value reflect.Value + } + + if key != "" { + key = key + "." + } + + pairs := []mapPair{} + + iter := v.MapRange() + for iter.Next() { + if iter.Key().Type().Kind() == reflect.String { + pairs = append(pairs, mapPair{key: iter.Key().String(), value: iter.Value()}) + } else { + return fmt.Errorf("cannot encode a map with a non string key") + } + } + + // Ensure deterministic output + sort.Slice(pairs, func(i, j int) bool { + return pairs[i].key < pairs[j].key + }) + + elementEncoder := e.typeEncoder(v.Type().Elem()) + for _, p := range pairs { + err := elementEncoder(key+string(p.key), p.value, writer) + if err != nil { + return err + } + } + + return nil +} + +func (e *encoder) newMapEncoder(t reflect.Type) encoderFunc { + return func(key string, value reflect.Value, writer *multipart.Writer) error { + return e.encodeMapEntries(key, value, writer) + } +} diff --git a/internal/apiform/form.go b/internal/apiform/form.go new file mode 100644 index 0000000..5445116 --- /dev/null +++ b/internal/apiform/form.go @@ -0,0 +1,5 @@ +package apiform + +type Marshaler interface { + MarshalMultipart() ([]byte, string, error) +} diff --git a/internal/apiform/form_test.go b/internal/apiform/form_test.go new file mode 100644 index 0000000..39d1460 --- /dev/null +++ b/internal/apiform/form_test.go @@ -0,0 +1,440 @@ +package apiform + +import ( + "bytes" + "mime/multipart" + "strings" + "testing" + "time" +) + +func P[T any](v T) *T { return &v } + +type Primitives struct { + A bool `form:"a"` + B int `form:"b"` + C uint `form:"c"` + D float64 `form:"d"` + E float32 `form:"e"` + F []int `form:"f"` +} + +type PrimitivePointers struct { + A *bool `form:"a"` + B *int `form:"b"` + C *uint `form:"c"` + D *float64 `form:"d"` + E *float32 `form:"e"` + F *[]int `form:"f"` +} + +type Slices struct { + Slice []Primitives `form:"slices"` +} + +type DateTime struct { + Date time.Time `form:"date" format:"date"` + DateTime time.Time `form:"date-time" format:"date-time"` +} + +type AdditionalProperties struct { + A bool `form:"a"` + Extras map[string]interface{} `form:"-,extras"` +} + +type TypedAdditionalProperties struct { + A bool `form:"a"` + Extras map[string]int `form:"-,extras"` +} + +type EmbeddedStructs struct { + AdditionalProperties + A *int `form:"number2"` + Extras map[string]interface{} `form:"-,extras"` +} + +type Recursive struct { + Name string `form:"name"` + Child *Recursive `form:"child"` +} + +type UnknownStruct struct { + Unknown interface{} `form:"unknown"` +} + +type UnionStruct struct { + Union Union `form:"union" format:"date"` +} + +type Union interface { + union() +} + +type UnionInteger int64 + +func (UnionInteger) union() {} + +type UnionStructA struct { + Type string `form:"type"` + A string `form:"a"` + B string `form:"b"` +} + +func (UnionStructA) union() {} + +type UnionStructB struct { + Type string `form:"type"` + A string `form:"a"` +} + +func (UnionStructB) union() {} + +type UnionTime time.Time + +func (UnionTime) union() {} + +type ReaderStruct struct { +} + +var tests = map[string]struct { + buf string + val interface{} +}{ + "map_string": { + `--xxx +Content-Disposition: form-data; name="foo" + +bar +--xxx-- +`, + map[string]string{"foo": "bar"}, + }, + + "map_interface": { + `--xxx +Content-Disposition: form-data; name="a" + +1 +--xxx +Content-Disposition: form-data; name="b" + +str +--xxx +Content-Disposition: form-data; name="c" + +false +--xxx-- +`, + map[string]interface{}{"a": float64(1), "b": "str", "c": false}, + }, + + "primitive_struct": { + `--xxx +Content-Disposition: form-data; name="a" + +false +--xxx +Content-Disposition: form-data; name="b" + +237628372683 +--xxx +Content-Disposition: form-data; name="c" + +654 +--xxx +Content-Disposition: form-data; name="d" + +9999.43 +--xxx +Content-Disposition: form-data; name="e" + +43.76 +--xxx +Content-Disposition: form-data; name="f.0" + +1 +--xxx +Content-Disposition: form-data; name="f.1" + +2 +--xxx +Content-Disposition: form-data; name="f.2" + +3 +--xxx +Content-Disposition: form-data; name="f.3" + +4 +--xxx-- +`, + Primitives{A: false, B: 237628372683, C: uint(654), D: 9999.43, E: 43.76, F: []int{1, 2, 3, 4}}, + }, + + "slices": { + `--xxx +Content-Disposition: form-data; name="slices.0.a" + +false +--xxx +Content-Disposition: form-data; name="slices.0.b" + +237628372683 +--xxx +Content-Disposition: form-data; name="slices.0.c" + +654 +--xxx +Content-Disposition: form-data; name="slices.0.d" + +9999.43 +--xxx +Content-Disposition: form-data; name="slices.0.e" + +43.76 +--xxx +Content-Disposition: form-data; name="slices.0.f.0" + +1 +--xxx +Content-Disposition: form-data; name="slices.0.f.1" + +2 +--xxx +Content-Disposition: form-data; name="slices.0.f.2" + +3 +--xxx +Content-Disposition: form-data; name="slices.0.f.3" + +4 +--xxx-- +`, + Slices{ + Slice: []Primitives{{A: false, B: 237628372683, C: uint(654), D: 9999.43, E: 43.76, F: []int{1, 2, 3, 4}}}, + }, + }, + + "primitive_pointer_struct": { + `--xxx +Content-Disposition: form-data; name="a" + +false +--xxx +Content-Disposition: form-data; name="b" + +237628372683 +--xxx +Content-Disposition: form-data; name="c" + +654 +--xxx +Content-Disposition: form-data; name="d" + +9999.43 +--xxx +Content-Disposition: form-data; name="e" + +43.76 +--xxx +Content-Disposition: form-data; name="f.0" + +1 +--xxx +Content-Disposition: form-data; name="f.1" + +2 +--xxx +Content-Disposition: form-data; name="f.2" + +3 +--xxx +Content-Disposition: form-data; name="f.3" + +4 +--xxx +Content-Disposition: form-data; name="f.4" + +5 +--xxx-- +`, + PrimitivePointers{ + A: P(false), + B: P(237628372683), + C: P(uint(654)), + D: P(9999.43), + E: P(float32(43.76)), + F: &[]int{1, 2, 3, 4, 5}, + }, + }, + + "datetime_struct": { + `--xxx +Content-Disposition: form-data; name="date" + +2006-01-02 +--xxx +Content-Disposition: form-data; name="date-time" + +2006-01-02T15:04:05Z +--xxx-- +`, + DateTime{ + Date: time.Date(2006, time.January, 2, 0, 0, 0, 0, time.UTC), + DateTime: time.Date(2006, time.January, 2, 15, 4, 5, 0, time.UTC), + }, + }, + + "additional_properties": { + `--xxx +Content-Disposition: form-data; name="a" + +true +--xxx +Content-Disposition: form-data; name="bar" + +value +--xxx +Content-Disposition: form-data; name="foo" + +true +--xxx-- +`, + AdditionalProperties{ + A: true, + Extras: map[string]interface{}{ + "bar": "value", + "foo": true, + }, + }, + }, + + "recursive_struct": { + `--xxx +Content-Disposition: form-data; name="child.name" + +Alex +--xxx +Content-Disposition: form-data; name="name" + +Robert +--xxx-- +`, + Recursive{Name: "Robert", Child: &Recursive{Name: "Alex"}}, + }, + + "unknown_struct_number": { + `--xxx +Content-Disposition: form-data; name="unknown" + +12 +--xxx-- +`, + UnknownStruct{ + Unknown: 12., + }, + }, + + "unknown_struct_map": { + `--xxx +Content-Disposition: form-data; name="unknown.foo" + +bar +--xxx-- +`, + UnknownStruct{ + Unknown: map[string]interface{}{ + "foo": "bar", + }, + }, + }, + + "union_integer": { + `--xxx +Content-Disposition: form-data; name="union" + +12 +--xxx-- +`, + UnionStruct{ + Union: UnionInteger(12), + }, + }, + + "union_struct_discriminated_a": { + `--xxx +Content-Disposition: form-data; name="union.a" + +foo +--xxx +Content-Disposition: form-data; name="union.b" + +bar +--xxx +Content-Disposition: form-data; name="union.type" + +typeA +--xxx-- +`, + + UnionStruct{ + Union: UnionStructA{ + Type: "typeA", + A: "foo", + B: "bar", + }, + }, + }, + + "union_struct_discriminated_b": { + `--xxx +Content-Disposition: form-data; name="union.a" + +foo +--xxx +Content-Disposition: form-data; name="union.type" + +typeB +--xxx-- +`, + UnionStruct{ + Union: UnionStructB{ + Type: "typeB", + A: "foo", + }, + }, + }, + + "union_struct_time": { + `--xxx +Content-Disposition: form-data; name="union" + +2010-05-23 +--xxx-- +`, + UnionStruct{ + Union: UnionTime(time.Date(2010, 05, 23, 0, 0, 0, 0, time.UTC)), + }, + }, +} + +func TestEncode(t *testing.T) { + for name, test := range tests { + t.Run(name, func(t *testing.T) { + buf := bytes.NewBuffer(nil) + writer := multipart.NewWriter(buf) + writer.SetBoundary("xxx") + err := Marshal(test.val, writer) + if err != nil { + t.Errorf("serialization of %v failed with error %v", test.val, err) + } + err = writer.Close() + if err != nil { + t.Errorf("serialization of %v failed with error %v", test.val, err) + } + raw := buf.Bytes() + if string(raw) != strings.ReplaceAll(test.buf, "\n", "\r\n") { + t.Errorf("expected %+#v to serialize to '%s' but got '%s'", test.val, test.buf, string(raw)) + } + }) + } +} diff --git a/internal/apiform/tag.go b/internal/apiform/tag.go new file mode 100644 index 0000000..b22e054 --- /dev/null +++ b/internal/apiform/tag.go @@ -0,0 +1,48 @@ +package apiform + +import ( + "reflect" + "strings" +) + +const jsonStructTag = "json" +const formStructTag = "form" +const formatStructTag = "format" + +type parsedStructTag struct { + name string + required bool + extras bool + metadata bool +} + +func parseFormStructTag(field reflect.StructField) (tag parsedStructTag, ok bool) { + raw, ok := field.Tag.Lookup(formStructTag) + if !ok { + raw, ok = field.Tag.Lookup(jsonStructTag) + } + if !ok { + return + } + parts := strings.Split(raw, ",") + if len(parts) == 0 { + return tag, false + } + tag.name = parts[0] + for _, part := range parts[1:] { + switch part { + case "required": + tag.required = true + case "extras": + tag.extras = true + case "metadata": + tag.metadata = true + } + } + return +} + +func parseFormatStructTag(field reflect.StructField) (format string, ok bool) { + format, ok = field.Tag.Lookup(formatStructTag) + return +} diff --git a/internal/apijson/decoder.go b/internal/apijson/decoder.go new file mode 100644 index 0000000..e1b21b7 --- /dev/null +++ b/internal/apijson/decoder.go @@ -0,0 +1,668 @@ +package apijson + +import ( + "encoding/json" + "errors" + "fmt" + "reflect" + "strconv" + "sync" + "time" + "unsafe" + + "github.com/tidwall/gjson" +) + +// decoders is a synchronized map with roughly the following type: +// map[reflect.Type]decoderFunc +var decoders sync.Map + +// Unmarshal is similar to [encoding/json.Unmarshal] and parses the JSON-encoded +// data and stores it in the given pointer. +func Unmarshal(raw []byte, to any) error { + d := &decoderBuilder{dateFormat: time.RFC3339} + return d.unmarshal(raw, to) +} + +// UnmarshalRoot is like Unmarshal, but doesn't try to call MarshalJSON on the +// root element. Useful if a struct's UnmarshalJSON is overrode to use the +// behavior of this encoder versus the standard library. +func UnmarshalRoot(raw []byte, to any) error { + d := &decoderBuilder{dateFormat: time.RFC3339, root: true} + return d.unmarshal(raw, to) +} + +// decoderBuilder contains the 'compile-time' state of the decoder. +type decoderBuilder struct { + // Whether or not this is the first element and called by [UnmarshalRoot], see + // the documentation there to see why this is necessary. + root bool + // The dateFormat (a format string for [time.Format]) which is chosen by the + // last struct tag that was seen. + dateFormat string +} + +// decoderState contains the 'run-time' state of the decoder. +type decoderState struct { + strict bool + exactness exactness +} + +// Exactness refers to how close to the type the result was if deserialization +// was successful. This is useful in deserializing unions, where you want to try +// each entry, first with strict, then with looser validation, without actually +// having to do a lot of redundant work by marshalling twice (or maybe even more +// times). +type exactness int8 + +const ( + // Some values had to fudged a bit, for example by converting a string to an + // int, or an enum with extra values. + loose exactness = iota + // There are some extra arguments, but other wise it matches the union. + extras + // Exactly right. + exact +) + +type decoderFunc func(node gjson.Result, value reflect.Value, state *decoderState) error + +type decoderField struct { + tag parsedStructTag + fn decoderFunc + idx []int + goname string +} + +type decoderEntry struct { + reflect.Type + dateFormat string + root bool +} + +func (d *decoderBuilder) unmarshal(raw []byte, to any) error { + value := reflect.ValueOf(to).Elem() + result := gjson.ParseBytes(raw) + if !value.IsValid() { + return fmt.Errorf("apijson: cannot marshal into invalid value") + } + return d.typeDecoder(value.Type())(result, value, &decoderState{strict: false, exactness: exact}) +} + +func (d *decoderBuilder) typeDecoder(t reflect.Type) decoderFunc { + entry := decoderEntry{ + Type: t, + dateFormat: d.dateFormat, + root: d.root, + } + + if fi, ok := decoders.Load(entry); ok { + return fi.(decoderFunc) + } + + // To deal with recursive types, populate the map with an + // indirect func before we build it. This type waits on the + // real func (f) to be ready and then calls it. This indirect + // func is only used for recursive types. + var ( + wg sync.WaitGroup + f decoderFunc + ) + wg.Add(1) + fi, loaded := decoders.LoadOrStore(entry, decoderFunc(func(node gjson.Result, v reflect.Value, state *decoderState) error { + wg.Wait() + return f(node, v, state) + })) + if loaded { + return fi.(decoderFunc) + } + + // Compute the real decoder and replace the indirect func with it. + f = d.newTypeDecoder(t) + wg.Done() + decoders.Store(entry, f) + return f +} + +func indirectUnmarshalerDecoder(n gjson.Result, v reflect.Value, state *decoderState) error { + return v.Addr().Interface().(json.Unmarshaler).UnmarshalJSON([]byte(n.Raw)) +} + +func unmarshalerDecoder(n gjson.Result, v reflect.Value, state *decoderState) error { + if v.Kind() == reflect.Pointer && v.CanSet() { + v.Set(reflect.New(v.Type().Elem())) + } + return v.Interface().(json.Unmarshaler).UnmarshalJSON([]byte(n.Raw)) +} + +func (d *decoderBuilder) newTypeDecoder(t reflect.Type) decoderFunc { + if t.ConvertibleTo(reflect.TypeOf(time.Time{})) { + return d.newTimeTypeDecoder(t) + } + if !d.root && t.Implements(reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()) { + return unmarshalerDecoder + } + if !d.root && reflect.PointerTo(t).Implements(reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()) { + return indirectUnmarshalerDecoder + } + d.root = false + + if _, ok := unionRegistry[t]; ok { + return d.newUnionDecoder(t) + } + + switch t.Kind() { + case reflect.Pointer: + inner := t.Elem() + innerDecoder := d.typeDecoder(inner) + + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + if !v.IsValid() { + return fmt.Errorf("apijson: unexpected invalid reflection value %+#v", v) + } + + newValue := reflect.New(inner).Elem() + err := innerDecoder(n, newValue, state) + if err != nil { + return err + } + + v.Set(newValue.Addr()) + return nil + } + case reflect.Struct: + return d.newStructTypeDecoder(t) + case reflect.Array: + fallthrough + case reflect.Slice: + return d.newArrayTypeDecoder(t) + case reflect.Map: + return d.newMapDecoder(t) + case reflect.Interface: + return func(node gjson.Result, value reflect.Value, state *decoderState) error { + if !value.IsValid() { + return fmt.Errorf("apijson: unexpected invalid value %+#v", value) + } + if node.Value() != nil && value.CanSet() { + value.Set(reflect.ValueOf(node.Value())) + } + return nil + } + default: + return d.newPrimitiveTypeDecoder(t) + } +} + +// newUnionDecoder returns a decoderFunc that deserializes into a union using an +// algorithm roughly similar to Pydantic's [smart algorithm]. +// +// Conceptually this is equivalent to choosing the best schema based on how 'exact' +// the deserialization is for each of the schemas. +// +// If there is a tie in the level of exactness, then the tie is broken +// left-to-right. +// +// [smart algorithm]: https://docs.pydantic.dev/latest/concepts/unions/#smart-mode +func (d *decoderBuilder) newUnionDecoder(t reflect.Type) decoderFunc { + unionEntry, ok := unionRegistry[t] + if !ok { + panic("apijson: couldn't find union of type " + t.String() + " in union registry") + } + decoders := []decoderFunc{} + for _, variant := range unionEntry.variants { + decoder := d.typeDecoder(variant.Type) + decoders = append(decoders, decoder) + } + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + // If there is a discriminator match, circumvent the exactness logic entirely + for idx, variant := range unionEntry.variants { + decoder := decoders[idx] + if variant.TypeFilter != n.Type { + continue + } + + if len(unionEntry.discriminatorKey) != 0 { + discriminatorValue := n.Get(unionEntry.discriminatorKey).Value() + if discriminatorValue == variant.DiscriminatorValue { + inner := reflect.New(variant.Type).Elem() + err := decoder(n, inner, state) + v.Set(inner) + return err + } + } + } + + // Set bestExactness to worse than loose + bestExactness := loose - 1 + for idx, variant := range unionEntry.variants { + decoder := decoders[idx] + if variant.TypeFilter != n.Type { + continue + } + sub := decoderState{strict: state.strict, exactness: exact} + inner := reflect.New(variant.Type).Elem() + err := decoder(n, inner, &sub) + if err != nil { + continue + } + if sub.exactness == exact { + v.Set(inner) + return nil + } + if sub.exactness > bestExactness { + v.Set(inner) + bestExactness = sub.exactness + } + } + + if bestExactness < loose { + return errors.New("apijson: was not able to coerce type as union") + } + + if guardStrict(state, bestExactness != exact) { + return errors.New("apijson: was not able to coerce type as union strictly") + } + + return nil + } +} + +func (d *decoderBuilder) newMapDecoder(t reflect.Type) decoderFunc { + keyType := t.Key() + itemType := t.Elem() + itemDecoder := d.typeDecoder(itemType) + + return func(node gjson.Result, value reflect.Value, state *decoderState) (err error) { + mapValue := reflect.MakeMapWithSize(t, len(node.Map())) + + node.ForEach(func(key, value gjson.Result) bool { + // It's fine for us to just use `ValueOf` here because the key types will + // always be primitive types so we don't need to decode it using the standard pattern + keyValue := reflect.ValueOf(key.Value()) + if !keyValue.IsValid() { + if err == nil { + err = fmt.Errorf("apijson: received invalid key type %v", keyValue.String()) + } + return false + } + if keyValue.Type() != keyType { + if err == nil { + err = fmt.Errorf("apijson: expected key type %v but got %v", keyType, keyValue.Type()) + } + return false + } + + itemValue := reflect.New(itemType).Elem() + itemerr := itemDecoder(value, itemValue, state) + if itemerr != nil { + if err == nil { + err = itemerr + } + return false + } + + mapValue.SetMapIndex(keyValue, itemValue) + return true + }) + + if err != nil { + return err + } + value.Set(mapValue) + return nil + } +} + +func (d *decoderBuilder) newArrayTypeDecoder(t reflect.Type) decoderFunc { + itemDecoder := d.typeDecoder(t.Elem()) + + return func(node gjson.Result, value reflect.Value, state *decoderState) (err error) { + if !node.IsArray() { + return fmt.Errorf("apijson: could not deserialize to an array") + } + + arrayNode := node.Array() + + arrayValue := reflect.MakeSlice(reflect.SliceOf(t.Elem()), len(arrayNode), len(arrayNode)) + for i, itemNode := range arrayNode { + err = itemDecoder(itemNode, arrayValue.Index(i), state) + if err != nil { + return err + } + } + + value.Set(arrayValue) + return nil + } +} + +func (d *decoderBuilder) newStructTypeDecoder(t reflect.Type) decoderFunc { + // map of json field name to struct field decoders + decoderFields := map[string]decoderField{} + anonymousDecoders := []decoderField{} + extraDecoder := (*decoderField)(nil) + inlineDecoder := (*decoderField)(nil) + + for i := 0; i < t.NumField(); i++ { + idx := []int{i} + field := t.FieldByIndex(idx) + if !field.IsExported() { + continue + } + // If this is an embedded struct, traverse one level deeper to extract + // the fields and get their encoders as well. + if field.Anonymous { + anonymousDecoders = append(anonymousDecoders, decoderField{ + fn: d.typeDecoder(field.Type), + idx: idx[:], + }) + continue + } + // If json tag is not present, then we skip, which is intentionally + // different behavior from the stdlib. + ptag, ok := parseJSONStructTag(field) + if !ok { + continue + } + // We only want to support unexported fields if they're tagged with + // `extras` because that field shouldn't be part of the public API. + if ptag.extras { + extraDecoder = &decoderField{ptag, d.typeDecoder(field.Type.Elem()), idx, field.Name} + continue + } + if ptag.inline { + inlineDecoder = &decoderField{ptag, d.typeDecoder(field.Type), idx, field.Name} + continue + } + if ptag.metadata { + continue + } + + oldFormat := d.dateFormat + dateFormat, ok := parseFormatStructTag(field) + if ok { + switch dateFormat { + case "date-time": + d.dateFormat = time.RFC3339 + case "date": + d.dateFormat = "2006-01-02" + } + } + decoderFields[ptag.name] = decoderField{ptag, d.typeDecoder(field.Type), idx, field.Name} + d.dateFormat = oldFormat + } + + return func(node gjson.Result, value reflect.Value, state *decoderState) (err error) { + if field := value.FieldByName("JSON"); field.IsValid() { + if raw := field.FieldByName("raw"); raw.IsValid() { + setUnexportedField(raw, node.Raw) + } + } + + for _, decoder := range anonymousDecoders { + // ignore errors + decoder.fn(node, value.FieldByIndex(decoder.idx), state) + } + + if inlineDecoder != nil { + var meta Field + dest := value.FieldByIndex(inlineDecoder.idx) + isValid := false + if dest.IsValid() && node.Type != gjson.Null { + err = inlineDecoder.fn(node, dest, state) + if err == nil { + isValid = true + } + } + + if node.Type == gjson.Null { + meta = Field{ + raw: node.Raw, + status: null, + } + } else if !isValid { + meta = Field{ + raw: node.Raw, + status: invalid, + } + } else if isValid { + meta = Field{ + raw: node.Raw, + status: valid, + } + } + if metadata := getSubField(value, inlineDecoder.idx, inlineDecoder.goname); metadata.IsValid() { + metadata.Set(reflect.ValueOf(meta)) + } + return err + } + + typedExtraType := reflect.Type(nil) + typedExtraFields := reflect.Value{} + if extraDecoder != nil { + typedExtraType = value.FieldByIndex(extraDecoder.idx).Type() + typedExtraFields = reflect.MakeMap(typedExtraType) + } + untypedExtraFields := map[string]Field{} + + for fieldName, itemNode := range node.Map() { + df, explicit := decoderFields[fieldName] + var ( + dest reflect.Value + fn decoderFunc + meta Field + ) + if explicit { + fn = df.fn + dest = value.FieldByIndex(df.idx) + } + if !explicit && extraDecoder != nil { + dest = reflect.New(typedExtraType.Elem()).Elem() + fn = extraDecoder.fn + } + + isValid := false + if dest.IsValid() && itemNode.Type != gjson.Null { + err = fn(itemNode, dest, state) + if err == nil { + isValid = true + } + } + + if itemNode.Type == gjson.Null { + meta = Field{ + raw: itemNode.Raw, + status: null, + } + } else if !isValid { + meta = Field{ + raw: itemNode.Raw, + status: invalid, + } + } else if isValid { + meta = Field{ + raw: itemNode.Raw, + status: valid, + } + } + + if explicit { + if metadata := getSubField(value, df.idx, df.goname); metadata.IsValid() { + metadata.Set(reflect.ValueOf(meta)) + } + } + if !explicit { + untypedExtraFields[fieldName] = meta + } + if !explicit && extraDecoder != nil { + typedExtraFields.SetMapIndex(reflect.ValueOf(fieldName), dest) + } + } + + if extraDecoder != nil && typedExtraFields.Len() > 0 { + value.FieldByIndex(extraDecoder.idx).Set(typedExtraFields) + } + + // Set exactness to 'extras' if there are untyped, extra fields. + if len(untypedExtraFields) > 0 && state.exactness > extras { + state.exactness = extras + } + + if metadata := getSubField(value, []int{-1}, "ExtraFields"); metadata.IsValid() && len(untypedExtraFields) > 0 { + metadata.Set(reflect.ValueOf(untypedExtraFields)) + } + return nil + } +} + +func (d *decoderBuilder) newPrimitiveTypeDecoder(t reflect.Type) decoderFunc { + switch t.Kind() { + case reflect.String: + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + v.SetString(n.String()) + if guardStrict(state, n.Type != gjson.String) { + return fmt.Errorf("apijson: failed to parse string strictly") + } + // Everything that is not an object can be loosely stringified. + if n.Type == gjson.JSON { + return fmt.Errorf("apijson: failed to parse string") + } + if guardUnknown(state, v) { + return fmt.Errorf("apijson: failed string enum validation") + } + return nil + } + case reflect.Bool: + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + v.SetBool(n.Bool()) + if guardStrict(state, n.Type != gjson.True && n.Type != gjson.False) { + return fmt.Errorf("apijson: failed to parse bool strictly") + } + // Numbers and strings that are either 'true' or 'false' can be loosely + // deserialized as bool. + if n.Type == gjson.String && (n.Raw != "true" && n.Raw != "false") || n.Type == gjson.JSON { + return fmt.Errorf("apijson: failed to parse bool") + } + if guardUnknown(state, v) { + return fmt.Errorf("apijson: failed bool enum validation") + } + return nil + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + v.SetInt(n.Int()) + if guardStrict(state, n.Type != gjson.Number || n.Num != float64(int(n.Num))) { + return fmt.Errorf("apijson: failed to parse int strictly") + } + // Numbers, booleans, and strings that maybe look like numbers can be + // loosely deserialized as numbers. + if n.Type == gjson.JSON || (n.Type == gjson.String && !canParseAsNumber(n.Str)) { + return fmt.Errorf("apijson: failed to parse int") + } + if guardUnknown(state, v) { + return fmt.Errorf("apijson: failed int enum validation") + } + return nil + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + v.SetUint(n.Uint()) + if guardStrict(state, n.Type != gjson.Number || n.Num != float64(int(n.Num)) || n.Num < 0) { + return fmt.Errorf("apijson: failed to parse uint strictly") + } + // Numbers, booleans, and strings that maybe look like numbers can be + // loosely deserialized as uint. + if n.Type == gjson.JSON || (n.Type == gjson.String && !canParseAsNumber(n.Str)) { + return fmt.Errorf("apijson: failed to parse uint") + } + if guardUnknown(state, v) { + return fmt.Errorf("apijson: failed uint enum validation") + } + return nil + } + case reflect.Float32, reflect.Float64: + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + v.SetFloat(n.Float()) + if guardStrict(state, n.Type != gjson.Number) { + return fmt.Errorf("apijson: failed to parse float strictly") + } + // Numbers, booleans, and strings that maybe look like numbers can be + // loosely deserialized as floats. + if n.Type == gjson.JSON || (n.Type == gjson.String && !canParseAsNumber(n.Str)) { + return fmt.Errorf("apijson: failed to parse float") + } + if guardUnknown(state, v) { + return fmt.Errorf("apijson: failed float enum validation") + } + return nil + } + default: + return func(node gjson.Result, v reflect.Value, state *decoderState) error { + return fmt.Errorf("unknown type received at primitive decoder: %s", t.String()) + } + } +} + +func (d *decoderBuilder) newTimeTypeDecoder(t reflect.Type) decoderFunc { + format := d.dateFormat + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + parsed, err := time.Parse(format, n.Str) + if err == nil { + v.Set(reflect.ValueOf(parsed).Convert(t)) + return nil + } + + if guardStrict(state, true) { + return err + } + + layouts := []string{ + "2006-01-02", + "2006-01-02T15:04:05Z07:00", + "2006-01-02T15:04:05Z0700", + "2006-01-02T15:04:05", + "2006-01-02 15:04:05Z07:00", + "2006-01-02 15:04:05Z0700", + "2006-01-02 15:04:05", + } + + for _, layout := range layouts { + parsed, err := time.Parse(layout, n.Str) + if err == nil { + v.Set(reflect.ValueOf(parsed).Convert(t)) + return nil + } + } + + return fmt.Errorf("unable to leniently parse date-time string: %s", n.Str) + } +} + +func setUnexportedField(field reflect.Value, value interface{}) { + reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Set(reflect.ValueOf(value)) +} + +func guardStrict(state *decoderState, cond bool) bool { + if !cond { + return false + } + + if state.strict { + return true + } + + state.exactness = loose + return false +} + +func canParseAsNumber(str string) bool { + _, err := strconv.ParseFloat(str, 64) + return err == nil +} + +func guardUnknown(state *decoderState, v reflect.Value) bool { + if have, ok := v.Interface().(interface{ IsKnown() bool }); guardStrict(state, ok && !have.IsKnown()) { + return true + } + return false +} diff --git a/internal/apijson/encoder.go b/internal/apijson/encoder.go new file mode 100644 index 0000000..abc92b1 --- /dev/null +++ b/internal/apijson/encoder.go @@ -0,0 +1,391 @@ +package apijson + +import ( + "bytes" + "encoding/json" + "fmt" + "reflect" + "sort" + "strconv" + "sync" + "time" + + "github.com/tidwall/sjson" + + "github.com/anthropics/anthropic-sdk-go/internal/param" +) + +var encoders sync.Map // map[encoderEntry]encoderFunc + +func Marshal(value interface{}) ([]byte, error) { + e := &encoder{dateFormat: time.RFC3339} + return e.marshal(value) +} + +func MarshalRoot(value interface{}) ([]byte, error) { + e := &encoder{root: true, dateFormat: time.RFC3339} + return e.marshal(value) +} + +type encoder struct { + dateFormat string + root bool +} + +type encoderFunc func(value reflect.Value) ([]byte, error) + +type encoderField struct { + tag parsedStructTag + fn encoderFunc + idx []int +} + +type encoderEntry struct { + reflect.Type + dateFormat string + root bool +} + +func (e *encoder) marshal(value interface{}) ([]byte, error) { + val := reflect.ValueOf(value) + if !val.IsValid() { + return nil, nil + } + typ := val.Type() + enc := e.typeEncoder(typ) + return enc(val) +} + +func (e *encoder) typeEncoder(t reflect.Type) encoderFunc { + entry := encoderEntry{ + Type: t, + dateFormat: e.dateFormat, + root: e.root, + } + + if fi, ok := encoders.Load(entry); ok { + return fi.(encoderFunc) + } + + // To deal with recursive types, populate the map with an + // indirect func before we build it. This type waits on the + // real func (f) to be ready and then calls it. This indirect + // func is only used for recursive types. + var ( + wg sync.WaitGroup + f encoderFunc + ) + wg.Add(1) + fi, loaded := encoders.LoadOrStore(entry, encoderFunc(func(v reflect.Value) ([]byte, error) { + wg.Wait() + return f(v) + })) + if loaded { + return fi.(encoderFunc) + } + + // Compute the real encoder and replace the indirect func with it. + f = e.newTypeEncoder(t) + wg.Done() + encoders.Store(entry, f) + return f +} + +func marshalerEncoder(v reflect.Value) ([]byte, error) { + return v.Interface().(json.Marshaler).MarshalJSON() +} + +func indirectMarshalerEncoder(v reflect.Value) ([]byte, error) { + return v.Addr().Interface().(json.Marshaler).MarshalJSON() +} + +func (e *encoder) newTypeEncoder(t reflect.Type) encoderFunc { + if t.ConvertibleTo(reflect.TypeOf(time.Time{})) { + return e.newTimeTypeEncoder() + } + if !e.root && t.Implements(reflect.TypeOf((*json.Marshaler)(nil)).Elem()) { + return marshalerEncoder + } + if !e.root && reflect.PointerTo(t).Implements(reflect.TypeOf((*json.Marshaler)(nil)).Elem()) { + return indirectMarshalerEncoder + } + e.root = false + switch t.Kind() { + case reflect.Pointer: + inner := t.Elem() + + innerEncoder := e.typeEncoder(inner) + return func(v reflect.Value) ([]byte, error) { + if !v.IsValid() || v.IsNil() { + return nil, nil + } + return innerEncoder(v.Elem()) + } + case reflect.Struct: + return e.newStructTypeEncoder(t) + case reflect.Array: + fallthrough + case reflect.Slice: + return e.newArrayTypeEncoder(t) + case reflect.Map: + return e.newMapEncoder(t) + case reflect.Interface: + return e.newInterfaceEncoder() + default: + return e.newPrimitiveTypeEncoder(t) + } +} + +func (e *encoder) newPrimitiveTypeEncoder(t reflect.Type) encoderFunc { + switch t.Kind() { + // Note that we could use `gjson` to encode these types but it would complicate our + // code more and this current code shouldn't cause any issues + case reflect.String: + return func(v reflect.Value) ([]byte, error) { + return []byte(fmt.Sprintf("%q", v.String())), nil + } + case reflect.Bool: + return func(v reflect.Value) ([]byte, error) { + if v.Bool() { + return []byte("true"), nil + } + return []byte("false"), nil + } + case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: + return func(v reflect.Value) ([]byte, error) { + return []byte(strconv.FormatInt(v.Int(), 10)), nil + } + case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return func(v reflect.Value) ([]byte, error) { + return []byte(strconv.FormatUint(v.Uint(), 10)), nil + } + case reflect.Float32: + return func(v reflect.Value) ([]byte, error) { + return []byte(strconv.FormatFloat(v.Float(), 'f', -1, 32)), nil + } + case reflect.Float64: + return func(v reflect.Value) ([]byte, error) { + return []byte(strconv.FormatFloat(v.Float(), 'f', -1, 64)), nil + } + default: + return func(v reflect.Value) ([]byte, error) { + return nil, fmt.Errorf("unknown type received at primitive encoder: %s", t.String()) + } + } +} + +func (e *encoder) newArrayTypeEncoder(t reflect.Type) encoderFunc { + itemEncoder := e.typeEncoder(t.Elem()) + + return func(value reflect.Value) ([]byte, error) { + json := []byte("[]") + for i := 0; i < value.Len(); i++ { + var value, err = itemEncoder(value.Index(i)) + if err != nil { + return nil, err + } + if value == nil { + // Assume that empty items should be inserted as `null` so that the output array + // will be the same length as the input array + value = []byte("null") + } + + json, err = sjson.SetRawBytes(json, "-1", value) + if err != nil { + return nil, err + } + } + + return json, nil + } +} + +func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc { + if t.Implements(reflect.TypeOf((*param.FieldLike)(nil)).Elem()) { + return e.newFieldTypeEncoder(t) + } + + encoderFields := []encoderField{} + extraEncoder := (*encoderField)(nil) + + // This helper allows us to recursively collect field encoders into a flat + // array. The parameter `index` keeps track of the access patterns necessary + // to get to some field. + var collectEncoderFields func(r reflect.Type, index []int) + collectEncoderFields = func(r reflect.Type, index []int) { + for i := 0; i < r.NumField(); i++ { + idx := append(index, i) + field := t.FieldByIndex(idx) + if !field.IsExported() { + continue + } + // If this is an embedded struct, traverse one level deeper to extract + // the field and get their encoders as well. + if field.Anonymous { + collectEncoderFields(field.Type, idx) + continue + } + // If json tag is not present, then we skip, which is intentionally + // different behavior from the stdlib. + ptag, ok := parseJSONStructTag(field) + if !ok { + continue + } + // We only want to support unexported field if they're tagged with + // `extras` because that field shouldn't be part of the public API. We + // also want to only keep the top level extras + if ptag.extras && len(index) == 0 { + extraEncoder = &encoderField{ptag, e.typeEncoder(field.Type.Elem()), idx} + continue + } + if ptag.name == "-" { + continue + } + + dateFormat, ok := parseFormatStructTag(field) + oldFormat := e.dateFormat + if ok { + switch dateFormat { + case "date-time": + e.dateFormat = time.RFC3339 + case "date": + e.dateFormat = "2006-01-02" + } + } + encoderFields = append(encoderFields, encoderField{ptag, e.typeEncoder(field.Type), idx}) + e.dateFormat = oldFormat + } + } + collectEncoderFields(t, []int{}) + + // Ensure deterministic output by sorting by lexicographic order + sort.Slice(encoderFields, func(i, j int) bool { + return encoderFields[i].tag.name < encoderFields[j].tag.name + }) + + return func(value reflect.Value) (json []byte, err error) { + json = []byte("{}") + + for _, ef := range encoderFields { + field := value.FieldByIndex(ef.idx) + encoded, err := ef.fn(field) + if err != nil { + return nil, err + } + if encoded == nil { + continue + } + json, err = sjson.SetRawBytes(json, ef.tag.name, encoded) + if err != nil { + return nil, err + } + } + + if extraEncoder != nil { + json, err = e.encodeMapEntries(json, value.FieldByIndex(extraEncoder.idx)) + if err != nil { + return nil, err + } + } + return + } +} + +func (e *encoder) newFieldTypeEncoder(t reflect.Type) encoderFunc { + f, _ := t.FieldByName("Value") + enc := e.typeEncoder(f.Type) + + return func(value reflect.Value) (json []byte, err error) { + present := value.FieldByName("Present") + if !present.Bool() { + return nil, nil + } + null := value.FieldByName("Null") + if null.Bool() { + return []byte("null"), nil + } + raw := value.FieldByName("Raw") + if !raw.IsNil() { + return e.typeEncoder(raw.Type())(raw) + } + return enc(value.FieldByName("Value")) + } +} + +func (e *encoder) newTimeTypeEncoder() encoderFunc { + format := e.dateFormat + return func(value reflect.Value) (json []byte, err error) { + return []byte(`"` + value.Convert(reflect.TypeOf(time.Time{})).Interface().(time.Time).Format(format) + `"`), nil + } +} + +func (e encoder) newInterfaceEncoder() encoderFunc { + return func(value reflect.Value) ([]byte, error) { + value = value.Elem() + if !value.IsValid() { + return nil, nil + } + return e.typeEncoder(value.Type())(value) + } +} + +// Given a []byte of json (may either be an empty object or an object that already contains entries) +// encode all of the entries in the map to the json byte array. +func (e *encoder) encodeMapEntries(json []byte, v reflect.Value) ([]byte, error) { + type mapPair struct { + key []byte + value reflect.Value + } + + pairs := []mapPair{} + keyEncoder := e.typeEncoder(v.Type().Key()) + + iter := v.MapRange() + for iter.Next() { + var encodedKey []byte + if iter.Key().Type().Kind() == reflect.String { + encodedKey = []byte(iter.Key().String()) + } else { + var err error + encodedKey, err = keyEncoder(iter.Key()) + if err != nil { + return nil, err + } + } + pairs = append(pairs, mapPair{key: encodedKey, value: iter.Value()}) + } + + // Ensure deterministic output + sort.Slice(pairs, func(i, j int) bool { + return bytes.Compare(pairs[i].key, pairs[j].key) < 0 + }) + + elementEncoder := e.typeEncoder(v.Type().Elem()) + for _, p := range pairs { + encodedValue, err := elementEncoder(p.value) + if err != nil { + return nil, err + } + if len(encodedValue) == 0 { + continue + } + json, err = sjson.SetRawBytes(json, string(p.key), encodedValue) + if err != nil { + return nil, err + } + } + + return json, nil +} + +func (e *encoder) newMapEncoder(t reflect.Type) encoderFunc { + return func(value reflect.Value) ([]byte, error) { + json := []byte("{}") + var err error + json, err = e.encodeMapEntries(json, value) + if err != nil { + return nil, err + } + return json, nil + } +} diff --git a/internal/apijson/field.go b/internal/apijson/field.go new file mode 100644 index 0000000..3ef207c --- /dev/null +++ b/internal/apijson/field.go @@ -0,0 +1,41 @@ +package apijson + +import "reflect" + +type status uint8 + +const ( + missing status = iota + null + invalid + valid +) + +type Field struct { + raw string + status status +} + +// Returns true if the field is explicitly `null` _or_ if it is not present at all (ie, missing). +// To check if the field's key is present in the JSON with an explicit null value, +// you must check `f.IsNull() && !f.IsMissing()`. +func (j Field) IsNull() bool { return j.status <= null } +func (j Field) IsMissing() bool { return j.status == missing } +func (j Field) IsInvalid() bool { return j.status == invalid } +func (j Field) Raw() string { return j.raw } + +func getSubField(root reflect.Value, index []int, name string) reflect.Value { + strct := root.FieldByIndex(index[:len(index)-1]) + if !strct.IsValid() { + panic("couldn't find encapsulating struct for field " + name) + } + meta := strct.FieldByName("JSON") + if !meta.IsValid() { + return reflect.Value{} + } + field := meta.FieldByName(name) + if !field.IsValid() { + return reflect.Value{} + } + return field +} diff --git a/internal/apijson/field_test.go b/internal/apijson/field_test.go new file mode 100644 index 0000000..6da7716 --- /dev/null +++ b/internal/apijson/field_test.go @@ -0,0 +1,66 @@ +package apijson + +import ( + "testing" + "time" + + "github.com/anthropics/anthropic-sdk-go/internal/param" +) + +type Struct struct { + A string `json:"a"` + B int64 `json:"b"` +} + +type FieldStruct struct { + A param.Field[string] `json:"a"` + B param.Field[int64] `json:"b"` + C param.Field[Struct] `json:"c"` + D param.Field[time.Time] `json:"d" format:"date"` + E param.Field[time.Time] `json:"e" format:"date-time"` + F param.Field[int64] `json:"f"` +} + +func TestFieldMarshal(t *testing.T) { + tests := map[string]struct { + value interface{} + expected string + }{ + "null_string": {param.Field[string]{Present: true, Null: true}, "null"}, + "null_int": {param.Field[int]{Present: true, Null: true}, "null"}, + "null_int64": {param.Field[int64]{Present: true, Null: true}, "null"}, + "null_struct": {param.Field[Struct]{Present: true, Null: true}, "null"}, + + "string": {param.Field[string]{Present: true, Value: "string"}, `"string"`}, + "int": {param.Field[int]{Present: true, Value: 123}, "123"}, + "int64": {param.Field[int64]{Present: true, Value: int64(123456789123456789)}, "123456789123456789"}, + "struct": {param.Field[Struct]{Present: true, Value: Struct{A: "yo", B: 123}}, `{"a":"yo","b":123}`}, + + "string_raw": {param.Field[int]{Present: true, Raw: "string"}, `"string"`}, + "int_raw": {param.Field[int]{Present: true, Raw: 123}, "123"}, + "int64_raw": {param.Field[int]{Present: true, Raw: int64(123456789123456789)}, "123456789123456789"}, + "struct_raw": {param.Field[int]{Present: true, Raw: Struct{A: "yo", B: 123}}, `{"a":"yo","b":123}`}, + + "param_struct": { + FieldStruct{ + A: param.Field[string]{Present: true, Value: "hello"}, + B: param.Field[int64]{Present: true, Value: int64(12)}, + D: param.Field[time.Time]{Present: true, Value: time.Date(2023, time.March, 18, 14, 47, 38, 0, time.UTC)}, + E: param.Field[time.Time]{Present: true, Value: time.Date(2023, time.March, 18, 14, 47, 38, 0, time.UTC)}, + }, + `{"a":"hello","b":12,"d":"2023-03-18","e":"2023-03-18T14:47:38Z"}`, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + b, err := Marshal(test.value) + if err != nil { + t.Fatalf("didn't expect error %v", err) + } + if string(b) != test.expected { + t.Fatalf("expected %s, received %s", test.expected, string(b)) + } + }) + } +} diff --git a/internal/apijson/json_test.go b/internal/apijson/json_test.go new file mode 100644 index 0000000..72bc4c2 --- /dev/null +++ b/internal/apijson/json_test.go @@ -0,0 +1,554 @@ +package apijson + +import ( + "reflect" + "strings" + "testing" + "time" + + "github.com/tidwall/gjson" +) + +func P[T any](v T) *T { return &v } + +type Primitives struct { + A bool `json:"a"` + B int `json:"b"` + C uint `json:"c"` + D float64 `json:"d"` + E float32 `json:"e"` + F []int `json:"f"` +} + +type PrimitivePointers struct { + A *bool `json:"a"` + B *int `json:"b"` + C *uint `json:"c"` + D *float64 `json:"d"` + E *float32 `json:"e"` + F *[]int `json:"f"` +} + +type Slices struct { + Slice []Primitives `json:"slices"` +} + +type DateTime struct { + Date time.Time `json:"date" format:"date"` + DateTime time.Time `json:"date-time" format:"date-time"` +} + +type AdditionalProperties struct { + A bool `json:"a"` + ExtraFields map[string]interface{} `json:"-,extras"` +} + +type TypedAdditionalProperties struct { + A bool `json:"a"` + ExtraFields map[string]int `json:"-,extras"` +} + +type EmbeddedStruct struct { + A bool `json:"a"` + B string `json:"b"` + + JSON EmbeddedStructJSON +} + +type EmbeddedStructJSON struct { + A Field + B Field + ExtraFields map[string]Field + raw string +} + +type EmbeddedStructs struct { + EmbeddedStruct + A *int `json:"a"` + ExtraFields map[string]interface{} `json:"-,extras"` + + JSON EmbeddedStructsJSON +} + +type EmbeddedStructsJSON struct { + A Field + ExtraFields map[string]Field + raw string +} + +type Recursive struct { + Name string `json:"name"` + Child *Recursive `json:"child"` +} + +type JSONFieldStruct struct { + A bool `json:"a"` + B int64 `json:"b"` + C string `json:"c"` + D string `json:"d"` + ExtraFields map[string]int64 `json:"-,extras"` + JSON JSONFieldStructJSON `json:"-,metadata"` +} + +type JSONFieldStructJSON struct { + A Field + B Field + C Field + D Field + ExtraFields map[string]Field + raw string +} + +type UnknownStruct struct { + Unknown interface{} `json:"unknown"` +} + +type UnionStruct struct { + Union Union `json:"union" format:"date"` +} + +type Union interface { + union() +} + +type Inline struct { + InlineField Primitives `json:"-,inline"` + JSON InlineJSON `json:"-,metadata"` +} + +type InlineArray struct { + InlineField []string `json:"-,inline"` + JSON InlineJSON `json:"-,metadata"` +} + +type InlineJSON struct { + InlineField Field + raw string +} + +type UnionInteger int64 + +func (UnionInteger) union() {} + +type UnionStructA struct { + Type string `json:"type"` + A string `json:"a"` + B string `json:"b"` +} + +func (UnionStructA) union() {} + +type UnionStructB struct { + Type string `json:"type"` + A string `json:"a"` +} + +func (UnionStructB) union() {} + +type UnionTime time.Time + +func (UnionTime) union() {} + +func init() { + RegisterUnion(reflect.TypeOf((*Union)(nil)).Elem(), "type", + UnionVariant{ + TypeFilter: gjson.String, + Type: reflect.TypeOf(UnionTime{}), + }, + UnionVariant{ + TypeFilter: gjson.Number, + Type: reflect.TypeOf(UnionInteger(0)), + }, + UnionVariant{ + TypeFilter: gjson.JSON, + DiscriminatorValue: "typeA", + Type: reflect.TypeOf(UnionStructA{}), + }, + UnionVariant{ + TypeFilter: gjson.JSON, + DiscriminatorValue: "typeB", + Type: reflect.TypeOf(UnionStructB{}), + }, + ) +} + +type ComplexUnionStruct struct { + Union ComplexUnion `json:"union"` +} + +type ComplexUnion interface { + complexUnion() +} + +type ComplexUnionA struct { + Boo string `json:"boo"` + Foo bool `json:"foo"` +} + +func (ComplexUnionA) complexUnion() {} + +type ComplexUnionB struct { + Boo bool `json:"boo"` + Foo string `json:"foo"` +} + +func (ComplexUnionB) complexUnion() {} + +type ComplexUnionC struct { + Boo int64 `json:"boo"` +} + +func (ComplexUnionC) complexUnion() {} + +type ComplexUnionTypeA struct { + Baz int64 `json:"baz"` + Type TypeA `json:"type"` +} + +func (ComplexUnionTypeA) complexUnion() {} + +type TypeA string + +func (t TypeA) IsKnown() bool { + return t == "a" +} + +type ComplexUnionTypeB struct { + Baz int64 `json:"baz"` + Type TypeB `json:"type"` +} + +type TypeB string + +func (t TypeB) IsKnown() bool { + return t == "b" +} + +type UnmarshalStruct struct { + Foo string `json:"foo"` + prop bool `json:"-"` +} + +func (r *UnmarshalStruct) UnmarshalJSON(json []byte) error { + r.prop = true + return UnmarshalRoot(json, r) +} + +func (ComplexUnionTypeB) complexUnion() {} + +func init() { + RegisterUnion(reflect.TypeOf((*ComplexUnion)(nil)).Elem(), "", + UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(ComplexUnionA{}), + }, + UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(ComplexUnionB{}), + }, + UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(ComplexUnionC{}), + }, + UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(ComplexUnionTypeA{}), + }, + UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(ComplexUnionTypeB{}), + }, + ) +} + +var tests = map[string]struct { + buf string + val interface{} +}{ + "true": {"true", true}, + "false": {"false", false}, + "int": {"1", 1}, + "int_bigger": {"12324", 12324}, + "int_string_coerce": {`"65"`, 65}, + "int_boolean_coerce": {"true", 1}, + "int64": {"1", int64(1)}, + "int64_huge": {"123456789123456789", int64(123456789123456789)}, + "uint": {"1", uint(1)}, + "uint_bigger": {"12324", uint(12324)}, + "uint_coerce": {`"65"`, uint(65)}, + "float_1.54": {"1.54", float32(1.54)}, + "float_1.89": {"1.89", float64(1.89)}, + "string": {`"str"`, "str"}, + "string_int_coerce": {`12`, "12"}, + "array_string": {`["foo","bar"]`, []string{"foo", "bar"}}, + "array_int": {`[1,2]`, []int{1, 2}}, + "array_int_coerce": {`["1",2]`, []int{1, 2}}, + + "ptr_true": {"true", P(true)}, + "ptr_false": {"false", P(false)}, + "ptr_int": {"1", P(1)}, + "ptr_int_bigger": {"12324", P(12324)}, + "ptr_int_string_coerce": {`"65"`, P(65)}, + "ptr_int_boolean_coerce": {"true", P(1)}, + "ptr_int64": {"1", P(int64(1))}, + "ptr_int64_huge": {"123456789123456789", P(int64(123456789123456789))}, + "ptr_uint": {"1", P(uint(1))}, + "ptr_uint_bigger": {"12324", P(uint(12324))}, + "ptr_uint_coerce": {`"65"`, P(uint(65))}, + "ptr_float_1.54": {"1.54", P(float32(1.54))}, + "ptr_float_1.89": {"1.89", P(float64(1.89))}, + + "date_time": {`"2007-03-01T13:00:00Z"`, time.Date(2007, time.March, 1, 13, 0, 0, 0, time.UTC)}, + "date_time_nano_coerce": {`"2007-03-01T13:03:05.123456789Z"`, time.Date(2007, time.March, 1, 13, 3, 5, 123456789, time.UTC)}, + + "date_time_missing_t_coerce": {`"2007-03-01 13:03:05Z"`, time.Date(2007, time.March, 1, 13, 3, 5, 0, time.UTC)}, + "date_time_missing_timezone_coerce": {`"2007-03-01T13:03:05"`, time.Date(2007, time.March, 1, 13, 3, 5, 0, time.UTC)}, + // note: using -1200 to minimize probability of conflicting with the local timezone of the test runner + // see https://en.wikipedia.org/wiki/UTC%E2%88%9212:00 + "date_time_missing_timezone_colon_coerce": {`"2007-03-01T13:03:05-1200"`, time.Date(2007, time.March, 1, 13, 3, 5, 0, time.FixedZone("", -12*60*60))}, + "date_time_nano_missing_t_coerce": {`"2007-03-01 13:03:05.123456789Z"`, time.Date(2007, time.March, 1, 13, 3, 5, 123456789, time.UTC)}, + + "map_string": {`{"foo":"bar"}`, map[string]string{"foo": "bar"}}, + "map_interface": {`{"a":1,"b":"str","c":false}`, map[string]interface{}{"a": float64(1), "b": "str", "c": false}}, + + "primitive_struct": { + `{"a":false,"b":237628372683,"c":654,"d":9999.43,"e":43.76,"f":[1,2,3,4]}`, + Primitives{A: false, B: 237628372683, C: uint(654), D: 9999.43, E: 43.76, F: []int{1, 2, 3, 4}}, + }, + + "slices": { + `{"slices":[{"a":false,"b":237628372683,"c":654,"d":9999.43,"e":43.76,"f":[1,2,3,4]}]}`, + Slices{ + Slice: []Primitives{{A: false, B: 237628372683, C: uint(654), D: 9999.43, E: 43.76, F: []int{1, 2, 3, 4}}}, + }, + }, + + "primitive_pointer_struct": { + `{"a":false,"b":237628372683,"c":654,"d":9999.43,"e":43.76,"f":[1,2,3,4,5]}`, + PrimitivePointers{ + A: P(false), + B: P(237628372683), + C: P(uint(654)), + D: P(9999.43), + E: P(float32(43.76)), + F: &[]int{1, 2, 3, 4, 5}, + }, + }, + + "datetime_struct": { + `{"date":"2006-01-02","date-time":"2006-01-02T15:04:05Z"}`, + DateTime{ + Date: time.Date(2006, time.January, 2, 0, 0, 0, 0, time.UTC), + DateTime: time.Date(2006, time.January, 2, 15, 4, 5, 0, time.UTC), + }, + }, + + "additional_properties": { + `{"a":true,"bar":"value","foo":true}`, + AdditionalProperties{ + A: true, + ExtraFields: map[string]interface{}{ + "bar": "value", + "foo": true, + }, + }, + }, + + "embedded_struct": { + `{"a":1,"b":"bar"}`, + EmbeddedStructs{ + EmbeddedStruct: EmbeddedStruct{ + A: true, + B: "bar", + JSON: EmbeddedStructJSON{ + A: Field{raw: `1`, status: valid}, + B: Field{raw: `"bar"`, status: valid}, + raw: `{"a":1,"b":"bar"}`, + }, + }, + A: P(1), + ExtraFields: map[string]interface{}{"b": "bar"}, + JSON: EmbeddedStructsJSON{ + A: Field{raw: `1`, status: valid}, + ExtraFields: map[string]Field{ + "b": {raw: `"bar"`, status: valid}, + }, + raw: `{"a":1,"b":"bar"}`, + }, + }, + }, + + "recursive_struct": { + `{"child":{"name":"Alex"},"name":"Robert"}`, + Recursive{Name: "Robert", Child: &Recursive{Name: "Alex"}}, + }, + + "metadata_coerce": { + `{"a":"12","b":"12","c":null,"extra_typed":12,"extra_untyped":{"foo":"bar"}}`, + JSONFieldStruct{ + A: false, + B: 12, + C: "", + JSON: JSONFieldStructJSON{ + raw: `{"a":"12","b":"12","c":null,"extra_typed":12,"extra_untyped":{"foo":"bar"}}`, + A: Field{raw: `"12"`, status: invalid}, + B: Field{raw: `"12"`, status: valid}, + C: Field{raw: "null", status: null}, + D: Field{raw: "", status: missing}, + ExtraFields: map[string]Field{ + "extra_typed": { + raw: "12", + status: valid, + }, + "extra_untyped": { + raw: `{"foo":"bar"}`, + status: invalid, + }, + }, + }, + ExtraFields: map[string]int64{ + "extra_typed": 12, + "extra_untyped": 0, + }, + }, + }, + + "unknown_struct_number": { + `{"unknown":12}`, + UnknownStruct{ + Unknown: 12., + }, + }, + + "unknown_struct_map": { + `{"unknown":{"foo":"bar"}}`, + UnknownStruct{ + Unknown: map[string]interface{}{ + "foo": "bar", + }, + }, + }, + + "union_integer": { + `{"union":12}`, + UnionStruct{ + Union: UnionInteger(12), + }, + }, + + "union_struct_discriminated_a": { + `{"union":{"a":"foo","b":"bar","type":"typeA"}}`, + UnionStruct{ + Union: UnionStructA{ + Type: "typeA", + A: "foo", + B: "bar", + }, + }, + }, + + "union_struct_discriminated_b": { + `{"union":{"a":"foo","type":"typeB"}}`, + UnionStruct{ + Union: UnionStructB{ + Type: "typeB", + A: "foo", + }, + }, + }, + + "union_struct_time": { + `{"union":"2010-05-23"}`, + UnionStruct{ + Union: UnionTime(time.Date(2010, 05, 23, 0, 0, 0, 0, time.UTC)), + }, + }, + + "complex_union_a": { + `{"union":{"boo":"12","foo":true}}`, + ComplexUnionStruct{Union: ComplexUnionA{Boo: "12", Foo: true}}, + }, + + "complex_union_b": { + `{"union":{"boo":true,"foo":"12"}}`, + ComplexUnionStruct{Union: ComplexUnionB{Boo: true, Foo: "12"}}, + }, + + "complex_union_c": { + `{"union":{"boo":12}}`, + ComplexUnionStruct{Union: ComplexUnionC{Boo: 12}}, + }, + + "complex_union_type_a": { + `{"union":{"baz":12,"type":"a"}}`, + ComplexUnionStruct{Union: ComplexUnionTypeA{Baz: 12, Type: TypeA("a")}}, + }, + + "complex_union_type_b": { + `{"union":{"baz":12,"type":"b"}}`, + ComplexUnionStruct{Union: ComplexUnionTypeB{Baz: 12, Type: TypeB("b")}}, + }, + + "unmarshal": { + `{"foo":"hello"}`, + &UnmarshalStruct{Foo: "hello", prop: true}, + }, + + "array_of_unmarshal": { + `[{"foo":"hello"}]`, + []UnmarshalStruct{{Foo: "hello", prop: true}}, + }, + + "inline_coerce": { + `{"a":false,"b":237628372683,"c":654,"d":9999.43,"e":43.76,"f":[1,2,3,4]}`, + Inline{ + InlineField: Primitives{A: false, B: 237628372683, C: 0x28e, D: 9999.43, E: 43.76, F: []int{1, 2, 3, 4}}, + JSON: InlineJSON{ + InlineField: Field{raw: "{\"a\":false,\"b\":237628372683,\"c\":654,\"d\":9999.43,\"e\":43.76,\"f\":[1,2,3,4]}", status: 3}, + raw: "{\"a\":false,\"b\":237628372683,\"c\":654,\"d\":9999.43,\"e\":43.76,\"f\":[1,2,3,4]}", + }, + }, + }, + + "inline_array_coerce": { + `["Hello","foo","bar"]`, + InlineArray{ + InlineField: []string{"Hello", "foo", "bar"}, + JSON: InlineJSON{ + InlineField: Field{raw: `["Hello","foo","bar"]`, status: 3}, + raw: `["Hello","foo","bar"]`, + }, + }, + }, +} + +func TestDecode(t *testing.T) { + for name, test := range tests { + t.Run(name, func(t *testing.T) { + result := reflect.New(reflect.TypeOf(test.val)) + if err := Unmarshal([]byte(test.buf), result.Interface()); err != nil { + t.Fatalf("deserialization of %v failed with error %v", result, err) + } + if !reflect.DeepEqual(result.Elem().Interface(), test.val) { + t.Fatalf("expected '%s' to deserialize to \n%#v\nbut got\n%#v", test.buf, test.val, result.Elem().Interface()) + } + }) + } +} + +func TestEncode(t *testing.T) { + for name, test := range tests { + if strings.HasSuffix(name, "_coerce") { + continue + } + t.Run(name, func(t *testing.T) { + raw, err := Marshal(test.val) + if err != nil { + t.Fatalf("serialization of %v failed with error %v", test.val, err) + } + if string(raw) != test.buf { + t.Fatalf("expected %+#v to serialize to %s but got %s", test.val, test.buf, string(raw)) + } + }) + } +} diff --git a/internal/apijson/port.go b/internal/apijson/port.go new file mode 100644 index 0000000..80b323b --- /dev/null +++ b/internal/apijson/port.go @@ -0,0 +1,107 @@ +package apijson + +import ( + "fmt" + "reflect" +) + +// Port copies over values from one struct to another struct. +func Port(from any, to any) error { + toVal := reflect.ValueOf(to) + fromVal := reflect.ValueOf(from) + + if toVal.Kind() != reflect.Ptr || toVal.IsNil() { + return fmt.Errorf("destination must be a non-nil pointer") + } + + for toVal.Kind() == reflect.Ptr { + toVal = toVal.Elem() + } + toType := toVal.Type() + + for fromVal.Kind() == reflect.Ptr { + fromVal = fromVal.Elem() + } + fromType := fromVal.Type() + + if toType.Kind() != reflect.Struct { + return fmt.Errorf("destination must be a non-nil pointer to a struct (%v %v)", toType, toType.Kind()) + } + + values := map[string]reflect.Value{} + fields := map[string]reflect.Value{} + + fromJSON := fromVal.FieldByName("JSON") + toJSON := toVal.FieldByName("JSON") + + // First, iterate through the from fields and load all the "normal" fields in the struct to the map of + // string to reflect.Value, as well as their raw .JSON.Foo counterpart. + for i := 0; i < fromType.NumField(); i++ { + field := fromType.Field(i) + ptag, ok := parseJSONStructTag(field) + if !ok { + continue + } + if ptag.name == "-" { + continue + } + values[ptag.name] = fromVal.Field(i) + if fromJSON.IsValid() { + fields[ptag.name] = fromJSON.FieldByName(field.Name) + } + } + + // Use the values from the previous step to populate the 'to' struct. + for i := 0; i < toType.NumField(); i++ { + field := toType.Field(i) + ptag, ok := parseJSONStructTag(field) + if !ok { + continue + } + if ptag.name == "-" { + continue + } + if value, ok := values[ptag.name]; ok { + delete(values, ptag.name) + if field.Type.Kind() == reflect.Interface { + toVal.Field(i).Set(value) + } else { + switch value.Kind() { + case reflect.String: + toVal.Field(i).SetString(value.String()) + case reflect.Bool: + toVal.Field(i).SetBool(value.Bool()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + toVal.Field(i).SetInt(value.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + toVal.Field(i).SetUint(value.Uint()) + case reflect.Float32, reflect.Float64: + toVal.Field(i).SetFloat(value.Float()) + default: + toVal.Field(i).Set(value) + } + } + } + + if fromJSONField, ok := fields[ptag.name]; ok { + if toJSONField := toJSON.FieldByName(field.Name); toJSONField.IsValid() { + toJSONField.Set(fromJSONField) + } + } + } + + // Finally, copy over the .JSON.raw and .JSON.ExtraFields + if toJSON.IsValid() { + if raw := toJSON.FieldByName("raw"); raw.IsValid() { + setUnexportedField(raw, fromJSON.Interface().(interface{ RawJSON() string }).RawJSON()) + } + + if toExtraFields := toJSON.FieldByName("ExtraFields"); toExtraFields.IsValid() { + if fromExtraFields := fromJSON.FieldByName("ExtraFields"); fromExtraFields.IsValid() { + setUnexportedField(toExtraFields, fromExtraFields.Interface()) + } + } + } + + return nil +} diff --git a/internal/apijson/port_test.go b/internal/apijson/port_test.go new file mode 100644 index 0000000..f9b6e3f --- /dev/null +++ b/internal/apijson/port_test.go @@ -0,0 +1,178 @@ +package apijson + +import ( + "reflect" + "testing" +) + +type Metadata struct { + CreatedAt string `json:"created_at"` +} + +// Card is the "combined" type of CardVisa and CardMastercard +type Card struct { + Processor CardProcessor `json:"processor"` + Data any `json:"data"` + IsFoo bool `json:"is_foo"` + IsBar bool `json:"is_bar"` + Metadata Metadata `json:"metadata"` + Value interface{} `json:"value"` + + JSON cardJSON +} + +type cardJSON struct { + Processor Field + Data Field + IsFoo Field + IsBar Field + Metadata Field + Value Field + ExtraFields map[string]Field + raw string +} + +func (r cardJSON) RawJSON() string { return r.raw } + +type CardProcessor string + +// CardVisa +type CardVisa struct { + Processor CardVisaProcessor `json:"processor"` + Data CardVisaData `json:"data"` + IsFoo bool `json:"is_foo"` + Metadata Metadata `json:"metadata"` + Value string `json:"value"` + + JSON cardVisaJSON +} + +type cardVisaJSON struct { + Processor Field + Data Field + IsFoo Field + Metadata Field + Value Field + ExtraFields map[string]Field + raw string +} + +func (r cardVisaJSON) RawJSON() string { return r.raw } + +type CardVisaProcessor string + +type CardVisaData struct { + Foo string `json:"foo"` +} + +// CardMastercard +type CardMastercard struct { + Processor CardMastercardProcessor `json:"processor"` + Data CardMastercardData `json:"data"` + IsBar bool `json:"is_bar"` + Metadata Metadata `json:"metadata"` + Value bool `json:"value"` + + JSON cardMastercardJSON +} + +type cardMastercardJSON struct { + Processor Field + Data Field + IsBar Field + Metadata Field + Value Field + ExtraFields map[string]Field + raw string +} + +func (r cardMastercardJSON) RawJSON() string { return r.raw } + +type CardMastercardProcessor string + +type CardMastercardData struct { + Bar int64 `json:"bar"` +} + +var portTests = map[string]struct { + from any + to any +}{ + "visa to card": { + CardVisa{ + Processor: "visa", + IsFoo: true, + Data: CardVisaData{ + Foo: "foo", + }, + Metadata: Metadata{ + CreatedAt: "Mar 29 2024", + }, + Value: "value", + JSON: cardVisaJSON{ + raw: `{"processor":"visa","is_foo":true,"data":{"foo":"foo"}}`, + Processor: Field{raw: `"visa"`, status: valid}, + IsFoo: Field{raw: `true`, status: valid}, + Data: Field{raw: `{"foo":"foo"}`, status: valid}, + Value: Field{raw: `"value"`, status: valid}, + ExtraFields: map[string]Field{"extra": {raw: `"yo"`, status: valid}}, + }, + }, + Card{ + Processor: "visa", + IsFoo: true, + IsBar: false, + Data: CardVisaData{ + Foo: "foo", + }, + Metadata: Metadata{ + CreatedAt: "Mar 29 2024", + }, + Value: "value", + JSON: cardJSON{ + raw: `{"processor":"visa","is_foo":true,"data":{"foo":"foo"}}`, + Processor: Field{raw: `"visa"`, status: valid}, + IsFoo: Field{raw: `true`, status: valid}, + Data: Field{raw: `{"foo":"foo"}`, status: valid}, + Value: Field{raw: `"value"`, status: valid}, + ExtraFields: map[string]Field{"extra": {raw: `"yo"`, status: valid}}, + }, + }, + }, + "mastercard to card": { + CardMastercard{ + Processor: "mastercard", + IsBar: true, + Data: CardMastercardData{ + Bar: 13, + }, + Value: false, + }, + Card{ + Processor: "mastercard", + IsFoo: false, + IsBar: true, + Data: CardMastercardData{ + Bar: 13, + }, + Value: false, + }, + }, +} + +func TestPort(t *testing.T) { + for name, test := range portTests { + t.Run(name, func(t *testing.T) { + toVal := reflect.New(reflect.TypeOf(test.to)) + + err := Port(test.from, toVal.Interface()) + if err != nil { + t.Fatalf("port of %v failed with error %v", test.from, err) + } + + if !reflect.DeepEqual(toVal.Elem().Interface(), test.to) { + t.Fatalf("expected:\n%+#v\n\nto port to:\n%+#v\n\nbut got:\n%+#v", test.from, test.to, toVal.Elem().Interface()) + } + }) + } +} diff --git a/internal/apijson/registry.go b/internal/apijson/registry.go new file mode 100644 index 0000000..fcc518b --- /dev/null +++ b/internal/apijson/registry.go @@ -0,0 +1,27 @@ +package apijson + +import ( + "reflect" + + "github.com/tidwall/gjson" +) + +type UnionVariant struct { + TypeFilter gjson.Type + DiscriminatorValue interface{} + Type reflect.Type +} + +var unionRegistry = map[reflect.Type]unionEntry{} + +type unionEntry struct { + discriminatorKey string + variants []UnionVariant +} + +func RegisterUnion(typ reflect.Type, discriminator string, variants ...UnionVariant) { + unionRegistry[typ] = unionEntry{ + discriminatorKey: discriminator, + variants: variants, + } +} diff --git a/internal/apijson/tag.go b/internal/apijson/tag.go new file mode 100644 index 0000000..812fb3c --- /dev/null +++ b/internal/apijson/tag.go @@ -0,0 +1,47 @@ +package apijson + +import ( + "reflect" + "strings" +) + +const jsonStructTag = "json" +const formatStructTag = "format" + +type parsedStructTag struct { + name string + required bool + extras bool + metadata bool + inline bool +} + +func parseJSONStructTag(field reflect.StructField) (tag parsedStructTag, ok bool) { + raw, ok := field.Tag.Lookup(jsonStructTag) + if !ok { + return + } + parts := strings.Split(raw, ",") + if len(parts) == 0 { + return tag, false + } + tag.name = parts[0] + for _, part := range parts[1:] { + switch part { + case "required": + tag.required = true + case "extras": + tag.extras = true + case "metadata": + tag.metadata = true + case "inline": + tag.inline = true + } + } + return +} + +func parseFormatStructTag(field reflect.StructField) (format string, ok bool) { + format, ok = field.Tag.Lookup(formatStructTag) + return +} diff --git a/internal/apiquery/encoder.go b/internal/apiquery/encoder.go new file mode 100644 index 0000000..9db8518 --- /dev/null +++ b/internal/apiquery/encoder.go @@ -0,0 +1,341 @@ +package apiquery + +import ( + "encoding/json" + "fmt" + "reflect" + "strconv" + "strings" + "sync" + "time" + + "github.com/anthropics/anthropic-sdk-go/internal/param" +) + +var encoders sync.Map // map[reflect.Type]encoderFunc + +type encoder struct { + dateFormat string + root bool + settings QuerySettings +} + +type encoderFunc func(key string, value reflect.Value) []Pair + +type encoderField struct { + tag parsedStructTag + fn encoderFunc + idx []int +} + +type encoderEntry struct { + reflect.Type + dateFormat string + root bool + settings QuerySettings +} + +type Pair struct { + key string + value string +} + +func (e *encoder) typeEncoder(t reflect.Type) encoderFunc { + entry := encoderEntry{ + Type: t, + dateFormat: e.dateFormat, + root: e.root, + settings: e.settings, + } + + if fi, ok := encoders.Load(entry); ok { + return fi.(encoderFunc) + } + + // To deal with recursive types, populate the map with an + // indirect func before we build it. This type waits on the + // real func (f) to be ready and then calls it. This indirect + // func is only used for recursive types. + var ( + wg sync.WaitGroup + f encoderFunc + ) + wg.Add(1) + fi, loaded := encoders.LoadOrStore(entry, encoderFunc(func(key string, v reflect.Value) []Pair { + wg.Wait() + return f(key, v) + })) + if loaded { + return fi.(encoderFunc) + } + + // Compute the real encoder and replace the indirect func with it. + f = e.newTypeEncoder(t) + wg.Done() + encoders.Store(entry, f) + return f +} + +func marshalerEncoder(key string, value reflect.Value) []Pair { + s, _ := value.Interface().(json.Marshaler).MarshalJSON() + return []Pair{{key, string(s)}} +} + +func (e *encoder) newTypeEncoder(t reflect.Type) encoderFunc { + if t.ConvertibleTo(reflect.TypeOf(time.Time{})) { + return e.newTimeTypeEncoder(t) + } + if !e.root && t.Implements(reflect.TypeOf((*json.Marshaler)(nil)).Elem()) { + return marshalerEncoder + } + e.root = false + switch t.Kind() { + case reflect.Pointer: + encoder := e.typeEncoder(t.Elem()) + return func(key string, value reflect.Value) (pairs []Pair) { + if !value.IsValid() || value.IsNil() { + return + } + pairs = encoder(key, value.Elem()) + return + } + case reflect.Struct: + return e.newStructTypeEncoder(t) + case reflect.Array: + fallthrough + case reflect.Slice: + return e.newArrayTypeEncoder(t) + case reflect.Map: + return e.newMapEncoder(t) + case reflect.Interface: + return e.newInterfaceEncoder() + default: + return e.newPrimitiveTypeEncoder(t) + } +} + +func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc { + if t.Implements(reflect.TypeOf((*param.FieldLike)(nil)).Elem()) { + return e.newFieldTypeEncoder(t) + } + + encoderFields := []encoderField{} + + // This helper allows us to recursively collect field encoders into a flat + // array. The parameter `index` keeps track of the access patterns necessary + // to get to some field. + var collectEncoderFields func(r reflect.Type, index []int) + collectEncoderFields = func(r reflect.Type, index []int) { + for i := 0; i < r.NumField(); i++ { + idx := append(index, i) + field := t.FieldByIndex(idx) + if !field.IsExported() { + continue + } + // If this is an embedded struct, traverse one level deeper to extract + // the field and get their encoders as well. + if field.Anonymous { + collectEncoderFields(field.Type, idx) + continue + } + // If query tag is not present, then we skip, which is intentionally + // different behavior from the stdlib. + ptag, ok := parseQueryStructTag(field) + if !ok { + continue + } + + if ptag.name == "-" && !ptag.inline { + continue + } + + dateFormat, ok := parseFormatStructTag(field) + oldFormat := e.dateFormat + if ok { + switch dateFormat { + case "date-time": + e.dateFormat = time.RFC3339 + case "date": + e.dateFormat = "2006-01-02" + } + } + encoderFields = append(encoderFields, encoderField{ptag, e.typeEncoder(field.Type), idx}) + e.dateFormat = oldFormat + } + } + collectEncoderFields(t, []int{}) + + return func(key string, value reflect.Value) (pairs []Pair) { + for _, ef := range encoderFields { + var subkey string = e.renderKeyPath(key, ef.tag.name) + if ef.tag.inline { + subkey = key + } + + field := value.FieldByIndex(ef.idx) + pairs = append(pairs, ef.fn(subkey, field)...) + } + return + } +} + +func (e *encoder) newMapEncoder(t reflect.Type) encoderFunc { + keyEncoder := e.typeEncoder(t.Key()) + elementEncoder := e.typeEncoder(t.Elem()) + return func(key string, value reflect.Value) (pairs []Pair) { + iter := value.MapRange() + for iter.Next() { + encodedKey := keyEncoder("", iter.Key()) + if len(encodedKey) != 1 { + panic("Unexpected number of parts for encoded map key. Are you using a non-primitive for this map?") + } + subkey := encodedKey[0].value + keyPath := e.renderKeyPath(key, subkey) + pairs = append(pairs, elementEncoder(keyPath, iter.Value())...) + } + return + } +} + +func (e *encoder) renderKeyPath(key string, subkey string) string { + if len(key) == 0 { + return subkey + } + if e.settings.NestedFormat == NestedQueryFormatDots { + return fmt.Sprintf("%s.%s", key, subkey) + } + return fmt.Sprintf("%s[%s]", key, subkey) +} + +func (e *encoder) newArrayTypeEncoder(t reflect.Type) encoderFunc { + switch e.settings.ArrayFormat { + case ArrayQueryFormatComma: + innerEncoder := e.typeEncoder(t.Elem()) + return func(key string, v reflect.Value) []Pair { + elements := []string{} + for i := 0; i < v.Len(); i++ { + for _, pair := range innerEncoder("", v.Index(i)) { + elements = append(elements, pair.value) + } + } + if len(elements) == 0 { + return []Pair{} + } + return []Pair{{key, strings.Join(elements, ",")}} + } + case ArrayQueryFormatRepeat: + innerEncoder := e.typeEncoder(t.Elem()) + return func(key string, value reflect.Value) (pairs []Pair) { + for i := 0; i < value.Len(); i++ { + pairs = append(pairs, innerEncoder(key, value.Index(i))...) + } + return pairs + } + case ArrayQueryFormatIndices: + panic("The array indices format is not supported yet") + case ArrayQueryFormatBrackets: + innerEncoder := e.typeEncoder(t.Elem()) + return func(key string, value reflect.Value) []Pair { + pairs := []Pair{} + for i := 0; i < value.Len(); i++ { + pairs = append(pairs, innerEncoder(key+"[]", value.Index(i))...) + } + return pairs + } + default: + panic(fmt.Sprintf("Unknown ArrayFormat value: %d", e.settings.ArrayFormat)) + } +} + +func (e *encoder) newPrimitiveTypeEncoder(t reflect.Type) encoderFunc { + switch t.Kind() { + case reflect.Pointer: + inner := t.Elem() + + innerEncoder := e.newPrimitiveTypeEncoder(inner) + return func(key string, v reflect.Value) []Pair { + if !v.IsValid() || v.IsNil() { + return nil + } + return innerEncoder(key, v.Elem()) + } + case reflect.String: + return func(key string, v reflect.Value) []Pair { + return []Pair{{key, v.String()}} + } + case reflect.Bool: + return func(key string, v reflect.Value) []Pair { + if v.Bool() { + return []Pair{{key, "true"}} + } + return []Pair{{key, "false"}} + } + case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: + return func(key string, v reflect.Value) []Pair { + return []Pair{{key, strconv.FormatInt(v.Int(), 10)}} + } + case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return func(key string, v reflect.Value) []Pair { + return []Pair{{key, strconv.FormatUint(v.Uint(), 10)}} + } + case reflect.Float32, reflect.Float64: + return func(key string, v reflect.Value) []Pair { + return []Pair{{key, strconv.FormatFloat(v.Float(), 'f', -1, 64)}} + } + case reflect.Complex64, reflect.Complex128: + bitSize := 64 + if t.Kind() == reflect.Complex128 { + bitSize = 128 + } + return func(key string, v reflect.Value) []Pair { + return []Pair{{key, strconv.FormatComplex(v.Complex(), 'f', -1, bitSize)}} + } + default: + return func(key string, v reflect.Value) []Pair { + return nil + } + } +} + +func (e *encoder) newFieldTypeEncoder(t reflect.Type) encoderFunc { + f, _ := t.FieldByName("Value") + enc := e.typeEncoder(f.Type) + + return func(key string, value reflect.Value) []Pair { + present := value.FieldByName("Present") + if !present.Bool() { + return nil + } + null := value.FieldByName("Null") + if null.Bool() { + // TODO: Error? + return nil + } + raw := value.FieldByName("Raw") + if !raw.IsNil() { + return e.typeEncoder(raw.Type())(key, raw) + } + return enc(key, value.FieldByName("Value")) + } +} + +func (e *encoder) newTimeTypeEncoder(t reflect.Type) encoderFunc { + format := e.dateFormat + return func(key string, value reflect.Value) []Pair { + return []Pair{{ + key, + value.Convert(reflect.TypeOf(time.Time{})).Interface().(time.Time).Format(format), + }} + } +} + +func (e encoder) newInterfaceEncoder() encoderFunc { + return func(key string, value reflect.Value) []Pair { + value = value.Elem() + if !value.IsValid() { + return nil + } + return e.typeEncoder(value.Type())(key, value) + } + +} diff --git a/internal/apiquery/query.go b/internal/apiquery/query.go new file mode 100644 index 0000000..6f90e99 --- /dev/null +++ b/internal/apiquery/query.go @@ -0,0 +1,50 @@ +package apiquery + +import ( + "net/url" + "reflect" + "time" +) + +func MarshalWithSettings(value interface{}, settings QuerySettings) url.Values { + e := encoder{time.RFC3339, true, settings} + kv := url.Values{} + val := reflect.ValueOf(value) + if !val.IsValid() { + return nil + } + typ := val.Type() + for _, pair := range e.typeEncoder(typ)("", val) { + kv.Add(pair.key, pair.value) + } + return kv +} + +func Marshal(value interface{}) url.Values { + return MarshalWithSettings(value, QuerySettings{}) +} + +type Queryer interface { + URLQuery() url.Values +} + +type QuerySettings struct { + NestedFormat NestedQueryFormat + ArrayFormat ArrayQueryFormat +} + +type NestedQueryFormat int + +const ( + NestedQueryFormatBrackets NestedQueryFormat = iota + NestedQueryFormatDots +) + +type ArrayQueryFormat int + +const ( + ArrayQueryFormatComma ArrayQueryFormat = iota + ArrayQueryFormatRepeat + ArrayQueryFormatIndices + ArrayQueryFormatBrackets +) diff --git a/internal/apiquery/query_test.go b/internal/apiquery/query_test.go new file mode 100644 index 0000000..1e740d6 --- /dev/null +++ b/internal/apiquery/query_test.go @@ -0,0 +1,335 @@ +package apiquery + +import ( + "net/url" + "testing" + "time" +) + +func P[T any](v T) *T { return &v } + +type Primitives struct { + A bool `query:"a"` + B int `query:"b"` + C uint `query:"c"` + D float64 `query:"d"` + E float32 `query:"e"` + F []int `query:"f"` +} + +type PrimitivePointers struct { + A *bool `query:"a"` + B *int `query:"b"` + C *uint `query:"c"` + D *float64 `query:"d"` + E *float32 `query:"e"` + F *[]int `query:"f"` +} + +type Slices struct { + Slice []Primitives `query:"slices"` + Mixed []interface{} `query:"mixed"` +} + +type DateTime struct { + Date time.Time `query:"date" format:"date"` + DateTime time.Time `query:"date-time" format:"date-time"` +} + +type AdditionalProperties struct { + A bool `query:"a"` + Extras map[string]interface{} `query:"-,inline"` +} + +type Recursive struct { + Name string `query:"name"` + Child *Recursive `query:"child"` +} + +type UnknownStruct struct { + Unknown interface{} `query:"unknown"` +} + +type UnionStruct struct { + Union Union `query:"union" format:"date"` +} + +type Union interface { + union() +} + +type UnionInteger int64 + +func (UnionInteger) union() {} + +type UnionString string + +func (UnionString) union() {} + +type UnionStructA struct { + Type string `query:"type"` + A string `query:"a"` + B string `query:"b"` +} + +func (UnionStructA) union() {} + +type UnionStructB struct { + Type string `query:"type"` + A string `query:"a"` +} + +func (UnionStructB) union() {} + +type UnionTime time.Time + +func (UnionTime) union() {} + +type DeeplyNested struct { + A DeeplyNested1 `query:"a"` +} + +type DeeplyNested1 struct { + B DeeplyNested2 `query:"b"` +} + +type DeeplyNested2 struct { + C DeeplyNested3 `query:"c"` +} + +type DeeplyNested3 struct { + D *string `query:"d"` +} + +var tests = map[string]struct { + enc string + val interface{} + settings QuerySettings +}{ + "primitives": { + "a=false&b=237628372683&c=654&d=9999.43&e=43.7599983215332&f=1,2,3,4", + Primitives{A: false, B: 237628372683, C: uint(654), D: 9999.43, E: 43.76, F: []int{1, 2, 3, 4}}, + QuerySettings{}, + }, + + "slices_brackets": { + `mixed[]=1&mixed[]=2.3&mixed[]=hello&slices[][a]=false&slices[][a]=false&slices[][b]=237628372683&slices[][b]=237628372683&slices[][c]=654&slices[][c]=654&slices[][d]=9999.43&slices[][d]=9999.43&slices[][e]=43.7599983215332&slices[][e]=43.7599983215332&slices[][f][]=1&slices[][f][]=2&slices[][f][]=3&slices[][f][]=4&slices[][f][]=1&slices[][f][]=2&slices[][f][]=3&slices[][f][]=4`, + Slices{ + Slice: []Primitives{ + {A: false, B: 237628372683, C: uint(654), D: 9999.43, E: 43.76, F: []int{1, 2, 3, 4}}, + {A: false, B: 237628372683, C: uint(654), D: 9999.43, E: 43.76, F: []int{1, 2, 3, 4}}, + }, + Mixed: []interface{}{1, 2.3, "hello"}, + }, + QuerySettings{ArrayFormat: ArrayQueryFormatBrackets}, + }, + + "slices_comma": { + `mixed=1,2.3,hello`, + Slices{ + Mixed: []interface{}{1, 2.3, "hello"}, + }, + QuerySettings{ArrayFormat: ArrayQueryFormatComma}, + }, + + "slices_repeat": { + `mixed=1&mixed=2.3&mixed=hello&slices[a]=false&slices[a]=false&slices[b]=237628372683&slices[b]=237628372683&slices[c]=654&slices[c]=654&slices[d]=9999.43&slices[d]=9999.43&slices[e]=43.7599983215332&slices[e]=43.7599983215332&slices[f]=1&slices[f]=2&slices[f]=3&slices[f]=4&slices[f]=1&slices[f]=2&slices[f]=3&slices[f]=4`, + Slices{ + Slice: []Primitives{ + {A: false, B: 237628372683, C: uint(654), D: 9999.43, E: 43.76, F: []int{1, 2, 3, 4}}, + {A: false, B: 237628372683, C: uint(654), D: 9999.43, E: 43.76, F: []int{1, 2, 3, 4}}, + }, + Mixed: []interface{}{1, 2.3, "hello"}, + }, + QuerySettings{ArrayFormat: ArrayQueryFormatRepeat}, + }, + + "primitive_pointer_struct": { + "a=false&b=237628372683&c=654&d=9999.43&e=43.7599983215332&f=1,2,3,4,5", + PrimitivePointers{ + A: P(false), + B: P(237628372683), + C: P(uint(654)), + D: P(9999.43), + E: P(float32(43.76)), + F: &[]int{1, 2, 3, 4, 5}, + }, + QuerySettings{}, + }, + + "datetime_struct": { + `date=2006-01-02&date-time=2006-01-02T15:04:05Z`, + DateTime{ + Date: time.Date(2006, time.January, 2, 0, 0, 0, 0, time.UTC), + DateTime: time.Date(2006, time.January, 2, 15, 4, 5, 0, time.UTC), + }, + QuerySettings{}, + }, + + "additional_properties": { + `a=true&bar=value&foo=true`, + AdditionalProperties{ + A: true, + Extras: map[string]interface{}{ + "bar": "value", + "foo": true, + }, + }, + QuerySettings{}, + }, + + "recursive_struct_brackets": { + `child[name]=Alex&name=Robert`, + Recursive{Name: "Robert", Child: &Recursive{Name: "Alex"}}, + QuerySettings{NestedFormat: NestedQueryFormatBrackets}, + }, + + "recursive_struct_dots": { + `child.name=Alex&name=Robert`, + Recursive{Name: "Robert", Child: &Recursive{Name: "Alex"}}, + QuerySettings{NestedFormat: NestedQueryFormatDots}, + }, + + "unknown_struct_number": { + `unknown=12`, + UnknownStruct{ + Unknown: 12., + }, + QuerySettings{}, + }, + + "unknown_struct_map_brackets": { + `unknown[foo]=bar`, + UnknownStruct{ + Unknown: map[string]interface{}{ + "foo": "bar", + }, + }, + QuerySettings{NestedFormat: NestedQueryFormatBrackets}, + }, + + "unknown_struct_map_dots": { + `unknown.foo=bar`, + UnknownStruct{ + Unknown: map[string]interface{}{ + "foo": "bar", + }, + }, + QuerySettings{NestedFormat: NestedQueryFormatDots}, + }, + + "union_string": { + `union=hello`, + UnionStruct{ + Union: UnionString("hello"), + }, + QuerySettings{}, + }, + + "union_integer": { + `union=12`, + UnionStruct{ + Union: UnionInteger(12), + }, + QuerySettings{}, + }, + + "union_struct_discriminated_a": { + `union[a]=foo&union[b]=bar&union[type]=typeA`, + UnionStruct{ + Union: UnionStructA{ + Type: "typeA", + A: "foo", + B: "bar", + }, + }, + QuerySettings{}, + }, + + "union_struct_discriminated_b": { + `union[a]=foo&union[type]=typeB`, + UnionStruct{ + Union: UnionStructB{ + Type: "typeB", + A: "foo", + }, + }, + QuerySettings{}, + }, + + "union_struct_time": { + `union=2010-05-23`, + UnionStruct{ + Union: UnionTime(time.Date(2010, 05, 23, 0, 0, 0, 0, time.UTC)), + }, + QuerySettings{}, + }, + + "deeply_nested_brackets": { + `a[b][c][d]=hello`, + DeeplyNested{ + A: DeeplyNested1{ + B: DeeplyNested2{ + C: DeeplyNested3{ + D: P("hello"), + }, + }, + }, + }, + QuerySettings{NestedFormat: NestedQueryFormatBrackets}, + }, + + "deeply_nested_dots": { + `a.b.c.d=hello`, + DeeplyNested{ + A: DeeplyNested1{ + B: DeeplyNested2{ + C: DeeplyNested3{ + D: P("hello"), + }, + }, + }, + }, + QuerySettings{NestedFormat: NestedQueryFormatDots}, + }, + + "deeply_nested_brackets_empty": { + ``, + DeeplyNested{ + A: DeeplyNested1{ + B: DeeplyNested2{ + C: DeeplyNested3{ + D: nil, + }, + }, + }, + }, + QuerySettings{NestedFormat: NestedQueryFormatBrackets}, + }, + + "deeply_nested_dots_empty": { + ``, + DeeplyNested{ + A: DeeplyNested1{ + B: DeeplyNested2{ + C: DeeplyNested3{ + D: nil, + }, + }, + }, + }, + QuerySettings{NestedFormat: NestedQueryFormatDots}, + }, +} + +func TestEncode(t *testing.T) { + for name, test := range tests { + t.Run(name, func(t *testing.T) { + values := MarshalWithSettings(test.val, test.settings) + str, _ := url.QueryUnescape(values.Encode()) + if str != test.enc { + t.Fatalf("expected %+#v to serialize to %s but got %s", test.val, test.enc, str) + } + }) + } +} diff --git a/internal/apiquery/tag.go b/internal/apiquery/tag.go new file mode 100644 index 0000000..7ccd739 --- /dev/null +++ b/internal/apiquery/tag.go @@ -0,0 +1,41 @@ +package apiquery + +import ( + "reflect" + "strings" +) + +const queryStructTag = "query" +const formatStructTag = "format" + +type parsedStructTag struct { + name string + omitempty bool + inline bool +} + +func parseQueryStructTag(field reflect.StructField) (tag parsedStructTag, ok bool) { + raw, ok := field.Tag.Lookup(queryStructTag) + if !ok { + return + } + parts := strings.Split(raw, ",") + if len(parts) == 0 { + return tag, false + } + tag.name = parts[0] + for _, part := range parts[1:] { + switch part { + case "omitempty": + tag.omitempty = true + case "inline": + tag.inline = true + } + } + return +} + +func parseFormatStructTag(field reflect.StructField) (format string, ok bool) { + format, ok = field.Tag.Lookup(formatStructTag) + return +} diff --git a/internal/param/field.go b/internal/param/field.go new file mode 100644 index 0000000..4d0fd9c --- /dev/null +++ b/internal/param/field.go @@ -0,0 +1,29 @@ +package param + +import ( + "fmt" +) + +type FieldLike interface{ field() } + +// Field is a wrapper used for all values sent to the API, +// to distinguish zero values from null or omitted fields. +// +// It also allows sending arbitrary deserializable values. +// +// To instantiate a Field, use the helpers exported from +// the package root: `F()`, `Null()`, `Raw()`, etc. +type Field[T any] struct { + FieldLike + Value T + Null bool + Present bool + Raw any +} + +func (f Field[T]) String() string { + if s, ok := any(f.Value).(fmt.Stringer); ok { + return s.String() + } + return fmt.Sprintf("%v", f.Value) +} diff --git a/internal/requestconfig/requestconfig.go b/internal/requestconfig/requestconfig.go new file mode 100644 index 0000000..b02acfc --- /dev/null +++ b/internal/requestconfig/requestconfig.go @@ -0,0 +1,486 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package requestconfig + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "math" + "math/rand" + "net/http" + "net/url" + "runtime" + "strconv" + "strings" + "time" + + "github.com/anthropics/anthropic-sdk-go/internal" + "github.com/anthropics/anthropic-sdk-go/internal/apierror" + "github.com/anthropics/anthropic-sdk-go/internal/apiform" + "github.com/anthropics/anthropic-sdk-go/internal/apiquery" +) + +func getDefaultHeaders() map[string]string { + return map[string]string{ + "User-Agent": fmt.Sprintf("Anthropic/Go %s", internal.PackageVersion), + } +} + +func getNormalizedOS() string { + switch runtime.GOOS { + case "ios": + return "iOS" + case "android": + return "Android" + case "darwin": + return "MacOS" + case "window": + return "Windows" + case "freebsd": + return "FreeBSD" + case "openbsd": + return "OpenBSD" + case "linux": + return "Linux" + default: + return fmt.Sprintf("Other:%s", runtime.GOOS) + } +} + +func getNormalizedArchitecture() string { + switch runtime.GOARCH { + case "386": + return "x32" + case "amd64": + return "x64" + case "arm": + return "arm" + case "arm64": + return "arm64" + default: + return fmt.Sprintf("other:%s", runtime.GOARCH) + } +} + +func getPlatformProperties() map[string]string { + return map[string]string{ + "X-Stainless-Lang": "go", + "X-Stainless-Package-Version": internal.PackageVersion, + "X-Stainless-OS": getNormalizedOS(), + "X-Stainless-Arch": getNormalizedArchitecture(), + "X-Stainless-Runtime": "go", + "X-Stainless-Runtime-Version": runtime.Version(), + } +} + +func NewRequestConfig(ctx context.Context, method string, u string, body interface{}, dst interface{}, opts ...func(*RequestConfig) error) (*RequestConfig, error) { + var reader io.Reader + + contentType := "application/json" + hasSerializationFunc := false + + if body, ok := body.(json.Marshaler); ok { + content, err := body.MarshalJSON() + if err != nil { + return nil, err + } + reader = bytes.NewBuffer(content) + hasSerializationFunc = true + } + if body, ok := body.(apiform.Marshaler); ok { + var ( + content []byte + err error + ) + content, contentType, err = body.MarshalMultipart() + if err != nil { + return nil, err + } + reader = bytes.NewBuffer(content) + hasSerializationFunc = true + } + if body, ok := body.(apiquery.Queryer); ok { + hasSerializationFunc = true + params := body.URLQuery().Encode() + if params != "" { + u = u + "?" + params + } + } + if body, ok := body.([]byte); ok { + reader = bytes.NewBuffer(body) + hasSerializationFunc = true + } + if body, ok := body.(io.Reader); ok { + reader = body + hasSerializationFunc = true + } + + // Fallback to json serialization if none of the serialization functions that we expect + // to see is present. + if body != nil && !hasSerializationFunc { + content, err := json.Marshal(body) + if err != nil { + return nil, err + } + reader = bytes.NewBuffer(content) + } + + req, err := http.NewRequestWithContext(ctx, method, u, nil) + if err != nil { + return nil, err + } + if reader != nil { + req.Header.Set("Content-Type", contentType) + } + + req.Header.Set("Accept", "application/json") + for k, v := range getDefaultHeaders() { + req.Header.Add(k, v) + } + req.Header.Set("anthropic-version", "2023-06-01") + for k, v := range getPlatformProperties() { + req.Header.Add(k, v) + } + cfg := RequestConfig{ + MaxRetries: 2, + Context: ctx, + Request: req, + HTTPClient: http.DefaultClient, + Body: reader, + } + cfg.ResponseBodyInto = dst + err = cfg.Apply(opts...) + if err != nil { + return nil, err + } + return &cfg, nil +} + +// RequestConfig represents all the state related to one request. +// +// Editing the variables inside RequestConfig directly is unstable api. Prefer +// composing func(\*RequestConfig) error instead if possible. +type RequestConfig struct { + MaxRetries int + RequestTimeout time.Duration + Context context.Context + Request *http.Request + BaseURL *url.URL + HTTPClient *http.Client + Middlewares []middleware + APIKey string + AuthToken string + // If ResponseBodyInto not nil, then we will attempt to deserialize into + // ResponseBodyInto. If Destination is a []byte, then it will return the body as + // is. + ResponseBodyInto interface{} + // ResponseInto copies the \*http.Response of the corresponding request into the + // given address + ResponseInto **http.Response + Body io.Reader +} + +// middleware is exactly the same type as the Middleware type found in the [option] package, +// but it is redeclared here for circular dependency issues. +type middleware = func(*http.Request, middlewareNext) (*http.Response, error) + +// middlewareNext is exactly the same type as the MiddlewareNext type found in the [option] package, +// but it is redeclared here for circular dependency issues. +type middlewareNext = func(*http.Request) (*http.Response, error) + +func applyMiddleware(middleware middleware, next middlewareNext) middlewareNext { + return func(req *http.Request) (res *http.Response, err error) { + return middleware(req, next) + } +} + +func shouldRetry(req *http.Request, res *http.Response) bool { + // If there is no way to recover the Body, then we shouldn't retry. + if req.Body != nil && req.GetBody == nil { + return false + } + + // If there is no response, that indicates that there is a connection error + // so we retry the request. + if res == nil { + return true + } + + // If the header explictly wants a retry behavior, respect that over the + // http status code. + if res.Header.Get("x-should-retry") == "true" { + return true + } + if res.Header.Get("x-should-retry") == "false" { + return false + } + + return res.StatusCode == http.StatusRequestTimeout || + res.StatusCode == http.StatusConflict || + res.StatusCode == http.StatusTooManyRequests || + res.StatusCode >= http.StatusInternalServerError +} + +func parseRetryAfterHeader(resp *http.Response) (time.Duration, bool) { + if resp == nil { + return 0, false + } + + type retryData struct { + header string + units time.Duration + + // custom is used when the regular algorithm failed and is optional. + // the returned duration is used verbatim (units is not applied). + custom func(string) (time.Duration, bool) + } + + nop := func(string) (time.Duration, bool) { return 0, false } + + // the headers are listed in order of preference + retries := []retryData{ + { + header: "Retry-After-Ms", + units: time.Millisecond, + custom: nop, + }, + { + header: "Retry-After", + units: time.Second, + + // retry-after values are expressed in either number of + // seconds or an HTTP-date indicating when to try again + custom: func(ra string) (time.Duration, bool) { + t, err := time.Parse(time.RFC1123, ra) + if err != nil { + return 0, false + } + return time.Until(t), true + }, + }, + } + + for _, retry := range retries { + v := resp.Header.Get(retry.header) + if v == "" { + continue + } + if retryAfter, err := strconv.ParseFloat(v, 64); err == nil { + return time.Duration(retryAfter * float64(retry.units)), true + } + if d, ok := retry.custom(v); ok { + return d, true + } + } + + return 0, false +} + +func retryDelay(res *http.Response, retryCount int) time.Duration { + // If the API asks us to wait a certain amount of time (and it's a reasonable amount), + // just do what it says. + + if retryAfterDelay, ok := parseRetryAfterHeader(res); ok && 0 <= retryAfterDelay && retryAfterDelay < time.Minute { + return retryAfterDelay + } + + maxDelay := 8 * time.Second + delay := time.Duration(0.5 * float64(time.Second) * math.Pow(2, float64(retryCount))) + if delay > maxDelay { + delay = maxDelay + } + + jitter := rand.Int63n(int64(delay / 4)) + delay -= time.Duration(jitter) + return delay +} + +func (cfg *RequestConfig) Execute() (err error) { + cfg.Request.URL, err = cfg.BaseURL.Parse(strings.TrimLeft(cfg.Request.URL.String(), "/")) + if err != nil { + return err + } + + if cfg.Body != nil && cfg.Request.Body == nil { + switch body := cfg.Body.(type) { + case *bytes.Buffer: + b := body.Bytes() + cfg.Request.ContentLength = int64(body.Len()) + cfg.Request.GetBody = func() (io.ReadCloser, error) { return io.NopCloser(bytes.NewReader(b)), nil } + cfg.Request.Body, _ = cfg.Request.GetBody() + case *bytes.Reader: + cfg.Request.ContentLength = int64(body.Len()) + cfg.Request.GetBody = func() (io.ReadCloser, error) { + _, err := body.Seek(0, 0) + return io.NopCloser(body), err + } + cfg.Request.Body, _ = cfg.Request.GetBody() + default: + if rc, ok := body.(io.ReadCloser); ok { + cfg.Request.Body = rc + } else { + cfg.Request.Body = io.NopCloser(body) + } + } + } + + handler := cfg.HTTPClient.Do + for i := len(cfg.Middlewares) - 1; i >= 0; i -= 1 { + handler = applyMiddleware(cfg.Middlewares[i], handler) + } + + var res *http.Response + for retryCount := 0; retryCount <= cfg.MaxRetries; retryCount += 1 { + ctx := cfg.Request.Context() + if cfg.RequestTimeout != time.Duration(0) { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, cfg.RequestTimeout) + defer cancel() + } + + res, err = handler(cfg.Request.Clone(ctx)) + if ctx != nil && ctx.Err() != nil { + return ctx.Err() + } + if !shouldRetry(cfg.Request, res) || retryCount >= cfg.MaxRetries { + break + } + + // Prepare next request and wait for the retry delay + if cfg.Request.GetBody != nil { + cfg.Request.Body, err = cfg.Request.GetBody() + if err != nil { + return err + } + } + + // Can't actually refresh the body, so we don't attempt to retry here + if cfg.Request.GetBody == nil && cfg.Request.Body != nil { + break + } + + time.Sleep(retryDelay(res, retryCount)) + } + + // Save *http.Response if it is requested to, even if there was an error making the request. This is + // useful in cases where you might want to debug by inspecting the response. Note that if err != nil, + // the response should be generally be empty, but there are edge cases. + if cfg.ResponseInto != nil { + *cfg.ResponseInto = res + } + if responseBodyInto, ok := cfg.ResponseBodyInto.(**http.Response); ok { + *responseBodyInto = res + } + + // If there was a connection error in the final request or any other transport error, + // return that early without trying to coerce into an APIError. + if err != nil { + return err + } + + if res.StatusCode >= 400 { + contents, err := io.ReadAll(res.Body) + res.Body.Close() + if err != nil { + return err + } + + // If there is an APIError, re-populate the response body so that debugging + // utilities can conveniently dump the response without issue. + res.Body = io.NopCloser(bytes.NewBuffer(contents)) + + // Load the contents into the error format if it is provided. + aerr := apierror.Error{Request: cfg.Request, Response: res, StatusCode: res.StatusCode} + err = aerr.UnmarshalJSON(contents) + if err != nil { + return err + } + return &aerr + } + + if cfg.ResponseBodyInto == nil { + return nil + } + if _, ok := cfg.ResponseBodyInto.(**http.Response); ok { + return nil + } + + contents, err := io.ReadAll(res.Body) + if err != nil { + return fmt.Errorf("error reading response body: %w", err) + } + + // If we are not json, return plaintext + contentType := res.Header.Get("content-type") + isJSON := strings.Contains(contentType, "application/json") || strings.Contains(contentType, "application/vnd.api+json") + if !isJSON { + switch dst := cfg.ResponseBodyInto.(type) { + case *string: + *dst = string(contents) + case **string: + tmp := string(contents) + *dst = &tmp + case *[]byte: + *dst = contents + default: + return fmt.Errorf("expected destination type of 'string' or '[]byte' for responses with content-type that is not 'application/json'") + } + return nil + } + + // If the response happens to be a byte array, deserialize the body as-is. + switch dst := cfg.ResponseBodyInto.(type) { + case *[]byte: + *dst = contents + } + + err = json.NewDecoder(bytes.NewReader(contents)).Decode(cfg.ResponseBodyInto) + if err != nil { + err = fmt.Errorf("error parsing response json: %w", err) + } + + return nil +} + +func ExecuteNewRequest(ctx context.Context, method string, u string, body interface{}, dst interface{}, opts ...func(*RequestConfig) error) error { + cfg, err := NewRequestConfig(ctx, method, u, body, dst, opts...) + if err != nil { + return err + } + return cfg.Execute() +} + +func (cfg *RequestConfig) Clone(ctx context.Context) *RequestConfig { + if cfg == nil { + return nil + } + req := cfg.Request.Clone(ctx) + var err error + if req.Body != nil { + req.Body, err = req.GetBody() + } + if err != nil { + return nil + } + new := &RequestConfig{ + MaxRetries: cfg.MaxRetries, + Context: ctx, + Request: req, + HTTPClient: cfg.HTTPClient, + } + + return new +} + +func (cfg *RequestConfig) Apply(opts ...func(*RequestConfig) error) error { + for _, opt := range opts { + err := opt(cfg) + if err != nil { + return err + } + } + return nil +} diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go new file mode 100644 index 0000000..826d266 --- /dev/null +++ b/internal/testutil/testutil.go @@ -0,0 +1,27 @@ +package testutil + +import ( + "net/http" + "os" + "strconv" + "testing" +) + +func CheckTestServer(t *testing.T, url string) bool { + if _, err := http.Get(url); err != nil { + const SKIP_MOCK_TESTS = "SKIP_MOCK_TESTS" + if str, ok := os.LookupEnv(SKIP_MOCK_TESTS); ok { + skip, err := strconv.ParseBool(str) + if err != nil { + t.Fatalf("strconv.ParseBool(os.LookupEnv(%s)) failed: %s", SKIP_MOCK_TESTS, err) + } + if skip { + t.Skip("The test will not run without a mock Prism server running against your OpenAPI spec") + return false + } + t.Errorf("The test will not run without a mock Prism server running against your OpenAPI spec. You can set the environment variable %s to true to skip running any tests that require the mock server", SKIP_MOCK_TESTS) + return false + } + } + return true +} diff --git a/internal/version.go b/internal/version.go new file mode 100644 index 0000000..4ff68e4 --- /dev/null +++ b/internal/version.go @@ -0,0 +1,5 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package internal + +const PackageVersion = "0.0.1-alpha.0" // x-release-please-version diff --git a/lib/.keep b/lib/.keep new file mode 100644 index 0000000..5e2c99f --- /dev/null +++ b/lib/.keep @@ -0,0 +1,4 @@ +File generated from our OpenAPI spec by Stainless. + +This directory can be used to store custom files to expand the SDK. +It is ignored by Stainless code generation and its content (other than this keep file) won't be touched. \ No newline at end of file diff --git a/message.go b/message.go new file mode 100644 index 0000000..c0be88e --- /dev/null +++ b/message.go @@ -0,0 +1,1788 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package anthropic + +import ( + "context" + "net/http" + "reflect" + + "github.com/anthropics/anthropic-sdk-go/internal/apijson" + "github.com/anthropics/anthropic-sdk-go/internal/param" + "github.com/anthropics/anthropic-sdk-go/internal/requestconfig" + "github.com/anthropics/anthropic-sdk-go/option" + "github.com/anthropics/anthropic-sdk-go/packages/ssestream" + "github.com/tidwall/gjson" +) + +// MessageService contains methods and other services that help with interacting +// with the anthropic API. +// +// Note, unlike clients, this service does not read variables from the environment +// automatically. You should not instantiate this service directly, and instead use +// the [NewMessageService] method instead. +type MessageService struct { + Options []option.RequestOption +} + +// NewMessageService generates a new service that applies the given options to each +// request. These options are applied after the parent client's options (if there +// is one), and before any request-specific options. +func NewMessageService(opts ...option.RequestOption) (r *MessageService) { + r = &MessageService{} + r.Options = opts + return +} + +// Create a Message. +// +// Send a structured list of input messages with text and/or image content, and the +// model will generate the next message in the conversation. +// +// The Messages API can be used for either single queries or stateless multi-turn +// conversations. +// +// Note: If you choose to set a timeout for this request, we recommend 10 minutes. +func (r *MessageService) New(ctx context.Context, body MessageNewParams, opts ...option.RequestOption) (res *Message, err error) { + opts = append(r.Options[:], opts...) + path := "v1/messages" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + return +} + +// Create a Message. +// +// Send a structured list of input messages with text and/or image content, and the +// model will generate the next message in the conversation. +// +// The Messages API can be used for either single queries or stateless multi-turn +// conversations. +// +// Note: If you choose to set a timeout for this request, we recommend 10 minutes. +func (r *MessageService) NewStreaming(ctx context.Context, body MessageNewParams, opts ...option.RequestOption) (stream *ssestream.Stream[MessageStreamEvent]) { + var ( + raw *http.Response + err error + ) + opts = append(r.Options[:], opts...) + opts = append([]option.RequestOption{option.WithJSONSet("stream", true)}, opts...) + path := "v1/messages" + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &raw, opts...) + return ssestream.NewStream[MessageStreamEvent](ssestream.NewDecoder(raw), err) +} + +type ContentBlock struct { + Type ContentBlockType `json:"type,required"` + Text string `json:"text"` + ID string `json:"id"` + Name string `json:"name"` + // This field can have the runtime type of [interface{}]. + Input interface{} `json:"input,required"` + JSON contentBlockJSON `json:"-"` + union ContentBlockUnion +} + +// contentBlockJSON contains the JSON metadata for the struct [ContentBlock] +type contentBlockJSON struct { + Type apijson.Field + Text apijson.Field + ID apijson.Field + Name apijson.Field + Input apijson.Field + raw string + ExtraFields map[string]apijson.Field +} + +func (r contentBlockJSON) RawJSON() string { + return r.raw +} + +func (r *ContentBlock) UnmarshalJSON(data []byte) (err error) { + *r = ContentBlock{} + err = apijson.UnmarshalRoot(data, &r.union) + if err != nil { + return err + } + return apijson.Port(r.union, &r) +} + +// AsUnion returns a [ContentBlockUnion] interface which you can cast to the +// specific types for more type safety. +// +// Possible runtime types of the union are [TextBlock], [ToolUseBlock]. +func (r ContentBlock) AsUnion() ContentBlockUnion { + return r.union +} + +// Union satisfied by [TextBlock] or [ToolUseBlock]. +type ContentBlockUnion interface { + implementsContentBlock() +} + +func init() { + apijson.RegisterUnion( + reflect.TypeOf((*ContentBlockUnion)(nil)).Elem(), + "type", + apijson.UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(TextBlock{}), + DiscriminatorValue: "text", + }, + apijson.UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(ToolUseBlock{}), + DiscriminatorValue: "tool_use", + }, + ) +} + +type ContentBlockType string + +const ( + ContentBlockTypeText ContentBlockType = "text" + ContentBlockTypeToolUse ContentBlockType = "tool_use" +) + +func (r ContentBlockType) IsKnown() bool { + switch r { + case ContentBlockTypeText, ContentBlockTypeToolUse: + return true + } + return false +} + +type ImageBlockParam struct { + Source param.Field[ImageBlockParamSource] `json:"source,required"` + Type param.Field[ImageBlockParamType] `json:"type,required"` +} + +func (r ImageBlockParam) MarshalJSON() (data []byte, err error) { + return apijson.MarshalRoot(r) +} + +func (r ImageBlockParam) implementsMessageParamContentUnion() {} + +func (r ImageBlockParam) implementsToolResultBlockParamContentUnion() {} + +type ImageBlockParamSource struct { + Data param.Field[string] `json:"data,required" format:"byte"` + MediaType param.Field[ImageBlockParamSourceMediaType] `json:"media_type,required"` + Type param.Field[ImageBlockParamSourceType] `json:"type,required"` +} + +func (r ImageBlockParamSource) MarshalJSON() (data []byte, err error) { + return apijson.MarshalRoot(r) +} + +type ImageBlockParamSourceMediaType string + +const ( + ImageBlockParamSourceMediaTypeImageJPEG ImageBlockParamSourceMediaType = "image/jpeg" + ImageBlockParamSourceMediaTypeImagePNG ImageBlockParamSourceMediaType = "image/png" + ImageBlockParamSourceMediaTypeImageGIF ImageBlockParamSourceMediaType = "image/gif" + ImageBlockParamSourceMediaTypeImageWebP ImageBlockParamSourceMediaType = "image/webp" +) + +func (r ImageBlockParamSourceMediaType) IsKnown() bool { + switch r { + case ImageBlockParamSourceMediaTypeImageJPEG, ImageBlockParamSourceMediaTypeImagePNG, ImageBlockParamSourceMediaTypeImageGIF, ImageBlockParamSourceMediaTypeImageWebP: + return true + } + return false +} + +type ImageBlockParamSourceType string + +const ( + ImageBlockParamSourceTypeBase64 ImageBlockParamSourceType = "base64" +) + +func (r ImageBlockParamSourceType) IsKnown() bool { + switch r { + case ImageBlockParamSourceTypeBase64: + return true + } + return false +} + +type ImageBlockParamType string + +const ( + ImageBlockParamTypeImage ImageBlockParamType = "image" +) + +func (r ImageBlockParamType) IsKnown() bool { + switch r { + case ImageBlockParamTypeImage: + return true + } + return false +} + +type InputJSONDelta struct { + PartialJSON string `json:"partial_json,required"` + Type InputJSONDeltaType `json:"type,required"` + JSON inputJSONDeltaJSON `json:"-"` +} + +// inputJSONDeltaJSON contains the JSON metadata for the struct [InputJSONDelta] +type inputJSONDeltaJSON struct { + PartialJSON apijson.Field + Type apijson.Field + raw string + ExtraFields map[string]apijson.Field +} + +func (r *InputJSONDelta) UnmarshalJSON(data []byte) (err error) { + return apijson.UnmarshalRoot(data, r) +} + +func (r inputJSONDeltaJSON) RawJSON() string { + return r.raw +} + +func (r InputJSONDelta) implementsContentBlockDeltaEventDelta() {} + +type InputJSONDeltaType string + +const ( + InputJSONDeltaTypeInputJSONDelta InputJSONDeltaType = "input_json_delta" +) + +func (r InputJSONDeltaType) IsKnown() bool { + switch r { + case InputJSONDeltaTypeInputJSONDelta: + return true + } + return false +} + +type Message struct { + // Unique object identifier. + // + // The format and length of IDs may change over time. + ID string `json:"id,required"` + // Content generated by the model. + // + // This is an array of content blocks, each of which has a `type` that determines + // its shape. + // + // Example: + // + // ```json + // [{ "type": "text", "text": "Hi, I'm Claude." }] + // ``` + // + // If the request input `messages` ended with an `assistant` turn, then the + // response `content` will continue directly from that last turn. You can use this + // to constrain the model's output. + // + // For example, if the input `messages` were: + // + // ```json + // [ + // + // { + // "role": "user", + // "content": "What's the Greek name for Sun? (A) Sol (B) Helios (C) Sun" + // }, + // { "role": "assistant", "content": "The best answer is (" } + // + // ] + // ``` + // + // Then the response `content` might be: + // + // ```json + // [{ "type": "text", "text": "B)" }] + // ``` + Content []ContentBlock `json:"content,required"` + // The model that will complete your prompt.\n\nSee + // [models](https://docs.anthropic.com/en/docs/models-overview) for additional + // details and options. + Model Model `json:"model,required"` + // Conversational role of the generated message. + // + // This will always be `"assistant"`. + Role MessageRole `json:"role,required"` + // The reason that we stopped. + // + // This may be one the following values: + // + // - `"end_turn"`: the model reached a natural stopping point + // - `"max_tokens"`: we exceeded the requested `max_tokens` or the model's maximum + // - `"stop_sequence"`: one of your provided custom `stop_sequences` was generated + // - `"tool_use"`: the model invoked one or more tools + // + // In non-streaming mode this value is always non-null. In streaming mode, it is + // null in the `message_start` event and non-null otherwise. + StopReason MessageStopReason `json:"stop_reason,required,nullable"` + // Which custom stop sequence was generated, if any. + // + // This value will be a non-null string if one of your custom stop sequences was + // generated. + StopSequence string `json:"stop_sequence,required,nullable"` + // Object type. + // + // For Messages, this is always `"message"`. + Type MessageType `json:"type,required"` + // Billing and rate-limit usage. + // + // Anthropic's API bills and rate-limits by token counts, as tokens represent the + // underlying cost to our systems. + // + // Under the hood, the API transforms requests into a format suitable for the + // model. The model's output then goes through a parsing stage before becoming an + // API response. As a result, the token counts in `usage` will not match one-to-one + // with the exact visible content of an API request or response. + // + // For example, `output_tokens` will be non-zero, even for an empty string response + // from Claude. + Usage Usage `json:"usage,required"` + JSON messageJSON `json:"-"` +} + +// messageJSON contains the JSON metadata for the struct [Message] +type messageJSON struct { + ID apijson.Field + Content apijson.Field + Model apijson.Field + Role apijson.Field + StopReason apijson.Field + StopSequence apijson.Field + Type apijson.Field + Usage apijson.Field + raw string + ExtraFields map[string]apijson.Field +} + +func (r *Message) UnmarshalJSON(data []byte) (err error) { + return apijson.UnmarshalRoot(data, r) +} + +func (r messageJSON) RawJSON() string { + return r.raw +} + +// Conversational role of the generated message. +// +// This will always be `"assistant"`. +type MessageRole string + +const ( + MessageRoleAssistant MessageRole = "assistant" +) + +func (r MessageRole) IsKnown() bool { + switch r { + case MessageRoleAssistant: + return true + } + return false +} + +// The reason that we stopped. +// +// This may be one the following values: +// +// - `"end_turn"`: the model reached a natural stopping point +// - `"max_tokens"`: we exceeded the requested `max_tokens` or the model's maximum +// - `"stop_sequence"`: one of your provided custom `stop_sequences` was generated +// - `"tool_use"`: the model invoked one or more tools +// +// In non-streaming mode this value is always non-null. In streaming mode, it is +// null in the `message_start` event and non-null otherwise. +type MessageStopReason string + +const ( + MessageStopReasonEndTurn MessageStopReason = "end_turn" + MessageStopReasonMaxTokens MessageStopReason = "max_tokens" + MessageStopReasonStopSequence MessageStopReason = "stop_sequence" + MessageStopReasonToolUse MessageStopReason = "tool_use" +) + +func (r MessageStopReason) IsKnown() bool { + switch r { + case MessageStopReasonEndTurn, MessageStopReasonMaxTokens, MessageStopReasonStopSequence, MessageStopReasonToolUse: + return true + } + return false +} + +// Object type. +// +// For Messages, this is always `"message"`. +type MessageType string + +const ( + MessageTypeMessage MessageType = "message" +) + +func (r MessageType) IsKnown() bool { + switch r { + case MessageTypeMessage: + return true + } + return false +} + +type MessageDeltaUsage struct { + // The cumulative number of output tokens which were used. + OutputTokens int64 `json:"output_tokens,required"` + JSON messageDeltaUsageJSON `json:"-"` +} + +// messageDeltaUsageJSON contains the JSON metadata for the struct +// [MessageDeltaUsage] +type messageDeltaUsageJSON struct { + OutputTokens apijson.Field + raw string + ExtraFields map[string]apijson.Field +} + +func (r *MessageDeltaUsage) UnmarshalJSON(data []byte) (err error) { + return apijson.UnmarshalRoot(data, r) +} + +func (r messageDeltaUsageJSON) RawJSON() string { + return r.raw +} + +type MessageParam struct { + Content param.Field[[]MessageParamContentUnion] `json:"content,required"` + Role param.Field[MessageParamRole] `json:"role,required"` +} + +func (r MessageParam) MarshalJSON() (data []byte, err error) { + return apijson.MarshalRoot(r) +} + +type MessageParamContent struct { + Type param.Field[MessageParamContentType] `json:"type,required"` + Text param.Field[string] `json:"text"` + Source param.Field[interface{}] `json:"source,required"` + ID param.Field[string] `json:"id"` + Name param.Field[string] `json:"name"` + Input param.Field[interface{}] `json:"input,required"` + ToolUseID param.Field[string] `json:"tool_use_id"` + IsError param.Field[bool] `json:"is_error"` + Content param.Field[interface{}] `json:"content,required"` +} + +func (r MessageParamContent) MarshalJSON() (data []byte, err error) { + return apijson.MarshalRoot(r) +} + +func (r MessageParamContent) implementsMessageParamContentUnion() {} + +// Satisfied by [TextBlockParam], [ImageBlockParam], [ToolUseBlockParam], +// [ToolResultBlockParam], [MessageParamContent]. +type MessageParamContentUnion interface { + implementsMessageParamContentUnion() +} + +type MessageParamContentType string + +const ( + MessageParamContentTypeText MessageParamContentType = "text" + MessageParamContentTypeImage MessageParamContentType = "image" + MessageParamContentTypeToolUse MessageParamContentType = "tool_use" + MessageParamContentTypeToolResult MessageParamContentType = "tool_result" +) + +func (r MessageParamContentType) IsKnown() bool { + switch r { + case MessageParamContentTypeText, MessageParamContentTypeImage, MessageParamContentTypeToolUse, MessageParamContentTypeToolResult: + return true + } + return false +} + +type MessageParamRole string + +const ( + MessageParamRoleUser MessageParamRole = "user" + MessageParamRoleAssistant MessageParamRole = "assistant" +) + +func (r MessageParamRole) IsKnown() bool { + switch r { + case MessageParamRoleUser, MessageParamRoleAssistant: + return true + } + return false +} + +type Model = string + +const ( + // Our most intelligent model + ModelClaude_3_5_Sonnet_20240620 Model = "claude-3-5-sonnet-20240620" + // Excels at writing and complex tasks + ModelClaude_3_Opus_20240229 Model = "claude-3-opus-20240229" + // Balance of speed and intelligence + ModelClaude_3_Sonnet_20240229 Model = "claude-3-sonnet-20240229" + // Fast and cost-effective + ModelClaude_3_Haiku_20240307 Model = "claude-3-haiku-20240307" + ModelClaude_2_1 Model = "claude-2.1" + ModelClaude_2_0 Model = "claude-2.0" + ModelClaude_Instant_1_2 Model = "claude-instant-1.2" +) + +type ContentBlockDeltaEvent struct { + Delta ContentBlockDeltaEventDelta `json:"delta,required"` + Index int64 `json:"index,required"` + Type ContentBlockDeltaEventType `json:"type,required"` + JSON contentBlockDeltaEventJSON `json:"-"` +} + +// contentBlockDeltaEventJSON contains the JSON metadata for the struct +// [ContentBlockDeltaEvent] +type contentBlockDeltaEventJSON struct { + Delta apijson.Field + Index apijson.Field + Type apijson.Field + raw string + ExtraFields map[string]apijson.Field +} + +func (r *ContentBlockDeltaEvent) UnmarshalJSON(data []byte) (err error) { + return apijson.UnmarshalRoot(data, r) +} + +func (r contentBlockDeltaEventJSON) RawJSON() string { + return r.raw +} + +func (r ContentBlockDeltaEvent) implementsMessageStreamEvent() {} + +type ContentBlockDeltaEventDelta struct { + Type ContentBlockDeltaEventDeltaType `json:"type,required"` + Text string `json:"text"` + PartialJSON string `json:"partial_json"` + JSON contentBlockDeltaEventDeltaJSON `json:"-"` + union ContentBlockDeltaEventDeltaUnion +} + +// contentBlockDeltaEventDeltaJSON contains the JSON metadata for the struct +// [ContentBlockDeltaEventDelta] +type contentBlockDeltaEventDeltaJSON struct { + Type apijson.Field + Text apijson.Field + PartialJSON apijson.Field + raw string + ExtraFields map[string]apijson.Field +} + +func (r contentBlockDeltaEventDeltaJSON) RawJSON() string { + return r.raw +} + +func (r *ContentBlockDeltaEventDelta) UnmarshalJSON(data []byte) (err error) { + *r = ContentBlockDeltaEventDelta{} + err = apijson.UnmarshalRoot(data, &r.union) + if err != nil { + return err + } + return apijson.Port(r.union, &r) +} + +// AsUnion returns a [ContentBlockDeltaEventDeltaUnion] interface which you can +// cast to the specific types for more type safety. +// +// Possible runtime types of the union are [TextDelta], [InputJSONDelta]. +func (r ContentBlockDeltaEventDelta) AsUnion() ContentBlockDeltaEventDeltaUnion { + return r.union +} + +// Union satisfied by [TextDelta] or [InputJSONDelta]. +type ContentBlockDeltaEventDeltaUnion interface { + implementsContentBlockDeltaEventDelta() +} + +func init() { + apijson.RegisterUnion( + reflect.TypeOf((*ContentBlockDeltaEventDeltaUnion)(nil)).Elem(), + "type", + apijson.UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(TextDelta{}), + DiscriminatorValue: "text_delta", + }, + apijson.UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(InputJSONDelta{}), + DiscriminatorValue: "input_json_delta", + }, + ) +} + +type ContentBlockDeltaEventDeltaType string + +const ( + ContentBlockDeltaEventDeltaTypeTextDelta ContentBlockDeltaEventDeltaType = "text_delta" + ContentBlockDeltaEventDeltaTypeInputJSONDelta ContentBlockDeltaEventDeltaType = "input_json_delta" +) + +func (r ContentBlockDeltaEventDeltaType) IsKnown() bool { + switch r { + case ContentBlockDeltaEventDeltaTypeTextDelta, ContentBlockDeltaEventDeltaTypeInputJSONDelta: + return true + } + return false +} + +type ContentBlockDeltaEventType string + +const ( + ContentBlockDeltaEventTypeContentBlockDelta ContentBlockDeltaEventType = "content_block_delta" +) + +func (r ContentBlockDeltaEventType) IsKnown() bool { + switch r { + case ContentBlockDeltaEventTypeContentBlockDelta: + return true + } + return false +} + +type ContentBlockStartEvent struct { + ContentBlock ContentBlockStartEventContentBlock `json:"content_block,required"` + Index int64 `json:"index,required"` + Type ContentBlockStartEventType `json:"type,required"` + JSON contentBlockStartEventJSON `json:"-"` +} + +// contentBlockStartEventJSON contains the JSON metadata for the struct +// [ContentBlockStartEvent] +type contentBlockStartEventJSON struct { + ContentBlock apijson.Field + Index apijson.Field + Type apijson.Field + raw string + ExtraFields map[string]apijson.Field +} + +func (r *ContentBlockStartEvent) UnmarshalJSON(data []byte) (err error) { + return apijson.UnmarshalRoot(data, r) +} + +func (r contentBlockStartEventJSON) RawJSON() string { + return r.raw +} + +func (r ContentBlockStartEvent) implementsMessageStreamEvent() {} + +type ContentBlockStartEventContentBlock struct { + Type ContentBlockStartEventContentBlockType `json:"type,required"` + Text string `json:"text"` + ID string `json:"id"` + Name string `json:"name"` + // This field can have the runtime type of [interface{}]. + Input interface{} `json:"input,required"` + JSON contentBlockStartEventContentBlockJSON `json:"-"` + union ContentBlockStartEventContentBlockUnion +} + +// contentBlockStartEventContentBlockJSON contains the JSON metadata for the struct +// [ContentBlockStartEventContentBlock] +type contentBlockStartEventContentBlockJSON struct { + Type apijson.Field + Text apijson.Field + ID apijson.Field + Name apijson.Field + Input apijson.Field + raw string + ExtraFields map[string]apijson.Field +} + +func (r contentBlockStartEventContentBlockJSON) RawJSON() string { + return r.raw +} + +func (r *ContentBlockStartEventContentBlock) UnmarshalJSON(data []byte) (err error) { + *r = ContentBlockStartEventContentBlock{} + err = apijson.UnmarshalRoot(data, &r.union) + if err != nil { + return err + } + return apijson.Port(r.union, &r) +} + +// AsUnion returns a [ContentBlockStartEventContentBlockUnion] interface which you +// can cast to the specific types for more type safety. +// +// Possible runtime types of the union are [TextBlock], [ToolUseBlock]. +func (r ContentBlockStartEventContentBlock) AsUnion() ContentBlockStartEventContentBlockUnion { + return r.union +} + +// Union satisfied by [TextBlock] or [ToolUseBlock]. +type ContentBlockStartEventContentBlockUnion interface { + implementsContentBlockStartEventContentBlock() +} + +func init() { + apijson.RegisterUnion( + reflect.TypeOf((*ContentBlockStartEventContentBlockUnion)(nil)).Elem(), + "type", + apijson.UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(TextBlock{}), + DiscriminatorValue: "text", + }, + apijson.UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(ToolUseBlock{}), + DiscriminatorValue: "tool_use", + }, + ) +} + +type ContentBlockStartEventContentBlockType string + +const ( + ContentBlockStartEventContentBlockTypeText ContentBlockStartEventContentBlockType = "text" + ContentBlockStartEventContentBlockTypeToolUse ContentBlockStartEventContentBlockType = "tool_use" +) + +func (r ContentBlockStartEventContentBlockType) IsKnown() bool { + switch r { + case ContentBlockStartEventContentBlockTypeText, ContentBlockStartEventContentBlockTypeToolUse: + return true + } + return false +} + +type ContentBlockStartEventType string + +const ( + ContentBlockStartEventTypeContentBlockStart ContentBlockStartEventType = "content_block_start" +) + +func (r ContentBlockStartEventType) IsKnown() bool { + switch r { + case ContentBlockStartEventTypeContentBlockStart: + return true + } + return false +} + +type ContentBlockStopEvent struct { + Index int64 `json:"index,required"` + Type ContentBlockStopEventType `json:"type,required"` + JSON contentBlockStopEventJSON `json:"-"` +} + +// contentBlockStopEventJSON contains the JSON metadata for the struct +// [ContentBlockStopEvent] +type contentBlockStopEventJSON struct { + Index apijson.Field + Type apijson.Field + raw string + ExtraFields map[string]apijson.Field +} + +func (r *ContentBlockStopEvent) UnmarshalJSON(data []byte) (err error) { + return apijson.UnmarshalRoot(data, r) +} + +func (r contentBlockStopEventJSON) RawJSON() string { + return r.raw +} + +func (r ContentBlockStopEvent) implementsMessageStreamEvent() {} + +type ContentBlockStopEventType string + +const ( + ContentBlockStopEventTypeContentBlockStop ContentBlockStopEventType = "content_block_stop" +) + +func (r ContentBlockStopEventType) IsKnown() bool { + switch r { + case ContentBlockStopEventTypeContentBlockStop: + return true + } + return false +} + +type MessageDeltaEvent struct { + Delta MessageDeltaEventDelta `json:"delta,required"` + Type MessageDeltaEventType `json:"type,required"` + // Billing and rate-limit usage. + // + // Anthropic's API bills and rate-limits by token counts, as tokens represent the + // underlying cost to our systems. + // + // Under the hood, the API transforms requests into a format suitable for the + // model. The model's output then goes through a parsing stage before becoming an + // API response. As a result, the token counts in `usage` will not match one-to-one + // with the exact visible content of an API request or response. + // + // For example, `output_tokens` will be non-zero, even for an empty string response + // from Claude. + Usage MessageDeltaUsage `json:"usage,required"` + JSON messageDeltaEventJSON `json:"-"` +} + +// messageDeltaEventJSON contains the JSON metadata for the struct +// [MessageDeltaEvent] +type messageDeltaEventJSON struct { + Delta apijson.Field + Type apijson.Field + Usage apijson.Field + raw string + ExtraFields map[string]apijson.Field +} + +func (r *MessageDeltaEvent) UnmarshalJSON(data []byte) (err error) { + return apijson.UnmarshalRoot(data, r) +} + +func (r messageDeltaEventJSON) RawJSON() string { + return r.raw +} + +func (r MessageDeltaEvent) implementsMessageStreamEvent() {} + +type MessageDeltaEventDelta struct { + StopReason MessageDeltaEventDeltaStopReason `json:"stop_reason,required,nullable"` + StopSequence string `json:"stop_sequence,required,nullable"` + JSON messageDeltaEventDeltaJSON `json:"-"` +} + +// messageDeltaEventDeltaJSON contains the JSON metadata for the struct +// [MessageDeltaEventDelta] +type messageDeltaEventDeltaJSON struct { + StopReason apijson.Field + StopSequence apijson.Field + raw string + ExtraFields map[string]apijson.Field +} + +func (r *MessageDeltaEventDelta) UnmarshalJSON(data []byte) (err error) { + return apijson.UnmarshalRoot(data, r) +} + +func (r messageDeltaEventDeltaJSON) RawJSON() string { + return r.raw +} + +type MessageDeltaEventDeltaStopReason string + +const ( + MessageDeltaEventDeltaStopReasonEndTurn MessageDeltaEventDeltaStopReason = "end_turn" + MessageDeltaEventDeltaStopReasonMaxTokens MessageDeltaEventDeltaStopReason = "max_tokens" + MessageDeltaEventDeltaStopReasonStopSequence MessageDeltaEventDeltaStopReason = "stop_sequence" + MessageDeltaEventDeltaStopReasonToolUse MessageDeltaEventDeltaStopReason = "tool_use" +) + +func (r MessageDeltaEventDeltaStopReason) IsKnown() bool { + switch r { + case MessageDeltaEventDeltaStopReasonEndTurn, MessageDeltaEventDeltaStopReasonMaxTokens, MessageDeltaEventDeltaStopReasonStopSequence, MessageDeltaEventDeltaStopReasonToolUse: + return true + } + return false +} + +type MessageDeltaEventType string + +const ( + MessageDeltaEventTypeMessageDelta MessageDeltaEventType = "message_delta" +) + +func (r MessageDeltaEventType) IsKnown() bool { + switch r { + case MessageDeltaEventTypeMessageDelta: + return true + } + return false +} + +type MessageStartEvent struct { + Message Message `json:"message,required"` + Type MessageStartEventType `json:"type,required"` + JSON messageStartEventJSON `json:"-"` +} + +// messageStartEventJSON contains the JSON metadata for the struct +// [MessageStartEvent] +type messageStartEventJSON struct { + Message apijson.Field + Type apijson.Field + raw string + ExtraFields map[string]apijson.Field +} + +func (r *MessageStartEvent) UnmarshalJSON(data []byte) (err error) { + return apijson.UnmarshalRoot(data, r) +} + +func (r messageStartEventJSON) RawJSON() string { + return r.raw +} + +func (r MessageStartEvent) implementsMessageStreamEvent() {} + +type MessageStartEventType string + +const ( + MessageStartEventTypeMessageStart MessageStartEventType = "message_start" +) + +func (r MessageStartEventType) IsKnown() bool { + switch r { + case MessageStartEventTypeMessageStart: + return true + } + return false +} + +type MessageStopEvent struct { + Type MessageStopEventType `json:"type,required"` + JSON messageStopEventJSON `json:"-"` +} + +// messageStopEventJSON contains the JSON metadata for the struct +// [MessageStopEvent] +type messageStopEventJSON struct { + Type apijson.Field + raw string + ExtraFields map[string]apijson.Field +} + +func (r *MessageStopEvent) UnmarshalJSON(data []byte) (err error) { + return apijson.UnmarshalRoot(data, r) +} + +func (r messageStopEventJSON) RawJSON() string { + return r.raw +} + +func (r MessageStopEvent) implementsMessageStreamEvent() {} + +type MessageStopEventType string + +const ( + MessageStopEventTypeMessageStop MessageStopEventType = "message_stop" +) + +func (r MessageStopEventType) IsKnown() bool { + switch r { + case MessageStopEventTypeMessageStop: + return true + } + return false +} + +type MessageStreamEvent struct { + Type MessageStreamEventType `json:"type,required"` + Message Message `json:"message"` + // This field can have the runtime type of [MessageDeltaEventDelta], + // [ContentBlockDeltaEventDelta]. + Delta interface{} `json:"delta,required"` + // Billing and rate-limit usage. + // + // Anthropic's API bills and rate-limits by token counts, as tokens represent the + // underlying cost to our systems. + // + // Under the hood, the API transforms requests into a format suitable for the + // model. The model's output then goes through a parsing stage before becoming an + // API response. As a result, the token counts in `usage` will not match one-to-one + // with the exact visible content of an API request or response. + // + // For example, `output_tokens` will be non-zero, even for an empty string response + // from Claude. + Usage MessageDeltaUsage `json:"usage"` + Index int64 `json:"index"` + // This field can have the runtime type of [ContentBlockStartEventContentBlock]. + ContentBlock interface{} `json:"content_block,required"` + JSON messageStreamEventJSON `json:"-"` + union MessageStreamEventUnion +} + +// messageStreamEventJSON contains the JSON metadata for the struct +// [MessageStreamEvent] +type messageStreamEventJSON struct { + Type apijson.Field + Message apijson.Field + Delta apijson.Field + Usage apijson.Field + Index apijson.Field + ContentBlock apijson.Field + raw string + ExtraFields map[string]apijson.Field +} + +func (r messageStreamEventJSON) RawJSON() string { + return r.raw +} + +func (r *MessageStreamEvent) UnmarshalJSON(data []byte) (err error) { + *r = MessageStreamEvent{} + err = apijson.UnmarshalRoot(data, &r.union) + if err != nil { + return err + } + return apijson.Port(r.union, &r) +} + +// AsUnion returns a [MessageStreamEventUnion] interface which you can cast to the +// specific types for more type safety. +// +// Possible runtime types of the union are [MessageStartEvent], +// [MessageDeltaEvent], [MessageStopEvent], [ContentBlockStartEvent], +// [ContentBlockDeltaEvent], [ContentBlockStopEvent]. +func (r MessageStreamEvent) AsUnion() MessageStreamEventUnion { + return r.union +} + +// Union satisfied by [MessageStartEvent], [MessageDeltaEvent], [MessageStopEvent], +// [ContentBlockStartEvent], [ContentBlockDeltaEvent] or [ContentBlockStopEvent]. +type MessageStreamEventUnion interface { + implementsMessageStreamEvent() +} + +func init() { + apijson.RegisterUnion( + reflect.TypeOf((*MessageStreamEventUnion)(nil)).Elem(), + "type", + apijson.UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(MessageStartEvent{}), + DiscriminatorValue: "message_start", + }, + apijson.UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(MessageDeltaEvent{}), + DiscriminatorValue: "message_delta", + }, + apijson.UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(MessageStopEvent{}), + DiscriminatorValue: "message_stop", + }, + apijson.UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(ContentBlockStartEvent{}), + DiscriminatorValue: "content_block_start", + }, + apijson.UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(ContentBlockDeltaEvent{}), + DiscriminatorValue: "content_block_delta", + }, + apijson.UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(ContentBlockStopEvent{}), + DiscriminatorValue: "content_block_stop", + }, + ) +} + +type MessageStreamEventType string + +const ( + MessageStreamEventTypeMessageStart MessageStreamEventType = "message_start" + MessageStreamEventTypeMessageDelta MessageStreamEventType = "message_delta" + MessageStreamEventTypeMessageStop MessageStreamEventType = "message_stop" + MessageStreamEventTypeContentBlockStart MessageStreamEventType = "content_block_start" + MessageStreamEventTypeContentBlockDelta MessageStreamEventType = "content_block_delta" + MessageStreamEventTypeContentBlockStop MessageStreamEventType = "content_block_stop" +) + +func (r MessageStreamEventType) IsKnown() bool { + switch r { + case MessageStreamEventTypeMessageStart, MessageStreamEventTypeMessageDelta, MessageStreamEventTypeMessageStop, MessageStreamEventTypeContentBlockStart, MessageStreamEventTypeContentBlockDelta, MessageStreamEventTypeContentBlockStop: + return true + } + return false +} + +type TextBlock struct { + Text string `json:"text,required"` + Type TextBlockType `json:"type,required"` + JSON textBlockJSON `json:"-"` +} + +// textBlockJSON contains the JSON metadata for the struct [TextBlock] +type textBlockJSON struct { + Text apijson.Field + Type apijson.Field + raw string + ExtraFields map[string]apijson.Field +} + +func (r *TextBlock) UnmarshalJSON(data []byte) (err error) { + return apijson.UnmarshalRoot(data, r) +} + +func (r textBlockJSON) RawJSON() string { + return r.raw +} + +func (r TextBlock) implementsContentBlock() {} + +func (r TextBlock) implementsContentBlockStartEventContentBlock() {} + +type TextBlockType string + +const ( + TextBlockTypeText TextBlockType = "text" +) + +func (r TextBlockType) IsKnown() bool { + switch r { + case TextBlockTypeText: + return true + } + return false +} + +type TextBlockParam struct { + Text param.Field[string] `json:"text,required"` + Type param.Field[TextBlockParamType] `json:"type,required"` +} + +func (r TextBlockParam) MarshalJSON() (data []byte, err error) { + return apijson.MarshalRoot(r) +} + +func (r TextBlockParam) implementsMessageParamContentUnion() {} + +func (r TextBlockParam) implementsToolResultBlockParamContentUnion() {} + +type TextBlockParamType string + +const ( + TextBlockParamTypeText TextBlockParamType = "text" +) + +func (r TextBlockParamType) IsKnown() bool { + switch r { + case TextBlockParamTypeText: + return true + } + return false +} + +type TextDelta struct { + Text string `json:"text,required"` + Type TextDeltaType `json:"type,required"` + JSON textDeltaJSON `json:"-"` +} + +// textDeltaJSON contains the JSON metadata for the struct [TextDelta] +type textDeltaJSON struct { + Text apijson.Field + Type apijson.Field + raw string + ExtraFields map[string]apijson.Field +} + +func (r *TextDelta) UnmarshalJSON(data []byte) (err error) { + return apijson.UnmarshalRoot(data, r) +} + +func (r textDeltaJSON) RawJSON() string { + return r.raw +} + +func (r TextDelta) implementsContentBlockDeltaEventDelta() {} + +type TextDeltaType string + +const ( + TextDeltaTypeTextDelta TextDeltaType = "text_delta" +) + +func (r TextDeltaType) IsKnown() bool { + switch r { + case TextDeltaTypeTextDelta: + return true + } + return false +} + +type ToolParam struct { + // [JSON schema](https://json-schema.org/) for this tool's input. + // + // This defines the shape of the `input` that your tool accepts and that the model + // will produce. + InputSchema param.Field[ToolInputSchemaParam] `json:"input_schema,required"` + Name param.Field[string] `json:"name,required"` + // Description of what this tool does. + // + // Tool descriptions should be as detailed as possible. The more information that + // the model has about what the tool is and how to use it, the better it will + // perform. You can use natural language descriptions to reinforce important + // aspects of the tool input JSON schema. + Description param.Field[string] `json:"description"` +} + +func (r ToolParam) MarshalJSON() (data []byte, err error) { + return apijson.MarshalRoot(r) +} + +// [JSON schema](https://json-schema.org/) for this tool's input. +// +// This defines the shape of the `input` that your tool accepts and that the model +// will produce. +type ToolInputSchemaParam struct { + Type param.Field[ToolInputSchemaType] `json:"type,required"` + Properties param.Field[interface{}] `json:"properties"` + ExtraFields map[string]interface{} `json:"-,extras"` +} + +func (r ToolInputSchemaParam) MarshalJSON() (data []byte, err error) { + return apijson.MarshalRoot(r) +} + +type ToolInputSchemaType string + +const ( + ToolInputSchemaTypeObject ToolInputSchemaType = "object" +) + +func (r ToolInputSchemaType) IsKnown() bool { + switch r { + case ToolInputSchemaTypeObject: + return true + } + return false +} + +type ToolResultBlockParam struct { + ToolUseID param.Field[string] `json:"tool_use_id,required"` + Type param.Field[ToolResultBlockParamType] `json:"type,required"` + Content param.Field[[]ToolResultBlockParamContentUnion] `json:"content"` + IsError param.Field[bool] `json:"is_error"` +} + +func (r ToolResultBlockParam) MarshalJSON() (data []byte, err error) { + return apijson.MarshalRoot(r) +} + +func (r ToolResultBlockParam) implementsMessageParamContentUnion() {} + +type ToolResultBlockParamType string + +const ( + ToolResultBlockParamTypeToolResult ToolResultBlockParamType = "tool_result" +) + +func (r ToolResultBlockParamType) IsKnown() bool { + switch r { + case ToolResultBlockParamTypeToolResult: + return true + } + return false +} + +type ToolResultBlockParamContent struct { + Type param.Field[ToolResultBlockParamContentType] `json:"type,required"` + Text param.Field[string] `json:"text"` + Source param.Field[interface{}] `json:"source,required"` +} + +func (r ToolResultBlockParamContent) MarshalJSON() (data []byte, err error) { + return apijson.MarshalRoot(r) +} + +func (r ToolResultBlockParamContent) implementsToolResultBlockParamContentUnion() {} + +// Satisfied by [TextBlockParam], [ImageBlockParam], [ToolResultBlockParamContent]. +type ToolResultBlockParamContentUnion interface { + implementsToolResultBlockParamContentUnion() +} + +type ToolResultBlockParamContentType string + +const ( + ToolResultBlockParamContentTypeText ToolResultBlockParamContentType = "text" + ToolResultBlockParamContentTypeImage ToolResultBlockParamContentType = "image" +) + +func (r ToolResultBlockParamContentType) IsKnown() bool { + switch r { + case ToolResultBlockParamContentTypeText, ToolResultBlockParamContentTypeImage: + return true + } + return false +} + +type ToolUseBlock struct { + ID string `json:"id,required"` + Input interface{} `json:"input,required"` + Name string `json:"name,required"` + Type ToolUseBlockType `json:"type,required"` + JSON toolUseBlockJSON `json:"-"` +} + +// toolUseBlockJSON contains the JSON metadata for the struct [ToolUseBlock] +type toolUseBlockJSON struct { + ID apijson.Field + Input apijson.Field + Name apijson.Field + Type apijson.Field + raw string + ExtraFields map[string]apijson.Field +} + +func (r *ToolUseBlock) UnmarshalJSON(data []byte) (err error) { + return apijson.UnmarshalRoot(data, r) +} + +func (r toolUseBlockJSON) RawJSON() string { + return r.raw +} + +func (r ToolUseBlock) implementsContentBlock() {} + +func (r ToolUseBlock) implementsContentBlockStartEventContentBlock() {} + +type ToolUseBlockType string + +const ( + ToolUseBlockTypeToolUse ToolUseBlockType = "tool_use" +) + +func (r ToolUseBlockType) IsKnown() bool { + switch r { + case ToolUseBlockTypeToolUse: + return true + } + return false +} + +type ToolUseBlockParam struct { + ID param.Field[string] `json:"id,required"` + Input param.Field[interface{}] `json:"input,required"` + Name param.Field[string] `json:"name,required"` + Type param.Field[ToolUseBlockParamType] `json:"type,required"` +} + +func (r ToolUseBlockParam) MarshalJSON() (data []byte, err error) { + return apijson.MarshalRoot(r) +} + +func (r ToolUseBlockParam) implementsMessageParamContentUnion() {} + +type ToolUseBlockParamType string + +const ( + ToolUseBlockParamTypeToolUse ToolUseBlockParamType = "tool_use" +) + +func (r ToolUseBlockParamType) IsKnown() bool { + switch r { + case ToolUseBlockParamTypeToolUse: + return true + } + return false +} + +type Usage struct { + // The number of input tokens which were used. + InputTokens int64 `json:"input_tokens,required"` + // The number of output tokens which were used. + OutputTokens int64 `json:"output_tokens,required"` + JSON usageJSON `json:"-"` +} + +// usageJSON contains the JSON metadata for the struct [Usage] +type usageJSON struct { + InputTokens apijson.Field + OutputTokens apijson.Field + raw string + ExtraFields map[string]apijson.Field +} + +func (r *Usage) UnmarshalJSON(data []byte) (err error) { + return apijson.UnmarshalRoot(data, r) +} + +func (r usageJSON) RawJSON() string { + return r.raw +} + +type MessageNewParams struct { + // The maximum number of tokens to generate before stopping. + // + // Note that our models may stop _before_ reaching this maximum. This parameter + // only specifies the absolute maximum number of tokens to generate. + // + // Different models have different maximum values for this parameter. See + // [models](https://docs.anthropic.com/en/docs/models-overview) for details. + MaxTokens param.Field[int64] `json:"max_tokens,required"` + // Input messages. + // + // Our models are trained to operate on alternating `user` and `assistant` + // conversational turns. When creating a new `Message`, you specify the prior + // conversational turns with the `messages` parameter, and the model then generates + // the next `Message` in the conversation. + // + // Each input message must be an object with a `role` and `content`. You can + // specify a single `user`-role message, or you can include multiple `user` and + // `assistant` messages. The first message must always use the `user` role. + // + // If the final message uses the `assistant` role, the response content will + // continue immediately from the content in that message. This can be used to + // constrain part of the model's response. + // + // Example with a single `user` message: + // + // ```json + // [{ "role": "user", "content": "Hello, Claude" }] + // ``` + // + // Example with multiple conversational turns: + // + // ```json + // [ + // + // { "role": "user", "content": "Hello there." }, + // { "role": "assistant", "content": "Hi, I'm Claude. How can I help you?" }, + // { "role": "user", "content": "Can you explain LLMs in plain English?" } + // + // ] + // ``` + // + // Example with a partially-filled response from Claude: + // + // ```json + // [ + // + // { + // "role": "user", + // "content": "What's the Greek name for Sun? (A) Sol (B) Helios (C) Sun" + // }, + // { "role": "assistant", "content": "The best answer is (" } + // + // ] + // ``` + // + // Each input message `content` may be either a single `string` or an array of + // content blocks, where each block has a specific `type`. Using a `string` for + // `content` is shorthand for an array of one content block of type `"text"`. The + // following input messages are equivalent: + // + // ```json + // { "role": "user", "content": "Hello, Claude" } + // ``` + // + // ```json + // { "role": "user", "content": [{ "type": "text", "text": "Hello, Claude" }] } + // ``` + // + // Starting with Claude 3 models, you can also send image content blocks: + // + // ```json + // + // { + // "role": "user", + // "content": [ + // { + // "type": "image", + // "source": { + // "type": "base64", + // "media_type": "image/jpeg", + // "data": "/9j/4AAQSkZJRg..." + // } + // }, + // { "type": "text", "text": "What is in this image?" } + // ] + // } + // + // ``` + // + // We currently support the `base64` source type for images, and the `image/jpeg`, + // `image/png`, `image/gif`, and `image/webp` media types. + // + // See [examples](https://docs.anthropic.com/en/api/messages-examples) for more + // input examples. + // + // Note that if you want to include a + // [system prompt](https://docs.anthropic.com/en/docs/system-prompts), you can use + // the top-level `system` parameter — there is no `"system"` role for input + // messages in the Messages API. + Messages param.Field[[]MessageParam] `json:"messages,required"` + // The model that will complete your prompt.\n\nSee + // [models](https://docs.anthropic.com/en/docs/models-overview) for additional + // details and options. + Model param.Field[Model] `json:"model,required"` + // An object describing metadata about the request. + Metadata param.Field[MessageNewParamsMetadata] `json:"metadata"` + // Custom text sequences that will cause the model to stop generating. + // + // Our models will normally stop when they have naturally completed their turn, + // which will result in a response `stop_reason` of `"end_turn"`. + // + // If you want the model to stop generating when it encounters custom strings of + // text, you can use the `stop_sequences` parameter. If the model encounters one of + // the custom sequences, the response `stop_reason` value will be `"stop_sequence"` + // and the response `stop_sequence` value will contain the matched stop sequence. + StopSequences param.Field[[]string] `json:"stop_sequences"` + // System prompt. + // + // A system prompt is a way of providing context and instructions to Claude, such + // as specifying a particular goal or role. See our + // [guide to system prompts](https://docs.anthropic.com/en/docs/system-prompts). + System param.Field[MessageNewParamsSystemUnion] `json:"system"` + // Amount of randomness injected into the response. + // + // Defaults to `1.0`. Ranges from `0.0` to `1.0`. Use `temperature` closer to `0.0` + // for analytical / multiple choice, and closer to `1.0` for creative and + // generative tasks. + // + // Note that even with `temperature` of `0.0`, the results will not be fully + // deterministic. + Temperature param.Field[float64] `json:"temperature"` + // How the model should use the provided tools. The model can use a specific tool, + // any available tool, or decide by itself. + ToolChoice param.Field[MessageNewParamsToolChoiceUnion] `json:"tool_choice"` + // Definitions of tools that the model may use. + // + // If you include `tools` in your API request, the model may return `tool_use` + // content blocks that represent the model's use of those tools. You can then run + // those tools using the tool input generated by the model and then optionally + // return results back to the model using `tool_result` content blocks. + // + // Each tool definition includes: + // + // - `name`: Name of the tool. + // - `description`: Optional, but strongly-recommended description of the tool. + // - `input_schema`: [JSON schema](https://json-schema.org/) for the tool `input` + // shape that the model will produce in `tool_use` output content blocks. + // + // For example, if you defined `tools` as: + // + // ```json + // [ + // + // { + // "name": "get_stock_price", + // "description": "Get the current stock price for a given ticker symbol.", + // "input_schema": { + // "type": "object", + // "properties": { + // "ticker": { + // "type": "string", + // "description": "The stock ticker symbol, e.g. AAPL for Apple Inc." + // } + // }, + // "required": ["ticker"] + // } + // } + // + // ] + // ``` + // + // And then asked the model "What's the S&P 500 at today?", the model might produce + // `tool_use` content blocks in the response like this: + // + // ```json + // [ + // + // { + // "type": "tool_use", + // "id": "toolu_01D7FLrfh4GYq7yT1ULFeyMV", + // "name": "get_stock_price", + // "input": { "ticker": "^GSPC" } + // } + // + // ] + // ``` + // + // You might then run your `get_stock_price` tool with `{"ticker": "^GSPC"}` as an + // input, and return the following back to the model in a subsequent `user` + // message: + // + // ```json + // [ + // + // { + // "type": "tool_result", + // "tool_use_id": "toolu_01D7FLrfh4GYq7yT1ULFeyMV", + // "content": "259.75 USD" + // } + // + // ] + // ``` + // + // Tools can be used for workflows that include running client-side tools and + // functions, or more generally whenever you want the model to produce a particular + // JSON structure of output. + // + // See our [guide](https://docs.anthropic.com/en/docs/tool-use) for more details. + Tools param.Field[[]ToolParam] `json:"tools"` + // Only sample from the top K options for each subsequent token. + // + // Used to remove "long tail" low probability responses. + // [Learn more technical details here](https://towardsdatascience.com/how-to-sample-from-language-models-682bceb97277). + // + // Recommended for advanced use cases only. You usually only need to use + // `temperature`. + TopK param.Field[int64] `json:"top_k"` + // Use nucleus sampling. + // + // In nucleus sampling, we compute the cumulative distribution over all the options + // for each subsequent token in decreasing probability order and cut it off once it + // reaches a particular probability specified by `top_p`. You should either alter + // `temperature` or `top_p`, but not both. + // + // Recommended for advanced use cases only. You usually only need to use + // `temperature`. + TopP param.Field[float64] `json:"top_p"` +} + +func (r MessageNewParams) MarshalJSON() (data []byte, err error) { + return apijson.MarshalRoot(r) +} + +// An object describing metadata about the request. +type MessageNewParamsMetadata struct { + // An external identifier for the user who is associated with the request. + // + // This should be a uuid, hash value, or other opaque identifier. Anthropic may use + // this id to help detect abuse. Do not include any identifying information such as + // name, email address, or phone number. + UserID param.Field[string] `json:"user_id"` +} + +func (r MessageNewParamsMetadata) MarshalJSON() (data []byte, err error) { + return apijson.MarshalRoot(r) +} + +// System prompt. +// +// A system prompt is a way of providing context and instructions to Claude, such +// as specifying a particular goal or role. See our +// [guide to system prompts](https://docs.anthropic.com/en/docs/system-prompts). +// +// Satisfied by [shared.UnionString], [MessageNewParamsSystemArray]. +type MessageNewParamsSystemUnion interface { + ImplementsMessageNewParamsSystemUnion() +} + +type MessageNewParamsSystemArray []TextBlockParam + +func (r MessageNewParamsSystemArray) ImplementsMessageNewParamsSystemUnion() {} + +// How the model should use the provided tools. The model can use a specific tool, +// any available tool, or decide by itself. +type MessageNewParamsToolChoice struct { + Type param.Field[MessageNewParamsToolChoiceType] `json:"type,required"` + // The name of the tool to use. + Name param.Field[string] `json:"name"` +} + +func (r MessageNewParamsToolChoice) MarshalJSON() (data []byte, err error) { + return apijson.MarshalRoot(r) +} + +func (r MessageNewParamsToolChoice) implementsMessageNewParamsToolChoiceUnion() {} + +// How the model should use the provided tools. The model can use a specific tool, +// any available tool, or decide by itself. +// +// Satisfied by [MessageNewParamsToolChoiceToolChoiceAuto], +// [MessageNewParamsToolChoiceToolChoiceAny], +// [MessageNewParamsToolChoiceToolChoiceTool], [MessageNewParamsToolChoice]. +type MessageNewParamsToolChoiceUnion interface { + implementsMessageNewParamsToolChoiceUnion() +} + +// The model will automatically decide whether to use tools. +type MessageNewParamsToolChoiceToolChoiceAuto struct { + Type param.Field[MessageNewParamsToolChoiceToolChoiceAutoType] `json:"type,required"` +} + +func (r MessageNewParamsToolChoiceToolChoiceAuto) MarshalJSON() (data []byte, err error) { + return apijson.MarshalRoot(r) +} + +func (r MessageNewParamsToolChoiceToolChoiceAuto) implementsMessageNewParamsToolChoiceUnion() {} + +type MessageNewParamsToolChoiceToolChoiceAutoType string + +const ( + MessageNewParamsToolChoiceToolChoiceAutoTypeAuto MessageNewParamsToolChoiceToolChoiceAutoType = "auto" +) + +func (r MessageNewParamsToolChoiceToolChoiceAutoType) IsKnown() bool { + switch r { + case MessageNewParamsToolChoiceToolChoiceAutoTypeAuto: + return true + } + return false +} + +// The model will use any available tools. +type MessageNewParamsToolChoiceToolChoiceAny struct { + Type param.Field[MessageNewParamsToolChoiceToolChoiceAnyType] `json:"type,required"` +} + +func (r MessageNewParamsToolChoiceToolChoiceAny) MarshalJSON() (data []byte, err error) { + return apijson.MarshalRoot(r) +} + +func (r MessageNewParamsToolChoiceToolChoiceAny) implementsMessageNewParamsToolChoiceUnion() {} + +type MessageNewParamsToolChoiceToolChoiceAnyType string + +const ( + MessageNewParamsToolChoiceToolChoiceAnyTypeAny MessageNewParamsToolChoiceToolChoiceAnyType = "any" +) + +func (r MessageNewParamsToolChoiceToolChoiceAnyType) IsKnown() bool { + switch r { + case MessageNewParamsToolChoiceToolChoiceAnyTypeAny: + return true + } + return false +} + +// The model will use the specified tool with `tool_choice.name`. +type MessageNewParamsToolChoiceToolChoiceTool struct { + // The name of the tool to use. + Name param.Field[string] `json:"name,required"` + Type param.Field[MessageNewParamsToolChoiceToolChoiceToolType] `json:"type,required"` +} + +func (r MessageNewParamsToolChoiceToolChoiceTool) MarshalJSON() (data []byte, err error) { + return apijson.MarshalRoot(r) +} + +func (r MessageNewParamsToolChoiceToolChoiceTool) implementsMessageNewParamsToolChoiceUnion() {} + +type MessageNewParamsToolChoiceToolChoiceToolType string + +const ( + MessageNewParamsToolChoiceToolChoiceToolTypeTool MessageNewParamsToolChoiceToolChoiceToolType = "tool" +) + +func (r MessageNewParamsToolChoiceToolChoiceToolType) IsKnown() bool { + switch r { + case MessageNewParamsToolChoiceToolChoiceToolTypeTool: + return true + } + return false +} + +type MessageNewParamsToolChoiceType string + +const ( + MessageNewParamsToolChoiceTypeAuto MessageNewParamsToolChoiceType = "auto" + MessageNewParamsToolChoiceTypeAny MessageNewParamsToolChoiceType = "any" + MessageNewParamsToolChoiceTypeTool MessageNewParamsToolChoiceType = "tool" +) + +func (r MessageNewParamsToolChoiceType) IsKnown() bool { + switch r { + case MessageNewParamsToolChoiceTypeAuto, MessageNewParamsToolChoiceTypeAny, MessageNewParamsToolChoiceTypeTool: + return true + } + return false +} diff --git a/message_test.go b/message_test.go new file mode 100644 index 0000000..0cb7067 --- /dev/null +++ b/message_test.go @@ -0,0 +1,106 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package anthropic_test + +import ( + "context" + "errors" + "os" + "testing" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/internal/testutil" + "github.com/anthropics/anthropic-sdk-go/option" +) + +func TestMessageNewWithOptionalParams(t *testing.T) { + baseURL := "http://localhost:4010" + if envURL, ok := os.LookupEnv("TEST_API_BASE_URL"); ok { + baseURL = envURL + } + if !testutil.CheckTestServer(t, baseURL) { + return + } + client := anthropic.NewClient( + option.WithBaseURL(baseURL), + option.WithAPIKey("my-anthropic-api-key"), + ) + _, err := client.Messages.New(context.TODO(), anthropic.MessageNewParams{ + MaxTokens: anthropic.F(int64(1024)), + Messages: anthropic.F([]anthropic.MessageParam{{ + Role: anthropic.F(anthropic.MessageParamRoleUser), + Content: anthropic.F([]anthropic.MessageParamContentUnion{anthropic.TextBlockParam{Type: anthropic.F(anthropic.TextBlockParamTypeText), Text: anthropic.F("What is a quaternion?")}}), + }}), + Model: anthropic.F(anthropic.ModelClaude_3_5_Sonnet_20240620), + Metadata: anthropic.F(anthropic.MessageNewParamsMetadata{ + UserID: anthropic.F("13803d75-b4b5-4c3e-b2a2-6f21399b021b"), + }), + StopSequences: anthropic.F([]string{"string", "string", "string"}), + System: anthropic.F[anthropic.MessageNewParamsSystemUnion](anthropic.MessageNewParamsSystemArray([]anthropic.TextBlockParam{{ + Type: anthropic.F(anthropic.TextBlockParamTypeText), + Text: anthropic.F("Today's date is 2024-06-01."), + }})), + Temperature: anthropic.F(1.000000), + ToolChoice: anthropic.F[anthropic.MessageNewParamsToolChoiceUnion](anthropic.MessageNewParamsToolChoiceToolChoiceAuto{ + Type: anthropic.F(anthropic.MessageNewParamsToolChoiceToolChoiceAutoTypeAuto), + }), + Tools: anthropic.F([]anthropic.ToolParam{{ + Description: anthropic.F("Get the current weather in a given location"), + Name: anthropic.F("x"), + InputSchema: anthropic.F(anthropic.ToolInputSchemaParam{ + Type: anthropic.F(anthropic.ToolInputSchemaTypeObject), + Properties: anthropic.F[any](map[string]interface{}{ + "location": map[string]interface{}{ + "description": "The city and state, e.g. San Francisco, CA", + "type": "string", + }, + "unit": map[string]interface{}{ + "description": "Unit for the output - one of (celsius, fahrenheit)", + "type": "string", + }, + }), + }), + }, { + Description: anthropic.F("Get the current weather in a given location"), + Name: anthropic.F("x"), + InputSchema: anthropic.F(anthropic.ToolInputSchemaParam{ + Type: anthropic.F(anthropic.ToolInputSchemaTypeObject), + Properties: anthropic.F[any](map[string]interface{}{ + "location": map[string]interface{}{ + "description": "The city and state, e.g. San Francisco, CA", + "type": "string", + }, + "unit": map[string]interface{}{ + "description": "Unit for the output - one of (celsius, fahrenheit)", + "type": "string", + }, + }), + }), + }, { + Description: anthropic.F("Get the current weather in a given location"), + Name: anthropic.F("x"), + InputSchema: anthropic.F(anthropic.ToolInputSchemaParam{ + Type: anthropic.F(anthropic.ToolInputSchemaTypeObject), + Properties: anthropic.F[any](map[string]interface{}{ + "location": map[string]interface{}{ + "description": "The city and state, e.g. San Francisco, CA", + "type": "string", + }, + "unit": map[string]interface{}{ + "description": "Unit for the output - one of (celsius, fahrenheit)", + "type": "string", + }, + }), + }), + }}), + TopK: anthropic.F(int64(5)), + TopP: anthropic.F(0.700000), + }) + if err != nil { + var apierr *anthropic.Error + if errors.As(err, &apierr) { + t.Log(string(apierr.DumpRequest(true))) + } + t.Fatalf("err should be nil: %s", err.Error()) + } +} diff --git a/option/requestoption.go b/option/requestoption.go new file mode 100644 index 0000000..5594c6d --- /dev/null +++ b/option/requestoption.go @@ -0,0 +1,245 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package option + +import ( + "bytes" + "fmt" + "io" + "log" + "net/http" + "net/url" + "time" + + "github.com/anthropics/anthropic-sdk-go/internal/requestconfig" + "github.com/tidwall/sjson" +) + +// RequestOption is an option for the requests made by the anthropic API Client +// which can be supplied to clients, services, and methods. You can read more about this functional +// options pattern in our [README]. +// +// [README]: https://pkg.go.dev/github.com/anthropics/anthropic-sdk-go#readme-requestoptions +type RequestOption = func(*requestconfig.RequestConfig) error + +// WithBaseURL returns a RequestOption that sets the BaseURL for the client. +func WithBaseURL(base string) RequestOption { + u, err := url.Parse(base) + if err != nil { + log.Fatalf("failed to parse BaseURL: %s\n", err) + } + return func(r *requestconfig.RequestConfig) error { + r.BaseURL = u + return nil + } +} + +// WithHTTPClient returns a RequestOption that changes the underlying [http.Client] used to make this +// request, which by default is [http.DefaultClient]. +func WithHTTPClient(client *http.Client) RequestOption { + return func(r *requestconfig.RequestConfig) error { + r.HTTPClient = client + return nil + } +} + +// MiddlewareNext is a function which is called by a middleware to pass an HTTP request +// to the next stage in the middleware chain. +type MiddlewareNext = func(*http.Request) (*http.Response, error) + +// Middleware is a function which intercepts HTTP requests, processing or modifying +// them, and then passing the request to the next middleware or handler +// in the chain by calling the provided MiddlewareNext function. +type Middleware = func(*http.Request, MiddlewareNext) (*http.Response, error) + +// WithMiddleware returns a RequestOption that applies the given middleware +// to the requests made. Each middleware will execute in the order they were given. +func WithMiddleware(middlewares ...Middleware) RequestOption { + return func(r *requestconfig.RequestConfig) error { + r.Middlewares = append(r.Middlewares, middlewares...) + return nil + } +} + +// WithMaxRetries returns a RequestOption that sets the maximum number of retries that the client +// attempts to make. When given 0, the client only makes one request. By +// default, the client retries two times. +// +// WithMaxRetries panics when retries is negative. +func WithMaxRetries(retries int) RequestOption { + if retries < 0 { + panic("option: cannot have fewer than 0 retries") + } + return func(r *requestconfig.RequestConfig) error { + r.MaxRetries = retries + return nil + } +} + +// WithHeader returns a RequestOption that sets the header value to the associated key. It overwrites +// any value if there was one already present. +func WithHeader(key, value string) RequestOption { + return func(r *requestconfig.RequestConfig) error { + r.Request.Header.Set(key, value) + return nil + } +} + +// WithHeaderAdd returns a RequestOption that adds the header value to the associated key. It appends +// onto any existing values. +func WithHeaderAdd(key, value string) RequestOption { + return func(r *requestconfig.RequestConfig) error { + r.Request.Header.Add(key, value) + return nil + } +} + +// WithHeaderDel returns a RequestOption that deletes the header value(s) associated with the given key. +func WithHeaderDel(key string) RequestOption { + return func(r *requestconfig.RequestConfig) error { + r.Request.Header.Del(key) + return nil + } +} + +// WithQuery returns a RequestOption that sets the query value to the associated key. It overwrites +// any value if there was one already present. +func WithQuery(key, value string) RequestOption { + return func(r *requestconfig.RequestConfig) error { + query := r.Request.URL.Query() + query.Set(key, value) + r.Request.URL.RawQuery = query.Encode() + return nil + } +} + +// WithQueryAdd returns a RequestOption that adds the query value to the associated key. It appends +// onto any existing values. +func WithQueryAdd(key, value string) RequestOption { + return func(r *requestconfig.RequestConfig) error { + query := r.Request.URL.Query() + query.Add(key, value) + r.Request.URL.RawQuery = query.Encode() + return nil + } +} + +// WithQueryDel returns a RequestOption that deletes the query value(s) associated with the key. +func WithQueryDel(key string) RequestOption { + return func(r *requestconfig.RequestConfig) error { + query := r.Request.URL.Query() + query.Del(key) + r.Request.URL.RawQuery = query.Encode() + return nil + } +} + +// WithJSONSet returns a RequestOption that sets the body's JSON value associated with the key. +// The key accepts a string as defined by the [sjson format]. +// +// [sjson format]: https://github.com/tidwall/sjson +func WithJSONSet(key string, value interface{}) RequestOption { + return func(r *requestconfig.RequestConfig) (err error) { + if buffer, ok := r.Body.(*bytes.Buffer); ok { + b := buffer.Bytes() + b, err = sjson.SetBytes(b, key, value) + if err != nil { + return err + } + r.Body = bytes.NewBuffer(b) + return nil + } + + return fmt.Errorf("cannot use WithJSONSet on a body that is not serialized as *bytes.Buffer") + } +} + +// WithJSONDel returns a RequestOption that deletes the body's JSON value associated with the key. +// The key accepts a string as defined by the [sjson format]. +// +// [sjson format]: https://github.com/tidwall/sjson +func WithJSONDel(key string) RequestOption { + return func(r *requestconfig.RequestConfig) (err error) { + if buffer, ok := r.Body.(*bytes.Buffer); ok { + b := buffer.Bytes() + b, err = sjson.DeleteBytes(b, key) + if err != nil { + return err + } + r.Body = bytes.NewBuffer(b) + return nil + } + + return fmt.Errorf("cannot use WithJSONDel on a body that is not serialized as *bytes.Buffer") + } +} + +// WithResponseBodyInto returns a RequestOption that overwrites the deserialization target with +// the given destination. If provided, we don't deserialize into the default struct. +func WithResponseBodyInto(dst any) RequestOption { + return func(r *requestconfig.RequestConfig) error { + r.ResponseBodyInto = dst + return nil + } +} + +// WithResponseInto returns a RequestOption that copies the [*http.Response] into the given address. +func WithResponseInto(dst **http.Response) RequestOption { + return func(r *requestconfig.RequestConfig) error { + r.ResponseInto = dst + return nil + } +} + +// WithRequestBody returns a RequestOption that provides a custom serialized body with the given +// content type. +// +// body accepts an io.Reader or raw []bytes. +func WithRequestBody(contentType string, body any) RequestOption { + return func(r *requestconfig.RequestConfig) error { + if reader, ok := body.(io.Reader); ok { + r.Body = reader + return r.Apply(WithHeader("Content-Type", contentType)) + } + + if b, ok := body.([]byte); ok { + r.Body = bytes.NewBuffer(b) + return r.Apply(WithHeader("Content-Type", contentType)) + } + + return fmt.Errorf("body must be a byte slice or implement io.Reader") + } +} + +// WithRequestTimeout returns a RequestOption that sets the timeout for +// each request attempt. This should be smaller than the timeout defined in +// the context, which spans all retries. +func WithRequestTimeout(dur time.Duration) RequestOption { + return func(r *requestconfig.RequestConfig) error { + r.RequestTimeout = dur + return nil + } +} + +// WithEnvironmentProduction returns a RequestOption that sets the current +// environment to be the "production" environment. An environment specifies which base URL +// to use by default. +func WithEnvironmentProduction() RequestOption { + return WithBaseURL("https://api.anthropic.com/") +} + +// WithAPIKey returns a RequestOption that sets the client setting "api_key". +func WithAPIKey(value string) RequestOption { + return func(r *requestconfig.RequestConfig) error { + r.APIKey = value + return r.Apply(WithHeader("X-Api-Key", r.APIKey)) + } +} + +// WithAuthToken returns a RequestOption that sets the client setting "auth_token". +func WithAuthToken(value string) RequestOption { + return func(r *requestconfig.RequestConfig) error { + r.AuthToken = value + return r.Apply(WithHeader("authorization", fmt.Sprintf("Bearer %s", r.AuthToken))) + } +} diff --git a/packages/ssestream/streaming.go b/packages/ssestream/streaming.go new file mode 100644 index 0000000..8a36567 --- /dev/null +++ b/packages/ssestream/streaming.go @@ -0,0 +1,172 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package ssestream + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" +) + +type Decoder interface { + Event() Event + Next() bool + Close() error + Err() error +} + +func NewDecoder(res *http.Response) Decoder { + if res == nil || res.Body == nil { + return nil + } + + var decoder Decoder + contentType := res.Header.Get("content-type") + if t, ok := decoderTypes[contentType]; ok { + decoder = t(res.Body) + } else { + scanner := bufio.NewScanner(res.Body) + decoder = &eventStreamDecoder{rc: res.Body, scn: scanner} + } + return decoder +} + +var decoderTypes = map[string](func(io.ReadCloser) Decoder){} + +func RegisterDecoder(contentType string, decoder func(io.ReadCloser) Decoder) { + decoderTypes[strings.ToLower(contentType)] = decoder +} + +type Event struct { + Type string + Data []byte +} + +// A base implementation of a Decoder for text/event-stream. +type eventStreamDecoder struct { + evt Event + rc io.ReadCloser + scn *bufio.Scanner + err error +} + +func (s *eventStreamDecoder) Next() bool { + if s.err != nil { + return false + } + + event := "" + data := bytes.NewBuffer(nil) + + for s.scn.Scan() { + txt := s.scn.Bytes() + + // Dispatch event on an empty line + if len(txt) == 0 { + s.evt = Event{ + Type: event, + Data: data.Bytes(), + } + return true + } + + // Split a string like "event: bar" into name="event" and value=" bar". + name, value, _ := bytes.Cut(txt, []byte(":")) + + // Consume an optional space after the colon if it exists. + if len(value) > 0 && value[0] == ' ' { + value = value[1:] + } + + switch string(name) { + case "": + // An empty line in the for ": something" is a comment and should be ignored. + continue + case "event": + event = string(value) + case "data": + _, s.err = data.Write(value) + if s.err != nil { + break + } + _, s.err = data.WriteRune('\n') + if s.err != nil { + break + } + } + } + + return false +} + +func (s *eventStreamDecoder) Event() Event { + return s.evt +} + +func (s *eventStreamDecoder) Close() error { + return s.rc.Close() +} + +func (s *eventStreamDecoder) Err() error { + return s.err +} + +type Stream[T any] struct { + decoder Decoder + cur T + err error + done bool +} + +func NewStream[T any](decoder Decoder, err error) *Stream[T] { + return &Stream[T]{ + decoder: decoder, + err: err, + } +} + +func (s *Stream[T]) Next() bool { + if s.err != nil { + return false + } + + for s.decoder.Next() { + switch s.decoder.Event().Type { + case "completion": + s.err = json.Unmarshal(s.decoder.Event().Data, &s.cur) + if s.err != nil { + return false + } + return true + case "message_start", "message_delta", "message_stop", "content_block_start", "content_block_delta", "content_block_stop": + s.err = json.Unmarshal(s.decoder.Event().Data, &s.cur) + if s.err != nil { + return false + } + return true + case "ping": + continue + case "error": + s.err = fmt.Errorf("%s", string(s.decoder.Event().Data)) + return false + } + } + + return false +} + +func (s *Stream[T]) Current() T { + return s.cur +} + +func (s *Stream[T]) Err() error { + return s.err +} + +func (s *Stream[T]) Close() error { + return s.decoder.Close() +} diff --git a/release-please-config.json b/release-please-config.json new file mode 100644 index 0000000..5391ba3 --- /dev/null +++ b/release-please-config.json @@ -0,0 +1,70 @@ +{ + "packages": { + ".": {} + }, + "$schema": "https://raw.githubusercontent.com/stainless-api/release-please/main/schemas/config.json", + "include-v-in-tag": true, + "include-component-in-tag": false, + "versioning": "prerelease", + "prerelease": true, + "bump-minor-pre-major": true, + "bump-patch-for-minor-pre-major": false, + "pull-request-header": "Automated Release PR", + "pull-request-title-pattern": "release: ${version}", + "changelog-sections": [ + { + "type": "feat", + "section": "Features" + }, + { + "type": "fix", + "section": "Bug Fixes" + }, + { + "type": "perf", + "section": "Performance Improvements" + }, + { + "type": "revert", + "section": "Reverts" + }, + { + "type": "chore", + "section": "Chores" + }, + { + "type": "docs", + "section": "Documentation" + }, + { + "type": "style", + "section": "Styles" + }, + { + "type": "refactor", + "section": "Refactors" + }, + { + "type": "test", + "section": "Tests", + "hidden": true + }, + { + "type": "build", + "section": "Build System" + }, + { + "type": "ci", + "section": "Continuous Integration", + "hidden": true + } + ], + "reviewers": [ + "@anthropics/sdk" + ], + "release-type": "go", + "extra-files": [ + "internal/version.go", + "README.md" + ] +} \ No newline at end of file diff --git a/scripts/bootstrap b/scripts/bootstrap new file mode 100755 index 0000000..ed03e52 --- /dev/null +++ b/scripts/bootstrap @@ -0,0 +1,16 @@ +#!/usr/bin/env bash + +set -e + +cd "$(dirname "$0")/.." + +if [ -f "Brewfile" ] && [ "$(uname -s)" = "Darwin" ]; then + brew bundle check >/dev/null 2>&1 || { + echo "==> Installing Homebrew dependencies…" + brew bundle + } +fi + +echo "==> Installing Go dependencies…" + +go mod tidy diff --git a/scripts/format b/scripts/format new file mode 100755 index 0000000..db2a3fa --- /dev/null +++ b/scripts/format @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +set -e + +cd "$(dirname "$0")/.." + +echo "==> Running gofmt -s -w" +gofmt -s -w . diff --git a/scripts/lint b/scripts/lint new file mode 100755 index 0000000..fa7ba1f --- /dev/null +++ b/scripts/lint @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +set -e + +cd "$(dirname "$0")/.." + +echo "==> Running Go build" +go build ./... diff --git a/scripts/mock b/scripts/mock new file mode 100755 index 0000000..f586157 --- /dev/null +++ b/scripts/mock @@ -0,0 +1,41 @@ +#!/usr/bin/env bash + +set -e + +cd "$(dirname "$0")/.." + +if [[ -n "$1" && "$1" != '--'* ]]; then + URL="$1" + shift +else + URL="$(grep 'openapi_spec_url' .stats.yml | cut -d' ' -f2)" +fi + +# Check if the URL is empty +if [ -z "$URL" ]; then + echo "Error: No OpenAPI spec path/url provided or found in .stats.yml" + exit 1 +fi + +echo "==> Starting mock server with URL ${URL}" + +# Run prism mock on the given spec +if [ "$1" == "--daemon" ]; then + npm exec --package=@stainless-api/prism-cli@5.8.4 -- prism mock "$URL" &> .prism.log & + + # Wait for server to come online + echo -n "Waiting for server" + while ! grep -q "✖ fatal\|Prism is listening" ".prism.log" ; do + echo -n "." + sleep 0.1 + done + + if grep -q "✖ fatal" ".prism.log"; then + cat .prism.log + exit 1 + fi + + echo +else + npm exec --package=@stainless-api/prism-cli@5.8.4 -- prism mock "$URL" +fi diff --git a/scripts/test b/scripts/test new file mode 100755 index 0000000..efebcea --- /dev/null +++ b/scripts/test @@ -0,0 +1,56 @@ +#!/usr/bin/env bash + +set -e + +cd "$(dirname "$0")/.." + +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[0;33m' +NC='\033[0m' # No Color + +function prism_is_running() { + curl --silent "http://localhost:4010" >/dev/null 2>&1 +} + +kill_server_on_port() { + pids=$(lsof -t -i tcp:"$1" || echo "") + if [ "$pids" != "" ]; then + kill "$pids" + echo "Stopped $pids." + fi +} + +function is_overriding_api_base_url() { + [ -n "$TEST_API_BASE_URL" ] +} + +if ! is_overriding_api_base_url && ! prism_is_running ; then + # When we exit this script, make sure to kill the background mock server process + trap 'kill_server_on_port 4010' EXIT + + # Start the dev server + ./scripts/mock --daemon +fi + +if is_overriding_api_base_url ; then + echo -e "${GREEN}✔ Running tests against ${TEST_API_BASE_URL}${NC}" + echo +elif ! prism_is_running ; then + echo -e "${RED}ERROR:${NC} The test suite will not run without a mock Prism server" + echo -e "running against your OpenAPI spec." + echo + echo -e "To run the server, pass in the path or url of your OpenAPI" + echo -e "spec to the prism command:" + echo + echo -e " \$ ${YELLOW}npm exec --package=@stoplight/prism-cli@~5.3.2 -- prism mock path/to/your.openapi.yml${NC}" + echo + + exit 1 +else + echo -e "${GREEN}✔ Mock prism server is running with your OpenAPI spec${NC}" + echo +fi + +echo "==> Running tests" +go test ./... "$@" diff --git a/shared/union.go b/shared/union.go new file mode 100644 index 0000000..f5c7461 --- /dev/null +++ b/shared/union.go @@ -0,0 +1,8 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package shared + +type UnionString string + +func (UnionString) ImplementsModel() {} +func (UnionString) ImplementsMessageNewParamsSystemUnion() {} diff --git a/usage_test.go b/usage_test.go new file mode 100644 index 0000000..c2d2be5 --- /dev/null +++ b/usage_test.go @@ -0,0 +1,39 @@ +// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +package anthropic_test + +import ( + "context" + "os" + "testing" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/internal/testutil" + "github.com/anthropics/anthropic-sdk-go/option" +) + +func TestUsage(t *testing.T) { + baseURL := "http://localhost:4010" + if envURL, ok := os.LookupEnv("TEST_API_BASE_URL"); ok { + baseURL = envURL + } + if !testutil.CheckTestServer(t, baseURL) { + return + } + client := anthropic.NewClient( + option.WithBaseURL(baseURL), + option.WithAPIKey("my-anthropic-api-key"), + ) + message, err := client.Messages.New(context.TODO(), anthropic.MessageNewParams{ + MaxTokens: anthropic.F(int64(1024)), + Messages: anthropic.F([]anthropic.MessageParam{{ + Role: anthropic.F(anthropic.MessageParamRoleUser), + Content: anthropic.F([]anthropic.MessageParamContentUnion{anthropic.TextBlockParam{Type: anthropic.F(anthropic.TextBlockParamTypeText), Text: anthropic.F("What is a quaternion?")}}), + }}), + Model: anthropic.F(anthropic.ModelClaude_3_5_Sonnet_20240620), + }) + if err != nil { + t.Error(err) + } + t.Logf("%+v\n", message.Content) +}