Skip to content

Commit

Permalink
Refactor zenml_wrappers.py for compatibility with ZenML v0.55.0+, sto…
Browse files Browse the repository at this point in the history
…re/store_configuration and set_store/set_store_configuration changes
  • Loading branch information
marwan37 committed Mar 25, 2024
1 parent 88f3343 commit ee537fd
Showing 1 changed file with 24 additions and 4 deletions.
28 changes: 24 additions & 4 deletions bundled/tool/zenml_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,16 @@ def set_store_configuration(self, remote_url: str, access_token: str):
new_store_config = self.RestZenStoreConfiguration(
type="rest", url=remote_url, api_token=access_token, verify_ssl=True
)

# Method name changed in 0.55.4 - 0.56.1
if hasattr(self.gc, "set_store_configuration"):
self.gc.set_store_configuration(new_store_config)
elif hasattr(self.gc, "set_store"): # Old method name
self.gc.set_store(new_store_config)
else:
raise AttributeError(
"GlobalConfiguration object does not have a method to set store configuration."
)
self.gc.set_store(new_store_config)

def get_global_configuration(self) -> dict:
Expand Down Expand Up @@ -164,8 +174,13 @@ def get_server_info(self) -> dict:
dict: Dictionary containing server info.
"""
store_info = json.loads(self.gc.zen_store.get_store_info().json(indent=2))
##THIS CHANGED from 0.55.2 to 0.55.5 store -> store_configuration
store_config = json.loads(self.gc.store.json(indent=2))
# Handle both 'store' and 'store_configuration' depending on version
store_attr_name = (
"store_configuration"
if hasattr(self.gc, "store_configuration")
else "store"
)
store_config = json.loads(getattr(self.gc, store_attr_name).json(indent=2))
return {"storeInfo": store_info, "storeConfig": store_config}

def connect(self, args, **kwargs) -> dict:
Expand Down Expand Up @@ -201,8 +216,13 @@ def disconnect(self, args) -> dict:
dict: Dictionary containing the result of the operation.
"""
try:
##THIS CHANGED from 0.55.2 to 0.55.5 store -> store_configuration
url = self.gc.store.url
# Adjust for changes from 'store' to 'store_configuration'
store_attr_name = (
"store_configuration"
if hasattr(self.gc, "store_configuration")
else "store"
)
url = getattr(self.gc, store_attr_name).url
store_type = self.BaseZenStore.get_store_type(url)

# pylint: disable=not-callable
Expand Down

0 comments on commit ee537fd

Please sign in to comment.