): ExpandableProps =>
- Object({
- label: "hi",
- expandable: true,
- expanded: true,
- ...props,
- })
-
-describe("withExpandable HOC", () => {
- it("renders without crashing", () => {
- const props = getProps()
- const WithHoc = withExpandable(testComponent)
- const wrapper = mount()
- expect(wrapper.find(StatelessAccordion).exists()).toBe(true)
- })
-
- it("renders expander label as expected", () => {
- const props = getProps()
- const WithHoc = withExpandable(testComponent)
- const wrapper = mount()
- const wrappedExpandLabel = wrapper.find(StreamlitMarkdown)
-
- expect(wrappedExpandLabel.props().source).toBe(getProps().label)
- expect(wrappedExpandLabel.props().isLabel).toBe(true)
- })
-
- it("should render a expanded component", () => {
- const props = getProps()
- const WithHoc = withExpandable(testComponent)
- const wrapper = mount()
- const accordion = wrapper.find(StatelessAccordion)
-
- expect(accordion.prop("expanded").length).toBe(1)
- })
-
- it("should render a collapsed component", () => {
- const props = getProps({
- expanded: false,
- })
- const WithHoc = withExpandable(testComponent)
- const wrapper = mount()
- const accordion = wrapper.find(StatelessAccordion)
-
- expect(accordion.prop("expanded").length).toBe(0)
- })
-
- it("should become stale", () => {
- const props = getProps({
- isStale: true,
- })
- const WithHoc = withExpandable(testComponent)
- const wrapper = mount()
- const accordion = wrapper.find(StatelessAccordion)
- const overrides = accordion.prop("overrides")
-
- // @ts-expect-error
- expect(overrides.Header.props.isStale).toBeTruthy()
- })
-})
diff --git a/frontend/lib/src/hocs/withExpandable/withExpandable.tsx b/frontend/lib/src/hocs/withExpandable/withExpandable.tsx
deleted file mode 100644
index fae9c70ca14c..000000000000
--- a/frontend/lib/src/hocs/withExpandable/withExpandable.tsx
+++ /dev/null
@@ -1,205 +0,0 @@
-/**
- * Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-import React, { ComponentType, ReactElement, useEffect, useState } from "react"
-import { ExpandMore, ExpandLess } from "@emotion-icons/material-outlined"
-import Icon from "@streamlit/lib/src/components/shared/Icon"
-import StreamlitMarkdown from "@streamlit/lib/src/components/shared/StreamlitMarkdown"
-
-import classNames from "classnames"
-import {
- StatelessAccordion as Accordion,
- Panel,
- SharedStylePropsArg,
-} from "baseui/accordion"
-import { useTheme } from "@emotion/react"
-import { StyledExpandableContainer } from "./styled-components"
-
-export interface ExpandableProps {
- expandable: boolean
- label: string
- expanded: boolean
- empty: boolean
- widgetsDisabled: boolean
- isStale: boolean
-}
-
-// Our wrapper takes the wrapped component's props plus ExpandableProps
-type WrapperProps = P & ExpandableProps
-
-// TODO: there's no reason for this to be a HOC. Adapt it to follow the same
-// pattern as the `Tabs` and `ChatMessage` containers that simply parent their
-// children.
-function withExpandable
(
- WrappedComponent: ComponentType
-): ComponentType> {
- const ExpandableComponent = (props: WrapperProps): ReactElement => {
- const {
- label,
- expanded: initialExpanded,
- empty,
- widgetsDisabled,
- isStale,
- ...componentProps
- } = props
-
- const [expanded, setExpanded] = useState(initialExpanded)
- useEffect(() => {
- setExpanded(initialExpanded)
- // Having `label` in the dependency array here is necessary because
- // sometimes two distinct expanders look so similar that even the react
- // diffing algorithm decides that they're the same element with updated
- // props (this happens when something in the app removes one expander and
- // replaces it with another in the same position).
- //
- // By adding `label` as a dependency, we ensure that we reset the
- // expander's `expanded` state in this edge case.
- }, [label, initialExpanded])
-
- const toggle = (): void => setExpanded(!expanded)
- const { colors, radii, spacing, fontSizes } = useTheme()
-
- return (
-
- ({
- backgroundColor: colors.transparent,
- marginLeft: spacing.none,
- marginRight: spacing.none,
- marginTop: spacing.none,
- marginBottom: spacing.none,
- overflow: "visible",
- paddingLeft: spacing.lg,
- paddingRight: spacing.lg,
- paddingTop: 0,
- paddingBottom: $expanded ? spacing.lg : 0,
- borderTopStyle: "none",
- borderBottomStyle: "none",
- borderRightStyle: "none",
- borderLeftStyle: "none",
- }),
- props: { className: "streamlit-expanderContent" },
- },
- // Allow fullscreen button to overflow the expander
- ContentAnimationContainer: {
- style: ({ $expanded }: SharedStylePropsArg) => ({
- overflow: $expanded ? "visible" : "hidden",
- }),
- },
- PanelContainer: {
- style: () => ({
- marginLeft: `${spacing.none} !important`,
- marginRight: `${spacing.none} !important`,
- marginTop: `${spacing.none} !important`,
- marginBottom: `${spacing.none} !important`,
- paddingLeft: `${spacing.none} !important`,
- paddingRight: `${spacing.none} !important`,
- paddingTop: `${spacing.none} !important`,
- paddingBottom: `${spacing.none} !important`,
- borderTopStyle: "none !important",
- borderBottomStyle: "none !important",
- borderRightStyle: "none !important",
- borderLeftStyle: "none !important",
- }),
- },
- Header: {
- style: ({ $disabled }: SharedStylePropsArg) => ({
- marginBottom: spacing.none,
- marginLeft: spacing.none,
- marginRight: spacing.none,
- marginTop: spacing.none,
- backgroundColor: colors.transparent,
- color: $disabled ? colors.disabled : colors.bodyText,
- fontSize: fontSizes.sm,
- borderTopStyle: "none",
- borderBottomStyle: "none",
- borderRightStyle: "none",
- borderLeftStyle: "none",
- paddingBottom: spacing.md,
- paddingTop: spacing.md,
- paddingRight: spacing.lg,
- paddingLeft: spacing.lg,
- ...(isStale
- ? {
- opacity: 0.33,
- transition: "opacity 1s ease-in 0.5s",
- }
- : {}),
- }),
- props: {
- className: "streamlit-expanderHeader",
- isStale,
- },
- },
- ToggleIcon: {
- style: ({ $disabled }: SharedStylePropsArg) => ({
- color: $disabled ? colors.disabled : colors.bodyText,
- }),
- // eslint-disable-next-line react/display-name
- component: () => {
- if (expanded) {
- return
- }
- return
- },
- },
- Root: {
- props: {
- className: classNames("streamlit-expander", { empty }),
- isStale,
- },
- style: {
- borderStyle: "solid",
- borderWidth: "1px",
- borderColor: colors.fadedText10,
- borderRadius: radii.lg,
- ...(isStale
- ? {
- borderColor: colors.fadedText05,
- transition: "border 1s ease-in 0.5s",
- }
- : {}),
- },
- },
- }}
- >
-
- }
- key="panel"
- >
-
-
-
-
- )
- }
-
- return ExpandableComponent
-}
-
-export default withExpandable
diff --git a/frontend/lib/src/hocs/withFullScreenWrapper/withFullScreenWrapper.tsx b/frontend/lib/src/hocs/withFullScreenWrapper/withFullScreenWrapper.tsx
index 8299b3017e45..a40ddc8aef99 100644
--- a/frontend/lib/src/hocs/withFullScreenWrapper/withFullScreenWrapper.tsx
+++ b/frontend/lib/src/hocs/withFullScreenWrapper/withFullScreenWrapper.tsx
@@ -19,7 +19,7 @@ import hoistNonReactStatics from "hoist-non-react-statics"
import FullScreenWrapper from "@streamlit/lib/src/components/shared/FullScreenWrapper"
-interface Props {
+export interface Props {
width: number
height?: number
}
diff --git a/frontend/lib/src/hooks/useScrollAnimation.test.ts b/frontend/lib/src/hooks/useScrollAnimation.test.ts
index 5fbb8d4c6f02..b355ec770783 100644
--- a/frontend/lib/src/hooks/useScrollAnimation.test.ts
+++ b/frontend/lib/src/hooks/useScrollAnimation.test.ts
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-import { renderHook, act } from "@testing-library/react-hooks"
+import { renderHook } from "@testing-library/react-hooks"
import useScrollAnimation from "./useScrollAnimation"
describe("useScrollAnimation", () => {
@@ -51,25 +51,17 @@ describe("useScrollAnimation", () => {
renderHook(() => useScrollAnimation(targetElement, onEndMock, true))
// Simulate scroll animation
- act(() => {
- jest.advanceTimersByTime(5)
- // Trigger the callback of requestAnimationFrame
- act(() => {
- jest.runOnlyPendingTimers()
- })
- })
+ jest.advanceTimersByTime(5)
+ // Trigger the callback of requestAnimationFrame
+ jest.runOnlyPendingTimers()
// Assert the updated scrollTop value
expect(targetElement.scrollTop).toBeGreaterThan(0)
// Simulate reaching the end of animation
- act(() => {
- jest.advanceTimersByTime(100)
- // Trigger the callback of requestAnimationFrame
- act(() => {
- jest.runOnlyPendingTimers()
- })
- })
+ jest.advanceTimersByTime(100)
+ // Trigger the callback of requestAnimationFrame
+ jest.runOnlyPendingTimers()
// Assert that onEnd callback is called
expect(onEndMock).toHaveBeenCalled()
diff --git a/frontend/lib/src/index.ts b/frontend/lib/src/index.ts
index dd881aa0cdc3..6d7f5663105f 100644
--- a/frontend/lib/src/index.ts
+++ b/frontend/lib/src/index.ts
@@ -18,7 +18,6 @@
export {
IS_DEV_ENV,
RERUN_PROMPT_MODAL_DIALOG,
- SHOW_DEPLOY_BUTTON,
WEBSOCKET_PORT_DEV,
} from "./baseconsts"
export { default as VerticalBlock } from "./components/core/Block"
@@ -113,7 +112,6 @@ export {
isPaddingDisplayed,
isScrollingHidden,
isToolbarDisplayed,
- isTesting,
notUndefined,
setCookie,
extractPageNameFromPathName,
@@ -136,4 +134,5 @@ export { mockTheme } from "./mocks/mockTheme"
export { default as AlertElement } from "./components/elements/AlertElement"
export { default as TextElement } from "./components/elements/TextElement"
export { default as useScrollToBottom } from "./hooks/useScrollToBottom"
+export { RootStyleProvider } from "./RootStyleProvider"
export * from "./proto"
diff --git a/frontend/lib/src/mocks/mocks.ts b/frontend/lib/src/mocks/mocks.ts
index 8c0809d13b99..0856ad8d579e 100644
--- a/frontend/lib/src/mocks/mocks.ts
+++ b/frontend/lib/src/mocks/mocks.ts
@@ -54,10 +54,14 @@ export function mockEndpoints(
return {
buildComponentURL: jest.fn(),
buildMediaURL: jest.fn(),
+ buildFileUploadURL: jest.fn(),
buildAppPageURL: jest.fn(),
uploadFileUploaderFile: jest
.fn()
.mockRejectedValue(new Error("unimplemented mock endpoint")),
+ deleteFileAtURL: jest
+ .fn()
+ .mockRejectedValue(new Error("unimplemented mock endpoint")),
fetchCachedForwardMsg: jest
.fn()
.mockRejectedValue(new Error("unimplemented mock endpoint")),
diff --git a/frontend/lib/src/theme/utils.ts b/frontend/lib/src/theme/utils.ts
index 49264c128dab..57fb7f175555 100644
--- a/frontend/lib/src/theme/utils.ts
+++ b/frontend/lib/src/theme/utils.ts
@@ -417,6 +417,49 @@ export function hasLightBackgroundColor(theme: EmotionTheme): boolean {
return getLuminance(theme.colors.bgColor) > 0.5
}
+export function getDividerColors(theme: EmotionTheme): any {
+ const lightTheme = hasLightBackgroundColor(theme)
+ const blue = lightTheme ? theme.colors.blue60 : theme.colors.blue90
+ const green = lightTheme ? theme.colors.green60 : theme.colors.green90
+ const orange = lightTheme ? theme.colors.orange60 : theme.colors.orange90
+ const red = lightTheme ? theme.colors.red60 : theme.colors.red90
+ const violet = lightTheme ? theme.colors.purple60 : theme.colors.purple80
+ const gray = lightTheme ? theme.colors.gray40 : theme.colors.gray70
+
+ return {
+ blue: blue,
+ green: green,
+ orange: orange,
+ red: red,
+ violet: violet,
+ gray: gray,
+ grey: gray,
+ rainbow: `linear-gradient(to right, ${red}, ${orange}, ${green}, ${blue}, ${violet})`,
+ }
+}
+
+export function getMarkdownTextColors(theme: EmotionTheme): any {
+ const lightTheme = hasLightBackgroundColor(theme)
+ const red = lightTheme ? theme.colors.red80 : theme.colors.red70
+ const orange = lightTheme ? theme.colors.orange100 : theme.colors.orange60
+ const yellow = lightTheme ? theme.colors.yellow100 : theme.colors.yellow40
+ const green = lightTheme ? theme.colors.green90 : theme.colors.green60
+ const blue = lightTheme ? theme.colors.blue80 : theme.colors.blue50
+ const violet = lightTheme ? theme.colors.purple80 : theme.colors.purple50
+ const purple = lightTheme ? theme.colors.purple100 : theme.colors.purple80
+ const gray = lightTheme ? theme.colors.gray80 : theme.colors.gray70
+ return {
+ red: red,
+ orange: orange,
+ yellow: yellow,
+ green: green,
+ blue: blue,
+ violet: violet,
+ purple: purple,
+ gray: gray,
+ }
+}
+
export function getGray70(theme: EmotionTheme): string {
return hasLightBackgroundColor(theme)
? theme.colors.gray70
@@ -435,36 +478,6 @@ export function getGray90(theme: EmotionTheme): string {
: theme.colors.gray10
}
-export function getMdRed(theme: EmotionTheme): string {
- return hasLightBackgroundColor(theme)
- ? theme.colors.red80
- : theme.colors.red70
-}
-
-export function getMdBlue(theme: EmotionTheme): string {
- return hasLightBackgroundColor(theme)
- ? theme.colors.blue80
- : theme.colors.blue50
-}
-
-export function getMdGreen(theme: EmotionTheme): string {
- return hasLightBackgroundColor(theme)
- ? theme.colors.green90
- : theme.colors.green60
-}
-
-export function getMdViolet(theme: EmotionTheme): string {
- return hasLightBackgroundColor(theme)
- ? theme.colors.purple80
- : theme.colors.purple50
-}
-
-export function getMdOrange(theme: EmotionTheme): string {
- return hasLightBackgroundColor(theme)
- ? theme.colors.orange100
- : theme.colors.orange60
-}
-
function getBlueArrayAsc(theme: EmotionTheme): string[] {
const { colors } = theme
return [
diff --git a/frontend/lib/src/util/Resolver.ts b/frontend/lib/src/util/Resolver.ts
index 2afe63d58ad1..dc5efb91f2e8 100644
--- a/frontend/lib/src/util/Resolver.ts
+++ b/frontend/lib/src/util/Resolver.ts
@@ -18,19 +18,26 @@
* A promise wrapper that makes resolve/reject functions public.
*/
export default class Resolver {
- public resolve: (value: T | PromiseLike) => void
+ public readonly resolve: (value: T | PromiseLike) => void
- public reject: (reason?: any) => void | Promise
+ public readonly reject: (reason?: any) => void | Promise
- public promise: Promise
+ public readonly promise: Promise
constructor() {
- // Initialize to something so TS is happy.
+ // Initialize to something so that TS is happy, then use @ts-expect-error
+ // so that we can assign the actually desired values to resolve and reject.
+ //
+ // This is necessary because TS isn't able to deduce that resolve and
+ // reject will always be set in the callback passed to the Promise
+ // constructor below.
this.resolve = () => {}
this.reject = () => {}
this.promise = new Promise((resFn, rejFn) => {
+ // @ts-expect-error
this.resolve = resFn
+ // @ts-expect-error
this.reject = rejFn
})
}
diff --git a/frontend/lib/src/util/utils.ts b/frontend/lib/src/util/utils.ts
index 41d2983bfd8a..44cc07d7a9bf 100644
--- a/frontend/lib/src/util/utils.ts
+++ b/frontend/lib/src/util/utils.ts
@@ -447,21 +447,3 @@ export function extractPageNameFromPathName(
.replace(new RegExp("/$"), "")
)
}
-
-export const TESTING_QUERY_PARAM_KEY = "_stcore_testing"
-export function isTesting(): boolean {
- const urlParams = new URLSearchParams(window.location.search)
- let isTesting = false
- urlParams.forEach((paramValue, paramKey) => {
- paramKey = paramKey.toString().toLowerCase()
- paramValue = paramValue.toString().toLowerCase()
- if (
- paramKey === TESTING_QUERY_PARAM_KEY.toLowerCase() &&
- paramValue === "true"
- ) {
- isTesting = true
- return isTesting
- }
- })
- return isTesting
-}
diff --git a/frontend/package.json b/frontend/package.json
index 6f7109a4c48e..2e74f4d667fd 100644
--- a/frontend/package.json
+++ b/frontend/package.json
@@ -1,8 +1,11 @@
{
"name": "streamlit",
- "version": "1.25.0.dev1",
+ "version": "1.26.0.dev1",
"private": true,
- "workspaces": ["app", "lib"],
+ "workspaces": [
+ "app",
+ "lib"
+ ],
"scripts": {
"start": "yarn workspace @streamlit/app start",
"build": "yarn workspace @streamlit/lib build && yarn workspace @streamlit/app build",
diff --git a/frontend/yarn.lock b/frontend/yarn.lock
index 048c8c2c668d..10e9c0a4ad0d 100644
--- a/frontend/yarn.lock
+++ b/frontend/yarn.lock
@@ -3916,6 +3916,11 @@
resolved "https://registry.yarnpkg.com/@types/unist/-/unist-2.0.6.tgz#250a7b16c3b91f672a24552ec64678eeb1d3a08d"
integrity sha512-PBjIUxZHOuj0R15/xuwJYjFi+KZdNFrehocChv4g5hu6aFroHue8m0lBP0POdK2nKzbw0cgV1mws8+V/JAcEkQ==
+"@types/uuid@^9.0.2":
+ version "9.0.2"
+ resolved "https://registry.yarnpkg.com/@types/uuid/-/uuid-9.0.2.tgz#ede1d1b1e451548d44919dc226253e32a6952c4b"
+ integrity sha512-kNnC1GFBLuhImSnV7w4njQkUiJi0ZXUycu1rUaouPqiKlXkh77JKgdRnTAp1x5eBwcIwbtI+3otwzuIDEuDoxQ==
+
"@types/ws@^8.5.1":
version "8.5.4"
resolved "https://registry.yarnpkg.com/@types/ws/-/ws-8.5.4.tgz#bb10e36116d6e570dd943735f86c933c1587b8a5"
@@ -16912,6 +16917,11 @@ uuid@^8.3.2:
resolved "https://registry.yarnpkg.com/uuid/-/uuid-8.3.2.tgz#80d5b5ced271bb9af6c445f21a1a04c606cefbe2"
integrity sha512-+NYs2QeMWy+GWFOEm9xnn6HCDp0l7QBD7ml8zLUmJ+93Q5NF0NocErnwkTkXVFNiX3/fpC6afS8Dhb/gz7R7eg==
+uuid@^9.0.0:
+ version "9.0.0"
+ resolved "https://registry.yarnpkg.com/uuid/-/uuid-9.0.0.tgz#592f550650024a38ceb0c562f2f6aa435761efb5"
+ integrity sha512-MXcSTerfPa4uqyzStbRoTgt5XIe3x5+42+q1sDuy3R5MDk66URdLMOZe5aPX/SQd+kuYAh0FdP/pO28IkQyTeg==
+
uvu@^0.5.0:
version "0.5.6"
resolved "https://registry.yarnpkg.com/uvu/-/uvu-0.5.6.tgz#2754ca20bcb0bb59b64e9985e84d2e81058502df"
diff --git a/lib/setup.py b/lib/setup.py
index f9caced83105..85d2a1bd0ee6 100644
--- a/lib/setup.py
+++ b/lib/setup.py
@@ -21,7 +21,7 @@
THIS_DIRECTORY = Path(__file__).parent
-VERSION = "1.25.0.dev1" # PEP-440
+VERSION = "1.26.0.dev1" # PEP-440
NAME = "streamlit"
@@ -58,7 +58,7 @@
"tenacity>=8.1.0, <9",
"toml>=0.10.1, <2",
"typing-extensions>=4.1.0, <5",
- "tzlocal>=1.1, <5",
+ "tzlocal>=1.1, <6",
"validators>=0.2, <1",
# Don't require watchdog on MacOS, since it'll fail without xcode tools.
# Without watchdog, we fallback to a polling file watcher to check for app changes.
diff --git a/lib/streamlit/__init__.py b/lib/streamlit/__init__.py
index 79786b35cb2a..1d8c184726a3 100644
--- a/lib/streamlit/__init__.py
+++ b/lib/streamlit/__init__.py
@@ -166,6 +166,7 @@ def _update_logger() -> None:
text = _main.text
text_area = _main.text_area
text_input = _main.text_input
+toggle = _main.toggle
time_input = _main.time_input
title = _main.title
vega_lite_chart = _main.vega_lite_chart
@@ -173,6 +174,7 @@ def _update_logger() -> None:
warning = _main.warning
write = _main.write
color_picker = _main.color_picker
+status = _main.status
# Events - Note: these methods cannot be called directly on sidebar (ex: st.sidebar.toast)
toast = event.toast
diff --git a/lib/streamlit/commands/execution_control.py b/lib/streamlit/commands/execution_control.py
index 31b3bb389625..811ad65f0e56 100644
--- a/lib/streamlit/commands/execution_control.py
+++ b/lib/streamlit/commands/execution_control.py
@@ -14,6 +14,7 @@
from typing import NoReturn
+from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.scriptrunner import (
RerunData,
OnPurposeRerunException,
@@ -43,6 +44,7 @@ def stop() -> NoReturn:
raise StopException()
+@gather_metrics("experimental_rerun")
def rerun() -> NoReturn:
"""Rerun the script immediately.
diff --git a/lib/streamlit/components/v1/components.py b/lib/streamlit/components/v1/components.py
index 89fca77d2f06..3595f6d54488 100644
--- a/lib/streamlit/components/v1/components.py
+++ b/lib/streamlit/components/v1/components.py
@@ -29,6 +29,7 @@
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.scriptrunner import get_script_run_ctx
from streamlit.runtime.state import NoValue, register_widget
+from streamlit.runtime.state.common import compute_widget_id
from streamlit.type_util import to_bytes
LOGGER = get_logger(__name__)
@@ -162,10 +163,9 @@ def marshall_component(dg, element: Element) -> Union[Any, Type[NoValue]]:
# Normally, a widget's element_hash (which determines
# its identity across multiple runs of an app) is computed
- # by hashing the entirety of its protobuf. This means that,
- # if any of the arguments to the widget are changed, Streamlit
- # considers it a new widget instance and it loses its previous
- # state.
+ # by hashing its arguments. This means that, if any of the arguments
+ # to the widget are changed, Streamlit considers it a new widget
+ # instance and it loses its previous state.
#
# However! If a *component* has a `key` argument, then the
# component's hash identity is determined by entirely by
@@ -173,10 +173,6 @@ def marshall_component(dg, element: Element) -> Union[Any, Type[NoValue]]:
# exists, the component will maintain its identity even when its
# other arguments change, and the component's iframe won't be
# remounted on the frontend.
- #
- # So: if `key` is None, we marshall the element's arguments
- # *before* computing its widget_ui_value (which creates its hash).
- # If `key` is not None, we marshall the arguments *after*.
def marshall_element_args():
element.component_instance.json_args = serialized_json_args
@@ -184,6 +180,26 @@ def marshall_element_args():
if key is None:
marshall_element_args()
+ id = compute_widget_id(
+ "component_instance",
+ user_key=key,
+ name=self.name,
+ form_id=current_form_id(dg),
+ url=self.url,
+ key=key,
+ json_args=serialized_json_args,
+ special_args=special_args,
+ )
+ else:
+ id = compute_widget_id(
+ "component_instance",
+ user_key=key,
+ name=self.name,
+ form_id=current_form_id(dg),
+ url=self.url,
+ key=key,
+ )
+ element.component_instance.id = id
def deserialize_component(ui_value, widget_id=""):
# ui_value is an object from json, an ArrowTable proto, or a bytearray
diff --git a/lib/streamlit/delta_generator.py b/lib/streamlit/delta_generator.py
index 32816c8d5c4e..c600edd96efe 100644
--- a/lib/streamlit/delta_generator.py
+++ b/lib/streamlit/delta_generator.py
@@ -24,6 +24,7 @@
Hashable,
Iterable,
NoReturn,
+ Optional,
Type,
TypeVar,
cast,
@@ -36,11 +37,12 @@
from streamlit import config, cursor, env_util, logger, runtime, type_util, util
from streamlit.cursor import Cursor
from streamlit.elements.alert import AlertMixin
+from streamlit.elements.altair_utils import AddRowsMetadata
# DataFrame elements come in two flavors: "Legacy" and "Arrow".
# We select between them with the DataFrameElementSelectorMixin.
from streamlit.elements.arrow import ArrowMixin
-from streamlit.elements.arrow_altair import ArrowAltairMixin
+from streamlit.elements.arrow_altair import ArrowAltairMixin, prep_data
from streamlit.elements.arrow_vega_lite import ArrowVegaLiteMixin
from streamlit.elements.balloons import BalloonsMixin
from streamlit.elements.bokeh_chart import BokehMixin
@@ -57,7 +59,7 @@
from streamlit.elements.image import ImageMixin
from streamlit.elements.json import JsonMixin
from streamlit.elements.layouts import LayoutsMixin
-from streamlit.elements.legacy_altair import LegacyAltairMixin
+from streamlit.elements.legacy_altair import ArrowNotSupportedError, LegacyAltairMixin
from streamlit.elements.legacy_data_frame import LegacyDataFrameMixin
from streamlit.elements.legacy_vega_lite import LegacyVegaLiteMixin
from streamlit.elements.map import MapMixin
@@ -113,6 +115,7 @@
"arrow_line_chart",
"arrow_area_chart",
"arrow_bar_chart",
+ "arrow_scatter_chart",
)
Value = TypeVar("Value")
@@ -417,7 +420,7 @@ def _enqueue( # type: ignore[misc]
delta_type: str,
element_proto: Message,
return_value: None,
- last_index: Hashable | None = None,
+ add_rows_metadata: Optional[AddRowsMetadata] = None,
element_width: int | None = None,
element_height: int | None = None,
) -> DeltaGenerator:
@@ -429,7 +432,7 @@ def _enqueue( # type: ignore[misc]
delta_type: str,
element_proto: Message,
return_value: Type[NoValue],
- last_index: Hashable | None = None,
+ add_rows_metadata: Optional[AddRowsMetadata] = None,
element_width: int | None = None,
element_height: int | None = None,
) -> None:
@@ -441,7 +444,7 @@ def _enqueue( # type: ignore[misc]
delta_type: str,
element_proto: Message,
return_value: Value,
- last_index: Hashable | None = None,
+ add_rows_metadata: Optional[AddRowsMetadata] = None,
element_width: int | None = None,
element_height: int | None = None,
) -> Value:
@@ -453,7 +456,7 @@ def _enqueue(
delta_type: str,
element_proto: Message,
return_value: None = None,
- last_index: Hashable | None = None,
+ add_rows_metadata: Optional[AddRowsMetadata] = None,
element_width: int | None = None,
element_height: int | None = None,
) -> DeltaGenerator:
@@ -465,7 +468,7 @@ def _enqueue(
delta_type: str,
element_proto: Message,
return_value: Type[NoValue] | Value | None = None,
- last_index: Hashable | None = None,
+ add_rows_metadata: Optional[AddRowsMetadata] = None,
element_width: int | None = None,
element_height: int | None = None,
) -> DeltaGenerator | Value | None:
@@ -476,7 +479,7 @@ def _enqueue(
delta_type: str,
element_proto: Message,
return_value: Type[NoValue] | Value | None = None,
- last_index: Hashable | None = None,
+ add_rows_metadata: Optional[AddRowsMetadata] = None,
element_width: int | None = None,
element_height: int | None = None,
) -> DeltaGenerator | Value | None:
@@ -549,7 +552,7 @@ def _enqueue(
# position.
new_cursor = (
dg._cursor.get_locked_cursor(
- delta_type=delta_type, last_index=last_index
+ delta_type=delta_type, add_rows_metadata=add_rows_metadata
)
if dg._cursor is not None
else None
@@ -579,6 +582,7 @@ def _enqueue(
def _block(
self,
block_proto: Block_pb2.Block = Block_pb2.Block(),
+ dg_type: type | None = None,
) -> DeltaGenerator:
# Operate on the active DeltaGenerator, in case we're in a `with` block.
dg = self._active_dg
@@ -626,18 +630,27 @@ def _block(
root_container=dg._root_container,
parent_path=dg._cursor.parent_path + (dg._cursor.index,),
)
- block_dg = DeltaGenerator(
- root_container=dg._root_container,
- cursor=block_cursor,
- parent=dg,
- block_type=block_type,
+
+ # `dg_type` param added for st.status container. It allows us to
+ # instantiate DeltaGenerator subclasses from the function.
+ if dg_type is None:
+ dg_type = DeltaGenerator
+
+ block_dg = cast(
+ DeltaGenerator,
+ dg_type(
+ root_container=dg._root_container,
+ cursor=block_cursor,
+ parent=dg,
+ block_type=block_type,
+ ),
)
# Blocks inherit their parent form ids.
# NOTE: Container form ids aren't set in proto.
block_dg._form_data = FormData(current_form_id(dg))
# Must be called to increment this cursor's index.
- dg._cursor.get_locked_cursor(last_index=None)
+ dg._cursor.get_locked_cursor(add_rows_metadata=None)
_enqueue_message(msg)
caching.save_block_message(
@@ -732,12 +745,16 @@ def _legacy_add_rows(
"Command requires exactly one dataset"
)
+ # The legacy add_rows does not support Arrow tables.
+ if type_util.is_type(data, "pyarrow.lib.Table"):
+ raise ArrowNotSupportedError()
+
# When doing _legacy_add_rows on an element that does not already have data
# (for example, st._legacy_line_chart() without any args), call the original
# st._legacy_foo() element with new data instead of doing a _legacy_add_rows().
if (
self._cursor.props["delta_type"] in DELTA_TYPES_THAT_MELT_DATAFRAMES
- and self._cursor.props["last_index"] is None
+ and self._cursor.props["add_rows_metadata"].last_index is None
):
# IMPORTANT: This assumes delta types and st method names always
# match!
@@ -747,8 +764,11 @@ def _legacy_add_rows(
st_method(data, **kwargs)
return None
- data, self._cursor.props["last_index"] = _maybe_melt_data_for_add_rows(
- data, self._cursor.props["delta_type"], self._cursor.props["last_index"]
+ new_data, self._cursor.props["add_rows_metadata"] = _prep_data_for_add_rows(
+ data,
+ self._cursor.props["delta_type"],
+ self._cursor.props["add_rows_metadata"],
+ is_legacy=True,
)
msg = ForwardMsg_pb2.ForwardMsg()
@@ -756,7 +776,7 @@ def _legacy_add_rows(
import streamlit.elements.legacy_data_frame as data_frame
- data_frame.marshall_data_frame(data, msg.delta.add_rows.data)
+ data_frame.marshall_data_frame(new_data, msg.delta.add_rows.data)
if name:
msg.delta.add_rows.name = name
@@ -853,7 +873,7 @@ def _arrow_add_rows(
# st._arrow_foo() element with new data instead of doing a _arrow_add_rows().
if (
self._cursor.props["delta_type"] in ARROW_DELTA_TYPES_THAT_MELT_DATAFRAMES
- and self._cursor.props["last_index"] is None
+ and self._cursor.props["add_rows_metadata"].last_index is None
):
# IMPORTANT: This assumes delta types and st method names always
# match!
@@ -863,8 +883,11 @@ def _arrow_add_rows(
st_method(data, **kwargs)
return None
- data, self._cursor.props["last_index"] = _maybe_melt_data_for_add_rows(
- data, self._cursor.props["delta_type"], self._cursor.props["last_index"]
+ new_data, self._cursor.props["add_rows_metadata"] = _prep_data_for_add_rows(
+ data,
+ self._cursor.props["delta_type"],
+ self._cursor.props["add_rows_metadata"],
+ is_legacy=False,
)
msg = ForwardMsg_pb2.ForwardMsg()
@@ -873,7 +896,7 @@ def _arrow_add_rows(
import streamlit.elements.arrow as arrow_proto
default_uuid = str(hash(self._get_delta_path_str()))
- arrow_proto.marshall(msg.delta.arrow_add_rows.data, data, default_uuid)
+ arrow_proto.marshall(msg.delta.arrow_add_rows.data, new_data, default_uuid)
if name:
msg.delta.arrow_add_rows.name = name
@@ -884,17 +907,26 @@ def _arrow_add_rows(
return self
-DFT = TypeVar("DFT", bound=type_util.DataFrameCompatible)
+def _prep_data_for_add_rows(
+ data: Data,
+ delta_type: str,
+ add_rows_metadata: AddRowsMetadata,
+ is_legacy: bool,
+) -> tuple[Data, AddRowsMetadata]:
+ out_data: Data
+ # For some delta types we have to reshape the data structure
+ # otherwise the input data and the actual data used
+ # by vega_lite will be different, and it will throw an error.
+ if (
+ delta_type in DELTA_TYPES_THAT_MELT_DATAFRAMES
+ or delta_type in ARROW_DELTA_TYPES_THAT_MELT_DATAFRAMES
+ ):
+ import pandas as pd
-def _maybe_melt_data_for_add_rows(
- data: DFT,
- delta_type: str,
- last_index: Any,
-) -> tuple[DFT | DataFrame, int | Any]:
- import pandas as pd
+ df = cast(pd.DataFrame, type_util.convert_anything_to_df(data))
- def _melt_data(df: DataFrame, last_index: Any) -> tuple[DataFrame, int | Any]:
+ # Make range indices start at last_index.
if isinstance(df.index, pd.RangeIndex):
old_step = _get_pandas_index_attr(df, "step")
@@ -908,35 +940,26 @@ def _melt_data(df: DataFrame, last_index: Any) -> tuple[DataFrame, int | Any]:
"'RangeIndex' object has no attribute 'step'"
)
- start = last_index + old_step
- stop = last_index + old_step + old_stop
+ start = add_rows_metadata.last_index + old_step
+ stop = add_rows_metadata.last_index + old_step + old_stop
df.index = pd.RangeIndex(start=start, stop=stop, step=old_step)
- last_index = stop - 1
-
- index_name = df.index.name
- if index_name is None:
- index_name = "index"
+ add_rows_metadata.last_index = stop - 1
- df = pd.melt(df.reset_index(), id_vars=[index_name])
- return df, last_index
+ if is_legacy:
+ index_name = df.index.name
+ if index_name is None:
+ index_name = "index"
- # For some delta types we have to reshape the data structure
- # otherwise the input data and the actual data used
- # by vega_lite will be different, and it will throw an error.
- if (
- delta_type in DELTA_TYPES_THAT_MELT_DATAFRAMES
- or delta_type in ARROW_DELTA_TYPES_THAT_MELT_DATAFRAMES
- ):
- if not isinstance(data, pd.DataFrame):
- return _melt_data(
- df=type_util.convert_anything_to_df(data),
- last_index=last_index,
- )
+ out_data = pd.melt(df.reset_index(), id_vars=[index_name])
else:
- return _melt_data(df=data, last_index=last_index)
+ out_data, *_ = prep_data(df, **add_rows_metadata.columns)
+
+ else:
+ # When calling add_rows on st.table or st.dataframe we want styles to pass through.
+ out_data = type_util.convert_anything_to_df(data, allow_styler=True)
- return data, last_index
+ return out_data, add_rows_metadata
def _get_pandas_index_attr(
diff --git a/lib/streamlit/elements/altair_utils.py b/lib/streamlit/elements/altair_utils.py
new file mode 100644
index 000000000000..4bd9742907b3
--- /dev/null
+++ b/lib/streamlit/elements/altair_utils.py
@@ -0,0 +1,38 @@
+# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Useful classes for our native Altair-based charts.
+
+These classes are used in both Arrow-based and legacy-based charting code to pass some
+important info to add_rows.
+"""
+
+from dataclasses import dataclass
+from typing import Hashable, List, Optional, TypedDict
+
+
+class PrepDataColumns(TypedDict):
+ """Columns used for the prep_data step in Altair Arrow charts."""
+
+ x_column: Optional[str]
+ y_column_list: List[str]
+ color_column: Optional[str]
+
+
+@dataclass
+class AddRowsMetadata:
+ """Metadata needed by add_rows on native charts."""
+
+ last_index: Optional[Hashable]
+ columns: PrepDataColumns
diff --git a/lib/streamlit/elements/arrow_altair.py b/lib/streamlit/elements/arrow_altair.py
index 5be17a303741..a0f6f9c3de41 100644
--- a/lib/streamlit/elements/arrow_altair.py
+++ b/lib/streamlit/elements/arrow_altair.py
@@ -24,6 +24,7 @@
from typing import (
TYPE_CHECKING,
Any,
+ Collection,
Dict,
List,
Optional,
@@ -39,6 +40,14 @@
import streamlit.elements.arrow_vega_lite as arrow_vega_lite
from streamlit import type_util
+from streamlit.color_util import (
+ Color,
+ is_color_like,
+ is_color_tuple_like,
+ is_hex_color_like,
+ to_css_color,
+)
+from streamlit.elements.altair_utils import AddRowsMetadata
from streamlit.elements.arrow import Data
from streamlit.elements.utils import last_index_for_melted_dataframes
from streamlit.errors import StreamlitAPIException
@@ -48,7 +57,7 @@
from streamlit.runtime.metrics_util import gather_metrics
if TYPE_CHECKING:
- from altair import Chart
+ import altair as alt
from streamlit.delta_generator import DeltaGenerator
@@ -59,14 +68,37 @@ class ChartType(Enum):
LINE = "line"
+COLOR_LEGEND_SETTINGS = dict(titlePadding=0, offset=10, orient="bottom")
+
+# User-readable names to give the index and melted columns.
+SEPARATED_INDEX_COLUMN_TITLE = "index"
+MELTED_Y_COLUMN_TITLE = "value"
+MELTED_COLOR_COLUMN_TITLE = "color"
+
+# Crazy internal (non-user-visible) names for the index and melted columns, in order to
+# avoid collision with existing column names. The suffix below was generated with an
+# online random number generator. Rationale: because it makes it even less likely to
+# lead to a conflict than something that's human-readable (like "--streamlit-fake-field"
+# or something).
+PROTECTION_SUFFIX = "--p5bJXXpQgvPz6yvQMFiy"
+SEPARATED_INDEX_COLUMN_NAME = SEPARATED_INDEX_COLUMN_TITLE + PROTECTION_SUFFIX
+MELTED_Y_COLUMN_NAME = MELTED_Y_COLUMN_TITLE + PROTECTION_SUFFIX
+MELTED_COLOR_COLUMN_NAME = MELTED_COLOR_COLUMN_TITLE + PROTECTION_SUFFIX
+
+# Name we use for a column we know doesn't exist in the data, to address a Vega-Lite rendering bug
+# where empty charts need x, y encodings set in order to take up space.
+NON_EXISTENT_COLUMN_NAME = "DOES_NOT_EXIST" + PROTECTION_SUFFIX
+
+
class ArrowAltairMixin:
@gather_metrics("_arrow_line_chart")
def _arrow_line_chart(
self,
data: Data = None,
*,
- x: Union[str, None] = None,
+ x: Optional[str] = None,
y: Union[str, Sequence[str], None] = None,
+ color: Union[str, Color, List[Color], None] = None,
width: int = 0,
height: int = 0,
use_container_width: bool = True,
@@ -87,14 +119,56 @@ def _arrow_line_chart(
Data to be plotted.
x : str or None
- Column name to use for the x-axis. If None, uses the data index for the x-axis.
- This argument can only be supplied by keyword.
+ Column name to use for the x-axis. If None, uses the data index for
+ the x-axis. This argument can only be supplied by keyword.
y : str, sequence of str, or None
- Column name(s) to use for the y-axis. If a sequence of strings, draws several series
- on the same chart by melting your wide-format table into a long-format table behind
- the scenes. If None, draws the data of all remaining columns as data series.
- This argument can only be supplied by keyword.
+ Column name(s) to use for the y-axis. If a sequence of strings,
+ draws several series on the same chart by melting your wide-format
+ table into a long-format table behind the scenes. If None, draws
+ the data of all remaining columns as data series. This argument
+ can only be supplied by keyword.
+
+ color : str, tuple, sequence of str, sequence of tuple, or None
+ The color to use for different lines in this chart. This argument
+ can only be supplied by keyword.
+
+ For a line chart with just one line, this can be:
+
+ * None, to use the default color.
+ * A hex string like "#ffaa00" or "#ffaa0088".
+ * An RGB or RGBA tuple with the red, green, blue, and alpha
+ components specified as ints from 0 to 255 or floats from 0.0 to
+ 1.0.
+
+ For a line chart with multiple lines, where the dataframe is in
+ long format (that is, y is None or just one column), this can be:
+
+ * None, to use the default colors.
+ * The name of a column in the dataset. Data points will be grouped
+ into lines of the same color based on the value of this column.
+ In addition, if the values in this column match one of the color
+ formats above (hex string or color tuple), then that color will
+ be used.
+
+ For example: if the dataset has 1000 rows, but this column can
+ only contains the values "adult", "child", "baby", then
+ those 1000 datapoints will be grouped into three lines, whose
+ colors will be automatically selected from the default palette.
+
+ But, if for the same 1000-row dataset, this column contained
+ the values "#ffaa00", "#f0f", "#0000ff", then then those 1000
+ datapoints would still be grouped into three lines, but their
+ colors would be "#ffaa00", "#f0f", "#0000ff" this time around.
+
+ For a line chart with multiple lines, where the dataframe is in
+ wide format (that is, y is a sequence of columns), this can be:
+
+ * None, to use the default colors.
+ * A list of string colors or color tuples to be used for each of
+ the lines in the chart. This list should have the same length
+ as the number of y values (e.g. ``color=["#fd0", "#f0f", "#04f"]``
+ for three lines).
width : int
The chart width in pixels. If 0, selects the width automatically.
@@ -109,8 +183,8 @@ def _arrow_line_chart(
precedence over the width argument.
This argument can only be supplied by keyword.
- Example
- -------
+ Examples
+ --------
>>> import streamlit as st
>>> import pandas as pd
>>> import numpy as np
@@ -118,28 +192,78 @@ def _arrow_line_chart(
>>> chart_data = pd.DataFrame(
... np.random.randn(20, 3),
... columns=['a', 'b', 'c'])
- ...
+ >>>
>>> st._arrow_line_chart(chart_data)
.. output::
https://static.streamlit.io/0.50.0-td2L/index.html?id=BdxXG3MmrVBfJyqS2R2ki8
height: 220px
+ You can also choose different columns to use for x and y, as well as set
+ the color dynamically based on a 3rd column (assuming your dataframe is in
+ long format):
+
+ >>> import streamlit as st
+ >>> import pandas as pd
+ >>> import numpy as np
+ >>>
+ >>> chart_data = pd.DataFrame({
+ ... 'col1' : np.random.randn(20),
+ ... 'col2' : np.random.randn(20),
+ ... 'col3' : np.random.choice(['A','B','C'], 20)
+ ... })
+ >>>
+ >>> st._arrow_line_chart(
+ ... chart_data,
+ ... x = 'col1',
+ ... y = 'col2',
+ ... color = 'col3'
+ ... )
+
+ Finally, if your dataframe is in wide format, you can group multiple
+ columns under the y argument to show multiple lines with different
+ colors:
+
+ >>> import streamlit as st
+ >>> import pandas as pd
+ >>> import numpy as np
+ >>>
+ >>> chart_data = pd.DataFrame(
+ ... np.random.randn(20, 3),
+ ... columns = ['col1', 'col2', 'col3'])
+ >>>
+ >>> st._arrow_line_chart(
+ ... chart_data,
+ ... x = 'col1',
+ ... y = ['col2', 'col3'],
+ ... color = ['#FF0000', '#0000FF'] # Optional
+ ... )
+
"""
proto = ArrowVegaLiteChartProto()
- chart = _generate_chart(ChartType.LINE, data, x, y, width, height)
+ chart, add_rows_metadata = _generate_chart(
+ chart_type=ChartType.LINE,
+ data=data,
+ x_from_user=x,
+ y_from_user=y,
+ color_from_user=color,
+ width=width,
+ height=height,
+ )
marshall(proto, chart, use_container_width, theme="streamlit")
- last_index = last_index_for_melted_dataframes(data)
- return self.dg._enqueue("arrow_line_chart", proto, last_index=last_index)
+ return self.dg._enqueue(
+ "arrow_line_chart", proto, add_rows_metadata=add_rows_metadata
+ )
@gather_metrics("_arrow_area_chart")
def _arrow_area_chart(
self,
data: Data = None,
*,
- x: Union[str, None] = None,
+ x: Optional[str] = None,
y: Union[str, Sequence[str], None] = None,
+ color: Union[str, Color, List[Color], None] = None,
width: int = 0,
height: int = 0,
use_container_width: bool = True,
@@ -160,14 +284,56 @@ def _arrow_area_chart(
Data to be plotted.
x : str or None
- Column name to use for the x-axis. If None, uses the data index for the x-axis.
- This argument can only be supplied by keyword.
+ Column name to use for the x-axis. If None, uses the data index for
+ the x-axis. This argument can only be supplied by keyword.
y : str, sequence of str, or None
- Column name(s) to use for the y-axis. If a sequence of strings, draws several series
- on the same chart by melting your wide-format table into a long-format table behind
- the scenes. If None, draws the data of all remaining columns as data series.
- This argument can only be supplied by keyword.
+ Column name(s) to use for the y-axis. If a sequence of strings,
+ draws several series on the same chart by melting your wide-format
+ table into a long-format table behind the scenes. If None, draws
+ the data of all remaining columns as data series. This argument can
+ only be supplied by keyword.
+
+ color : str, tuple, sequence of str, sequence of tuple, or None
+ The color to use for different series in this chart. This argument
+ can only be supplied by keyword.
+
+ For an area chart with just 1 series, this can be:
+
+ * None, to use the default color.
+ * A hex string like "#ffaa00" or "#ffaa0088".
+ * An RGB or RGBA tuple with the red, green, blue, and alpha
+ components specified as ints from 0 to 255 or floats from 0.0 to
+ 1.0.
+
+ For an area chart with multiple series, where the dataframe is in
+ long format (that is, y is None or just one column), this can be:
+
+ * None, to use the default colors.
+ * The name of a column in the dataset. Data points will be grouped
+ into series of the same color based on the value of this column.
+ In addition, if the values in this column match one of the color
+ formats above (hex string or color tuple), then that color will
+ be used.
+
+ For example: if the dataset has 1000 rows, but this column can
+ only contains the values "adult", "child", "baby",
+ then those 1000 datapoints will be grouped into 3 series, whose
+ colors will be automatically selected from the default palette.
+
+ But, if for the same 1000-row dataset, this column contained
+ the values "#ffaa00", "#f0f", "#0000ff", then then those 1000
+ datapoints would still be grouped into 3 series, but their
+ colors would be "#ffaa00", "#f0f", "#0000ff" this time around.
+
+ For an area chart with multiple series, where the dataframe is in
+ wide format (that is, y is a sequence of columns), this can be:
+
+ * None, to use the default colors.
+ * A list of string colors or color tuples to be used for each of
+ the series in the chart. This list should have the same length
+ as the number of y values (e.g. ``color=["#fd0", "#f0f", "#04f"]``
+ for three lines).
width : int
The chart width in pixels. If 0, selects the width automatically.
@@ -189,30 +355,80 @@ def _arrow_area_chart(
>>>
>>> chart_data = pd.DataFrame(
... np.random.randn(20, 3),
- ... columns=['a', 'b', 'c'])
- ...
+ ... columns = ['a', 'b', 'c'])
+ >>>
>>> st._arrow_area_chart(chart_data)
.. output::
https://static.streamlit.io/0.50.0-td2L/index.html?id=Pp65STuFj65cJRDfhGh4Jt
height: 220px
+ You can also choose different columns to use for x and y, as well as set
+ the color dynamically based on a 3rd column (assuming your dataframe is in
+ long format):
+
+ >>> import streamlit as st
+ >>> import pandas as pd
+ >>> import numpy as np
+ >>>
+ >>> chart_data = pd.DataFrame({
+ ... 'col1' : np.random.randn(20),
+ ... 'col2' : np.random.randn(20),
+ ... 'col3' : np.random.choice(['A', 'B', 'C'], 20)
+ ... })
+ >>>
+ >>> st._arrow_area_chart(
+ ... chart_data,
+ ... x = 'col1',
+ ... y = 'col2',
+ ... color = 'col3'
+ ... )
+
+ Finally, if your dataframe is in wide format, you can group multiple
+ columns under the y argument to show multiple lines with different
+ colors:
+
+ >>> import streamlit as st
+ >>> import pandas as pd
+ >>> import numpy as np
+ >>>
+ >>> chart_data = pd.DataFrame(
+ ... np.random.randn(20, 3),
+ ... columns=['col1', 'col2', 'col3'])
+ ...
+ >>> st._arrow_area_chart(
+ ... chart_data,
+ ... x='col1',
+ ... y=['col2', 'col3'],
+ ... color=['#FF0000','#0000FF']
+ ... )
+
"""
proto = ArrowVegaLiteChartProto()
- chart = _generate_chart(ChartType.AREA, data, x, y, width, height)
+ chart, add_rows_metadata = _generate_chart(
+ chart_type=ChartType.AREA,
+ data=data,
+ x_from_user=x,
+ y_from_user=y,
+ color_from_user=color,
+ width=width,
+ height=height,
+ )
marshall(proto, chart, use_container_width, theme="streamlit")
- last_index = last_index_for_melted_dataframes(data)
- return self.dg._enqueue("arrow_area_chart", proto, last_index=last_index)
+ return self.dg._enqueue(
+ "arrow_area_chart", proto, add_rows_metadata=add_rows_metadata
+ )
@gather_metrics("_arrow_bar_chart")
def _arrow_bar_chart(
self,
data: Data = None,
*,
- x: Union[str, None] = None,
+ x: Optional[str] = None,
y: Union[str, Sequence[str], None] = None,
+ color: Union[str, Color, List[Color], None] = None,
width: int = 0,
height: int = 0,
use_container_width: bool = True,
@@ -233,14 +449,56 @@ def _arrow_bar_chart(
Data to be plotted.
x : str or None
- Column name to use for the x-axis. If None, uses the data index for the x-axis.
- This argument can only be supplied by keyword.
+ Column name to use for the x-axis. If None, uses the data index
+ for the x-axis. This argument can only be supplied by keyword.
y : str, sequence of str, or None
- Column name(s) to use for the y-axis. If a sequence of strings, draws several series
- on the same chart by melting your wide-format table into a long-format table behind
- the scenes. If None, draws the data of all remaining columns as data series.
- This argument can only be supplied by keyword.
+ Column name(s) to use for the y-axis. If a sequence of strings,
+ draws several series on the same chart by melting your wide-format
+ table into a long-format table behind the scenes. If None, draws
+ the data of all remaining columns as data series. This argument
+ can only be supplied by keyword.
+
+ color : str, tuple, sequence of str, sequence of tuple, or None
+ The color to use for different series in this chart. This argument
+ can only be supplied by keyword.
+
+ For a bar chart with just 1 series, this can be:
+
+ * None, to use the default color.
+ * A hex string like "#ffaa00" or "#ffaa0088".
+ * An RGB or RGBA tuple with the red, green, blue, and alpha
+ components specified as ints from 0 to 255 or floats from 0.0 to
+ 1.0.
+
+ For a bar chart with multiple series, where the dataframe is in
+ long format (that is, y is None or just one column), this can be:
+
+ * None, to use the default colors.
+ * The name of a column in the dataset. Data points will be grouped
+ into series of the same color based on the value of this column.
+ In addition, if the values in this column match one of the color
+ formats above (hex string or color tuple), then that color will
+ be used.
+
+ For example: if the dataset has 1000 rows, but this column can
+ only contains the values "adult", "child", "baby",
+ then those 1000 datapoints will be grouped into 3 series, whose
+ colors will be automatically selected from the default palette.
+
+ But, if for the same 1000-row dataset, this column contained
+ the values "#ffaa00", "#f0f", "#0000ff", then then those 1000
+ datapoints would still be grouped into 3 series, but their
+ colors would be "#ffaa00", "#f0f", "#0000ff" this time around.
+
+ For a bar chart with multiple series, where the dataframe is in
+ wide format (that is, y is a sequence of columns), this can be:
+
+ * None, to use the default colors.
+ * A list of string colors or color tuples to be used for each of
+ the series in the chart. This list should have the same length
+ as the number of y values (e.g. ``color=["#fd0", "#f0f", "#04f"]``
+ for three lines).
width : int
The chart width in pixels. If 0, selects the width automatically.
@@ -271,19 +529,67 @@ def _arrow_bar_chart(
https://static.streamlit.io/0.66.0-2BLtg/index.html?id=GaYDn6vxskvBUkBwsGVEaL
height: 220px
+ You can also choose different columns to use for x and y, as well as set
+ the color dynamically based on a 3rd column (assuming your dataframe is in
+ long format):
+
+ >>> import streamlit as st
+ >>> import pandas as pd
+ >>> import numpy as np
+ >>>
+ >>> chart_data = pd.DataFrame({
+ ... 'col1' : np.random.randn(20),
+ ... 'col2' : np.random.randn(20),
+ ... 'col3' : np.random.choice(['A','B','C'],20)
+ ... })
+ >>>
+ >>> st._arrow_bar_chart(
+ ... chart_data,
+ ... x='col1',
+ ... y='col2',
+ ... color='col3'
+ ... )
+ Finally, if your dataframe is in wide format, you can group multiple
+ columns under the y argument to show multiple lines with different
+ colors:
+
+ >>> import streamlit as st
+ >>> import pandas as pd
+ >>> import numpy as np
+ >>>
+ >>> chart_data = pd.DataFrame(
+ ... np.random.randn(20, 3),
+ ... columns=['col1', 'col2', 'col3'])
+ ...
+ >>> st._arrow_bar_chart(
+ ... chart_data,
+ ... x='col1',
+ ... y=['col2', 'col3'],
+ ... color=['#FF0000','#0000FF']
+ ... )
+
"""
proto = ArrowVegaLiteChartProto()
- chart = _generate_chart(ChartType.BAR, data, x, y, width, height)
+ chart, add_rows_metadata = _generate_chart(
+ chart_type=ChartType.BAR,
+ data=data,
+ x_from_user=x,
+ y_from_user=y,
+ color_from_user=color,
+ width=width,
+ height=height,
+ )
marshall(proto, chart, use_container_width, theme="streamlit")
- last_index = last_index_for_melted_dataframes(data)
- return self.dg._enqueue("arrow_bar_chart", proto, last_index=last_index)
+ return self.dg._enqueue(
+ "arrow_bar_chart", proto, add_rows_metadata=add_rows_metadata
+ )
@gather_metrics("_arrow_altair_chart")
def _arrow_altair_chart(
self,
- altair_chart: Chart,
+ altair_chart: alt.Chart,
use_container_width: bool = False,
theme: Union[None, Literal["streamlit"]] = "streamlit",
) -> DeltaGenerator:
@@ -342,7 +648,7 @@ def dg(self) -> DeltaGenerator:
return cast("DeltaGenerator", self)
-def _is_date_column(df: pd.DataFrame, name: str) -> bool:
+def _is_date_column(df: pd.DataFrame, name: Optional[str]) -> bool:
"""True if the column with the given name stores datetime.date values.
This function just checks the first value in the given column, so
@@ -359,6 +665,9 @@ def _is_date_column(df: pd.DataFrame, name: str) -> bool:
bool
"""
+ if name is None:
+ return False
+
column = df[name]
if column.size == 0:
return False
@@ -367,23 +676,23 @@ def _is_date_column(df: pd.DataFrame, name: str) -> bool:
def _melt_data(
- data_df: pd.DataFrame,
+ df: pd.DataFrame,
x_column: str,
- y_column: str,
- color_column: str,
- value_columns: Optional[List[str]] = None,
+ y_column_list: Optional[List[str]],
+ new_y_column_name: str,
+ new_color_column_name: str,
) -> pd.DataFrame:
"""Converts a wide-format dataframe to a long-format dataframe."""
- data_df = pd.melt(
- data_df,
+ melted_df = pd.melt(
+ df,
id_vars=[x_column],
- value_vars=value_columns,
- var_name=color_column,
- value_name=y_column,
+ value_vars=y_column_list,
+ var_name=new_color_column_name,
+ value_name=new_y_column_name,
)
- y_series = data_df[y_column]
+ y_series = melted_df[new_y_column_name]
if (
y_series.dtype == "object"
and "mixed" in infer_dtype(y_series)
@@ -395,190 +704,576 @@ def _melt_data(
# Arrow has problems with object types after melting two different dtypes
# pyarrow.lib.ArrowTypeError: "Expected a object, got a object"
- data_df = type_util.fix_arrow_incompatible_column_types(
- data_df, selected_columns=[x_column, color_column, y_column]
+ fixed_df = type_util.fix_arrow_incompatible_column_types(
+ melted_df, selected_columns=[x_column, new_color_column_name, new_y_column_name]
)
- return data_df
+ return fixed_df
-def _maybe_melt(
- data_df: pd.DataFrame,
- x: Union[str, None] = None,
- y: Union[str, Sequence[str], None] = None,
-) -> Tuple[pd.DataFrame, str, str, str, str, Optional[str], Optional[str]]:
- """Determines based on the selected x & y parameter, if the data needs to
- be converted to a long-format dataframe. If so, it returns the melted dataframe
- and the x, y, and color columns used for rendering the chart.
+def prep_data(
+ df: pd.DataFrame,
+ x_column: Optional[str],
+ y_column_list: List[str],
+ color_column: Optional[str],
+) -> Tuple[pd.DataFrame, Optional[str], Optional[str], Optional[str]]:
+ """Prepares the data for charting. This is also used in add_rows.
+
+ Returns the prepared dataframe and the new names of the x column (taking the index reset into
+ consideration) as well as the y and color columns.
"""
- color_column: Optional[str]
- # This has to contain an empty space, otherwise the
- # full y-axis disappears (maybe a bug in vega-lite)?
- color_title: Optional[str] = " "
+ # If y is provided, but x is not, we'll use the index as x.
+ # So we need to pull the index into its own column.
+ x_column = _maybe_reset_index_in_place(df, x_column, y_column_list)
- y_column = "value"
- y_title = ""
+ # Drop columns we're not using.
+ selected_data = _drop_unused_columns(df, x_column, color_column, *y_column_list)
- if x and isinstance(x, str):
- # x is a single string -> use for x-axis
- x_column = x
- x_title = x
- if x_column not in data_df.columns:
- raise StreamlitAPIException(
- f"{x_column} (x parameter) was not found in the data columns or keys”."
- )
- else:
- # use index for x-axis
- x_column = data_df.index.name or "index"
- x_title = ""
- data_df = data_df.reset_index()
-
- if y and type_util.is_sequence(y) and len(y) == 1:
- # Sequence is only a single element
- y = str(y[0])
+ # Maybe convert color to Vega colors.
+ _maybe_convert_color_column_in_place(selected_data, color_column)
- if y and isinstance(y, str):
- # y is a single string -> use for y-axis
- y_column = y
- y_title = y
- if y_column not in data_df.columns:
- raise StreamlitAPIException(
- f"{y_column} (y parameter) was not found in the data columns or keys”."
- )
-
- # Set var name to None since it should not be used
- color_column = None
- elif y and type_util.is_sequence(y):
- color_column = "variable"
- # y is a list -> melt dataframe into value vars provided in y
- value_columns: List[str] = []
- for col in y:
- if str(col) not in data_df.columns:
- raise StreamlitAPIException(
- f"{str(col)} in y parameter was not found in the data columns or keys”."
- )
- value_columns.append(str(col))
+ # Make sure all columns have string names.
+ x_column, y_column_list, color_column = _convert_col_names_to_str_in_place(
+ selected_data, x_column, y_column_list, color_column
+ )
- if x_column in [y_column, color_column]:
- raise StreamlitAPIException(
- f"Unable to melt the table. Please rename the columns used for x ({x_column}) or y ({y_column})."
- )
+ # Maybe melt data from wide format into long format.
+ melted_data, y_column, color_column = _maybe_melt(
+ selected_data, x_column, y_column_list, color_column
+ )
- data_df = _melt_data(data_df, x_column, y_column, color_column, value_columns)
- else:
- color_column = "variable"
- # -> data will be melted into the value prop for y
- data_df = _melt_data(data_df, x_column, y_column, color_column)
-
- relevant_columns = []
- if x_column and x_column not in relevant_columns:
- relevant_columns.append(x_column)
- if color_column and color_column not in relevant_columns:
- relevant_columns.append(color_column)
- if y_column and y_column not in relevant_columns:
- relevant_columns.append(y_column)
- # Only select the relevant columns required for the chart
- # Other columns can be ignored
- data_df = data_df[relevant_columns]
- return data_df, x_column, x_title, y_column, y_title, color_column, color_title
+ # Return the data, but also the new names to use for x, y, and color.
+ return melted_data, x_column, y_column, color_column
def _generate_chart(
chart_type: ChartType,
- data: Data,
- x: Union[str, None] = None,
- y: Union[str, Sequence[str], None] = None,
+ data: Optional[Data],
+ x_from_user: Optional[str] = None,
+ y_from_user: Union[str, Sequence[str], None] = None,
+ color_from_user: Union[str, Color, List[Color], None] = None,
width: int = 0,
height: int = 0,
-) -> Chart:
+) -> alt.Chart:
"""Function to use the chart's type, data columns and indices to figure out the chart's spec."""
import altair as alt
- if data is None:
- # Use an empty-ish dict because if we use None the x axis labels rotate
- # 90 degrees. No idea why. Need to debug.
- data = {"": []}
+ df = type_util.convert_anything_to_df(data, ensure_copy=True)
+
+ # From now on, use "df" instead of "data". Deleting "data" to guarantee we follow this.
+ del data
+
+ # Convert arguments received from the user to things Vega-Lite understands.
+ # Get name of column to use for x.
+ x_column = _parse_x_column(df, x_from_user)
+ # Get name of columns to use for y.
+ y_column_list = _parse_y_columns(df, y_from_user, x_column)
+ # Get name of column to use for color, or constant value to use. Any/both could be None.
+ color_column, color_value = _parse_color_column(df, color_from_user)
+
+ # Store some info so we can use it in add_rows.
+ add_rows_metadata = AddRowsMetadata(
+ # The last index of df so we can adjust the input df in add_rows:
+ last_index=last_index_for_melted_dataframes(df),
+ # This is the input to prep_data (except for the df):
+ columns=dict(
+ x_column=x_column,
+ y_column_list=y_column_list,
+ color_column=color_column,
+ ),
+ )
+
+ # At this point, all foo_column variables are either None/empty or contain actual
+ # columns that are guaranteed to exist.
- if not isinstance(data, pd.DataFrame):
- data = type_util.convert_anything_to_df(data)
+ df, x_column, y_column, color_column = prep_data(
+ df, x_column, y_column_list, color_column
+ )
+
+ # At this point, x_column is only None if user did not provide one AND df is empty.
+
+ # Create a Chart with x and y encodings.
+ chart = alt.Chart(
+ data=df,
+ mark=chart_type.value,
+ width=width,
+ height=height,
+ ).encode(
+ x=_get_x_encoding(df, x_column, x_from_user, chart_type),
+ y=_get_y_encoding(df, y_column, y_from_user),
+ )
- data, x_column, x_title, y_column, y_title, color_column, color_title = _maybe_melt(
- data, x, y
+ # Set up opacity encoding.
+ opacity_enc = _get_opacity_encoding(chart_type, color_column)
+ if opacity_enc is not None:
+ chart = chart.encode(opacity=opacity_enc)
+
+ # Set up color encoding.
+ color_enc = _get_color_encoding(
+ df, color_value, color_column, y_column_list, color_from_user
+ )
+ if color_enc is not None:
+ chart = chart.encode(color=color_enc)
+
+ # Set up tooltip encoding.
+ if x_column is not None and y_column is not None:
+ chart = chart.encode(
+ tooltip=_get_tooltip_encoding(
+ x_column,
+ y_column,
+ color_column,
+ color_enc,
+ )
+ )
+
+ return chart.interactive(), add_rows_metadata
+
+
+def _maybe_reset_index_in_place(
+ df: pd.DataFrame, x_column: Optional[str], y_column_list: List[str]
+) -> Optional[str]:
+ if x_column is None and len(y_column_list) > 0:
+ if df.index.name is None:
+ # Pick column name that is unlikely to collide with user-given names.
+ x_column = SEPARATED_INDEX_COLUMN_NAME
+ else:
+ # Reuse index's name for the new column.
+ x_column = df.index.name
+
+ df.index.name = x_column
+ df.reset_index(inplace=True)
+
+ return x_column
+
+
+def _drop_unused_columns(
+ df: pd.DataFrame, *column_names: Optional[str]
+) -> pd.DataFrame:
+ """Returns a subset of df, selecting only column_names that aren't None."""
+
+ # We can't just call set(col_names) because sets don't have stable ordering,
+ # which means tests that depend on ordering will fail.
+ # Performance-wise, it's not a problem, though, since this function is only ever
+ # used on very small lists.
+ seen = set()
+ keep = []
+
+ for x in column_names:
+ if x is None:
+ continue
+ if x in seen:
+ continue
+ seen.add(x)
+ keep.append(x)
+
+ return df[keep]
+
+
+def _maybe_convert_color_column_in_place(df: pd.DataFrame, color_column: Optional[str]):
+ """If needed, convert color column to a format Vega understands."""
+ if color_column is None or len(df[color_column]) == 0:
+ return
+
+ first_color_datum = df[color_column][0]
+
+ if is_hex_color_like(first_color_datum):
+ # Hex is already CSS-valid.
+ pass
+ elif is_color_tuple_like(first_color_datum):
+ # Tuples need to be converted to CSS-valid.
+ df[color_column] = df[color_column].map(to_css_color)
+ else:
+ # Other kinds of colors columns (i.e. pure numbers or nominal strings) shouldn't
+ # be converted since they are treated by Vega-Lite as sequential or categorical colors.
+ pass
+
+
+def _convert_col_names_to_str_in_place(
+ df: pd.DataFrame,
+ x_column: Optional[str],
+ y_column_list: List[str],
+ color_column: Optional[str],
+) -> Tuple[Optional[str], List[str], Optional[str]]:
+ """Converts column names to strings, since Vega-Lite does not accept ints, etc."""
+ column_names = list(df.columns) # list() converts RangeIndex, etc, to regular list.
+ str_column_names = [str(c) for c in column_names]
+ df.columns = pd.Index(str_column_names)
+
+ return (
+ None if x_column is None else str(x_column),
+ [str(c) for c in y_column_list],
+ None if color_column is None else str(color_column),
)
- opacity = None
- if chart_type == ChartType.AREA and color_column:
- opacity = {y_column: 0.7}
+
+def _parse_color_column(
+ df: pd.DataFrame, column_or_value: Any
+) -> Tuple[Optional[str], Any]:
+ if isinstance(column_or_value, str) and column_or_value in df.columns:
+ column_name = column_or_value
+ value = None
+ else:
+ column_name = None
+ value = column_or_value
+
+ return column_name, value
+
+
+def _parse_x_column(df: pd.DataFrame, x_from_user: Optional[str]) -> Optional[str]:
+ if x_from_user is None:
+ return None
+
+ elif isinstance(x_from_user, str):
+ if x_from_user not in df.columns:
+ raise StreamlitColumnNotFoundError(df, x_from_user)
+
+ return x_from_user
+
+ else:
+ raise StreamlitAPIException(
+ "x parameter should be a column name (str) or None to use the "
+ f" dataframe's index. Value given: {x_from_user} "
+ f"(type {type(x_from_user)})"
+ )
+
+
+def _parse_y_columns(
+ df: pd.DataFrame,
+ y_from_user: Union[str, Sequence[str], None],
+ x_column: Union[str, None],
+) -> List[str]:
+
+ y_column_list: List[str] = []
+
+ if y_from_user is None:
+ y_column_list = list(df.columns)
+
+ elif isinstance(y_from_user, str):
+ y_column_list = [y_from_user]
+
+ elif type_util.is_sequence(y_from_user):
+ y_column_list = list(str(col) for col in y_from_user)
+
+ else:
+ raise StreamlitAPIException(
+ "y parameter should be a column name (str) or list thereof. "
+ f"Value given: {y_from_user} (type {type(y_from_user)})"
+ )
+
+ for col in y_column_list:
+ if col not in df.columns:
+ raise StreamlitColumnNotFoundError(df, col)
+
+ # y_column_list should only include x_column when user explicitly asked for it.
+ if x_column in y_column_list and (not y_from_user or x_column not in y_from_user):
+ y_column_list.remove(x_column)
+
+ return y_column_list
+
+
+def _get_opacity_encoding(
+ chart_type: ChartType, color_column: Optional[str]
+) -> Optional[alt.OpacityValue]:
+ import altair as alt
+
+ if color_column and chart_type == ChartType.AREA:
+ return alt.OpacityValue(0.7)
+
+ return None
+
+
+def _get_scale(df: pd.DataFrame, column_name: Optional[str]) -> alt.Scale:
+ import altair as alt
+
# Set the X and Y axes' scale to "utc" if they contain date values.
# This causes time data to be displayed in UTC, rather the user's local
# time zone. (By default, vega-lite displays time data in the browser's
# local time zone, regardless of which time zone the data specifies:
# https://vega.github.io/vega-lite/docs/timeunit.html#output).
- x_scale = (
- alt.Scale(type="utc") if _is_date_column(data, x_column) else alt.Undefined
- )
- y_scale = (
- alt.Scale(type="utc") if _is_date_column(data, y_column) else alt.Undefined
- )
+ if _is_date_column(df, column_name):
+ return alt.Scale(type="utc")
- x_type = alt.Undefined
- # Bar charts should have a discrete (ordinal) x-axis, UNLESS type is date/time
- # https://github.com/streamlit/streamlit/pull/2097#issuecomment-714802475
- if chart_type == ChartType.BAR and not _is_date_column(data, x_column):
- x_type = "ordinal"
+ return alt.Scale()
+
+
+def _get_axis_config(
+ df: pd.DataFrame, column_name: Optional[str], grid: bool
+) -> alt.Axis:
+ import altair as alt
+
+ if column_name is not None and is_integer_dtype(df[column_name]):
+ # Use a max tick size of 1 for integer columns (prevents zoom into float numbers)
+ # and deactivate grid lines for x-axis
+ return alt.Axis(tickMinStep=1, grid=grid)
+
+ return alt.Axis(grid=grid)
- # Use a max tick size of 1 for integer columns (prevents zoom into float numbers)
- # and deactivate grid lines for x-axis
- x_axis_config = alt.Axis(
- tickMinStep=1 if is_integer_dtype(data[x_column]) else alt.Undefined, grid=False
+
+def _maybe_melt(
+ df: pd.DataFrame,
+ x_column: Optional[str],
+ y_column_list: List[str],
+ color_column: Optional[str],
+) -> Tuple[pd.DataFrame, Optional[str], Optional[str]]:
+ """If multiple columns are set for y, melt the dataframe into long format."""
+ y_column: Optional[str]
+
+ if len(y_column_list) == 0:
+ y_column = None
+ elif len(y_column_list) == 1:
+ y_column = y_column_list[0]
+ elif x_column is not None:
+ # Pick column names that are unlikely to collide with user-given names.
+ y_column = MELTED_Y_COLUMN_NAME
+ color_column = MELTED_COLOR_COLUMN_NAME
+
+ df = _melt_data(
+ df=df,
+ x_column=x_column,
+ y_column_list=y_column_list,
+ new_y_column_name=y_column,
+ new_color_column_name=color_column,
+ )
+
+ return df, y_column, color_column
+
+
+def _get_x_encoding(
+ df: pd.DataFrame,
+ x_column: Optional[str],
+ x_from_user: Optional[str],
+ chart_type: ChartType,
+) -> alt.X:
+ import altair as alt
+
+ if x_column is None:
+ # If no field is specified, the full axis disappears when no data is present.
+ # Maybe a bug in vega-lite? So we pass a field that doesn't exist.
+ x_field = NON_EXISTENT_COLUMN_NAME
+ x_title = ""
+ elif x_column == SEPARATED_INDEX_COLUMN_NAME:
+ # If the x column name is the crazy anti-collision name we gave it, then need to set
+ # up a title so we never show the crazy name to the user.
+ x_field = x_column
+ # Don't show a label in the x axis (not even a nice label like
+ # SEPARATED_INDEX_COLUMN_TITLE) when we pull the x axis from the index.
+ x_title = ""
+ else:
+ x_field = x_column
+
+ # Only show a label in the x axis if the user passed a column explicitly. We
+ # could go either way here, but I'm keeping this to avoid breaking the existing
+ # behavior.
+ if x_from_user is None:
+ x_title = ""
+ else:
+ x_title = x_column
+
+ return alt.X(
+ x_field,
+ title=x_title,
+ type=_get_x_encoding_type(df, chart_type, x_column),
+ scale=_get_scale(df, x_column),
+ axis=_get_axis_config(df, x_column, grid=False),
)
- y_axis_config = alt.Axis(
- tickMinStep=1 if is_integer_dtype(data[y_column]) else alt.Undefined
+
+
+def _get_y_encoding(
+ df: pd.DataFrame,
+ y_column: Optional[str],
+ y_from_user: Union[str, Sequence[str], None],
+) -> alt.Y:
+ import altair as alt
+
+ if y_column is None:
+ # If no field is specified, the full axis disappears when no data is present.
+ # Maybe a bug in vega-lite? So we pass a field that doesn't exist.
+ y_field = NON_EXISTENT_COLUMN_NAME
+ y_title = ""
+ elif y_column == MELTED_Y_COLUMN_NAME:
+ # If the y column name is the crazy anti-collision name we gave it, then need to set
+ # up a title so we never show the crazy name to the user.
+ y_field = y_column
+ # Don't show a label in the y axis (not even a nice label like
+ # MELTED_Y_COLUMN_TITLE) when we pull the x axis from the index.
+ y_title = ""
+ else:
+ y_field = y_column
+
+ # Only show a label in the y axis if the user passed a column explicitly. We
+ # could go either way here, but I'm keeping this to avoid breaking the existing
+ # behavior.
+ if y_from_user is None:
+ y_title = ""
+ else:
+ y_title = y_column
+
+ return alt.Y(
+ field=y_field,
+ title=y_title,
+ type=_get_y_encoding_type(df, y_column),
+ scale=_get_scale(df, y_column),
+ axis=_get_axis_config(df, y_column, grid=True),
)
- tooltips = [
- alt.Tooltip(x_column, title=x_column),
- alt.Tooltip(y_column, title=y_column),
- ]
- color = None
-
- if color_column:
- color = alt.Color(
- color_column,
- title=color_title,
- type="nominal",
- legend=alt.Legend(titlePadding=0, offset=10, orient="bottom"),
+
+def _get_color_encoding(
+ df: pd.DataFrame,
+ color_value: Optional[Color],
+ color_column: Optional[str],
+ y_column_list: List[str],
+ color_from_user: Union[str, Color, List[Color], None],
+) -> alt.Color:
+ import altair as alt
+
+ has_color_value = color_value not in [None, [], tuple()]
+
+ # If user passed a color value, that should win over colors coming from the
+ # color column (be they manual or auto-assigned due to melting)
+ if has_color_value:
+
+ # If the color value is color-like, return that.
+ if is_color_like(cast(Any, color_value)):
+ if len(y_column_list) != 1:
+ raise StreamlitColorLengthError([color_value], y_column_list)
+
+ return alt.ColorValue(to_css_color(cast(Any, color_value)))
+
+ # If the color value is a list of colors of approriate length, return that.
+ elif isinstance(color_value, (list, tuple)):
+ color_values = cast(Collection[Color], color_value)
+
+ if len(color_values) != len(y_column_list):
+ raise StreamlitColorLengthError(color_values, y_column_list)
+
+ if len(color_value) == 1:
+ return alt.ColorValue(to_css_color(cast(Any, color_value[0])))
+ else:
+ return alt.Color(
+ field=color_column,
+ scale=alt.Scale(range=[to_css_color(c) for c in color_values]),
+ legend=COLOR_LEGEND_SETTINGS,
+ type="nominal",
+ title=" ",
+ )
+
+ raise StreamlitInvalidColorError(df, color_from_user)
+
+ elif color_column is not None:
+ column_type: Union[str, Tuple[str, List[Any]]]
+
+ if color_column == MELTED_COLOR_COLUMN_NAME:
+ column_type = "nominal"
+ else:
+ column_type = type_util.infer_vegalite_type(df[color_column])
+
+ color_enc = alt.Color(
+ field=color_column, legend=COLOR_LEGEND_SETTINGS, type=column_type
)
- tooltips.append(alt.Tooltip(color_column, title="label"))
-
- chart = getattr(
- alt.Chart(data, width=width, height=height),
- "mark_" + chart_type.value,
- )().encode(
- x=alt.X(
- x_column,
- title=x_title,
- scale=x_scale,
- type=x_type,
- axis=x_axis_config,
- ),
- y=alt.Y(y_column, title=y_title, scale=y_scale, axis=y_axis_config),
- tooltip=tooltips,
- )
- if color:
- chart = chart.encode(color=color)
+ # Fix title if DF was melted
+ if color_column == MELTED_COLOR_COLUMN_NAME:
+ # This has to contain an empty space, otherwise the
+ # full y-axis disappears (maybe a bug in vega-lite)?
+ color_enc["title"] = " "
+
+ # If the 0th element in the color column looks like a color, we'll use the color column's
+ # values as the colors in our chart.
+ elif len(df[color_column]) and is_color_like(df[color_column][0]):
+ color_range = [to_css_color(c) for c in df[color_column].unique()]
+ color_enc["scale"] = alt.Scale(range=color_range)
+ # Don't show the color legend, because it will just show text with the color values,
+ # like #f00, #00f, etc, which are not user-readable.
+ color_enc["legend"] = None
+
+ # Otherwise, let Vega-Lite auto-assign colors.
+ # This codepath is typically reached when the color column contains numbers (in which case
+ # Vega-Lite uses a color gradient to represent them) or strings (in which case Vega-Lite
+ # assigns one color for each unique value).
+ else:
+ pass
- if opacity:
- chart = chart.encode(opacity=opacity)
+ return color_enc
- return chart.interactive()
+ return None
+
+
+def _get_tooltip_encoding(
+ x_column: str,
+ y_column: str,
+ color_column: Optional[str],
+ color_enc: alt.Color,
+) -> list[alt.Tooltip]:
+ import altair as alt
+
+ tooltip = []
+
+ # If the x column name is the crazy anti-collision name we gave it, then need to set
+ # up a tooltip title so we never show the crazy name to the user.
+ if x_column == SEPARATED_INDEX_COLUMN_NAME:
+ tooltip.append(alt.Tooltip(x_column, title=SEPARATED_INDEX_COLUMN_TITLE))
+ else:
+ tooltip.append(alt.Tooltip(x_column))
+
+ # If the y column name is the crazy anti-collision name we gave it, then need to set
+ # up a tooltip title so we never show the crazy name to the user.
+ if y_column == MELTED_Y_COLUMN_NAME:
+ tooltip.append(
+ alt.Tooltip(
+ y_column,
+ title=MELTED_Y_COLUMN_TITLE,
+ type="quantitative", # Just picked something random. Doesn't really matter!
+ )
+ )
+ else:
+ tooltip.append(alt.Tooltip(y_column))
+
+ # If we earlier decided that there should be no color legend, that's because the
+ # user passed a color column with actual color values (like "#ff0"), so we should
+ # not show the color values in the tooltip.
+ if color_column and getattr(color_enc, "legend", True) is not None:
+ # Use a human-readable title for the color.
+ if color_column == MELTED_COLOR_COLUMN_NAME:
+ tooltip.append(
+ alt.Tooltip(
+ color_column,
+ title=MELTED_COLOR_COLUMN_TITLE,
+ type="nominal",
+ )
+ )
+ else:
+ tooltip.append(alt.Tooltip(color_column))
+
+ return tooltip
+
+
+def _get_x_encoding_type(
+ df: pd.DataFrame, chart_type: ChartType, x_column: Optional[str]
+) -> Union[str, Tuple[str, List[Any]]]:
+ if x_column is None:
+ return "quantitative" # Anything. If None, Vega-Lite may hide the axis.
+
+ # Bar charts should have a discrete (ordinal) x-axis, UNLESS type is date/time
+ # https://github.com/streamlit/streamlit/pull/2097#issuecomment-714802475
+ if chart_type == ChartType.BAR and not _is_date_column(df, x_column):
+ return "ordinal"
+
+ return type_util.infer_vegalite_type(df[x_column])
+
+
+def _get_y_encoding_type(
+ df: pd.DataFrame, y_column: Optional[str]
+) -> Union[str, Tuple[str, List[Any]]]:
+ if y_column:
+ return type_util.infer_vegalite_type(df[y_column])
+
+ return "quantitative" # Pick anything. If undefined, Vega-Lite may hide the axis.
def marshall(
vega_lite_chart: ArrowVegaLiteChartProto,
- altair_chart: Chart,
+ altair_chart: alt.Chart,
use_container_width: bool = False,
theme: Union[None, Literal["streamlit"]] = "streamlit",
**kwargs: Any,
@@ -597,8 +1292,9 @@ def id_transform(data) -> Dict[str, str]:
"""Altair data transformer that returns a fake named dataset with the
object id.
"""
- datasets[id(data)] = data
- return {"name": str(id(data))}
+ name = str(id(data))
+ datasets[name] = data
+ return {"name": name}
alt.data_transformers.register("id", id_transform)
@@ -620,3 +1316,41 @@ def id_transform(data) -> Dict[str, str]:
theme=theme,
**kwargs,
)
+
+
+class StreamlitColumnNotFoundError(StreamlitAPIException):
+ def __init__(self, df, col_name, *args):
+ available_columns = ", ".join(str(c) for c in list(df.columns))
+ message = (
+ f'Data does not have a column named `"{col_name}"`. '
+ f"Available columns are `{available_columns}`"
+ )
+ super().__init__(message, *args)
+
+
+class StreamlitInvalidColorError(StreamlitAPIException):
+ def __init__(self, df, color_from_user, *args):
+ ", ".join(str(c) for c in list(df.columns))
+ message = f"""
+This does not look like a valid color argument: `{color_from_user}`.
+
+The color argument can be:
+
+* A hex string like "#ffaa00" or "#ffaa0088".
+* An RGB or RGBA tuple with the red, green, blue, and alpha
+ components specified as ints from 0 to 255 or floats from 0.0 to
+ 1.0.
+* The name of a column.
+* Or a list of colors, matching the number of y columns to draw.
+ """
+ super().__init__(message, *args)
+
+
+class StreamlitColorLengthError(StreamlitAPIException):
+ def __init__(self, color_values, y_column_list, *args):
+ message = (
+ f"The list of colors `{color_values}` must have the same "
+ "length as the list of columns to be colored "
+ f"`{y_column_list}`."
+ )
+ super().__init__(message, *args)
diff --git a/lib/streamlit/elements/dataframe_selector.py b/lib/streamlit/elements/dataframe_selector.py
index 7cb18a7c5685..14d4ccb9acd7 100644
--- a/lib/streamlit/elements/dataframe_selector.py
+++ b/lib/streamlit/elements/dataframe_selector.py
@@ -22,6 +22,7 @@
from typing_extensions import Literal
from streamlit import config
+from streamlit.color_util import Color
from streamlit.elements.lib.column_config_utils import ColumnConfigMappingInput
from streamlit.runtime.metrics_util import gather_metrics
@@ -234,6 +235,7 @@ def line_chart(
*,
x: Union[str, None] = None,
y: Union[str, Sequence[str], None] = None,
+ color: Union[str, Color, None] = None,
width: int = 0,
height: int = 0,
use_container_width: bool = True,
@@ -267,6 +269,47 @@ def line_chart(
the scenes. If None, draws the data of all remaining columns as data series.
This argument can only be supplied by keyword.
+ color : str, tuple, sequence of str, sequence of tuple, or None
+ The color to use for different lines in this chart. This argument
+ can only be supplied by keyword.
+
+ For a line chart with just one line, this can be:
+
+ * None, to use the default color.
+ * A hex string like "#ffaa00" or "#ffaa0088".
+ * An RGB or RGBA tuple with the red, green, blue, and alpha
+ components specified as ints from 0 to 255 or floats from 0.0 to
+ 1.0.
+
+ For a line chart with multiple lines, where the dataframe is in
+ long format (that is, y is None or just one column), this can be:
+
+ * None, to use the default colors.
+ * The name of a column in the dataset. Data points will be grouped
+ into lines of the same color based on the value of this column.
+ In addition, if the values in this column match one of the color
+ formats above (hex string or color tuple), then that color will
+ be used.
+
+ For example: if the dataset has 1000 rows, but this column can
+ only contains the values "adult", "child", "baby", then
+ those 1000 datapoints will be grouped into three lines, whose
+ colors will be automatically selected from the default palette.
+
+ But, if for the same 1000-row dataset, this column contained
+ the values "#ffaa00", "#f0f", "#0000ff", then then those 1000
+ datapoints would still be grouped into three lines, but their
+ colors would be "#ffaa00", "#f0f", "#0000ff" this time around.
+
+ For a line chart with multiple lines, where the dataframe is in
+ wide format (that is, y is a sequence of columns), this can be:
+
+ * None, to use the default colors.
+ * A list of string colors or color tuples to be used for each of
+ the lines in the chart. This list should have the same length
+ as the number of y values (e.g. ``color=["#fd0", "#f0f", "#04f"]``
+ for three lines).
+
width : int
The chart width in pixels. If 0, selects the width automatically.
This argument can only be supplied by keyword.
@@ -289,12 +332,60 @@ def line_chart(
>>> chart_data = pd.DataFrame(
... np.random.randn(20, 3),
... columns=['a', 'b', 'c'])
- ...
+ >>>
>>> st.line_chart(chart_data)
.. output::
https://doc-line-chart.streamlit.app/
- height: 400px
+ height: 440px
+
+ You can also choose different columns to use for x and y, as well as set
+ the color dynamically based on a 3rd column (assuming your dataframe is in
+ long format):
+
+ >>> import streamlit as st
+ >>> import pandas as pd
+ >>> import numpy as np
+ >>>
+ >>> chart_data = pd.DataFrame({
+ ... 'col1' : np.random.randn(20),
+ ... 'col2' : np.random.randn(20),
+ ... 'col3' : np.random.choice(['A','B','C'], 20)
+ ... })
+ >>>
+ >>> st.line_chart(
+ ... chart_data,
+ ... x = 'col1',
+ ... y = 'col2',
+ ... color = 'col3'
+ ... )
+
+ .. output::
+ https://doc-line-chart1.streamlit.app/
+ height: 440px
+
+ Finally, if your dataframe is in wide format, you can group multiple
+ columns under the y argument to show multiple lines with different
+ colors:
+
+ >>> import streamlit as st
+ >>> import pandas as pd
+ >>> import numpy as np
+ >>>
+ >>> chart_data = pd.DataFrame(
+ ... np.random.randn(20, 3),
+ ... columns = ['col1', 'col2', 'col3'])
+ >>>
+ >>> st.line_chart(
+ ... chart_data,
+ ... x = 'col1',
+ ... y = ['col2', 'col3'],
+ ... color = ['#FF0000', '#0000FF'] # Optional
+ ... )
+
+ .. output::
+ https://doc-line-chart2.streamlit.app/
+ height: 440px
"""
if _use_arrow():
@@ -302,6 +393,7 @@ def line_chart(
data,
x=x,
y=y,
+ color=color,
width=width,
height=height,
use_container_width=use_container_width,
@@ -321,6 +413,7 @@ def area_chart(
*,
x: Union[str, None] = None,
y: Union[str, Sequence[str], None] = None,
+ color: Union[str, Color, None] = None,
width: int = 0,
height: int = 0,
use_container_width: bool = True,
@@ -354,6 +447,47 @@ def area_chart(
the scenes. If None, draws the data of all remaining columns as data series.
This argument can only be supplied by keyword.
+ color : str, tuple, sequence of str, sequence of tuple, or None
+ The color to use for different series in this chart. This argument
+ can only be supplied by keyword.
+
+ For an area chart with just 1 series, this can be:
+
+ * None, to use the default color.
+ * A hex string like "#ffaa00" or "#ffaa0088".
+ * An RGB or RGBA tuple with the red, green, blue, and alpha
+ components specified as ints from 0 to 255 or floats from 0.0 to
+ 1.0.
+
+ For an area chart with multiple series, where the dataframe is in
+ long format (that is, y is None or just one column), this can be:
+
+ * None, to use the default colors.
+ * The name of a column in the dataset. Data points will be grouped
+ into series of the same color based on the value of this column.
+ In addition, if the values in this column match one of the color
+ formats above (hex string or color tuple), then that color will
+ be used.
+
+ For example: if the dataset has 1000 rows, but this column can
+ only contains the values "adult", "child", "baby",
+ then those 1000 datapoints will be grouped into 3 series, whose
+ colors will be automatically selected from the default palette.
+
+ But, if for the same 1000-row dataset, this column contained
+ the values "#ffaa00", "#f0f", "#0000ff", then then those 1000
+ datapoints would still be grouped into 3 series, but their
+ colors would be "#ffaa00", "#f0f", "#0000ff" this time around.
+
+ For an area chart with multiple series, where the dataframe is in
+ wide format (that is, y is a sequence of columns), this can be:
+
+ * None, to use the default colors.
+ * A list of string colors or color tuples to be used for each of
+ the series in the chart. This list should have the same length
+ as the number of y values (e.g. ``color=["#fd0", "#f0f", "#04f"]``
+ for three lines).
+
width : int
The chart width in pixels. If 0, selects the width automatically.
This argument can only be supplied by keyword.
@@ -375,13 +509,61 @@ def area_chart(
>>>
>>> chart_data = pd.DataFrame(
... np.random.randn(20, 3),
- ... columns=['a', 'b', 'c'])
- ...
+ ... columns = ['a', 'b', 'c'])
+ >>>
>>> st.area_chart(chart_data)
.. output::
https://doc-area-chart.streamlit.app/
- height: 400px
+ height: 440px
+
+ You can also choose different columns to use for x and y, as well as set
+ the color dynamically based on a 3rd column (assuming your dataframe is in
+ long format):
+
+ >>> import streamlit as st
+ >>> import pandas as pd
+ >>> import numpy as np
+ >>>
+ >>> chart_data = pd.DataFrame({
+ ... 'col1' : np.random.randn(20),
+ ... 'col2' : np.random.randn(20),
+ ... 'col3' : np.random.choice(['A', 'B', 'C'], 20)
+ ... })
+ >>>
+ >>> st.area_chart(
+ ... chart_data,
+ ... x = 'col1',
+ ... y = 'col2',
+ ... color = 'col3'
+ ... )
+
+ .. output::
+ https://doc-area-chart1.streamlit.app/
+ height: 440px
+
+ Finally, if your dataframe is in wide format, you can group multiple
+ columns under the y argument to show multiple series with different
+ colors:
+
+ >>> import streamlit as st
+ >>> import pandas as pd
+ >>> import numpy as np
+ >>>
+ >>> chart_data = pd.DataFrame(
+ ... np.random.randn(20, 3),
+ ... columns=['col1', 'col2', 'col3'])
+ ...
+ >>> st.area_chart(
+ ... chart_data,
+ ... x='col1',
+ ... y=['col2', 'col3'],
+ ... color=['#FF0000','#0000FF'] # Optional
+ ... )
+
+ .. output::
+ https://doc-area-chart2.streamlit.app/
+ height: 440px
"""
if _use_arrow():
@@ -389,6 +571,7 @@ def area_chart(
data,
x=x,
y=y,
+ color=color,
width=width,
height=height,
use_container_width=use_container_width,
@@ -408,6 +591,7 @@ def bar_chart(
*,
x: Union[str, None] = None,
y: Union[str, Sequence[str], None] = None,
+ color: Union[str, Color, None] = None,
width: int = 0,
height: int = 0,
use_container_width: bool = True,
@@ -441,6 +625,47 @@ def bar_chart(
the scenes. If None, draws the data of all remaining columns as data series.
This argument can only be supplied by keyword.
+ color : str, tuple, sequence of str, sequence of tuple, or None
+ The color to use for different series in this chart. This argument
+ can only be supplied by keyword.
+
+ For a bar chart with just 1 series, this can be:
+
+ * None, to use the default color.
+ * A hex string like "#ffaa00" or "#ffaa0088".
+ * An RGB or RGBA tuple with the red, green, blue, and alpha
+ components specified as ints from 0 to 255 or floats from 0.0 to
+ 1.0.
+
+ For a bar chart with multiple series, where the dataframe is in
+ long format (that is, y is None or just one column), this can be:
+
+ * None, to use the default colors.
+ * The name of a column in the dataset. Data points will be grouped
+ into series of the same color based on the value of this column.
+ In addition, if the values in this column match one of the color
+ formats above (hex string or color tuple), then that color will
+ be used.
+
+ For example: if the dataset has 1000 rows, but this column can
+ only contains the values "adult", "child", "baby",
+ then those 1000 datapoints will be grouped into 3 series, whose
+ colors will be automatically selected from the default palette.
+
+ But, if for the same 1000-row dataset, this column contained
+ the values "#ffaa00", "#f0f", "#0000ff", then then those 1000
+ datapoints would still be grouped into 3 series, but their
+ colors would be "#ffaa00", "#f0f", "#0000ff" this time around.
+
+ For a bar chart with multiple series, where the dataframe is in
+ wide format (that is, y is a sequence of columns), this can be:
+
+ * None, to use the default colors.
+ * A list of string colors or color tuples to be used for each of
+ the series in the chart. This list should have the same length
+ as the number of y values (e.g. ``color=["#fd0", "#f0f", "#04f"]``
+ for three lines).
+
width : int
The chart width in pixels. If 0, selects the width automatically.
This argument can only be supplied by keyword.
@@ -468,7 +693,55 @@ def bar_chart(
.. output::
https://doc-bar-chart.streamlit.app/
- height: 400px
+ height: 440px
+
+ You can also choose different columns to use for x and y, as well as set
+ the color dynamically based on a 3rd column (assuming your dataframe is in
+ long format):
+
+ >>> import streamlit as st
+ >>> import pandas as pd
+ >>> import numpy as np
+ >>>
+ >>> chart_data = pd.DataFrame({
+ ... 'col1' : np.random.randn(20),
+ ... 'col2' : np.random.randn(20),
+ ... 'col3' : np.random.choice(['A','B','C'],20)
+ ... })
+ >>>
+ >>> st.bar_chart(
+ ... chart_data,
+ ... x='col1',
+ ... y='col2',
+ ... color='col3'
+ ... )
+
+ .. output::
+ https://doc-bar-chart1.streamlit.app/
+ height: 440px
+
+ Finally, if your dataframe is in wide format, you can group multiple
+ columns under the y argument to show multiple series with different
+ colors:
+
+ >>> import streamlit as st
+ >>> import pandas as pd
+ >>> import numpy as np
+ >>>
+ >>> chart_data = pd.DataFrame(
+ ... np.random.randn(20, 3),
+ ... columns=['col1', 'col2', 'col3'])
+ ...
+ >>> st.bar_chart(
+ ... chart_data,
+ ... x='col1',
+ ... y=['col2', 'col3'],
+ ... color=['#FF0000','#0000FF'] # Optional
+ ... )
+
+ .. output::
+ https://doc-bar-chart2.streamlit.app/
+ height: 440px
"""
@@ -477,6 +750,7 @@ def bar_chart(
data,
x=x,
y=y,
+ color=color,
width=width,
height=height,
use_container_width=use_container_width,
diff --git a/lib/streamlit/elements/deck_gl_json_chart.py b/lib/streamlit/elements/deck_gl_json_chart.py
index 34dda1e63ca7..632929e9b344 100644
--- a/lib/streamlit/elements/deck_gl_json_chart.py
+++ b/lib/streamlit/elements/deck_gl_json_chart.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import hashlib
import json
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, cast
@@ -156,12 +157,18 @@ def marshall(
) -> None:
if pydeck_obj is None:
spec = json.dumps(EMPTY_MAP)
+ id = ""
else:
spec = pydeck_obj.to_json()
+ json_string = json.dumps(spec)
+ json_bytes = json_string.encode("utf-8")
+ id = hashlib.md5(json_bytes).hexdigest()
pydeck_proto.json = spec
pydeck_proto.use_container_width = use_container_width
+ pydeck_proto.id = id
+
tooltip = _get_pydeck_tooltip(pydeck_obj)
if tooltip:
pydeck_proto.tooltip = json.dumps(tooltip)
diff --git a/lib/streamlit/elements/heading.py b/lib/streamlit/elements/heading.py
index 357fb42e7ea2..798cf6e7ab8c 100644
--- a/lib/streamlit/elements/heading.py
+++ b/lib/streamlit/elements/heading.py
@@ -34,6 +34,7 @@ class HeadingProtoTag(Enum):
Anchor = Optional[Union[str, Literal[False]]]
+Divider = Optional[Union[bool, str]]
class HeadingMixin:
@@ -44,6 +45,7 @@ def header(
anchor: Anchor = None,
*, # keyword-only arguments:
help: Optional[str] = None,
+ divider: Divider = False,
) -> "DeltaGenerator":
"""Display text in header formatting.
@@ -65,7 +67,7 @@ def header(
* Colored text, using the syntax ``:color[text to be colored]``,
where ``color`` needs to be replaced with any of the following
- supported colors: blue, green, orange, red, violet.
+ supported colors: blue, green, orange, red, violet, gray/grey, rainbow.
anchor : str or False
The anchor name of the header that can be accessed with #anchor
@@ -75,18 +77,34 @@ def header(
help : str
An optional tooltip that gets displayed next to the header.
+ divider : bool or “blue”, “green”, “orange”, “red”, “violet”, “gray”/"grey", or “rainbow”
+ Shows a colored divider below the header. If True, successive
+ headers will cycle through divider colors. That is, the first
+ header will have a blue line, the second header will have a
+ green line, and so on. If a string, the color can be set to one of
+ the following: blue, green, orange, red, violet, gray/grey, or
+ rainbow.
+
Examples
--------
>>> import streamlit as st
>>>
- >>> st.header('This is a header')
- >>> st.header('A header with _italics_ :blue[colors] and emojis :sunglasses:')
+ >>> st.header('This is a header with a divider', divider='rainbow')
+ >>> st.header('_Streamlit_ is :blue[cool] :sunglasses:')
+
+ .. output::
+ https://doc-header.streamlit.app/
+ height: 220px
"""
return self.dg._enqueue(
"heading",
HeadingMixin._create_heading_proto(
- tag=HeadingProtoTag.HEADER_TAG, body=body, anchor=anchor, help=help
+ tag=HeadingProtoTag.HEADER_TAG,
+ body=body,
+ anchor=anchor,
+ help=help,
+ divider=divider,
),
)
@@ -97,6 +115,7 @@ def subheader(
anchor: Anchor = None,
*, # keyword-only arguments:
help: Optional[str] = None,
+ divider: Divider = False,
) -> "DeltaGenerator":
"""Display text in subheader formatting.
@@ -118,7 +137,7 @@ def subheader(
* Colored text, using the syntax ``:color[text to be colored]``,
where ``color`` needs to be replaced with any of the following
- supported colors: blue, green, orange, red, violet.
+ supported colors: blue, green, orange, red, violet, gray/grey, rainbow.
anchor : str or False
The anchor name of the header that can be accessed with #anchor
@@ -128,18 +147,34 @@ def subheader(
help : str
An optional tooltip that gets displayed next to the subheader.
+ divider : bool or “blue”, “green”, “orange”, “red”, “violet”, “gray”/"grey", or “rainbow”
+ Shows a colored divider below the header. If True, successive
+ headers will cycle through divider colors. That is, the first
+ header will have a blue line, the second header will have a
+ green line, and so on. If a string, the color can be set to one of
+ the following: blue, green, orange, red, violet, gray/grey, or
+ rainbow.
+
Examples
--------
>>> import streamlit as st
>>>
- >>> st.subheader('This is a subheader')
- >>> st.subheader('A subheader with _italics_ :blue[colors] and emojis :sunglasses:')
+ >>> st.subheader('This is a subheader with a divider', divider='rainbow')
+ >>> st.subheader('_Streamlit_ is :blue[cool] :sunglasses:')
+
+ .. output::
+ https://doc-subheader.streamlit.app/
+ height: 220px
"""
return self.dg._enqueue(
"heading",
HeadingMixin._create_heading_proto(
- tag=HeadingProtoTag.SUBHEADER_TAG, body=body, anchor=anchor, help=help
+ tag=HeadingProtoTag.SUBHEADER_TAG,
+ body=body,
+ anchor=anchor,
+ help=help,
+ divider=divider,
),
)
@@ -174,7 +209,7 @@ def title(
* Colored text, using the syntax ``:color[text to be colored]``,
where ``color`` needs to be replaced with any of the following
- supported colors: blue, green, orange, red, violet.
+ supported colors: blue, green, orange, red, violet, gray/grey, rainbow.
anchor : str or False
The anchor name of the header that can be accessed with #anchor
@@ -189,7 +224,11 @@ def title(
>>> import streamlit as st
>>>
>>> st.title('This is a title')
- >>> st.title('A title with _italics_ :blue[colors] and emojis :sunglasses:')
+ >>> st.title('_Streamlit_ is :blue[cool] :sunglasses:')
+
+ .. output::
+ https://doc-title.streamlit.app/
+ height: 220px
"""
return self.dg._enqueue(
@@ -204,16 +243,40 @@ def dg(self) -> "DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("DeltaGenerator", self)
+ @staticmethod
+ def _handle_divider_color(divider):
+ if divider is True:
+ return "auto"
+ valid_colors = [
+ "blue",
+ "green",
+ "orange",
+ "red",
+ "violet",
+ "gray",
+ "grey",
+ "rainbow",
+ ]
+ if divider in valid_colors:
+ return divider
+ else:
+ raise StreamlitAPIException(
+ f"Divider parameter has invalid value: `{divider}`. Please choose from: {', '.join(valid_colors)}."
+ )
+
@staticmethod
def _create_heading_proto(
tag: HeadingProtoTag,
body: SupportsStr,
anchor: Anchor = None,
help: Optional[str] = None,
+ divider: Divider = False,
) -> HeadingProto:
proto = HeadingProto()
proto.tag = tag.value
proto.body = clean_text(body)
+ if divider:
+ proto.divider = HeadingMixin._handle_divider_color(divider)
if anchor is not None:
if anchor is False:
proto.hide_anchor = True
diff --git a/lib/streamlit/elements/image.py b/lib/streamlit/elements/image.py
index 6cc7c007a0fd..60789a4ede4e 100644
--- a/lib/streamlit/elements/image.py
+++ b/lib/streamlit/elements/image.py
@@ -19,7 +19,6 @@
"""Image marshalling."""
-import imghdr
import io
import mimetypes
import re
@@ -288,8 +287,8 @@ def _ensure_image_size_and_format(
MAXIMUM_CONTENT_WIDTH. Ensure the image's format corresponds to the given
ImageFormat. Return the (possibly resized and reformatted) image bytes.
"""
- image = Image.open(io.BytesIO(image_data))
- actual_width, actual_height = image.size
+ pil_image = Image.open(io.BytesIO(image_data))
+ actual_width, actual_height = pil_image.size
if width < 0 and actual_width > MAXIMUM_CONTENT_WIDTH:
width = MAXIMUM_CONTENT_WIDTH
@@ -297,13 +296,12 @@ def _ensure_image_size_and_format(
if width > 0 and actual_width > width:
# We need to resize the image.
new_height = int(1.0 * actual_height * width / actual_width)
- image = image.resize((width, new_height), resample=Image.BILINEAR)
- return _PIL_to_bytes(image, format=image_format, quality=90)
+ pil_image = pil_image.resize((width, new_height), resample=Image.BILINEAR)
+ return _PIL_to_bytes(pil_image, format=image_format, quality=90)
- ext = imghdr.what(None, image_data)
- if ext != image_format.lower():
+ if pil_image.format != image_format:
# We need to reformat the image.
- return _PIL_to_bytes(image, format=image_format, quality=90)
+ return _PIL_to_bytes(pil_image, format=image_format, quality=90)
# No resizing or reformatting necessary - return the original bytes.
return image_data
diff --git a/lib/streamlit/elements/layouts.py b/lib/streamlit/elements/layouts.py
index 1137c0aba57a..f9fe3e2b5149 100644
--- a/lib/streamlit/elements/layouts.py
+++ b/lib/streamlit/elements/layouts.py
@@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, List, Optional, Sequence, Union, cast
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, List, Literal, Optional, Sequence, Union, cast
from streamlit.errors import StreamlitAPIException
from streamlit.proto.Block_pb2 import Block as BlockProto
@@ -19,6 +22,7 @@
if TYPE_CHECKING:
from streamlit.delta_generator import DeltaGenerator
+ from streamlit.elements.lib.mutable_status_container import StatusContainer
SpecType = Union[int, Sequence[Union[int, float]]]
@@ -236,7 +240,7 @@ def tabs(self, tabs: Sequence[str]) -> Sequence["DeltaGenerator"]:
* Colored text, using the syntax ``:color[text to be colored]``,
where ``color`` needs to be replaced with any of the following
- supported colors: blue, green, orange, red, violet.
+ supported colors: blue, green, orange, red, violet, gray/grey, rainbow.
Unsupported elements are unwrapped so only their children (text contents) render.
Display unsupported elements as literal characters by
@@ -346,7 +350,7 @@ def expander(self, label: str, expanded: bool = False) -> "DeltaGenerator":
* Colored text, using the syntax ``:color[text to be colored]``,
where ``color`` needs to be replaced with any of the following
- supported colors: blue, green, orange, red, violet.
+ supported colors: blue, green, orange, red, violet, gray/grey, rainbow.
Unsupported elements are unwrapped so only their children (text contents) render.
Display unsupported elements as literal characters by
@@ -402,11 +406,133 @@ def expander(self, label: str, expanded: bool = False) -> "DeltaGenerator":
expandable_proto.label = label
block_proto = BlockProto()
- block_proto.allow_empty = True
+ block_proto.allow_empty = False
block_proto.expandable.CopyFrom(expandable_proto)
return self.dg._block(block_proto=block_proto)
+ @gather_metrics("status")
+ def status(
+ self,
+ label: str,
+ *,
+ expanded: bool = False,
+ state: Literal["running", "complete", "error"] = "running",
+ ) -> "StatusContainer":
+ """Insert a status container to display output from long-running tasks.
+
+ Inserts a container into your app that is typically used to show the status and
+ details of a process or task. The container can hold multiple elements and can
+ be expanded or collapsed by the user similar to ``st.expander``.
+ When collapsed, all that is visible is the status icon and label.
+
+ The label, state, and expanded state can all be updated by calling ``.update()``
+ on the returned object. To add elements to the returned container, you can
+ use "with" notation (preferred) or just call methods directly on the returned
+ object.
+
+ By default, ``st.status()`` initializes in the "running" state. When called using
+ "with" notation, it automatically updates to the "complete" state at the end
+ of the "with" block. See examples below for more details.
+
+ Parameters
+ ----------
+
+ label : str
+ The initial label of the status container. The label can optionally
+ contain Markdown and supports the following elements: Bold,
+ Italics, Strikethroughs, Inline Code, Emojis, and Links.
+
+ This also supports:
+
+ * Emoji shortcodes, such as ``:+1:`` and ``:sunglasses:``.
+ For a list of all supported codes,
+ see https://share.streamlit.io/streamlit/emoji-shortcodes.
+
+ * LaTeX expressions, by wrapping them in "$" or "$$" (the "$$"
+ must be on their own lines). Supported LaTeX functions are listed
+ at https://katex.org/docs/supported.html.
+
+ * Colored text, using the syntax ``:color[text to be colored]``,
+ where ``color`` needs to be replaced with any of the following
+ supported colors: blue, green, orange, red, violet, gray/grey, rainbow.
+
+ Unsupported elements are unwrapped so only their children (text contents)
+ render. Display unsupported elements as literal characters by
+ backslash-escaping them. E.g. ``1\. Not an ordered list``.
+
+ expanded : bool
+ If True, initializes the status container in "expanded" state. Defaults to
+ False (collapsed).
+
+ state : "running", "complete", or "error"
+ The initial state of the status container which determines which icon is
+ shown:
+
+ * ``running`` (default): A spinner icon is shown.
+
+ * ``complete``: A checkmark icon is shown.
+
+ * ``error``: An error icon is shown.
+
+ Returns
+ -------
+
+ StatusContainer
+ A mutable status container that can hold multiple elements. The label, state,
+ and expanded state can be updated after creation via ``.update()``.
+
+ Examples
+ --------
+
+ You can use `with` notation to insert any element into an status container:
+
+ >>> import time
+ >>> import streamlit as st
+ >>>
+ >>> with st.status("Downloading data..."):
+ ... st.write("Searching for data...")
+ ... time.sleep(2)
+ ... st.write("Found URL.")
+ ... time.sleep(1)
+ ... st.write("Downloading data...")
+ ... time.sleep(1)
+ >>>
+ >>> st.button('Rerun')
+
+ .. output ::
+ https://doc-status.streamlit.app/
+ height: 300px
+
+ You can also use `.update()` on the container to change the label, state,
+ or expanded state:
+
+ >>> import time
+ >>> import streamlit as st
+ >>>
+ >>> with st.status("Downloading data...", expanded=True) as status:
+ ... st.write("Searching for data...")
+ ... time.sleep(2)
+ ... st.write("Found URL.")
+ ... time.sleep(1)
+ ... st.write("Downloading data...")
+ ... time.sleep(1)
+ ... status.update(label="Download complete!", state="complete", expanded=False)
+ >>>
+ >>> st.button('Rerun')
+
+ .. output ::
+ https://doc-status-update.streamlit.app/
+ height: 300px
+
+ """
+ # We need to import StatusContainer here to avoid a circular import
+ from streamlit.elements.lib.mutable_status_container import StatusContainer
+
+ return StatusContainer._create(
+ self.dg, label=label, expanded=expanded, state=state
+ )
+
@property
def dg(self) -> "DeltaGenerator":
"""Get our DeltaGenerator."""
diff --git a/lib/streamlit/elements/legacy_altair.py b/lib/streamlit/elements/legacy_altair.py
index 38db5fc6a251..1014376c94cc 100644
--- a/lib/streamlit/elements/legacy_altair.py
+++ b/lib/streamlit/elements/legacy_altair.py
@@ -18,13 +18,14 @@
"""
from datetime import date
-from typing import TYPE_CHECKING, Hashable, cast
+from typing import TYPE_CHECKING, Hashable, Optional, cast
import pandas as pd
import pyarrow as pa
import streamlit.elements.legacy_vega_lite as vega_lite
from streamlit import errors, type_util
+from streamlit.elements.altair_utils import AddRowsMetadata
from streamlit.elements.utils import last_index_for_melted_dataframes
from streamlit.proto.VegaLiteChart_pb2 import VegaLiteChart as VegaLiteChartProto
from streamlit.runtime.metrics_util import gather_metrics
@@ -90,12 +91,11 @@ def _legacy_line_chart(
"""
vega_lite_chart_proto = VegaLiteChartProto()
- chart = generate_chart("line", data, width, height)
+ chart, add_rows_metadata = generate_chart("line", data, width, height)
marshall(vega_lite_chart_proto, chart, use_container_width)
- last_index = last_index_for_melted_dataframes(data)
return self.dg._enqueue(
- "line_chart", vega_lite_chart_proto, last_index=last_index
+ "line_chart", vega_lite_chart_proto, add_rows_metadata=add_rows_metadata
)
@gather_metrics("_legacy_area_chart")
@@ -150,12 +150,11 @@ def _legacy_area_chart(
"""
vega_lite_chart_proto = VegaLiteChartProto()
- chart = generate_chart("area", data, width, height)
+ chart, add_rows_metadata = generate_chart("area", data, width, height)
marshall(vega_lite_chart_proto, chart, use_container_width)
- last_index = last_index_for_melted_dataframes(data)
return self.dg._enqueue(
- "area_chart", vega_lite_chart_proto, last_index=last_index
+ "area_chart", vega_lite_chart_proto, add_rows_metadata=add_rows_metadata
)
@gather_metrics("_legacy_bar_chart")
@@ -210,12 +209,11 @@ def _legacy_bar_chart(
"""
vega_lite_chart_proto = VegaLiteChartProto()
- chart = generate_chart("bar", data, width, height)
+ chart, add_rows_metadata = generate_chart("bar", data, width, height)
marshall(vega_lite_chart_proto, chart, use_container_width)
- last_index = last_index_for_melted_dataframes(data)
return self.dg._enqueue(
- "bar_chart", vega_lite_chart_proto, last_index=last_index
+ "bar_chart", vega_lite_chart_proto, add_rows_metadata=add_rows_metadata
)
@gather_metrics("_legacy_altair_chart")
@@ -273,7 +271,7 @@ def dg(self) -> "DeltaGenerator":
return cast("DeltaGenerator", self)
-def _is_date_column(df: pd.DataFrame, name: Hashable) -> bool:
+def _is_date_column(df: pd.DataFrame, name: Optional[Hashable]) -> bool:
"""True if the column with the given name stores datetime.date values.
This function just checks the first value in the given column, so
@@ -305,14 +303,7 @@ def generate_chart(chart_type, data, width: int = 0, height: int = 0):
data = {"": []}
if isinstance(data, pa.Table):
- raise errors.StreamlitAPIException(
- """
-pyarrow tables are not supported by Streamlit's legacy DataFrame serialization (i.e. with `config.dataFrameSerialization = "legacy"`).
-
-To be able to use pyarrow tables, please enable pyarrow by changing the config setting,
-`config.dataFrameSerialization = "arrow"`
-"""
- )
+ raise ArrowNotSupportedError()
if not isinstance(data, pd.DataFrame):
data = type_util.convert_anything_to_df(data)
@@ -321,6 +312,16 @@ def generate_chart(chart_type, data, width: int = 0, height: int = 0):
if index_name is None:
index_name = "index"
+ add_rows_metadata = AddRowsMetadata(
+ last_index=last_index_for_melted_dataframes(data),
+ # Not used:
+ columns=dict(
+ x_column=index_name,
+ y_column_list=[],
+ color_column=None,
+ ),
+ )
+
data = pd.melt(data.reset_index(), id_vars=[index_name])
if chart_type == "area":
@@ -355,7 +356,7 @@ def generate_chart(chart_type, data, width: int = 0, height: int = 0):
)
.interactive()
)
- return chart
+ return chart, add_rows_metadata
def marshall(
@@ -395,3 +396,15 @@ def id_transform(data):
use_container_width=use_container_width,
**kwargs,
)
+
+
+class ArrowNotSupportedError(errors.StreamlitAPIException):
+ def __init__(self, *args):
+ message = """
+pyarrow tables are not supported by Streamlit's legacy DataFrame serialization (i.e.
+with `config.dataFrameSerialization = "legacy"`).
+
+To be able to use pyarrow tables, please enable pyarrow by changing the config setting,
+`config.dataFrameSerialization = "arrow"`.
+"""
+ super().__init__(message, *args)
diff --git a/lib/streamlit/elements/lib/column_types.py b/lib/streamlit/elements/lib/column_types.py
index 3e594e5f7623..1e081d4aa6f0 100644
--- a/lib/streamlit/elements/lib/column_types.py
+++ b/lib/streamlit/elements/lib/column_types.py
@@ -719,6 +719,7 @@ def SelectboxColumn(
>>> "📈 Data Visualization",
>>> "🤖 LLM",
>>> ],
+ >>> required=True,
>>> )
>>> },
>>> hide_index=True,
diff --git a/lib/streamlit/elements/lib/mutable_status_container.py b/lib/streamlit/elements/lib/mutable_status_container.py
new file mode 100644
index 000000000000..897a07d97b6d
--- /dev/null
+++ b/lib/streamlit/elements/lib/mutable_status_container.py
@@ -0,0 +1,179 @@
+# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import time
+from types import TracebackType
+from typing import List, Optional, Type, cast
+
+from typing_extensions import Literal, TypeAlias
+
+from streamlit.cursor import Cursor
+from streamlit.delta_generator import DeltaGenerator, _enqueue_message
+from streamlit.errors import StreamlitAPIException
+from streamlit.proto.Block_pb2 import Block as BlockProto
+from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
+
+States: TypeAlias = Literal["running", "complete", "error"]
+
+
+class StatusContainer(DeltaGenerator):
+ @staticmethod
+ def _create(
+ parent: DeltaGenerator,
+ label: str,
+ expanded: bool = False,
+ state: States = "running",
+ ) -> StatusContainer:
+ expandable_proto = BlockProto.Expandable()
+ expandable_proto.expanded = expanded
+ expandable_proto.label = label or ""
+
+ if state == "running":
+ expandable_proto.icon = "spinner"
+ elif state == "complete":
+ expandable_proto.icon = "check"
+ elif state == "error":
+ expandable_proto.icon = "error"
+ else:
+ raise StreamlitAPIException(
+ f"Unknown state ({state}). Must be one of 'running', 'complete', or 'error'."
+ )
+
+ block_proto = BlockProto()
+ block_proto.allow_empty = True
+ block_proto.expandable.CopyFrom(expandable_proto)
+
+ delta_path: List[int] = (
+ parent._active_dg._cursor.delta_path if parent._active_dg._cursor else []
+ )
+
+ status_container = cast(
+ StatusContainer,
+ parent._block(block_proto=block_proto, dg_type=StatusContainer),
+ )
+
+ # Apply initial configuration
+ status_container._delta_path = delta_path
+ status_container._current_proto = block_proto
+ status_container._current_state = state
+
+ # We need to sleep here for a very short time to prevent issues when
+ # the status is updated too quickly. If an .update() directly follows the
+ # the initialization, sometimes only the latest update is applied.
+ # Adding a short timeout here allows the frontend to render the update before.
+ time.sleep(0.05)
+
+ return status_container
+
+ def __init__(
+ self,
+ root_container: int | None,
+ cursor: Cursor | None,
+ parent: DeltaGenerator | None,
+ block_type: str | None,
+ ):
+ super().__init__(root_container, cursor, parent, block_type)
+
+ # Initialized in `_create()`:
+ self._current_proto: BlockProto | None = None
+ self._current_state: States | None = None
+ self._delta_path: List[int] | None = None
+
+ def update(
+ self,
+ *,
+ label: str | None = None,
+ expanded: bool | None = None,
+ state: States | None = None,
+ ) -> None:
+ """Update the status container.
+
+ Only specified arguments are updated. Container contents and unspecified
+ arguments remain unchanged.
+
+ Parameters
+ ----------
+ label : str or None
+ A new label of the status container. If None, the label is not
+ changed.
+
+ expanded : bool or None
+ The new expanded state of the status container. If None,
+ the expanded state is not changed.
+
+ state : "running", "complete", "error", or None
+ The new state of the status container. This mainly changes the
+ icon. If None, the state is not changed.
+ """
+ assert self._current_proto is not None, "Status not correctly initialized!"
+ assert self._delta_path is not None, "Status not correctly initialized!"
+
+ msg = ForwardMsg()
+ msg.metadata.delta_path[:] = self._delta_path
+ msg.delta.add_block.CopyFrom(self._current_proto)
+
+ if expanded is not None:
+ msg.delta.add_block.expandable.expanded = expanded
+ else:
+ msg.delta.add_block.expandable.ClearField("expanded")
+
+ if label is not None:
+ msg.delta.add_block.expandable.label = label
+
+ if state is not None:
+ if state == "running":
+ msg.delta.add_block.expandable.icon = "spinner"
+ elif state == "complete":
+ msg.delta.add_block.expandable.icon = "check"
+ elif state == "error":
+ msg.delta.add_block.expandable.icon = "error"
+ else:
+ raise StreamlitAPIException(
+ f"Unknown state ({state}). Must be one of 'running', 'complete', or 'error'."
+ )
+ self._current_state = state
+
+ self._current_proto = msg.delta.add_block
+ _enqueue_message(msg)
+
+ def __enter__(self) -> StatusContainer: # type: ignore[override]
+ # This is a little dubious: we're returning a different type than
+ # our superclass' `__enter__` function. Maybe DeltaGenerator.__enter__
+ # should always return `self`?
+ super().__enter__()
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> Literal[False]:
+ # Only update if the current state is running
+ if self._current_state == "running":
+ # We need to sleep here for a very short time to prevent issues when
+ # the status is updated too quickly. If an .update() is directly followed
+ # by the exit of the context manager, sometimes only the last update
+ # (to complete) is applied. Adding a short timeout here allows the frontend
+ # to render the update before.
+ time.sleep(0.05)
+ if exc_type is not None:
+ # If an exception was raised in the context,
+ # we want to update the status to error.
+ self.update(state="error")
+ else:
+ self.update(state="complete")
+ return super().__exit__(exc_type, exc_val, exc_tb)
diff --git a/lib/streamlit/elements/markdown.py b/lib/streamlit/elements/markdown.py
index 161d4273de16..3213e0ccc926 100644
--- a/lib/streamlit/elements/markdown.py
+++ b/lib/streamlit/elements/markdown.py
@@ -56,7 +56,7 @@ def markdown(
* Colored text, using the syntax ``:color[text to be colored]``,
where ``color`` needs to be replaced with any of the following
- supported colors: blue, green, orange, red, violet.
+ supported colors: blue, green, orange, red, violet, gray/grey, rainbow.
unsafe_allow_html : bool
By default, any HTML tags found in the body will be escaped and
@@ -76,9 +76,23 @@ def markdown(
--------
>>> import streamlit as st
>>>
- >>> st.markdown('Streamlit is **_really_ cool**.')
- >>> st.markdown("This text is :red[colored red], and this is **:blue[colored]** and bold.")
- >>> st.markdown(":green[$\sqrt{x^2+y^2}=1$] is a Pythagorean identity. :pencil:")
+ >>> st.markdown("*Streamlit* is **really** ***cool***.")
+ >>> st.markdown('''
+ ... :red[Streamlit] :orange[can] :green[write] :blue[text] :violet[in]
+ ... :gray[pretty] :rainbow[colors].''')
+ >>> st.markdown("Here's a bouquet —\
+ ... :tulip::cherry_blossom::rose::hibiscus::sunflower::blossom:")
+ >>>
+ >>> multi = '''If you end a line with two spaces,
+ ... a soft return is used for the next line.
+ ...
+ ... Two (or more) newline characters in a row will result in a hard return.
+ ... '''
+ >>> st.markdown(multi)
+
+ .. output::
+ https://doc-markdown.streamlit.app/
+ height: 260px
"""
markdown_proto = MarkdownProto()
@@ -160,7 +174,7 @@ def caption(
* Colored text, using the syntax ``:color[text to be colored]``,
where ``color`` needs to be replaced with any of the following
- supported colors: blue, green, orange, red, violet.
+ supported colors: blue, green, orange, red, violet, gray/grey, rainbow.
unsafe_allow_html : bool
By default, any HTML tags found in strings will be escaped and
diff --git a/lib/streamlit/elements/metric.py b/lib/streamlit/elements/metric.py
index f0242e309da9..843a8e4f10bd 100644
--- a/lib/streamlit/elements/metric.py
+++ b/lib/streamlit/elements/metric.py
@@ -79,7 +79,7 @@ def metric(
* Colored text, using the syntax ``:color[text to be colored]``,
where ``color`` needs to be replaced with any of the following
- supported colors: blue, green, orange, red, violet.
+ supported colors: blue, green, orange, red, violet, gray/grey, rainbow.
Unsupported elements are unwrapped so only their children (text contents) render.
Display unsupported elements as literal characters by
diff --git a/lib/streamlit/elements/progress.py b/lib/streamlit/elements/progress.py
index 6c36877ca73c..56135933802d 100644
--- a/lib/streamlit/elements/progress.py
+++ b/lib/streamlit/elements/progress.py
@@ -92,7 +92,7 @@ def progress(
* Colored text, using the syntax ``:color[text to be colored]``,
where ``color`` needs to be replaced with any of the following
- supported colors: blue, green, orange, red, violet.
+ supported colors: blue, green, orange, red, violet, gray/grey, rainbow.
Unsupported elements are unwrapped so only their children (text contents) render.
Display unsupported elements as literal characters by
diff --git a/lib/streamlit/elements/toast.py b/lib/streamlit/elements/toast.py
index 0582a47807e2..81d908719b58 100644
--- a/lib/streamlit/elements/toast.py
+++ b/lib/streamlit/elements/toast.py
@@ -67,7 +67,7 @@ def toast(
* Colored text, using the syntax ``:color[text to be colored]``,
where ``color`` needs to be replaced with any of the following
- supported colors: blue, green, orange, red, violet.
+ supported colors: blue, green, orange, red, violet, gray/grey, rainbow.
icon : str or None
An optional, keyword-only argument that specifies an emoji to use as
the icon for the toast. Shortcodes are not allowed, please use a
diff --git a/lib/streamlit/elements/widgets/button.py b/lib/streamlit/elements/widgets/button.py
index 025482aa9f49..dc38608ea75c 100644
--- a/lib/streamlit/elements/widgets/button.py
+++ b/lib/streamlit/elements/widgets/button.py
@@ -33,12 +33,12 @@
WidgetKwargs,
register_widget,
)
+from streamlit.runtime.state.common import compute_widget_id
from streamlit.type_util import Key, to_key
if TYPE_CHECKING:
from streamlit.delta_generator import DeltaGenerator
-
FORM_DOCS_INFO: Final = """
For more information, refer to the
@@ -93,7 +93,7 @@ def button(
* Colored text, using the syntax ``:color[text to be colored]``,
where ``color`` needs to be replaced with any of the following
- supported colors: blue, green, orange, red, violet.
+ supported colors: blue, green, orange, red, violet, gray/grey, rainbow.
Unsupported elements are unwrapped so only their children (text contents) render.
Display unsupported elements as literal characters by
@@ -132,6 +132,7 @@ def button(
-------
>>> import streamlit as st
>>>
+ >>> st.button("Reset", type="primary")
>>> if st.button('Say hello'):
... st.write('Why hello there')
... else:
@@ -179,6 +180,7 @@ def download_button(
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only arguments:
+ type: Literal["primary", "secondary"] = "secondary",
disabled: bool = False,
use_container_width: bool = False,
) -> bool:
@@ -210,7 +212,7 @@ def download_button(
* Colored text, using the syntax ``:color[text to be colored]``,
where ``color`` needs to be replaced with any of the following
- supported colors: blue, green, orange, red, violet.
+ supported colors: blue, green, orange, red, violet, gray/grey, rainbow.
Unsupported elements are unwrapped so only their children (text contents) render.
Display unsupported elements as literal characters by
@@ -241,6 +243,10 @@ def download_button(
An optional tuple of args to pass to the callback.
kwargs : dict
An optional dict of kwargs to pass to the callback.
+ type : "secondary" or "primary"
+ An optional string that specifies the button type. Can be "primary" for a
+ button with additional emphasis or "secondary" for a normal button. This
+ argument can only be supplied by keyword. Defaults to "secondary".
disabled : bool
An optional boolean, which disables the download button if set to
True. The default is False. This argument can only be supplied by
@@ -308,6 +314,13 @@ def download_button(
"""
ctx = get_script_run_ctx()
+
+ if type not in ["primary", "secondary"]:
+ raise StreamlitAPIException(
+ 'The type argument to st.button must be "primary" or "secondary". \n'
+ f'The argument passed was "{type}".'
+ )
+
return self._download_button(
label=label,
data=data,
@@ -319,6 +332,7 @@ def download_button(
args=args,
kwargs=kwargs,
disabled=disabled,
+ type=type,
use_container_width=use_container_width,
ctx=ctx,
)
@@ -335,26 +349,42 @@ def _download_button(
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
*, # keyword-only arguments:
+ type: Literal["primary", "secondary"] = "secondary",
disabled: bool = False,
use_container_width: bool = False,
ctx: Optional[ScriptRunContext] = None,
) -> bool:
-
key = to_key(key)
check_session_state_rules(default_value=None, key=key, writes_allowed=False)
+
+ id = compute_widget_id(
+ "download_button",
+ user_key=key,
+ label=label,
+ data=str(data),
+ file_name=file_name,
+ mime=mime,
+ key=key,
+ help=help,
+ type=type,
+ use_container_width=use_container_width,
+ )
+
if is_in_form(self.dg):
raise StreamlitAPIException(
f"`st.download_button()` can't be used in an `st.form()`.{FORM_DOCS_INFO}"
)
download_button_proto = DownloadButtonProto()
-
+ download_button_proto.id = id
download_button_proto.use_container_width = use_container_width
download_button_proto.label = label
download_button_proto.default = False
+ download_button_proto.type = type
marshall_file(
self.dg._get_delta_path_str(), data, download_button_proto, mime, file_name
)
+ download_button_proto.disabled = disabled
if help is not None:
download_button_proto.help = dedent(help)
@@ -373,10 +403,6 @@ def _download_button(
ctx=ctx,
)
- # This needs to be done after register_widget because we don't want
- # the following proto fields to affect a widget's ID.
- download_button_proto.disabled = disabled
-
self.dg._enqueue("download_button", download_button_proto)
return button_state.value
@@ -399,6 +425,17 @@ def _button(
check_callback_rules(self.dg, on_click)
check_session_state_rules(default_value=None, key=key, writes_allowed=False)
+ id = compute_widget_id(
+ "button",
+ user_key=key,
+ label=label,
+ key=key,
+ help=help,
+ is_form_submitter=is_form_submitter,
+ type=type,
+ use_container_width=use_container_width,
+ )
+
# It doesn't make sense to create a button inside a form (except
# for the "Form Submitter" button that's automatically created in
# every form). We throw an error to warn the user about this.
@@ -415,12 +452,15 @@ def _button(
)
button_proto = ButtonProto()
+ button_proto.id = id
button_proto.label = label
button_proto.default = False
button_proto.is_form_submitter = is_form_submitter
button_proto.form_id = current_form_id(self.dg)
button_proto.type = type
button_proto.use_container_width = use_container_width
+ button_proto.disabled = disabled
+
if help is not None:
button_proto.help = dedent(help)
@@ -438,10 +478,6 @@ def _button(
ctx=ctx,
)
- # This needs to be done after register_widget because we don't want
- # the following proto fields to affect a widget's ID.
- button_proto.disabled = disabled
-
self.dg._enqueue("button", button_proto)
return button_state.value
diff --git a/lib/streamlit/elements/widgets/camera_input.py b/lib/streamlit/elements/widgets/camera_input.py
index 22c0380380f7..77e317293dc3 100644
--- a/lib/streamlit/elements/widgets/camera_input.py
+++ b/lib/streamlit/elements/widgets/camera_input.py
@@ -14,7 +14,7 @@
from dataclasses import dataclass
from textwrap import dedent
-from typing import TYPE_CHECKING, List, Optional, cast
+from typing import TYPE_CHECKING, Optional, Union, cast
from streamlit.elements.form import current_form_id
from streamlit.elements.utils import (
@@ -22,6 +22,7 @@
check_session_state_rules,
get_label_visibility_proto_value,
)
+from streamlit.elements.widgets.file_uploader import _get_upload_files
from streamlit.proto.CameraInput_pb2 import CameraInput as CameraInputProto
from streamlit.proto.Common_pb2 import FileUploaderState as FileUploaderStateProto
from streamlit.proto.Common_pb2 import UploadedFileInfo as UploadedFileInfoProto
@@ -33,37 +34,14 @@
WidgetKwargs,
register_widget,
)
-from streamlit.runtime.uploaded_file_manager import UploadedFile, UploadedFileRec
+from streamlit.runtime.state.common import compute_widget_id
+from streamlit.runtime.uploaded_file_manager import DeletedFile, UploadedFile
from streamlit.type_util import Key, LabelVisibility, maybe_raise_label_warnings, to_key
if TYPE_CHECKING:
from streamlit.delta_generator import DeltaGenerator
-SomeUploadedSnapshotFile = Optional[UploadedFile]
-
-
-def _get_file_recs_for_camera_input_widget(
- widget_id: str, widget_value: Optional[FileUploaderStateProto]
-) -> List[UploadedFileRec]:
- if widget_value is None:
- return []
-
- ctx = get_script_run_ctx()
- if ctx is None:
- return []
-
- uploaded_file_info = widget_value.uploaded_file_info
- if len(uploaded_file_info) == 0:
- return []
-
- active_file_ids = [f.id for f in uploaded_file_info]
-
- # Grab the files that correspond to our active file IDs.
- return ctx.uploaded_file_mgr.get_files(
- session_id=ctx.session_id,
- widget_id=widget_id,
- file_ids=active_file_ids,
- )
+SomeUploadedSnapshotFile = Union[UploadedFile, DeletedFile, None]
@dataclass
@@ -74,34 +52,25 @@ def serialize(
) -> FileUploaderStateProto:
state_proto = FileUploaderStateProto()
- ctx = get_script_run_ctx()
- if ctx is None:
- return state_proto
-
- # ctx.uploaded_file_mgr._file_id_counter stores the id to use for
- # the *next* uploaded file, so the current highest file id is the
- # counter minus 1.
- state_proto.max_file_id = ctx.uploaded_file_mgr._file_id_counter - 1
-
- if not snapshot:
+ if snapshot is None or isinstance(snapshot, DeletedFile):
return state_proto
file_info: UploadedFileInfoProto = state_proto.uploaded_file_info.add()
- file_info.id = snapshot.id
+ file_info.file_id = snapshot.file_id
file_info.name = snapshot.name
file_info.size = snapshot.size
+ file_info.file_urls.CopyFrom(snapshot._file_urls)
return state_proto
def deserialize(
self, ui_value: Optional[FileUploaderStateProto], widget_id: str
) -> SomeUploadedSnapshotFile:
- file_recs = _get_file_recs_for_camera_input_widget(widget_id, ui_value)
-
- if len(file_recs) == 0:
+ upload_files = _get_upload_files(ui_value)
+ if len(upload_files) == 0:
return_value = None
else:
- return_value = UploadedFile(file_recs[0])
+ return_value = upload_files[0]
return return_value
@@ -118,7 +87,7 @@ def camera_input(
*, # keyword-only arguments:
disabled: bool = False,
label_visibility: LabelVisibility = "visible",
- ) -> SomeUploadedSnapshotFile:
+ ) -> Optional[UploadedFile]:
r"""Display a widget that returns pictures from the user's webcam.
Parameters
@@ -140,7 +109,7 @@ def camera_input(
* Colored text, using the syntax ``:color[text to be colored]``,
where ``color`` needs to be replaced with any of the following
- supported colors: blue, green, orange, red, violet.
+ supported colors: blue, green, orange, red, violet, gray/grey, rainbow.
Unsupported elements are unwrapped so only their children (text contents) render.
Display unsupported elements as literal characters by
@@ -221,15 +190,29 @@ def _camera_input(
disabled: bool = False,
label_visibility: LabelVisibility = "visible",
ctx: Optional[ScriptRunContext] = None,
- ) -> SomeUploadedSnapshotFile:
+ ) -> Optional[UploadedFile]:
key = to_key(key)
check_callback_rules(self.dg, on_change)
check_session_state_rules(default_value=None, key=key, writes_allowed=False)
maybe_raise_label_warnings(label, label_visibility)
+ id = compute_widget_id(
+ "camera_input",
+ user_key=key,
+ label=label,
+ key=key,
+ help=help,
+ form_id=current_form_id(self.dg),
+ )
+
camera_input_proto = CameraInputProto()
+ camera_input_proto.id = id
camera_input_proto.label = label
camera_input_proto.form_id = current_form_id(self.dg)
+ camera_input_proto.disabled = disabled
+ camera_input_proto.label_visibility.value = get_label_visibility_proto_value(
+ label_visibility
+ )
if help is not None:
camera_input_proto.help = dedent(help)
@@ -248,30 +231,10 @@ def _camera_input(
ctx=ctx,
)
- # This needs to be done after register_widget because we don't want
- # the following proto fields to affect a widget's ID.
- camera_input_proto.disabled = disabled
- camera_input_proto.label_visibility.value = get_label_visibility_proto_value(
- label_visibility
- )
-
- ctx = get_script_run_ctx()
- camera_image_input_state = serde.serialize(camera_input_state.value)
-
- uploaded_shapshot_info = camera_image_input_state.uploaded_file_info
-
- if ctx is not None and len(uploaded_shapshot_info) != 0:
- newest_file_id = camera_image_input_state.max_file_id
- active_file_ids = [f.id for f in uploaded_shapshot_info]
-
- ctx.uploaded_file_mgr.remove_orphaned_files(
- session_id=ctx.session_id,
- widget_id=camera_input_proto.id,
- newest_file_id=newest_file_id,
- active_file_ids=active_file_ids,
- )
-
self.dg._enqueue("camera_input", camera_input_proto)
+
+ if isinstance(camera_input_state.value, DeletedFile):
+ return None
return camera_input_state.value
@property
diff --git a/lib/streamlit/elements/widgets/chat.py b/lib/streamlit/elements/widgets/chat.py
index 0a1627dbf7ac..ecb7b0cfd7b2 100644
--- a/lib/streamlit/elements/widgets/chat.py
+++ b/lib/streamlit/elements/widgets/chat.py
@@ -35,6 +35,7 @@
WidgetKwargs,
register_widget,
)
+from streamlit.runtime.state.common import compute_widget_id
from streamlit.string_util import is_emoji
from streamlit.type_util import Key, to_key
@@ -45,10 +46,12 @@
class PresetNames(str, Enum):
USER = "user"
ASSISTANT = "assistant"
+ AI = "ai" # Equivalent to assistant
+ HUMAN = "human" # Equivalent to user
def _process_avatar_input(
- avatar: str | AtomicImage | None = None,
+ avatar: str | AtomicImage | None, delta_path: str
) -> Tuple[BlockProto.ChatMessage.AvatarType.ValueType, str]:
"""Detects the avatar type and prepares the avatar data for the frontend.
@@ -56,6 +59,9 @@ def _process_avatar_input(
----------
avatar :
The avatar that was provided by the user.
+ delta_path : str
+ The delta path is used as media ID when a local image is served via the media
+ file manager.
Returns
-------
@@ -66,11 +72,14 @@ def _process_avatar_input(
if avatar is None:
return AvatarType.ICON, ""
- elif isinstance(avatar, str) and avatar in [
- PresetNames.USER,
- PresetNames.ASSISTANT,
- ]:
- return AvatarType.ICON, avatar
+ elif isinstance(avatar, str) and avatar in {item.value for item in PresetNames}:
+ # On the frontend, we only support "assistant" and "user" for the avatar.
+ return (
+ AvatarType.ICON,
+ "assistant"
+ if avatar in [PresetNames.AI, PresetNames.ASSISTANT]
+ else "user",
+ )
elif isinstance(avatar, str) and is_emoji(avatar):
return AvatarType.EMOJI, avatar
else:
@@ -83,7 +92,7 @@ def _process_avatar_input(
clamp=False,
channels="RGB",
output_format="auto",
- image_id="",
+ image_id=delta_path,
)
except Exception as ex:
raise StreamlitAPIException(
@@ -112,7 +121,7 @@ class ChatMixin:
@gather_metrics("chat_message")
def chat_message(
self,
- name: Literal["user", "assistant"] | str,
+ name: Literal["user", "assistant", "ai", "human"] | str,
*,
avatar: Literal["user", "assistant"] | str | AtomicImage | None = None,
) -> "DeltaGenerator":
@@ -124,9 +133,9 @@ def chat_message(
Parameters
----------
- name : "user", "assistant", or str
- The name of the message author. Can be “user” or “assistant” to
- enable preset styling and avatars.
+ name : "user", "assistant", "ai", "human", or str
+ The name of the message author. Can be "human"/"user" or
+ "ai"/"assistant" to enable preset styling and avatars.
Currently, the name is not shown in the UI but is only set as an
accessibility label. For accessibility reasons, you should not use
@@ -141,8 +150,8 @@ def chat_message(
image file; URL to fetch the image from; array of shape (w,h) or (w,h,1)
for a monochrome image, (w,h,3) for a color image, or (w,h,4) for an RGBA image.
- If None (default), uses default icons if ``name`` is "user" or
- "assistant", or the first letter of the ``name`` value.
+ If None (default), uses default icons if ``name`` is "user",
+ "assistant", "ai", "human" or the first letter of the ``name`` value.
Returns
-------
@@ -184,16 +193,13 @@ def chat_message(
)
if avatar is None and (
- name.lower()
- in [
- PresetNames.USER,
- PresetNames.ASSISTANT,
- ]
- or is_emoji(name)
+ name.lower() in {item.value for item in PresetNames} or is_emoji(name)
):
# For selected labels, we are mapping the label to an avatar
avatar = name.lower()
- avatar_type, converted_avatar = _process_avatar_input(avatar)
+ avatar_type, converted_avatar = _process_avatar_input(
+ avatar, self.dg._get_delta_path_str()
+ )
message_container_proto = BlockProto.ChatMessage()
message_container_proto.name = name
@@ -277,6 +283,14 @@ def chat_input(
check_callback_rules(self.dg, on_submit)
check_session_state_rules(default_value=default, key=key, writes_allowed=False)
+ id = compute_widget_id(
+ "chat_input",
+ user_key=key,
+ key=key,
+ placeholder=placeholder,
+ max_chars=max_chars,
+ )
+
# We omit this check for scripts running outside streamlit, because
# they will have no script_run_ctx.
if runtime.exists():
@@ -293,6 +307,7 @@ def chat_input(
raise StreamlitAPIException(DISALLOWED_CONTAINERS_ERROR_TEXT)
chat_input_proto = ChatInputProto()
+ chat_input_proto.id = id
chat_input_proto.placeholder = str(placeholder)
if max_chars is not None:
diff --git a/lib/streamlit/elements/widgets/checkbox.py b/lib/streamlit/elements/widgets/checkbox.py
index da6b1cfb8960..13d549eb2904 100644
--- a/lib/streamlit/elements/widgets/checkbox.py
+++ b/lib/streamlit/elements/widgets/checkbox.py
@@ -31,6 +31,7 @@
WidgetKwargs,
register_widget,
)
+from streamlit.runtime.state.common import compute_widget_id
from streamlit.type_util import Key, LabelVisibility, maybe_raise_label_warnings, to_key
if TYPE_CHECKING:
@@ -84,11 +85,15 @@ def checkbox(
* Colored text, using the syntax ``:color[text to be colored]``,
where ``color`` needs to be replaced with any of the following
- supported colors: blue, green, orange, red, violet.
+ supported colors: blue, green, orange, red, violet, gray/grey, rainbow.
Unsupported elements are unwrapped so only their children (text contents) render.
Display unsupported elements as literal characters by
backslash-escaping them. E.g. ``1\. Not an ordered list``.
+
+ For accessibility reasons, you should never set an empty label (label="")
+ but hide it with label_visibility if needed. In the future, we may disallow
+ empty labels by raising an exception.
value : bool
Preselect the checkbox when it first renders. This will be
cast to bool internally.
@@ -144,6 +149,110 @@ def checkbox(
kwargs=kwargs,
disabled=disabled,
label_visibility=label_visibility,
+ type=CheckboxProto.StyleType.DEFAULT,
+ ctx=ctx,
+ )
+
+ @gather_metrics("toggle")
+ def toggle(
+ self,
+ label: str,
+ value: bool = False,
+ key: Optional[Key] = None,
+ help: Optional[str] = None,
+ on_change: Optional[WidgetCallback] = None,
+ args: Optional[WidgetArgs] = None,
+ kwargs: Optional[WidgetKwargs] = None,
+ *, # keyword-only arguments:
+ disabled: bool = False,
+ label_visibility: LabelVisibility = "visible",
+ ) -> bool:
+ r"""Display a toggle widget.
+
+ Parameters
+ ----------
+ label : str
+ A short label explaining to the user what this toggle is for.
+ The label can optionally contain Markdown and supports the following
+ elements: Bold, Italics, Strikethroughs, Inline Code, Emojis, and Links.
+
+ This also supports:
+
+ * Emoji shortcodes, such as ``:+1:`` and ``:sunglasses:``.
+ For a list of all supported codes,
+ see https://share.streamlit.io/streamlit/emoji-shortcodes.
+
+ * LaTeX expressions, by wrapping them in "$" or "$$" (the "$$"
+ must be on their own lines). Supported LaTeX functions are listed
+ at https://katex.org/docs/supported.html.
+
+ * Colored text, using the syntax ``:color[text to be colored]``,
+ where ``color`` needs to be replaced with any of the following
+ supported colors: blue, green, orange, red, violet, gray/grey, rainbow.
+
+ Unsupported elements are unwrapped so only their children (text contents) render.
+ Display unsupported elements as literal characters by
+ backslash-escaping them. E.g. ``1\. Not an ordered list``.
+
+ For accessibility reasons, you should never set an empty label (label="")
+ but hide it with label_visibility if needed. In the future, we may disallow
+ empty labels by raising an exception.
+ value : bool
+ Preselect the toggle when it first renders. This will be
+ cast to bool internally.
+ key : str or int
+ An optional string or integer to use as the unique key for the widget.
+ If this is omitted, a key will be generated for the widget
+ based on its content. Multiple widgets of the same type may
+ not share the same key.
+ help : str
+ An optional tooltip that gets displayed next to the toggle.
+ on_change : callable
+ An optional callback invoked when this toggle's value changes.
+ args : tuple
+ An optional tuple of args to pass to the callback.
+ kwargs : dict
+ An optional dict of kwargs to pass to the callback.
+ disabled : bool
+ An optional boolean, which disables the toggle if set to True.
+ The default is False. This argument can only be supplied by keyword.
+ label_visibility : "visible", "hidden", or "collapsed"
+ The visibility of the label. If "hidden", the label doesn't show but there
+ is still empty space for it (equivalent to label="").
+ If "collapsed", both the label and the space are removed. Default is
+ "visible". This argument can only be supplied by keyword.
+
+ Returns
+ -------
+ bool
+ Whether or not the toggle is checked.
+
+ Example
+ -------
+ >>> import streamlit as st
+ >>>
+ >>> on = st.toggle('Activate feature')
+ >>>
+ >>> if on:
+ ... st.write('Feature activated!')
+
+ .. output::
+ https://doc-toggle.streamlit.app/
+ height: 220px
+
+ """
+ ctx = get_script_run_ctx()
+ return self._checkbox(
+ label=label,
+ value=value,
+ key=key,
+ help=help,
+ on_change=on_change,
+ args=args,
+ kwargs=kwargs,
+ disabled=disabled,
+ label_visibility=label_visibility,
+ type=CheckboxProto.StyleType.TOGGLE,
ctx=ctx,
)
@@ -159,6 +268,7 @@ def _checkbox(
*, # keyword-only arguments:
disabled: bool = False,
label_visibility: LabelVisibility = "visible",
+ type: CheckboxProto.StyleType.ValueType = CheckboxProto.StyleType.DEFAULT,
ctx: Optional[ScriptRunContext] = None,
) -> bool:
key = to_key(key)
@@ -169,10 +279,27 @@ def _checkbox(
maybe_raise_label_warnings(label, label_visibility)
+ id = compute_widget_id(
+ "toggle" if type == CheckboxProto.StyleType.TOGGLE else "checkbox",
+ user_key=key,
+ label=label,
+ value=bool(value),
+ key=key,
+ help=help,
+ form_id=current_form_id(self.dg),
+ )
+
checkbox_proto = CheckboxProto()
+ checkbox_proto.id = id
checkbox_proto.label = label
checkbox_proto.default = bool(value)
+ checkbox_proto.type = type
checkbox_proto.form_id = current_form_id(self.dg)
+ checkbox_proto.disabled = disabled
+ checkbox_proto.label_visibility.value = get_label_visibility_proto_value(
+ label_visibility
+ )
+
if help is not None:
checkbox_proto.help = dedent(help)
@@ -190,13 +317,6 @@ def _checkbox(
ctx=ctx,
)
- # This needs to be done after register_widget because we don't want
- # the following proto fields to affect a widget's ID.
- checkbox_proto.disabled = disabled
- checkbox_proto.label_visibility.value = get_label_visibility_proto_value(
- label_visibility
- )
-
if checkbox_state.value_changed:
checkbox_proto.value = checkbox_state.value
checkbox_proto.set_value = True
diff --git a/lib/streamlit/elements/widgets/color_picker.py b/lib/streamlit/elements/widgets/color_picker.py
index e3bfc33bef6d..bc7a8ffa2103 100644
--- a/lib/streamlit/elements/widgets/color_picker.py
+++ b/lib/streamlit/elements/widgets/color_picker.py
@@ -34,6 +34,7 @@
WidgetKwargs,
register_widget,
)
+from streamlit.runtime.state.common import compute_widget_id
from streamlit.type_util import Key, LabelVisibility, maybe_raise_label_warnings, to_key
@@ -84,7 +85,7 @@ def color_picker(
* Colored text, using the syntax ``:color[text to be colored]``,
where ``color`` needs to be replaced with any of the following
- supported colors: blue, green, orange, red, violet.
+ supported colors: blue, green, orange, red, violet, gray/grey, rainbow.
Unsupported elements are unwrapped so only their children (text contents) render.
Display unsupported elements as literal characters by
@@ -170,6 +171,16 @@ def _color_picker(
check_session_state_rules(default_value=value, key=key)
maybe_raise_label_warnings(label, label_visibility)
+ id = compute_widget_id(
+ "color_picker",
+ user_key=key,
+ label=label,
+ value=str(value),
+ key=key,
+ help=help,
+ form_id=current_form_id(self.dg),
+ )
+
# set value default
if value is None:
value = "#000000"
@@ -197,9 +208,15 @@ def _color_picker(
)
color_picker_proto = ColorPickerProto()
+ color_picker_proto.id = id
color_picker_proto.label = label
color_picker_proto.default = str(value)
color_picker_proto.form_id = current_form_id(self.dg)
+ color_picker_proto.disabled = disabled
+ color_picker_proto.label_visibility.value = get_label_visibility_proto_value(
+ label_visibility
+ )
+
if help is not None:
color_picker_proto.help = dedent(help)
@@ -217,12 +234,6 @@ def _color_picker(
ctx=ctx,
)
- # This needs to be done after register_widget because we don't want
- # the following proto fields to affect a widget's ID.
- color_picker_proto.disabled = disabled
- color_picker_proto.label_visibility.value = get_label_visibility_proto_value(
- label_visibility
- )
if widget_state.value_changed:
color_picker_proto.value = widget_state.value
color_picker_proto.set_value = True
diff --git a/lib/streamlit/elements/widgets/data_editor.py b/lib/streamlit/elements/widgets/data_editor.py
index 6d342c282052..c827646e6990 100644
--- a/lib/streamlit/elements/widgets/data_editor.py
+++ b/lib/streamlit/elements/widgets/data_editor.py
@@ -65,6 +65,7 @@
WidgetKwargs,
register_widget,
)
+from streamlit.runtime.state.common import compute_widget_id
from streamlit.type_util import DataFormat, DataFrameGenericAlias, Key, is_type, to_key
from streamlit.util import calc_md5
@@ -727,6 +728,9 @@ def data_editor(
check_callback_rules(self.dg, on_change)
check_session_state_rules(default_value=None, key=key, writes_allowed=False)
+ if column_order is not None:
+ column_order = list(column_order)
+
column_config_mapping: ColumnConfigMapping = {}
data_format = type_util.determine_data_format(data)
@@ -787,7 +791,28 @@ def data_editor(
# Throws an exception if any of the configured types are incompatible.
_check_type_compatibilities(data_df, column_config_mapping, dataframe_schema)
+ arrow_bytes = type_util.pyarrow_table_to_bytes(arrow_table)
+
+ # We want to do this as early as possible to avoid introducing nondeterminism,
+ # but it isn't clear how much processing is needed to have the data in a
+ # format that will hash consistently, so we do it late here to have it
+ # as close as possible to how it used to be.
+ id = compute_widget_id(
+ "data_editor",
+ user_key=key,
+ data=arrow_bytes,
+ width=width,
+ height=height,
+ use_container_width=use_container_width,
+ column_order=column_order,
+ column_config_mapping=str(column_config_mapping),
+ num_rows=num_rows,
+ key=key,
+ form_id=current_form_id(self.dg),
+ )
+
proto = ArrowProto()
+ proto.id = id
proto.use_container_width = use_container_width
@@ -825,7 +850,7 @@ def data_editor(
data.set_uuid(styler_uuid)
marshall_styler(proto, data, styler_uuid)
- proto.data = type_util.pyarrow_table_to_bytes(arrow_table)
+ proto.data = arrow_bytes
marshall_column_config(proto, column_config_mapping)
diff --git a/lib/streamlit/elements/widgets/file_uploader.py b/lib/streamlit/elements/widgets/file_uploader.py
index b42201d27132..8174ea6f1be1 100644
--- a/lib/streamlit/elements/widgets/file_uploader.py
+++ b/lib/streamlit/elements/widgets/file_uploader.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from dataclasses import dataclass
from textwrap import dedent
from typing import List, Optional, Sequence, Union, cast, overload
@@ -37,11 +36,16 @@
WidgetKwargs,
register_widget,
)
-from streamlit.runtime.uploaded_file_manager import UploadedFile, UploadedFileRec
+from streamlit.runtime.state.common import compute_widget_id
+from streamlit.runtime.uploaded_file_manager import DeletedFile, UploadedFile
from streamlit.type_util import Key, LabelVisibility, maybe_raise_label_warnings, to_key
-SomeUploadedFiles = Optional[Union[UploadedFile, List[UploadedFile]]]
-
+SomeUploadedFiles = Union[
+ UploadedFile,
+ DeletedFile,
+ List[Union[UploadedFile, DeletedFile]],
+ None,
+]
TYPE_PAIRS = [
(".jpg", ".jpeg"),
@@ -52,9 +56,9 @@
]
-def _get_file_recs(
- widget_id: str, widget_value: Optional[FileUploaderStateProto]
-) -> List[UploadedFileRec]:
+def _get_upload_files(
+ widget_value: Optional[FileUploaderStateProto],
+) -> List[Union[UploadedFile, DeletedFile]]:
if widget_value is None:
return []
@@ -66,15 +70,25 @@ def _get_file_recs(
if len(uploaded_file_info) == 0:
return []
- active_file_ids = [f.id for f in uploaded_file_info]
-
- # Grab the files that correspond to our active file IDs.
- return ctx.uploaded_file_mgr.get_files(
+ file_recs_list = ctx.uploaded_file_mgr.get_files(
session_id=ctx.session_id,
- widget_id=widget_id,
- file_ids=active_file_ids,
+ file_ids=[f.file_id for f in uploaded_file_info],
)
+ file_recs = {f.file_id: f for f in file_recs_list}
+
+ collected_files: List[Union[UploadedFile, DeletedFile]] = []
+
+ for f in uploaded_file_info:
+ maybe_file_rec = file_recs.get(f.file_id)
+ if maybe_file_rec is not None:
+ uploaded_file = UploadedFile(maybe_file_rec, f.file_urls)
+ collected_files.append(uploaded_file)
+ else:
+ collected_files.append(DeletedFile(f.file_id))
+
+ return collected_files
+
@dataclass
class FileUploaderSerde:
@@ -83,38 +97,32 @@ class FileUploaderSerde:
def deserialize(
self, ui_value: Optional[FileUploaderStateProto], widget_id: str
) -> SomeUploadedFiles:
- file_recs = _get_file_recs(widget_id, ui_value)
- if len(file_recs) == 0:
- return_value: Optional[Union[List[UploadedFile], UploadedFile]] = (
- [] if self.accept_multiple_files else None
- )
+ upload_files = _get_upload_files(ui_value)
+
+ if len(upload_files) == 0:
+ return_value: SomeUploadedFiles = [] if self.accept_multiple_files else None
else:
- files = [UploadedFile(rec) for rec in file_recs]
- return_value = files if self.accept_multiple_files else files[0]
+ return_value = (
+ upload_files if self.accept_multiple_files else upload_files[0]
+ )
return return_value
def serialize(self, files: SomeUploadedFiles) -> FileUploaderStateProto:
state_proto = FileUploaderStateProto()
- ctx = get_script_run_ctx()
- if ctx is None:
- return state_proto
-
- # ctx.uploaded_file_mgr._file_id_counter stores the id to use for
- # the *next* uploaded file, so the current highest file id is the
- # counter minus 1.
- state_proto.max_file_id = ctx.uploaded_file_mgr._file_id_counter - 1
-
if not files:
return state_proto
elif not isinstance(files, list):
files = [files]
for f in files:
+ if isinstance(f, DeletedFile):
+ continue
file_info: UploadedFileInfoProto = state_proto.uploaded_file_info.add()
- file_info.id = f.id
+ file_info.file_id = f.file_id
file_info.name = f.name
file_info.size = f.size
+ file_info.file_urls.CopyFrom(f._file_urls)
return state_proto
@@ -136,7 +144,7 @@ class FileUploaderMixin:
def file_uploader(
self,
label: str,
- type: Optional[Union[str, Sequence[str]]],
+ type: Union[str, Sequence[str], None],
accept_multiple_files: Literal[True],
key: Optional[Key] = None,
help: Optional[str] = None,
@@ -155,7 +163,7 @@ def file_uploader(
def file_uploader(
self,
label: str,
- type: Optional[Union[str, Sequence[str]]],
+ type: Union[str, Sequence[str], None],
accept_multiple_files: Literal[False] = False,
key: Optional[Key] = None,
help: Optional[str] = None,
@@ -181,7 +189,7 @@ def file_uploader(
label: str,
*,
accept_multiple_files: Literal[True],
- type: Optional[Union[str, Sequence[str]]] = None,
+ type: Union[str, Sequence[str], None] = None,
key: Optional[Key] = None,
help: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
@@ -200,7 +208,7 @@ def file_uploader(
label: str,
*,
accept_multiple_files: Literal[False] = False,
- type: Optional[Union[str, Sequence[str]]] = None,
+ type: Union[str, Sequence[str], None] = None,
key: Optional[Key] = None,
help: Optional[str] = None,
on_change: Optional[WidgetCallback] = None,
@@ -215,7 +223,7 @@ def file_uploader(
def file_uploader(
self,
label: str,
- type: Optional[Union[str, Sequence[str]]] = None,
+ type: Union[str, Sequence[str], None] = None,
accept_multiple_files: bool = False,
key: Optional[Key] = None,
help: Optional[str] = None,
@@ -225,7 +233,7 @@ def file_uploader(
*, # keyword-only arguments:
disabled: bool = False,
label_visibility: LabelVisibility = "visible",
- ) -> SomeUploadedFiles:
+ ) -> Union[UploadedFile, List[UploadedFile], None]:
r"""Display a file uploader widget.
By default, uploaded files are limited to 200MB. You can configure
this using the `server.maxUploadSize` config option. For more info
@@ -251,7 +259,7 @@ def file_uploader(
* Colored text, using the syntax ``:color[text to be colored]``,
where ``color`` needs to be replaced with any of the following
- supported colors: blue, green, orange, red, violet.
+ supported colors: blue, green, orange, red, violet, gray/grey, rainbow.
Unsupported elements are unwrapped so only their children (text contents) render.
Display unsupported elements as literal characters by
@@ -371,7 +379,7 @@ def file_uploader(
def _file_uploader(
self,
label: str,
- type: Optional[Union[str, Sequence[str]]] = None,
+ type: Union[str, Sequence[str], None] = None,
accept_multiple_files: bool = False,
key: Optional[Key] = None,
help: Optional[str] = None,
@@ -382,12 +390,23 @@ def _file_uploader(
label_visibility: LabelVisibility = "visible",
disabled: bool = False,
ctx: Optional[ScriptRunContext] = None,
- ) -> SomeUploadedFiles:
+ ) -> Union[UploadedFile, List[UploadedFile], None]:
key = to_key(key)
check_callback_rules(self.dg, on_change)
check_session_state_rules(default_value=None, key=key, writes_allowed=False)
maybe_raise_label_warnings(label, label_visibility)
+ id = compute_widget_id(
+ "file_uploader",
+ user_key=key,
+ label=label,
+ type=type,
+ accept_multiple_files=accept_multiple_files,
+ key=key,
+ help=help,
+ form_id=current_form_id(self.dg),
+ )
+
if type:
if isinstance(type, str):
type = [type]
@@ -408,6 +427,7 @@ def _file_uploader(
type.append(x)
file_uploader_proto = FileUploaderProto()
+ file_uploader_proto.id = id
file_uploader_proto.label = label
file_uploader_proto.type[:] = type if type is not None else []
file_uploader_proto.max_upload_size_mb = config.get_option(
@@ -415,6 +435,11 @@ def _file_uploader(
)
file_uploader_proto.multiple_files = accept_multiple_files
file_uploader_proto.form_id = current_form_id(self.dg)
+ file_uploader_proto.disabled = disabled
+ file_uploader_proto.label_visibility.value = get_label_visibility_proto_value(
+ label_visibility
+ )
+
if help is not None:
file_uploader_proto.help = dedent(help)
@@ -435,27 +460,15 @@ def _file_uploader(
ctx=ctx,
)
- # This needs to be done after register_widget because we don't want
- # the following proto fields to affect a widget's ID.
- file_uploader_proto.disabled = disabled
- file_uploader_proto.label_visibility.value = get_label_visibility_proto_value(
- label_visibility
- )
+ self.dg._enqueue("file_uploader", file_uploader_proto)
- file_uploader_state = serde.serialize(widget_state.value)
- uploaded_file_info = file_uploader_state.uploaded_file_info
- if ctx is not None and len(uploaded_file_info) != 0:
- newest_file_id = file_uploader_state.max_file_id
- active_file_ids = [f.id for f in uploaded_file_info]
-
- ctx.uploaded_file_mgr.remove_orphaned_files(
- session_id=ctx.session_id,
- widget_id=file_uploader_proto.id,
- newest_file_id=newest_file_id,
- active_file_ids=active_file_ids,
- )
+ filtered_value: Union[UploadedFile, List[UploadedFile], None]
+
+ if isinstance(widget_state.value, DeletedFile):
+ return None
+ elif isinstance(widget_state.value, list):
+ return [f for f in widget_state.value if not isinstance(f, DeletedFile)]
- self.dg._enqueue("file_uploader", file_uploader_proto)
return widget_state.value
@property
diff --git a/lib/streamlit/elements/widgets/multiselect.py b/lib/streamlit/elements/widgets/multiselect.py
index 13eab37451c4..2f0174eb8792 100644
--- a/lib/streamlit/elements/widgets/multiselect.py
+++ b/lib/streamlit/elements/widgets/multiselect.py
@@ -43,6 +43,7 @@
WidgetKwargs,
register_widget,
)
+from streamlit.runtime.state.common import compute_widget_id
from streamlit.type_util import (
Key,
LabelVisibility,
@@ -182,7 +183,7 @@ def multiselect(
* Colored text, using the syntax ``:color[text to be colored]``,
where ``color`` needs to be replaced with any of the following
- supported colors: blue, green, orange, red, violet.
+ supported colors: blue, green, orange, red, violet, gray/grey, rainbow.
Unsupported elements are unwrapped so only their children (text contents) render.
Display unsupported elements as literal characters by
@@ -294,14 +295,35 @@ def _multiselect(
maybe_raise_label_warnings(label, label_visibility)
indices = _check_and_convert_to_indices(opt, default)
+
+ id = compute_widget_id(
+ "multiselect",
+ user_key=key,
+ label=label,
+ options=[str(format_func(option)) for option in opt],
+ default=indices,
+ key=key,
+ help=help,
+ max_selections=max_selections,
+ placeholder=placeholder,
+ form_id=current_form_id(self.dg),
+ )
+
+ default_value: List[int] = [] if indices is None else indices
+
multiselect_proto = MultiSelectProto()
+ multiselect_proto.id = id
multiselect_proto.label = label
- default_value: List[int] = [] if indices is None else indices
multiselect_proto.default[:] = default_value
multiselect_proto.options[:] = [str(format_func(option)) for option in opt]
multiselect_proto.form_id = current_form_id(self.dg)
multiselect_proto.max_selections = max_selections or 0
multiselect_proto.placeholder = placeholder
+ multiselect_proto.disabled = disabled
+ multiselect_proto.label_visibility.value = get_label_visibility_proto_value(
+ label_visibility
+ )
+
if help is not None:
multiselect_proto.help = dedent(help)
@@ -324,12 +346,6 @@ def _multiselect(
_get_over_max_options_message(default_count, max_selections)
)
- # This needs to be done after register_widget because we don't want
- # the following proto fields to affect a widget's ID.
- multiselect_proto.disabled = disabled
- multiselect_proto.label_visibility.value = get_label_visibility_proto_value(
- label_visibility
- )
if widget_state.value_changed:
multiselect_proto.value[:] = serde.serialize(widget_state.value)
multiselect_proto.set_value = True
diff --git a/lib/streamlit/elements/widgets/number_input.py b/lib/streamlit/elements/widgets/number_input.py
index bbcdcc424cca..8d7db13ad45b 100644
--- a/lib/streamlit/elements/widgets/number_input.py
+++ b/lib/streamlit/elements/widgets/number_input.py
@@ -36,6 +36,7 @@
WidgetKwargs,
register_widget,
)
+from streamlit.runtime.state.common import compute_widget_id
from streamlit.type_util import Key, LabelVisibility, maybe_raise_label_warnings, to_key
Number = Union[int, float]
@@ -108,7 +109,7 @@ def number_input(
* Colored text, using the syntax ``:color[text to be colored]``,
where ``color`` needs to be replaced with any of the following
- supported colors: blue, green, orange, red, violet.
+ supported colors: blue, green, orange, red, violet, gray/grey, rainbow.
Unsupported elements are unwrapped so only their children (text contents) render.
Display unsupported elements as literal characters by
@@ -217,6 +218,21 @@ def _number_input(
default_value=None if isinstance(value, NoValue) else value, key=key
)
maybe_raise_label_warnings(label, label_visibility)
+
+ id = compute_widget_id(
+ "number_input",
+ user_key=key,
+ label=label,
+ min_value=min_value,
+ max_value=max_value,
+ value=value,
+ step=step,
+ format=format,
+ key=key,
+ help=help,
+ form_id=current_form_id(self.dg),
+ )
+
# Ensure that all arguments are of the same type.
number_input_args = [min_value, max_value, value, step]
@@ -328,10 +344,16 @@ def _number_input(
data_type = NumberInputProto.INT if all_ints else NumberInputProto.FLOAT
number_input_proto = NumberInputProto()
+ number_input_proto.id = id
number_input_proto.data_type = data_type
number_input_proto.label = label
number_input_proto.default = value
number_input_proto.form_id = current_form_id(self.dg)
+ number_input_proto.disabled = disabled
+ number_input_proto.label_visibility.value = get_label_visibility_proto_value(
+ label_visibility
+ )
+
if help is not None:
number_input_proto.help = dedent(help)
@@ -362,13 +384,6 @@ def _number_input(
ctx=ctx,
)
- # This needs to be done after register_widget because we don't want
- # the following proto fields to affect a widget's ID.
- number_input_proto.disabled = disabled
- number_input_proto.label_visibility.value = get_label_visibility_proto_value(
- label_visibility
- )
-
if widget_state.value_changed:
number_input_proto.value = widget_state.value
number_input_proto.set_value = True
diff --git a/lib/streamlit/elements/widgets/radio.py b/lib/streamlit/elements/widgets/radio.py
index 3f09648decb3..247fcb0a0f39 100644
--- a/lib/streamlit/elements/widgets/radio.py
+++ b/lib/streamlit/elements/widgets/radio.py
@@ -32,6 +32,7 @@
WidgetKwargs,
register_widget,
)
+from streamlit.runtime.state.common import compute_widget_id
from streamlit.type_util import (
Key,
LabelVisibility,
@@ -87,6 +88,7 @@ def radio(
*, # keyword-only args:
disabled: bool = False,
horizontal: bool = False,
+ captions: Optional[Sequence[str]] = None,
label_visibility: LabelVisibility = "visible",
) -> Optional[T]:
r"""Display a radio button widget.
@@ -110,7 +112,7 @@ def radio(
* Colored text, using the syntax ``:color[text to be colored]``,
where ``color`` needs to be replaced with any of the following
- supported colors: blue, green, orange, red, violet.
+ supported colors: blue, green, orange, red, violet, gray/grey, rainbow.
Unsupported elements are unwrapped so only their children (text contents) render.
Display unsupported elements as literal characters by
@@ -120,8 +122,10 @@ def radio(
but hide it with label_visibility if needed. In the future, we may disallow
empty labels by raising an exception.
options : Sequence, numpy.ndarray, pandas.Series, pandas.DataFrame, or pandas.Index
- Labels for the radio options. This will be cast to str internally
- by default. For pandas.DataFrame, the first column is selected.
+ Labels for the radio options. Labels can include markdown as
+ described in the ``label`` parameter and will be cast to str
+ internally by default. For pandas.DataFrame, the first column is
+ selected.
index : int
The index of the preselected option on first render.
format_func : function
@@ -150,7 +154,9 @@ def radio(
An optional boolean, which orients the radio group horizontally.
The default is false (vertical buttons). This argument can only
be supplied by keyword.
-
+ captions : iterable of str or None
+ A list of captions to show below each radio button. If None (default),
+ no captions are shown.
label_visibility : "visible", "hidden", or "collapsed"
The visibility of the label. If "hidden", the label doesn't show but there
is still empty space for it above the widget (equivalent to label="").
@@ -167,17 +173,18 @@ def radio(
>>> import streamlit as st
>>>
>>> genre = st.radio(
- ... "What\'s your favorite movie genre",
- ... ('Comedy', 'Drama', 'Documentary'))
+ ... "What's your favorite movie genre",
+ ... [":rainbow[Comedy]", "***Drama***", "Documentary :movie_camera:"],
+ ... captions = ["Laugh out loud.", "Get the popcorn.", "Never stop learning."])
>>>
- >>> if genre == 'Comedy':
+ >>> if genre == ':rainbow[Comedy]':
... st.write('You selected comedy.')
... else:
... st.write("You didn\'t select comedy.")
.. output::
https://doc-radio.streamlit.app/
- height: 260px
+ height: 300px
"""
ctx = get_script_run_ctx()
@@ -193,8 +200,9 @@ def radio(
kwargs=kwargs,
disabled=disabled,
horizontal=horizontal,
- ctx=ctx,
+ captions=captions,
label_visibility=label_visibility,
+ ctx=ctx,
)
def _radio(
@@ -212,6 +220,7 @@ def _radio(
disabled: bool = False,
horizontal: bool = False,
label_visibility: LabelVisibility = "visible",
+ captions: Optional[Sequence[str]] = None,
ctx: Optional[ScriptRunContext],
) -> Optional[T]:
key = to_key(key)
@@ -220,6 +229,19 @@ def _radio(
maybe_raise_label_warnings(label, label_visibility)
opt = ensure_indexable(options)
+ id = compute_widget_id(
+ "radio",
+ user_key=key,
+ label=label,
+ options=[str(format_func(option)) for option in opt],
+ index=index,
+ key=key,
+ help=help,
+ horizontal=horizontal,
+ captions=captions,
+ form_id=current_form_id(self.dg),
+ )
+
if not isinstance(index, int):
raise StreamlitAPIException(
"Radio Value has invalid type: %s" % type(index).__name__
@@ -230,12 +252,31 @@ def _radio(
"Radio index must be between 0 and length of options"
)
+ def handle_captions(caption: Optional[str]) -> str:
+ if caption is None:
+ return ""
+ elif isinstance(caption, str):
+ return caption
+ else:
+ raise StreamlitAPIException(
+ f"Radio captions must be strings. Passed type: {type(caption).__name__}"
+ )
+
radio_proto = RadioProto()
+ radio_proto.id = id
radio_proto.label = label
radio_proto.default = index
radio_proto.options[:] = [str(format_func(option)) for option in opt]
radio_proto.form_id = current_form_id(self.dg)
radio_proto.horizontal = horizontal
+ radio_proto.disabled = disabled
+ radio_proto.label_visibility.value = get_label_visibility_proto_value(
+ label_visibility
+ )
+
+ if captions is not None:
+ radio_proto.captions[:] = map(handle_captions, captions)
+
if help is not None:
radio_proto.help = dedent(help)
@@ -253,13 +294,6 @@ def _radio(
ctx=ctx,
)
- # This needs to be done after register_widget because we don't want
- # the following proto fields to affect a widget's ID.
- radio_proto.disabled = disabled
- radio_proto.label_visibility.value = get_label_visibility_proto_value(
- label_visibility
- )
-
if widget_state.value_changed:
radio_proto.value = serde.serialize(widget_state.value)
radio_proto.set_value = True
diff --git a/lib/streamlit/elements/widgets/select_slider.py b/lib/streamlit/elements/widgets/select_slider.py
index b2cf4a9a40c8..ba7c79666baf 100644
--- a/lib/streamlit/elements/widgets/select_slider.py
+++ b/lib/streamlit/elements/widgets/select_slider.py
@@ -45,6 +45,7 @@
WidgetKwargs,
register_widget,
)
+from streamlit.runtime.state.common import compute_widget_id
from streamlit.type_util import (
Key,
LabelVisibility,
@@ -149,7 +150,7 @@ def select_slider(
* Colored text, using the syntax ``:color[text to be colored]``,
where ``color`` needs to be replaced with any of the following
- supported colors: blue, green, orange, red, violet.
+ supported colors: blue, green, orange, red, violet, gray/grey, rainbow.
Unsupported elements are unwrapped so only their children (text contents) render.
Display unsupported elements as literal characters by
@@ -284,7 +285,19 @@ def as_index_list(v: object) -> List[int]:
# Convert element to index of the elements
slider_value = as_index_list(value)
+ id = compute_widget_id(
+ "select_slider",
+ user_key=key,
+ label=label,
+ options=[str(format_func(option)) for option in opt],
+ value=slider_value,
+ key=key,
+ help=help,
+ form_id=current_form_id(self.dg),
+ )
+
slider_proto = SliderProto()
+ slider_proto.id = id
slider_proto.type = SliderProto.Type.SELECT_SLIDER
slider_proto.label = label
slider_proto.format = "%s"
@@ -295,6 +308,10 @@ def as_index_list(v: object) -> List[int]:
slider_proto.data_type = SliderProto.INT
slider_proto.options[:] = [str(format_func(option)) for option in opt]
slider_proto.form_id = current_form_id(self.dg)
+ slider_proto.disabled = disabled
+ slider_proto.label_visibility.value = get_label_visibility_proto_value(
+ label_visibility
+ )
if help is not None:
slider_proto.help = dedent(help)
@@ -312,12 +329,6 @@ def as_index_list(v: object) -> List[int]:
ctx=ctx,
)
- # This needs to be done after register_widget because we don't want
- # the following proto fields to affect a widget's ID.
- slider_proto.disabled = disabled
- slider_proto.label_visibility.value = get_label_visibility_proto_value(
- label_visibility
- )
if widget_state.value_changed:
slider_proto.value[:] = serde.serialize(widget_state.value)
slider_proto.set_value = True
diff --git a/lib/streamlit/elements/widgets/selectbox.py b/lib/streamlit/elements/widgets/selectbox.py
index 69046c162b4a..f4d1352913f3 100644
--- a/lib/streamlit/elements/widgets/selectbox.py
+++ b/lib/streamlit/elements/widgets/selectbox.py
@@ -32,6 +32,7 @@
WidgetKwargs,
register_widget,
)
+from streamlit.runtime.state.common import compute_widget_id
from streamlit.type_util import (
Key,
LabelVisibility,
@@ -106,7 +107,7 @@ def selectbox(
* Colored text, using the syntax ``:color[text to be colored]``,
where ``color`` needs to be replaced with any of the following
- supported colors: blue, green, orange, red, violet.
+ supported colors: blue, green, orange, red, violet, gray/grey, rainbow.
Unsupported elements are unwrapped so only their children (text contents) render.
Display unsupported elements as literal characters by
@@ -213,6 +214,18 @@ def _selectbox(
opt = ensure_indexable(options)
+ id = compute_widget_id(
+ "selectbox",
+ user_key=key,
+ label=label,
+ options=[str(format_func(option)) for option in opt],
+ index=index,
+ key=key,
+ help=help,
+ placeholder=placeholder,
+ form_id=current_form_id(self.dg),
+ )
+
if not isinstance(index, int):
raise StreamlitAPIException(
"Selectbox Value has invalid type: %s" % type(index).__name__
@@ -224,11 +237,17 @@ def _selectbox(
)
selectbox_proto = SelectboxProto()
+ selectbox_proto.id = id
selectbox_proto.label = label
selectbox_proto.default = index
selectbox_proto.options[:] = [str(format_func(option)) for option in opt]
selectbox_proto.form_id = current_form_id(self.dg)
selectbox_proto.placeholder = placeholder
+ selectbox_proto.disabled = disabled
+ selectbox_proto.label_visibility.value = get_label_visibility_proto_value(
+ label_visibility
+ )
+
if help is not None:
selectbox_proto.help = dedent(help)
@@ -246,12 +265,6 @@ def _selectbox(
ctx=ctx,
)
- # This needs to be done after register_widget because we don't want
- # the following proto fields to affect a widget's ID.
- selectbox_proto.disabled = disabled
- selectbox_proto.label_visibility.value = get_label_visibility_proto_value(
- label_visibility
- )
if widget_state.value_changed:
selectbox_proto.value = serde.serialize(widget_state.value)
selectbox_proto.set_value = True
diff --git a/lib/streamlit/elements/widgets/slider.py b/lib/streamlit/elements/widgets/slider.py
index 8fa1b89ebcc1..63a22eee02f4 100644
--- a/lib/streamlit/elements/widgets/slider.py
+++ b/lib/streamlit/elements/widgets/slider.py
@@ -48,6 +48,7 @@
get_session_state,
register_widget,
)
+from streamlit.runtime.state.common import compute_widget_id
from streamlit.type_util import Key, LabelVisibility, maybe_raise_label_warnings, to_key
if TYPE_CHECKING:
@@ -229,7 +230,7 @@ def slider(
* Colored text, using the syntax ``:color[text to be colored]``,
where ``color`` needs to be replaced with any of the following
- supported colors: blue, green, orange, red, violet.
+ supported colors: blue, green, orange, red, violet, gray/grey, rainbow.
Unsupported elements are unwrapped so only their children (text contents) render.
Display unsupported elements as literal characters by
@@ -376,18 +377,19 @@ def _slider(
maybe_raise_label_warnings(label, label_visibility)
- if value is None:
- # Set value from session_state if exists.
- session_state = get_session_state().filtered_state
-
- # we look first to session_state value of the widget because
- # depending on the value (single value or list/tuple) the slider should be
- # initializing differently (either as range or single value slider)
- if key is not None and key in session_state:
- value = session_state[key]
- else:
- # Set value default.
- value = min_value if min_value is not None else 0
+ id = compute_widget_id(
+ "slider",
+ user_key=key,
+ label=label,
+ min_value=min_value,
+ max_value=max_value,
+ value=value,
+ step=step,
+ format=format,
+ key=key,
+ help=help,
+ form_id=current_form_id(self.dg),
+ )
SUPPORTED_TYPES = {
Integral: SliderProto.INT,
@@ -398,6 +400,27 @@ def _slider(
}
TIMELIKE_TYPES = (SliderProto.DATETIME, SliderProto.TIME, SliderProto.DATE)
+ if value is None:
+ # We need to know if this is a single or range slider, but don't have
+ # a default value, so we check if session_state can tell us.
+ # We already calcluated the id, so there is no risk of this causing
+ # the id to change.
+
+ single_value = True
+
+ session_state = get_session_state().filtered_state
+
+ if key is not None and key in session_state:
+ state_value = session_state[key]
+ single_value = isinstance(state_value, tuple(SUPPORTED_TYPES.keys()))
+
+ if single_value:
+ value = min_value if min_value is not None else 0
+ else:
+ mn = min_value if min_value is not None else 0
+ mx = max_value if max_value is not None else 100
+ value = [mn, mx]
+
# Ensure that the value is either a single value or a range of values.
single_value = isinstance(value, tuple(SUPPORTED_TYPES.keys()))
range_value = isinstance(value, (list, tuple)) and len(value) in (0, 1, 2)
@@ -609,6 +632,7 @@ def all_same_type(items):
slider_proto = SliderProto()
slider_proto.type = SliderProto.Type.SLIDER
+ slider_proto.id = id
slider_proto.label = label
slider_proto.format = format
slider_proto.default[:] = value
@@ -618,6 +642,11 @@ def all_same_type(items):
slider_proto.data_type = data_type
slider_proto.options[:] = []
slider_proto.form_id = current_form_id(self.dg)
+ slider_proto.disabled = disabled
+ slider_proto.label_visibility.value = get_label_visibility_proto_value(
+ label_visibility
+ )
+
if help is not None:
slider_proto.help = dedent(help)
@@ -635,13 +664,6 @@ def all_same_type(items):
ctx=ctx,
)
- # This needs to be done after register_widget because we don't want
- # the following proto fields to affect a widget's ID.
- slider_proto.disabled = disabled
- slider_proto.label_visibility.value = get_label_visibility_proto_value(
- label_visibility
- )
-
if widget_state.value_changed:
slider_proto.value[:] = serde.serialize(widget_state.value)
slider_proto.set_value = True
diff --git a/lib/streamlit/elements/widgets/text_widgets.py b/lib/streamlit/elements/widgets/text_widgets.py
index 69b583540ca8..719766011bb3 100644
--- a/lib/streamlit/elements/widgets/text_widgets.py
+++ b/lib/streamlit/elements/widgets/text_widgets.py
@@ -36,6 +36,7 @@
WidgetKwargs,
register_widget,
)
+from streamlit.runtime.state.common import compute_widget_id
from streamlit.type_util import (
Key,
LabelVisibility,
@@ -107,7 +108,7 @@ def text_input(
* Colored text, using the syntax ``:color[text to be colored]``,
where ``color`` needs to be replaced with any of the following
- supported colors: blue, green, orange, red, violet.
+ supported colors: blue, green, orange, red, violet, gray/grey, rainbow.
Unsupported elements are unwrapped so only their children (text contents) render.
Display unsupported elements as literal characters by
@@ -214,10 +215,29 @@ def _text_input(
maybe_raise_label_warnings(label, label_visibility)
+ id = compute_widget_id(
+ "text_input",
+ user_key=key,
+ label=label,
+ value=str(value),
+ max_chars=max_chars,
+ key=key,
+ type=type,
+ help=help,
+ autocomplete=autocomplete,
+ placeholder=str(placeholder),
+ form_id=current_form_id(self.dg),
+ )
+
text_input_proto = TextInputProto()
+ text_input_proto.id = id
text_input_proto.label = label
text_input_proto.default = str(value)
text_input_proto.form_id = current_form_id(self.dg)
+ text_input_proto.disabled = disabled
+ text_input_proto.label_visibility.value = get_label_visibility_proto_value(
+ label_visibility
+ )
if help is not None:
text_input_proto.help = dedent(help)
@@ -258,12 +278,6 @@ def _text_input(
ctx=ctx,
)
- # This needs to be done after register_widget because we don't want
- # the following proto fields to affect a widget's ID.
- text_input_proto.disabled = disabled
- text_input_proto.label_visibility.value = get_label_visibility_proto_value(
- label_visibility
- )
if widget_state.value_changed:
text_input_proto.value = widget_state.value
text_input_proto.set_value = True
@@ -309,7 +323,7 @@ def text_area(
* Colored text, using the syntax ``:color[text to be colored]``,
where ``color`` needs to be replaced with any of the following
- supported colors: blue, green, orange, red, violet.
+ supported colors: blue, green, orange, red, violet, gray/grey, rainbow.
Unsupported elements are unwrapped so only their children (text contents) render.
Display unsupported elements as literal characters by
@@ -410,10 +424,28 @@ def _text_area(
maybe_raise_label_warnings(label, label_visibility)
+ id = compute_widget_id(
+ "text_area",
+ user_key=key,
+ label=label,
+ value=str(value),
+ height=height,
+ max_chars=max_chars,
+ key=key,
+ help=help,
+ placeholder=str(placeholder),
+ form_id=current_form_id(self.dg),
+ )
+
text_area_proto = TextAreaProto()
+ text_area_proto.id = id
text_area_proto.label = label
text_area_proto.default = str(value)
text_area_proto.form_id = current_form_id(self.dg)
+ text_area_proto.disabled = disabled
+ text_area_proto.label_visibility.value = get_label_visibility_proto_value(
+ label_visibility
+ )
if help is not None:
text_area_proto.help = dedent(help)
@@ -440,12 +472,6 @@ def _text_area(
ctx=ctx,
)
- # This needs to be done after register_widget because we don't want
- # the following proto fields to affect a widget's ID.
- text_area_proto.disabled = disabled
- text_area_proto.label_visibility.value = get_label_visibility_proto_value(
- label_visibility
- )
if widget_state.value_changed:
text_area_proto.value = widget_state.value
text_area_proto.set_value = True
diff --git a/lib/streamlit/elements/widgets/time_widgets.py b/lib/streamlit/elements/widgets/time_widgets.py
index 75fe0f962ac1..ad40dd34bb0b 100644
--- a/lib/streamlit/elements/widgets/time_widgets.py
+++ b/lib/streamlit/elements/widgets/time_widgets.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import annotations
import re
from dataclasses import dataclass
@@ -38,6 +39,7 @@
WidgetKwargs,
register_widget,
)
+from streamlit.runtime.state.common import compute_widget_id
from streamlit.type_util import Key, LabelVisibility, maybe_raise_label_warnings, to_key
if TYPE_CHECKING:
@@ -250,7 +252,7 @@ def time_input(
* Colored text, using the syntax ``:color[text to be colored]``,
where ``color`` needs to be replaced with any of the following
- supported colors: blue, green, orange, red, violet.
+ supported colors: blue, green, orange, red, violet, gray/grey, rainbow.
Unsupported elements are unwrapped so only their children (text contents) render.
Display unsupported elements as literal characters by
@@ -353,9 +355,21 @@ def _time_input(
raise StreamlitAPIException(
"The type of value should be one of datetime, time or None"
)
+
+ id = compute_widget_id(
+ "time_input",
+ user_key=key,
+ label=label,
+ value=(None if value is None else parsed_time),
+ key=key,
+ help=help,
+ step=step,
+ form_id=current_form_id(self.dg),
+ )
del value
time_input_proto = TimeInputProto()
+ time_input_proto.id = id
time_input_proto.label = label
time_input_proto.default = time.strftime(parsed_time, "%H:%M")
time_input_proto.form_id = current_form_id(self.dg)
@@ -370,6 +384,11 @@ def _time_input(
f"`step` must be between 60 seconds and 23 hours but is currently set to {step} seconds."
)
time_input_proto.step = step
+ time_input_proto.disabled = disabled
+ time_input_proto.label_visibility.value = get_label_visibility_proto_value(
+ label_visibility
+ )
+
if help is not None:
time_input_proto.help = dedent(help)
@@ -386,12 +405,6 @@ def _time_input(
ctx=ctx,
)
- # This needs to be done after register_widget because we don't want
- # the following proto fields to affect a widget's ID.
- time_input_proto.disabled = disabled
- time_input_proto.label_visibility.value = get_label_visibility_proto_value(
- label_visibility
- )
if widget_state.value_changed:
time_input_proto.value = serde.serialize(widget_state.value)
time_input_proto.set_value = True
@@ -437,7 +450,7 @@ def date_input(
* Colored text, using the syntax ``:color[text to be colored]``,
where ``color`` needs to be replaced with any of the following
- supported colors: blue, green, orange, red, violet.
+ supported colors: blue, green, orange, red, violet, gray/grey, rainbow.
Unsupported elements are unwrapped so only their children (text contents) render.
Display unsupported elements as literal characters by
@@ -562,6 +575,36 @@ def _date_input(
maybe_raise_label_warnings(label, label_visibility)
+ def parse_date_deterministic(v: SingleDateValue) -> str | None:
+ if v is None:
+ return None
+ elif isinstance(v, datetime):
+ return date.strftime(v.date(), "%Y/%m/%d")
+ elif isinstance(v, date):
+ return date.strftime(v, "%Y/%m/%d")
+
+ parsed_min_date = parse_date_deterministic(min_value)
+ parsed_max_date = parse_date_deterministic(max_value)
+
+ if isinstance(value, datetime) or isinstance(value, date) or value is None:
+ parsed: str | None | List[str | None] = parse_date_deterministic(value)
+ else:
+ parsed = [parse_date_deterministic(v) for v in value]
+
+ # TODO this is missing the error path, integrate with the dateinputvalues parsing
+
+ id = compute_widget_id(
+ "date_input",
+ user_key=key,
+ label=label,
+ value=parsed,
+ min_value=parsed_min_date,
+ max_value=parsed_max_date,
+ key=key,
+ help=help,
+ format=format,
+ form_id=current_form_id(self.dg),
+ )
if not bool(ALLOWED_DATE_FORMATS.match(format)):
raise StreamlitAPIException(
f"The provided format (`{format}`) is not valid. DateInput format "
@@ -577,21 +620,24 @@ def _date_input(
del value, min_value, max_value
date_input_proto = DateInputProto()
+ date_input_proto.id = id
date_input_proto.is_range = parsed_values.is_range
- if help is not None:
- date_input_proto.help = dedent(help)
-
+ date_input_proto.disabled = disabled
+ date_input_proto.label_visibility.value = get_label_visibility_proto_value(
+ label_visibility
+ )
date_input_proto.format = format
date_input_proto.label = label
date_input_proto.default[:] = [
date.strftime(v, "%Y/%m/%d") for v in parsed_values.value
]
-
date_input_proto.min = date.strftime(parsed_values.min, "%Y/%m/%d")
date_input_proto.max = date.strftime(parsed_values.max, "%Y/%m/%d")
-
date_input_proto.form_id = current_form_id(self.dg)
+ if help is not None:
+ date_input_proto.help = dedent(help)
+
serde = DateInputSerde(parsed_values)
widget_state = register_widget(
@@ -606,12 +652,6 @@ def _date_input(
ctx=ctx,
)
- # This needs to be done after register_widget because we don't want
- # the following proto fields to affect a widget's ID.
- date_input_proto.disabled = disabled
- date_input_proto.label_visibility.value = get_label_visibility_proto_value(
- label_visibility
- )
if widget_state.value_changed:
date_input_proto.value[:] = serde.serialize(widget_state.value)
date_input_proto.set_value = True
diff --git a/lib/streamlit/external/langchain/mutable_expander.py b/lib/streamlit/external/langchain/mutable_expander.py
deleted file mode 100644
index 3f36c2eebd2e..000000000000
--- a/lib/streamlit/external/langchain/mutable_expander.py
+++ /dev/null
@@ -1,172 +0,0 @@
-# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from __future__ import annotations
-
-from enum import Enum
-from typing import Any, NamedTuple, Optional
-
-from streamlit.delta_generator import DeltaGenerator
-from streamlit.type_util import SupportsStr
-
-
-class ChildType(Enum):
- MARKDOWN = "MARKDOWN"
- EXCEPTION = "EXCEPTION"
-
-
-class ChildRecord(NamedTuple):
- type: ChildType
- kwargs: dict[str, Any]
- dg: DeltaGenerator
-
-
-class MutableExpander:
- """A Streamlit expander that can be renamed and dynamically expanded/collapsed.
- Used internally by StreamlitCallbackHandler.
-
- NB: this class's functionality is tested only indirectly by `streamlit_callback_handler_test.py`.
- It's not currently intended for use outside of StreamlitCallbackHandler. If it
- becomes more broadly useful, we should consider turning it into a "proper" Streamlit
- that can support mutations without rebuilding its entire state.
- """
-
- def __init__(self, parent_container: DeltaGenerator, label: str, expanded: bool):
- """Create a new MutableExpander.
-
- Parameters
- ----------
- parent_container
- The `st.container` that the expander will be created inside.
-
- The expander transparently deletes and recreates its underlying
- `st.expander` instance when its label changes, and it uses
- `parent_container` to ensure it recreates this underlying expander in the
- same location onscreen.
- label
- The expander's initial label.
- expanded
- The expander's initial `expanded` value.
- """
- self._label = label
- self._expanded = expanded
- self._parent_cursor = parent_container.empty()
- self._container = self._parent_cursor.expander(label, expanded)
- self._child_records: list[ChildRecord] = []
-
- @property
- def label(self) -> str:
- """The expander's label string."""
- return self._label
-
- @property
- def expanded(self) -> bool:
- """True if the expander was created with `expanded=True`."""
- return self._expanded
-
- def clear(self) -> None:
- """Remove the container and its contents entirely. A cleared container can't
- be reused.
- """
- self._container = self._parent_cursor.empty()
- self._child_records.clear()
-
- def append_copy(self, other: MutableExpander) -> None:
- """Append a copy of another MutableExpander's children to this
- MutableExpander.
- """
- other_records = other._child_records.copy()
- for record in other_records:
- self._create_child(record.type, record.kwargs)
-
- def update(
- self, *, new_label: Optional[str] = None, new_expanded: Optional[bool] = None
- ) -> None:
- """Change the expander's label and expanded state"""
- if new_label is None:
- new_label = self._label
- if new_expanded is None:
- new_expanded = self._expanded
-
- if self._label == new_label and self._expanded == new_expanded:
- # No change!
- return
-
- self._label = new_label
- self._expanded = new_expanded
- self._container = self._parent_cursor.expander(new_label, new_expanded)
-
- prev_records = self._child_records
- self._child_records = []
-
- # Replay all children into the new container
- for record in prev_records:
- self._create_child(record.type, record.kwargs)
-
- def markdown(
- self,
- body: SupportsStr,
- unsafe_allow_html: bool = False,
- *,
- help: Optional[str] = None,
- index: Optional[int] = None,
- ) -> int:
- """Add a Markdown element to the container and return its index."""
- kwargs = {"body": body, "unsafe_allow_html": unsafe_allow_html, "help": help}
- new_dg = self._get_dg(index).markdown(**kwargs) # type: ignore[arg-type]
- record = ChildRecord(ChildType.MARKDOWN, kwargs, new_dg)
- return self._add_record(record, index)
-
- def exception(
- self, exception: BaseException, *, index: Optional[int] = None
- ) -> int:
- """Add an Exception element to the container and return its index."""
- kwargs = {"exception": exception}
- new_dg = self._get_dg(index).exception(**kwargs)
- record = ChildRecord(ChildType.EXCEPTION, kwargs, new_dg)
- return self._add_record(record, index)
-
- def _create_child(self, type: ChildType, kwargs: dict[str, Any]) -> None:
- """Create a new child with the given params"""
- if type == ChildType.MARKDOWN:
- self.markdown(**kwargs)
- elif type == ChildType.EXCEPTION:
- self.exception(**kwargs)
- else:
- raise RuntimeError(f"Unexpected child type {type}")
-
- def _add_record(self, record: ChildRecord, index: Optional[int]) -> int:
- """Add a ChildRecord to self._children. If `index` is specified, replace
- the existing record at that index. Otherwise, append the record to the
- end of the list.
-
- Return the index of the added record.
- """
- if index is not None:
- # Replace existing child
- self._child_records[index] = record
- return index
-
- # Append new child
- self._child_records.append(record)
- return len(self._child_records) - 1
-
- def _get_dg(self, index: Optional[int]) -> DeltaGenerator:
- if index is not None:
- # Existing index: reuse child's DeltaGenerator
- assert 0 <= index < len(self._child_records), f"Bad index: {index}"
- return self._child_records[index].dg
-
- # No index: use container's DeltaGenerator
- return self._container
diff --git a/lib/streamlit/external/langchain/streamlit_callback_handler.py b/lib/streamlit/external/langchain/streamlit_callback_handler.py
index 5570f08186f5..70b96dcaad35 100644
--- a/lib/streamlit/external/langchain/streamlit_callback_handler.py
+++ b/lib/streamlit/external/langchain/streamlit_callback_handler.py
@@ -31,16 +31,19 @@
from __future__ import annotations
+import time
from enum import Enum
-from typing import Any, Dict, List, NamedTuple, Optional, Union
+from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
-from streamlit.delta_generator import DeltaGenerator
-from streamlit.external.langchain.mutable_expander import MutableExpander
from streamlit.runtime.metrics_util import gather_metrics
+if TYPE_CHECKING:
+ from streamlit.delta_generator import DeltaGenerator
+ from streamlit.elements.lib.mutable_status_container import StatusContainer
+
def _convert_newlines(text: str) -> str:
"""Convert newline characters to markdown newline sequences
@@ -49,11 +52,6 @@ def _convert_newlines(text: str) -> str:
return text.replace("\n", " \n")
-CHECKMARK_EMOJI = "✅"
-THINKING_EMOJI = "🤔"
-HISTORY_EMOJI = "📚"
-EXCEPTION_EMOJI = "⚠️"
-
# The maximum length of the "input_str" portion of a tool label.
# Strings that are longer than this will be truncated with "..."
MAX_TOOL_INPUT_STR_LENGTH = 60
@@ -66,6 +64,8 @@ class LLMThoughtState(Enum):
RUNNING_TOOL = "RUNNING_TOOL"
# We have results from the tool.
COMPLETE = "COMPLETE"
+ # The LLM completed with an error.
+ ERROR = "ERROR"
class ToolRecord(NamedTuple):
@@ -84,7 +84,7 @@ def get_initial_label(self) -> str:
"""Return the markdown label for a new LLMThought that doesn't have
an associated tool yet.
"""
- return f"{THINKING_EMOJI} **Thinking...**"
+ return "Thinking..."
def get_tool_label(self, tool: ToolRecord, is_complete: bool) -> str:
"""Return the label for an LLMThought that has an associated
@@ -104,32 +104,23 @@ def get_tool_label(self, tool: ToolRecord, is_complete: bool) -> str:
The markdown label for the thought's container.
"""
- input = tool.input_str
+ input_str = tool.input_str
name = tool.name
- emoji = CHECKMARK_EMOJI if is_complete else THINKING_EMOJI
if name == "_Exception":
- emoji = EXCEPTION_EMOJI
name = "Parsing error"
- input_str_len = min(MAX_TOOL_INPUT_STR_LENGTH, len(input))
- input = input[0:input_str_len]
+ input_str_len = min(MAX_TOOL_INPUT_STR_LENGTH, len(input_str))
+ input_str = input_str[:input_str_len]
if len(tool.input_str) > input_str_len:
- input = input + "..."
- input = input.replace("\n", " ")
- label = f"{emoji} **{name}:** {input}"
- return label
-
- def get_history_label(self) -> str:
- """Return a markdown label for the special 'history' container
- that contains overflow thoughts.
- """
- return f"{HISTORY_EMOJI} **History**"
+ input_str = input_str + "..."
+ input_str = input_str.replace("\n", " ")
+ return f"**{name}:** {input_str}"
def get_final_agent_thought_label(self) -> str:
"""Return the markdown label for the agent's final thought -
the "Now I have the answer" thought, that doesn't involve
a tool.
"""
- return f"{CHECKMARK_EMOJI} **Complete!**"
+ return "**Complete!**"
class LLMThought:
@@ -147,20 +138,19 @@ def __init__(
expanded: bool,
collapse_on_complete: bool,
):
- self._container = MutableExpander(
- parent_container=parent_container,
- label=labeler.get_initial_label(),
- expanded=expanded,
+ self._container = parent_container.status(
+ labeler.get_initial_label(), expanded=expanded
)
+
self._state = LLMThoughtState.THINKING
self._llm_token_stream = ""
- self._llm_token_writer_idx: Optional[int] = None
+ self._llm_token_stream_placeholder: Optional[DeltaGenerator] = None
self._last_tool: Optional[ToolRecord] = None
self._collapse_on_complete = collapse_on_complete
self._labeler = labeler
@property
- def container(self) -> MutableExpander:
+ def container(self) -> "StatusContainer":
"""The container we're writing into."""
return self._container
@@ -170,8 +160,11 @@ def last_tool(self) -> Optional[ToolRecord]:
return self._last_tool
def _reset_llm_token_stream(self) -> None:
+ if self._llm_token_stream_placeholder is not None:
+ self._llm_token_stream_placeholder.markdown(self._llm_token_stream)
+
self._llm_token_stream = ""
- self._llm_token_writer_idx = None
+ self._llm_token_stream_placeholder = None
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str]) -> None:
self._reset_llm_token_stream()
@@ -179,9 +172,9 @@ def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str]) -> None:
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
# This is only called when the LLM is initialized with `streaming=True`
self._llm_token_stream += _convert_newlines(token)
- self._llm_token_writer_idx = self._container.markdown(
- self._llm_token_stream, index=self._llm_token_writer_idx
- )
+ if self._llm_token_stream_placeholder is None:
+ self._llm_token_stream_placeholder = self._container.empty()
+ self._llm_token_stream_placeholder.markdown(self._llm_token_stream + "▕")
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
# `response` is the concatenation of all the tokens received by the LLM.
@@ -194,6 +187,7 @@ def on_llm_error(
) -> None:
self._container.markdown("**LLM encountered an error...**")
self._container.exception(error)
+ self._state = LLMThoughtState.ERROR
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
@@ -204,7 +198,8 @@ def on_tool_start(
tool_name = serialized["name"]
self._last_tool = ToolRecord(name=tool_name, input_str=input_str)
self._container.update(
- new_label=self._labeler.get_tool_label(self._last_tool, is_complete=False)
+ label=self._labeler.get_tool_label(self._last_tool, is_complete=False),
+ state="running",
)
def on_tool_end(
@@ -215,13 +210,14 @@ def on_tool_end(
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
- self._container.markdown(f"**{output}**")
+ self._container.markdown(output)
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
self._container.markdown("**Tool encountered an error...**")
self._container.exception(error)
+ self._container.update(state="error")
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
@@ -241,15 +237,21 @@ def complete(self, final_label: Optional[str] = None) -> None:
final_label = self._labeler.get_tool_label(
self._last_tool, is_complete=True
)
- self._state = LLMThoughtState.COMPLETE
+
+ if self._last_tool and self._last_tool.name == "_Exception":
+ self._state = LLMThoughtState.ERROR
+ elif self._state != LLMThoughtState.ERROR:
+ self._state = LLMThoughtState.COMPLETE
+
if self._collapse_on_complete:
- self._container.update(new_label=final_label, new_expanded=False)
- else:
- self._container.update(new_label=final_label)
+ # Add a quick delay to show the user the final output before we collapse
+ time.sleep(0.25)
- def clear(self) -> None:
- """Remove the thought from the screen. A cleared thought can't be reused."""
- self._container.clear()
+ self._container.update(
+ label=final_label,
+ expanded=False if self._collapse_on_complete else None,
+ state="error" if self._state == LLMThoughtState.ERROR else "complete",
+ )
class StreamlitCallbackHandler(BaseCallbackHandler):
@@ -259,8 +261,8 @@ def __init__(
parent_container: DeltaGenerator,
*,
max_thought_containers: int = 4,
- expand_new_thoughts: bool = True,
- collapse_completed_thoughts: bool = True,
+ expand_new_thoughts: bool = False,
+ collapse_completed_thoughts: bool = False,
thought_labeler: Optional[LLMThoughtLabeler] = None,
):
"""Construct a new StreamlitCallbackHandler. This CallbackHandler is geared
@@ -269,26 +271,35 @@ def __init__(
Parameters
----------
+
parent_container
The `st.container` that will contain all the Streamlit elements that the
Handler creates.
+
max_thought_containers
+
+ .. note::
+ This parameter is deprecated and is ignored in the latest version of
+ the callback handler.
+
The max number of completed LLM thought containers to show at once. When
this threshold is reached, a new thought will cause the oldest thoughts to
be collapsed into a "History" expander. Defaults to 4.
+
expand_new_thoughts
Each LLM "thought" gets its own `st.expander`. This param controls whether
- that expander is expanded by default. Defaults to True.
+ that expander is expanded by default. Defaults to False.
+
collapse_completed_thoughts
If True, LLM thought expanders will be collapsed when completed.
- Defaults to True.
+ Defaults to False.
+
thought_labeler
An optional custom LLMThoughtLabeler instance. If unspecified, the handler
will use the default thought labeling logic. Defaults to None.
"""
self._parent_container = parent_container
self._history_parent = parent_container.container()
- self._history_container: Optional[MutableExpander] = None
self._current_thought: Optional[LLMThought] = None
self._completed_thoughts: List[LLMThought] = []
self._max_thought_containers = max(max_thought_containers, 1)
@@ -310,19 +321,6 @@ def _get_last_completed_thought(self) -> Optional[LLMThought]:
return self._completed_thoughts[len(self._completed_thoughts) - 1]
return None
- @property
- def _num_thought_containers(self) -> int:
- """The number of 'thought containers' we're currently showing: the
- number of completed thought containers, the history container (if it exists),
- and the current thought container (if it exists).
- """
- count = len(self._completed_thoughts)
- if self._history_container is not None:
- count += 1
- if self._current_thought is not None:
- count += 1
- return count
-
def _complete_current_thought(self, final_label: Optional[str] = None) -> None:
"""Complete the current thought, optionally assigning it a new label.
Add it to our _completed_thoughts list.
@@ -332,30 +330,6 @@ def _complete_current_thought(self, final_label: Optional[str] = None) -> None:
self._completed_thoughts.append(thought)
self._current_thought = None
- def _prune_old_thought_containers(self) -> None:
- """If we have too many thoughts onscreen, move older thoughts to the
- 'history container.'
- """
- while (
- self._num_thought_containers > self._max_thought_containers
- and len(self._completed_thoughts) > 0
- ):
- # Create our history container if it doesn't exist, and if
- # max_thought_containers is > 1. (if max_thought_containers is 1, we don't
- # have room to show history.)
- if self._history_container is None and self._max_thought_containers > 1:
- self._history_container = MutableExpander(
- self._history_parent,
- label=self._thought_labeler.get_history_label(),
- expanded=False,
- )
-
- oldest_thought = self._completed_thoughts.pop(0)
- if self._history_container is not None:
- self._history_container.markdown(oldest_thought.container.label)
- self._history_container.append_copy(oldest_thought.container)
- oldest_thought.clear()
-
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
@@ -374,23 +348,19 @@ def on_llm_start(
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
self._require_current_thought().on_llm_new_token(token, **kwargs)
- self._prune_old_thought_containers()
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
self._require_current_thought().on_llm_end(response, **kwargs)
- self._prune_old_thought_containers()
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
self._require_current_thought().on_llm_error(error, **kwargs)
- self._prune_old_thought_containers()
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
self._require_current_thought().on_tool_start(serialized, input_str, **kwargs)
- self._prune_old_thought_containers()
def on_tool_end(
self,
@@ -409,13 +379,11 @@ def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
self._require_current_thought().on_tool_error(error, **kwargs)
- self._prune_old_thought_containers()
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
self._require_current_thought().on_agent_action(action, color, **kwargs)
- self._prune_old_thought_containers()
def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
diff --git a/lib/streamlit/runtime/app_session.py b/lib/streamlit/runtime/app_session.py
index 117b84ad6aa0..e0c80bfd6164 100644
--- a/lib/streamlit/runtime/app_session.py
+++ b/lib/streamlit/runtime/app_session.py
@@ -23,6 +23,7 @@
from streamlit.logger import get_logger
from streamlit.proto.BackMsg_pb2 import BackMsg
from streamlit.proto.ClientState_pb2 import ClientState
+from streamlit.proto.Common_pb2 import FileURLs, FileURLsRequest
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.proto.GitInfo_pb2 import GitInfo
from streamlit.proto.NewSession_pb2 import (
@@ -287,6 +288,8 @@ def handle_backmsg(self, msg: BackMsg) -> None:
self._handle_set_run_on_save_request(msg.set_run_on_save)
elif msg_type == "stop_script":
self._handle_stop_script_request()
+ elif msg_type == "file_urls_request":
+ self._handle_file_urls_request(msg.file_urls_request)
else:
LOGGER.warning('No handler for "%s"', msg_type)
@@ -677,6 +680,10 @@ def _handle_git_information_request(self) -> None:
repository_name, branch, module = repo_info
+ if repository_name.endswith(".git"):
+ # Remove the .git extension from the repository name
+ repository_name = repository_name[:-4]
+
msg.git_info_changed.repository = repository_name
msg.git_info_changed.branch = branch
msg.git_info_changed.module = module
@@ -740,6 +747,26 @@ def _handle_set_run_on_save_request(self, new_value: bool) -> None:
self._run_on_save = new_value
self._enqueue_forward_msg(self._create_session_status_changed_message())
+ def _handle_file_urls_request(self, file_urls_request: FileURLsRequest) -> None:
+ """Handle a file_urls_request BackMsg sent by the client."""
+ msg = ForwardMsg()
+ msg.file_urls_response.response_id = file_urls_request.request_id
+
+ upload_url_infos = self._uploaded_file_mgr.get_upload_urls(
+ self.id, file_urls_request.file_names
+ )
+
+ for upload_url_info in upload_url_infos:
+ msg.file_urls_response.file_urls.append(
+ FileURLs(
+ file_id=upload_url_info.file_id,
+ upload_url=upload_url_info.upload_url,
+ delete_url=upload_url_info.delete_url,
+ )
+ )
+
+ self._enqueue_forward_msg(msg)
+
# Config.ToolbarMode.ValueType does not exist at runtime (only in the pyi stubs), so
# we need to use quotes.
diff --git a/lib/streamlit/runtime/forward_msg_cache.py b/lib/streamlit/runtime/forward_msg_cache.py
index a18dde967912..548a68e50310 100644
--- a/lib/streamlit/runtime/forward_msg_cache.py
+++ b/lib/streamlit/runtime/forward_msg_cache.py
@@ -13,6 +13,7 @@
# limitations under the License.
import hashlib
+import sys
from typing import TYPE_CHECKING, Dict, List, MutableMapping, Optional
from weakref import WeakKeyDictionary
@@ -51,7 +52,10 @@ def populate_hash_if_needed(msg: ForwardMsg) -> str:
msg.ClearField("metadata")
# MD5 is good enough for what we need, which is uniqueness.
- hasher = hashlib.md5()
+ if sys.version_info >= (3, 9):
+ hasher = hashlib.md5(usedforsecurity=False)
+ else:
+ hasher = hashlib.md5()
hasher.update(msg.SerializeToString())
msg.hash = hasher.hexdigest()
diff --git a/lib/streamlit/runtime/memory_uploaded_file_manager.py b/lib/streamlit/runtime/memory_uploaded_file_manager.py
new file mode 100644
index 000000000000..6a12f5162da0
--- /dev/null
+++ b/lib/streamlit/runtime/memory_uploaded_file_manager.py
@@ -0,0 +1,135 @@
+# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import uuid
+from collections import defaultdict
+from typing import Dict, List, Sequence
+
+from streamlit import util
+from streamlit.logger import get_logger
+from streamlit.runtime.stats import CacheStat
+from streamlit.runtime.uploaded_file_manager import (
+ UploadedFileManager,
+ UploadedFileRec,
+ UploadFileUrlInfo,
+)
+
+LOGGER = get_logger(__name__)
+
+
+class MemoryUploadedFileManager(UploadedFileManager):
+ """Holds files uploaded by users of the running Streamlit app.
+ This class can be used safely from multiple threads simultaneously.
+ """
+
+ def __init__(self, upload_endpoint: str):
+ self.file_storage: Dict[str, Dict[str, UploadedFileRec]] = defaultdict(dict)
+ self.endpoint = upload_endpoint
+
+ def get_files(
+ self, session_id: str, file_ids: Sequence[str]
+ ) -> List[UploadedFileRec]:
+ """Return a list of UploadedFileRec for a given sequence of file_ids.
+
+ Parameters
+ ----------
+ session_id
+ The ID of the session that owns the files.
+ file_ids
+ The sequence of ids associated with files to retrieve.
+
+ Returns
+ -------
+ List[UploadedFileRec]
+ A list of URL UploadedFileRec instances, each instance contains information
+ about uploaded file.
+ """
+ session_storage = self.file_storage[session_id]
+ file_recs = []
+
+ for file_id in file_ids:
+ file_rec = session_storage.get(file_id, None)
+ if file_rec is not None:
+ file_recs.append(file_rec)
+
+ return file_recs
+
+ def remove_session_files(self, session_id: str) -> None:
+ """Remove all files associated with a given session."""
+ self.file_storage.pop(session_id, None)
+
+ def __repr__(self) -> str:
+ return util.repr_(self)
+
+ def add_file(
+ self,
+ session_id: str,
+ file: UploadedFileRec,
+ ) -> None:
+ """
+ Safe to call from any thread.
+
+ Parameters
+ ----------
+ session_id
+ The ID of the session that owns the file.
+ file
+ The file to add.
+ """
+
+ self.file_storage[session_id][file.file_id] = file
+
+ def remove_file(self, session_id, file_id):
+ """Remove file with given file_id associated with a given session."""
+ session_storage = self.file_storage[session_id]
+ session_storage.pop(file_id, None)
+
+ def get_upload_urls(
+ self, session_id: str, file_names: Sequence[str]
+ ) -> List[UploadFileUrlInfo]:
+ """Return a list of UploadFileUrlInfo for a given sequence of file_names."""
+ result = []
+ for _ in file_names:
+ file_id = str(uuid.uuid4())
+ result.append(
+ UploadFileUrlInfo(
+ file_id=file_id,
+ upload_url=f"{self.endpoint}/{session_id}/{file_id}",
+ delete_url=f"{self.endpoint}/{session_id}/{file_id}",
+ )
+ )
+ return result
+
+ def get_stats(self) -> List[CacheStat]:
+ """Return the manager's CacheStats.
+
+ Safe to call from any thread.
+ """
+ # Flatten all files into a single list
+ all_files: List[UploadedFileRec] = []
+ # Make copy of self.file_storage for thread safety, to be sure
+ # that main storage won't be changed form other thread
+ file_storage_copy = self.file_storage.copy()
+
+ for session_storage in file_storage_copy.values():
+ all_files.extend(session_storage.values())
+
+ return [
+ CacheStat(
+ category_name="UploadedFileManager",
+ cache_name="",
+ byte_length=len(file.data),
+ )
+ for file in all_files
+ ]
diff --git a/lib/streamlit/runtime/metrics_util.py b/lib/streamlit/runtime/metrics_util.py
index 210c0620438e..2b2962ac84ad 100644
--- a/lib/streamlit/runtime/metrics_util.py
+++ b/lib/streamlit/runtime/metrics_util.py
@@ -72,7 +72,6 @@
"pymssql",
"cassandra",
"azure",
- "google",
"redis",
"sqlite3",
"neo4j",
@@ -319,8 +318,9 @@ def wrapper(f: F) -> F:
@wraps(non_optional_func)
def wrapped_func(*args, **kwargs):
exec_start = timer()
- # get_script_run_ctx gets imported here to prevent circular dependencies
+ # Local imports to prevent circular dependencies
from streamlit.runtime.scriptrunner import get_script_run_ctx
+ from streamlit.runtime.scriptrunner.script_runner import RerunException
ctx = get_script_run_ctx(suppress_warning=True)
@@ -331,6 +331,8 @@ def wrapped_func(*args, **kwargs):
and len(ctx.tracked_commands)
< _MAX_TRACKED_COMMANDS # Prevent too much memory usage
)
+
+ deferred_exception: Optional[RerunException] = None
command_telemetry: Optional[Command] = None
if ctx and tracking_activated:
@@ -354,6 +356,12 @@ def wrapped_func(*args, **kwargs):
_LOGGER.debug("Failed to collect command telemetry", exc_info=ex)
try:
result = non_optional_func(*args, **kwargs)
+ except RerunException as ex:
+ # Duplicated from below, because static analysis tools get confused
+ # by deferring the rethrow.
+ if tracking_activated and command_telemetry:
+ command_telemetry.time = to_microseconds(timer() - exec_start)
+ raise ex
finally:
# Activate tracking again if command executes without any exceptions
if ctx:
@@ -362,6 +370,7 @@ def wrapped_func(*args, **kwargs):
if tracking_activated and command_telemetry:
# Set the execution time to the measured value
command_telemetry.time = to_microseconds(timer() - exec_start)
+
return result
with contextlib.suppress(AttributeError):
diff --git a/lib/streamlit/runtime/runtime.py b/lib/streamlit/runtime/runtime.py
index 927acaaea8bc..faa1e074b134 100644
--- a/lib/streamlit/runtime/runtime.py
+++ b/lib/streamlit/runtime/runtime.py
@@ -15,7 +15,6 @@
from __future__ import annotations
import asyncio
-import sys
import time
import traceback
from dataclasses import dataclass, field
@@ -91,6 +90,9 @@ class RuntimeConfig:
# The storage backend for Streamlit's MediaFileManager.
media_file_storage: MediaFileStorage
+ # The upload file manager
+ uploaded_file_manager: UploadedFileManager
+
# The cache storage backend for Streamlit's st.cache_data.
cache_storage_manager: CacheStorageManager = field(
default_factory=LocalDiskCacheStorageManager
@@ -187,8 +189,7 @@ def __init__(self, config: RuntimeConfig):
# Initialize managers
self._message_cache = ForwardMsgCache()
- self._uploaded_file_mgr = UploadedFileManager()
- self._uploaded_file_mgr.on_files_updated.connect(self._on_files_updated)
+ self._uploaded_file_mgr = config.uploaded_file_manager
self._media_file_mgr = MediaFileManager(storage=config.media_file_storage)
self._cache_storage_manager = config.cache_storage_manager
self._script_cache = ScriptCache()
@@ -237,8 +238,8 @@ def stopped(self) -> Awaitable[None]:
"""A Future that completes when the Runtime's run loop has exited."""
return self._get_async_objs().stopped
- # NOTE: A few Runtime methods listed as threadsafe (get_client, _on_files_updated,
- # and is_active_session) currently rely on the implementation detail that
+ # NOTE: A few Runtime methods listed as threadsafe (get_client and
+ # is_active_session) currently rely on the implementation detail that
# WebsocketSessionManager's get_active_session_info and is_active_session methods
# happen to be threadsafe. This may change with future SessionManager implementations,
# at which point we'll need to formalize our thread safety rules for each
@@ -256,19 +257,6 @@ def get_client(self, session_id: str) -> Optional[SessionClient]:
return None
return session_info.client
- def _on_files_updated(self, session_id: str) -> None:
- """Event handler for UploadedFileManager.on_file_added.
- Ensures that uploaded files from stale sessions get deleted.
-
- Notes
- -----
- Threading: SAFE. May be called on any thread.
- """
- if not self._session_mgr.is_active_session(session_id):
- # If an uploaded file doesn't belong to an active session,
- # remove it so it doesn't stick around forever.
- self._uploaded_file_mgr.remove_session_files(session_id)
-
async def start(self) -> None:
"""Start the runtime. This must be called only once, before
any other functions are called.
@@ -292,14 +280,9 @@ async def start(self) -> None:
)
self._async_objs = async_objs
- if sys.version_info >= (3, 8, 0):
- # Python 3.8+ supports a create_task `name` parameter, which can
- # make debugging a bit easier.
- self._loop_coroutine_task = asyncio.create_task(
- self._loop_coroutine(), name="Runtime.loop_coroutine"
- )
- else:
- self._loop_coroutine_task = asyncio.create_task(self._loop_coroutine())
+ self._loop_coroutine_task = asyncio.create_task(
+ self._loop_coroutine(), name="Runtime.loop_coroutine"
+ )
await async_objs.started
diff --git a/lib/streamlit/runtime/scriptrunner/magic.py b/lib/streamlit/runtime/scriptrunner/magic.py
index 8c89cfd5b492..facb4bd33a6c 100644
--- a/lib/streamlit/runtime/scriptrunner/magic.py
+++ b/lib/streamlit/runtime/scriptrunner/magic.py
@@ -13,7 +13,6 @@
# limitations under the License.
import ast
-import sys
from typing_extensions import Final
@@ -157,7 +156,7 @@ def _get_st_write_from_expr(node, i, parent_type):
if (
i == 0
and _is_docstring_node(node.value)
- and parent_type in (ast.FunctionDef, ast.Module)
+ and parent_type in (ast.FunctionDef, ast.AsyncFunctionDef, ast.Module)
):
return None
@@ -196,7 +195,4 @@ def _get_st_write_from_expr(node, i, parent_type):
def _is_docstring_node(node):
- if sys.version_info >= (3, 8, 0):
- return type(node) is ast.Constant and type(node.value) is str
- else:
- return type(node) is ast.Str
+ return type(node) is ast.Constant and type(node.value) is str
diff --git a/lib/streamlit/runtime/state/common.py b/lib/streamlit/runtime/state/common.py
index 7f02a2732836..57f541b943ca 100644
--- a/lib/streamlit/runtime/state/common.py
+++ b/lib/streamlit/runtime/state/common.py
@@ -17,8 +17,21 @@
import hashlib
from dataclasses import dataclass, field
-from typing import Any, Callable, Dict, Generic, Optional, Tuple, TypeVar, Union
-
+from datetime import date, datetime, time, timedelta
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Generic,
+ Optional,
+ Sequence,
+ Tuple,
+ TypeVar,
+ Union,
+)
+
+from google.protobuf.message import Message
from typing_extensions import Final, TypeAlias
from streamlit import util
@@ -43,6 +56,10 @@
from streamlit.proto.TimeInput_pb2 import TimeInput
from streamlit.type_util import ValueFieldName
+if TYPE_CHECKING:
+ from streamlit.runtime.state.widgets import NoValue
+
+
# Protobuf types for all widgets.
WidgetProto: TypeAlias = Union[
Arrow,
@@ -138,23 +155,37 @@ def failure(
return cls(value=deserializer(None, ""), value_changed=False)
+PROTO_SCALAR_VALUE = Union[float, int, bool, str, bytes]
+SAFE_VALUES = Union[
+ date, time, datetime, timedelta, None, "NoValue", Message, PROTO_SCALAR_VALUE
+]
+
+
def compute_widget_id(
- element_type: str, element_proto: WidgetProto, user_key: Optional[str] = None
+ element_type: str,
+ user_key: str | None = None,
+ **kwargs: SAFE_VALUES | Sequence[SAFE_VALUES],
) -> str:
"""Compute the widget id for the given widget. This id is stable: a given
set of inputs to this function will always produce the same widget id output.
+ Only stable, deterministic values should be used to compute widget ids. Using
+ nondeterministic values as inputs can cause the resulting widget id to
+ change between runs.
+
The widget id includes the user_key so widgets with identical arguments can
use it to be distinct.
The widget id includes an easily identified prefix, and the user_key as a
suffix, to make it easy to identify it and know if a key maps to it.
-
- Does not mutate the element_proto object.
"""
h = hashlib.new("md5")
h.update(element_type.encode("utf-8"))
- h.update(element_proto.SerializeToString())
+ # This will iterate in a consistent order when the provided arguments have
+ # consistent order; dicts are always in insertion order.
+ for k, v in kwargs.items():
+ h.update(str(k).encode("utf-8"))
+ h.update(str(v).encode("utf-8"))
return f"{GENERATED_WIDGET_ID_PREFIX}-{h.hexdigest()}-{user_key}"
diff --git a/lib/streamlit/runtime/state/session_state.py b/lib/streamlit/runtime/state/session_state.py
index 74d6bf4f2b20..22775d2cb8d2 100644
--- a/lib/streamlit/runtime/state/session_state.py
+++ b/lib/streamlit/runtime/state/session_state.py
@@ -179,6 +179,8 @@ def set_widget_metadata(self, widget_meta: WidgetMetadata[Any]) -> None:
def remove_stale_widgets(self, active_widget_ids: set[str]) -> None:
"""Remove widget state for widgets whose ids aren't in `active_widget_ids`."""
+ # TODO(vdonato / kajarenc): Remove files corresponding to an inactive file
+ # uploader.
self.states = {k: v for k, v in self.states.items() if k in active_widget_ids}
def get_serialized(self, k: str) -> WidgetStateProto | None:
diff --git a/lib/streamlit/runtime/state/widgets.py b/lib/streamlit/runtime/state/widgets.py
index 2b051d6987da..cf305bbf5b1f 100644
--- a/lib/streamlit/runtime/state/widgets.py
+++ b/lib/streamlit/runtime/state/widgets.py
@@ -30,7 +30,6 @@
WidgetMetadata,
WidgetProto,
WidgetSerializer,
- compute_widget_id,
user_key_from_widget_id,
)
from streamlit.type_util import ValueFieldName
@@ -38,6 +37,7 @@
if TYPE_CHECKING:
from streamlit.runtime.scriptrunner import ScriptRunContext
+
ElementType: TypeAlias = str
# NOTE: We use this table to start with a best-effort guess for the value_type
@@ -148,12 +148,9 @@ def register_widget(
For both paths a widget return value is provided, allowing the widgets
to be used in a non-streamlit setting.
"""
- widget_id = compute_widget_id(element_type, element_proto, user_key)
- element_proto.id = widget_id
-
# Create the widget's updated metadata, and register it with session_state.
metadata = WidgetMetadata(
- widget_id,
+ element_proto.id,
deserializer,
serializer,
value_type=ELEMENT_TYPE_TO_VALUE_TYPE[element_type],
diff --git a/lib/streamlit/runtime/uploaded_file_manager.py b/lib/streamlit/runtime/uploaded_file_manager.py
index e1c1072371ee..4946869b2cdc 100644
--- a/lib/streamlit/runtime/uploaded_file_manager.py
+++ b/lib/streamlit/runtime/uploaded_file_manager.py
@@ -11,29 +11,47 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import io
-import threading
-from typing import Dict, List, NamedTuple, Tuple
+from abc import abstractmethod
+from typing import List, NamedTuple, Sequence
-from blinker import Signal
+from typing_extensions import Protocol
from streamlit import util
-from streamlit.logger import get_logger
-from streamlit.runtime.stats import CacheStat, CacheStatsProvider
-
-LOGGER = get_logger(__name__)
+from streamlit.proto.Common_pb2 import FileURLs as FileURLsProto
+from streamlit.runtime.stats import CacheStatsProvider
class UploadedFileRec(NamedTuple):
"""Metadata and raw bytes for an uploaded file. Immutable."""
- id: int
+ file_id: str
name: str
type: str
data: bytes
+class UploadFileUrlInfo(NamedTuple):
+ """Information we provide for single file in get_upload_urls"""
+
+ file_id: str
+ upload_url: str
+ delete_url: str
+
+
+class DeletedFile(NamedTuple):
+ """Represents a deleted file in deserialized values for st.file_uploader and
+ st.camera_input
+
+ Return this from st.file_uploader and st.camera_input deserialize (so they can
+ be used in session_state), when widget value contains file record that is missing
+ from the storage.
+ DeleteFile instances filtered out before return final value to the user in script,
+ or before sending to frontend."""
+
+ file_id: str
+
+
class UploadedFile(io.BytesIO):
"""A mutable uploaded file.
@@ -41,294 +59,85 @@ class UploadedFile(io.BytesIO):
initialized with `bytes`.
"""
- def __init__(self, record: UploadedFileRec):
+ def __init__(self, record: UploadedFileRec, file_urls: FileURLsProto):
# BytesIO's copy-on-write semantics doesn't seem to be mentioned in
# the Python docs - possibly because it's a CPython-only optimization
# and not guaranteed to be in other Python runtimes. But it's detailed
# here: https://hg.python.org/cpython/rev/79a5fbe2c78f
- super(UploadedFile, self).__init__(record.data)
- self.id = record.id
+ super().__init__(record.data)
+ self.file_id = record.file_id
self.name = record.name
self.type = record.type
self.size = len(record.data)
+ self._file_urls = file_urls
def __eq__(self, other: object) -> bool:
if not isinstance(other, UploadedFile):
return NotImplemented
- return self.id == other.id
-
- def __repr__(self) -> str:
- return util.repr_(self)
-
-
-class UploadedFileManager(CacheStatsProvider):
- """Holds files uploaded by users of the running Streamlit app,
- and emits an event signal when a file is added.
-
- This class can be used safely from multiple threads simultaneously.
- """
-
- def __init__(self):
- # List of files for a given widget in a given session.
- self._files_by_id: Dict[Tuple[str, str], List[UploadedFileRec]] = {}
-
- # A counter that generates unique file IDs. Each file ID is greater
- # than the previous ID, which means we can use IDs to compare files
- # by age.
- self._file_id_counter = 1
- self._file_id_lock = threading.Lock()
-
- # Prevents concurrent access to the _files_by_id dict.
- # In remove_session_files(), we iterate over the dict's keys. It's
- # an error to mutate a dict while iterating; this lock prevents that.
- self._files_lock = threading.Lock()
- self.on_files_updated = Signal(
- doc="""Emitted when a file list is added to the manager or updated.
-
- Parameters
- ----------
- session_id : str
- The session_id for the session whose files were updated.
- """
- )
+ return self.file_id == other.file_id
def __repr__(self) -> str:
return util.repr_(self)
- def add_file(
- self,
- session_id: str,
- widget_id: str,
- file: UploadedFileRec,
- ) -> UploadedFileRec:
- """Add a file to the FileManager, and return a new UploadedFileRec
- with its ID assigned.
-
- The "on_files_updated" Signal will be emitted.
-
- Safe to call from any thread.
-
- Parameters
- ----------
- session_id
- The ID of the session that owns the file.
- widget_id
- The widget ID of the FileUploader that created the file.
- file
- The file to add.
- Returns
- -------
- UploadedFileRec
- The added file, which has its unique ID assigned.
- """
- files_by_widget = session_id, widget_id
+class UploadedFileManager(CacheStatsProvider, Protocol):
+ """UploadedFileManager protocol, that should be implemented by the concrete
+ uploaded file managers.
- # Assign the file a unique ID
- file_id = self._get_next_file_id()
- file = UploadedFileRec(
- id=file_id, name=file.name, type=file.type, data=file.data
- )
+ It is responsible for:
+ - retrieving files by session_id and file_id for st.file_uploader and
+ st.camera_input
+ - cleaning up uploaded files associated with session on session end
- with self._files_lock:
- file_list = self._files_by_id.get(files_by_widget, None)
- if file_list is not None:
- file_list.append(file)
- else:
- self._files_by_id[files_by_widget] = [file]
+ It should be created during Runtime initialization.
- self.on_files_updated.send(session_id)
- return file
-
- def get_all_files(self, session_id: str, widget_id: str) -> List[UploadedFileRec]:
- """Return all the files stored for the given widget.
-
- Safe to call from any thread.
-
- Parameters
- ----------
- session_id
- The ID of the session that owns the files.
- widget_id
- The widget ID of the FileUploader that created the files.
- """
- file_list_id = (session_id, widget_id)
- with self._files_lock:
- return self._files_by_id.get(file_list_id, []).copy()
+ Optionally UploadedFileManager could be responsible for issuing URLs which will be
+ used by frontend to upload files to.
+ """
+ @abstractmethod
def get_files(
- self, session_id: str, widget_id: str, file_ids: List[int]
+ self, session_id: str, file_ids: Sequence[str]
) -> List[UploadedFileRec]:
- """Return the files with the given widget_id and file_ids.
-
- Safe to call from any thread.
+ """Return a list of UploadedFileRec for a given sequence of file_ids.
Parameters
----------
session_id
The ID of the session that owns the files.
- widget_id
- The widget ID of the FileUploader that created the files.
file_ids
- List of file IDs. Only files whose IDs are in this list will be
- returned.
- """
- return [
- f for f in self.get_all_files(session_id, widget_id) if f.id in file_ids
- ]
-
- def remove_orphaned_files(
- self,
- session_id: str,
- widget_id: str,
- newest_file_id: int,
- active_file_ids: List[int],
- ) -> None:
- """Remove 'orphaned' files: files that have been uploaded and
- subsequently deleted, but haven't yet been removed from memory.
-
- Because FileUploader can live inside forms, file deletion is made a
- bit tricky: a file deletion should only happen after the form is
- submitted.
-
- FileUploader's widget value is an array of numbers that has two parts:
- - The first number is always 'this.state.newestServerFileId'.
- - The remaining 0 or more numbers are the file IDs of all the
- uploader's uploaded files.
-
- When the server receives the widget value, it deletes "orphaned"
- uploaded files. An orphaned file is any file associated with a given
- FileUploader whose file ID is not in the active_file_ids, and whose
- ID is <= `newestServerFileId`.
-
- This logic ensures that a FileUploader within a form doesn't have any
- of its "unsubmitted" uploads prematurely deleted when the script is
- re-run.
-
- Safe to call from any thread.
- """
- file_list_id = (session_id, widget_id)
- with self._files_lock:
- file_list = self._files_by_id.get(file_list_id)
- if file_list is None:
- return
-
- # Remove orphaned files from the list:
- # - `f.id in active_file_ids`:
- # File is currently tracked by the widget. DON'T remove.
- # - `f.id > newest_file_id`:
- # file was uploaded *after* the widget was most recently
- # updated. (It's probably in a form.) DON'T remove.
- # - `f.id < newest_file_id and f.id not in active_file_ids`:
- # File is not currently tracked by the widget, and was uploaded
- # *before* this most recent update. This means it's been deleted
- # by the user on the frontend, and is now "orphaned". Remove!
- new_list = [
- f for f in file_list if f.id > newest_file_id or f.id in active_file_ids
- ]
- self._files_by_id[file_list_id] = new_list
- num_removed = len(file_list) - len(new_list)
-
- if num_removed > 0:
- LOGGER.debug("Removed %s orphaned files" % num_removed)
-
- def remove_file(self, session_id: str, widget_id: str, file_id: int) -> bool:
- """Remove the file list with the given ID, if it exists.
-
- The "on_files_updated" Signal will be emitted.
-
- Safe to call from any thread.
+ The sequence of ids associated with files to retrieve.
Returns
-------
- bool
- True if the file was removed, or False if no such file exists.
+ List[UploadedFileRec]
+ A list of URL UploadedFileRec instances, each instance contains information
+ about uploaded file.
"""
- file_list_id = (session_id, widget_id)
- with self._files_lock:
- file_list = self._files_by_id.get(file_list_id, None)
- if file_list is None:
- return False
-
- # Remove the file from its list.
- new_file_list = [file for file in file_list if file.id != file_id]
- self._files_by_id[file_list_id] = new_file_list
-
- self.on_files_updated.send(session_id)
- return True
-
- def _remove_files(self, session_id: str, widget_id: str) -> None:
- """Remove the file list for the provided widget in the
- provided session, if it exists.
-
- Does not emit any signals.
-
- Safe to call from any thread.
- """
- files_by_widget = session_id, widget_id
- with self._files_lock:
- self._files_by_id.pop(files_by_widget, None)
-
- def remove_files(self, session_id: str, widget_id: str) -> None:
- """Remove the file list for the provided widget in the
- provided session, if it exists.
-
- The "on_files_updated" Signal will be emitted.
-
- Safe to call from any thread.
-
- Parameters
- ----------
- session_id : str
- The ID of the session that owns the files.
- widget_id : str
- The widget ID of the FileUploader that created the files.
- """
- self._remove_files(session_id, widget_id)
- self.on_files_updated.send(session_id)
+ raise NotImplementedError
+ @abstractmethod
def remove_session_files(self, session_id: str) -> None:
- """Remove all files that belong to the given session.
+ """Remove all files associated with a given session."""
+ raise NotImplementedError
- Safe to call from any thread.
+ def get_upload_urls(
+ self, session_id: str, file_names: Sequence[str]
+ ) -> List[UploadFileUrlInfo]:
+ """Return a list of UploadFileUrlInfo for a given sequence of file_names.
+ Optional to implement, issuing of URLs could be done by other service.
Parameters
----------
- session_id : str
- The ID of the session whose files we're removing.
-
- """
- # Copy the keys into a list, because we'll be mutating the dictionary.
- with self._files_lock:
- all_ids = list(self._files_by_id.keys())
-
- for files_id in all_ids:
- if files_id[0] == session_id:
- self.remove_files(*files_id)
-
- def _get_next_file_id(self) -> int:
- """Return the next file ID and increment our ID counter."""
- with self._file_id_lock:
- file_id = self._file_id_counter
- self._file_id_counter += 1
- return file_id
-
- def get_stats(self) -> List[CacheStat]:
- """Return the manager's CacheStats.
+ session_id
+ The ID of the session that request URLs.
+ file_names
+ The sequence of file names for which URLs are requested
- Safe to call from any thread.
+ Returns
+ -------
+ List[UploadFileUrlInfo]
+ A list of UploadFileUrlInfo instances, each instance contains information
+ about uploaded file URLs.
"""
- with self._files_lock:
- # Flatten all files into a single list
- all_files: List[UploadedFileRec] = []
- for file_list in self._files_by_id.values():
- all_files.extend(file_list)
-
- return [
- CacheStat(
- category_name="UploadedFileManager",
- cache_name="",
- byte_length=len(file.data),
- )
- for file in all_files
- ]
+ raise NotImplementedError
diff --git a/lib/streamlit/testing/local_script_runner.py b/lib/streamlit/testing/local_script_runner.py
index 6a6f01698b31..109d7b403a3a 100644
--- a/lib/streamlit/testing/local_script_runner.py
+++ b/lib/streamlit/testing/local_script_runner.py
@@ -22,10 +22,10 @@
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.proto.WidgetStates_pb2 import WidgetStates
from streamlit.runtime.forward_msg_queue import ForwardMsgQueue
+from streamlit.runtime.memory_uploaded_file_manager import MemoryUploadedFileManager
from streamlit.runtime.scriptrunner import RerunData, ScriptRunner, ScriptRunnerEvent
from streamlit.runtime.scriptrunner.script_cache import ScriptCache
from streamlit.runtime.state.session_state import SessionState
-from streamlit.runtime.uploaded_file_manager import UploadedFileManager
from streamlit.testing.element_tree import ElementTree, parse_tree_from_messages
@@ -53,7 +53,7 @@ def __init__(
main_script_path=script_path,
client_state=ClientState(),
session_state=self.session_state,
- uploaded_file_mgr=UploadedFileManager(),
+ uploaded_file_mgr=MemoryUploadedFileManager("/mock/upload"),
script_cache=ScriptCache(),
initial_rerun_data=RerunData(),
user_info={"email": "test@test.com"},
diff --git a/lib/streamlit/type_util.py b/lib/streamlit/type_util.py
index c95543aa0136..a1298ad139ee 100644
--- a/lib/streamlit/type_util.py
+++ b/lib/streamlit/type_util.py
@@ -17,6 +17,7 @@
from __future__ import annotations
import contextlib
+import copy
import re
import types
from enum import Enum, auto
@@ -53,6 +54,7 @@
import sympy
from pandas.core.indexing import _iLocIndexer
from pandas.io.formats.style import Styler
+ from pandas.io.formats.style_renderer import StyleRenderer
from plotly.graph_objs import Figure
from pydeck import Deck
@@ -480,11 +482,31 @@ def is_sequence(seq: Any) -> bool:
return True
+@overload
def convert_anything_to_df(
data: Any,
max_unevaluated_rows: int = MAX_UNEVALUATED_DF_ROWS,
ensure_copy: bool = False,
) -> DataFrame:
+ ...
+
+
+@overload
+def convert_anything_to_df(
+ data: Any,
+ max_unevaluated_rows: int = MAX_UNEVALUATED_DF_ROWS,
+ ensure_copy: bool = False,
+ allow_styler: bool = False,
+) -> Union[DataFrame, "Styler"]:
+ ...
+
+
+def convert_anything_to_df(
+ data: Any,
+ max_unevaluated_rows: int = MAX_UNEVALUATED_DF_ROWS,
+ ensure_copy: bool = False,
+ allow_styler: bool = False,
+) -> Union[DataFrame, "Styler"]:
"""Try to convert different formats to a Pandas Dataframe.
Parameters
@@ -499,16 +521,33 @@ def convert_anything_to_df(
If True, make sure to always return a copy of the data. If False, it depends on the
type of the data. For example, a Pandas DataFrame will be returned as-is.
+ allow_styler: bool
+ If True, allows this to return a Pandas Styler object as well. If False, returns
+ a plain Pandas DataFrame (which, of course, won't contain the Styler's styles).
+
Returns
-------
- pandas.DataFrame
+ pandas.DataFrame or pandas.Styler
"""
if is_type(data, _PANDAS_DF_TYPE_STR):
- return data.copy() if ensure_copy else data
+ return data.copy() if ensure_copy else cast(DataFrame, data)
if is_pandas_styler(data):
- return data.data.copy() if ensure_copy else data.data
+ # Every Styler is a StyleRenderer. I'm casting to StyleRenderer here rather than to the more
+ # correct Styler becayse MyPy doesn't like when we cast to Styler. It complains .data
+ # doesn't exist, when it does in fact exist in the parent class StyleRenderer!
+ sr = cast("StyleRenderer", data)
+
+ if allow_styler:
+ if ensure_copy:
+ out = copy.deepcopy(sr)
+ out.data = sr.data.copy()
+ return cast("Styler", out)
+ else:
+ return data
+ else:
+ return cast("Styler", sr.data.copy() if ensure_copy else sr.data)
if is_type(data, "numpy.ndarray"):
if len(data.shape) == 0:
@@ -529,13 +568,13 @@ def convert_anything_to_df(
f"⚠️ Showing only {string_util.simplify_number(max_unevaluated_rows)} rows. "
"Call `collect()` on the dataframe to show more."
)
- return data
+ return cast(DataFrame, data)
# This is inefficient when data is a pyarrow.Table as it will be converted
# back to Arrow when marshalled to protobuf, but area/bar/line charts need
# DataFrame magic to generate the correct output.
if hasattr(data, "to_pandas"):
- return data.to_pandas()
+ return cast(DataFrame, data.to_pandas())
# Try to convert to pandas.DataFrame. This will raise an error is df is not
# compatible with the pandas.DataFrame constructor.
@@ -563,11 +602,11 @@ def ensure_iterable(obj: Iterable[V_co]) -> Iterable[V_co]:
@overload
-def ensure_iterable(obj: DataFrame) -> Iterable[Any]:
+def ensure_iterable(obj: OptionSequence[V_co]) -> Iterable[Any]:
...
-def ensure_iterable(obj: Union[DataFrame, Iterable[V_co]]) -> Iterable[Any]:
+def ensure_iterable(obj: Union[OptionSequence[V_co], Iterable[V_co]]) -> Iterable[Any]:
"""Try to convert different formats to something iterable. Most inputs
are assumed to be iterable, but if we have a DataFrame, we can just
select the first column to iterate over. If the input is not iterable,
@@ -582,6 +621,7 @@ def ensure_iterable(obj: Union[DataFrame, Iterable[V_co]]) -> Iterable[Any]:
iterable
"""
+
if is_snowpark_or_pyspark_data_object(obj):
obj = convert_anything_to_df(obj)
@@ -648,7 +688,7 @@ def pyarrow_table_to_bytes(table: pa.Table) -> bytes:
return cast(bytes, sink.getvalue().to_pybytes())
-def is_colum_type_arrow_incompatible(column: Union[Series, Index]) -> bool:
+def is_colum_type_arrow_incompatible(column: Union[Series[Any], Index]) -> bool:
"""Return True if the column type is known to cause issues during Arrow conversion."""
if column.dtype.kind in [
# timedelta is supported by pyarrow but not in the Arrow JS:
@@ -776,7 +816,7 @@ def bytes_to_data_frame(source: bytes) -> DataFrame:
"""
reader = pa.RecordBatchStreamReader(source)
- return reader.read_pandas()
+ return cast(DataFrame, reader.read_pandas())
def determine_data_format(input_data: Any) -> DataFormat:
@@ -864,7 +904,7 @@ def convert_df_to_data_format(
df: DataFrame, data_format: DataFormat
) -> Union[
DataFrame,
- Series,
+ Series[Any],
pa.Table,
np.ndarray[Any, np.dtype[Any]],
Tuple[Any],
@@ -982,3 +1022,55 @@ def maybe_raise_label_warnings(label: Optional[str], label_visibility: Optional[
f"Unsupported label_visibility option '{label_visibility}'. "
f"Valid values are 'visible', 'hidden' or 'collapsed'."
)
+
+
+# The code below is copied from Altair, and slightly modified.
+# We copy this code here so we don't depend on private Altair functions.
+# Source: https://github.com/altair-viz/altair/blob/62ca5e37776f5cecb27e83c1fbd5d685a173095d/altair/utils/core.py#L193
+
+# STREAMLIT MOD: I changed the type for the data argument from "pd.Series" to Series,
+# and the return type to a Union including a (str, list) tuple, since the function does
+# return that in some situations.
+def infer_vegalite_type(data: Series[Any]) -> Union[str, Tuple[str, List[Any]]]:
+ """
+ From an array-like input, infer the correct vega typecode
+ ('ordinal', 'nominal', 'quantitative', or 'temporal')
+
+ Parameters
+ ----------
+ data: Numpy array or Pandas Series
+ """
+ # STREAMLIT MOD: I'm using infer_dtype directly here, rather than using Altair's wrapper. Their
+ # wrapper is only there to support Pandas < 0.20, but Streamlit requires Pandas 1.3.
+ typ = infer_dtype(data)
+
+ if typ in [
+ "floating",
+ "mixed-integer-float",
+ "integer",
+ "mixed-integer",
+ "complex",
+ ]:
+ return "quantitative"
+ elif typ == "categorical" and data.cat.ordered:
+ return ("ordinal", data.cat.categories.tolist())
+ elif typ in ["string", "bytes", "categorical", "boolean", "mixed", "unicode"]:
+ return "nominal"
+ elif typ in [
+ "datetime",
+ "datetime64",
+ "timedelta",
+ "timedelta64",
+ "date",
+ "time",
+ "period",
+ ]:
+ return "temporal"
+ else:
+ # STREAMLIT MOD: I commented this out since Streamlit doesn't have a warnings object.
+ # warnings.warn(
+ # "I don't know how to infer vegalite type from '{}'. "
+ # "Defaulting to nominal.".format(typ),
+ # stacklevel=1,
+ # )
+ return "nominal"
diff --git a/lib/streamlit/watcher/util.py b/lib/streamlit/watcher/util.py
index e820b16c2e6e..a8c0eab36770 100644
--- a/lib/streamlit/watcher/util.py
+++ b/lib/streamlit/watcher/util.py
@@ -20,6 +20,7 @@
import hashlib
import os
+import sys
import time
from pathlib import Path
from typing import Optional
@@ -55,7 +56,10 @@ def calc_md5_with_blocking_retries(
else:
content = _get_file_content_with_blocking_retries(path)
- md5 = hashlib.md5()
+ if sys.version_info >= (3, 9):
+ md5 = hashlib.md5(usedforsecurity=False)
+ else:
+ md5 = hashlib.md5()
md5.update(content)
# Use hexdigest() instead of digest(), so it's easier to debug.
diff --git a/lib/streamlit/web/bootstrap.py b/lib/streamlit/web/bootstrap.py
index b7a9636dd6bd..bd56e7b793d1 100644
--- a/lib/streamlit/web/bootstrap.py
+++ b/lib/streamlit/web/bootstrap.py
@@ -130,7 +130,7 @@ def _fix_tornado_crash() -> None:
FIXME: if/when tornado supports the defaults in asyncio,
remove and bump tornado requirement for py38
"""
- if env_util.IS_WINDOWS and sys.version_info >= (3, 8):
+ if env_util.IS_WINDOWS:
try:
from asyncio import ( # type: ignore[attr-defined]
WindowsProactorEventLoopPolicy,
diff --git a/lib/streamlit/web/server/server.py b/lib/streamlit/web/server/server.py
index aba0ff40aed6..f16a94232d7e 100644
--- a/lib/streamlit/web/server/server.py
+++ b/lib/streamlit/web/server/server.py
@@ -36,6 +36,7 @@
from streamlit.logger import get_logger
from streamlit.runtime import Runtime, RuntimeConfig, RuntimeState
from streamlit.runtime.memory_media_file_storage import MemoryMediaFileStorage
+from streamlit.runtime.memory_uploaded_file_manager import MemoryUploadedFileManager
from streamlit.runtime.runtime_util import get_max_message_size_bytes
from streamlit.web.cache_storage_manager_config import (
create_default_cache_storage_manager,
@@ -53,10 +54,7 @@
)
from streamlit.web.server.server_util import make_url_path_regex
from streamlit.web.server.stats_request_handler import StatsRequestHandler
-from streamlit.web.server.upload_file_request_handler import (
- UPLOAD_FILE_ROUTE,
- UploadFileRequestHandler,
-)
+from streamlit.web.server.upload_file_request_handler import UploadFileRequestHandler
LOGGER = get_logger(__name__)
@@ -83,6 +81,7 @@
UNIX_SOCKET_PREFIX = "unix://"
MEDIA_ENDPOINT: Final = "/media"
+UPLOAD_FILE_ENDPOINT: Final = "/_stcore/upload_file"
STREAM_ENDPOINT: Final = r"_stcore/stream"
METRIC_ENDPOINT: Final = r"(?:st-metrics|_stcore/metrics)"
MESSAGE_ENDPOINT: Final = r"_stcore/message"
@@ -228,11 +227,14 @@ def __init__(self, main_script_path: str, command_line: Optional[str]):
media_file_storage = MemoryMediaFileStorage(MEDIA_ENDPOINT)
MediaFileHandler.initialize_storage(media_file_storage)
+ uploaded_file_mgr = MemoryUploadedFileManager(UPLOAD_FILE_ENDPOINT)
+
self._runtime = Runtime(
RuntimeConfig(
script_path=main_script_path,
command_line=command_line,
media_file_storage=media_file_storage,
+ uploaded_file_manager=uploaded_file_mgr,
cache_storage_manager=create_default_cache_storage_manager(),
),
)
@@ -299,7 +301,7 @@ def _create_app(self) -> tornado.web.Application:
(
make_url_path_regex(
base,
- UPLOAD_FILE_ROUTE,
+ rf"{UPLOAD_FILE_ENDPOINT}/(?P[^/]+)/(?P[^/]+)",
),
UploadFileRequestHandler,
dict(
diff --git a/lib/streamlit/web/server/upload_file_request_handler.py b/lib/streamlit/web/server/upload_file_request_handler.py
index 41392e2bada4..8741f27c9e4c 100644
--- a/lib/streamlit/web/server/upload_file_request_handler.py
+++ b/lib/streamlit/web/server/upload_file_request_handler.py
@@ -19,13 +19,10 @@
from streamlit import config
from streamlit.logger import get_logger
+from streamlit.runtime.memory_uploaded_file_manager import MemoryUploadedFileManager
from streamlit.runtime.uploaded_file_manager import UploadedFileManager, UploadedFileRec
from streamlit.web.server import routes, server_util
-# /_stcore/upload_file/(optional session id)/(optional widget id)
-UPLOAD_FILE_ROUTE = (
- r"/_stcore/upload_file/?(?P[^/]*)?/?(?P[^/]*)?"
-)
LOGGER = get_logger(__name__)
@@ -33,7 +30,9 @@ class UploadFileRequestHandler(tornado.web.RequestHandler):
"""Implements the POST /upload_file endpoint."""
def initialize(
- self, file_mgr: UploadedFileManager, is_active_session: Callable[[str], bool]
+ self,
+ file_mgr: MemoryUploadedFileManager,
+ is_active_session: Callable[[str], bool],
):
"""
Parameters
@@ -49,7 +48,7 @@ def initialize(
self._is_active_session = is_active_session
def set_default_headers(self):
- self.set_header("Access-Control-Allow-Methods", "POST, OPTIONS")
+ self.set_header("Access-Control-Allow-Methods", "PUT, OPTIONS, DELETE")
self.set_header("Access-Control-Allow-Headers", "Content-Type")
if config.get_option("server.enableXsrfProtection"):
self.set_header(
@@ -82,32 +81,15 @@ def options(self, **kwargs):
self.set_status(204)
self.finish()
- @staticmethod
- def _require_arg(args: Dict[str, List[bytes]], name: str) -> str:
- """Return the value of the argument with the given name.
+ def put(self, **kwargs):
+ """Receive an uploaded file and add it to our UploadedFileManager."""
- A human-readable exception will be raised if the argument doesn't
- exist. This will be used as the body for the error response returned
- from the request.
- """
- try:
- arg = args[name]
- except KeyError:
- raise Exception(f"Missing '{name}'")
-
- if len(arg) != 1:
- raise Exception(f"Expected 1 '{name}' arg, but got {len(arg)}")
-
- # Convert bytes to string
- return arg[0].decode("utf-8")
-
- def post(self, **kwargs):
- """Receive an uploaded file and add it to our UploadedFileManager.
- Return the file's ID, so that the client can refer to it.
- """
args: Dict[str, List[bytes]] = {}
files: Dict[str, List[Any]] = {}
+ session_id = self.path_kwargs["session_id"]
+ file_id = self.path_kwargs["file_id"]
+
tornado.httputil.parse_body_arguments(
content_type=self.request.headers["Content-Type"],
body=self.request.body,
@@ -116,25 +98,19 @@ def post(self, **kwargs):
)
try:
- session_id = self._require_arg(args, "sessionId")
- widget_id = self._require_arg(args, "widgetId")
if not self._is_active_session(session_id):
raise Exception(f"Invalid session_id: '{session_id}'")
-
except Exception as e:
self.send_error(400, reason=str(e))
return
- # Create an UploadedFile object for each file.
- # We assign an initial, invalid file_id to each file in this loop.
- # The file_mgr will assign unique file IDs and return in `add_file`,
- # below.
uploaded_files: List[UploadedFileRec] = []
+
for _, flist in files.items():
for file in flist:
uploaded_files.append(
UploadedFileRec(
- id=0,
+ file_id=file_id,
name=file["filename"],
type=file["content_type"],
data=file["body"],
@@ -147,11 +123,13 @@ def post(self, **kwargs):
)
return
- added_file = self._file_mgr.add_file(
- session_id=session_id, widget_id=widget_id, file=uploaded_files[0]
- )
+ self._file_mgr.add_file(session_id=session_id, file=uploaded_files[0])
+ self.set_status(204)
+
+ def delete(self, **kwargs):
+ """Delete file request handler."""
+ session_id = self.path_kwargs["session_id"]
+ file_id = self.path_kwargs["file_id"]
- # Return the file_id to the client. (The client will parse
- # the string back to an int.)
- self.write(str(added_file.id))
- self.set_status(200)
+ self._file_mgr.remove_file(session_id=session_id, file_id=file_id)
+ self.set_status(204)
diff --git a/lib/test-requirements.txt b/lib/test-requirements.txt
index 0e8df6c9396d..f6c0442d355d 100644
--- a/lib/test-requirements.txt
+++ b/lib/test-requirements.txt
@@ -24,9 +24,11 @@ pytest
pytest-cov
requests-mock
testfixtures
+pytest-playwright>=0.1.2
+pixelmatch>=0.3.0
+pytest-xdist
-
-mypy-protobuf>=3.2
+mypy-protobuf>=3.2, <3.4
# These requirements exist only for `@st.cache` tests. st.cache is deprecated, and
# we're not going to update its associated tests anymore. Please don't modify
diff --git a/lib/tests/delta_generator_test_case.py b/lib/tests/delta_generator_test_case.py
index f915e7a9a819..d9b6304dabd3 100644
--- a/lib/tests/delta_generator_test_case.py
+++ b/lib/tests/delta_generator_test_case.py
@@ -28,14 +28,14 @@
from streamlit.runtime.forward_msg_queue import ForwardMsgQueue
from streamlit.runtime.media_file_manager import MediaFileManager
from streamlit.runtime.memory_media_file_storage import MemoryMediaFileStorage
+from streamlit.runtime.memory_uploaded_file_manager import MemoryUploadedFileManager
from streamlit.runtime.scriptrunner import (
ScriptRunContext,
add_script_run_ctx,
get_script_run_ctx,
)
from streamlit.runtime.state import SafeSessionState, SessionState
-from streamlit.runtime.uploaded_file_manager import UploadedFileManager
-from streamlit.web.server.server import MEDIA_ENDPOINT
+from streamlit.web.server.server import MEDIA_ENDPOINT, UPLOAD_FILE_ENDPOINT
class DeltaGeneratorTestCase(unittest.TestCase):
@@ -51,7 +51,7 @@ def setUp(self):
_enqueue=self.forward_msg_queue.enqueue,
query_string="",
session_state=SafeSessionState(SessionState()),
- uploaded_file_mgr=UploadedFileManager(),
+ uploaded_file_mgr=MemoryUploadedFileManager(UPLOAD_FILE_ENDPOINT),
page_script_hash="",
user_info={"email": "test@test.com"},
)
@@ -64,6 +64,7 @@ def setUp(self):
mock_runtime = MagicMock(spec=Runtime)
mock_runtime.cache_storage_manager = MemoryCacheStorageManager()
mock_runtime.media_file_mgr = MediaFileManager(self.media_file_storage)
+ mock_runtime.uploaded_file_mgr = self.script_run_ctx.uploaded_file_mgr
Runtime._instance = mock_runtime
def tearDown(self):
diff --git a/lib/tests/exception_capturing_thread.py b/lib/tests/exception_capturing_thread.py
index 5dc61357e8ab..336afb59b875 100644
--- a/lib/tests/exception_capturing_thread.py
+++ b/lib/tests/exception_capturing_thread.py
@@ -16,9 +16,9 @@
from typing import Any, Callable, Optional
from streamlit.runtime.forward_msg_queue import ForwardMsgQueue
+from streamlit.runtime.memory_uploaded_file_manager import MemoryUploadedFileManager
from streamlit.runtime.scriptrunner import ScriptRunContext, add_script_run_ctx
from streamlit.runtime.state import SafeSessionState, SessionState
-from streamlit.runtime.uploaded_file_manager import UploadedFileManager
def call_on_threads(
@@ -63,7 +63,7 @@ def call_on_threads(
_enqueue=ForwardMsgQueue().enqueue,
query_string="",
session_state=SafeSessionState(SessionState()),
- uploaded_file_mgr=UploadedFileManager(),
+ uploaded_file_mgr=MemoryUploadedFileManager("/mock/upload"),
page_script_hash="",
user_info={"email": "test@test.com"},
)
diff --git a/lib/tests/streamlit/dataframe_selector_test.py b/lib/tests/streamlit/dataframe_selector_test.py
index c1fdb4730372..d2c6eadc1823 100644
--- a/lib/tests/streamlit/dataframe_selector_test.py
+++ b/lib/tests/streamlit/dataframe_selector_test.py
@@ -23,10 +23,11 @@
import streamlit
from streamlit.delta_generator import DeltaGenerator
+from streamlit.errors import StreamlitAPIException
from tests.streamlit import pyspark_mocks
from tests.streamlit.snowpark_mocks import DataFrame as MockSnowparkDataFrame
from tests.streamlit.snowpark_mocks import Table as MockSnowparkTable
-from tests.testutil import patch_config_options, should_skip_pyspark_tests
+from tests.testutil import patch_config_options
DATAFRAME = pd.DataFrame([["A", "B", "C", "D"], [28, 55, 43, 91]], index=["a", "b"]).T
ALTAIR_CHART = alt.Chart(DATAFRAME).mark_bar().encode(x="a", y="b")
@@ -80,9 +81,6 @@ def test_arrow_dataframe_with_snowpark_dataframe(
column_config=None,
)
- @pytest.mark.skipif(
- should_skip_pyspark_tests(), reason="pyspark is incompatible with Python3.11"
- )
@patch.object(DeltaGenerator, "_legacy_dataframe")
@patch.object(DeltaGenerator, "_arrow_dataframe")
@patch_config_options({"global.dataFrameSerialization": "arrow"})
@@ -159,6 +157,7 @@ def test_arrow_line_chart(self, arrow_line_chart, legacy_line_chart):
DATAFRAME,
x=None,
y=None,
+ color=None,
width=100,
height=200,
use_container_width=True,
@@ -184,6 +183,7 @@ def test_arrow_area_chart(self, arrow_area_chart, legacy_area_chart):
DATAFRAME,
x=None,
y=None,
+ color=None,
width=100,
height=200,
use_container_width=True,
@@ -209,6 +209,7 @@ def test_arrow_bar_chart(self, arrow_bar_chart, legacy_bar_chart):
DATAFRAME,
x=None,
y=None,
+ color=None,
width=100,
height=200,
use_container_width=True,
diff --git a/lib/tests/streamlit/delta_generator_test.py b/lib/tests/streamlit/delta_generator_test.py
index 966930b4feb4..d74d64939bca 100644
--- a/lib/tests/streamlit/delta_generator_test.py
+++ b/lib/tests/streamlit/delta_generator_test.py
@@ -38,6 +38,7 @@
from streamlit.proto.Text_pb2 import Text as TextProto
from streamlit.proto.TextArea_pb2 import TextArea
from streamlit.proto.TextInput_pb2 import TextInput
+from streamlit.runtime.state.common import compute_widget_id
from streamlit.runtime.state.widgets import _build_duplicate_widget_message
from tests.delta_generator_test_case import DeltaGeneratorTestCase
@@ -142,6 +143,7 @@ def test_public_api(self):
"snow",
"subheader",
"success",
+ "status",
"table",
"tabs",
"text",
@@ -150,6 +152,7 @@ def test_public_api(self):
"time_input",
"title",
"toast",
+ "toggle",
"vega_lite_chart",
"video",
"warning",
@@ -608,97 +611,67 @@ def test_empty(self):
class AutogeneratedWidgetIdTests(DeltaGeneratorTestCase):
- def test_ids_are_equal_when_proto_is_equal(self):
- text_input1 = TextInput()
- text_input1.label = "Label #1"
- text_input1.default = "Value #1"
-
- text_input2 = TextInput()
- text_input2.label = "Label #1"
- text_input2.default = "Value #1"
-
- element1 = Element()
- element1.text_input.CopyFrom(text_input1)
-
- element2 = Element()
- element2.text_input.CopyFrom(text_input2)
-
- register_widget("text_input", element1.text_input, ctx=self.script_run_ctx)
+ def test_ids_are_equal_when_inputs_are_equal(self):
+ id1 = compute_widget_id(
+ "text_input",
+ label="Label #1",
+ default="Value #1",
+ )
- with self.assertRaises(DuplicateWidgetID):
- register_widget("text_input", element2.text_input, ctx=self.script_run_ctx)
+ id2 = compute_widget_id(
+ "text_input",
+ label="Label #1",
+ default="Value #1",
+ )
+ assert id1 == id2
def test_ids_are_diff_when_labels_are_diff(self):
- text_input1 = TextInput()
- text_input1.label = "Label #1"
- text_input1.default = "Value #1"
-
- text_input2 = TextInput()
- text_input2.label = "Label #2"
- text_input2.default = "Value #1"
-
- element1 = Element()
- element1.text_input.CopyFrom(text_input1)
-
- element2 = Element()
- element2.text_input.CopyFrom(text_input2)
-
- register_widget("text_input", element1.text_input, ctx=self.script_run_ctx)
- register_widget("text_input", element2.text_input, ctx=self.script_run_ctx)
+ id1 = compute_widget_id(
+ "text_input",
+ label="Label #1",
+ default="Value #1",
+ )
+ id2 = compute_widget_id(
+ "text_input",
+ label="Label #2",
+ default="Value #1",
+ )
- self.assertNotEqual(element1.text_input.id, element2.text_input.id)
+ assert id1 != id2
def test_ids_are_diff_when_types_are_diff(self):
- text_input1 = TextInput()
- text_input1.label = "Label #1"
- text_input1.default = "Value #1"
-
- text_area2 = TextArea()
- text_area2.label = "Label #1"
- text_area2.default = "Value #1"
-
- element1 = Element()
- element1.text_input.CopyFrom(text_input1)
-
- element2 = Element()
- element2.text_area.CopyFrom(text_area2)
-
- register_widget("text_input", element1.text_input, ctx=self.script_run_ctx)
- register_widget("text_input", element2.text_input, ctx=self.script_run_ctx)
-
- self.assertNotEqual(element1.text_input.id, element2.text_area.id)
+ id1 = compute_widget_id(
+ "text_input",
+ label="Label #1",
+ default="Value #1",
+ )
+ id2 = compute_widget_id(
+ "text_area",
+ label="Label #1",
+ default="Value #1",
+ )
+ assert id1 != id2
class KeyWidgetIdTests(DeltaGeneratorTestCase):
def test_ids_are_diff_when_keys_are_diff(self):
- text_input1 = TextInput()
- text_input1.label = "Label #1"
- text_input1.default = "Value #1"
-
- text_input2 = TextInput()
- text_input2.label = "Label #1"
- text_input2.default = "Value #1"
-
- element1 = Element()
- element1.text_input.CopyFrom(text_input1)
-
- element2 = Element()
- element2.text_input.CopyFrom(text_input2)
-
- register_widget(
+ id1 = compute_widget_id(
"text_input",
- element1.text_input,
user_key="some_key1",
- ctx=self.script_run_ctx,
+ label="Label #1",
+ default="Value #1",
+ key="some_key1",
)
- register_widget(
+
+ id2 = compute_widget_id(
"text_input",
- element2.text_input,
user_key="some_key2",
- ctx=self.script_run_ctx,
+ label="Label #1",
+ default="Value #1",
+ key="some_key2",
)
- self.assertNotEqual(element1.text_input.id, element2.text_input.id)
+ assert id1 != id2
class DeltaGeneratorImageTest(DeltaGeneratorTestCase):
diff --git a/lib/tests/streamlit/elements/arrow_add_rows_test.py b/lib/tests/streamlit/elements/arrow_add_rows_test.py
index 6465e2795769..e2ef9f9d005a 100644
--- a/lib/tests/streamlit/elements/arrow_add_rows_test.py
+++ b/lib/tests/streamlit/elements/arrow_add_rows_test.py
@@ -15,41 +15,201 @@
"""Unit test of dg._arrow_add_rows()."""
import pandas as pd
+from parameterized import parameterized
import streamlit as st
from streamlit.type_util import bytes_to_data_frame
from tests.delta_generator_test_case import DeltaGeneratorTestCase
-DATAFRAME = pd.DataFrame({"a": [1], "b": [10]})
-NEW_ROWS = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
-MELTED_DATAFRAME = pd.DataFrame(
- {
- "index": [1, 2, 3, 1, 2, 3],
- "variable": ["a", "a", "a", "b", "b", "b"],
- "value": [1, 2, 3, 4, 5, 6],
- }
-)
+DATAFRAME = pd.DataFrame({"a": [10], "b": [20], "c": [30]})
+NEW_ROWS = pd.DataFrame({"a": [11, 12, 13], "b": [21, 22, 23], "c": [31, 32, 33]})
class DeltaGeneratorAddRowsTest(DeltaGeneratorTestCase):
"""Test dg._arrow_add_rows."""
- def _get_deltas_that_melt_dataframes(self):
- return [
- lambda df: st._arrow_line_chart(df),
- lambda df: st._arrow_bar_chart(df),
- lambda df: st._arrow_area_chart(df),
+ @parameterized.expand(
+ [
+ st._arrow_area_chart,
+ st._arrow_bar_chart,
+ st._arrow_line_chart,
]
+ )
+ def test_charts_with_implict_x_and_y(self, chart_command):
+ expected = pd.DataFrame(
+ {
+ "index--p5bJXXpQgvPz6yvQMFiy": [1, 2, 3, 1, 2, 3, 1, 2, 3],
+ "color--p5bJXXpQgvPz6yvQMFiy": [
+ "a",
+ "a",
+ "a",
+ "b",
+ "b",
+ "b",
+ "c",
+ "c",
+ "c",
+ ],
+ "value--p5bJXXpQgvPz6yvQMFiy": [11, 12, 13, 21, 22, 23, 31, 32, 33],
+ }
+ )
- def test_deltas_that_melt_dataframes(self):
- deltas = self._get_deltas_that_melt_dataframes()
+ element = chart_command(DATAFRAME)
+ element._arrow_add_rows(NEW_ROWS)
- for delta in deltas:
- element = delta(DATAFRAME)
- element._arrow_add_rows(NEW_ROWS)
+ proto = bytes_to_data_frame(
+ self.get_delta_from_queue().arrow_add_rows.data.data
+ )
- proto = bytes_to_data_frame(
- self.get_delta_from_queue().arrow_add_rows.data.data
- )
+ pd.testing.assert_frame_equal(proto, expected)
- pd.testing.assert_frame_equal(proto, MELTED_DATAFRAME)
+ @parameterized.expand(
+ [
+ st._arrow_area_chart,
+ st._arrow_bar_chart,
+ st._arrow_line_chart,
+ ]
+ )
+ def test_charts_with_explicit_x_and_y(self, chart_command):
+ expected = pd.DataFrame(
+ {
+ "b": [21, 22, 23],
+ "c": [31, 32, 33],
+ }
+ )
+ expected.index = pd.RangeIndex(1, 4)
+
+ element = chart_command(DATAFRAME, x="b", y="c")
+ element._arrow_add_rows(NEW_ROWS)
+
+ proto = bytes_to_data_frame(
+ self.get_delta_from_queue().arrow_add_rows.data.data
+ )
+
+ pd.testing.assert_frame_equal(proto, expected)
+
+ @parameterized.expand(
+ [
+ st._arrow_area_chart,
+ st._arrow_bar_chart,
+ st._arrow_line_chart,
+ ]
+ )
+ def test_charts_with_implict_x_and_explicit_y(self, chart_command):
+ expected = pd.DataFrame(
+ {
+ "index--p5bJXXpQgvPz6yvQMFiy": [1, 2, 3],
+ "b": [21, 22, 23],
+ }
+ )
+
+ element = chart_command(DATAFRAME, y="b")
+ element._arrow_add_rows(NEW_ROWS)
+
+ proto = bytes_to_data_frame(
+ self.get_delta_from_queue().arrow_add_rows.data.data
+ )
+
+ pd.testing.assert_frame_equal(proto, expected)
+
+ @parameterized.expand(
+ [
+ st._arrow_area_chart,
+ st._arrow_bar_chart,
+ st._arrow_line_chart,
+ ]
+ )
+ def test_charts_with_explicit_x_and_implicit_y(self, chart_command):
+ expected = pd.DataFrame(
+ {
+ "b": [21, 22, 23, 21, 22, 23],
+ "color--p5bJXXpQgvPz6yvQMFiy": ["a", "a", "a", "c", "c", "c"],
+ "value--p5bJXXpQgvPz6yvQMFiy": [11, 12, 13, 31, 32, 33],
+ }
+ )
+
+ element = chart_command(DATAFRAME, x="b")
+ element._arrow_add_rows(NEW_ROWS)
+
+ proto = bytes_to_data_frame(
+ self.get_delta_from_queue().arrow_add_rows.data.data
+ )
+
+ pd.testing.assert_frame_equal(proto, expected)
+
+ @parameterized.expand(
+ [
+ st._arrow_area_chart,
+ st._arrow_bar_chart,
+ st._arrow_line_chart,
+ ]
+ )
+ def test_charts_with_explicit_x_and_y_sequence(self, chart_command):
+ expected = pd.DataFrame(
+ {
+ "b": [21, 22, 23, 21, 22, 23],
+ "color--p5bJXXpQgvPz6yvQMFiy": ["a", "a", "a", "c", "c", "c"],
+ "value--p5bJXXpQgvPz6yvQMFiy": [11, 12, 13, 31, 32, 33],
+ }
+ )
+
+ element = chart_command(DATAFRAME, x="b", y=["a", "c"])
+ element._arrow_add_rows(NEW_ROWS)
+
+ proto = bytes_to_data_frame(
+ self.get_delta_from_queue().arrow_add_rows.data.data
+ )
+
+ pd.testing.assert_frame_equal(proto, expected)
+
+ @parameterized.expand(
+ [
+ st._arrow_area_chart,
+ st._arrow_bar_chart,
+ st._arrow_line_chart,
+ ]
+ )
+ def test_charts_with_explicit_x_and_y_sequence_and_static_color(
+ self, chart_command
+ ):
+ expected = pd.DataFrame(
+ {
+ "b": [21, 22, 23, 21, 22, 23],
+ "color--p5bJXXpQgvPz6yvQMFiy": ["a", "a", "a", "c", "c", "c"],
+ "value--p5bJXXpQgvPz6yvQMFiy": [11, 12, 13, 31, 32, 33],
+ }
+ )
+
+ element = chart_command(DATAFRAME, x="b", y=["a", "c"], color=["#f00", "#0f0"])
+ element._arrow_add_rows(NEW_ROWS)
+
+ proto = bytes_to_data_frame(
+ self.get_delta_from_queue().arrow_add_rows.data.data
+ )
+
+ pd.testing.assert_frame_equal(proto, expected)
+
+ @parameterized.expand(
+ [
+ st._arrow_area_chart,
+ st._arrow_bar_chart,
+ st._arrow_line_chart,
+ ]
+ )
+ def test_charts_with_fewer_args_than_cols(self, chart_command):
+ expected = pd.DataFrame(
+ {
+ "b": [21, 22, 23],
+ "a": [11, 12, 13],
+ }
+ )
+ expected.index = pd.RangeIndex(start=1, stop=4, step=1)
+
+ element = chart_command(DATAFRAME, x="b", y="a")
+ element._arrow_add_rows(NEW_ROWS)
+
+ proto = bytes_to_data_frame(
+ self.get_delta_from_queue().arrow_add_rows.data.data
+ )
+
+ pd.testing.assert_frame_equal(proto, expected)
diff --git a/lib/tests/streamlit/elements/arrow_altair_test.py b/lib/tests/streamlit/elements/arrow_altair_test.py
index 3a0d3a7dd88c..ccd59b5b438c 100644
--- a/lib/tests/streamlit/elements/arrow_altair_test.py
+++ b/lib/tests/streamlit/elements/arrow_altair_test.py
@@ -15,7 +15,7 @@
import json
from datetime import date
from functools import reduce
-from typing import Callable
+from typing import Any, Callable
import altair as alt
import pandas as pd
@@ -28,6 +28,7 @@
from streamlit.errors import StreamlitAPIException
from streamlit.type_util import bytes_to_data_frame
from tests.delta_generator_test_case import DeltaGeneratorTestCase
+from tests.streamlit import pyspark_mocks, snowpark_mocks
def _deep_get(dictionary, *keys):
@@ -80,7 +81,7 @@ def test_date_column_utc_scale(self):
{"index": [date(2019, 8, 9), date(2019, 8, 10)], "numbers": [1, 10]}
).set_index("index")
- chart = altair._generate_chart(ChartType.LINE, df)
+ chart, _ = altair._generate_chart(ChartType.LINE, df)
st._arrow_altair_chart(chart)
proto = self.get_delta_from_queue().new_element.arrow_vega_lite_chart
spec_dict = json.loads(proto.spec)
@@ -105,7 +106,7 @@ def test_theme(self, theme_value, proto_value):
{"index": [date(2019, 8, 9), date(2019, 8, 10)], "numbers": [1, 10]}
).set_index("index")
- chart = altair._generate_chart(ChartType.LINE, df)
+ chart, _ = altair._generate_chart(ChartType.LINE, df)
st._arrow_altair_chart(chart, theme=theme_value)
el = self.get_delta_from_queue().new_element
@@ -116,7 +117,7 @@ def test_bad_theme(self):
{"index": [date(2019, 8, 9), date(2019, 8, 10)], "numbers": [1, 10]}
).set_index("index")
- chart = altair._generate_chart(ChartType.LINE, df)
+ chart, _ = altair._generate_chart(ChartType.LINE, df)
with self.assertRaises(StreamlitAPIException) as exc:
st._arrow_altair_chart(chart, theme="bad_theme")
@@ -129,22 +130,31 @@ def test_bad_theme(self):
class ArrowChartsTest(DeltaGeneratorTestCase):
"""Test Arrow charts."""
- def test_arrow_line_chart(self):
- """Test st._arrow_line_chart."""
- df = pd.DataFrame([[20, 30, 50]], columns=["a", "b", "c"])
- EXPECTED_DATAFRAME = pd.DataFrame(
- [[0, "a", 20], [0, "b", 30], [0, "c", 50]],
- index=[0, 1, 2],
- columns=["index", "variable", "value"],
- )
+ @parameterized.expand(
+ [
+ (st._arrow_area_chart, "area"),
+ (st._arrow_bar_chart, "bar"),
+ (st._arrow_line_chart, "line"),
+ ]
+ )
+ def test_empty_arrow_chart(self, chart_command: Callable, altair_type: str):
+ """Test arrow chart with no arguments."""
+ EXPECTED_DATAFRAME = pd.DataFrame()
+
+ # Make some mutations that arrow_altair.prep_data() does.
+ column_names = list(
+ EXPECTED_DATAFRAME.columns
+ ) # list() converts RangeIndex, etc, to regular list.
+ str_column_names = [str(c) for c in column_names]
+ EXPECTED_DATAFRAME.columns = pd.Index(str_column_names)
- st._arrow_line_chart(df, width=640, height=480)
+ chart_command()
proto = self.get_delta_from_queue().new_element.arrow_vega_lite_chart
+
chart_spec = json.loads(proto.spec)
- self.assertIn(chart_spec["mark"], ["line", {"type": "line"}])
- self.assertEqual(chart_spec["width"], 640)
- self.assertEqual(chart_spec["height"], 480)
+ self.assertIn(chart_spec["mark"], [altair_type, {"type": altair_type}])
+
pd.testing.assert_frame_equal(
bytes_to_data_frame(proto.datasets[0].data.data),
EXPECTED_DATAFRAME,
@@ -157,7 +167,217 @@ def test_arrow_line_chart(self):
(st._arrow_line_chart, "line"),
]
)
- def test_arrow_chart_with_x_y(self, chart_command: Callable, altair_type: str):
+ def test_arrow_chart_with_implicit_x_and_y(
+ self, chart_command: Callable, altair_type: str
+ ):
+ """Test st._arrow_line_chart with implicit x and y."""
+ df = pd.DataFrame([[20, 30, 50]], columns=["a", "b", "c"])
+ EXPECTED_DATAFRAME = pd.DataFrame(
+ [[20, "b", 30], [20, "c", 50]],
+ columns=["a", "color--p5bJXXpQgvPz6yvQMFiy", "value--p5bJXXpQgvPz6yvQMFiy"],
+ )
+
+ chart_command(df, x="a", y=["b", "c"])
+
+ proto = self.get_delta_from_queue().new_element.arrow_vega_lite_chart
+ chart_spec = json.loads(proto.spec)
+
+ self.assertIn(chart_spec["mark"], [altair_type, {"type": altair_type}])
+ self.assertEqual(chart_spec["encoding"]["x"]["field"], "a")
+ self.assertEqual(
+ chart_spec["encoding"]["y"]["field"], "value--p5bJXXpQgvPz6yvQMFiy"
+ )
+ self.assertEqual(
+ chart_spec["encoding"]["color"]["field"], "color--p5bJXXpQgvPz6yvQMFiy"
+ )
+
+ self.assert_output_df_is_correct_and_input_is_untouched(
+ orig_df=df, expected_df=EXPECTED_DATAFRAME, chart_proto=proto
+ )
+
+ @parameterized.expand(
+ [
+ (st._arrow_area_chart, "area"),
+ (st._arrow_bar_chart, "bar"),
+ (st._arrow_line_chart, "line"),
+ ]
+ )
+ def test_arrow_chart_with_pyspark_dataframe(
+ self, chart_command: Callable, altair_type: str
+ ):
+ spark_df = pyspark_mocks.DataFrame(is_numpy_arr=True)
+
+ chart_command(spark_df)
+
+ proto = self.get_delta_from_queue().new_element.arrow_vega_lite_chart
+ chart_spec = json.loads(proto.spec)
+ self.assertIn(chart_spec["mark"], [altair_type, {"type": altair_type}])
+ self.assertEqual(
+ chart_spec["encoding"]["x"]["field"], "index--p5bJXXpQgvPz6yvQMFiy"
+ )
+ self.assertEqual(
+ chart_spec["encoding"]["y"]["field"], "value--p5bJXXpQgvPz6yvQMFiy"
+ )
+ self.assertEqual(
+ chart_spec["encoding"]["color"]["field"], "color--p5bJXXpQgvPz6yvQMFiy"
+ )
+
+ output_df = bytes_to_data_frame(proto.datasets[0].data.data)
+
+ self.assertEqual(len(output_df.columns), 3)
+ self.assertEqual(output_df.columns[0], "index--p5bJXXpQgvPz6yvQMFiy")
+ self.assertEqual(output_df.columns[1], "color--p5bJXXpQgvPz6yvQMFiy")
+ self.assertEqual(output_df.columns[2], "value--p5bJXXpQgvPz6yvQMFiy")
+
+ @parameterized.expand(
+ [
+ (st._arrow_area_chart, "area"),
+ (st._arrow_bar_chart, "bar"),
+ (st._arrow_line_chart, "line"),
+ ]
+ )
+ def test_arrow_chart_with_snowpark_dataframe(
+ self, chart_command: Callable, altair_type: str
+ ):
+ snow_df = snowpark_mocks.DataFrame()
+
+ chart_command(snow_df)
+
+ proto = self.get_delta_from_queue().new_element.arrow_vega_lite_chart
+ chart_spec = json.loads(proto.spec)
+ self.assertIn(chart_spec["mark"], [altair_type, {"type": altair_type}])
+ self.assertEqual(
+ chart_spec["encoding"]["x"]["field"], "index--p5bJXXpQgvPz6yvQMFiy"
+ )
+ self.assertEqual(
+ chart_spec["encoding"]["y"]["field"], "value--p5bJXXpQgvPz6yvQMFiy"
+ )
+ self.assertEqual(
+ chart_spec["encoding"]["color"]["field"], "color--p5bJXXpQgvPz6yvQMFiy"
+ )
+
+ output_df = bytes_to_data_frame(proto.datasets[0].data.data)
+
+ self.assertEqual(len(output_df.columns), 3)
+ self.assertEqual(output_df.columns[0], "index--p5bJXXpQgvPz6yvQMFiy")
+ self.assertEqual(output_df.columns[1], "color--p5bJXXpQgvPz6yvQMFiy")
+ self.assertEqual(output_df.columns[2], "value--p5bJXXpQgvPz6yvQMFiy")
+
+ @parameterized.expand(
+ [
+ (st._arrow_area_chart, "area"),
+ (st._arrow_bar_chart, "bar"),
+ (st._arrow_line_chart, "line"),
+ ]
+ )
+ def test_arrow_chart_with_explicit_x_and_implicit_y(
+ self, chart_command: Callable, altair_type: str
+ ):
+ """Test st._arrow_line_chart with explicit x and implicit y."""
+ df = pd.DataFrame([[20, 30, 50]], columns=["a", "b", "c"])
+ EXPECTED_DATAFRAME = pd.DataFrame(
+ [[20, "b", 30], [20, "c", 50]],
+ columns=["a", "color--p5bJXXpQgvPz6yvQMFiy", "value--p5bJXXpQgvPz6yvQMFiy"],
+ )
+
+ chart_command(df, x="a")
+
+ proto = self.get_delta_from_queue().new_element.arrow_vega_lite_chart
+ chart_spec = json.loads(proto.spec)
+ self.assertIn(chart_spec["mark"], [altair_type, {"type": altair_type}])
+ self.assertEqual(chart_spec["encoding"]["x"]["field"], "a")
+ self.assertEqual(
+ chart_spec["encoding"]["y"]["field"], "value--p5bJXXpQgvPz6yvQMFiy"
+ )
+ self.assertEqual(
+ chart_spec["encoding"]["color"]["field"], "color--p5bJXXpQgvPz6yvQMFiy"
+ )
+
+ self.assert_output_df_is_correct_and_input_is_untouched(
+ orig_df=df, expected_df=EXPECTED_DATAFRAME, chart_proto=proto
+ )
+
+ @parameterized.expand(
+ [
+ (st._arrow_area_chart, "area"),
+ (st._arrow_bar_chart, "bar"),
+ (st._arrow_line_chart, "line"),
+ ]
+ )
+ def test_arrow_chart_with_implicit_x_and_explicit_y(
+ self, chart_command: Callable, altair_type: str
+ ):
+ """Test st._arrow_line_chart with implicit x and explicit y."""
+ df = pd.DataFrame([[20, 30, 50]], columns=["a", "b", "c"])
+ EXPECTED_DATAFRAME = pd.DataFrame(
+ [[0, 30]], columns=["index--p5bJXXpQgvPz6yvQMFiy", "b"]
+ )
+
+ chart_command(df, y="b")
+
+ proto = self.get_delta_from_queue().new_element.arrow_vega_lite_chart
+ chart_spec = json.loads(proto.spec)
+ self.assertIn(chart_spec["mark"], [altair_type, {"type": altair_type}])
+ self.assertEqual(
+ chart_spec["encoding"]["x"]["field"], "index--p5bJXXpQgvPz6yvQMFiy"
+ )
+ self.assertEqual(chart_spec["encoding"]["y"]["field"], "b")
+ self.assertFalse("color" in chart_spec["encoding"])
+
+ self.assert_output_df_is_correct_and_input_is_untouched(
+ orig_df=df, expected_df=EXPECTED_DATAFRAME, chart_proto=proto
+ )
+
+ @parameterized.expand(
+ [
+ (st._arrow_area_chart, "area"),
+ (st._arrow_bar_chart, "bar"),
+ (st._arrow_line_chart, "line"),
+ ]
+ )
+ def test_arrow_chart_with_implicit_x_and_explicit_y_sequence(
+ self, chart_command: Callable, altair_type: str
+ ):
+ """Test st._arrow_line_chart with implicit x and explicit y sequence."""
+ df = pd.DataFrame([[20, 30, 50, 60]], columns=["a", "b", "c", "d"])
+ EXPECTED_DATAFRAME = pd.DataFrame(
+ [[0, "b", 30], [0, "c", 50]],
+ columns=[
+ "index--p5bJXXpQgvPz6yvQMFiy",
+ "color--p5bJXXpQgvPz6yvQMFiy",
+ "value--p5bJXXpQgvPz6yvQMFiy",
+ ],
+ )
+
+ chart_command(df, y=["b", "c"])
+
+ proto = self.get_delta_from_queue().new_element.arrow_vega_lite_chart
+ chart_spec = json.loads(proto.spec)
+ self.assertIn(chart_spec["mark"], [altair_type, {"type": altair_type}])
+ self.assertEqual(
+ chart_spec["encoding"]["x"]["field"], "index--p5bJXXpQgvPz6yvQMFiy"
+ )
+ self.assertEqual(
+ chart_spec["encoding"]["y"]["field"], "value--p5bJXXpQgvPz6yvQMFiy"
+ )
+ self.assertEqual(
+ chart_spec["encoding"]["color"]["field"], "color--p5bJXXpQgvPz6yvQMFiy"
+ )
+
+ self.assert_output_df_is_correct_and_input_is_untouched(
+ orig_df=df, expected_df=EXPECTED_DATAFRAME, chart_proto=proto
+ )
+
+ @parameterized.expand(
+ [
+ (st._arrow_area_chart, "area"),
+ (st._arrow_bar_chart, "bar"),
+ (st._arrow_line_chart, "line"),
+ ]
+ )
+ def test_arrow_chart_with_explicit_x_and_y(
+ self, chart_command: Callable, altair_type: str
+ ):
"""Test x/y-support for built-in charts."""
df = pd.DataFrame([[20, 30, 50]], columns=["a", "b", "c"])
EXPECTED_DATAFRAME = pd.DataFrame([[20, 30]], columns=["a", "b"])
@@ -172,9 +392,9 @@ def test_arrow_chart_with_x_y(self, chart_command: Callable, altair_type: str):
self.assertEqual(chart_spec["height"], 480)
self.assertEqual(chart_spec["encoding"]["x"]["field"], "a")
self.assertEqual(chart_spec["encoding"]["y"]["field"], "b")
- pd.testing.assert_frame_equal(
- bytes_to_data_frame(proto.datasets[0].data.data),
- EXPECTED_DATAFRAME,
+
+ self.assert_output_df_is_correct_and_input_is_untouched(
+ orig_df=df, expected_df=EXPECTED_DATAFRAME, chart_proto=proto
)
@parameterized.expand(
@@ -184,13 +404,14 @@ def test_arrow_chart_with_x_y(self, chart_command: Callable, altair_type: str):
(st._arrow_line_chart, "line"),
]
)
- def test_arrow_chart_with_x_y_sequence(
+ def test_arrow_chart_with_explicit_x_and_y_sequence(
self, chart_command: Callable, altair_type: str
):
- """Test x/y-sequence support for built-in charts."""
- df = pd.DataFrame([[20, 30, 50]], columns=["a", "b", "c"])
+ """Test support for explicit wide-format tables (i.e. y is a sequence)."""
+ df = pd.DataFrame([[20, 30, 50, 60]], columns=["a", "b", "c", "d"])
EXPECTED_DATAFRAME = pd.DataFrame(
- [[20, "b", 30], [20, "c", 50]], columns=["a", "variable", "value"]
+ [[20, "b", 30], [20, "c", 50]],
+ columns=["a", "color--p5bJXXpQgvPz6yvQMFiy", "value--p5bJXXpQgvPz6yvQMFiy"],
)
chart_command(df, x="a", y=["b", "c"])
@@ -200,11 +421,192 @@ def test_arrow_chart_with_x_y_sequence(
self.assertIn(chart_spec["mark"], [altair_type, {"type": altair_type}])
self.assertEqual(chart_spec["encoding"]["x"]["field"], "a")
- self.assertEqual(chart_spec["encoding"]["y"]["field"], "value")
+ self.assertEqual(
+ chart_spec["encoding"]["y"]["field"], "value--p5bJXXpQgvPz6yvQMFiy"
+ )
+ self.assertEqual(
+ chart_spec["encoding"]["color"]["field"], "color--p5bJXXpQgvPz6yvQMFiy"
+ )
- pd.testing.assert_frame_equal(
- bytes_to_data_frame(proto.datasets[0].data.data),
- EXPECTED_DATAFRAME,
+ self.assert_output_df_is_correct_and_input_is_untouched(
+ orig_df=df, expected_df=EXPECTED_DATAFRAME, chart_proto=proto
+ )
+
+ @parameterized.expand(
+ [
+ (st._arrow_area_chart, "area"),
+ (st._arrow_bar_chart, "bar"),
+ (st._arrow_line_chart, "line"),
+ ]
+ )
+ def test_arrow_chart_with_color_value(
+ self, chart_command: Callable, altair_type: str
+ ):
+ """Test color support for built-in charts."""
+ df = pd.DataFrame([[20, 30]], columns=["a", "b"])
+ EXPECTED_DATAFRAME = pd.DataFrame([[20, 30]], columns=["a", "b"])
+
+ chart_command(df, x="a", y="b", color="#f00")
+
+ proto = self.get_delta_from_queue().new_element.arrow_vega_lite_chart
+ chart_spec = json.loads(proto.spec)
+
+ self.assertEqual(chart_spec["encoding"]["color"]["value"], "#f00")
+
+ self.assert_output_df_is_correct_and_input_is_untouched(
+ orig_df=df, expected_df=EXPECTED_DATAFRAME, chart_proto=proto
+ )
+
+ @parameterized.expand(
+ [
+ (st._arrow_area_chart, "area"),
+ (st._arrow_bar_chart, "bar"),
+ (st._arrow_line_chart, "line"),
+ ]
+ )
+ def test_arrow_chart_with_color_column(
+ self, chart_command: Callable, altair_type: str
+ ):
+ """Test color support for built-in charts."""
+ df = pd.DataFrame(
+ {
+ "x": [0, 1, 2],
+ "y": [22, 21, 20],
+ "tuple3_int_color": [[255, 0, 0], [0, 255, 0], [0, 0, 255]],
+ "tuple4_int_int_color": [
+ [255, 0, 0, 51],
+ [0, 255, 0, 51],
+ [0, 0, 255, 51],
+ ],
+ "tuple4_int_float_color": [
+ [255, 0, 0, 0.2],
+ [0, 255, 0, 0.2],
+ [0, 0, 255, 0.2],
+ ],
+ "tuple3_float_color": [
+ [1.0, 0.0, 0.0],
+ [0.0, 1.0, 0.0],
+ [0.0, 0.0, 1.0],
+ ],
+ "tuple4_float_float_color": [
+ [1.0, 0.0, 0.0, 0.2],
+ [0.0, 1.0, 0.0, 0.2],
+ [0.0, 0.0, 1.0, 0.2],
+ ],
+ "hex3_color": ["#f00", "#0f0", "#00f"],
+ "hex4_color": ["#f008", "#0f08", "#00f8"],
+ "hex6_color": ["#ff0000", "#00ff00", "#0000ff"],
+ "hex8_color": ["#ff000088", "#00ff0088", "#0000ff88"],
+ }
+ )
+
+ color_columns = sorted(set(df.columns))
+ color_columns.remove("x")
+ color_columns.remove("y")
+
+ expected_values = pd.DataFrame(
+ {
+ "tuple3": ["rgb(255, 0, 0)", "rgb(0, 255, 0)", "rgb(0, 0, 255)"],
+ "tuple4": [
+ "rgba(255, 0, 0, 0.2)",
+ "rgba(0, 255, 0, 0.2)",
+ "rgba(0, 0, 255, 0.2)",
+ ],
+ "hex3": ["#f00", "#0f0", "#00f"],
+ "hex6": ["#ff0000", "#00ff00", "#0000ff"],
+ "hex4": ["#f008", "#0f08", "#00f8"],
+ "hex8": ["#ff000088", "#00ff0088", "#0000ff88"],
+ }
+ )
+
+ def get_expected_color_values(col_name):
+ for prefix, expected_color_values in expected_values.items():
+ if col_name.startswith(prefix):
+ return expected_color_values
+
+ for color_column in color_columns:
+ expected_color_values = get_expected_color_values(color_column)
+
+ chart_command(df, x="x", y="y", color=color_column)
+
+ proto = self.get_delta_from_queue().new_element.arrow_vega_lite_chart
+ chart_spec = json.loads(proto.spec)
+
+ self.assertEqual(chart_spec["encoding"]["color"]["field"], color_column)
+
+ # Manually-specified colors should not have a legend
+ self.assertEqual(chart_spec["encoding"]["color"]["legend"], None)
+
+ # Manually-specified colors are set via the color scale's range property.
+ self.assertTrue(chart_spec["encoding"]["color"]["scale"]["range"])
+
+ proto_df = bytes_to_data_frame(proto.datasets[0].data.data)
+
+ pd.testing.assert_series_equal(
+ proto_df[color_column],
+ expected_color_values,
+ check_names=False,
+ )
+
+ @parameterized.expand(
+ [
+ (st._arrow_area_chart, "area"),
+ (st._arrow_bar_chart, "bar"),
+ (st._arrow_line_chart, "line"),
+ ]
+ )
+ def test_arrow_chart_with_explicit_x_plus_y_and_color_sequence(
+ self, chart_command: Callable, altair_type: str
+ ):
+ """Test color support for built-in charts with wide-format table."""
+ df = pd.DataFrame([[20, 30, 50]], columns=["a", "b", "c"])
+
+ EXPECTED_DATAFRAME = pd.DataFrame(
+ [[20, "b", 30], [20, "c", 50]],
+ columns=["a", "color--p5bJXXpQgvPz6yvQMFiy", "value--p5bJXXpQgvPz6yvQMFiy"],
+ )
+
+ chart_command(df, x="a", y=["b", "c"], color=["#f00", "#0ff"])
+
+ proto = self.get_delta_from_queue().new_element.arrow_vega_lite_chart
+ chart_spec = json.loads(proto.spec)
+
+ self.assertIn(chart_spec["mark"], [altair_type, {"type": altair_type}])
+
+ # Color should be set to the melted column name.
+ self.assertEqual(
+ chart_spec["encoding"]["color"]["field"], "color--p5bJXXpQgvPz6yvQMFiy"
+ )
+
+ # Automatically-specified colors should have no legend title.
+ self.assertEqual(chart_spec["encoding"]["color"]["title"], " ")
+
+ # Automatically-specified colors should have a legend
+ self.assertNotEqual(chart_spec["encoding"]["color"]["legend"], None)
+
+ self.assert_output_df_is_correct_and_input_is_untouched(
+ orig_df=df, expected_df=EXPECTED_DATAFRAME, chart_proto=proto
+ )
+
+ @parameterized.expand(
+ [[None], [[]], [tuple()]],
+ )
+ def test_arrow_chart_with_empty_color(self, color_arg: Any):
+ """Test color support for built-in charts with wide-format table."""
+ df = pd.DataFrame([[20, 30, 50]], columns=["a", "b", "c"])
+
+ EXPECTED_DATAFRAME = pd.DataFrame([[20, 30]], columns=["a", "b"])
+
+ st.line_chart(df, x="a", y="b", color=color_arg)
+
+ proto = self.get_delta_from_queue().new_element.arrow_vega_lite_chart
+ chart_spec = json.loads(proto.spec)
+
+ # Color should be set to the melted column name.
+ self.assertEqual(getattr(chart_spec["encoding"], "color", None), None)
+
+ self.assert_output_df_is_correct_and_input_is_untouched(
+ orig_df=df, expected_df=EXPECTED_DATAFRAME, chart_proto=proto
)
@parameterized.expand(
@@ -247,19 +649,19 @@ def test_arrow_chart_with_x_y_on_sliced_data(
self.assertEqual(chart_spec["encoding"]["x"]["field"], "a")
self.assertEqual(chart_spec["encoding"]["y"]["field"], "b")
- pd.testing.assert_frame_equal(
- bytes_to_data_frame(proto.datasets[0].data.data),
- EXPECTED_DATAFRAME,
+ self.assert_output_df_is_correct_and_input_is_untouched(
+ orig_df=df, expected_df=EXPECTED_DATAFRAME, chart_proto=proto
)
- def test_arrow_line_chart_with_generic_index(self):
- """Test st._arrow_line_chart with a generic index."""
+ def test_arrow_line_chart_with_named_index(self):
+ """Test st._arrow_line_chart with a named index."""
df = pd.DataFrame([[20, 30, 50]], columns=["a", "b", "c"])
df.set_index("a", inplace=True)
+
EXPECTED_DATAFRAME = pd.DataFrame(
[[20, "b", 30], [20, "c", 50]],
index=[0, 1],
- columns=["a", "variable", "value"],
+ columns=["a", "color--p5bJXXpQgvPz6yvQMFiy", "value--p5bJXXpQgvPz6yvQMFiy"],
)
st._arrow_line_chart(df)
@@ -267,48 +669,84 @@ def test_arrow_line_chart_with_generic_index(self):
proto = self.get_delta_from_queue().new_element.arrow_vega_lite_chart
chart_spec = json.loads(proto.spec)
self.assertIn(chart_spec["mark"], ["line", {"type": "line"}])
- pd.testing.assert_frame_equal(
- bytes_to_data_frame(proto.datasets[0].data.data),
- EXPECTED_DATAFRAME,
+
+ self.assert_output_df_is_correct_and_input_is_untouched(
+ orig_df=df, expected_df=EXPECTED_DATAFRAME, chart_proto=proto
)
- def test_arrow_area_chart(self):
- """Test st._arrow_area_chart."""
- df = pd.DataFrame([[20, 30, 50]], columns=["a", "b", "c"])
- EXPECTED_DATAFRAME = pd.DataFrame(
- [[0, "a", 20], [0, "b", 30], [0, "c", 50]],
- index=[0, 1, 2],
- columns=["index", "variable", "value"],
+ @parameterized.expand(
+ [
+ (st._arrow_area_chart, "area"),
+ (st._arrow_bar_chart, "bar"),
+ (st._arrow_line_chart, "line"),
+ ]
+ )
+ def test_unused_columns_are_dropped(
+ self, chart_command: Callable, altair_type: str
+ ):
+ """Test built-in charts drop columns that are not used."""
+
+ df = pd.DataFrame(
+ [[5, 10, 20, 30, 35, 40, 50, 60]],
+ columns=["z", "a", "b", "c", "x", "d", "e", "f"],
)
- st._arrow_area_chart(df)
+ chart_command(df, x="a", y="c", color="d")
+
+ EXPECTED_DATAFRAME = pd.DataFrame([[10, 40, 30]], columns=["a", "d", "c"])
proto = self.get_delta_from_queue().new_element.arrow_vega_lite_chart
- chart_spec = json.loads(proto.spec)
- self.assertIn(chart_spec["mark"], ["area", {"type": "area"}])
- pd.testing.assert_frame_equal(
- bytes_to_data_frame(proto.datasets[0].data.data),
- EXPECTED_DATAFRAME,
+ json.loads(proto.spec)
+
+ self.assert_output_df_is_correct_and_input_is_untouched(
+ orig_df=df, expected_df=EXPECTED_DATAFRAME, chart_proto=proto
)
- def test_arrow_bar_chart(self):
- """Test st._arrow_bar_chart."""
+ @parameterized.expand(
+ [
+ (st._arrow_area_chart, "area"),
+ (st._arrow_bar_chart, "bar"),
+ (st._arrow_line_chart, "line"),
+ ]
+ )
+ def test_arrow_chart_with_bad_color_arg(
+ self, chart_command: Callable, altair_type: str
+ ):
+ """Test that we throw a pretty exception when colors arg is wrong."""
df = pd.DataFrame([[20, 30, 50]], columns=["a", "b", "c"])
- EXPECTED_DATAFRAME = pd.DataFrame(
- [[0, "a", 20], [0, "b", 30], [0, "c", 50]],
- index=[0, 1, 2],
- columns=["index", "variable", "value"],
- )
- st._arrow_bar_chart(df, width=640, height=480)
+ too_few_args = ["#f00", ["#f00"], (1, 0, 0, 0.5)]
+ too_many_args = [["#f00", "#0ff"], [(1, 0, 0), (0, 0, 1)]]
+ bad_args = ["foo", "blue"]
- proto = self.get_delta_from_queue().new_element.arrow_vega_lite_chart
- chart_spec = json.loads(proto.spec)
+ for color_arg in too_few_args:
+ with self.assertRaises(StreamlitAPIException) as exc:
+ chart_command(df, y=["a", "b"], color=color_arg)
- self.assertIn(chart_spec["mark"], ["bar", {"type": "bar"}])
- self.assertEqual(chart_spec["width"], 640)
- self.assertEqual(chart_spec["height"], 480)
- pd.testing.assert_frame_equal(
- bytes_to_data_frame(proto.datasets[0].data.data),
- EXPECTED_DATAFRAME,
- )
+ self.assertTrue("The list of colors" in str(exc.exception))
+
+ for color_arg in too_many_args:
+ with self.assertRaises(StreamlitAPIException) as exc:
+ chart_command(df, y="a", color=color_arg)
+
+ self.assertTrue("The list of colors" in str(exc.exception))
+
+ for color_arg in bad_args:
+ with self.assertRaises(StreamlitAPIException) as exc:
+ chart_command(df, y="a", color=color_arg)
+
+ self.assertTrue(
+ "This does not look like a valid color argument" in str(exc.exception)
+ )
+
+ def assert_output_df_is_correct_and_input_is_untouched(
+ self, orig_df, expected_df, chart_proto
+ ):
+ """Test that when we modify the outgoing DF we don't mutate the input DF."""
+ output_df = bytes_to_data_frame(chart_proto.datasets[0].data.data)
+
+ self.assertNotEqual(id(orig_df), id(output_df))
+ self.assertNotEqual(id(orig_df), id(expected_df))
+ self.assertNotEqual(id(output_df), id(expected_df))
+
+ pd.testing.assert_frame_equal(output_df, expected_df)
diff --git a/lib/tests/streamlit/elements/chat_test.py b/lib/tests/streamlit/elements/chat_test.py
index 0714c8a61ee4..05b5e5279dc6 100644
--- a/lib/tests/streamlit/elements/chat_test.py
+++ b/lib/tests/streamlit/elements/chat_test.py
@@ -73,6 +73,38 @@ def test_assistant_message(self):
BlockProto.ChatMessage.AvatarType.ICON,
)
+ def test_ai_message(self):
+ """Test that the ai preset is mapped to assistant avatar."""
+ message = st.chat_message("ai")
+
+ with message:
+ pass
+
+ message_block = self.get_delta_from_queue()
+
+ self.assertEqual(message_block.add_block.chat_message.name, "ai")
+ self.assertEqual(message_block.add_block.chat_message.avatar, "assistant")
+ self.assertEqual(
+ message_block.add_block.chat_message.avatar_type,
+ BlockProto.ChatMessage.AvatarType.ICON,
+ )
+
+ def test_human_message(self):
+ """Test that the human preset is mapped to user avatar."""
+ message = st.chat_message("human")
+
+ with message:
+ pass
+
+ message_block = self.get_delta_from_queue()
+
+ self.assertEqual(message_block.add_block.chat_message.name, "human")
+ self.assertEqual(message_block.add_block.chat_message.avatar, "user")
+ self.assertEqual(
+ message_block.add_block.chat_message.avatar_type,
+ BlockProto.ChatMessage.AvatarType.ICON,
+ )
+
def test_emoji_avatar(self):
"""Test that it is possible to set an emoji as avatar."""
diff --git a/lib/tests/streamlit/elements/checkbox_test.py b/lib/tests/streamlit/elements/checkbox_test.py
index 422a5e6de21a..1646d92877bc 100644
--- a/lib/tests/streamlit/elements/checkbox_test.py
+++ b/lib/tests/streamlit/elements/checkbox_test.py
@@ -20,6 +20,7 @@
import streamlit as st
from streamlit.errors import StreamlitAPIException
+from streamlit.proto.Checkbox_pb2 import Checkbox as CheckboxProto
from streamlit.proto.LabelVisibilityMessage_pb2 import LabelVisibilityMessage
from streamlit.type_util import _LOGGER
from tests.delta_generator_test_case import DeltaGeneratorTestCase
@@ -44,6 +45,7 @@ def test_just_label(self):
c.label_visibility.value,
LabelVisibilityMessage.LabelVisibilityOptions.VISIBLE,
)
+ self.assertEqual(c.type, CheckboxProto.StyleType.DEFAULT)
def test_just_disabled(self):
"""Test that it can be called with disabled param."""
@@ -141,3 +143,17 @@ def test_empty_label_warning(self):
"`label` got an empty value. This is discouraged for accessibility reasons",
logs.records[0].msg,
)
+
+ def test_toggle_widget(self):
+ """Test that the usage of `st.toggle` uses the correct checkbox proto config."""
+ st.toggle("the label")
+
+ c = self.get_delta_from_queue().new_element.checkbox
+ self.assertEqual(c.label, "the label")
+ self.assertEqual(c.default, False)
+ self.assertEqual(c.disabled, False)
+ self.assertEqual(
+ c.label_visibility.value,
+ LabelVisibilityMessage.LabelVisibilityOptions.VISIBLE,
+ )
+ self.assertEqual(c.type, CheckboxProto.StyleType.TOGGLE)
diff --git a/lib/tests/streamlit/elements/download_button_test.py b/lib/tests/streamlit/elements/download_button_test.py
index bd358d23d87e..01e14b621a0d 100644
--- a/lib/tests/streamlit/elements/download_button_test.py
+++ b/lib/tests/streamlit/elements/download_button_test.py
@@ -30,6 +30,7 @@ def test_just_label(self, data):
c = self.get_delta_from_queue().new_element.download_button
self.assertEqual(c.label, "the label")
+ self.assertEqual(c.type, "secondary")
self.assertEqual(c.disabled, False)
def test_just_disabled(self):
@@ -46,6 +47,13 @@ def test_url_exist(self):
c = self.get_delta_from_queue().new_element.download_button
self.assertTrue("/media/" in c.url)
+ def test_type(self):
+ """Test that it can be called with type param."""
+ st.download_button("the label", data="Streamlit", type="primary")
+
+ c = self.get_delta_from_queue().new_element.download_button
+ self.assertEqual(c.type, "primary")
+
def test_use_container_width_can_be_set_to_true(self):
"""Test use_container_width can be set to true."""
st.download_button("the label", data="juststring", use_container_width=True)
diff --git a/lib/tests/streamlit/elements/file_uploader_test.py b/lib/tests/streamlit/elements/file_uploader_test.py
index cccd451d32b9..a8346e0f1e6b 100644
--- a/lib/tests/streamlit/elements/file_uploader_test.py
+++ b/lib/tests/streamlit/elements/file_uploader_test.py
@@ -21,9 +21,13 @@
import streamlit as st
from streamlit import config
from streamlit.errors import StreamlitAPIException
+from streamlit.proto.Common_pb2 import FileURLs as FileURLsProto
from streamlit.proto.LabelVisibilityMessage_pb2 import LabelVisibilityMessage
-from streamlit.runtime.scriptrunner import get_script_run_ctx
-from streamlit.runtime.uploaded_file_manager import UploadedFile, UploadedFileRec
+from streamlit.runtime.uploaded_file_manager import (
+ DeletedFile,
+ UploadedFile,
+ UploadedFileRec,
+)
from tests.delta_generator_test_case import DeltaGeneratorTestCase
@@ -80,16 +84,23 @@ def test_uppercase_expansion(self):
c = self.get_delta_from_queue().new_element.file_uploader
self.assertEqual(c.type, [".png", ".jpg", ".jpeg"])
- @patch("streamlit.elements.widgets.file_uploader._get_file_recs")
- def test_multiple_files(self, get_file_recs_patch):
+ @patch("streamlit.elements.widgets.file_uploader._get_upload_files")
+ def test_multiple_files(self, get_upload_files_patch):
"""Test the accept_multiple_files flag"""
# Patch UploadFileManager to return two files
- file_recs = [
- UploadedFileRec(1, "file1", "type", b"123"),
- UploadedFileRec(2, "file2", "type", b"456"),
+ rec1 = UploadedFileRec("file1", "file1", "type", b"123")
+ rec2 = UploadedFileRec("file2", "file2", "type", b"456")
+
+ uploaded_files = [
+ UploadedFile(
+ rec1, FileURLsProto(file_id="file1", delete_url="d1", upload_url="u1")
+ ),
+ UploadedFile(
+ rec2, FileURLsProto(file_id="file2", delete_url="d1", upload_url="u1")
+ ),
]
- get_file_recs_patch.return_value = file_recs
+ get_upload_files_patch.return_value = uploaded_files
for accept_multiple in [True, False]:
return_val = st.file_uploader(
@@ -101,25 +112,21 @@ def test_multiple_files(self, get_file_recs_patch):
# If "accept_multiple_files" is True, then we should get a list of
# values back. Otherwise, we should just get a single value.
- # Because file_uploader returns unique UploadedFile instances
- # each time it's called, we convert the return value back
- # from UploadedFile -> UploadedFileRec (which implements
- # equals()) to test equality.
-
if accept_multiple:
- results = [
- UploadedFileRec(file.id, file.name, file.type, file.getvalue())
- for file in return_val
- ]
- self.assertEqual(file_recs, results)
+ self.assertEqual(return_val, uploaded_files)
+
+ for actual, expected in zip(return_val, uploaded_files):
+ self.assertEqual(actual.name, expected.name)
+ self.assertEqual(actual.type, expected.type)
+ self.assertEqual(actual.size, expected.size)
+ self.assertEqual(actual.getvalue(), expected.getvalue())
else:
- results = UploadedFileRec(
- return_val.id,
- return_val.name,
- return_val.type,
- return_val.getvalue(),
- )
- self.assertEqual(file_recs[0], results)
+ first_uploaded_file = uploaded_files[0]
+ self.assertEqual(return_val, first_uploaded_file)
+ self.assertEqual(return_val.name, first_uploaded_file.name)
+ self.assertEqual(return_val.type, first_uploaded_file.type)
+ self.assertEqual(return_val.size, first_uploaded_file.size)
+ self.assertEqual(return_val.getvalue(), first_uploaded_file.getvalue())
def test_max_upload_size_mb(self):
"""Test that the max upload size is the configuration value."""
@@ -130,18 +137,25 @@ def test_max_upload_size_mb(self):
c.max_upload_size_mb, config.get_option("server.maxUploadSize")
)
- @patch("streamlit.elements.widgets.file_uploader._get_file_recs")
- def test_unique_uploaded_file_instance(self, get_file_recs_patch):
+ @patch("streamlit.elements.widgets.file_uploader._get_upload_files")
+ def test_unique_uploaded_file_instance(self, get_upload_files_patch):
"""We should get a unique UploadedFile instance each time we access
the file_uploader widget."""
# Patch UploadFileManager to return two files
- file_recs = [
- UploadedFileRec(1, "file1", "type", b"123"),
- UploadedFileRec(2, "file2", "type", b"456"),
+ rec1 = UploadedFileRec("file1", "file1", "type", b"123")
+ rec2 = UploadedFileRec("file2", "file2", "type", b"456")
+
+ uploaded_files = [
+ UploadedFile(
+ rec1, FileURLsProto(file_id="file1", delete_url="d1", upload_url="u1")
+ ),
+ UploadedFile(
+ rec2, FileURLsProto(file_id="file2", delete_url="d1", upload_url="u1")
+ ),
]
- get_file_recs_patch.return_value = file_recs
+ get_upload_files_patch.return_value = uploaded_files
# These file_uploaders have different labels so that we don't cause
# a DuplicateKey error - but because we're patching the get_files
@@ -156,40 +170,32 @@ def test_unique_uploaded_file_instance(self, get_file_recs_patch):
self.assertEqual(b"3", file1.read())
self.assertEqual(b"123", file2.read())
- @patch(
- "streamlit.runtime.uploaded_file_manager.UploadedFileManager.remove_orphaned_files"
- )
- @patch("streamlit.elements.widgets.file_uploader._get_file_recs")
- def test_remove_orphaned_files(
- self, get_file_recs_patch, remove_orphaned_files_patch
- ):
- """When file_uploader is accessed, it should call
- UploadedFileManager.remove_orphaned_files.
- """
- ctx = get_script_run_ctx()
- ctx.uploaded_file_mgr._file_id_counter = 101
-
- file_recs = [
- UploadedFileRec(1, "file1", "type", b"123"),
- UploadedFileRec(2, "file2", "type", b"456"),
+ @patch("streamlit.elements.widgets.file_uploader._get_upload_files")
+ def test_deleted_files_filtered_out(self, get_upload_files_patch):
+ """We should filter out DeletedFile objects for final user value."""
+
+ rec1 = UploadedFileRec("file1", "file1", "type", b"1234")
+ rec2 = UploadedFileRec("file2", "file2", "type", b"5678")
+
+ uploaded_files = [
+ DeletedFile(file_id="a"),
+ UploadedFile(
+ rec1, FileURLsProto(file_id="file1", delete_url="d1", upload_url="u1")
+ ),
+ DeletedFile(file_id="b"),
+ UploadedFile(
+ rec2, FileURLsProto(file_id="file2", delete_url="d1", upload_url="u1")
+ ),
+ DeletedFile(file_id="c"),
]
- get_file_recs_patch.return_value = file_recs
-
- st.file_uploader("foo", accept_multiple_files=True)
- args, kwargs = remove_orphaned_files_patch.call_args
- self.assertEqual(len(args), 0)
- self.assertEqual(kwargs["session_id"], "test session id")
- self.assertEqual(kwargs["newest_file_id"], 100)
- self.assertEqual(kwargs["active_file_ids"], [1, 2])
+ get_upload_files_patch.return_value = uploaded_files
- # Patch _get_file_recs to return [] instead. remove_orphaned_files
- # should not be called when file_uploader is accessed.
- get_file_recs_patch.return_value = []
- remove_orphaned_files_patch.reset_mock()
+ result_1: UploadedFile = st.file_uploader("a", accept_multiple_files=False)
+ result_2: UploadedFile = st.file_uploader("b", accept_multiple_files=True)
- st.file_uploader("foo")
- remove_orphaned_files_patch.assert_not_called()
+ self.assertEqual(result_1, None)
+ self.assertEqual(result_2, [uploaded_files[1], uploaded_files[3]])
@parameterized.expand(
[
diff --git a/lib/tests/streamlit/elements/heading_test.py b/lib/tests/streamlit/elements/heading_test.py
index 54606a1956bf..5387be6b63fb 100644
--- a/lib/tests/streamlit/elements/heading_test.py
+++ b/lib/tests/streamlit/elements/heading_test.py
@@ -30,6 +30,7 @@ def test_st_header(self):
self.assertEqual(el.heading.body, "some header")
self.assertEqual(el.heading.tag, "h2")
self.assertFalse(el.heading.hide_anchor, False)
+ self.assertFalse(el.heading.divider)
def test_st_header_with_anchor(self):
"""Test st.header with anchor."""
@@ -40,6 +41,7 @@ def test_st_header_with_anchor(self):
self.assertEqual(el.heading.tag, "h2")
self.assertEqual(el.heading.anchor, "some-anchor")
self.assertFalse(el.heading.hide_anchor, False)
+ self.assertFalse(el.heading.divider)
def test_st_header_with_hidden_anchor(self):
"""Test st.header with hidden anchor."""
@@ -50,6 +52,7 @@ def test_st_header_with_hidden_anchor(self):
self.assertEqual(el.heading.tag, "h2")
self.assertEqual(el.heading.anchor, "")
self.assertTrue(el.heading.hide_anchor, True)
+ self.assertFalse(el.heading.divider)
def test_st_header_with_invalid_anchor(self):
"""Test st.header with invalid anchor."""
@@ -63,6 +66,32 @@ def test_st_header_with_help(self):
self.assertEqual(el.heading.body, "some header")
self.assertEqual(el.heading.tag, "h2")
self.assertEqual(el.heading.help, "help text")
+ self.assertFalse(el.heading.divider)
+
+ def test_st_header_with_divider_true(self):
+ """Test st.header with divider True."""
+ st.header("some header", divider=True)
+
+ el = self.get_delta_from_queue().new_element
+ self.assertEqual(el.heading.body, "some header")
+ self.assertEqual(el.heading.tag, "h2")
+ self.assertFalse(el.heading.hide_anchor, False)
+ self.assertEqual(el.heading.divider, "auto")
+
+ def test_st_header_with_divider_color(self):
+ """Test st.header with divider color."""
+ st.header("some header", divider="blue")
+
+ el = self.get_delta_from_queue().new_element
+ self.assertEqual(el.heading.body, "some header")
+ self.assertEqual(el.heading.tag, "h2")
+ self.assertFalse(el.heading.hide_anchor, False)
+ self.assertEqual(el.heading.divider, "blue")
+
+ def test_st_header_with_invalid_divider(self):
+ """Test st.header with invalid divider."""
+ with pytest.raises(StreamlitAPIException):
+ st.header("some header", divider="corgi")
class StSubheaderTest(DeltaGeneratorTestCase):
@@ -76,6 +105,7 @@ def test_st_subheader(self):
self.assertEqual(el.heading.body, "some subheader")
self.assertEqual(el.heading.tag, "h3")
self.assertFalse(el.heading.hide_anchor)
+ self.assertFalse(el.heading.divider)
def test_st_subheader_with_anchor(self):
"""Test st.subheader with anchor."""
@@ -86,6 +116,7 @@ def test_st_subheader_with_anchor(self):
self.assertEqual(el.heading.tag, "h3")
self.assertEqual(el.heading.anchor, "some-anchor")
self.assertFalse(el.heading.hide_anchor)
+ self.assertFalse(el.heading.divider)
def test_st_subheader_with_hidden_anchor(self):
"""Test st.subheader with hidden anchor."""
@@ -96,6 +127,7 @@ def test_st_subheader_with_hidden_anchor(self):
self.assertEqual(el.heading.tag, "h3")
self.assertEqual(el.heading.anchor, "")
self.assertTrue(el.heading.hide_anchor, True)
+ self.assertFalse(el.heading.divider)
def test_st_subheader_with_invalid_anchor(self):
"""Test st.subheader with invalid anchor."""
@@ -109,10 +141,36 @@ def test_st_subheader_with_help(self):
self.assertEqual(el.heading.body, "some subheader")
self.assertEqual(el.heading.tag, "h3")
self.assertEqual(el.heading.help, "help text")
+ self.assertFalse(el.heading.divider)
+
+ def test_st_subheader_with_divider_true(self):
+ """Test st.subheader with divider True."""
+ st.subheader("some subheader", divider=True)
+
+ el = self.get_delta_from_queue().new_element
+ self.assertEqual(el.heading.body, "some subheader")
+ self.assertEqual(el.heading.tag, "h3")
+ self.assertFalse(el.heading.hide_anchor)
+ self.assertEqual(el.heading.divider, "auto")
+
+ def test_st_subheader_with_divider_color(self):
+ """Test st.subheader with divider color."""
+ st.subheader("some subheader", divider="blue")
+
+ el = self.get_delta_from_queue().new_element
+ self.assertEqual(el.heading.body, "some subheader")
+ self.assertEqual(el.heading.tag, "h3")
+ self.assertFalse(el.heading.hide_anchor)
+ self.assertEqual(el.heading.divider, "blue")
+
+ def test_st_subheader_with_invalid_divider(self):
+ """Test st.subheader with invalid divider."""
+ with pytest.raises(StreamlitAPIException):
+ st.subheader("some header", divider="corgi")
class StTitleTest(DeltaGeneratorTestCase):
- """Test ability to marshall subheader protos."""
+ """Test ability to marshall title protos."""
def test_st_title(self):
"""Test st.title."""
@@ -122,6 +180,7 @@ def test_st_title(self):
self.assertEqual(el.heading.body, "some title")
self.assertEqual(el.heading.tag, "h1")
self.assertFalse(el.heading.hide_anchor)
+ self.assertFalse(el.heading.divider)
def test_st_title_with_anchor(self):
"""Test st.title with anchor."""
@@ -132,6 +191,7 @@ def test_st_title_with_anchor(self):
self.assertEqual(el.heading.tag, "h1")
self.assertEqual(el.heading.anchor, "some-anchor")
self.assertFalse(el.heading.hide_anchor)
+ self.assertFalse(el.heading.divider)
def test_st_title_with_hidden_anchor(self):
"""Test st.title with hidden anchor."""
@@ -142,6 +202,7 @@ def test_st_title_with_hidden_anchor(self):
self.assertEqual(el.heading.tag, "h1")
self.assertEqual(el.heading.anchor, "")
self.assertTrue(el.heading.hide_anchor)
+ self.assertFalse(el.heading.divider)
def test_st_title_with_invalid_anchor(self):
"""Test st.title with invalid anchor."""
@@ -162,3 +223,11 @@ def test_st_title_with_help(self):
self.assertEqual(el.heading.body, "some title")
self.assertEqual(el.heading.tag, "h1")
self.assertEqual(el.heading.help, "help text")
+ self.assertFalse(el.heading.divider)
+
+ def test_st_title_with_invalid_divider(self):
+ """Test st.title with invalid divider."""
+ with pytest.raises(TypeError):
+ st.title("some header", divider=True)
+ with pytest.raises(TypeError):
+ st.title("some header", divider="blue")
diff --git a/lib/tests/streamlit/elements/help_test.py b/lib/tests/streamlit/elements/help_test.py
index 99f305baaf4b..b95124a80d85 100644
--- a/lib/tests/streamlit/elements/help_test.py
+++ b/lib/tests/streamlit/elements/help_test.py
@@ -394,9 +394,6 @@ def test_constant_should_have_no_name(self):
actual = _get_variable_name_from_code_str(code)
self.assertEqual(actual, None)
- @pytest.mark.skipif(
- sys.version_info < (3, 8), reason="Walrus was introduced in Python 3.8"
- )
def test_walrus_should_return_var_name(self):
for st_call in st_calls:
# Wrap test in an st call.
diff --git a/lib/tests/streamlit/elements/legacy_altair_test.py b/lib/tests/streamlit/elements/legacy_altair_test.py
index 1ea8ef6a7b9d..a5469273164b 100644
--- a/lib/tests/streamlit/elements/legacy_altair_test.py
+++ b/lib/tests/streamlit/elements/legacy_altair_test.py
@@ -68,7 +68,7 @@ def test_date_column_utc_scale(self):
{"index": [date(2019, 8, 9), date(2019, 8, 10)], "numbers": [1, 10]}
).set_index("index")
- chart = altair.generate_chart("line", df)
+ chart, _ = altair.generate_chart("line", df)
st._legacy_altair_chart(chart)
c = self.get_delta_from_queue().new_element.vega_lite_chart
spec_dict = json.loads(c.spec)
diff --git a/lib/tests/streamlit/elements/map_test.py b/lib/tests/streamlit/elements/map_test.py
index 0deb02b19972..6a6cf0c0493e 100644
--- a/lib/tests/streamlit/elements/map_test.py
+++ b/lib/tests/streamlit/elements/map_test.py
@@ -28,11 +28,7 @@
from tests.streamlit import pyspark_mocks
from tests.streamlit.snowpark_mocks import DataFrame as MockedSnowparkDataFrame
from tests.streamlit.snowpark_mocks import Table as MockedSnowparkTable
-from tests.testutil import (
- create_snowpark_session,
- patch_config_options,
- should_skip_pyspark_tests,
-)
+from tests.testutil import create_snowpark_session, patch_config_options
df1 = pd.DataFrame({"lat": [1, 2, 3, 4], "lon": [10, 20, 30, 40]})
@@ -368,9 +364,6 @@ def test_unevaluated_snowpark_dataframe_integration(self):
"""Check if map data have 4 rows"""
self.assertEqual(len(c["layers"][0]["data"]), 4)
- @pytest.mark.skipif(
- should_skip_pyspark_tests(), reason="pyspark is incompatible with Python3.11"
- )
def test_pyspark_dataframe(self):
"""Test st.map with pyspark.sql.DataFrame"""
pyspark_map_dataframe = (
diff --git a/lib/tests/streamlit/elements/radio_test.py b/lib/tests/streamlit/elements/radio_test.py
index 3f8747517bc5..2d23c149806f 100644
--- a/lib/tests/streamlit/elements/radio_test.py
+++ b/lib/tests/streamlit/elements/radio_test.py
@@ -41,6 +41,7 @@ def test_just_label(self):
)
self.assertEqual(c.default, 0)
self.assertEqual(c.disabled, False)
+ self.assertEqual(c.captions, [])
def test_just_disabled(self):
"""Test that it can be called with disabled param."""
@@ -212,3 +213,25 @@ def test_label_visibility_wrong_value(self):
"Unsupported label_visibility option 'wrong_value'. Valid values are "
"'visible', 'hidden' or 'collapsed'.",
)
+
+ def test_no_captions(self):
+ """Test that it can be called with no captions."""
+ st.radio("the label", ("option1", "option2", "option3"), captions=None)
+
+ c = self.get_delta_from_queue().new_element.radio
+ self.assertEqual(c.label, "the label")
+ self.assertEqual(c.default, 0)
+ self.assertEqual(c.captions, [])
+
+ def test_some_captions(self):
+ """Test that it can be called with some captions."""
+ st.radio(
+ "the label",
+ ("option1", "option2", "option3", "option4"),
+ captions=("first caption", None, "", "last caption"),
+ )
+
+ c = self.get_delta_from_queue().new_element.radio
+ self.assertEqual(c.label, "the label")
+ self.assertEqual(c.default, 0)
+ self.assertEqual(c.captions, ["first caption", "", "", "last caption"])
diff --git a/lib/tests/streamlit/elements/slider_test.py b/lib/tests/streamlit/elements/slider_test.py
index 672697f3c2a9..7bb1ddce6ae8 100644
--- a/lib/tests/streamlit/elements/slider_test.py
+++ b/lib/tests/streamlit/elements/slider_test.py
@@ -25,6 +25,7 @@
from streamlit.errors import StreamlitAPIException
from streamlit.js_number import JSNumber
from streamlit.proto.LabelVisibilityMessage_pb2 import LabelVisibilityMessage
+from streamlit.testing.script_interactions import InteractiveScriptTests
from tests.delta_generator_test_case import DeltaGeneratorTestCase
@@ -300,3 +301,20 @@ def test_label_visibility_wrong_value(self):
"Unsupported label_visibility option 'wrong_value'. Valid values are "
"'visible', 'hidden' or 'collapsed'.",
)
+
+
+class SliderInteractiveTest(InteractiveScriptTests):
+ def test_id_stability(self):
+ script = self.script_from_string(
+ """
+ import streamlit as st
+
+ st.slider("slider", key="slider")
+ """
+ )
+ sr = script.run()
+ s1 = sr.slider[0]
+ sr2 = s1.set_value(5).run()
+ s2 = sr2.slider[0]
+
+ assert s1.id == s2.id
diff --git a/lib/tests/streamlit/external/langchain/streamlit_callback_handler_test.py b/lib/tests/streamlit/external/langchain/streamlit_callback_handler_test.py
index 812497ad166b..e82049ba2d57 100644
--- a/lib/tests/streamlit/external/langchain/streamlit_callback_handler_test.py
+++ b/lib/tests/streamlit/external/langchain/streamlit_callback_handler_test.py
@@ -95,32 +95,19 @@ def test_agent_run(self):
expected_deltas = [
{'addBlock': {}},
{'addBlock': {}},
- {'addBlock': {'expandable': {'label': '🤔 **Thinking...**', 'expanded': True}, 'allowEmpty': True}},
+ {'addBlock': {'expandable': {'label': 'Thinking...', 'expanded': True, 'icon': 'spinner'}, 'allowEmpty': True}},
{'newElement': {'markdown': {'body': 'I need to find out the artist\'s full name and then search the FooBar database for their albums. \nAction: Search \nAction Input: "The Storm Before the Calm" artist', 'elementType': 'NATIVE'}}},
- {'addBlock': {'expandable': {'label': '🤔 **Search:** The Storm Before the Calm" artist', 'expanded': True}, 'allowEmpty': True}},
- {'newElement': {'markdown': {'body': '**Alanis Morissette**', 'elementType': 'NATIVE'}}},
- {'addBlock': {'expandable': {'label': '✅ **Search:** The Storm Before the Calm" artist'}, 'allowEmpty': True}},
- {'addBlock': {'expandable': {'label': '🤔 **Thinking...**', 'expanded': True}, 'allowEmpty': True}},
- {'newElement': {'markdown': {'body': "I now need to search the FooBar database for Alanis Morissette's albums. \nAction: FooBar DB \nAction Input: What albums of Alanis Morissette are in the FooBar database?", 'elementType': 'NATIVE'}}},
- {'addBlock': {'expandable': {'label': '🤔 **FooBar DB:** What albums of Alanis Morissette are in the FooBar database?', 'expanded': True}, 'allowEmpty': True}},
+ {'addBlock': {'expandable': {'label': '**Search:** The Storm Before the Calm" artist', 'icon': 'spinner'}, 'allowEmpty': True}},
+ {'newElement': {'markdown': {'body': 'Art Film Music Theater TV "Storm Before the Calm" Brings Climate Dystopia to Praz-Delavallade Praz-Delavallade Ricky Amadour Oct 8, 2022 Storm Before the Calm at Praz-Delavallade Los... Alanis Morissette The Storm Before The Calm on Collectors\' Choice Music The Storm Before The Calm CD Artist: Alanis Morissette Genre: Pop Release Date: 8/26/2022 Qty: Add to Cart List Price: $16.98 Price: $14.43 You Save: $2.55 (15%) Add to Wish List Product Description 2022 release. Choose your favorite the calm before the storm paintings from 176 available designs. All the calm before the storm paintings ship within 48 hours and include a 30-day money-back guarantee. ... Calm Before the Storm Painting. Vanaja\'s Fine-Art. $35. $28. More from This Artist Similar Designs. Storm Before the Calm Painting. Lorie McClung. $22. $18. Choose your favorite calm before the storm paintings from 178 available designs. All calm before the storm paintings ship within 48 hours and include a 30-day money-back guarantee. ... Calm Before the Storm Painting. Vanaja\'s Fine-Art. $35. More from This Artist Similar Designs. Calm Before the Storm in Imagination Harbor Painting. Katheryn ... the storm before the calm. Alanis Morissette. 11 SONGS • 1 HOUR AND 46 MINUTES • JUN 17 2022. Purchase Options. 1. light—the lightworker\'s lament. 05:28. 2. heart—power of a soft heart.', 'elementType': 'NATIVE'}}},
+ {'addBlock': {'expandable': {'label': '**Search:** The Storm Before the Calm" artist', 'expanded': False, 'icon': 'check'}, 'allowEmpty': True}}, {'addBlock': {'expandable': {'label': 'Thinking...', 'expanded': True, 'icon': 'spinner'}, 'allowEmpty': True}},
+ {'newElement': {'markdown': {'body': "I now know the artist's full name is Alanis Morissette. \nAction: FooBar DB \nAction Input: What albums of Alanis Morissette are in the FooBar database?", 'elementType': 'NATIVE'}}},
+ {'addBlock': {'expandable': {'label': '**FooBar DB:** What albums of Alanis Morissette are in the FooBar database?', 'icon': 'spinner'}, 'allowEmpty': True}},
{'newElement': {'markdown': {'body': 'SELECT "Title" FROM "Album" INNER JOIN "Artist" ON "Album"."ArtistId" = "Artist"."ArtistId" WHERE "Name" = \'Alanis Morissette\' LIMIT 5;', 'elementType': 'NATIVE'}}},
{'newElement': {'markdown': {'body': 'The albums of Alanis Morissette in the FooBar database are Jagged Little Pill.', 'elementType': 'NATIVE'}}},
- {'newElement': {'markdown': {'body': '**The albums of Alanis Morissette in the FooBar database are Jagged Little Pill.**', 'elementType': 'NATIVE'}}},
- {'addBlock': {'expandable': {'label': '✅ **FooBar DB:** What albums of Alanis Morissette are in the FooBar database?'}, 'allowEmpty': True}},
- {'addBlock': {'expandable': {'label': '🤔 **Thinking...**', 'expanded': True}, 'allowEmpty': True}},
- {'newElement': {'markdown': {'body': "I now know the final answer. \nFinal Answer: The artist who recently released an album called 'The Storm Before the Calm' is Alanis Morissette and the albums of hers in the FooBar database are Jagged Little Pill.", 'elementType': 'NATIVE'}}},
- {'addBlock': {'expandable': {'label': '📚 **History**'}, 'allowEmpty': True}},
- {'newElement': {'markdown': {'body': '✅ **Search:** The Storm Before the Calm" artist', 'elementType': 'NATIVE'}}},
- {'newElement': {'markdown': {'body': 'I need to find out the artist\'s full name and then search the FooBar database for their albums. \nAction: Search \nAction Input: "The Storm Before the Calm" artist', 'elementType': 'NATIVE'}}},
- {'newElement': {'markdown': {'body': '**Alanis Morissette**', 'elementType': 'NATIVE'}}},
- {'newElement': {'empty': {}}},
- {'newElement': {'markdown': {'body': '✅ **FooBar DB:** What albums of Alanis Morissette are in the FooBar database?', 'elementType': 'NATIVE'}}},
- {'newElement': {'markdown': {'body': "I now need to search the FooBar database for Alanis Morissette's albums. \nAction: FooBar DB \nAction Input: What albums of Alanis Morissette are in the FooBar database?", 'elementType': 'NATIVE'}}},
- {'newElement': {'markdown': {'body': 'SELECT "Title" FROM "Album" INNER JOIN "Artist" ON "Album"."ArtistId" = "Artist"."ArtistId" WHERE "Name" = \'Alanis Morissette\' LIMIT 5;', 'elementType': 'NATIVE'}}},
{'newElement': {'markdown': {'body': 'The albums of Alanis Morissette in the FooBar database are Jagged Little Pill.', 'elementType': 'NATIVE'}}},
- {'newElement': {'markdown': {'body': '**The albums of Alanis Morissette in the FooBar database are Jagged Little Pill.**', 'elementType': 'NATIVE'}}},
- {'newElement': {'empty': {}}},
- {'addBlock': {'expandable': {'label': '✅ **Complete!**'}, 'allowEmpty': True}}
+ {'addBlock': {'expandable': {'label': '**FooBar DB:** What albums of Alanis Morissette are in the FooBar database?', 'expanded': False, 'icon': 'check'}, 'allowEmpty': True}},
+ {'addBlock': {'expandable': {'label': 'Thinking...', 'expanded': True, 'icon': 'spinner'}, 'allowEmpty': True}}, {'newElement': {'markdown': {'body': "I now know the final answer. \nFinal Answer: The artist who recently released an album called 'The Storm Before the Calm' is Alanis Morissette and the albums of hers in the FooBar database are Jagged Little Pill.", 'elementType': 'NATIVE'}}},
+ {'addBlock': {'expandable': {'label': '**Complete!**', 'expanded': False, 'icon': 'check'}, 'allowEmpty': True}}
]
# fmt: on
@@ -128,4 +115,5 @@ def test_agent_run(self):
actual_deltas = [
MessageToDict(delta) for delta in self.get_all_deltas_from_queue()
]
+
self.assertEqual(expected_deltas, actual_deltas)
diff --git a/lib/tests/streamlit/external/langchain/test_data/alanis.pickle b/lib/tests/streamlit/external/langchain/test_data/alanis.pickle
index d35fe8621ece..cecda6215251 100644
Binary files a/lib/tests/streamlit/external/langchain/test_data/alanis.pickle and b/lib/tests/streamlit/external/langchain/test_data/alanis.pickle differ
diff --git a/lib/tests/streamlit/layouts_test.py b/lib/tests/streamlit/layouts_test.py
index 0aacecaf1825..3899b3bf3728 100644
--- a/lib/tests/streamlit/layouts_test.py
+++ b/lib/tests/streamlit/layouts_test.py
@@ -146,6 +146,85 @@ def test_just_label(self):
self.assertEqual(expander_block.add_block.expandable.expanded, False)
+class StatusContainerTest(DeltaGeneratorTestCase):
+ def test_label_required(self):
+ """Test that label is required"""
+ with self.assertRaises(TypeError):
+ st.status()
+
+ def test_throws_error_on_wrong_state(self):
+ """Test that it throws an error on unknown state."""
+ with self.assertRaises(StreamlitAPIException):
+ st.status("label", state="unknown")
+
+ def test_just_label(self):
+ """Test that it correctly applies label param."""
+ st.status("label")
+ status_block = self.get_delta_from_queue()
+ self.assertEqual(status_block.add_block.expandable.label, "label")
+ self.assertEqual(status_block.add_block.expandable.expanded, False)
+ self.assertEqual(status_block.add_block.expandable.icon, "spinner")
+
+ def test_expanded_param(self):
+ """Test that it correctly applies expanded param."""
+ st.status("label", expanded=True)
+
+ status_block = self.get_delta_from_queue()
+ self.assertEqual(status_block.add_block.expandable.label, "label")
+ self.assertEqual(status_block.add_block.expandable.expanded, True)
+ self.assertEqual(status_block.add_block.expandable.icon, "spinner")
+
+ def test_state_param_complete(self):
+ """Test that it correctly applies state param with `complete`."""
+ st.status("label", state="complete")
+
+ status_block = self.get_delta_from_queue()
+ self.assertEqual(status_block.add_block.expandable.label, "label")
+ self.assertEqual(status_block.add_block.expandable.expanded, False)
+ self.assertEqual(status_block.add_block.expandable.icon, "check")
+
+ def test_state_param_error(self):
+ """Test that it correctly applies state param with `error`."""
+ st.status("label", state="error")
+
+ status_block = self.get_delta_from_queue()
+ self.assertEqual(status_block.add_block.expandable.label, "label")
+ self.assertEqual(status_block.add_block.expandable.expanded, False)
+ self.assertEqual(status_block.add_block.expandable.icon, "error")
+
+ def test_usage_with_context_manager(self):
+ """Test that it correctly switches to complete state when used as context manager."""
+ status = st.status("label")
+
+ with status:
+ pass
+
+ status_block = self.get_delta_from_queue()
+ self.assertEqual(status_block.add_block.expandable.label, "label")
+ self.assertEqual(status_block.add_block.expandable.expanded, False)
+ self.assertEqual(status_block.add_block.expandable.icon, "check")
+
+ def test_mutation_via_update(self):
+ """Test that update can be used to change the label, state and expand."""
+ status = st.status("label", expanded=False)
+ status.update(label="new label", state="error", expanded=True)
+
+ status_block = self.get_delta_from_queue()
+ self.assertEqual(status_block.add_block.expandable.label, "new label")
+ self.assertEqual(status_block.add_block.expandable.expanded, True)
+ self.assertEqual(status_block.add_block.expandable.icon, "error")
+
+ def test_mutation_via_update_in_cm(self):
+ """Test that update can be used in context manager to change the label, state and expand."""
+ with st.status("label", expanded=False) as status:
+ status.update(label="new label", state="error", expanded=True)
+
+ status_block = self.get_delta_from_queue()
+ self.assertEqual(status_block.add_block.expandable.label, "new label")
+ self.assertEqual(status_block.add_block.expandable.expanded, True)
+ self.assertEqual(status_block.add_block.expandable.icon, "error")
+
+
class TabsTest(DeltaGeneratorTestCase):
def test_tab_required(self):
"""Test that at least one tab is required."""
diff --git a/lib/tests/streamlit/runtime/app_session_test.py b/lib/tests/streamlit/runtime/app_session_test.py
index 13d923bfe463..f47407296a6a 100644
--- a/lib/tests/streamlit/runtime/app_session_test.py
+++ b/lib/tests/streamlit/runtime/app_session_test.py
@@ -27,6 +27,7 @@
from streamlit import config
from streamlit.proto.AppPage_pb2 import AppPage
from streamlit.proto.BackMsg_pb2 import BackMsg
+from streamlit.proto.Common_pb2 import FileURLs, FileURLsRequest, FileURLsResponse
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.runtime import Runtime
from streamlit.runtime.app_session import AppSession, AppSessionState
@@ -46,7 +47,10 @@
get_script_run_ctx,
)
from streamlit.runtime.state import SessionState
-from streamlit.runtime.uploaded_file_manager import UploadedFileManager
+from streamlit.runtime.uploaded_file_manager import (
+ UploadedFileManager,
+ UploadFileUrlInfo,
+)
from streamlit.watcher.local_sources_watcher import LocalSourcesWatcher
from tests.testutil import patch_config_options
@@ -480,6 +484,57 @@ def test_disconnect_file_watchers_removes_refs(self):
self.assertEqual(len(gc.get_referrers(session)), 0)
+ @patch("streamlit.runtime.app_session.AppSession._enqueue_forward_msg")
+ def test_handle_file_urls_request(self, mock_enqueue):
+ session = _create_test_session()
+
+ upload_file_urls = [
+ UploadFileUrlInfo(
+ file_id="file_1",
+ upload_url="upload_file_url_1",
+ delete_url="delete_file_url_1",
+ ),
+ UploadFileUrlInfo(
+ file_id="file_2",
+ upload_url="upload_file_url_2",
+ delete_url="delete_file_url_2",
+ ),
+ UploadFileUrlInfo(
+ file_id="file_3",
+ upload_url="upload_file_url_3",
+ delete_url="delete_file_url_3",
+ ),
+ ]
+ session._uploaded_file_mgr.get_upload_urls.return_value = upload_file_urls
+
+ session._handle_file_urls_request(
+ FileURLsRequest(
+ request_id="my_id",
+ file_names=["file_1", "file_2", "file_3"],
+ session_id=session.id,
+ )
+ )
+
+ session._uploaded_file_mgr.get_upload_urls.assert_called_once_with(
+ session.id, ["file_1", "file_2", "file_3"]
+ )
+
+ expected_msg = ForwardMsg(
+ file_urls_response=FileURLsResponse(
+ response_id="my_id",
+ file_urls=[
+ FileURLs(
+ file_id=url.file_id,
+ upload_url=url.upload_url,
+ delete_url=url.delete_url,
+ )
+ for url in upload_file_urls
+ ],
+ )
+ )
+
+ mock_enqueue.assert_called_once_with(expected_msg)
+
def _mock_get_options_for_section(overrides=None) -> Callable[..., Any]:
if not overrides:
diff --git a/lib/tests/streamlit/runtime/caching/cache_errors_test.py b/lib/tests/streamlit/runtime/caching/cache_errors_test.py
index 8146cfb5034b..91c778519978 100644
--- a/lib/tests/streamlit/runtime/caching/cache_errors_test.py
+++ b/lib/tests/streamlit/runtime/caching/cache_errors_test.py
@@ -29,7 +29,6 @@
from tests import testutil
from tests.delta_generator_test_case import DeltaGeneratorTestCase
from tests.streamlit import pyspark_mocks, snowpark_mocks
-from tests.testutil import should_skip_pyspark_tests
class CacheErrorsTest(DeltaGeneratorTestCase):
@@ -112,10 +111,6 @@ def test_unevaluated_dataframe_error(self, type_name):
elif "snowpark.dataframe.DataFrame" in type_name:
to_return = snowpark_mocks.DataFrame()
else:
- if should_skip_pyspark_tests():
- # Python 3.11 is incompatible with Pyspark
- return
-
to_return = (
pyspark_mocks.create_pyspark_dataframe_with_mocked_personal_data()
)
diff --git a/lib/tests/streamlit/runtime/caching/common_cache_test.py b/lib/tests/streamlit/runtime/caching/common_cache_test.py
index 4da99c9c05c6..4df44dc0e268 100644
--- a/lib/tests/streamlit/runtime/caching/common_cache_test.py
+++ b/lib/tests/streamlit/runtime/caching/common_cache_test.py
@@ -42,6 +42,7 @@
MemoryCacheStorageManager,
)
from streamlit.runtime.forward_msg_queue import ForwardMsgQueue
+from streamlit.runtime.memory_uploaded_file_manager import MemoryUploadedFileManager
from streamlit.runtime.scriptrunner import (
ScriptRunContext,
add_script_run_ctx,
@@ -256,7 +257,7 @@ def test_cached_st_function_warning(self, _, cache_decorator, call_stack):
_enqueue=forward_msg_queue.enqueue,
query_string="",
session_state=SafeSessionState(SessionState()),
- uploaded_file_mgr=UploadedFileManager(),
+ uploaded_file_mgr=MemoryUploadedFileManager("/mock/upload"),
page_script_hash="",
user_info={"email": "test@test.com"},
),
diff --git a/lib/tests/streamlit/runtime/caching/hashing_test.py b/lib/tests/streamlit/runtime/caching/hashing_test.py
index 35c6f89dbccd..f8bcbd92a39d 100644
--- a/lib/tests/streamlit/runtime/caching/hashing_test.py
+++ b/lib/tests/streamlit/runtime/caching/hashing_test.py
@@ -37,6 +37,7 @@
from parameterized import parameterized
from PIL import Image
+from streamlit.proto.Common_pb2 import FileURLs
from streamlit.runtime.caching import cache_data, cache_resource
from streamlit.runtime.caching.cache_errors import UnhashableTypeError
from streamlit.runtime.caching.cache_type import CacheType
@@ -328,12 +329,6 @@ def test_PIL_image(self):
[
(BytesIO, b"123", b"456", b"123"),
(StringIO, "123", "456", "123"),
- (
- UploadedFile,
- UploadedFileRec(0, "name", "type", b"123"),
- UploadedFileRec(0, "name", "type", b"456"),
- UploadedFileRec(0, "name", "type", b"123"),
- ),
]
)
def test_io(self, io_type, io_data1, io_data2, io_data3):
@@ -349,6 +344,28 @@ def test_io(self, io_type, io_data1, io_data2, io_data3):
io3.seek(0)
self.assertNotEqual(get_hash(io1), get_hash(io3))
+ def test_uploaded_file_io(self):
+ rec1 = UploadedFileRec("file1", "name", "type", b"123")
+ rec2 = UploadedFileRec("file1", "name", "type", b"456")
+ rec3 = UploadedFileRec("file1", "name", "type", b"123")
+ io1 = UploadedFile(
+ rec1, FileURLs(file_id=rec1.file_id, upload_url="u1", delete_url="d1")
+ )
+ io2 = UploadedFile(
+ rec2, FileURLs(file_id=rec2.file_id, upload_url="u2", delete_url="d2")
+ )
+ io3 = UploadedFile(
+ rec3, FileURLs(file_id=rec3.file_id, upload_url="u3", delete_url="u3")
+ )
+
+ self.assertEqual(get_hash(io1), get_hash(io3))
+ self.assertNotEqual(get_hash(io1), get_hash(io2))
+
+ # Changing the stream position should change the hash
+ io1.seek(1)
+ io3.seek(0)
+ self.assertNotEqual(get_hash(io1), get_hash(io3))
+
def test_partial(self):
p1 = functools.partial(int, base=2)
p2 = functools.partial(int, base=3)
diff --git a/lib/tests/streamlit/runtime/legacy_caching/hashing_test.py b/lib/tests/streamlit/runtime/legacy_caching/hashing_test.py
index 13109b2194fd..a8e3c09c9f3b 100644
--- a/lib/tests/streamlit/runtime/legacy_caching/hashing_test.py
+++ b/lib/tests/streamlit/runtime/legacy_caching/hashing_test.py
@@ -36,6 +36,7 @@
from parameterized import parameterized
import streamlit as st
+from streamlit.proto.Common_pb2 import FileURLs
from streamlit.runtime.legacy_caching.hashing import (
_FFI_TYPE_NAMES,
_NP_SIZE_LARGE,
@@ -322,12 +323,6 @@ def test_numpy(self):
[
(BytesIO, b"123", b"456", b"123"),
(StringIO, "123", "456", "123"),
- (
- UploadedFile,
- UploadedFileRec("id", "name", "type", b"123"),
- UploadedFileRec("id", "name", "type", b"456"),
- UploadedFileRec("id", "name", "type", b"123"),
- ),
]
)
def test_io(self, io_type, io_data1, io_data2, io_data3):
@@ -343,6 +338,28 @@ def test_io(self, io_type, io_data1, io_data2, io_data3):
io3.seek(0)
self.assertNotEqual(get_hash(io1), get_hash(io3))
+ def test_uploaded_file_io(self):
+ rec1 = UploadedFileRec("file1", "name", "type", b"123")
+ rec2 = UploadedFileRec("file1", "name", "type", b"456")
+ rec3 = UploadedFileRec("file1", "name", "type", b"123")
+ io1 = UploadedFile(
+ rec1, FileURLs(file_id=rec1.file_id, upload_url="u1", delete_url="d1")
+ )
+ io2 = UploadedFile(
+ rec2, FileURLs(file_id=rec2.file_id, upload_url="u2", delete_url="d2")
+ )
+ io3 = UploadedFile(
+ rec3, FileURLs(file_id=rec3.file_id, upload_url="u3", delete_url="u3")
+ )
+
+ self.assertEqual(get_hash(io1), get_hash(io3))
+ self.assertNotEqual(get_hash(io1), get_hash(io2))
+
+ # Changing the stream position should change the hash
+ io1.seek(1)
+ io3.seek(0)
+ self.assertNotEqual(get_hash(io1), get_hash(io3))
+
def test_partial(self):
p1 = functools.partial(int, base=2)
p2 = functools.partial(int, base=3)
diff --git a/lib/tests/streamlit/runtime/runtime_test.py b/lib/tests/streamlit/runtime/runtime_test.py
index 9c746c0e1f80..d62a85ecae8f 100644
--- a/lib/tests/streamlit/runtime/runtime_test.py
+++ b/lib/tests/streamlit/runtime/runtime_test.py
@@ -40,8 +40,8 @@
from streamlit.runtime.forward_msg_cache import populate_hash_if_needed
from streamlit.runtime.memory_media_file_storage import MemoryMediaFileStorage
from streamlit.runtime.memory_session_storage import MemorySessionStorage
+from streamlit.runtime.memory_uploaded_file_manager import MemoryUploadedFileManager
from streamlit.runtime.runtime import AsyncObjects, RuntimeStoppedError
-from streamlit.runtime.uploaded_file_manager import UploadedFileRec
from streamlit.runtime.websocket_session_manager import WebsocketSessionManager
from streamlit.watcher import event_based_path_watcher
from tests.streamlit.message_mocks import (
@@ -68,7 +68,10 @@ def write_forward_msg(self, msg: ForwardMsg) -> None:
class RuntimeConfigTests(unittest.TestCase):
def test_runtime_config_defaults(self):
config = RuntimeConfig(
- "/my/script.py", None, MemoryMediaFileStorage("/mock/media")
+ "/my/script.py",
+ None,
+ MemoryMediaFileStorage("/mock/media"),
+ MemoryUploadedFileManager("/mock/upload"),
)
self.assertIsInstance(
@@ -553,50 +556,6 @@ async def send_data_msg() -> None:
await finish_script(True)
self.assertFalse(is_data_msg_cached())
- async def test_orphaned_upload_file_deletion(self):
- """An uploaded file with no associated AppSession should be
- deleted.
- """
- await self.runtime.start()
-
- client = MockSessionClient()
- session_id = self.runtime.connect_session(client=client, user_info=MagicMock())
-
- file = UploadedFileRec(0, "file.txt", "type", b"123")
-
- # Upload a file for our connected session.
- added_file = self.runtime._uploaded_file_mgr.add_file(
- session_id=session_id,
- widget_id="widget_id",
- file=UploadedFileRec(0, "file.txt", "type", b"123"),
- )
-
- # The file should exist.
- self.assertEqual(
- self.runtime._uploaded_file_mgr.get_all_files(session_id, "widget_id"),
- [added_file],
- )
-
- # Disconnect the session. The file should be deleted.
- self.runtime.disconnect_session(session_id)
- self.assertEqual(
- self.runtime._uploaded_file_mgr.get_all_files(session_id, "widget_id"),
- [],
- )
-
- # Upload a file for a session that doesn't exist.
- self.runtime._uploaded_file_mgr.add_file(
- session_id="no_such_session", widget_id="widget_id", file=file
- )
-
- # The file should be immediately deleted.
- self.assertEqual(
- self.runtime._uploaded_file_mgr.get_all_files(
- "no_such_session", "widget_id"
- ),
- [],
- )
-
async def test_get_async_objs(self):
"""Runtime._get_async_objs() will raise an error if called before the
Runtime is started, and will return the Runtime's AsyncObjects instance otherwise.
@@ -631,6 +590,7 @@ async def asyncSetUp(self):
script_path=self._path,
command_line="mock command line",
media_file_storage=MemoryMediaFileStorage("/mock/media"),
+ uploaded_file_manager=MemoryUploadedFileManager("/mock/upload"),
session_manager_class=MagicMock,
session_storage=MagicMock(),
cache_storage_manager=MagicMock(),
diff --git a/lib/tests/streamlit/runtime/runtime_test_case.py b/lib/tests/streamlit/runtime/runtime_test_case.py
index 154f4983d807..dee9f4d6d4bb 100644
--- a/lib/tests/streamlit/runtime/runtime_test_case.py
+++ b/lib/tests/streamlit/runtime/runtime_test_case.py
@@ -23,6 +23,7 @@
MemoryCacheStorageManager,
)
from streamlit.runtime.memory_media_file_storage import MemoryMediaFileStorage
+from streamlit.runtime.memory_uploaded_file_manager import MemoryUploadedFileManager
from streamlit.runtime.script_data import ScriptData
from streamlit.runtime.scriptrunner.script_cache import ScriptCache
from streamlit.runtime.session_manager import (
@@ -104,6 +105,7 @@ async def asyncSetUp(self):
script_path="mock/script/path.py",
command_line="",
media_file_storage=MemoryMediaFileStorage("/mock/media"),
+ uploaded_file_manager=MemoryUploadedFileManager("/mock/upload"),
session_manager_class=MockSessionManager,
session_storage=mock.MagicMock(),
cache_storage_manager=MemoryCacheStorageManager(),
diff --git a/lib/tests/streamlit/runtime/runtime_threading_test.py b/lib/tests/streamlit/runtime/runtime_threading_test.py
index 698f0cfac028..de479d84a601 100644
--- a/lib/tests/streamlit/runtime/runtime_threading_test.py
+++ b/lib/tests/streamlit/runtime/runtime_threading_test.py
@@ -48,6 +48,7 @@ def create_runtime_on_another_thread():
"mock/script/path.py",
"",
media_file_storage=MagicMock(),
+ uploaded_file_manager=MagicMock(),
session_manager_class=MagicMock,
session_storage=MagicMock(),
cache_storage_manager=MagicMock(),
diff --git a/lib/tests/streamlit/runtime/scriptrunner/magic_test.py b/lib/tests/streamlit/runtime/scriptrunner/magic_test.py
index 22cb8824b060..98842489b1b1 100644
--- a/lib/tests/streamlit/runtime/scriptrunner/magic_test.py
+++ b/lib/tests/streamlit/runtime/scriptrunner/magic_test.py
@@ -175,11 +175,20 @@ async def myfunc(a):
"""
self._testCode(CODE_ASYNC_FOR, 1)
- def test_docstring_is_ignored(self):
+ def test_docstring_is_ignored_func(self):
"""Test that docstrings don't print in the app"""
CODE = """
def myfunc(a):
'''This is the docstring'''
return 42
+"""
+ self._testCode(CODE, 0)
+
+ def test_docstring_is_ignored_async_func(self):
+ """Test that async function docstrings don't print in the app"""
+ CODE = """
+async def myfunc(a):
+ '''This is the docstring for async func'''
+ return 43
"""
self._testCode(CODE, 0)
diff --git a/lib/tests/streamlit/runtime/scriptrunner/script_runner_test.py b/lib/tests/streamlit/runtime/scriptrunner/script_runner_test.py
index 99aa307fa6d0..56b20b3b902a 100644
--- a/lib/tests/streamlit/runtime/scriptrunner/script_runner_test.py
+++ b/lib/tests/streamlit/runtime/scriptrunner/script_runner_test.py
@@ -36,6 +36,7 @@
from streamlit.runtime.legacy_caching import caching
from streamlit.runtime.media_file_manager import MediaFileManager
from streamlit.runtime.memory_media_file_storage import MemoryMediaFileStorage
+from streamlit.runtime.memory_uploaded_file_manager import MemoryUploadedFileManager
from streamlit.runtime.scriptrunner import (
RerunData,
RerunException,
@@ -50,7 +51,6 @@
ScriptRequestType,
)
from streamlit.runtime.state.session_state import SessionState
-from streamlit.runtime.uploaded_file_manager import UploadedFileManager
from tests import testutil
text_utf = "complete! 👨🎤"
@@ -1069,7 +1069,7 @@ def __init__(self, script_name: str):
main_script_path=main_script_path,
client_state=ClientState(),
session_state=SessionState(),
- uploaded_file_mgr=UploadedFileManager(),
+ uploaded_file_mgr=MemoryUploadedFileManager("/mock/upload"),
script_cache=ScriptCache(),
initial_rerun_data=RerunData(),
user_info={"email": "test@test.com"},
diff --git a/lib/tests/streamlit/runtime/state/session_state_test.py b/lib/tests/streamlit/runtime/state/session_state_test.py
index 5c613517fb09..de9144a63f2e 100644
--- a/lib/tests/streamlit/runtime/state/session_state_test.py
+++ b/lib/tests/streamlit/runtime/state/session_state_test.py
@@ -26,6 +26,7 @@
import streamlit as st
import tests.streamlit.runtime.state.strategies as stst
from streamlit.errors import StreamlitAPIException
+from streamlit.proto.Common_pb2 import FileURLs as FileURLsProto
from streamlit.proto.WidgetStates_pb2 import WidgetState as WidgetStateProto
from streamlit.proto.WidgetStates_pb2 import WidgetStates as WidgetStatesProto
from streamlit.runtime.scriptrunner import get_script_run_ctx
@@ -37,7 +38,7 @@
WidgetMetadata,
WStates,
)
-from streamlit.runtime.uploaded_file_manager import UploadedFileRec
+from streamlit.runtime.uploaded_file_manager import UploadedFile, UploadedFileRec
from streamlit.testing.script_interactions import InteractiveScriptTests
from tests.delta_generator_test_case import DeltaGeneratorTestCase
from tests.testutil import patch_config_options
@@ -387,12 +388,16 @@ def test_date_input_serde(self):
)
check_roundtrip("date_interval", date_interval)
- @patch("streamlit.elements.widgets.file_uploader._get_file_recs")
- def test_file_uploader_serde(self, get_file_recs_patch):
- file_recs = [
- UploadedFileRec(1, "file1", "type", b"123"),
+ @patch("streamlit.elements.widgets.file_uploader._get_upload_files")
+ def test_file_uploader_serde(self, get_upload_files_patch):
+ file_rec = UploadedFileRec("file1", "file1", "type", b"123")
+ uploaded_files = [
+ UploadedFile(
+ file_rec, FileURLsProto(file_id="1", delete_url="d1", upload_url="u1")
+ )
]
- get_file_recs_patch.return_value = file_recs
+
+ get_upload_files_patch.return_value = uploaded_files
uploaded_file = st.file_uploader("file_uploader", key="file_uploader")
check_roundtrip("file_uploader", uploaded_file)
diff --git a/lib/tests/streamlit/runtime/state/widgets_test.py b/lib/tests/streamlit/runtime/state/widgets_test.py
index 7a9d4a5b0ad4..c7d3776cbab7 100644
--- a/lib/tests/streamlit/runtime/state/widgets_test.py
+++ b/lib/tests/streamlit/runtime/state/widgets_test.py
@@ -14,21 +14,21 @@
"""Tests widget-related functionality"""
+import inspect
import unittest
-from unittest.mock import MagicMock, call, patch
+from unittest.mock import ANY, MagicMock, call, patch
from parameterized import parameterized
import streamlit as st
from streamlit import errors
-from streamlit.proto.Button_pb2 import Button as ButtonProto
from streamlit.proto.Common_pb2 import StringTriggerValue as StringTriggerValueProto
from streamlit.proto.WidgetStates_pb2 import WidgetStates
from streamlit.runtime.scriptrunner.script_run_context import get_script_run_ctx
from streamlit.runtime.state import coalesce_widget_states
-from streamlit.runtime.state.common import GENERATED_WIDGET_ID_PREFIX
+from streamlit.runtime.state.common import GENERATED_WIDGET_ID_PREFIX, compute_widget_id
from streamlit.runtime.state.session_state import SessionState, WidgetMetadata
-from streamlit.runtime.state.widgets import compute_widget_id, user_key_from_widget_id
+from streamlit.runtime.state.widgets import user_key_from_widget_id
from tests.delta_generator_test_case import DeltaGeneratorTestCase
@@ -301,60 +301,155 @@ def test_coalesce_widget_states(self):
class WidgetHelperTests(unittest.TestCase):
def test_get_widget_with_generated_key(self):
- button_proto = ButtonProto()
- button_proto.label = "the label"
- self.assertTrue(
- compute_widget_id("button", button_proto).startswith(
- GENERATED_WIDGET_ID_PREFIX
- )
- )
-
+ id = compute_widget_id("button", label="the label")
+ assert id.startswith(GENERATED_WIDGET_ID_PREFIX)
+
+
+class ComputeWidgetIdTests(DeltaGeneratorTestCase):
+ """Enforce that new arguments added to the signature of a widget function are taken
+ into account when computing widget IDs unless explicitly excluded.
+ """
+
+ def signature_to_expected_kwargs(self, sig):
+ # These widget kwargs aren't used for widget ID calculation, meaning that they
+ # can be changed without resetting the widget.
+ excluded_kwargs = {
+ # Internal stuff
+ "ctx",
+ # Formatting/display stuff
+ "disabled",
+ "format_func",
+ "label_visibility",
+ # on_change callbacks and similar/related parameters.
+ "args",
+ "kwargs",
+ "on_change",
+ "on_click",
+ "on_submit",
+ }
+
+ kwargs = {
+ kwarg: ANY
+ for kwarg in sig.parameters.keys()
+ if kwarg not in excluded_kwargs
+ }
+
+ # Add some kwargs that are passed to compute_widget_id but don't appear in widget
+ # signatures.
+ for kwarg in ["form_id", "user_key"]:
+ kwargs[kwarg] = ANY
+
+ return kwargs
-class WidgetIdDisabledTests(DeltaGeneratorTestCase):
@parameterized.expand(
[
- (st.button,),
- (st.camera_input,),
- (st.checkbox,),
- (st.color_picker,),
- (st.file_uploader,),
- (st.number_input,),
- (st.slider,),
- (st.text_area,),
- (st.text_input,),
- (st.date_input,),
- (st.time_input,),
+ (st.camera_input, "camera_input"),
+ (st.checkbox, "checkbox"),
+ (st.color_picker, "color_picker"),
+ (st.date_input, "time_widgets"),
+ (st.file_uploader, "file_uploader"),
+ (st.number_input, "number_input"),
+ (st.slider, "slider"),
+ (st.text_area, "text_widgets"),
+ (st.text_input, "text_widgets"),
+ (st.time_input, "time_widgets"),
]
)
- def test_disabled_parameter_id(self, widget_func):
- widget_func("my_widget")
+ def test_widget_id_computation(self, widget_func, module_name):
+ with patch(
+ f"streamlit.elements.widgets.{module_name}.compute_widget_id",
+ wraps=compute_widget_id,
+ ) as patched_compute_widget_id:
+ widget_func("my_widget")
- # The `disabled` argument shouldn't affect a widget's ID, so we
- # expect a DuplicateWidgetID error.
- with self.assertRaises(errors.DuplicateWidgetID):
- widget_func("my_widget", disabled=True)
+ sig = inspect.signature(widget_func)
+ expected_sig = self.signature_to_expected_kwargs(sig)
- def test_disabled_parameter_id_download_button(self):
- st.download_button("my_widget", data="")
+ patched_compute_widget_id.assert_called_with(ANY, **expected_sig)
+ # Double check that we get a DuplicateWidgetID error since the `disabled`
+ # argument shouldn't affect a widget's ID.
with self.assertRaises(errors.DuplicateWidgetID):
- st.download_button("my_widget", data="", disabled=True)
+ widget_func("my_widget", disabled=True)
+
+ @parameterized.expand(
+ [
+ (st.button, "button"),
+ (st.chat_input, "chat"),
+ (st.download_button, "button"),
+ ]
+ )
+ def test_widget_id_computation_no_form_widgets(self, widget_func, module_name):
+ with patch(
+ f"streamlit.elements.widgets.{module_name}.compute_widget_id",
+ wraps=compute_widget_id,
+ ) as patched_compute_widget_id:
+ if widget_func == st.download_button:
+ widget_func("my_widget", data="")
+ else:
+ widget_func("my_widget")
+
+ sig = inspect.signature(widget_func)
+ expected_sig = self.signature_to_expected_kwargs(sig)
+
+ # button and chat widgets don't include a form_id param in their calls to
+ # compute_widget_id because having either in forms (aside from the form's
+ # submit button) is illegal.
+ del expected_sig["form_id"]
+ if widget_func == st.button:
+ expected_sig["is_form_submitter"] = ANY
+
+ patched_compute_widget_id.assert_called_with(ANY, **expected_sig)
@parameterized.expand(
[
- (st.multiselect,),
- (st.radio,),
- (st.select_slider,),
- (st.selectbox,),
+ (st.multiselect, "multiselect"),
+ (st.radio, "radio"),
+ (st.select_slider, "select_slider"),
+ (st.selectbox, "selectbox"),
]
)
- def test_disabled_parameter_id_options_widgets(self, widget_func):
+ def test_widget_id_computation_options_widgets(self, widget_func, module_name):
options = ["a", "b", "c"]
- widget_func("my_widget", options)
+ with patch(
+ f"streamlit.elements.widgets.{module_name}.compute_widget_id",
+ wraps=compute_widget_id,
+ ) as patched_compute_widget_id:
+ widget_func("my_widget", options)
+
+ sig = inspect.signature(widget_func)
+ patched_compute_widget_id.assert_called_with(
+ ANY, **self.signature_to_expected_kwargs(sig)
+ )
+
+ # Double check that we get a DuplicateWidgetID error since the `disabled`
+ # argument shouldn't affect a widget's ID.
with self.assertRaises(errors.DuplicateWidgetID):
widget_func("my_widget", options, disabled=True)
+ def test_widget_id_computation_data_editor(self):
+ with patch(
+ f"streamlit.elements.widgets.data_editor.compute_widget_id",
+ wraps=compute_widget_id,
+ ) as patched_compute_widget_id:
+ st.data_editor(data=[])
+
+ sig = inspect.signature(st.data_editor)
+ expected_sig = self.signature_to_expected_kwargs(sig)
+
+ # Make some changes to expected_sig unique to st.data_editor.
+ expected_sig["column_config_mapping"] = ANY
+ del expected_sig["hide_index"]
+ del expected_sig["column_config"]
+
+ patched_compute_widget_id.assert_called_with(ANY, **expected_sig)
+
+ # Double check that we get a DuplicateWidgetID error since the `disabled`
+ # argument shouldn't affect a widget's ID.
+ with self.assertRaises(errors.DuplicateWidgetID):
+ st.data_editor(data=[], disabled=True)
+
@patch("streamlit.runtime.Runtime.exists", new=MagicMock(return_value=True))
class WidgetUserKeyTests(DeltaGeneratorTestCase):
diff --git a/lib/tests/streamlit/runtime/uploaded_file_manager_test.py b/lib/tests/streamlit/runtime/uploaded_file_manager_test.py
index 03fae8292449..71c920b8515f 100644
--- a/lib/tests/streamlit/runtime/uploaded_file_manager_test.py
+++ b/lib/tests/streamlit/runtime/uploaded_file_manager_test.py
@@ -16,131 +16,78 @@
import unittest
+from streamlit.runtime.memory_uploaded_file_manager import MemoryUploadedFileManager
from streamlit.runtime.stats import CacheStat
-from streamlit.runtime.uploaded_file_manager import UploadedFileManager, UploadedFileRec
+from streamlit.runtime.uploaded_file_manager import UploadedFileRec
from tests.exception_capturing_thread import call_on_threads
-FILE_1 = UploadedFileRec(id=0, name="file1", type="type", data=b"file1")
-FILE_2 = UploadedFileRec(id=0, name="file2", type="type", data=b"file222")
+FILE_1 = UploadedFileRec(file_id="url1", name="file1", type="type", data=b"file1")
+FILE_2 = UploadedFileRec(file_id="url2", name="file2", type="type", data=b"file222")
class UploadedFileManagerTest(unittest.TestCase):
def setUp(self):
- self.mgr = UploadedFileManager()
- self.filemgr_events = []
- self.mgr.on_files_updated.connect(self._on_files_updated)
-
- def _on_files_updated(self, file_list, **kwargs):
- self.filemgr_events.append(file_list)
+ self.mgr = MemoryUploadedFileManager("/mock/upload")
def test_added_file_id(self):
- """An added file should have a unique ID."""
- f1 = self.mgr.add_file("session", "widget", FILE_1)
- f2 = self.mgr.add_file("session", "widget", FILE_1)
- self.assertNotEqual(FILE_1.id, f1.id)
- self.assertNotEqual(f1.id, f2.id)
+ """Presigned file URL should have a unique ID."""
+ info1, info2 = self.mgr.get_upload_urls("session", ["name1", "name1"])
+ self.assertNotEqual(info1.file_id, info2.file_id)
- def test_added_file_properties(self):
+ def test_retrieve_added_file(self):
"""An added file should maintain all its source properties
except its ID."""
- added = self.mgr.add_file("session", "widget", FILE_1)
- self.assertNotEqual(added.id, FILE_1.id)
- self.assertEqual(added.name, FILE_1.name)
- self.assertEqual(added.type, FILE_1.type)
- self.assertEqual(added.data, FILE_1.data)
-
- def test_retrieve_added_file(self):
- """After adding a file to the mgr, we should be able to get it back."""
- self.assertEqual([], self.mgr.get_all_files("non-report", "non-widget"))
-
- file_1 = self.mgr.add_file("session", "widget", FILE_1)
- self.assertEqual([file_1], self.mgr.get_all_files("session", "widget"))
- self.assertEqual([file_1], self.mgr.get_files("session", "widget", [file_1.id]))
- self.assertEqual(len(self.filemgr_events), 1)
-
- # Add another file
- file_2 = self.mgr.add_file("session", "widget", FILE_2)
- self.assertEqual([file_1, file_2], self.mgr.get_all_files("session", "widget"))
- self.assertEqual([file_1], self.mgr.get_files("session", "widget", [file_1.id]))
- self.assertEqual([file_2], self.mgr.get_files("session", "widget", [file_2.id]))
- self.assertEqual(len(self.filemgr_events), 2)
+ self.mgr.add_file("session", FILE_1)
+ self.mgr.add_file("session", FILE_2)
+
+ file1_from_storage, *rest_files = self.mgr.get_files("session", ["url1"])
+ self.assertEqual(len(rest_files), 0)
+ self.assertEqual(file1_from_storage.file_id, FILE_1.file_id)
+ self.assertEqual(file1_from_storage.name, FILE_1.name)
+ self.assertEqual(file1_from_storage.type, FILE_1.type)
+ self.assertEqual(file1_from_storage.data, FILE_1.data)
+
+ file2_from_storage, *other_files = self.mgr.get_files("session", ["url2"])
+ self.assertEqual(len(other_files), 0)
+ self.assertEqual(file2_from_storage.file_id, FILE_2.file_id)
+ self.assertEqual(file2_from_storage.name, FILE_2.name)
+ self.assertEqual(file2_from_storage.type, FILE_2.type)
+ self.assertEqual(file2_from_storage.data, FILE_2.data)
def test_remove_file(self):
# This should not error.
- self.mgr.remove_files("non-report", "non-widget")
+ self.mgr.remove_file("non-session", "non-file-id")
- f1 = self.mgr.add_file("session", "widget", FILE_1)
- self.mgr.remove_file("session", "widget", f1.id)
- self.assertEqual([], self.mgr.get_all_files("session", "widget"))
+ self.mgr.add_file("session", FILE_1)
+ self.mgr.remove_file("session", FILE_1.file_id)
+ self.assertEqual([], self.mgr.get_files("session", [FILE_1.file_id]))
# Remove the file again. It doesn't exist, but this isn't an error.
- self.mgr.remove_file("session", "widget", f1.id)
- self.assertEqual([], self.mgr.get_all_files("session", "widget"))
-
- f1 = self.mgr.add_file("session", "widget", FILE_1)
- f2 = self.mgr.add_file("session", "widget", FILE_2)
- self.mgr.remove_file("session", "widget", f1.id)
- self.assertEqual([f2], self.mgr.get_all_files("session", "widget"))
-
- def test_remove_widget_files(self):
- # This should not error.
- self.mgr.remove_session_files("non-report")
-
- # Add two files with different session IDs, but the same widget ID.
- self.mgr.add_file("session1", "widget", FILE_1)
- f2 = self.mgr.add_file("session2", "widget", FILE_1)
-
- self.mgr.remove_files("session1", "widget")
- self.assertEqual([], self.mgr.get_all_files("session1", "widget"))
- self.assertEqual([f2], self.mgr.get_all_files("session2", "widget"))
+ self.mgr.remove_file("session", FILE_1.file_id)
+ self.assertEqual([], self.mgr.get_files("session", [FILE_1.file_id]))
+
+ self.mgr.add_file("session", FILE_1)
+ self.mgr.add_file("session", FILE_2)
+ self.mgr.remove_file("session", FILE_1.file_id)
+ self.assertEqual(
+ [FILE_2], self.mgr.get_files("session", [FILE_1.file_id, FILE_2.file_id])
+ )
def test_remove_session_files(self):
# This should not error.
self.mgr.remove_session_files("non-report")
# Add two files with different session IDs, but the same widget ID.
- self.mgr.add_file("session1", "widget1", FILE_1)
- self.mgr.add_file("session1", "widget2", FILE_1)
- f3 = self.mgr.add_file("session2", "widget", FILE_1)
+ self.mgr.add_file("session1", FILE_1)
+ self.mgr.add_file("session1", FILE_2)
- self.mgr.remove_session_files("session1")
- self.assertEqual([], self.mgr.get_all_files("session1", "widget1"))
- self.assertEqual([], self.mgr.get_all_files("session1", "widget2"))
- self.assertEqual([f3], self.mgr.get_all_files("session2", "widget"))
-
- def test_remove_orphaned_files(self):
- """Test the remove_orphaned_files behavior"""
- f1 = self.mgr.add_file("session1", "widget1", FILE_1)
- f2 = self.mgr.add_file("session1", "widget1", FILE_1)
- f3 = self.mgr.add_file("session1", "widget1", FILE_1)
- self.assertEqual([f1, f2, f3], self.mgr.get_all_files("session1", "widget1"))
-
- # Nothing should be removed here (all files are active).
- self.mgr.remove_orphaned_files(
- "session1",
- "widget1",
- newest_file_id=f3.id,
- active_file_ids=[f1.id, f2.id, f3.id],
- )
- self.assertEqual([f1, f2, f3], self.mgr.get_all_files("session1", "widget1"))
+ self.mgr.add_file("session2", FILE_1)
- # Nothing should be removed here (no files are active, but they're all
- # "newer" than newest_file_id).
- self.mgr.remove_orphaned_files(
- "session1", "widget1", newest_file_id=f1.id - 1, active_file_ids=[]
- )
- self.assertEqual([f1, f2, f3], self.mgr.get_all_files("session1", "widget1"))
-
- # f2 should be removed here (it's not in the active file list)
- self.mgr.remove_orphaned_files(
- "session1", "widget1", newest_file_id=f3.id, active_file_ids=[f1.id, f3.id]
- )
- self.assertEqual([f1, f3], self.mgr.get_all_files("session1", "widget1"))
-
- # remove_orphaned_files on an untracked session/widget should not error
- self.mgr.remove_orphaned_files(
- "no_session", "no_widget", newest_file_id=0, active_file_ids=[]
+ self.mgr.remove_session_files("session1")
+ self.assertEqual(
+ [], self.mgr.get_files("session1", [FILE_1.file_id, FILE_2.file_id])
)
+ self.assertEqual([FILE_1], self.mgr.get_files("session2", [FILE_1.file_id]))
def test_cache_stats_provider(self):
"""Test CacheStatsProvider implementation."""
@@ -149,8 +96,8 @@ def test_cache_stats_provider(self):
self.assertEqual([], self.mgr.get_stats())
# Test manager with files
- self.mgr.add_file("session1", "widget1", FILE_1)
- self.mgr.add_file("session1", "widget2", FILE_2)
+ self.mgr.add_file("session1", FILE_1)
+ self.mgr.add_file("session1", FILE_2)
expected = [
CacheStat(
@@ -172,7 +119,7 @@ class UploadedFileManagerThreadingTest(unittest.TestCase):
NUM_THREADS = 50
def setUp(self) -> None:
- self.mgr = UploadedFileManager()
+ self.mgr = MemoryUploadedFileManager("/mock/upload")
def test_add_file(self):
"""`add_file` is thread-safe."""
@@ -181,22 +128,28 @@ def test_add_file(self):
def add_file(index: int) -> None:
file = UploadedFileRec(
- id=0, name=f"file_{index}", type="type", data=bytes(f"{index}", "utf-8")
+ file_id=f"id_{index}",
+ name=f"file_{index}",
+ type="type",
+ data=bytes(f"{index}", "utf-8"),
)
- added_files.append(self.mgr.add_file("session", f"widget_{index}", file))
+
+ self.mgr.add_file("session", file)
+ files_from_storage = self.mgr.get_files("session", [file.file_id])
+ added_files.extend(files_from_storage)
call_on_threads(add_file, num_threads=self.NUM_THREADS)
# Ensure all our files are present
for ii in range(self.NUM_THREADS):
- files = self.mgr.get_all_files("session", f"widget_{ii}")
+ files = self.mgr.get_files("session", [f"id_{ii}"])
self.assertEqual(1, len(files))
self.assertEqual(bytes(f"{ii}", "utf-8"), files[0].data)
# Ensure all files have unique IDs
file_ids = set()
- for file_list in self.mgr._files_by_id.values():
- file_ids.update(file.id for file in file_list)
+ for file_rec in self.mgr.file_storage["session"].values():
+ file_ids.add(file_rec.file_id)
self.assertEqual(self.NUM_THREADS, len(file_ids))
def test_remove_file(self):
@@ -204,79 +157,59 @@ def test_remove_file(self):
# Add a bunch of files to a single widget
file_ids = []
for ii in range(self.NUM_THREADS):
- file = UploadedFileRec(id=0, name=f"file_{ii}", type="type", data=b"123")
- file_ids.append(self.mgr.add_file("session", "widget", file).id)
+ file = UploadedFileRec(
+ file_id=f"id_{ii}",
+ name=f"file_{ii}",
+ type="type",
+ data=b"123",
+ )
+ self.mgr.add_file("session", file)
+ file_ids.append(file.file_id)
# Have each thread remove a single file
def remove_file(index: int) -> None:
file_id = file_ids[index]
# Ensure our file exists
- get_files_result = self.mgr.get_files("session", "widget", [file_id])
+ get_files_result = self.mgr.get_files("session", [file_id])
self.assertEqual(1, len(get_files_result))
- # Remove our file
- was_removed = self.mgr.remove_file("session", "widget", file_id)
- self.assertTrue(was_removed)
-
- # Ensure our file no longer exists
- get_files_result = self.mgr.get_files("session", "widget", [file_id])
+ # Remove our file and ensure our file no longer exists
+ self.mgr.remove_file("session", file_id)
+ get_files_result = self.mgr.get_files("session", [file_id])
self.assertEqual(0, len(get_files_result))
call_on_threads(remove_file, self.NUM_THREADS)
- self.assertEqual(0, len(self.mgr.get_all_files("session", "widget")))
+ self.assertEqual(0, len(self.mgr.file_storage["session"]))
def test_remove_session_files(self):
"""`remove_session_files` is thread-safe."""
# Add a bunch of files, each to a different session
file_ids = []
for ii in range(self.NUM_THREADS):
- file = UploadedFileRec(id=0, name=f"file_{ii}", type="type", data=b"123")
- file_ids.append(self.mgr.add_file(f"session_{ii}", "widget", file).id)
+ file = UploadedFileRec(
+ file_id=f"id_{ii}",
+ name=f"file_{ii}",
+ type="type",
+ data=b"123",
+ )
+ self.mgr.add_file(f"session_{ii}", file)
+ file_ids.append(file.file_id)
# Have each thread remove its session's file
def remove_session_files(index: int) -> None:
session_id = f"session_{index}"
# Our file should exist
- session_files = self.mgr.get_all_files(session_id, "widget")
+ session_files = self.mgr.get_files(session_id, [f"id_{index}"])
self.assertEqual(1, len(session_files))
- self.assertEqual(file_ids[index], session_files[0].id)
+ self.assertEqual(file_ids[index], session_files[0].file_id)
# Remove session files
self.mgr.remove_session_files(session_id)
# Our file should no longer exist
- session_files = self.mgr.get_all_files(session_id, "widget")
+ session_files = self.mgr.get_files(session_id, [f"id_{index}"])
self.assertEqual(0, len(session_files))
call_on_threads(remove_session_files, num_threads=self.NUM_THREADS)
-
- def test_remove_orphaned_files(self):
- """`remove_orphaned_files` is thread-safe."""
- # Add a bunch of "active" files to a single widget
- active_file_ids = []
- for ii in range(self.NUM_THREADS):
- file = UploadedFileRec(id=0, name=f"file_{ii}", type="type", data=b"123")
- active_file_ids.append(self.mgr.add_file("session", "widget", file).id)
-
- # Now add some "inactive" files to the same widget
- inactive_file_ids = []
- for ii in range(self.NUM_THREADS, self.NUM_THREADS + 50):
- file = UploadedFileRec(id=0, name=f"file_{ii}", type="type", data=b"123")
- inactive_file_ids.append(self.mgr.add_file("session", "widget", file).id)
-
- newest_file_id = inactive_file_ids[len(inactive_file_ids) - 1] + 1
-
- # Call `remove_orphaned_files` from each thread.
- # Our active_files should remain intact, and our orphans should be removed!
- def remove_orphans(_: int) -> None:
- self.mgr.remove_orphaned_files(
- "session", "widget", newest_file_id, active_file_ids
- )
- remaining_ids = [
- file.id for file in self.mgr.get_all_files("session", "widget")
- ]
- self.assertEqual(sorted(active_file_ids), sorted(remaining_ids))
-
- call_on_threads(remove_orphans, num_threads=self.NUM_THREADS)
diff --git a/lib/tests/streamlit/script_run_context_test.py b/lib/tests/streamlit/script_run_context_test.py
index 431be8343b90..06767a55eeaa 100644
--- a/lib/tests/streamlit/script_run_context_test.py
+++ b/lib/tests/streamlit/script_run_context_test.py
@@ -16,9 +16,9 @@
from streamlit.errors import StreamlitAPIException
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
+from streamlit.runtime.memory_uploaded_file_manager import MemoryUploadedFileManager
from streamlit.runtime.scriptrunner import ScriptRunContext
from streamlit.runtime.state import SafeSessionState, SessionState
-from streamlit.runtime.uploaded_file_manager import UploadedFileManager
class ScriptRunContextTest(unittest.TestCase):
@@ -31,7 +31,7 @@ def test_set_page_config_immutable(self):
_enqueue=fake_enqueue,
query_string="",
session_state=SafeSessionState(SessionState()),
- uploaded_file_mgr=UploadedFileManager(),
+ uploaded_file_mgr=MemoryUploadedFileManager("mock/upload"),
page_script_hash="",
user_info={"email": "test@test.com"},
)
@@ -53,7 +53,7 @@ def test_set_page_config_first(self):
_enqueue=fake_enqueue,
query_string="",
session_state=SafeSessionState(SessionState()),
- uploaded_file_mgr=UploadedFileManager(),
+ uploaded_file_mgr=MemoryUploadedFileManager("/mock/upload"),
page_script_hash="",
user_info={"email": "test@test.com"},
)
@@ -79,7 +79,7 @@ def test_disallow_set_page_config_twice(self):
_enqueue=fake_enqueue,
query_string="",
session_state=SafeSessionState(SessionState()),
- uploaded_file_mgr=UploadedFileManager(),
+ uploaded_file_mgr=MemoryUploadedFileManager("/mock/upload"),
page_script_hash="",
user_info={"email": "test@test.com"},
)
@@ -104,7 +104,7 @@ def test_set_page_config_reset(self):
_enqueue=fake_enqueue,
query_string="",
session_state=SafeSessionState(SessionState()),
- uploaded_file_mgr=UploadedFileManager(),
+ uploaded_file_mgr=MemoryUploadedFileManager("/mock/upload"),
page_script_hash="",
user_info={"email": "test@test.com"},
)
diff --git a/lib/tests/streamlit/streamlit_test.py b/lib/tests/streamlit/streamlit_test.py
index 260a708fc869..6660e29b09d8 100644
--- a/lib/tests/streamlit/streamlit_test.py
+++ b/lib/tests/streamlit/streamlit_test.py
@@ -113,6 +113,7 @@ def test_public_api(self):
"snow",
"subheader",
"success",
+ "status",
"table",
"text",
"text_area",
@@ -120,6 +121,7 @@ def test_public_api(self):
"time_input",
"title",
"toast",
+ "toggle",
"vega_lite_chart",
"video",
"warning",
diff --git a/lib/tests/streamlit/type_util_test.py b/lib/tests/streamlit/type_util_test.py
index 6b2f74f51ab8..97657ee575ce 100644
--- a/lib/tests/streamlit/type_util_test.py
+++ b/lib/tests/streamlit/type_util_test.py
@@ -178,6 +178,83 @@ def test_convert_anything_to_df_supports_key_value_dicts(self):
df = convert_anything_to_df(data)
pd.testing.assert_frame_equal(df, pd.DataFrame.from_dict(data, orient="index"))
+ def test_convert_anything_to_df_passes_styler_through(self):
+ """Test that `convert_anything_to_df` correctly passes Stylers through."""
+ original_df = pd.DataFrame(
+ {
+ "integer": [1, 2, 3],
+ "float": [1.0, 2.1, 3.2],
+ "string": ["foo", "bar", None],
+ },
+ index=[1.0, "foo", 3],
+ )
+
+ original_styler = original_df.style.highlight_max(axis=0)
+
+ out = convert_anything_to_df(original_styler, allow_styler=True)
+ self.assertEqual(original_styler, out)
+ self.assertEqual(id(original_df), id(out.data))
+
+ def test_convert_anything_to_df_clones_stylers(self):
+ """Test that `convert_anything_to_df` correctly clones Stylers."""
+ original_df = pd.DataFrame(
+ {
+ "integer": [1, 2, 3],
+ "float": [1.0, 2.1, 3.2],
+ "string": ["foo", "bar", None],
+ },
+ index=[1.0, "foo", 3],
+ )
+
+ original_styler = original_df.style.highlight_max(axis=0)
+
+ out = convert_anything_to_df(
+ original_styler, allow_styler=True, ensure_copy=True
+ )
+ self.assertNotEqual(original_styler, out)
+ self.assertNotEqual(id(original_df), id(out.data))
+ pd.testing.assert_frame_equal(original_df, out.data)
+
+ def test_convert_anything_to_df_converts_stylers(self):
+ """Test that `convert_anything_to_df` correctly converts Stylers to DF, without cloning the
+ data.
+ """
+ original_df = pd.DataFrame(
+ {
+ "integer": [1, 2, 3],
+ "float": [1.0, 2.1, 3.2],
+ "string": ["foo", "bar", None],
+ },
+ index=[1.0, "foo", 3],
+ )
+
+ original_styler = original_df.style.highlight_max(axis=0)
+
+ out = convert_anything_to_df(original_styler, allow_styler=False)
+ self.assertNotEqual(id(original_styler), id(out))
+ self.assertEqual(id(original_df), id(out))
+ pd.testing.assert_frame_equal(original_df, out)
+
+ def test_convert_anything_to_df_converts_stylers_and_clones_data(self):
+ """Test that `convert_anything_to_df` correctly converts Stylers to DF, cloning the data."""
+ original_df = pd.DataFrame(
+ {
+ "integer": [1, 2, 3],
+ "float": [1.0, 2.1, 3.2],
+ "string": ["foo", "bar", None],
+ },
+ index=[1.0, "foo", 3],
+ )
+
+ original_styler = original_df.style.highlight_max(axis=0)
+
+ out = convert_anything_to_df(
+ original_styler, allow_styler=False, ensure_copy=True
+ )
+ self.assertNotEqual(id(original_styler), id(out))
+ self.assertNotEqual(id(original_df), id(out))
+ pd.testing.assert_frame_equal(original_df, out)
+
def test_convert_anything_to_df_calls_to_pandas_when_available(self):
class DataFrameIsh:
def to_pandas(self):
diff --git a/lib/tests/streamlit/web/server/upload_file_request_handler_test.py b/lib/tests/streamlit/web/server/upload_file_request_handler_test.py
index 392987e649fd..5c8d66a443d0 100644
--- a/lib/tests/streamlit/web/server/upload_file_request_handler_test.py
+++ b/lib/tests/streamlit/web/server/upload_file_request_handler_test.py
@@ -22,11 +22,10 @@
import tornado.websocket
from streamlit.logger import get_logger
+from streamlit.runtime.memory_uploaded_file_manager import MemoryUploadedFileManager
from streamlit.runtime.uploaded_file_manager import UploadedFileManager
-from streamlit.web.server.upload_file_request_handler import (
- UPLOAD_FILE_ROUTE,
- UploadFileRequestHandler,
-)
+from streamlit.web.server.server import UPLOAD_FILE_ENDPOINT
+from streamlit.web.server.upload_file_request_handler import UploadFileRequestHandler
LOGGER = get_logger(__name__)
@@ -45,11 +44,11 @@ class UploadFileRequestHandlerTest(tornado.testing.AsyncHTTPTestCase):
"""Tests the /upload_file endpoint."""
def get_app(self):
- self.file_mgr = UploadedFileManager()
+ self.file_mgr = MemoryUploadedFileManager(upload_endpoint=UPLOAD_FILE_ENDPOINT)
return tornado.web.Application(
[
(
- UPLOAD_FILE_ROUTE,
+ f"{UPLOAD_FILE_ENDPOINT}/(?P[^/]+)/(?P[^/]+)",
UploadFileRequestHandler,
dict(
file_mgr=self.file_mgr,
@@ -59,17 +58,19 @@ def get_app(self):
]
)
- def _upload_files(self, params):
+ def _upload_files(self, files_body, session_id, file_id):
# We use requests.Request to construct our multipart/form-data request
# here, because they are absurdly fiddly to compose, and Tornado
# doesn't include a utility for building them. We then use self.fetch()
# to actually send the request to the test server.
req = requests.Request(
- method="POST", url=self.get_url("/_stcore/upload_file"), files=params
+ method="PUT",
+ url=self.get_url(f"{UPLOAD_FILE_ENDPOINT}/{session_id}/{file_id}"),
+ files=files_body,
).prepare()
return self.fetch(
- "/_stcore/upload_file",
+ req.url,
method=req.method,
headers=req.headers,
body=req.body,
@@ -78,19 +79,18 @@ def _upload_files(self, params):
def test_upload_one_file(self):
"""Uploading a file should populate our file_mgr."""
file = MockFile("filename", b"123")
- params = {
- file.name: file.data,
- "sessionId": (None, "mockSessionId"),
- "widgetId": (None, "mockWidgetId"),
- }
- response = self._upload_files(params)
- self.assertEqual(200, response.code, response.reason)
- file_id = int(response.body)
+ params = {file.name: file.data}
+ response = self._upload_files(
+ params, session_id="test_session_id", file_id=file.name
+ )
+
+ self.assertEqual(204, response.code, response.reason)
+
self.assertEqual(
- [(file_id, file.name, file.data)],
+ [(file.name, file.name, file.data)],
[
- (rec.id, rec.name, rec.data)
- for rec in self.file_mgr.get_all_files("mockSessionId", "mockWidgetId")
+ (rec.file_id, rec.name, rec.data)
+ for rec in self.file_mgr.get_files("test_session_id", [file.name])
],
)
@@ -99,38 +99,45 @@ def test_upload_multiple_files_error(self):
file_1 = MockFile("file1", b"123")
file_2 = MockFile("file2", b"456")
- params = {
+ files_body = {
file_1.name: file_1.data,
file_2.name: file_2.data,
- "sessionId": (None, "mockSessionId"),
- "widgetId": (None, "mockWidgetId"),
}
- response = self._upload_files(params)
+ response = self._upload_files(
+ files_body, session_id="some-session-id", file_id="some-file-id"
+ )
self.assertEqual(400, response.code)
self.assertIn("Expected 1 file, but got 2", response.reason)
- def test_upload_missing_params_error(self):
- """Missing params in the body should fail with 400 status."""
- params = {
+ def test_upload_missing_session_id_error(self):
+ """Missing session_id in the path should fail with 404 status."""
+ file_body = {
"image.png": ("image.png", b"1234"),
- "fileId": (None, "123"),
- "sessionId": (None, "mockSessionId"),
- # "widgetId": (None, 'mockWidgetId'),
}
- response = self._upload_files(params)
- self.assertEqual(400, response.code)
- self.assertIn("Missing 'widgetId'", response.reason)
+ response = self._upload_files(file_body, session_id="", file_id="file_id")
+ self.assertEqual(404, response.code)
+ self.assertIn("Not Found", response.reason)
+
+ def test_upload_missing_file_id_error(self):
+ """Missing file_id in the path should fail with 404 status."""
+ file_body = {
+ "image.png": ("image.png", b"1234"),
+ }
+
+ response = self._upload_files(file_body, session_id="session_id", file_id="")
+ self.assertEqual(404, response.code)
+ self.assertIn("Not Found", response.reason)
def test_upload_missing_file_error(self):
"""Missing file should fail with 400 status."""
- params = {
- # "image.png": ("image.png", b"1234"),
- "fileId": (None, "123"),
- "sessionId": (None, "mockSessionId"),
- "widgetId": (None, "mockWidgetId"),
+ file_body = {
+ "file1": (None, b"123"),
}
- response = self._upload_files(params)
+ response = self._upload_files(
+ file_body, session_id="sessionId", file_id="fileId"
+ )
+
self.assertEqual(400, response.code)
self.assertIn("Expected 1 file, but got 0", response.reason)
@@ -139,11 +146,11 @@ class UploadFileRequestHandlerInvalidSessionTest(tornado.testing.AsyncHTTPTestCa
"""Tests the /upload_file endpoint."""
def get_app(self):
- self.file_mgr = UploadedFileManager()
+ self.file_mgr = MemoryUploadedFileManager(upload_endpoint=UPLOAD_FILE_ENDPOINT)
return tornado.web.Application(
[
(
- UPLOAD_FILE_ROUTE,
+ f"{UPLOAD_FILE_ENDPOINT}/(?P[^/]+)/(?P[^/]+)",
UploadFileRequestHandler,
dict(
file_mgr=self.file_mgr,
@@ -153,17 +160,19 @@ def get_app(self):
]
)
- def _upload_files(self, params):
+ def _upload_files(self, files_body, session_id, file_id):
# We use requests.Request to construct our multipart/form-data request
# here, because they are absurdly fiddly to compose, and Tornado
# doesn't include a utility for building them. We then use self.fetch()
# to actually send the request to the test server.
req = requests.Request(
- method="POST", url=self.get_url("/_stcore/upload_file"), files=params
+ method="PUT",
+ url=self.get_url(f"{UPLOAD_FILE_ENDPOINT}/{session_id}/{file_id}"),
+ files=files_body,
).prepare()
return self.fetch(
- "/_stcore/upload_file",
+ req.url,
method=req.method,
headers=req.headers,
body=req.body,
@@ -172,14 +181,8 @@ def _upload_files(self, params):
def test_upload_one_file(self):
"""Upload should fail if the sessionId doesn't exist."""
file = MockFile("filename", b"123")
- params = {
- file.name: file.data,
- "sessionId": (None, "mockSessionId"),
- "widgetId": (None, "mockWidgetId"),
- }
- response = self._upload_files(params)
+ params = {file.name: file.data}
+ response = self._upload_files(params, session_id="sessionId", file_id="fileId")
self.assertEqual(400, response.code)
- self.assertIn("Invalid session_id: 'mockSessionId'", response.reason)
- self.assertEqual(
- self.file_mgr.get_all_files("mockSessionId", "mockWidgetId"), []
- )
+ self.assertIn("Invalid session_id: 'sessionId'", response.reason)
+ self.assertEqual(self.file_mgr.get_files("sessionId", ["fileId"]), [])
diff --git a/lib/tests/streamlit/write_test.py b/lib/tests/streamlit/write_test.py
index aba16dc530ca..4ff7c2d48d1f 100644
--- a/lib/tests/streamlit/write_test.py
+++ b/lib/tests/streamlit/write_test.py
@@ -30,7 +30,6 @@
from streamlit.error_util import handle_uncaught_app_exception
from streamlit.errors import StreamlitAPIException
from streamlit.runtime.state import SessionStateProxy
-from tests.testutil import should_skip_pyspark_tests
class StreamlitWriteTest(unittest.TestCase):
@@ -198,9 +197,6 @@ def test_snowpark_dataframe_write(self):
)
p.assert_called_once()
- @pytest.mark.skipif(
- should_skip_pyspark_tests(), reason="pyspark is incompatible with Python3.11"
- )
def test_pyspark_dataframe_write(self):
"""Test st.write with pyspark.sql.DataFrame."""
# Import package inside the test so the test suite still runs even if you don't
diff --git a/lib/tests/testutil.py b/lib/tests/testutil.py
index 8ac5d41aa316..9bc931b546e5 100644
--- a/lib/tests/testutil.py
+++ b/lib/tests/testutil.py
@@ -13,32 +13,22 @@
# limitations under the License.
"""Utility functions to use in our tests."""
+
import json
-import sys
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Dict
from unittest.mock import patch
from streamlit import config
+from streamlit.runtime.memory_uploaded_file_manager import MemoryUploadedFileManager
from streamlit.runtime.scriptrunner import ScriptRunContext
from streamlit.runtime.state import SafeSessionState, SessionState
-from streamlit.runtime.uploaded_file_manager import UploadedFileManager
from tests.constants import SNOWFLAKE_CREDENTIAL_FILE
if TYPE_CHECKING:
from snowflake.snowpark import Session
-def should_skip_pyspark_tests() -> bool:
- """Disable pyspark unit tests in Python 3.11.
- Pyspark is not compatible with Python 3.11 as of 2023.01.12, and results in test failures.
- (Please remove this when pyspark is compatible!)
- """
- # See: https://pyreadiness.org/3.11/
- # See: https://stackoverflow.com/questions/74579273/indexerror-tuple-index-out-of-range-when-creating-pyspark-dataframe
- return sys.version_info >= (3, 11, 0)
-
-
def should_skip_pydantic_tests() -> bool:
try:
import pydantic
@@ -55,7 +45,7 @@ def create_mock_script_run_ctx() -> ScriptRunContext:
_enqueue=lambda msg: None,
query_string="mock_query_string",
session_state=SafeSessionState(SessionState()),
- uploaded_file_mgr=UploadedFileManager(),
+ uploaded_file_mgr=MemoryUploadedFileManager("/mock/upload"),
page_script_hash="mock_page_script_hash",
user_info={"email": "mock@test.com"},
)
diff --git a/proto/streamlit/proto/BackMsg.proto b/proto/streamlit/proto/BackMsg.proto
index 51387f42854f..1090475a836b 100644
--- a/proto/streamlit/proto/BackMsg.proto
+++ b/proto/streamlit/proto/BackMsg.proto
@@ -17,6 +17,7 @@
syntax = "proto3";
import "streamlit/proto/ClientState.proto";
+import "streamlit/proto/Common.proto";
// A message from the browser to the server.
message BackMsg {
@@ -62,6 +63,10 @@ message BackMsg {
// runtime. This message is IGNORED unless the runtime is configured with
// global.developmentMode = True.
bool debug_shutdown_runtime = 15;
+
+ // Requests that the server generate URLs for getting/uploading/deleting
+ // files for the `st.file_uploader` widget
+ FileURLsRequest file_urls_request = 16;
}
// An ID used to associate this BackMsg with the corresponding ForwardMsgs
@@ -71,5 +76,5 @@ message BackMsg {
reserved 1, 2, 3, 4, 8, 9, 10;
- // Next: 16
+ // Next: 17
}
diff --git a/proto/streamlit/proto/Block.proto b/proto/streamlit/proto/Block.proto
index e5e464f55938..d40c73776008 100644
--- a/proto/streamlit/proto/Block.proto
+++ b/proto/streamlit/proto/Block.proto
@@ -44,7 +44,8 @@ message Block {
message Expandable {
string label = 1;
- bool expanded = 2;
+ optional bool expanded = 2;
+ string icon = 3;
}
message Form {
diff --git a/proto/streamlit/proto/Checkbox.proto b/proto/streamlit/proto/Checkbox.proto
index a8372f077617..5b83216afd1c 100644
--- a/proto/streamlit/proto/Checkbox.proto
+++ b/proto/streamlit/proto/Checkbox.proto
@@ -19,6 +19,11 @@ syntax = "proto3";
import "streamlit/proto/LabelVisibilityMessage.proto";
message Checkbox {
+ enum StyleType {
+ DEFAULT = 0;
+ TOGGLE = 1;
+ }
+
string id = 1;
string label = 2;
bool default = 3;
@@ -28,4 +33,5 @@ message Checkbox {
bool set_value = 7;
bool disabled = 8;
LabelVisibilityMessage label_visibility = 9;
+ StyleType type = 10;
}
diff --git a/proto/streamlit/proto/Common.proto b/proto/streamlit/proto/Common.proto
index 05b4cd64d7e1..dacf0c314184 100644
--- a/proto/streamlit/proto/Common.proto
+++ b/proto/streamlit/proto/Common.proto
@@ -46,17 +46,48 @@ message StringTriggerValue {
optional string data = 1;
}
+// TODO(vdonato / kajarenc): Finalize the next two proto types. We currently
+// have enough information here to support pure OS use cases, but we'll need to
+// coordinate with the SiS team to verify that we're passing enough file
+// metadata back for them.
+message FileURLsRequest {
+ string request_id = 1;
+ repeated string file_names = 2;
+ string session_id = 3;
+}
+
+message FileURLs {
+ string file_id = 1;
+ string upload_url = 2;
+ string delete_url = 3;
+}
+
+message FileURLsResponse {
+ string response_id = 1;
+ repeated FileURLs file_urls = 2;
+ string error_msg = 3;
+}
+
// Information on a file uploaded via the file_uploader widget.
message UploadedFileInfo {
+ // DEPRECATED.
sint64 id = 1;
string name = 2;
// The size of this file in bytes.
uint32 size = 3;
+
+ // ID that can be used to retrieve a file.
+ string file_id = 4;
+
+ // Metadata containing information about file_urls.
+ FileURLs file_urls = 5;
}
message FileUploaderState {
+ // DEPRECATED
sint64 max_file_id = 1;
+
repeated UploadedFileInfo uploaded_file_info = 2;
}
diff --git a/proto/streamlit/proto/DeckGlJsonChart.proto b/proto/streamlit/proto/DeckGlJsonChart.proto
index 891cbe9ba781..d205da0da421 100644
--- a/proto/streamlit/proto/DeckGlJsonChart.proto
+++ b/proto/streamlit/proto/DeckGlJsonChart.proto
@@ -17,11 +17,14 @@
syntax = "proto3";
message DeckGlJsonChart {
- // The dataframe that will be used as the chart's main data source.
+ // The json of the pydeck object (https://deckgl.readthedocs.io/en/latest/deck.html)
string json = 1;
string tooltip = 2;
// If True, will overwrite the chart width spec to fit to container.
bool use_container_width = 4;
+
+ // the hash of the json so the the frontend doesn't always have to parse the pydeck json object
+ string id = 5;
}
diff --git a/proto/streamlit/proto/DownloadButton.proto b/proto/streamlit/proto/DownloadButton.proto
index 35a844e0fcad..436d5b47a13a 100644
--- a/proto/streamlit/proto/DownloadButton.proto
+++ b/proto/streamlit/proto/DownloadButton.proto
@@ -25,4 +25,5 @@ message DownloadButton {
string url = 6;
bool disabled = 7;
bool use_container_width = 8;
+ string type = 9;
}
diff --git a/proto/streamlit/proto/ForwardMsg.proto b/proto/streamlit/proto/ForwardMsg.proto
index 65a748b6adb5..36b27fab4fff 100644
--- a/proto/streamlit/proto/ForwardMsg.proto
+++ b/proto/streamlit/proto/ForwardMsg.proto
@@ -16,6 +16,7 @@
syntax = "proto3";
+import "streamlit/proto/Common.proto";
import "streamlit/proto/Delta.proto";
import "streamlit/proto/GitInfo.proto";
import "streamlit/proto/NewSession.proto";
@@ -65,6 +66,7 @@ message ForwardMsg {
// Other messages.
PageNotFound page_not_found = 15;
PagesChanged pages_changed = 16;
+ FileURLsResponse file_urls_response = 19;
// A reference to a ForwardMsg that has already been delivered.
// The client should substitute the message with the given hash
@@ -79,7 +81,7 @@ message ForwardMsg {
string debug_last_backmsg_id = 17;
reserved 7, 8;
- // Next: 19
+ // Next: 20
}
// ForwardMsgMetadata contains all data that does _not_ get hashed (or cached)
diff --git a/proto/streamlit/proto/Heading.proto b/proto/streamlit/proto/Heading.proto
index 9cf3e0b11b8a..ce56121bbd2a 100644
--- a/proto/streamlit/proto/Heading.proto
+++ b/proto/streamlit/proto/Heading.proto
@@ -25,4 +25,5 @@ message Heading {
string help = 4;
bool hide_anchor = 5;
+ string divider = 6;
}
diff --git a/proto/streamlit/proto/Radio.proto b/proto/streamlit/proto/Radio.proto
index cdd0843c953b..d88e9ebe8eea 100644
--- a/proto/streamlit/proto/Radio.proto
+++ b/proto/streamlit/proto/Radio.proto
@@ -30,4 +30,5 @@ message Radio {
bool disabled = 9;
bool horizontal = 10;
LabelVisibilityMessage label_visibility = 11;
+ repeated string captions = 12;
}