-
+
diff --git a/src/ui/src/wds/WdsTextInput.vue b/src/ui/src/wds/WdsTextInput.vue
index bcbfc58c6..81712ee23 100644
--- a/src/ui/src/wds/WdsTextInput.vue
+++ b/src/ui/src/wds/WdsTextInput.vue
@@ -1,13 +1,35 @@
-
+
+ {{ leftIcon }}
+
+
+
diff --git a/src/writer/__init__.py b/src/writer/__init__.py
index b93a03f1d..1639e1974 100644
--- a/src/writer/__init__.py
+++ b/src/writer/__init__.py
@@ -5,6 +5,7 @@
from writer.core import (
BytesWrapper,
Config,
+ EditableDataframe,
FileWrapper,
Readable,
State,
diff --git a/src/writer/ai.py b/src/writer/ai.py
index ae9655de2..6ce94e142 100644
--- a/src/writer/ai.py
+++ b/src/writer/ai.py
@@ -1,6 +1,7 @@
import logging
from datetime import datetime
from typing import (
+ Dict,
Generator,
Iterable,
List,
@@ -31,6 +32,7 @@
)
from writerai.types import File as SDKFile
from writerai.types import Graph as SDKGraph
+from writerai.types.application_generate_content_params import Input
from writerai.types.chat_chat_params import Message as WriterAIMessage
from writer.core import get_app_process
@@ -60,7 +62,7 @@ class CreateOptions(APIOptions, total=False):
stop: Union[List[str], str, NotGiven]
temperature: Union[float, NotGiven]
top_p: Union[float, NotGiven]
-
+
class APIListOptions(APIOptions, total=False):
after: Union[str, NotGiven]
@@ -68,7 +70,6 @@ class APIListOptions(APIOptions, total=False):
limit: Union[int, NotGiven]
order: Union[Literal["asc", "desc"], NotGiven]
-
logger = logging.Logger(__name__)
@@ -926,6 +927,35 @@ def serialized_messages(self) -> List['Message']:
return serialized_messages
+class Apps:
+ def generate_content(self, application_id: str, input_dict: Dict[str, str] = {}, config: Optional[APIOptions] = None) -> str:
+ """
+ Generates output based on an existing AI Studio no-code application.
+
+ :param application_id: The id for the application, which can be obtained on AI Studio.
+ :param input_dict: Optional dictionary containing parameters for the generation call.
+ :return: The generated text.
+ :raises RuntimeError: If response data was not properly formatted to retrieve model text.
+ """
+
+ client = WriterAIManager.acquire_client()
+ config = config or {}
+ inputs = []
+
+ for k, v in input_dict.items():
+ inputs.append(Input({
+ "id": k,
+ "value": v if isinstance(v, list) else [v]
+ }))
+
+ response_data = client.applications.generate_content(application_id=application_id, inputs=inputs, **config)
+
+ text = response_data.suggestion
+ if text:
+ return text
+
+ raise RuntimeError(f"Failed to acquire proper response for completion from data: {response_data}")
+
def complete(initial_text: str, config: Optional['CreateOptions'] = None) -> str:
"""
Completes the input text using the given data and returns the first resulting text choice.
@@ -1009,3 +1039,5 @@ def init(token: Optional[str] = None):
:return: An instance of WriterAIManager.
"""
return WriterAIManager(token=token)
+
+apps = Apps()
\ No newline at end of file
diff --git a/src/writer/auth.py b/src/writer/auth.py
index bc768ce8f..646208543 100644
--- a/src/writer/auth.py
+++ b/src/writer/auth.py
@@ -1,5 +1,6 @@
import asyncio
import dataclasses
+import logging
import os.path
import time
from abc import ABCMeta, abstractmethod
@@ -16,6 +17,8 @@
from writer.serve import WriterFastAPI
from writer.ss_types import InitSessionRequestPayload
+logger = logging.getLogger('writer')
+
# Dictionary for storing failed attempts {ip_address: timestamp}
failed_attempts: Dict[str, float] = {}
@@ -181,11 +184,23 @@ def register(self,
callback: Optional[Callable[[Request, str, dict], None]] = None,
unauthorized_action: Optional[Callable[[Request, Unauthorized], Response]] = None
):
+
+ redirect_url = urljoin(self.host_url, self.callback_authorize)
+ host_url_path = urlpath(self.host_url)
+ callback_authorize_path = urljoin(host_url_path, self.callback_authorize)
+ asset_assets_path = urljoin(host_url_path, "assets")
+
+ logger.debug(f"[auth] oidc - url redirect: {redirect_url}")
+ logger.debug(f"[auth] oidc - endpoint authorize: {self.url_authorize}")
+ logger.debug(f"[auth] oidc - endpoint token: {self.url_oauthtoken}")
+ logger.debug(f"[auth] oidc - path: {host_url_path}")
+ logger.debug(f"[auth] oidc - authorize path: {callback_authorize_path}")
+ logger.debug(f"[auth] oidc - asset path: {asset_assets_path}")
self.authlib = OAuth2Session(
client_id=self.client_id,
client_secret=self.client_secret,
scope=self.scope.split(" "),
- redirect_uri=_urljoin(self.host_url, self.callback_authorize),
+ redirect_uri=redirect_url,
authorization_endpoint=self.url_authorize,
token_endpoint=self.url_oauthtoken,
)
@@ -195,10 +210,8 @@ def register(self,
@asgi_app.middleware("http")
async def oidc_middleware(request: Request, call_next):
session = request.cookies.get('session')
- host_url_path = _urlpath(self.host_url)
- full_callback_authorize = '/' + _urljoin(host_url_path, self.callback_authorize)
- full_assets = '/' + _urljoin(host_url_path, '/assets')
- if session is not None or request.url.path in [full_callback_authorize] or request.url.path.startswith(full_assets):
+
+ if session is not None or request.url.path in [callback_authorize_path] or request.url.path.startswith(asset_assets_path):
response: Response = await call_next(request)
return response
else:
@@ -206,11 +219,11 @@ async def oidc_middleware(request: Request, call_next):
response = RedirectResponse(url=url[0])
return response
- @asgi_app.get('/' + _urlstrip(self.callback_authorize))
+ @asgi_app.get('/' + urlstrip(self.callback_authorize))
async def route_callback(request: Request):
self.authlib.fetch_token(url=self.url_oauthtoken, authorization_response=str(request.url))
try:
- host_url_path = _urlpath(self.host_url)
+ host_url_path = urlpath(self.host_url)
response = RedirectResponse(url=host_url_path)
session_id = session_manager.generate_session_id()
@@ -300,44 +313,54 @@ def Auth0(client_id: str, client_secret: str, domain: str, host_url: str) -> Oid
url_oauthtoken=f"https://{domain}/oauth/token",
url_userinfo=f"https://{domain}/userinfo")
-def _urlpath(url: str):
+def urlpath(url: str):
"""
- >>> _urlpath("http://localhost/app1")
+ >>> urlpath("http://localhost/app1")
>>> "/app1"
+
+ >>> urlpath("http://localhost")
+ >>> "/"
"""
- return urlparse(url).path
+ path = urlparse(url).path
+ if len(path) == 0:
+ return "/"
+ else:
+ return path
-def _urljoin(*args):
+def urljoin(*args):
"""
- >>> _urljoin("http://localhost/app1", "edit")
+ >>> urljoin("http://localhost/app1", "edit")
>>> "http://localhost/app1/edit"
- >>> _urljoin("app1/", "edit")
+ >>> urljoin("app1/", "edit")
>>> "app1/edit"
- >>> _urljoin("app1", "edit")
+ >>> urljoin("app1", "edit")
>>> "app1/edit"
- >>> _urljoin("/app1/", "/edit")
- >>> "app1/edit"
+ >>> urljoin("/app1/", "/edit")
+ >>> "/app1/edit"
"""
+ root_part = args[0]
+ root_part_is_root_path = root_part.startswith('/') and len(root_part) > 1
+
url_strip_parts = []
for part in args:
if part:
- url_strip_parts.append(_urlstrip(part))
+ url_strip_parts.append(urlstrip(part))
- return '/'.join(url_strip_parts)
+ return '/'.join(url_strip_parts) if root_part_is_root_path is False else '/' + '/'.join(url_strip_parts)
-def _urlstrip(url_path: str):
+def urlstrip(url_path: str):
"""
- >>> _urlstrip("/app1/")
+ >>> urlstrip("/app1/")
>>> "app1"
- >>> _urlstrip("http://localhost/app1")
+ >>> urlstrip("http://localhost/app1")
>>> "http://localhost/app1"
- >>> _urlstrip("http://localhost/app1/")
+ >>> urlstrip("http://localhost/app1/")
>>> "http://localhost/app1"
"""
return url_path.strip('/')
diff --git a/src/writer/command_line.py b/src/writer/command_line.py
index a6b2d767a..ddd4acc26 100644
--- a/src/writer/command_line.py
+++ b/src/writer/command_line.py
@@ -1,133 +1,78 @@
-import argparse
-import getpass
import logging
import os
-import re
import shutil
import sys
-from typing import List, Optional, Union
-
-import writer.deploy
-import writer.serve
+from typing import Optional
+import click
+import writer.serve
+from writer.deploy import cloud
+
+CONTEXT_SETTINGS = {'help_option_names': ['-h', '--help']}
+@click.group(
+ context_settings=CONTEXT_SETTINGS,
+ help="Writer Framework CLI",
+)
+@click.version_option(None, '--version', '-v')
def main():
- parser = argparse.ArgumentParser(
- description="Run, edit or create a Writer Framework app.")
- parser.add_argument("command", choices=[
- "run", "edit", "create", "hello", "deploy", "undeploy", "deployment-logs"])
- parser.add_argument(
- "path", nargs="?", help="Path to the app's folder")
- parser.add_argument(
- "--port", help="The port on which to run the server.")
- parser.add_argument(
- "--api-key", help="The API key to use for deployment.")
- parser.add_argument(
- "--host", help="The host on which to run the server. Use 0.0.0.0 to share in your local network.")
- parser.add_argument(
- "--enable-remote-edit", help="Set this flag to allow non-local requests in edit mode.", action='store_true')
- parser.add_argument(
- "--enable-server-setup", help="Set this flag to enable server setup hook in edit mode.", action='store_true')
- parser.add_argument(
- "--template", help="The template to use when creating a new app.")
- parser.add_argument(
- "--env", nargs="*", help="Env variables for the deploy command in the format ENV_VAR=value.")
-
- args = parser.parse_args()
- command = args.command
- default_port = 3006 if command in ("edit", "hello") else 3005
- enable_remote_edit = args.enable_remote_edit
- enable_server_setup_hook = args.enable_server_setup
- template_name = args.template
-
- port = int(args.port) if args.port else default_port
- absolute_app_path = _get_absolute_app_path(
- args.path) if args.path else None
- host = args.host if args.host else None
- api_key = args.api_key if args.api_key else None
-
- _perform_checks(command, absolute_app_path, host, enable_remote_edit, api_key)
- api_key = _get_api_key(command, api_key)
- env = _validate_env_vars(args.env)
- _route(command, absolute_app_path, port, host, enable_remote_edit, enable_server_setup_hook, template_name, api_key, env)
-
-def _validate_env_vars(env: Union[List[str], None]) -> Union[List[str], None]:
- if env is None:
- return None
- for var in env:
- regex = r"^[a-zA-Z_]+[a-zA-Z0-9_]*=.*$"
- if not re.match(regex, var):
- logging.error(f"Invalid environment variable: {var}, please use the format ENV_VAR=value")
- sys.exit(1)
- return env
-
-def _get_api_key(command, api_key: Optional[str]) -> Optional[str]:
- if command in ("deploy", "undeploy", "deployment-logs") and api_key is None:
- env_key = os.getenv("WRITER_API_KEY", None)
- if env_key is not None and env_key != "":
- return env_key
- else:
- logging.info("An API key is required to deploy a Writer Framework app.")
- api_key = getpass.getpass(prompt='Enter your API key: ', stream=None)
- if api_key is None or api_key == "":
- logging.error("No API key provided. Exiting.")
- sys.exit(1)
- return api_key
- else:
- return api_key
-
-
-def _perform_checks(command: str, absolute_app_path: str, host: Optional[str], enable_remote_edit: Optional[bool], api_key: Optional[str] = None):
- is_path_folder = absolute_app_path is not None and os.path.isdir(absolute_app_path)
-
- if command in ("run", "edit", "deploy") and is_path_folder is False:
- logging.error("A path to a folder containing a Writer Framework app is required. For example: writer edit my_app")
- sys.exit(1)
-
- if command in ("create") and absolute_app_path is None:
- logging.error("A target folder is required to create a Writer Framework app. For example: writer create my_app")
- sys.exit(1)
-
- if command in ("edit", "hello") and host is not None:
- logging.warning("Writer Framework has been enabled in edit mode with a host argument\nThis is enabled for local development purposes (such as a local VM).\nDon't expose Builder to the Internet. We recommend using a SSH tunnel instead.")
-
- if command in ("edit", "hello") and enable_remote_edit is True:
- logging.warning("The remote edit flag is active. Builder will accept non-local requests. Please make sure the host is protected to avoid drive-by attacks.")
-
-
-def _route(
- command: str,
- absolute_app_path: str,
- port: int,
- host: Optional[str],
- enable_remote_edit: Optional[bool],
- enable_server_setup: Optional[bool],
- template_name: Optional[str],
- api_key: Optional[str] = None,
- env: Union[List[str], None] = None
-):
- if host is None:
- host = "127.0.0.1"
- if command in ("deploy"):
- writer.deploy.deploy(absolute_app_path, api_key, env=env)
- if command in ("undeploy"):
- writer.deploy.undeploy(api_key)
- if command in ("deployment-logs"):
- writer.deploy.runtime_logs(api_key)
- if command in ("edit"):
- writer.serve.serve(
- absolute_app_path, mode="edit", port=port, host=host,
- enable_remote_edit=enable_remote_edit, enable_server_setup=enable_server_setup)
- if command in ("run"):
- writer.serve.serve(
- absolute_app_path, mode="run", port=port, host=host, enable_server_setup=True)
- elif command in ("hello"):
- create_app("hello", template_name="hello", overwrite=True)
- writer.serve.serve("hello", mode="edit",
- port=port, host=host, enable_remote_edit=enable_remote_edit,
- enable_server_setup=False)
- elif command in ("create"):
- create_app(absolute_app_path, template_name=template_name)
+ pass
+
+@main.command()
+@click.option('--host', default="127.0.0.1", help="Host to run the app on")
+@click.option('--port', default=5000, help="Port to run the app on")
+@click.argument('path')
+def run(path, host, port):
+ """Run the app from PATH folder in run mode."""
+
+ abs_path = os.path.abspath(path)
+ if not os.path.isdir(abs_path):
+ raise click.ClickException("A path to a folder containing a Writer Framework app is required. For example: writer run my_app")
+
+ writer.serve.serve(
+ abs_path, mode="run", port=port, host=host, enable_server_setup=True)
+
+@main.command()
+@click.option('--host', default="127.0.0.1", help="Host to run the app on")
+@click.option('--port', default=5000, help="Port to run the app on")
+@click.option('--enable-remote-edit', help="Set this flag to allow non-local requests in edit mode.", is_flag=True)
+@click.option('--enable-server-setup', help="Set this flag to enable server setup hook in edit mode.", is_flag=True)
+@click.argument('path')
+def edit(path, port, host, enable_remote_edit, enable_server_setup):
+ """Run the app from PATH folder in edit mode."""
+
+ abs_path = os.path.abspath(path)
+ if not os.path.isdir(abs_path):
+ raise click.ClickException("A path to a folder containing a Writer Framework app is required. For example: writer edit my_app")
+
+ writer.serve.serve(
+ abs_path, mode="edit", port=port, host=host,
+ enable_remote_edit=enable_remote_edit, enable_server_setup=enable_server_setup)
+
+@main.command()
+@click.argument('path')
+@click.option('--template', help="The template to use when creating a new app.")
+def create(path, template):
+ """Create a new app in PATH folder."""
+
+ abs_path = os.path.abspath(path)
+ if os.path.isfile(abs_path):
+ raise click.ClickException("A target folder is required to create a Writer Framework app. For example: writer create my_app")
+
+ create_app(os.path.abspath(path), template_name=template)
+
+@main.command()
+@click.option('--host', default="127.0.0.1", help="Host to run the app on")
+@click.option('--port', default=5000, help="Port to run the app on")
+@click.option('--enable-remote-edit', help="Set this flag to allow non-local requests in edit mode.", is_flag=True)
+def hello(port, host, enable_remote_edit):
+ """Create and run an onboarding 'Hello' app."""
+ create_app("hello", template_name="hello", overwrite=True)
+ writer.serve.serve("hello", mode="edit",
+ port=port, host=host, enable_remote_edit=enable_remote_edit,
+ enable_server_setup=False)
+
+main.add_command(cloud)
def create_app(app_path: str, template_name: Optional[str], overwrite=False):
if template_name is None:
@@ -149,15 +94,5 @@ def create_app(app_path: str, template_name: Optional[str], overwrite=False):
shutil.copytree(template_path, app_path, dirs_exist_ok=True)
-
-def _get_absolute_app_path(app_path: str):
- is_path_absolute = os.path.isabs(app_path)
- if is_path_absolute:
- return app_path
- else:
- return os.path.join(os.getcwd(), app_path)
-
-
-
if __name__ == "__main__":
main()
diff --git a/src/writer/core.py b/src/writer/core.py
index 8cb2cf442..c565066f9 100644
--- a/src/writer/core.py
+++ b/src/writer/core.py
@@ -2,6 +2,7 @@
import base64
import contextlib
import copy
+import dataclasses
import datetime
import inspect
import io
@@ -14,6 +15,8 @@
import time
import traceback
import urllib.request
+from abc import ABCMeta
+from functools import wraps
from multiprocessing.process import BaseProcess
from types import ModuleType
from typing import (
@@ -35,9 +38,14 @@
cast,
)
+import pyarrow # type: ignore
+
from writer import core_ui
from writer.core_ui import Component
from writer.ss_types import (
+ DataframeRecordAdded,
+ DataframeRecordRemoved,
+ DataframeRecordUpdated,
InstancePath,
InstancePathItem,
Readable,
@@ -47,6 +55,9 @@
)
if TYPE_CHECKING:
+ import pandas
+ import polars
+
from writer.app_runner import AppProcess
@@ -65,6 +76,31 @@ def get_app_process() -> 'AppProcess':
raise RuntimeError( "Failed to retrieve the AppProcess: running in wrong context")
+
+def import_failure(rvalue: Any = None):
+ """
+ This decorator captures the failure to load a volume and returns a value instead.
+
+ If the import of a module fails, the decorator returns the value given as a parameter.
+
+ >>> @import_failure(rvalue=False)
+ >>> def my_handler():
+ >>> import pandas
+ >>> return pandas.DataFrame()
+
+ :param rvalue: the value to return
+ """
+ def decorator(func):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ try:
+ return func(*args, **kwargs)
+ except ImportError:
+ return rvalue
+ return wrapper
+ return decorator
+
+
class Config:
is_mail_enabled_for_log: bool = False
@@ -129,12 +165,10 @@ class StateSerialiserException(ValueError):
class StateSerialiser:
-
"""
Serialises user state values before sending them to the front end.
Provides JSON-compatible values, including data URLs for binary data.
"""
-
def serialise(self, v: Any) -> Union[Dict, List, str, bool, int, float, None]:
from writer.ai import Conversation
if isinstance(v, State):
@@ -153,6 +187,9 @@ def serialise(self, v: Any) -> Union[Dict, List, str, bool, int, float, None]:
return self._serialise_list_recursively(v)
if isinstance(v, (str, bool)):
return v
+ if isinstance(v, EditableDataframe):
+ table = v.pyarrow_table()
+ return self._serialise_pyarrow_table(table)
if v is None:
return v
@@ -242,6 +279,44 @@ def _serialise_pyarrow_table(self, table):
bw = BytesWrapper(buf, "application/vnd.apache.arrow.file")
return self.serialise(bw)
+class MutableValue:
+ """
+ MutableValue allows you to implement a value whose modification
+ will be followed by the state of Writer Framework and will trigger the refresh
+ of the user interface.
+
+ >>> class MyValue(MutableValue):
+ >>> def __init__(self, value):
+ >>> self.value = value
+ >>>
+ >>> def modify(self, new_value):
+ >>> self.value = new_value
+ >>> self.mutate()
+ """
+ def __init__(self):
+ self._mutated = False
+
+ def mutated(self) -> bool:
+ """
+ Returns whether the value has been mutated.
+ :return:
+ """
+ return self._mutated
+
+ def mutate(self) -> None:
+ """
+ Marks the value as mutated.
+ This will trigger the refresh of the user interface on the next round trip
+ :return:
+ """
+ self._mutated = True
+
+ def reset_mutation(self) -> None:
+ """
+ Resets the mutation flag to False.
+ :return:
+ """
+ self._mutated = False
class StateProxy:
@@ -349,8 +424,14 @@ def carry_mutation_flag(base_key, child_key):
try:
serialised_value = state_serialiser.serialise(value)
except BaseException:
- raise ValueError(
- f"""Couldn't serialise value of type "{ type(value) }" for key "{ key }".""")
+ raise ValueError(f"""Couldn't serialise value of type "{ type(value) }" for key "{ key }".""")
+ serialised_mutations[f"+{escaped_key}"] = serialised_value
+ elif isinstance(value, MutableValue) is True and value.mutated():
+ try:
+ serialised_value = state_serialiser.serialise(value)
+ value.reset_mutation()
+ except BaseException:
+ raise ValueError(f"""Couldn't serialise value of type "{ type(value) }" for key "{ key }".""")
serialised_mutations[f"+{escaped_key}"] = serialised_value
deleted_keys = \
@@ -1506,6 +1587,408 @@ def __set__(self, instance, value):
proxy = getattr(instance, self.objectName)
proxy[self.key] = value
+
+class DataframeRecordRemove:
+ pass
+
+
+class DataframeRecordProcessor():
+ """
+ This interface defines the signature of the methods to process the events of a
+ dataframe compatible with EditableDataframe.
+
+ A Dataframe can be any structure composed of tabular data.
+
+ This class defines the signature of the methods to be implemented.
+ """
+ __metaclass__ = ABCMeta
+
+ @staticmethod
+ def match(df: Any) -> bool:
+ """
+ This method checks if the dataframe is compatible with the processor.
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ def record(df: Any, record_index: int) -> dict:
+ """
+ This method read a record at the given line and get it back as dictionary
+
+ >>> edf = EditableDataframe(df)
+ >>> r = edf.record(1)
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ def record_add(df: Any, payload: DataframeRecordAdded) -> Any:
+ """
+ signature of the methods to be implemented to process wf-dataframe-add event
+
+ >>> edf = EditableDataframe(df)
+ >>> edf.record_add({"record": {"a": 1, "b": 2}})
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ def record_update(df: Any, payload: DataframeRecordUpdated) -> Any:
+ """
+ signature of the methods to be implemented to process wf-dataframe-update event
+
+ >>> edf = EditableDataframe(df)
+ >>> edf.record_update({"record_index": 12, "record": {"a": 1, "b": 2}})
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ def record_remove(df: Any, payload: DataframeRecordRemoved) -> Any:
+ """
+ signature of the methods to be implemented to process wf-dataframe-action event
+
+ >>> edf = EditableDataframe(df)
+ >>> edf.record_remove({"record_index": 12})
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ def pyarrow_table(df: Any) -> pyarrow.Table:
+ """
+ Serializes the dataframe into a pyarrow table
+ """
+ raise NotImplementedError
+
+
+class PandasRecordProcessor(DataframeRecordProcessor):
+ """
+ PandasRecordProcessor processes records from a pandas dataframe saved into an EditableDataframe
+
+ >>> df = pandas.DataFrame({"a": [1, 2], "b": [3, 4]})
+ >>> edf = EditableDataframe(df)
+ >>> edf.record_add({"a": 1, "b": 2})
+ """
+
+ @staticmethod
+ @import_failure(rvalue=False)
+ def match(df: Any) -> bool:
+ import pandas
+ return True if isinstance(df, pandas.DataFrame) else False
+
+ @staticmethod
+ def record(df: 'pandas.DataFrame', record_index: int) -> dict:
+ """
+
+ >>> edf = EditableDataframe(df)
+ >>> r = edf.record(1)
+ """
+ import pandas
+
+ record = df.iloc[record_index]
+ if not isinstance(df.index, pandas.RangeIndex):
+ index_list = df.index.tolist()
+ record_index_content = index_list[record_index]
+ if isinstance(record_index_content, tuple):
+ for i, n in enumerate(df.index.names):
+ record[n] = record_index_content[i]
+ else:
+ record[df.index.names[0]] = record_index_content
+
+ return dict(record)
+
+ @staticmethod
+ def record_add(df: 'pandas.DataFrame', payload: DataframeRecordAdded) -> 'pandas.DataFrame':
+ """
+ >>> edf = EditableDataframe(df)
+ >>> edf.record_add({"record": {"a": 1, "b": 2}})
+ """
+ import pandas
+
+ _assert_record_match_pandas_df(df, payload['record'])
+
+ record, index = _split_record_as_pandas_record_and_index(payload['record'], df.index.names)
+
+ if isinstance(df.index, pandas.RangeIndex):
+ new_df = pandas.DataFrame([record])
+ return pandas.concat([df, new_df], ignore_index=True)
+ else:
+ new_df = pandas.DataFrame([record], index=[index])
+ return pandas.concat([df, new_df])
+
+ @staticmethod
+ def record_update(df: 'pandas.DataFrame', payload: DataframeRecordUpdated) -> 'pandas.DataFrame':
+ """
+ >>> edf = EditableDataframe(df)
+ >>> edf.record_update({"record_index": 12, "record": {"a": 1, "b": 2}})
+ """
+ _assert_record_match_pandas_df(df, payload['record'])
+
+ record: dict
+ record, index = _split_record_as_pandas_record_and_index(payload['record'], df.index.names)
+
+ record_index = payload['record_index']
+ df.iloc[record_index] = record # type: ignore
+
+ index_list = df.index.tolist()
+ index_list[record_index] = index
+ df.index = index_list # type: ignore
+
+ return df
+
+ @staticmethod
+ def record_remove(df: 'pandas.DataFrame', payload: DataframeRecordRemoved) -> 'pandas.DataFrame':
+ """
+ >>> edf = EditableDataframe(df)
+ >>> edf.record_remove({"record_index": 12})
+ """
+ record_index: int = payload['record_index']
+ idx = df.index[record_index]
+ df = df.drop(idx)
+
+ return df
+
+ @staticmethod
+ def pyarrow_table(df: 'pandas.DataFrame') -> pyarrow.Table:
+ """
+ Serializes the dataframe into a pyarrow table
+ """
+ table = pyarrow.Table.from_pandas(df=df)
+ return table
+
+
+class PolarRecordProcessor(DataframeRecordProcessor):
+ """
+ PolarRecordProcessor processes records from a polar dataframe saved into an EditableDataframe
+
+ >>> df = polars.DataFrame({"a": [1, 2], "b": [3, 4]})
+ >>> edf = EditableDataframe(df)
+ >>> edf.record_add({"record": {"a": 1, "b": 2}})
+ """
+
+ @staticmethod
+ @import_failure(rvalue=False)
+ def match(df: Any) -> bool:
+ import polars
+ return True if isinstance(df, polars.DataFrame) else False
+
+ @staticmethod
+ def record(df: 'polars.DataFrame', record_index: int) -> dict:
+ """
+
+ >>> edf = EditableDataframe(df)
+ >>> r = edf.record(1)
+ """
+ record = {}
+ r = df[record_index]
+ for c in r.columns:
+ record[c] = df[record_index, c]
+
+ return record
+
+
+ @staticmethod
+ def record_add(df: 'polars.DataFrame', payload: DataframeRecordAdded) -> 'polars.DataFrame':
+ _assert_record_match_polar_df(df, payload['record'])
+
+ import polars
+ new_df = polars.DataFrame([payload['record']])
+ return polars.concat([df, new_df])
+
+ @staticmethod
+ def record_update(df: 'polars.DataFrame', payload: DataframeRecordUpdated) -> 'polars.DataFrame':
+ # This implementation works but is not optimal.
+ # I didn't find a better way to update a record in polars
+ #
+ # https://github.com/pola-rs/polars/issues/5973
+ _assert_record_match_polar_df(df, payload['record'])
+
+ record = payload['record']
+ record_index = payload['record_index']
+ for r in record:
+ df[record_index, r] = record[r]
+
+ return df
+
+ @staticmethod
+ def record_remove(df: 'polars.DataFrame', payload: DataframeRecordRemoved) -> 'polars.DataFrame':
+ import polars
+
+ record_index: int = payload['record_index']
+ df_filtered = polars.concat([df[:record_index], df[record_index + 1:]])
+ return df_filtered
+
+ @staticmethod
+ def pyarrow_table(df: 'polars.DataFrame') -> pyarrow.Table:
+ """
+ Serializes the dataframe into a pyarrow table
+ """
+ import pyarrow.interchange
+ table: pyarrow.Table = pyarrow.interchange.from_dataframe(df)
+ return table
+
+class RecordListRecordProcessor(DataframeRecordProcessor):
+ """
+ RecordListRecordProcessor processes records from a list of record saved into an EditableDataframe
+
+ >>> df = [{"a": 1, "b": 2}, {"a": 3, "b": 4}]
+ >>> edf = EditableDataframe(df)
+ >>> edf.record_add({"record": {"a": 1, "b": 2}})
+ """
+
+ @staticmethod
+ def match(df: Any) -> bool:
+ return True if isinstance(df, list) else False
+
+
+ @staticmethod
+ def record(df: List[Dict[str, Any]], record_index: int) -> dict:
+ """
+
+ >>> edf = EditableDataframe(df)
+ >>> r = edf.record(1)
+ """
+ r = df[record_index]
+ return copy.copy(r)
+
+ @staticmethod
+ def record_add(df: List[Dict[str, Any]], payload: DataframeRecordAdded) -> List[Dict[str, Any]]:
+ _assert_record_match_list_of_records(df, payload['record'])
+ df.append(payload['record'])
+ return df
+
+ @staticmethod
+ def record_update(df: List[Dict[str, Any]], payload: DataframeRecordUpdated) -> List[Dict[str, Any]]:
+ _assert_record_match_list_of_records(df, payload['record'])
+
+ record_index = payload['record_index']
+ record = payload['record']
+
+ df[record_index] = record
+ return df
+
+ @staticmethod
+ def record_remove(df: List[Dict[str, Any]], payload: DataframeRecordRemoved) -> List[Dict[str, Any]]:
+ del(df[payload['record_index']])
+ return df
+
+ @staticmethod
+ def pyarrow_table(df: List[Dict[str, Any]]) -> pyarrow.Table:
+ """
+ Serializes the dataframe into a pyarrow table
+ """
+ column_names = list(df[0].keys())
+ columns = {key: [record[key] for record in df] for key in column_names}
+
+ pyarrow_columns = {key: pyarrow.array(values) for key, values in columns.items()}
+ schema = pyarrow.schema([(key, pyarrow_columns[key].type) for key in pyarrow_columns])
+ table = pyarrow.Table.from_arrays(
+ [pyarrow_columns[key] for key in column_names],
+ schema=schema
+ )
+
+ return table
+
+class EditableDataframe(MutableValue):
+ """
+ Editable Dataframe makes it easier to process events from components
+ that modify a dataframe like the dataframe editor.
+
+ >>> initial_state = wf.init_state({
+ >>> "df": wf.EditableDataframe(df)
+ >>> })
+
+ Editable Dataframe is compatible with a pandas, thrillers or record list dataframe
+ """
+ processors = [PandasRecordProcessor, PolarRecordProcessor, RecordListRecordProcessor]
+
+ def __init__(self, df: Union['pandas.DataFrame', 'polars.DataFrame', List[dict]]):
+ super().__init__()
+ self._df = df
+ self.processor: Type[DataframeRecordProcessor]
+ for processor in self.processors:
+ if processor.match(self.df):
+ self.processor = processor
+ break
+
+ if self.processor is None:
+ raise ValueError("The dataframe must be a pandas, polar Dataframe or a list of record")
+
+ @property
+ def df(self) -> Union['pandas.DataFrame', 'polars.DataFrame', List[dict]]:
+ return self._df
+
+ @df.setter
+ def df(self, value: Union['pandas.DataFrame', 'polars.DataFrame', List[dict]]) -> None:
+ self._df = value
+ self.mutate()
+
+ def record_add(self, payload: DataframeRecordAdded) -> None:
+ """
+ Adds a record to the dataframe
+
+ >>> df = pandas.DataFrame({"a": [1, 2], "b": [3, 4]})
+ >>> edf = EditableDataframe(df)
+ >>> edf.record_add({"record": {"a": 1, "b": 2}})
+ """
+ assert self.processor is not None
+
+ self._df = self.processor.record_add(self.df, payload)
+ self.mutate()
+
+ def record_update(self, payload: DataframeRecordUpdated) -> None:
+ """
+ Updates a record in the dataframe
+
+ The record must be complete otherwise an error is raised (ValueError).
+ It must a value for each index / column.
+
+ >>> df = pandas.DataFrame({"a": [1, 2], "b": [3, 4]})
+ >>> edf = EditableDataframe(df)
+ >>> edf.record_update({"record_index": 0, "record": {"a": 2, "b": 2}})
+ """
+ assert self.processor is not None
+
+ self._df = self.processor.record_update(self.df, payload)
+ self.mutate()
+
+ def record_remove(self, payload: DataframeRecordRemoved) -> None:
+ """
+ Removes a record from the dataframe
+
+ >>> df = pandas.DataFrame({"a": [1, 2], "b": [3, 4]})
+ >>> edf = EditableDataframe(df)
+ >>> edf.record_remove({"record_index": 0})
+ """
+ assert self.processor is not None
+
+ self._df = self.processor.record_remove(self.df, payload)
+ self.mutate()
+
+ def pyarrow_table(self) -> pyarrow.Table:
+ """
+ Serializes the dataframe into a pyarrow table
+
+ This mechanism is used for serializing data for transmission to the frontend.
+
+ >>> df = pandas.DataFrame({"a": [1, 2], "b": [3, 4]})
+ >>> edf = EditableDataframe(df)
+ >>> pa_table = edf.pyarrow_table()
+ """
+ assert self.processor is not None
+
+ pa_table = self.processor.pyarrow_table(self.df)
+ return pa_table
+
+ def record(self, record_index: int):
+ """
+ Retrieves a specific record in dictionary form.
+
+ :param record_index:
+ :return:
+ """
+ assert self.processor is not None
+
+ record = self.processor.record(self.df, record_index)
+ return record
+
S = TypeVar("S", bound=WriterState)
def new_initial_state(klass: Type[S], raw_state: dict) -> S:
@@ -1623,6 +2106,63 @@ async def _async_wrapper_internal(callable_handler: Callable, arg_values: List[A
result = await callable_handler(*arg_values)
return result
+def _assert_record_match_pandas_df(df: 'pandas.DataFrame', record: Dict[str, Any]) -> None:
+ """
+ Asserts that the record matches the dataframe columns & index
+
+ >>> _assert_record_match_pandas_df(pandas.DataFrame({"a": [1, 2], "b": [3, 4]}), {"a": 1, "b": 2})
+ """
+ import pandas
+
+ columns = set(list(df.columns.values) + df.index.names) if isinstance(df.index, pandas.RangeIndex) is False else set(df.columns.values)
+ columns_record = set(record.keys())
+ if columns != columns_record:
+ raise ValueError(f"Columns mismatch. Expected {columns}, got {columns_record}")
+
+def _assert_record_match_polar_df(df: 'polars.DataFrame', record: Dict[str, Any]) -> None:
+ """
+ Asserts that the record matches the columns of polar dataframe
+
+ >>> _assert_record_match_pandas_df(polars.DataFrame({"a": [1, 2], "b": [3, 4]}), {"a": 1, "b": 2})
+ """
+ columns = set(df.columns)
+ columns_record = set(record.keys())
+ if columns != columns_record:
+ raise ValueError(f"Columns mismatch. Expected {columns}, got {columns_record}")
+
+def _assert_record_match_list_of_records(df: List[Dict[str, Any]], record: Dict[str, Any]) -> None:
+ """
+ Asserts that the record matches the key in the record list (it use the first record to check)
+
+ >>> _assert_record_match_list_of_records([{"a": 1, "b": 2}, {"a": 3, "b": 4}], {"a": 1, "b": 2})
+ """
+ if len(df) == 0:
+ return
+
+ columns = set(df[0].keys())
+ columns_record = set(record.keys())
+ if columns != columns_record:
+ raise ValueError(f"Columns mismatch. Expected {columns}, got {columns_record}")
+
+
+def _split_record_as_pandas_record_and_index(param: dict, index_columns: list) -> Tuple[dict, tuple]:
+ """
+ Separates a record into the record part and the index part to be able to
+ create or update a row in a dataframe.
+
+ >>> record, index = _split_record_as_pandas_record_and_index({"a": 1, "b": 2}, ["a"])
+ >>> print(record) # {"b": 2}
+ >>> print(index) # (1,)
+ """
+ final_record = {}
+ final_index = []
+ for key, value in param.items():
+ if key in index_columns:
+ final_index.append(value)
+ else:
+ final_record[key] = value
+
+ return final_record, tuple(final_index)
state_serialiser = StateSerialiser()
initial_state = WriterState()
diff --git a/src/writer/deploy.py b/src/writer/deploy.py
index d09bdf654..dfba5ba57 100644
--- a/src/writer/deploy.py
+++ b/src/writer/deploy.py
@@ -1,29 +1,95 @@
import json
+import logging
import os
+import re
import sys
import tarfile
import tempfile
import time
from datetime import datetime, timedelta
-from typing import List
+from typing import List, Union
+import click
import dateutil.parser
import pytz
import requests
from gitignore_parser import parse_gitignore
-WRITER_DEPLOY_URL = os.getenv("WRITER_DEPLOY_URL", "https://api.writer.com/v1/deployment/apps")
-def deploy(path, token, env):
- check_app(token)
- tar = pack_project(path)
- upload_package(tar, token, env)
+@click.group()
+def cloud():
+ """A group of commands to deploy the app on writer cloud"""
+ pass
-def undeploy(token):
+@cloud.command()
+@click.option('--api-key',
+ default=lambda: os.environ.get("WRITER_API_KEY", None),
+ allow_from_autoenv=True,
+ show_envvar=True,
+ envvar='WRITER_API_KEY',
+ prompt="Enter your API key",
+ hide_input=True, help="Writer API key"
+)
+@click.option('--env', '-e', multiple=True, default=[], help="Environment to deploy the app to")
+@click.option('--force', '-f', default=False, is_flag=True, help="Ignores warnings and overwrites the app")
+@click.option('--verbose', '-v', default=False, is_flag=True, help="Enable verbose mode")
+@click.argument('path')
+def deploy(path, api_key, env, verbose, force):
+ """Deploy the app from PATH folder."""
+
+ deploy_url = os.getenv("WRITER_DEPLOY_URL", "https://api.writer.com/v1/deployment/apps")
+ sleep_interval = int(os.getenv("WRITER_DEPLOY_SLEEP_INTERVAL", '5'))
+
+ if not force:
+ check_app(deploy_url, api_key)
+
+ abs_path = os.path.abspath(path)
+ if not os.path.isdir(abs_path):
+ raise click.ClickException("A path to a folder containing a Writer Framework app is required. For example: writer cloud deploy my_app")
+
+ env = _validate_env_vars(env)
+ tar = pack_project(abs_path)
+ try:
+ upload_package(deploy_url, tar, api_key, env, verbose=verbose, sleep_interval=sleep_interval)
+ except requests.exceptions.HTTPError as e:
+ if e.response.status_code == 401:
+ unauthorized_error()
+ else:
+ on_error_print_and_raise(e.response, verbose=verbose)
+ except Exception as e:
+ print(e)
+ print("Error deploying app")
+ sys.exit(1)
+ finally:
+ tar.close()
+
+def _validate_env_vars(env: Union[List[str], None]) -> Union[List[str], None]:
+ if env is None:
+ return None
+ for var in env:
+ regex = r"^[a-zA-Z_]+[a-zA-Z0-9_]*=.*$"
+ if not re.match(regex, var):
+ logging.error(f"Invalid environment variable: {var}, please use the format ENV_VAR=value")
+ sys.exit(1)
+ return env
+
+@cloud.command()
+@click.option('--api-key',
+ default=lambda: os.environ.get("WRITER_API_KEY", None),
+ allow_from_autoenv=True,
+ show_envvar=True,
+ envvar='WRITER_API_KEY',
+ prompt="Enter your API key",
+ hide_input=True, help="Writer API key"
+)
+@click.option('--verbose', '-v', default=False, is_flag=True, help="Enable verbose mode")
+def undeploy(api_key, verbose):
+ """Stop the app, app would not be available anymore."""
try:
print("Undeploying app")
- with requests.delete(WRITER_DEPLOY_URL, headers={"Authorization": f"Bearer {token}"}) as resp:
- resp.raise_for_status()
+ deploy_url = os.getenv("WRITER_DEPLOY_URL", "https://api.writer.com/v1/deployment/apps")
+ with requests.delete(deploy_url, headers={"Authorization": f"Bearer {api_key}"}) as resp:
+ on_error_print_and_raise(resp, verbose=verbose)
print("App undeployed")
sys.exit(0)
except Exception as e:
@@ -31,30 +97,45 @@ def undeploy(token):
print(e)
sys.exit(1)
-def runtime_logs(token):
+@cloud.command()
+@click.option('--api-key',
+ default=lambda: os.environ.get("WRITER_API_KEY", None),
+ allow_from_autoenv=True,
+ show_envvar=True,
+ envvar='WRITER_API_KEY',
+ prompt="Enter your API key",
+ hide_input=True, help="Writer API key"
+)
+@click.option('--verbose', '-v', default=False, is_flag=True, help="Enable verbose mode")
+def logs(api_key, verbose):
+ """Fetch logs from the deployed app."""
+
+ deploy_url = os.getenv("WRITER_DEPLOY_URL", "https://api.writer.com/v1/deployment/apps")
+ sleep_interval = int(os.getenv("WRITER_DEPLOY_SLEEP_INTERVAL", '5'))
+
try:
build_time = datetime.now(pytz.timezone('UTC')) - timedelta(days=4)
start_time = build_time
while True:
prev_start = start_time
end_time = datetime.now(pytz.timezone('UTC'))
- data = get_logs(token, {
+ data = get_logs(deploy_url, api_key, {
"buildTime": build_time,
"startTime": start_time,
"endTime": end_time,
- })
+ }, verbose=verbose)
# order logs by date and print
logs = data['logs']
for log in logs:
start_time = start_time if start_time > log[0] else log[0]
if start_time == prev_start:
start_time = datetime.now(pytz.timezone('UTC'))
- time.sleep(5)
+ time.sleep(sleep_interval)
continue
for log in logs:
print(log[0], log[1])
print(start_time)
- time.sleep(1)
+ time.sleep(sleep_interval)
except Exception as e:
print(e)
sys.exit(1)
@@ -91,9 +172,8 @@ def match(file_path) -> bool: return False
return f
-
-def check_app(token):
- url = get_app_url(token)
+def check_app(deploy_url, token):
+ url = _get_app_url(deploy_url, token)
if url:
print("[WARNING] This token was already used to deploy a different app")
print(f"[WARNING] URL: {url}")
@@ -101,8 +181,8 @@ def check_app(token):
if input("[WARNING] Are you sure you want to overwrite? (y/N)").lower() != "y":
sys.exit(1)
-def get_app_url(token):
- with requests.get(WRITER_DEPLOY_URL, params={"lineLimit": 1}, headers={"Authorization": f"Bearer {token}"}) as resp:
+def _get_app_url(deploy_url: str, token: str) -> Union[str, None]:
+ with requests.get(deploy_url, params={"lineLimit": 1}, headers={"Authorization": f"Bearer {token}"}) as resp:
try:
resp.raise_for_status()
except Exception as e:
@@ -112,13 +192,9 @@ def get_app_url(token):
data = resp.json()
return data['status']['url']
-def get_logs(token, params):
- with requests.get(WRITER_DEPLOY_URL, params = params, headers={"Authorization": f"Bearer {token}"}) as resp:
- try:
- resp.raise_for_status()
- except Exception as e:
- print(resp.json())
- raise e
+def get_logs(deploy_url, token, params, verbose=False):
+ with requests.get(deploy_url, params = params, headers={"Authorization": f"Bearer {token}"}) as resp:
+ on_error_print_and_raise(resp, verbose=verbose)
data = resp.json()
logs = []
@@ -130,8 +206,8 @@ def get_logs(token, params):
logs.sort(key=lambda x: x[0])
return {"status": data["status"], "logs": logs}
-def check_service_status(token, build_id, build_time, start_time, end_time, last_status):
- data = get_logs(token, {
+def check_service_status(deploy_url, token, build_id, build_time, start_time, end_time, last_status):
+ data = get_logs(deploy_url, token, {
"buildId": build_id,
"buildTime": build_time,
"startTime": start_time,
@@ -156,53 +232,54 @@ def dictFromEnv(env: List[str]) -> dict:
return env_dict
-def upload_package(tar, token, env):
- try:
- print("Uploading package to deployment server")
- tar.seek(0)
- files = {'file': tar}
- start_time = datetime.now(pytz.timezone('UTC'))
- build_time = start_time
- with requests.post(
- url = WRITER_DEPLOY_URL,
- headers = {
- "Authorization": f"Bearer {token}",
- },
- files=files,
- data={"envs": json.dumps(dictFromEnv(env))}
- ) as resp:
- try:
- resp.raise_for_status()
- except Exception as e:
- print(resp.json())
- raise e
- data = resp.json()
- build_id = data["buildId"]
-
- print("Package uploaded. Building...")
- status = "WAITING"
- url = ""
- while status not in ["COMPLETED", "FAILED"] and datetime.now(pytz.timezone('UTC')) < build_time + timedelta(minutes=5):
- end_time = datetime.now(pytz.timezone('UTC'))
- status, url = check_service_status(token, build_id, build_time, start_time, end_time, status)
- time.sleep(5)
- start_time = end_time
- if status == "COMPLETED":
- print("Deployment successful")
- print(f"URL: {url}")
- sys.exit(0)
- else:
- time.sleep(5)
- check_service_status(token, build_id, build_time, start_time, datetime.now(pytz.timezone('UTC')), status)
- print("Deployment failed")
- sys.exit(1)
+def upload_package(deploy_url, tar, token, env, verbose=False, sleep_interval=5):
+ print("Uploading package to deployment server")
+ tar.seek(0)
+ files = {'file': tar}
+ start_time = datetime.now(pytz.timezone('UTC'))
+ build_time = start_time
- except Exception as e:
- print("Error uploading package")
- print(e)
+ with requests.post(
+ url = deploy_url,
+ headers = {
+ "Authorization": f"Bearer {token}",
+ },
+ files=files,
+ data={"envs": json.dumps(dictFromEnv(env))}
+ ) as resp:
+ on_error_print_and_raise(resp, verbose=verbose)
+ data = resp.json()
+ build_id = data["buildId"]
+
+ print("Package uploaded. Building...")
+ status = "WAITING"
+ url = ""
+ while status not in ["COMPLETED", "FAILED"] and datetime.now(pytz.timezone('UTC')) < build_time + timedelta(minutes=5):
+ end_time = datetime.now(pytz.timezone('UTC'))
+ status, url = check_service_status(deploy_url, token, build_id, build_time, start_time, end_time, status)
+ time.sleep(sleep_interval)
+ start_time = end_time
+
+ if status == "COMPLETED":
+ print("Deployment successful")
+ print(f"URL: {url}")
+ sys.exit(0)
+ else:
+ time.sleep(sleep_interval)
+ check_service_status(deploy_url, token, build_id, build_time, start_time, datetime.now(pytz.timezone('UTC')), status)
+ print("Deployment failed")
sys.exit(1)
- finally:
- tar.close()
+def on_error_print_and_raise(resp, verbose=False):
+ try:
+ resp.raise_for_status()
+ except Exception as e:
+ if verbose:
+ print(resp.json())
+ raise e
+
+def unauthorized_error():
+ print("Unauthorized. Please check your API key.")
+ sys.exit(1)
diff --git a/src/writer/ss_types.py b/src/writer/ss_types.py
index d82d2cc17..c7e65d9e8 100644
--- a/src/writer/ss_types.py
+++ b/src/writer/ss_types.py
@@ -161,10 +161,19 @@ class StateEnquiryResponse(AppProcessServerResponse):
payload: Optional[StateEnquiryResponsePayload]
-AppProcessServerResponsePacket = Tuple[int,
- Optional[str], AppProcessServerResponse]
+AppProcessServerResponsePacket = Tuple[int, Optional[str], AppProcessServerResponse]
+class DataframeRecordAdded(TypedDict):
+ record: Dict[str, Any]
+
+class DataframeRecordUpdated(TypedDict):
+ record_index: int
+ record: Dict[str, Any]
+
+class DataframeRecordRemoved(TypedDict):
+ record_index: int
+
class WriterEventResult(TypedDict):
ok: bool
result: Any
diff --git a/tests/backend/fixtures/cloud_deploy_fixtures.py b/tests/backend/fixtures/cloud_deploy_fixtures.py
new file mode 100644
index 000000000..bfe0b734b
--- /dev/null
+++ b/tests/backend/fixtures/cloud_deploy_fixtures.py
@@ -0,0 +1,132 @@
+import contextlib
+import json
+import re
+import threading
+import time
+from datetime import datetime, timedelta
+from typing import Annotated, Union
+
+import pytest
+import pytz
+import uvicorn
+from click.testing import CliRunner
+from fastapi import Body, Depends, FastAPI, File, Header, UploadFile
+from writer.command_line import main
+
+
+def create_app():
+ class State:
+ log_counter = 0
+ envs: Union[str, None] = None
+
+ state = State()
+ app = FastAPI()
+
+
+ @app.post("/deploy")
+ def deploy(
+ state: Annotated[State, Depends(lambda: state)],
+ authorization: Annotated[str, Header(description="The API key")],
+ file: UploadFile = File(...),
+ envs: Annotated[str, Body(description = 'JSON object of environment variables')] = "{}",
+ ):
+ state.envs = envs
+ return {"status": "ok", "buildId": "123"}
+
+
+ @app.get("/deploy")
+ def get_status(
+ state: Annotated[State, Depends(lambda: state)],
+ authorization: Annotated[str, Header(description="The API key")],
+ ):
+
+ def get_time(n):
+ return (datetime.now(pytz.timezone('UTC')) + timedelta(seconds=n)).isoformat()
+
+ state.log_counter += 1
+ if (authorization == "Bearer full"):
+ if state.log_counter == 1: # first call is to checking if app exist
+ return {
+ "logs": [],
+ "status": {
+ "url": None,
+ "status": "PENDING",
+ }
+ }
+ if state.log_counter == 2:
+ return {
+ "logs": [
+ {"log": f"{get_time(-7)} stdout F {state.envs} "},
+ {"log": f"{get_time(-6)} stdout F "},
+ {"log": f"{get_time(-5)} stdout F "},
+ ],
+ "status": {
+ "url": None,
+ "status": "BUILDING",
+ }
+ }
+ if state.log_counter == 3:
+ return {
+ "logs": [
+ {"log": f"{get_time(-2)} stdout F "},
+ {"log": f"{get_time(-4)} stdout F "},
+ ],
+ "status": {
+ "url": "https://full.my-app.com",
+ "status": "COMPLETED",
+ }
+ }
+ if (authorization == "Bearer test"):
+ return {
+ "logs": [
+ {"log": f"20210813163223 stdout F {state.envs} "},
+ ],
+ "status": {
+ "url": "https://my-app.com",
+ "status": "COMPLETED",
+ }
+ }
+ return {
+ "logs": [],
+ "status": {
+ "url": None,
+ "status": "FAILED",
+ }
+ }
+
+ @app.delete("/deploy")
+ def undeploy(
+ authorization: Annotated[str, Header(description="The API key")],
+ ):
+ return {"status": "ok"}
+ return app
+
+
+class Server(uvicorn.Server):
+ def __init__(self):
+ config = uvicorn.Config(create_app(), host="127.0.0.1", port=8888, log_level="info")
+ super().__init__(config)
+ self.keep_running = True
+
+ def install_signal_handlers(self):
+ pass
+
+ @contextlib.contextmanager
+ def run_in_thread(self):
+ thread = threading.Thread(target=self.run)
+ thread.start()
+ try:
+ while not self.started:
+ time.sleep(1e-3)
+ yield
+ finally:
+ self.should_exit = True
+ thread.join()
+
+
+
+@contextlib.contextmanager
+def use_fake_cloud_deploy_server():
+ server = Server()
+ with server.run_in_thread():
+ yield server
diff --git a/tests/backend/test_ai.py b/tests/backend/test_ai.py
index b569f1a22..1d892e3b4 100644
--- a/tests/backend/test_ai.py
+++ b/tests/backend/test_ai.py
@@ -55,6 +55,7 @@
SDKFile,
SDKGraph,
WriterAIManager,
+ apps,
complete,
create_graph,
delete_file,
@@ -69,7 +70,13 @@
)
from writerai import Writer
from writerai._streaming import Stream
-from writerai.types import Chat, ChatStreamingData, Completion, StreamingData
+from writerai.types import (
+ ApplicationGenerateContentResponse,
+ Chat,
+ ChatStreamingData,
+ Completion,
+ StreamingData,
+)
# Decorator to mark tests as explicit, i.e. that they only to be run on direct demand
explicit = pytest.mark.explicit
@@ -78,6 +85,20 @@
test_complete_literal = "Completed text"
+@pytest.fixture
+def mock_app_content_generation():
+ with patch('writer.ai.WriterAIManager.acquire_client') as mock_acquire_client:
+ original_client = Writer(api_key="fake_token")
+ non_streaming_client = AsyncMock(original_client)
+ mock_acquire_client.return_value = non_streaming_client
+
+ non_streaming_client.applications.generate_content.return_value = ApplicationGenerateContentResponse(
+ suggestion=test_complete_literal
+ )
+
+ yield non_streaming_client
+
+
@pytest.fixture
def mock_non_streaming_client():
with patch('writer.ai.WriterAIManager.acquire_client') as mock_acquire_client:
@@ -377,6 +398,16 @@ def test_stream_complete(emulate_app_process, mock_streaming_client):
assert "".join(response_chunks) == "part1 part2"
+@pytest.mark.set_token("fake_token")
+def test_generate_content_from_app(emulate_app_process, mock_app_content_generation):
+ response = apps.generate_content("abc123", {
+ "Favorite animal": "Dog",
+ "Favorite color": "Purple"
+ })
+
+ assert response == test_complete_literal
+
+
@pytest.mark.set_token("fake_token")
def test_init_writer_ai_manager(emulate_app_process):
manager = init("fake_token")
@@ -603,3 +634,6 @@ def test_explicit_delete_file(emulate_app_process, created_files):
created_files.remove(uploaded_file)
assert response.deleted is True
+
+# For doing a explicit test of apps.generate_content() we need a no-code app that
+# nobody will touch. That is a challenge.
diff --git a/tests/backend/test_auth.py b/tests/backend/test_auth.py
index 7236d0cff..443ef11c7 100644
--- a/tests/backend/test_auth.py
+++ b/tests/backend/test_auth.py
@@ -1,6 +1,8 @@
import fastapi
import fastapi.testclient
+import pytest
import writer.serve
+from writer import auth
from tests.backend import test_basicauth_dir
@@ -35,3 +37,38 @@ def test_basicauth_authentication_module_disabled_when_server_setup_hook_is_disa
with fastapi.testclient.TestClient(asgi_app) as client:
res = client.get("/api/init")
assert res.status_code == 405
+
+ @pytest.mark.parametrize("path,expected_path", [
+ ("", "/"),
+ ("http://localhost", "/"),
+ ("http://localhost/", "/"),
+ ("http://localhost/any", "/any"),
+ ("http://localhost/any/", "/any/"),
+ ("/any/yolo", "/any/yolo")
+ ])
+ def test_url_path_scenarios(self, path: str, expected_path: str):
+ assert auth.urlpath(path) == expected_path
+
+ @pytest.mark.parametrize("path,expected_path", [
+ ("/", ""),
+ ("/yolo", "yolo"),
+ ("/yolo/", "yolo"),
+ ("http://localhost", "http://localhost"),
+ ("http://localhost/", "http://localhost"),
+ ("http://localhost/any", "http://localhost/any"),
+ ("http://localhost/any/", "http://localhost/any")
+ ])
+ def test_url_split_scenarios(self, path: str, expected_path: str):
+ assert auth.urlstrip(path) == expected_path
+
+ @pytest.mark.parametrize("path1,path2,expected_path", [
+ ("/", "any", "/any"),
+ ("", "any", "any"),
+ ("/yolo", "any", "/yolo/any"),
+ ("/yolo", "/any", "/yolo/any"),
+ ("http://localhost", "any", "http://localhost/any"),
+ ("http://localhost/", "/any", "http://localhost/any"),
+ ("http://localhost/yolo", "/any", "http://localhost/yolo/any"),
+ ])
+ def test_urljoin_scenarios(self, path1: str, path2, expected_path: str):
+ assert auth.urljoin(path1, path2) == expected_path
diff --git a/tests/backend/test_cli.py b/tests/backend/test_cli.py
new file mode 100644
index 000000000..9635e809f
--- /dev/null
+++ b/tests/backend/test_cli.py
@@ -0,0 +1,101 @@
+import ctypes
+import os
+import platform
+import subprocess
+import time
+
+import requests
+from click.testing import CliRunner
+from writer.command_line import main
+
+
+def test_version():
+ runner = CliRunner()
+ with runner.isolated_filesystem():
+ result = runner.invoke(main, ['-v'])
+ assert result.exit_code == 0
+ assert 'version' in result.output
+
+def test_create_default():
+ runner = CliRunner()
+ with runner.isolated_filesystem():
+ result = runner.invoke(main, ['create', './my_app'])
+ print(result.output)
+ assert result.exit_code == 0
+ assert os.path.exists('./my_app')
+ assert os.path.exists('./my_app/ui.json')
+ assert os.path.exists('./my_app/main.py')
+ with open('./my_app/pyproject.toml') as f:
+ content = f.read()
+ assert content.find('name = "writer-framework-default"') != -1
+
+def test_create_specific_template():
+ runner = CliRunner()
+ with runner.isolated_filesystem():
+ result = runner.invoke(main, ['create', './my_app', '--template', 'hello'])
+ print(result.output)
+ assert result.exit_code == 0
+ assert os.path.exists('./my_app')
+ assert os.path.exists('./my_app/ui.json')
+ assert os.path.exists('./my_app/main.py')
+ with open('./my_app/pyproject.toml') as f:
+ content = f.read()
+ assert content.find('name = "writer-framework-hello"') != -1
+
+
+def test_run():
+ runner = CliRunner()
+ p = None
+ try:
+ with runner.isolated_filesystem():
+ runner.invoke(main, ['create', './my_app', '--template', 'hello'])
+ p = subprocess.Popen(["writer", "run", "my_app", "--port", "5001"], shell=(platform.system() == 'Windows'))
+
+ retry = 0
+ success = False
+ while True:
+ try:
+ response = requests.get('http://127.0.0.1:5001')
+ if response.status_code == 200:
+ success = True
+ break
+ if response.status_code != 200:
+ raise Exception("Status code is not 200")
+ except Exception:
+ time.sleep(1)
+ retry += 1
+ if retry > 10:
+ break
+ assert success == True
+ finally:
+ if p is not None:
+ p.terminate()
+
+
+def test_edit():
+ runner = CliRunner()
+ p = None
+ try:
+ with runner.isolated_filesystem():
+ runner.invoke(main, ['create', './my_app', '--template', 'hello'])
+ p = subprocess.Popen(["writer", "edit", "my_app", "--port", "5002"], shell=(platform.system() == 'Windows'))
+
+ retry = 0
+ success = False
+ while True:
+ try:
+ response = requests.get('http://127.0.0.1:5002')
+ if response.status_code == 200:
+ success = True
+ break
+ if response.status_code != 200:
+ raise Exception("Status code is not 200")
+ except Exception:
+ retry += 1
+ time.sleep(1)
+ if retry > 10:
+ break
+ assert success == True
+ finally:
+ if p is not None:
+ p.terminate()
diff --git a/tests/backend/test_core.py b/tests/backend/test_core.py
index 7122c96a3..b3bf26293 100644
--- a/tests/backend/test_core.py
+++ b/tests/backend/test_core.py
@@ -6,8 +6,10 @@
import altair
import numpy as np
+import pandas
import pandas as pd
import plotly.express as px
+import polars
import polars as pl
import pyarrow as pa
import pytest
@@ -17,11 +19,13 @@
Evaluator,
EventDeserialiser,
FileWrapper,
+ MutableValue,
SessionManager,
State,
StateSerialiser,
StateSerialiserException,
WriterState,
+ import_failure,
)
from writer.core_ui import Component
from writer.ss_types import WriterEvent
@@ -192,6 +196,70 @@ def test_to_raw_state(self) -> None:
assert self.sp.to_raw_state() == raw_state_dict
assert self.sp_simple_dict.to_raw_state() == simple_dict
+ def test_mutable_value_should_raise_mutation(self) -> None:
+ """
+ Tests that a class that implements MutableValue can be used in a State and throw mutations.
+ """
+ class MyValue(MutableValue):
+
+ def __init__(self):
+ super().__init__()
+ self._value = 0
+
+ def set(self, value):
+ self._value = value
+ self.mutate()
+
+ def to_dict(self):
+ return {"a": self._value}
+
+ s = WriterState({
+ "value": MyValue()
+ })
+ # Reset the mutation after initialisation
+ s._state_proxy.get_mutations_as_dict()
+
+ # When
+ s["value"].set(2)
+ a = s._state_proxy.get_mutations_as_dict()
+
+ # Then
+ assert "+value" in a
+ assert a["+value"] == {"a": 2}
+
+ def test_mutable_value_should_reset_mutation_after_reading_get_mutations(self) -> None:
+ """
+ Tests that after reading the mutations, they are reset to zero
+ with a focus on the MutableValue.
+ """
+ class MyValue(MutableValue):
+
+ def __init__(self):
+ super().__init__()
+ self._value = 0
+
+ def set(self, value):
+ self._value = value
+ self.mutate()
+
+ def to_dict(self):
+ return {"a": self._value}
+
+ s = WriterState({
+ "value": MyValue()
+ })
+ # Reset the mutation after initialisation
+ s._state_proxy.get_mutations_as_dict()
+
+ # Then
+ s["value"].set(2)
+ s._state_proxy.get_mutations_as_dict()
+
+ # Mutation is read a second time
+ a = s._state_proxy.get_mutations_as_dict()
+
+ # Then
+ assert a == {}
class TestState:
@@ -991,3 +1059,370 @@ def session_verifier_2(headers: Dict[str, str]) -> None:
None
)
assert s_invalid is None
+
+class TestEditableDataframe:
+
+ def test_editable_dataframe_expose_pandas_dataframe_as_df_property(self) -> None:
+ df = pandas.DataFrame({
+ "name": ["Alice", "Bob", "Charlie"],
+ "age": [25, 30, 35]
+ })
+
+ edf = wf.EditableDataframe(df)
+ assert edf.df is not None
+ assert isinstance(edf.df, pandas.DataFrame)
+
+ def test_editable_dataframe_register_mutation_when_df_is_updated(self) -> None:
+ # Given
+ df = pandas.DataFrame({
+ "name": ["Alice", "Bob", "Charlie"],
+ "age": [25, 30, 35]
+ })
+
+ edf = wf.EditableDataframe(df)
+
+ # When
+ edf.df.loc[0, "age"] = 26
+ edf.df = edf.df
+
+ # Then
+ assert edf.mutated() is True
+
+ def test_editable_dataframe_should_read_record_as_dict_based_on_record_index(self) -> None:
+ df = pandas.DataFrame({
+ "name": ["Alice", "Bob", "Charlie"],
+ "age": [25, 30, 35]
+ })
+ edf = wf.EditableDataframe(df)
+
+ # When
+ r = edf.record(0)
+
+ # Then
+ assert r['name'] == 'Alice'
+ assert r['age'] == 25
+
+ def test_editable_dataframe_should_read_record_as_dict_based_on_record_index_when_dataframe_has_index(self) -> None:
+ df = pandas.DataFrame({
+ "name": ["Alice", "Bob", "Charlie"],
+ "age": [25, 30, 35]
+ })
+ df = df.set_index('name')
+
+ edf = wf.EditableDataframe(df)
+
+ # When
+ r = edf.record(0)
+
+ # Then
+ assert r['name'] == 'Alice'
+ assert r['age'] == 25
+
+ def test_editable_dataframe_should_read_record_as_dict_based_on_record_index_when_dataframe_has_multi_index(self) -> None:
+ df = pandas.DataFrame({
+ "name": ["Alice", "Bob", "Charlie"],
+ "age": [25, 30, 35],
+ "city": ["Paris", "London", "New York"]
+ })
+ df = df.set_index(['name', 'city'])
+
+ edf = wf.EditableDataframe(df)
+
+ # When
+ r = edf.record(0)
+
+ # Then
+ assert r['name'] == 'Alice'
+ assert r['age'] == 25
+ assert r['city'] == 'Paris'
+
+ def test_editable_dataframe_should_process_new_record_into_dataframe(self) -> None:
+ df = pandas.DataFrame({
+ "name": ["Alice", "Bob", "Charlie"],
+ "age": [25, 30, 35]
+ })
+
+ edf = wf.EditableDataframe(df)
+
+ # When
+ edf.record_add({"record": {"name": "David", "age": 40}})
+
+ # Then
+ assert len(edf.df) == 4
+ assert edf.df.index.tolist()[3] == 3
+
+ def test_editable_dataframe_should_process_new_record_into_dataframe_with_index(self) -> None:
+ df = pandas.DataFrame({
+ "name": ["Alice", "Bob", "Charlie"],
+ "age": [25, 30, 35]
+ })
+ df = df.set_index('name')
+
+ edf = wf.EditableDataframe(df)
+
+ # When
+ edf.record_add({"record": {"name": "David", "age": 40}})
+
+ # Then
+ assert len(edf.df) == 4
+
+ def test_editable_dataframe_should_process_new_record_into_dataframe_with_multiindex(self) -> None:
+ df = pandas.DataFrame({
+ "name": ["Alice", "Bob", "Charlie"],
+ "age": [25, 30, 35],
+ "city": ["Paris", "London", "New York"]
+ })
+ df = df.set_index(['name', 'city'])
+
+ edf = wf.EditableDataframe(df)
+
+ # When
+ edf.record_add({"record": {"name": "David", "age": 40, "city": "Berlin"}})
+
+ # Then
+ assert len(edf.df) == 4
+
+ def test_editable_dataframe_should_update_existing_record_as_dateframe_with_multiindex(self) -> None:
+ df = pandas.DataFrame({
+ "name": ["Alice", "Bob", "Charlie"],
+ "age": [25, 30, 35],
+ "city": ["Paris", "London", "New York"]
+ })
+
+ df = df.set_index(['name', 'city'])
+
+ edf = wf.EditableDataframe(df)
+
+ # When
+ edf.record_update({"record_index": 0, "record": {"name": "Alicia", "age": 25, "city": "Paris"}})
+
+ # Then
+ assert edf.df.iloc[0]['age'] == 25
+
+ def test_editable_dataframe_should_remove_existing_record_as_dateframe_with_multiindex(self) -> None:
+ df = pandas.DataFrame({
+ "name": ["Alice", "Bob", "Charlie"],
+ "age": [25, 30, 35],
+ "city": ["Paris", "London", "New York"]
+ })
+
+ df = df.set_index(['name', 'city'])
+
+ edf = wf.EditableDataframe(df)
+
+ # When
+ edf.record_remove({"record_index": 0})
+
+ # Then
+ assert len(edf.df) == 2
+
+ def test_editable_dataframe_should_serialize_pandas_dataframe_with_multiindex(self) -> None:
+ df = pandas.DataFrame({
+ "name": ["Alice", "Bob", "Charlie"],
+ "age": [25, 30, 35],
+ "city": ["Paris", "London", "New York"]
+ })
+ df = df.set_index(['name', 'city'])
+
+ edf = wf.EditableDataframe(df)
+
+ # When
+ table = edf.pyarrow_table()
+
+ # Then
+ assert len(table) == 3
+
+ def test_editable_dataframe_expose_polar_dataframe_in_df_property(self) -> None:
+ df = polars.DataFrame({
+ "name": ["Alice", "Bob", "Charlie"],
+ "age": [25, 30, 35]
+ })
+
+ edf = wf.EditableDataframe(df)
+ assert edf.df is not None
+ assert isinstance(edf.df, polars.DataFrame)
+
+ def test_editable_dataframe_should_read_record_from_polar_as_dict_based_on_record_index(self) -> None:
+ df = polars.DataFrame({
+ "name": ["Alice", "Bob", "Charlie"],
+ "age": [25, 30, 35]
+ })
+ edf = wf.EditableDataframe(df)
+
+ # When
+ r = edf.record(0)
+
+ # Then
+ assert r['name'] == 'Alice'
+ assert r['age'] == 25
+
+ def test_editable_dataframe_should_process_new_record_into_polar_dataframe(self) -> None:
+ df = polars.DataFrame({
+ "name": ["Alice", "Bob", "Charlie"],
+ "age": [25, 30, 35]
+ })
+
+ edf = wf.EditableDataframe(df)
+
+ # When
+ edf.record_add({"record": {"name": "David", "age": 40}})
+
+ # Then
+ assert len(edf.df) == 4
+
+ def test_editable_dataframe_should_update_existing_record_into_polar_dataframe(self) -> None:
+ df = polars.DataFrame({
+ "name": ["Alice", "Bob", "Charlie"],
+ "age": [25, 30, 35]
+ })
+
+ edf = wf.EditableDataframe(df)
+
+ # When
+ edf.record_update({"record_index": 0, "record": {"name": "Alicia", "age": 25}})
+
+ # Then
+ assert edf.df[0, "name"] == "Alicia"
+
+ def test_editable_dataframe_should_remove_existing_record_into_polar_dataframe(self) -> None:
+ df = polars.DataFrame({
+ "name": ["Alice", "Bob", "Charlie"],
+ "age": [25, 30, 35]
+ })
+
+ edf = wf.EditableDataframe(df)
+
+ # When
+ edf.record_remove({"record_index": 0})
+
+ # Then
+ assert len(edf.df) == 2
+
+ def test_editable_dataframe_should_serialize_polar_dataframe(self) -> None:
+ df = polars.DataFrame({
+ "name": ["Alice", "Bob", "Charlie"],
+ "age": [25, 30, 35],
+ "city": ["Paris", "London", "New York"]
+ })
+
+ edf = wf.EditableDataframe(df)
+
+ # When
+ table = edf.pyarrow_table()
+
+ # Then
+ assert len(table) == 3
+
+
+ def test_editable_dataframe_expose_list_of_records_in_df_property(self) -> None:
+ records = [
+ {"name": "Alice", "age": 25},
+ {"name": "Bob", "age": 30},
+ {"name": "Charlie", "age": 35}
+ ]
+
+ edf = wf.EditableDataframe(records)
+
+ assert edf.df is not None
+ assert isinstance(edf.df, list)
+
+ def test_editable_dataframe_should_read_record_from_list_of_record_as_dict_based_on_record_index(self) -> None:
+ records = [
+ {"name": "Alice", "age": 25},
+ {"name": "Bob", "age": 30},
+ {"name": "Charlie", "age": 35}
+ ]
+
+ edf = wf.EditableDataframe(records)
+
+ # When
+ r = edf.record(0)
+
+ # Then
+ assert r['name'] == 'Alice'
+ assert r['age'] == 25
+
+ def test_editable_dataframe_should_process_new_record_into_list_of_records(self) -> None:
+ records = [
+ {"name": "Alice", "age": 25},
+ {"name": "Bob", "age": 30},
+ {"name": "Charlie", "age": 35}
+ ]
+
+ edf = wf.EditableDataframe(records)
+
+ # When
+ edf.record_add({"record": {"name": "David", "age": 40}})
+
+ # Then
+ assert len(edf.df) == 4
+
+
+ def test_editable_dataframe_should_update_existing_record_into_list_of_record(self) -> None:
+ records = [
+ {"name": "Alice", "age": 25},
+ {"name": "Bob", "age": 30},
+ {"name": "Charlie", "age": 35}
+ ]
+
+ edf = wf.EditableDataframe(records)
+
+ # When
+ edf.record_update({"record_index": 0, "record": {"name": "Alicia", "age": 25}})
+
+ # Then
+ assert edf.df[0]['name'] == "Alicia"
+
+ def test_editable_dataframe_should_remove_existing_record_into_list_of_record(self) -> None:
+ records = [
+ {"name": "Alice", "age": 25},
+ {"name": "Bob", "age": 30},
+ {"name": "Charlie", "age": 35}
+ ]
+
+ edf = wf.EditableDataframe(records)
+
+ # When
+ edf.record_remove({"record_index": 0})
+
+ # Then
+ assert len(edf.df) == 2
+
+
+ def test_editable_dataframe_should_serialized_list_of_records_into_pyarrow_table(self) -> None:
+ records = [
+ {"name": "Alice", "age": 25},
+ {"name": "Bob", "age": 30},
+ {"name": "Charlie", "age": 35}
+ ]
+
+ edf = wf.EditableDataframe(records)
+
+ # When
+ table = edf.pyarrow_table()
+
+ # Then
+ assert len(table) == 3
+
+
+def test_import_failure_returns_expected_value_when_import_fails():
+ """
+ Test that an import failure returns the expected value
+ """
+ @import_failure(rvalue=False)
+ def myfunc():
+ import yop
+
+ assert myfunc() is False
+
+
+def test_import_failure_do_nothing_when_import_go_well():
+ """
+ Test that the import_failure decorator do nothing when the import is a success
+ """
+ @import_failure(rvalue=False)
+ def myfunc():
+ import math
+ return 2
+
+ assert myfunc() == 2
diff --git a/tests/backend/test_deploy.py b/tests/backend/test_deploy.py
new file mode 100644
index 000000000..637aa0f36
--- /dev/null
+++ b/tests/backend/test_deploy.py
@@ -0,0 +1,170 @@
+import json
+import re
+
+from click.testing import CliRunner
+from writer.command_line import main
+
+from backend.fixtures.cloud_deploy_fixtures import use_fake_cloud_deploy_server
+
+
+def _assert_warning(result, url = "https://my-app.com"):
+ found = re.search(f".WARNING. URL: {url}", result.output)
+
+ assert found is not None
+
+
+def _assert_url(result, expectedUrl):
+ url = re.search("URL: (.*)$", result.output)
+ assert url and url.group(1) == expectedUrl
+
+def _extract_envs(result):
+ content = re.search("(.*) ", result.output)
+ assert content is not None
+ return json.loads(content.group(1))
+
+
+def test_deploy():
+ runner = CliRunner()
+ with runner.isolated_filesystem(), use_fake_cloud_deploy_server():
+ result = runner.invoke(main, ['create', './my_app'])
+ assert result.exit_code == 0
+ result = runner.invoke(main, ['cloud', 'deploy', './my_app'], env={
+ 'WRITER_DEPLOY_URL': 'http://localhost:8888/deploy',
+ 'WRITER_API_KEY': 'test',
+ }, input='y\n')
+ print(result.output)
+ assert result.exit_code == 0
+ _assert_warning(result)
+ _assert_url(result, 'https://my-app.com')
+
+def test_deploy_force_flag():
+ runner = CliRunner()
+ with runner.isolated_filesystem(), use_fake_cloud_deploy_server():
+
+ result = runner.invoke(main, ['create', './my_app'])
+ assert result.exit_code == 0
+ result = runner.invoke(main, ['cloud', 'deploy', './my_app', '--force'], env={
+ 'WRITER_DEPLOY_URL': 'http://localhost:8888/deploy',
+ 'WRITER_API_KEY': 'test',
+ })
+ print(result.output)
+ assert result.exit_code == 0
+ found = re.search(".WARNING. URL: https://my-app.com", result.output)
+ assert found is None
+ _assert_url(result, 'https://my-app.com')
+
+def test_deploy_api_key_option():
+ runner = CliRunner()
+ with runner.isolated_filesystem(), use_fake_cloud_deploy_server():
+
+ result = runner.invoke(main, ['create', './my_app'])
+ assert result.exit_code == 0
+ result = runner.invoke(main, ['cloud', 'deploy', './my_app', '--api-key', 'test'], env={
+ 'WRITER_DEPLOY_URL': 'http://localhost:8888/deploy',
+ 'WRITER_API_KEY': 'fail',
+ }, input='y\n')
+ print(result.output)
+ assert result.exit_code == 0
+ _assert_warning(result)
+ _assert_url(result, 'https://my-app.com')
+
+def test_deploy_api_key_prompt():
+ runner = CliRunner()
+ with runner.isolated_filesystem(), use_fake_cloud_deploy_server():
+
+ result = runner.invoke(main, ['create', './my_app'])
+ assert result.exit_code == 0
+ result = runner.invoke(main, ['cloud', 'deploy', './my_app'], env={
+ 'WRITER_DEPLOY_URL': 'http://localhost:8888/deploy',
+ }, input='test\ny\n')
+ print(result.output)
+ assert result.exit_code == 0
+ _assert_warning(result)
+ _assert_url(result, 'https://my-app.com')
+
+def test_deploy_warning():
+ runner = CliRunner()
+ with runner.isolated_filesystem(), use_fake_cloud_deploy_server():
+
+ result = runner.invoke(main, ['create', './my_app'])
+ assert result.exit_code == 0
+ result = runner.invoke(main, ['cloud', 'deploy', './my_app'], env={
+ 'WRITER_DEPLOY_URL': 'http://localhost:8888/deploy',
+ 'WRITER_API_KEY': 'test',
+ })
+ print(result.output)
+ assert result.exit_code == 1
+
+def test_deploy_env():
+ runner = CliRunner()
+ with runner.isolated_filesystem(), use_fake_cloud_deploy_server():
+
+ result = runner.invoke(main, ['create', './my_app'])
+ assert result.exit_code == 0
+ result = runner.invoke(main,
+ args = [
+ 'cloud', 'deploy', './my_app',
+ '-e', 'ENV1=test', '-e', 'ENV2=other'
+ ],
+ env={
+ 'WRITER_DEPLOY_URL': 'http://localhost:8888/deploy',
+ 'WRITER_API_KEY': 'test',
+ 'WRITER_DEPLOY_SLEEP_INTERVAL': '0'
+ },
+ input='y\n'
+ )
+ print(result.output)
+ assert result.exit_code == 0
+ envs = _extract_envs(result)
+ assert envs['ENV1'] == 'test'
+ assert envs['ENV2'] == 'other'
+ _assert_url(result, 'https://my-app.com')
+
+def test_deploy_full_flow():
+ runner = CliRunner()
+ with runner.isolated_filesystem(), use_fake_cloud_deploy_server():
+
+ result = runner.invoke(main, ['create', './my_app'])
+ assert result.exit_code == 0
+ result = runner.invoke(main,
+ args = [
+ 'cloud', 'deploy', './my_app',
+ '-e', 'ENV1=test', '-e', 'ENV2=other'
+ ],
+ env={
+ 'WRITER_DEPLOY_URL': 'http://localhost:8888/deploy',
+ 'WRITER_API_KEY': 'full',
+ 'WRITER_DEPLOY_SLEEP_INTERVAL': '0'
+ },
+ )
+ print(result.output)
+ assert result.exit_code == 0
+ envs = _extract_envs(result)
+ assert envs['ENV1'] == 'test'
+ assert envs['ENV2'] == 'other'
+ _assert_url(result, 'https://full.my-app.com')
+
+ logs = re.findall(" ", result.output)
+ assert logs[0] == " "
+ assert logs[1] == " "
+ assert logs[2] == " "
+ assert logs[3] == " "
+
+
+def test_undeploy():
+ runner = CliRunner()
+ with runner.isolated_filesystem(), use_fake_cloud_deploy_server():
+ result = runner.invoke(main,
+ args = [
+ 'cloud', 'undeploy'
+ ],
+ env={
+ 'WRITER_DEPLOY_URL': 'http://localhost:8888/deploy',
+ 'WRITER_API_KEY': 'full',
+ 'WRITER_DEPLOY_SLEEP_INTERVAL': '0'
+ },
+ )
+ print(result.output)
+ assert re.search("App undeployed", result.output)
+ assert result.exit_code == 0
+
diff --git a/tests/conftest.py b/tests/conftest.py
index f75f0b495..d457e68db 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -51,3 +51,15 @@ def _manage_launch_args(app_dir: str, app_command: Literal["run", "edit"], load:
finally:
ar.shut_down()
return _manage_launch_args
+
+@pytest.fixture(autouse=True)
+def build_app_provisionning():
+ import os
+ import shutil
+
+ root_dir = os.path.dirname(os.path.dirname(__file__))
+
+ if os.path.isdir(os.path.join(root_dir, 'src/writer/app_templates')):
+ shutil.rmtree(os.path.join(root_dir, 'src/writer/app_templates'))
+
+ shutil.copytree( os.path.join(root_dir, 'apps'), os.path.join(root_dir, 'src/writer/app_templates'))