Skip to content

Commit

Permalink
Pass data server config to any query/predict, auto-update metadata to…
Browse files Browse the repository at this point in the history
… have kind, signature, don't check uv version.
  • Loading branch information
fabioz committed Dec 20, 2024
1 parent 99a8d8b commit 4be97e9
Show file tree
Hide file tree
Showing 16 changed files with 359 additions and 132 deletions.
3 changes: 3 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
- Code lenses are now shown to run `@query` and `@predict` decorated methods.
- Action Server updated to `2.3.0`.
- Agent CLI updated to `0.2.2`
- Action input format automatically updated to `v3` (now have a `kind` field and an `actionSignature` field).
- If the signature changes between action launches, the input file metadata is automatically updated.
- Don't warn the user if the `uv` version is not the latest in `package.yaml`.

## New in 2.8.1 (2024-11-21)

Expand Down
3 changes: 3 additions & 0 deletions sema4ai/src/sema4ai_code/robo/collect_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class ExtractedActionInfo(TypedDict):
managed_params_json_schema: dict
action_name: str
action_relative_path: str
kind: str


def extract_info(
Expand All @@ -53,6 +54,7 @@ def extract_info(
for item in action_info_found:
action_name = item.get("name", "")
target_file = item.get("file", "")
kind = item.get("options", {}).get("kind", "action")
relative_path = ""
# Now, make relative to the action_package_yaml_directory
if target_file:
Expand Down Expand Up @@ -96,6 +98,7 @@ def extract_info(
"managed_params_json_schema": managed_params_schema,
"action_name": action_name,
"action_relative_path": relative_path,
"kind": kind,
}

action_name_to_extracted_info[item.get("name")] = full
Expand Down
50 changes: 45 additions & 5 deletions sema4ai/src/sema4ai_code/robo/collect_actions_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
from sema4ai_ls_core import uris
from sema4ai_ls_core.core_log import get_logger
from sema4ai_ls_core.lsp import RangeTypedDict
from sema4ai_ls_core.protocols import ActionInfoTypedDict, DatasourceInfoTypedDict
from sema4ai_ls_core.protocols import (
ActionInfoTypedDict,
ActionResult,
DatasourceInfoTypedDict,
IMonitor,
)

log = get_logger(__name__)

Expand Down Expand Up @@ -211,7 +216,12 @@ def _collect_datasources(
)


def _collect_actions_from_ast(ast: ast_module.AST) -> Iterator[dict]:
class _ActionInfo(TypedDict):
kind: str
node: ast_module.FunctionDef


def _collect_actions_from_ast(ast: ast_module.AST) -> Iterator[_ActionInfo]:
for _stack, node in _iter_nodes(ast, recursive=False):
if isinstance(node, ast_module.FunctionDef):
for decorator in node.decorator_list:
Expand Down Expand Up @@ -291,12 +301,12 @@ def iter_actions_and_datasources(
uri = uris.from_fs_path(str(f))

for node_info_action in _collect_actions_from_ast(ast):
ast_node = node_info_action["node"]
node_range = _get_ast_node_range(ast_node)
function_def_node = node_info_action["node"]
node_range = _get_ast_node_range(function_def_node)
yield ActionInfoTypedDict(
uri=uri,
range=node_range,
name=ast_node.name,
name=function_def_node.name,
kind=node_info_action["kind"],
)

Expand Down Expand Up @@ -344,3 +354,33 @@ def iter_actions_and_datasources(
log.error(
f"Unable to collect @action/@query/@predict/datasources from {f}. Error: {e}"
)


def get_action_signature(
action_relative_path: str,
action_package_yaml_directory: str,
action_name: str,
monitor: IMonitor,
) -> ActionResult[str]:
action_file_path = Path(action_package_yaml_directory) / action_relative_path
if not action_file_path.exists():
return ActionResult.make_failure(f"Action file not found: {action_file_path}")

action_contents_file = action_file_path.read_bytes()
try:
ast = ast_module.parse(action_contents_file, "<string>")
except Exception:
return ActionResult.make_failure(
f"Unable to parse action file: {action_file_path}"
)

for node_info_action in _collect_actions_from_ast(ast):
function_def_node = node_info_action["node"]
if function_def_node.name == action_name:
# Convert the function signature to a string
signature = ast_module.unparse(function_def_node.args)
return ActionResult.make_success(
f"{node_info_action['kind']}/args: {signature!r}"
)

return ActionResult.make_failure(f"Action not found: {action_name}")
28 changes: 28 additions & 0 deletions sema4ai/src/sema4ai_code/robocorp_language_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,34 @@ def _run_in_rcc_internal(self, params=RunInRccParamsDict) -> ActionResultDict:
return dict(success=False, message=str(e), result=None)
return ret.as_dict()

def m_get_action_signature(
self,
action_relative_path: str,
action_package_yaml_directory: str,
action_name: str,
) -> ActionResultDict[str]:
return require_monitor(
partial(
self._get_action_signature,
action_relative_path,
action_package_yaml_directory,
action_name,
)
)

def _get_action_signature(
self,
action_relative_path: str,
action_package_yaml_directory: str,
action_name: str,
monitor: IMonitor,
) -> ActionResultDict[str]:
from sema4ai_code.robo.collect_actions_ast import get_action_signature

return get_action_signature(
action_relative_path, action_package_yaml_directory, action_name, monitor
).as_dict()

def m_list_actions_full_and_slow(
self, action_package_uri: str = "", action_package_yaml_directory: str = ""
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def iter_conda_issues(self) -> Iterator[_DiagnosticsTypedDict]:
if sqlite_queries:
with sqlite_queries.db_cursors() as db_cursors:
for conda_dep in self._conda_deps.iter_deps_infos():
if conda_dep.name in ("python", "pip"):
if conda_dep.name in ("python", "pip", "uv"):
continue

if conda_dep.error_msg:
Expand Down
16 changes: 16 additions & 0 deletions sema4ai/tests/sema4ai_code_tests/robo/test_list_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,22 @@
from sema4ai_code.robo.collect_actions import get_metadata


@pytest.mark.parametrize("action_name", ["get_churn_data", "predict_churn"])
def test_get_action_signature(cases, data_regression, action_name):
from sema4ai_ls_core.jsonrpc.monitor import Monitor

from sema4ai_code.robo.collect_actions_ast import get_action_signature

action_package_path = Path(cases.get_path("action_package", must_exist=True))
action_relative_path = "data_actions.py"
monitor = Monitor()
result = get_action_signature(
action_relative_path, action_package_path, action_name, monitor
)
assert result.success, result.message
data_regression.check(result.result)


def test_list_actions_and_datasources_simple(cases, data_regression):
action_package_path = Path(cases.get_path("action_package", must_exist=True))

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
'query/args: ''datasource: FileChurnDataSource'''
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
'predict/args: ''datasource: Annotated[DataSource, ChurnPredictionDataSource | FileChurnDataSource],
limit: int=10'''
Original file line number Diff line number Diff line change
Expand Up @@ -182,4 +182,5 @@ my_action:
required:
- entry
type: object
kind: action
managed_params_json_schema: {}
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ my_action:
required:
- model
type: object
kind: action
managed_params_json_schema: {}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ my_action:
json_schema:
properties: {}
type: object
kind: action
managed_params_json_schema:
google_secret:
provider: google
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ my_action:
- d
- model
type: object
kind: action
managed_params_json_schema:
secret:
type: Secret
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@ my_action:
required:
- entry
type: object
kind: action
managed_params_json_schema: {}
5 changes: 4 additions & 1 deletion sema4ai/vscode-client/src/files.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ export async function uriExists(uri: Uri) {
}
}

export async function readFromFile(targetFile: string) {
/**
* @returns undefined if the file does not exist, otherwise the file contents.
*/
export async function readFromFile(targetFile: string): Promise<string | undefined> {
if (!(await fileExists(targetFile))) {
return undefined;
}
Expand Down
Loading

0 comments on commit 4be97e9

Please sign in to comment.