Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into outline
Browse files Browse the repository at this point in the history
  • Loading branch information
metaboulie committed Aug 29, 2024
2 parents 4801ef2 + 8560fee commit 4b08fe2
Show file tree
Hide file tree
Showing 19 changed files with 554 additions and 22 deletions.
16 changes: 16 additions & 0 deletions docs/guides/editor_features/ai_completion.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,22 @@ cell. This will open an input to modify the cell using AI.
</figure>
</div>

### Using Anthropic

To use Anthropic with marimo:

1. Sign up for an account at [Anthropic](https://console.anthropic.com/) and grab your [Anthropic Key](https://console.anthropic.com/settings/keys).
2. Add the following to your `~/.marimo.toml`:

```toml
[ai.open_ai]
model = "claude-3-5-sonnet-20240620"
# or any model from https://docs.anthropic.com/en/docs/about-claude/models

[ai.anthropic]
api_key = "sk-..."
```

### Using other AI providers

marimo supports OpenAI's GPT-3.5 API by default. If your provider is compatible with OpenAI's API, you can use it by changing the `base_url` in the configuration.
Expand Down
4 changes: 4 additions & 0 deletions frontend/src/components/app-config/user-config-form.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,10 @@ export const UserConfigForm: React.FC = () => {
/>
</FormControl>
<FormMessage />
<FormDescription>
If the model starts with "claude-", we will use your Anthropic
API key. Otherwise, we will use your OpenAI API key.
</FormDescription>
</FormItem>
)}
/>
Expand Down
8 changes: 5 additions & 3 deletions frontend/src/components/editor/cell/code/cell-editor.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import { aiCompletionCellAtom } from "@/core/ai/state";
import { mergeRefs } from "@/utils/mergeRefs";
import { useSetLastFocusedCellId } from "@/core/cells/focus";
import type { LanguageAdapterType } from "@/core/codemirror/language/types";
import { autoInstantiateAtom } from "@/core/config/config";
import { autoInstantiateAtom, isAiEnabled } from "@/core/config/config";
import { maybeAddMarimoImport } from "@/core/cells/add-missing-import";
import { OverridingHotkeyProvider } from "@/core/hotkeys/hotkeys";
import { useSplitCellCallback } from "../useSplitCell";
Expand Down Expand Up @@ -143,11 +143,13 @@ const CellEditorInternal = ({
maybeAddMarimoImport(autoInstantiate, createNewCell);
});

const aiEnabled = isAiEnabled(userConfig);

const extensions = useMemo(() => {
const extensions = setupCodeMirror({
cellId,
showPlaceholder,
enableAI: Boolean(userConfig.ai.open_ai?.api_key),
enableAI: aiEnabled,
cellCodeCallbacks: {
updateCellCode,
afterToggleMarkdown,
Expand Down Expand Up @@ -208,7 +210,7 @@ const CellEditorInternal = ({
cellId,
userConfig.keymap,
userConfig.completion,
userConfig.ai.open_ai?.api_key,
aiEnabled,
theme,
showPlaceholder,
createAbove,
Expand Down
5 changes: 5 additions & 0 deletions frontend/src/core/config/config-schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ export const UserConfigSchema = z
model: z.string().optional(),
})
.optional(),
anthropic: z
.object({
api_key: z.string().optional(),
})
.optional(),
})
.default({}),
experimental: z
Expand Down
8 changes: 7 additions & 1 deletion frontend/src/core/config/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,15 @@ export function getUserConfig() {
}

export const aiEnabledAtom = atom<boolean>((get) => {
return Boolean(get(userConfigAtom).ai.open_ai?.api_key);
return isAiEnabled(get(userConfigAtom));
});

export function isAiEnabled(config: UserConfig) {
return (
Boolean(config.ai.open_ai?.api_key) || Boolean(config.ai.anthropic?.api_key)
);
}

/**
* Atom for storing the app config.
*/
Expand Down
6 changes: 5 additions & 1 deletion frontend/src/css/admonition.css
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@
--admonition-heading-color: var(--blue-11);

display: flex;
position: relative;
flex-direction: column;
padding: 1rem;
padding-left: 3rem;
border-left: 4px solid;
margin-bottom: 1rem;
background-color: var(--admonition-bg);
border-color: var(--admonition-border);

&::before {
position: absolute;
margin-top: 0.35rem;
margin-right: 1rem;
margin-left: -2rem;
}
}

Expand Down
2 changes: 1 addition & 1 deletion marimo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
"video",
"vstack",
]
__version__ = "0.8.3"
__version__ = "0.8.4"

from marimo._ast.app import App
from marimo._ast.cell import Cell
Expand Down
19 changes: 17 additions & 2 deletions marimo/_config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,11 @@ class AiConfig(TypedDict):
**Keys.**
- `open_ai`: the OpenAI config
- `anthropic`: the Anthropic config
"""

open_ai: OpenAiConfig
anthropic: AnthropicConfig


@dataclass
Expand All @@ -186,7 +188,8 @@ class OpenAiConfig(TypedDict):
**Keys.**
- `api_key`: the OpenAI API key
- `model`: the model to use
- `model`: the model to use.
if model starts with `claude-` we use the AnthropicConfig
- `base_url`: the base URL for the API
"""

Expand All @@ -195,6 +198,18 @@ class OpenAiConfig(TypedDict):
base_url: NotRequired[str]


@dataclass
class AnthropicConfig(TypedDict):
"""Configuration options for Anthropic.
**Keys.**
- `api_key`: the Anthropic
"""

api_key: str


@mddoc
@dataclass
class MarimoConfig(TypedDict):
Expand Down Expand Up @@ -303,7 +318,7 @@ def deep_remove_from_path(path: list[str], obj: Dict[str, Any]) -> None:
else:
deep_remove_from_path(path[1:], cast(Dict[str, Any], obj[key]))

secrets = [["ai", "open_ai", "api_key"]]
secrets = [["ai", "open_ai", "api_key"], ["ai", "anthropic", "api_key"]]

new_config = _deep_copy(config)
for secret in secrets:
Expand Down
106 changes: 98 additions & 8 deletions marimo/_server/api/endpoints/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,17 @@
from marimo._utils.assert_never import assert_never

if TYPE_CHECKING:
from openai import OpenAI, Stream # type: ignore[import-not-found]
from anthropic import ( # type: ignore[import-not-found]
Client,
Stream as AnthropicStream,
)
from anthropic.types import ( # type: ignore[import-not-found]
RawMessageStreamEvent,
)
from openai import ( # type: ignore[import-not-found]
OpenAI,
Stream as OpenAiStream,
)
from openai.types.chat import ( # type: ignore[import-not-found]
ChatCompletionChunk,
)
Expand All @@ -36,7 +46,8 @@ def get_openai_client(config: MarimoConfig) -> "OpenAI":
from openai import OpenAI # type: ignore[import-not-found]
except ImportError:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, detail="OpenAI not installed"
status_code=HTTPStatus.BAD_REQUEST,
detail="OpenAI not installed. Run `pip install openai`",
) from None

