Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Error When Retrieving Server Information #28

Merged
merged 2 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion bundled/tool/type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,33 @@ class RunArtifactResponse(TypedDict):
author: Dict[str, str]
update: str
data: Dict[str, str]
metadata: Dict[str, Any]
metadata: Dict[str, Any]

class ZenmlStoreInfo(TypedDict):
id: str
version: str
debug: bool
deployment_type: str
database_type: str
secrets_store_type: str
auth_scheme: str
server_url: str
dashboard_url: str

class ZenmlStoreConfig(TypedDict):
type: str
url: str
api_token: Union[str, None]

class ZenmlServerInfoResp(TypedDict):
store_info: ZenmlStoreInfo
store_config: ZenmlStoreConfig

class ZenmlGlobalConfigResp(TypedDict):
user_id: str
user_email: str
analytics_opt_in: bool
version: str
active_stack_id: str
active_workspace_name: str
store: ZenmlStoreConfig
55 changes: 43 additions & 12 deletions bundled/tool/zenml_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
# permissions and limitations under the License.
"""This module provides wrappers for ZenML configuration and operations."""

import json
import pathlib
from typing import Any, Tuple, Union
from type_hints import GraphResponse, ErrorResponse, RunStepResponse, RunArtifactResponse
from type_hints import GraphResponse, ErrorResponse, RunStepResponse, RunArtifactResponse, ZenmlServerInfoResp, ZenmlGlobalConfigResp
from zenml_grapher import Grapher


Expand Down Expand Up @@ -99,19 +98,32 @@ def set_store_configuration(self, remote_url: str, access_token: str):
)
self.gc.set_store(new_store_config)

def get_global_configuration(self) -> dict:
def get_global_configuration(self) -> ZenmlGlobalConfigResp:
"""Get the global configuration.

Returns:
dict: Global configuration.
"""
gc_dict = json.loads(self.gc.json(indent=2))
user_id = gc_dict.get("user_id", "")

if user_id and user_id.startswith("UUID('") and user_id.endswith("')"):
gc_dict["user_id"] = user_id[6:-2]
store_attr_name = (
"store_configuration" if hasattr(self.gc, "store_configuration") else "store"
)

return gc_dict
store_data = getattr(self.gc, store_attr_name)

return {
"user_id": str(self.gc.user_id),
"user_email": self.gc.user_email,
"analytics_opt_in": self.gc.analytics_opt_in,
"version": self.gc.version,
"active_stack_id": str(self.gc.active_stack_id),
"active_workspace_name": self.gc.active_workspace_name,
"store": {
"type": store_data.type,
"url": store_data.url,
"api_token": store_data.api_token if hasattr(store_data, "api_token") else None
}
}


class ZenServerWrapper:
Expand Down Expand Up @@ -167,19 +179,38 @@ def get_active_deployment(self):
"""Returns the function to get the active ZenML server deployment."""
return self.lazy_import("zenml.zen_server.utils", "get_active_deployment")

def get_server_info(self) -> dict:
def get_server_info(self) -> ZenmlServerInfoResp:
"""Fetches the ZenML server info.

Returns:
dict: Dictionary containing server info.
"""
store_info = json.loads(self.gc.zen_store.get_store_info().json(indent=2))
store_info = self.gc.zen_store.get_store_info()

# Handle both 'store' and 'store_configuration' depending on version
store_attr_name = (
"store_configuration" if hasattr(self.gc, "store_configuration") else "store"
)
store_config = json.loads(getattr(self.gc, store_attr_name).json(indent=2))
return {"storeInfo": store_info, "storeConfig": store_config}
store_config = getattr(self.gc, store_attr_name)

return {
"storeInfo": {
"id": str(store_info.id),
"version": store_info.version,
"debug": store_info.debug,
"deployment_type": store_info.deployment_type,
"database_type": store_info.database_type,
"secrets_store_type": store_info.secrets_store_type,
"auth_scheme": store_info.auth_scheme,
"server_url": store_info.server_url,
"dashboard_url": store_info.dashboard_url,
},
"storeConfig": {
"type": store_config.type,
"url": store_config.url,
"api_token": store_config.api_token if hasattr(store_config, "api_token") else None
}
}

def connect(self, args, **kwargs) -> dict:
"""Connects to a ZenML server.
Expand Down
Loading