diff --git a/geoservercloud/geoservercloud.py b/geoservercloud/geoservercloud.py index 7521de4..340334c 100644 --- a/geoservercloud/geoservercloud.py +++ b/geoservercloud/geoservercloud.py @@ -142,9 +142,7 @@ def unset_default_locale_for_service(self, workspace_name) -> tuple[str, int]: """ return self.set_default_locale_for_service(workspace_name, None) - def get_datastores( - self, workspace_name: str - ) -> tuple[list[dict[str, str]] | str, int]: + def get_datastores(self, workspace_name: str) -> tuple[list[str] | str, int]: """ Get all datastores for a given workspace """ diff --git a/geoservercloud/geoservercloudsync.py b/geoservercloud/geoservercloudsync.py index ff49430..45b5921 100644 --- a/geoservercloud/geoservercloudsync.py +++ b/geoservercloud/geoservercloudsync.py @@ -43,22 +43,109 @@ def __init__( self.dst_instance: RestService = RestService(dst_url, self.dst_auth) def copy_workspace( - self, workspace_name: str, include_styles: bool = False + self, workspace_name: str, deep_copy: bool = False ) -> tuple[str, int]: """ Copy a workspace from source to destination GeoServer instance, optionally including styles """ - workspace, status_code = self.src_instance.get_workspace(workspace_name) + workspace, code = self.src_instance.get_workspace(workspace_name) if isinstance(workspace, str): - return workspace, status_code - new_workspace, status_code = self.dst_instance.create_workspace(workspace) - if status_code >= 400: - return new_workspace, status_code - if include_styles: + return workspace, code + new_ws, new_ws_code = self.dst_instance.create_workspace(workspace) + if self.not_ok(code): + return new_ws, new_ws_code + if deep_copy: content, code = self.copy_styles(workspace_name) - if code >= 400: + if self.not_ok(code): return content, code - return new_workspace, status_code + content, code = self.copy_pg_datastores(workspace_name, deep_copy=True) + if self.not_ok(code): + return content, code + return new_ws, new_ws_code + + def copy_pg_datastores( + self, workspace_name: str, deep_copy: bool = False + ) -> tuple[str, int]: + """ + Copy all the datastores in given workspace. + If deep_copy is True, copy all feature types and the corresponding layers in each datastore + """ + datastores, code = self.src_instance.get_datastores(workspace_name) + if isinstance(datastores, str): + return datastores, code + for datastore_name in datastores.aslist(): + content, code = self.copy_pg_datastore( + workspace_name, datastore_name, deep_copy=deep_copy + ) + if self.not_ok(code): + return content, code + return content, code + + def copy_pg_datastore( + self, workspace_name: str, datastore_name: str, deep_copy: bool = False + ) -> tuple[str, int]: + """ + Copy a datastore from source to destination GeoServer instance + If deep_copy is True, copy all feature types and the corresponding layers + """ + datastore, status_code = self.src_instance.get_pg_datastore( + workspace_name, datastore_name + ) + if isinstance(datastore, str): + return datastore, status_code + new_ds, new_ds_code = self.dst_instance.create_pg_datastore( + workspace_name, datastore + ) + if deep_copy: + self.copy_feature_types(workspace_name, datastore_name, copy_layers=True) + return new_ds, new_ds_code + + def copy_feature_types( + self, workspace_name: str, datastore_name: str, copy_layers: bool = False + ) -> tuple[str, int]: + """ + Copy all feature types in a datastore from source to destination GeoServer instance + """ + feature_types, status_code = self.src_instance.get_feature_types( + workspace_name, datastore_name + ) + if isinstance(feature_types, str): + return feature_types, status_code + for feature_type in feature_types.aslist(): + content, code = self.copy_feature_type( + workspace_name, datastore_name, feature_type["name"] + ) + if self.not_ok(code): + return content, code + if copy_layers: + content, code = self.copy_layer(workspace_name, feature_type["name"]) + if self.not_ok(code): + return content, code + return content, code + + def copy_feature_type( + self, workspace_name: str, datastore_name: str, feature_type_name: str + ) -> tuple[str, int]: + """ + Copy a feature type from source to destination GeoServer instance + """ + feature_type, code = self.src_instance.get_feature_type( + workspace_name, datastore_name, feature_type_name + ) + if isinstance(feature_type, str): + return feature_type, code + return self.dst_instance.create_feature_type(feature_type) + + def copy_layer( + self, workspace_name: str, feature_type_name: str + ) -> tuple[str, int]: + """ + Copy a layer from source to destination GeoServer instance + """ + layer, code = self.src_instance.get_layer(workspace_name, feature_type_name) + if isinstance(layer, str): + return layer, code + return self.dst_instance.update_layer(layer, workspace_name) def copy_styles( self, workspace_name: str | None = None, include_images: bool = True @@ -68,14 +155,14 @@ def copy_styles( """ if include_images: content, code = self.copy_style_images(workspace_name) - if code >= 400: + if self.not_ok(code): return content, code styles, code = self.src_instance.get_styles(workspace_name) if isinstance(styles, str): return styles, code for style in styles.aslist(): content, code = self.copy_style(style, workspace_name) - if code >= 400: + if self.not_ok(code): return content, code return content, code @@ -124,7 +211,7 @@ def copy_resource( resource, code = self.src_instance.get_resource( resource_dir, resource_name, workspace_name ) - if code >= 400: + if self.not_ok(code): return resource.decode(), code return self.dst_instance.put_resource( path=resource_dir, @@ -133,3 +220,7 @@ def copy_resource( content_type=content_type, data=resource, ) + + @staticmethod + def not_ok(code: int) -> bool: + return code >= 400 diff --git a/geoservercloud/models/datastores.py b/geoservercloud/models/datastores.py index e49ba5a..5eaff59 100644 --- a/geoservercloud/models/datastores.py +++ b/geoservercloud/models/datastores.py @@ -2,18 +2,18 @@ class DataStores(ListModel): - def __init__(self, datastores: list[dict[str, str]] = []) -> None: - self._datastores: list[dict[str, str]] = datastores + def __init__(self, datastores: list[str] = []) -> None: + self._datastores: list[str] = datastores @classmethod def from_get_response_payload(cls, content: dict): - datastores: str | dict = content["dataStores"] + datastores: dict | str = content["dataStores"] if not datastores: return cls() - return cls(datastores["dataStore"]) # type: ignore + return cls([ds["name"] for ds in datastores["dataStore"]]) # type: ignore def __repr__(self) -> str: - return str(self._datastores) + return str([{"name": ds} for ds in self._datastores]) - def aslist(self) -> list[dict[str, str]]: + def aslist(self) -> list[str]: return self._datastores diff --git a/geoservercloud/services/restservice.py b/geoservercloud/services/restservice.py index ec4de21..49ad52c 100644 --- a/geoservercloud/services/restservice.py +++ b/geoservercloud/services/restservice.py @@ -383,6 +383,14 @@ def create_style( response = self.rest_client.put(resource_path, data=style, headers=headers) return response.content.decode(), response.status_code + def get_layer( + self, workspace_name: str, layer_name: str + ) -> tuple[Layer | str, int]: + response: Response = self.rest_client.get( + self.rest_endpoints.workspace_layer(workspace_name, layer_name) + ) + return self.deserialize_response(response, Layer) + def update_layer(self, layer: Layer, workspace_name: str) -> tuple[str, int]: response: Response = self.rest_client.put( self.rest_endpoints.workspace_layer(workspace_name, layer.name), diff --git a/tests/models/test_datastores.py b/tests/models/test_datastores.py index 2d756e4..dcb27a0 100644 --- a/tests/models/test_datastores.py +++ b/tests/models/test_datastores.py @@ -4,33 +4,28 @@ @fixture(scope="module") -def mock_datastore(): - return { - "name": "DataStore1", - "href": "http://example.com/ds1", - } - - -@fixture(scope="module") -def mock_response(mock_datastore): +def mock_response(): return { "dataStores": { - "dataStore": [mock_datastore], + "dataStore": [ + { + "name": "DataStore1", + "href": "http://example.com/ds1", + }, + { + "name": "DataStore2", + "href": "http://example.com/ds2", + }, + ], } } -def test_datastores_initialization(mock_datastore): - ds = DataStores([mock_datastore]) - - assert ds.aslist() == [mock_datastore] - - -def test_datastores_from_get_response_payload(mock_datastore, mock_response): +def test_datastores_from_get_response_payload(mock_response): ds = DataStores.from_get_response_payload(mock_response) - assert ds.aslist() == [mock_datastore] + assert ds.aslist() == ["DataStore1", "DataStore2"] def test_datastores_from_get_response_payload_empty(): @@ -41,9 +36,9 @@ def test_datastores_from_get_response_payload_empty(): assert ds.aslist() == [] -def test_datastores_repr(mock_datastore): - ds = DataStores([mock_datastore]) +def test_datastores_repr(): + ds = DataStores(["DataStore1", "DataStore2"]) - expected_repr = "[{'name': 'DataStore1', 'href': 'http://example.com/ds1'}]" + expected_repr = "[{'name': 'DataStore1'}, {'name': 'DataStore2'}]" assert repr(ds) == expected_repr diff --git a/tests/test_datastore.py b/tests/test_datastore.py index 8dd3f5f..d20f688 100644 --- a/tests/test_datastore.py +++ b/tests/test_datastore.py @@ -132,7 +132,7 @@ def test_get_datastores( ) datastores, status_code = geoserver.get_datastores(workspace_name=WORKSPACE) - assert datastores == datastores_get_response["dataStores"]["dataStore"] + assert datastores == ["test_store"] assert status_code == 200