if "ai" not in config:
Expand Down Expand Up @@ -68,6 +79,42 @@ def get_openai_client(config: MarimoConfig) -> "OpenAI":
return OpenAI(api_key=key, base_url=base_url)


def get_anthropic_client(config: MarimoConfig) -> "Client":
try:
from anthropic import Client # type: ignore[import-not-found]
except ImportError:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail="Anthropic not installed. Run `pip install anthropic`",
) from None

if "ai" not in config:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail="Anthropic not configured",
)
if "anthropic" not in config["ai"]:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail="Anthropic not configured",
)
if "api_key" not in config["ai"]["anthropic"]:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail="Anthropic API key not configured",
)

key: str = config["ai"]["anthropic"]["api_key"]

if not key:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail="Anthropic API key not configured",
)

return Client(api_key=key)


def get_model(config: MarimoConfig) -> str:
model: str = (
config.get("ai", {}).get("open_ai", {}).get("model", "gpt-4-turbo")
Expand All @@ -77,15 +124,34 @@ def get_model(config: MarimoConfig) -> str:
return model


def get_content(
response: RawMessageStreamEvent | ChatCompletionChunk,
) -> str | None:
if hasattr(response, "choices"):
return response.choices[0].delta.content # type: ignore

from anthropic.types import (
RawContentBlockDeltaEvent,
TextDelta,
)

if isinstance(response, RawContentBlockDeltaEvent):
if isinstance(response.delta, TextDelta):
return response.delta.text # type: ignore

return None


def make_stream_response(
response: Stream[ChatCompletionChunk],
response: OpenAiStream[ChatCompletionChunk]
| AnthropicStream[RawMessageStreamEvent],
) -> Generator[str, None, None]:
original_content = ""
buffer: str = ""
buffer = ""
in_code_fence = False
# If it starts or ends with markdown, remove it

for chunk in response:
content = chunk.choices[0].delta.content
content = get_content(chunk)
if not content:
continue

Expand Down Expand Up @@ -158,7 +224,6 @@ async def ai_completion(
app_state.require_current_session()
config = app_state.config_manager.get_config(hide_secrets=False)
body = await parse_request(request, cls=AiCompletionRequest)
client = get_openai_client(config)

if body.language == "python":
system_prompt = (
Expand Down Expand Up @@ -193,7 +258,32 @@ async def ai_completion(
if body.code.strip():
prompt = f"{prompt}\n\nCurrent code:\n{body.code}"

response = client.chat.completions.create(
model = get_model(config)

# If the model starts with claude, use anthropic
if model.startswith("claude"):
anthropic_client = get_anthropic_client(config)
response = anthropic_client.messages.create(
model=model,
max_tokens=1000,
messages=[
{
"role": "user",
"content": prompt,
}
],
system=system_prompt,
stream=True,
temperature=0,
)

return StreamingResponse(
content=make_stream_response(response),
media_type="application/json",
)

openai_client = get_openai_client(config)
response = openai_client.chat.completions.create(
model=get_model(config),
messages=[
{
Expand Down
28 changes: 25 additions & 3 deletions marimo/_smoke_tests/admonitions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright 2024 Marimo. All rights reserved.

import marimo

__generated_with = "0.6.0"
__generated_with = "0.8.3"
app = marimo.App()


Expand Down Expand Up @@ -32,22 +33,24 @@ def __(mo):
"warning",
]


def create(kind):
return mo.md(
rf"""
rf"""
!!! {kind} "{kind} admonition"
This is an admonition for {kind}
"""
)


mo.vstack([create(kind) for kind in kinds])
return create, kinds


@app.cell
def __(mo):
mo.md("# Misc")
mo.md("""# Misc""")
return


Expand All @@ -62,5 +65,24 @@ def __(mo):
return


@app.cell
def __(mo):
mo.md(
r"""
!!! tip ""
Importa recordar as seguintes regras de diferenciação de matrizes:
$$\frac{\partial\, u'v}{\partial\, v} = \frac{\partial\, v'u}{\partial\, v} = u$$
sendo $u$ e $v$ dois vetores.
$$\frac{\partial\, v'Av}{\partial\, v}=2Av=2v'A$$
em que $A$ é uma matriz simétrica. No nosso caso, $A=X'X$ e $v=\hat{\boldsymbol{\beta}}$.import marimo as mo
"""
)
return


if __name__ == "__main__":
app.run()
Loading

0 comments on commit 4b08fe2

Please sign in to comment.