From 3f904733388ffaf0c954256ec407f0026299d157 Mon Sep 17 00:00:00 2001 From: dweinholz Date: Wed, 3 Jan 2024 16:02:53 +0100 Subject: [PATCH] added more tests openstack --- .../openstack_connector.py | 302 +++--- simple_vm_client/test_openstack_connector.py | 872 +++++++++++++++++- 2 files changed, 973 insertions(+), 201 deletions(-) diff --git a/simple_vm_client/openstack_connector/openstack_connector.py b/simple_vm_client/openstack_connector/openstack_connector.py index 67d2690..187479e 100644 --- a/simple_vm_client/openstack_connector/openstack_connector.py +++ b/simple_vm_client/openstack_connector/openstack_connector.py @@ -141,7 +141,7 @@ def load_env_config(self) -> None: sys.exit(1) self.USE_APPLICATION_CREDENTIALS = ( - os.environ.get("USE_APPLICATION_CREDENTIALS", "False").lower() == "true" + os.environ.get("USE_APPLICATION_CREDENTIALS", "False").lower() == "true" ) if self.USE_APPLICATION_CREDENTIALS: @@ -183,15 +183,15 @@ def load_env_config(self) -> None: self.PROJECT_DOMAIN_ID = os.environ["OS_PROJECT_DOMAIN_ID"] def create_server( - self, - name: str, - image_id: str, - flavor_id: str, - network_id: str, - userdata: str, - key_name: str, - metadata: dict[str, str], - security_groups: list[str], + self, + name: str, + image_id: str, + flavor_id: str, + network_id: str, + userdata: str, + key_name: str, + metadata: dict[str, str], + security_groups: list[str], ) -> Server: logger.info( f"Create Server:\n\tname: {name}\n\timage_id:{image_id}\n\tflavor_id:{flavor_id}\n\tmetadata:{metadata}" @@ -233,7 +233,7 @@ def delete_volume(self, volume_id: str) -> None: raise DefaultException(message=e.message) def create_volume_snapshot( - self, volume_id: str, name: str, description: str + self, volume_id: str, name: str, description: str ) -> str: try: logger.info(f"Create Snapshot for Volume {volume_id}") @@ -276,7 +276,7 @@ def delete_volume_snapshot(self, snapshot_id: str) -> None: raise DefaultException(message=e.message) def create_volume_by_source_volume( - self, volume_name: str, metadata: dict[str, str], source_volume_id: str + self, volume_name: str, metadata: dict[str, str], source_volume_id: str ) -> Volume: logger.info(f"Creating volume from source volume with id {source_volume_id}") try: @@ -292,7 +292,7 @@ def create_volume_by_source_volume( raise ResourceNotAvailableException(message=e.message) def create_volume_by_volume_snap( - self, volume_name: str, metadata: dict[str, str], volume_snap_id: str + self, volume_name: str, metadata: dict[str, str], volume_snap_id: str ) -> Volume: logger.info(f"Creating volume from volume snapshot with id {volume_snap_id}") try: @@ -330,7 +330,7 @@ def get_servers_by_ids(self, ids: list[str]) -> list[Server]: return servers def attach_volume_to_server( - self, openstack_id: str, volume_id: str + self, openstack_id: str, volume_id: str ) -> dict[str, str]: server = self.get_server(openstack_id=openstack_id) volume = self.get_volume(name_or_id=volume_id) @@ -371,7 +371,7 @@ def resize_volume(self, volume_id: str, size: int) -> None: raise DefaultException(message=str(e)) def create_volume( - self, volume_name: str, volume_storage: int, metadata: dict[str, str] + self, volume_name: str, volume_storage: int, metadata: dict[str, str] ) -> Volume: logger.info(f"Creating volume with {volume_storage} GB storage") try: @@ -488,9 +488,9 @@ def get_active_image_by_os_version(self, os_version: str, os_distro: str) -> Ima image_os_distro = metadata.get("os_distro", None) base_image_ref = metadata.get("base_image_ref", None) if ( - os_version == image_os_version - and image.status == "active" - and base_image_ref is None + os_version == image_os_version + and image.status == "active" + and base_image_ref is None ): if os_distro and os_distro == image_os_distro: return image @@ -502,11 +502,11 @@ def get_active_image_by_os_version(self, os_version: str, os_distro: str) -> Ima ) def get_image( - self, - name_or_id: str, - replace_inactive: bool = False, - ignore_not_active: bool = False, - ignore_not_found: bool = False, + self, + name_or_id: str, + replace_inactive: bool = False, + ignore_not_active: bool = False, + ignore_not_found: bool = False, ) -> Image: logger.info(f"Get Image {name_or_id}") @@ -530,12 +530,12 @@ def get_image( return image def create_snapshot( - self, - openstack_id: str, - name: str, - username: str, - base_tags: list[str], - description: str, + self, + openstack_id: str, + name: str, + username: str, + base_tags: list[str], + description: str, ) -> str: logger.info( f"Create Snapshot from Instance {openstack_id} with name {name} for {username}" @@ -647,9 +647,9 @@ def get_gateway_ip(self) -> dict[str, str]: return {"gateway_ip": self.GATEWAY_IP} def create_mount_init_script( - self, - new_volumes: list[dict[str, str]] = None, # type: ignore - attach_volumes: list[dict[str, str]] = None, # type: ignore + self, + new_volumes: list[dict[str, str]] = None, # type: ignore + attach_volumes: list[dict[str, str]] = None, # type: ignore ) -> str: logger.info(f"Create init script for volume ids:{new_volumes}") if not new_volumes and not attach_volumes: @@ -730,7 +730,7 @@ def delete_security_group_rule(self, openstack_id): ) def open_port_range_for_vm_in_project( - self, range_start, range_stop, openstack_id, ethertype="IPV4", protocol="TCP" + self, range_start, range_stop, openstack_id, ethertype="IPV4", protocol="TCP" ): server: Server = self.openstack_connection.get_server_by_id(id=openstack_id) if server is None: @@ -779,13 +779,13 @@ def open_port_range_for_vm_in_project( raise OpenStackConflictException(message=e.message) def create_security_group( - self, - name: str, - udp_port: int = None, # type: ignore - ssh: bool = True, - udp: bool = False, - description: str = "", - research_environment_metadata: ResearchEnvironmentMetadata = None, + self, + name: str, + udp_port: int = None, # type: ignore + ssh: bool = True, + udp: bool = False, + description: str = "", + research_environment_metadata: ResearchEnvironmentMetadata = None, ) -> SecurityGroup: logger.info(f"Create new security group {name}") sec: SecurityGroup = self.openstack_connection.get_security_group( @@ -843,6 +843,8 @@ def create_security_group( remote_group_id=self.GATEWAY_SECURITY_GROUP_ID, ) if research_environment_metadata: + logger.info(f"Add research env rule to security group {name}") + self.openstack_connection.network.create_security_group_rule( direction=research_environment_metadata.direction, protocol=research_environment_metadata.protocol, @@ -892,7 +894,7 @@ def is_security_group_in_use(self, security_group_id): return False def get_or_create_research_environment_security_group( - self, resenv_metadata: ResearchEnvironmentMetadata + self, resenv_metadata: ResearchEnvironmentMetadata ): if not resenv_metadata.needs_forc_support: return None @@ -913,7 +915,7 @@ def get_or_create_research_environment_security_group( ) new_security_group = self.openstack_connection.create_security_group( - name=resenv_metadata.securitygroup_name, description=resenv_metadata.name + name=resenv_metadata.securitygroup_name, description=resenv_metadata.description ) self.openstack_connection.network.create_security_group_rule( direction=resenv_metadata.direction, @@ -969,8 +971,7 @@ def get_or_create_project_security_group(self, project_name, project_id): def get_limits(self) -> dict[str, str]: logger.info("Get Limits") - limits = {} - limits.update(self.openstack_connection.get_compute_limits()) + limits = self.openstack_connection.get_compute_limits() limits.update(self.openstack_connection.get_volume_limits()["absolute"]) return { @@ -980,10 +981,10 @@ def get_limits(self) -> dict[str, str]: "current_used_cores": str(limits["total_cores_used"]), "current_used_vms": str(limits["total_instances_used"]), "current_used_ram": str(math.ceil(limits["total_ram_used"] / 1024)), - "volume_counter_limit": str(limits["maxTotalVolumes"]), - "volume_storage_limit": str(limits["maxTotalVolumeGigabytes"]), - "current_used_volumes": str(limits["totalVolumesUsed"]), - "current_used_volume_storage": str(limits["totalGigabytesUsed"]), + "volume_counter_limit": str(limits["max_total_volumes"]), + "volume_storage_limit": str(limits["max_total_volume_gigabytes"]), + "current_used_volumes": str(limits["total_volumes_used"]), + "current_used_volume_storage": str(limits["total_gigabytes_used"]), } def exist_server(self, name: str) -> bool: @@ -995,13 +996,7 @@ def exist_server(self, name: str) -> bool: def set_server_metadata(self, openstack_id: str, metadata) -> None: try: logger.info(f"Set Server Metadata: {openstack_id} --> {metadata}") - server: Server = self.openstack_connection.get_server_by_id(id=openstack_id) - if server is None: - logger.exception(f"Instance {openstack_id} not found") - raise ServerNotFoundException( - message=f"Instance {openstack_id} not found", - name_or_id=openstack_id, - ) + server: Server = self.get_server(openstack_id) self.openstack_connection.compute.set_server_metadata(server, metadata) except OpenStackCloudException as e: raise DefaultException( @@ -1019,17 +1014,7 @@ def get_server(self, openstack_id: str) -> Server: name_or_id=openstack_id, ) if server.vm_state == VmStates.ACTIVE.value: - fixed_ip = server.private_v4 - base_port = int(fixed_ip.split(".")[-1]) # noqa F841 - subnet_port = int(fixed_ip.split(".")[-2]) # noqa F841 - - x = sympy.symbols("x") - y = sympy.symbols("y") - ssh_port = int( - sympy.sympify(self.SSH_PORT_CALCULATION).evalf( - subs={x: base_port, y: subnet_port} - ) - ) + ssh_port, udp_port = self._calculate_vm_ports(server=server) if not self.netcat(host=self.GATEWAY_IP, port=ssh_port): server.task_state = VmTaskStates.CHECKING_SSH_CONNECTION.value @@ -1052,12 +1037,6 @@ def resume_server(self, openstack_id: str) -> None: logger.info(f"Resume Server {openstack_id}") try: server = self.get_server(openstack_id=openstack_id) - if server is None: - logger.exception(f"Instance {openstack_id} not found") - raise ServerNotFoundException( - message=f"Instance {openstack_id} not found", - name_or_id=openstack_id, - ) self.openstack_connection.compute.start_server(server) except ConflictException as e: @@ -1084,23 +1063,45 @@ def stop_server(self, openstack_id: str) -> None: logger.info(f"Stop Server {openstack_id}") server = self.get_server(openstack_id=openstack_id) try: - if server is None: - raise ServerNotFoundException( - message=f"Instance {openstack_id} not found", - name_or_id=openstack_id, - ) - self.openstack_connection.compute.stop_server(server) except ConflictException as e: logger.exception(f"Stop Server {openstack_id} failed!") raise OpenStackConflictException(message=e.message) + def _remove_security_groups_from_server(self, server: Server) -> None: + security_groups = server.security_groups + + if security_groups is not None: + for sg in security_groups: + sec = self.openstack_connection.get_security_group(name_or_id=sg["name"]) + logger.info(f"Remove security group {sec.id} from server {server.id}") + self.openstack_connection.compute.remove_security_group_from_server( + server=server, security_group=sec + ) + + if ( + sg["name"] != self.DEFAULT_SECURITY_GROUP_NAME + and ("bibigrid" not in sec.name or "master" not in server.name) + and not self.is_security_group_in_use(security_group_id=sec.id) + ): + logger.info(f"Delete security group {sec}") + + self.openstack_connection.delete_security_group(sec) + + def _validate_server_for_deletion(self, server: Server) -> None: + task_state = server.task_state + if task_state in [ + "image_snapshot", + "image_pending_upload", + "image_uploading", + ]: + raise ConflictException("task_state in image creating") + def delete_server(self, openstack_id: str) -> None: logger.info(f"Delete Server {openstack_id}") try: - server = self.get_server(openstack_id=openstack_id) - + server: Server = self.get_server(openstack_id=openstack_id) if not server: logger.error(f"Instance {openstack_id} not found") raise ServerNotFoundException( @@ -1108,29 +1109,8 @@ def delete_server(self, openstack_id: str) -> None: name_or_id=openstack_id, ) - task_state = server.get("task_state", None) - if task_state in [ - "image_snapshot", - "image_pending_upload", - "image_uploading", - ]: - raise ConflictException("task_state in image creating") - security_groups = server["security_groups"] - if security_groups is not None: - for sg in security_groups: - sec = self.openstack_connection.get_security_group( - name_or_id=sg["name"] - ) - logger.info(f"Delete security group {sec}") - self.openstack_connection.compute.remove_security_group_from_server( - server=server, security_group=sec - ) - if ( - sg["name"] != self.DEFAULT_SECURITY_GROUP_NAME - and ("bibigrid" not in sec.name or "master" not in server.name) - and not self.is_security_group_in_use(security_group_id=sec.id) - ): - self.openstack_connection.delete_security_group(sg) + self._validate_server_for_deletion(server=server) + self._remove_security_groups_from_server(server=server) self.openstack_connection.compute.delete_server(server.id, force=True) except ConflictException as e: @@ -1138,14 +1118,8 @@ def delete_server(self, openstack_id: str) -> None: raise OpenStackConflictException(message=e.message) - def get_vm_ports(self, openstack_id: str) -> dict[str, str]: - logger.info(f"Get IP and PORT for server {openstack_id}") - server = self.get_server(openstack_id=openstack_id) - if not server: - raise ServerNotFoundException( - message=f"Server {openstack_id} not found!", name_or_id=openstack_id - ) - fixed_ip = server["private_v4"] + def _calculate_vm_ports(self, server: Server): + fixed_ip = server.private_v4 base_port = int(fixed_ip.split(".")[-1]) # noqa F841 subnet_port = int(fixed_ip.split(".")[-2]) # noqa F841 @@ -1161,14 +1135,20 @@ def get_vm_ports(self, openstack_id: str) -> dict[str, str]: subs={x: base_port, y: subnet_port} ) ) + return ssh_port, udp_port + + def get_vm_ports(self, openstack_id: str) -> dict[str, str]: + logger.info(f"Get IP and PORT for server {openstack_id}") + server = self.get_server(openstack_id=openstack_id) + ssh_port, udp_port = self._calculate_vm_ports(server=server) return {"port": str(ssh_port), "udp": str(udp_port)} def create_userdata( - self, - volume_ids_path_new: list[dict[str, str]], - volume_ids_path_attach: list[dict[str, str]], - additional_keys: list[str], + self, + volume_ids_path_new: list[dict[str, str]], + volume_ids_path_attach: list[dict[str, str]], + additional_keys: list[str], ) -> str: unlock_ubuntu_user_script = "#!/bin/bash\npasswd -u ubuntu\n" unlock_ubuntu_user_script_encoded = encodeutils.safe_encode( @@ -1179,9 +1159,9 @@ def create_userdata( if additional_keys: add_key_script = self.create_add_keys_script(keys=additional_keys) init_script = ( - add_key_script - + encodeutils.safe_encode("\n".encode("utf-8")) - + init_script + add_key_script + + encodeutils.safe_encode("\n".encode("utf-8")) + + init_script ) if volume_ids_path_new or volume_ids_path_attach: mount_script = self.create_mount_init_script( @@ -1189,25 +1169,25 @@ def create_userdata( attach_volumes=volume_ids_path_attach, ) init_script = ( - init_script - + encodeutils.safe_encode("\n".encode("utf-8")) - + mount_script + init_script + + encodeutils.safe_encode("\n".encode("utf-8")) + + mount_script ) return init_script def start_server( - self, - flavor_name: str, - image_name: str, - servername: str, - metadata: dict[str, str], - public_key: str, - research_environment_metadata: Union[ResearchEnvironmentMetadata, None] = None, - volume_ids_path_new: Union[list[dict[str, str]], None] = None, - volume_ids_path_attach: Union[list[dict[str, str]], None] = None, - additional_keys: Union[list[str], None] = None, - additional_security_group_ids: Union[list[str], None] = None, + self, + flavor_name: str, + image_name: str, + servername: str, + metadata: dict[str, str], + public_key: str, + research_environment_metadata: Union[ResearchEnvironmentMetadata, None] = None, + volume_ids_path_new: Union[list[dict[str, str]], None] = None, + volume_ids_path_attach: Union[list[dict[str, str]], None] = None, + additional_keys: Union[list[str], None] = None, + additional_security_group_ids: Union[list[str], None] = None, ) -> str: logger.info(f"Start Server {servername}") @@ -1261,13 +1241,13 @@ def start_server( if key_name: self.delete_keypair(key_name=key_name) - logger.exception(f"Start Server {servername} error:{e}") + logger.exception(f"Start Server {servername} error") raise DefaultException(message=str(e)) def _get_volumes_machines_start( - self, - volume_ids_path_new: list[dict[str, str]] = None, - volume_ids_path_attach: list[dict[str, str]] = None, + self, + volume_ids_path_new: list[dict[str, str]] = None, + volume_ids_path_attach: list[dict[str, str]] = None, ) -> list[Volume]: volume_ids = [] volumes = [] @@ -1281,11 +1261,11 @@ def _get_volumes_machines_start( return volumes def _get_security_groups_starting_machine( - self, - additional_security_group_ids: Union[list[str], None] = None, - project_name: Union[str, None] = None, - project_id: Union[str, None] = None, - research_environment_metadata: Union[ResearchEnvironmentMetadata, None] = None, + self, + additional_security_group_ids: Union[list[str], None] = None, + project_name: Union[str, None] = None, + project_id: Union[str, None] = None, + research_environment_metadata: Union[ResearchEnvironmentMetadata, None] = None, ) -> list[str]: security_groups = self._get_default_security_groups() if research_environment_metadata: @@ -1310,16 +1290,16 @@ def _get_security_groups_starting_machine( return security_groups def start_server_with_playbook( - self, - flavor_name: str, - image_name: str, - servername: str, - metadata: dict[str, str], - research_environment_metadata: ResearchEnvironmentMetadata, - volume_ids_path_new: list[dict[str, str]] = None, # type: ignore - volume_ids_path_attach: list[dict[str, str]] = None, # type: ignore - additional_keys: list[str] = None, # type: ignore - additional_security_group_ids=None, # type: ignore + self, + flavor_name: str, + image_name: str, + servername: str, + metadata: dict[str, str], + research_environment_metadata: ResearchEnvironmentMetadata, + volume_ids_path_new: list[dict[str, str]] = None, # type: ignore + volume_ids_path_attach: list[dict[str, str]] = None, # type: ignore + additional_keys: list[str] = None, # type: ignore + additional_security_group_ids=None, # type: ignore ) -> tuple[str, str]: logger.info(f"Start Server {servername}") @@ -1429,16 +1409,16 @@ def add_udp_security_group(self, server_id): return def add_cluster_machine( - self, - cluster_id: str, - cluster_user: str, - cluster_group_id: list[str], - image_name: str, - flavor_name: str, - name: str, - key_name: str, - batch_idx: int, - worker_idx: int, + self, + cluster_id: str, + cluster_user: str, + cluster_group_id: list[str], + image_name: str, + flavor_name: str, + name: str, + key_name: str, + batch_idx: int, + worker_idx: int, ) -> str: logger.info(f"Add machine to {cluster_id}") image: Image = self.get_image(name_or_id=image_name, replace_inactive=True) diff --git a/simple_vm_client/test_openstack_connector.py b/simple_vm_client/test_openstack_connector.py index 89efed7..468d1d3 100644 --- a/simple_vm_client/test_openstack_connector.py +++ b/simple_vm_client/test_openstack_connector.py @@ -1,10 +1,12 @@ import os +import random import socket import tempfile import unittest from unittest import mock from unittest.mock import MagicMock, call, patch +import pytest from openstack.block_storage.v3 import volume from openstack.block_storage.v3.limits import Limit from openstack.block_storage.v3.volume import Volume @@ -19,7 +21,8 @@ from openstack.test import fakes from oslo_utils import encodeutils - +from simple_vm_client.forc_connector.template.template import ResearchEnvironmentMetadata +from simple_vm_client.util.state_enums import VmStates, VmTaskStates from .openstack_connector.openstack_connector import OpenStackConnector from .ttypes import ( DefaultException, @@ -27,9 +30,54 @@ OpenStackConflictException, ResourceNotAvailableException, SnapshotNotFoundException, - VolumeNotFoundException, + VolumeNotFoundException, ServerNotFoundException, +) + +METADATA_EXAMPLE_NO_FORC = ResearchEnvironmentMetadata( + template_name="example_template", + port="8080", + wiki_link="https://example.com/wiki", + description="Example template for testing", + title="Example Template", + community_driven=True, + logo_url="https://example.com/logo.png", + info_url="https://example.com/info", + securitygroup_name="example_group", + securitygroup_description="Example security group", + securitygroup_ssh=True, + direction="inbound", + protocol="tcp", + information_for_display="Some information", + needs_forc_support=False, + min_ram=2, + min_cores=1, + is_maintained=True, + forc_versions=["1.0.0", "2.0.0"], + incompatible_versions=["3.0.0"], ) +METADATA_EXAMPLE = ResearchEnvironmentMetadata( + template_name="example_template", + port="8080", + wiki_link="https://example.com/wiki", + description="Example template for testing", + title="Example Template", + community_driven=True, + logo_url="https://example.com/logo.png", + info_url="https://example.com/info", + securitygroup_name="example_group", + securitygroup_description="Example security group", + securitygroup_ssh=True, + direction="inbound", + protocol="tcp", + information_for_display="Some information", + needs_forc_support=True, + min_ram=2, + min_cores=1, + is_maintained=True, + forc_versions=["1.0.0", "2.0.0"], + incompatible_versions=["3.0.0"], +) EXPECTED_IMAGE = image_module.Image( id="image_id_2", status="active", @@ -63,15 +111,16 @@ ), INACTIVE_IMAGE, ] +PORT_CALCULATION = "30000 + x + y * 256" DEFAULT_SECURITY_GROUPS = ["defaultSimpleVM"] -CONFIG_DATA = """ +CONFIG_DATA = f""" openstack: gateway_ip: "192.168.1.1" network: "my_network" sub_network: "my_sub_network" cloud_site: "my_cloud_site" - ssh_port_calculation: 22 - udp_port_calculation: 12345 + ssh_port_calculation: {PORT_CALCULATION} + udp_port_calculation: {PORT_CALCULATION} gateway_security_group_id: "security_group_id" production: true forc: @@ -150,8 +199,8 @@ def test_load_config_yml(self): self.assertEqual(self.openstack_connector.SUB_NETWORK, "my_sub_network") self.assertTrue(self.openstack_connector.PRODUCTION) self.assertEqual(self.openstack_connector.CLOUD_SITE, "my_cloud_site") - self.assertEqual(self.openstack_connector.SSH_PORT_CALCULATION, 22) - self.assertEqual(self.openstack_connector.UDP_PORT_CALCULATION, 12345) + self.assertEqual(self.openstack_connector.SSH_PORT_CALCULATION, PORT_CALCULATION) + self.assertEqual(self.openstack_connector.UDP_PORT_CALCULATION, PORT_CALCULATION) self.assertEqual( self.openstack_connector.FORC_SECURITY_GROUP_ID, "forc_security_group_id" ) @@ -402,12 +451,18 @@ def test_replace_inactive_image(self): # Assert that the method returns the replacement image self.assertEqual(result, EXPECTED_IMAGE) - @unittest.skip("Currently not working") def test_get_limits(self): compute_limits = fakes.generate_fake_resource(limits.AbsoluteLimits) volume_limits = fakes.generate_fake_resource(Limit) - self.mock_openstack_connection.get_compute_limits.return_value = compute_limits - self.mock_openstack_connection.get_volume_limits.return_value = volume_limits + compute_copy = {} + for key in compute_limits.keys(): + compute_copy[key] = random.randint(0, 10000) + + absolute_volume = volume_limits["absolute"] + for key in absolute_volume.keys(): + volume_limits["absolute"][key] = random.randint(0, 10000) + self.openstack_connector.openstack_connection.get_compute_limits.return_value = compute_copy + self.openstack_connector.openstack_connection.get_volume_limits.return_value = volume_limits self.openstack_connector.get_limits() @patch("simple_vm_client.openstack_connector.openstack_connector.logger.info") @@ -493,7 +548,7 @@ def test_get_volume_exception(self, mock_logger_exception): # Call the get_volume method and expect a VolumeNotFoundException with self.assertRaises( - Exception + Exception ): # Replace Exception with the actual exception type self.openstack_connector.get_volume(name_or_id) @@ -528,21 +583,21 @@ def test_delete_volume(self, mock_logger_exception, mock_logger_info): # 2. ResourceNotFound, expect VolumeNotFoundException with self.assertRaises( - VolumeNotFoundException + VolumeNotFoundException ): # Replace Exception with the actual exception type self.openstack_connector.delete_volume(volume_id) mock_logger_exception.assert_called_with(f"No Volume with id {volume_id}") # 3. ConflictException, expect OpenStackCloudException with self.assertRaises( - OpenStackCloudException + OpenStackCloudException ): # Replace Exception with the actual exception type self.openstack_connector.delete_volume(volume_id) mock_logger_exception.assert_called_with(f"Delete volume: {volume_id}) failed!") # 4. OpenStackCloudException, expect DefaultException with self.assertRaises( - DefaultException + DefaultException ): # Replace Exception with the actual exception type self.openstack_connector.delete_volume(volume_id) @@ -642,14 +697,14 @@ def test_delete_volume_snapshot(self, mock_logger_exception, mock_logger_info): # 2. ResourceNotFound, expect SnapshotNotFoundException with self.assertRaises( - SnapshotNotFoundException + SnapshotNotFoundException ): # Replace Exception with the actual exception type self.openstack_connector.delete_volume_snapshot(snapshot_id) mock_logger_exception.assert_called_with(f"Snapshot not found: {snapshot_id}") # 3. ConflictException, expect OpenStackCloudException with self.assertRaises( - OpenStackCloudException + OpenStackCloudException ): # Replace Exception with the actual exception type self.openstack_connector.delete_volume_snapshot(snapshot_id) mock_logger_exception.assert_called_with( @@ -658,7 +713,7 @@ def test_delete_volume_snapshot(self, mock_logger_exception, mock_logger_info): # 4. OpenStackCloudException, expect DefaultException with self.assertRaises( - DefaultException + DefaultException ): # Replace Exception with the actual exception type self.openstack_connector.delete_volume_snapshot(snapshot_id) @@ -681,7 +736,7 @@ def test_get_servers(self, mock_logger_info): @patch("simple_vm_client.openstack_connector.openstack_connector.logger.exception") @patch("simple_vm_client.openstack_connector.openstack_connector.logger.info") def test_get_servers_by_ids( - self, mock_logger_info, mock_logger_exception, mock_logger_error + self, mock_logger_info, mock_logger_exception, mock_logger_error ): # Prepare test data server_ids = ["id1", "id2", "id3", "id4"] @@ -1198,12 +1253,12 @@ def test_get_public_images(self, mock_logger_info): @patch.object(OpenStackConnector, "create_server") @mock.patch("simple_vm_client.openstack_connector.openstack_connector.logger.info") def test_add_cluster_machine( - self, - mock_logger_info, - mock_create_server, - mock_get_network, - mock_get_flavor, - mock_get_image, + self, + mock_logger_info, + mock_create_server, + mock_get_network, + mock_get_flavor, + mock_get_image, ): # Arrange cluster_id = "123" @@ -1296,7 +1351,7 @@ def test_add_udp_security_group_existing_group(self): @patch.object(OpenStackConnector, "create_security_group") @mock.patch("simple_vm_client.openstack_connector.openstack_connector.logger.info") def test_add_udp_security_group_new_group( - self, mock_logger_info, mock_create_security_group, mock_get_vm_ports + self, mock_logger_info, mock_create_security_group, mock_get_vm_ports ): # Test when a new UDP security group needs to be created @@ -1379,11 +1434,11 @@ def test_add_udp_security_group_already_added(self, mock_logger_info): @patch.object(OpenStackConnector, "create_userdata") @patch.object(OpenStackConnector, "delete_keypair") def test_start_server_with_playbook( - self, - mock_delete_keypair, - mock_create_userdata, - mock_get_volumes, - mock_get_security_groups_starting_machine, + self, + mock_delete_keypair, + mock_create_userdata, + mock_get_volumes, + mock_get_security_groups_starting_machine, ): server = fakes.generate_fake_resource(Server) server_keypair = fakes.generate_fake_resource(keypair.Keypair) @@ -1485,12 +1540,12 @@ def test_start_server_with_playbook( @patch.object(OpenStackConnector, "delete_keypair") @patch("simple_vm_client.openstack_connector.openstack_connector.logger.exception") def test_start_server_with_playbook_exception( - self, - mock_logger_exception, - mock_delete_keypair, - mock_create_userdata, - mock_get_volumes, - mock_get_security_groups_starting_machine, + self, + mock_logger_exception, + mock_delete_keypair, + mock_create_userdata, + mock_get_volumes, + mock_get_security_groups_starting_machine, ): server = fakes.generate_fake_resource(Server) server_keypair = fakes.generate_fake_resource(keypair.Keypair) @@ -1553,10 +1608,10 @@ def test_start_server_with_playbook_exception( ) @patch.object(OpenStackConnector, "get_or_create_project_security_group") def test_get_security_groups_starting_machine( - self, - mock_get_project_sg, - mock_get_research_env_sg, - mock_get_default_security_groups, + self, + mock_get_project_sg, + mock_get_research_env_sg, + mock_get_default_security_groups, ): # Set up mocks fake_default_security_group = fakes.generate_fake_resource( @@ -1637,6 +1692,743 @@ def test_get_volumes_machines_start(self): expected_result = [fake_vol_1, fake_vol_2] self.assertEqual(result, expected_result) + @patch.object(OpenStackConnector, "_get_security_groups_starting_machine") + @patch.object(OpenStackConnector, "_get_volumes_machines_start") + @patch.object(OpenStackConnector, "create_userdata") + @patch.object(OpenStackConnector, "delete_keypair") + def test_start_server( + self, + mock_delete_keypair, + mock_create_userdata, + mock_get_volumes, + mock_get_security_groups_starting_machine, + ): + server = fakes.generate_fake_resource(Server) + server_keypair = fakes.generate_fake_resource(keypair.Keypair) + server_keypair.name = server.name + "_mock_project" + fake_image = fakes.generate_fake_resource(image.Image) + fake_image.status = "active" + fake_flavor = fakes.generate_fake_resource(flavor.Flavor) + fake_network = fakes.generate_fake_resource(Network) + + # Set up mocks + self.openstack_connector.openstack_connection.create_server.return_value = ( + server + ) + self.openstack_connector.openstack_connection.create_keypair.return_value = ( + server_keypair + ) + mock_get_security_groups_starting_machine.return_value = ["sg1", "sg2"] + self.openstack_connector.openstack_connection.get_image.return_value = ( + fake_image + ) + self.openstack_connector.openstack_connection.get_flavor.return_value = ( + fake_flavor + ) + self.openstack_connector.openstack_connection.network.find_network.return_value = ( + fake_network + ) + self.openstack_connector.openstack_connection.compute.find_keypair.return_value = server_keypair + self.openstack_connector.openstack_connection.compute.get_keypair.return_value = server_keypair + + mock_get_volumes.return_value = ["volume1", "volume2"] + mock_create_userdata.return_value = "userdata" + + # Set necessary input parameters + flavor_name = fake_flavor.name + image_name = fake_image.name + servername = server.name + metadata = {"project_name": "mock_project", "project_id": "mock_project_id"} + research_environment_metadata = MagicMock() + volume_ids_path_new = [ + {"openstack_id": "volume_id1"}, + {"openstack_id": "volume_id2"}, + ] + volume_ids_path_attach = [{"openstack_id": "volume_id3"}] + additional_keys = ["key1", "key2"] + additional_security_group_ids = ["sg3", "sg4"] + public_key = "public_key" + + # Call the method + result = self.openstack_connector.start_server( + flavor_name=flavor_name, + image_name=image_name, + servername=servername, + metadata=metadata, + research_environment_metadata=research_environment_metadata, + volume_ids_path_new=volume_ids_path_new, + volume_ids_path_attach=volume_ids_path_attach, + additional_keys=additional_keys, + additional_security_group_ids=additional_security_group_ids, + public_key=public_key + ) + + # Assertions + self.openstack_connector.openstack_connection.create_server.assert_called_once_with( + name=server.name, + image=fake_image.id, + flavor=fake_flavor.id, + network=[fake_network.id], + key_name=server_keypair.name, + meta=metadata, + volumes=["volume1", "volume2"], + userdata="userdata", + security_groups=["sg1", "sg2"], + ) + + mock_create_userdata.assert_called_once_with( + volume_ids_path_new=volume_ids_path_new, + volume_ids_path_attach=volume_ids_path_attach, + additional_keys=additional_keys, + ) + + mock_get_security_groups_starting_machine.assert_called_once_with( + additional_security_group_ids=additional_security_group_ids, + project_name="mock_project", + project_id="mock_project_id", + research_environment_metadata=research_environment_metadata, + ) + + self.openstack_connector.openstack_connection.create_keypair.assert_called_once_with( + name=server_keypair.name, public_key=public_key + ) + + mock_get_volumes.assert_called_once_with( + volume_ids_path_new=volume_ids_path_new, + volume_ids_path_attach=volume_ids_path_attach, + ) + + self.openstack_connector.openstack_connection.get_keypair.assert_called_once_with(name_or_id=server_keypair.name) + mock_delete_keypair.assert_any_call(key_name=server_keypair.name) + + # Check the result + self.assertEqual(result, server.id) + + @patch.object(OpenStackConnector, "_get_security_groups_starting_machine") + @patch.object(OpenStackConnector, "_get_volumes_machines_start") + @patch.object(OpenStackConnector, "create_userdata") + @patch.object(OpenStackConnector, "delete_keypair") + @patch("simple_vm_client.openstack_connector.openstack_connector.logger.exception") + def test_start_server_exception( + self, + mock_logger_exception, + mock_delete_keypair, + mock_create_userdata, + mock_get_volumes, + mock_get_security_groups_starting_machine, + ): + server = fakes.generate_fake_resource(Server) + server_keypair = fakes.generate_fake_resource(keypair.Keypair) + fake_image = fakes.generate_fake_resource(image.Image) + fake_image.status = "active" + fake_flavor = fakes.generate_fake_resource(flavor.Flavor) + fake_network = fakes.generate_fake_resource(Network) + public_key = "public_key" + + # Set up mocks + self.openstack_connector.openstack_connection.create_server.return_value = ( + server + ) + self.openstack_connector.openstack_connection.create_keypair.return_value = ( + server_keypair + ) + mock_get_security_groups_starting_machine.return_value = ["sg1", "sg2"] + self.openstack_connector.openstack_connection.get_image.return_value = ( + fake_image + ) + self.openstack_connector.openstack_connection.get_flavor.return_value = ( + fake_flavor + ) + self.openstack_connector.openstack_connection.network.find_network.return_value = ( + fake_network + ) + mock_get_volumes.side_effect = OpenStackCloudException("Unit Test Error") + flavor_name = fake_flavor.name + image_name = fake_image.name + servername = server.name + metadata = {"project_name": "mock_project", "project_id": "mock_project_id"} + research_environment_metadata = MagicMock() + volume_ids_path_new = [ + {"openstack_id": "volume_id1"}, + {"openstack_id": "volume_id2"}, + ] + volume_ids_path_attach = [{"openstack_id": "volume_id3"}] + additional_keys = ["key1", "key2"] + additional_security_group_ids = ["sg3", "sg4"] + + with self.assertRaises(DefaultException): + self.openstack_connector.start_server( + flavor_name=flavor_name, + image_name=image_name, + servername=servername, + metadata=metadata, + research_environment_metadata=research_environment_metadata, + volume_ids_path_new=volume_ids_path_new, + volume_ids_path_attach=volume_ids_path_attach, + additional_keys=additional_keys, + additional_security_group_ids=additional_security_group_ids, + public_key=public_key + ) + mock_logger_exception.assert_any_call( + (f"Start Server {servername} error") + ) + + @patch.object(OpenStackConnector, "create_add_keys_script") + @patch.object(OpenStackConnector, "create_mount_init_script") + def test_create_userdata( + self, mock_create_mount_init_script, + mock_create_add_keys_script + ): + # Set up mocks + mock_create_add_keys_script.return_value = b"mock_add_keys_script" + mock_create_mount_init_script.return_value = b"mock_mount_script" + + # Set necessary input parameters + volume_ids_path_new = [{"openstack_id": "volume_id_new"}] + volume_ids_path_attach = [{"openstack_id": "volume_id_attach"}] + additional_keys = ["key1", "key2"] + + # Call the method + result = self.openstack_connector.create_userdata( + volume_ids_path_new, volume_ids_path_attach, additional_keys + ) + + # Assertions + mock_create_add_keys_script.assert_called_once_with(keys=additional_keys) + mock_create_mount_init_script.assert_called_once_with( + new_volumes=volume_ids_path_new, attach_volumes=volume_ids_path_attach + ) + + # Check the result + expected_result = ( + b"mock_add_keys_script\n" + + b"#!/bin/bash\npasswd -u ubuntu\n" + + b"\nmock_mount_script" + ) + self.assertEqual(result, expected_result) + + @patch.object(OpenStackConnector, "get_server") + @patch("simple_vm_client.openstack_connector.openstack_connector.sympy.symbols") + @patch("simple_vm_client.openstack_connector.openstack_connector.sympy.sympify") + def test_get_vm_ports( + self, mock_sympify, mock_symbols, mock_get_server + ): + # Set up mocks + mock_server = fakes.generate_fake_resource(server.Server) + mock_server["private_v4"] = "192.168.1.2" + + mock_get_server.return_value = mock_server + mock_sympify.return_value.evalf.return_value = 30258 # Replace with expected values + mock_symbols.side_effect = ["x", "y"] + + # Call the method + result = self.openstack_connector.get_vm_ports(mock_server.id) + + # Assertions + mock_get_server.assert_called_once_with(openstack_id=mock_server.id) + mock_symbols.assert_any_call("x") + mock_symbols.assert_any_call("y") + + mock_sympify.assert_called_with(self.openstack_connector.SSH_PORT_CALCULATION) + mock_sympify.return_value.evalf.assert_called_with(subs={"x": 2, "y": 1}) + + # Check the result + expected_result = {"port": "30258", "udp": "30258"} # Replace with expected values + self.assertEqual(result, expected_result) + + @patch.object(OpenStackConnector, "get_server") + @patch.object(OpenStackConnector, "_validate_server_for_deletion") + @patch.object(OpenStackConnector, "_remove_security_groups_from_server") + def test_delete_server_successful(self, mock_remove_security_groups, mock_validate_server, mock_get_server): + # Arrange + mock_server = fakes.generate_fake_resource(server.Server) + + mock_get_server.return_value = mock_server + # Act + self.openstack_connector.delete_server(mock_server.id) + + # Assert + mock_get_server.assert_called_once_with(openstack_id=mock_server.id) + mock_validate_server.assert_called_once_with(server=mock_server) + mock_remove_security_groups.assert_called_once_with(server=mock_server) + self.openstack_connector.openstack_connection.compute.delete_server.assert_called_once_with(mock_server.id, force=True) + + @patch.object(OpenStackConnector, "get_server") + def test_delete_server_exception(self, mock_get_server): + # Arrange + mock_server = fakes.generate_fake_resource(server.Server) + # Mocking the necessary methods to raise a ConflictException + mock_get_server.side_effect = ConflictException("Conflict") + # Act + + # Act and Assert + with self.assertRaises(OpenStackConflictException): + self.openstack_connector.delete_server(mock_server.id) + + mock_get_server.assert_called_once_with(openstack_id=mock_server.id) + self.openstack_connector.openstack_connection.compute.delete_server.assert_not_called() + + @patch.object(OpenStackConnector, "get_server") + def test_delete_server_not_found_exception(self, mock_get_server): + # Arrange + # Mocking the necessary methods to raise a ConflictException + server_id = "not_found" + mock_get_server.return_value = None + # Act + + # Act and Assert + with self.assertRaises(ServerNotFoundException): + self.openstack_connector.delete_server(server_id) + + mock_get_server.assert_called_once_with(openstack_id=server_id) + self.openstack_connector.openstack_connection.compute.delete_server.assert_not_called() + + def test_validate_server_for_deletion(self): + # Arrange + server_mock = fakes.generate_fake_resource(server.Server) + + # Act + self.openstack_connector._validate_server_for_deletion(server_mock) + + # Assert + # No exceptions should be raised if the server is found + self.assertTrue(True) + + def test_validate_server_for_deletion_conflict_exception(self): + # Arrange + server_mock = fakes.generate_fake_resource(server.Server) + server_mock.task_state = "image_pending_upload" + with self.assertRaises(ConflictException): + # Act + self.openstack_connector._validate_server_for_deletion(server_mock) + + def test_remove_security_groups_from_server_no_security_groups(self): + # Arrange + server_mock = fakes.generate_fake_resource(server.Server) + server_mock.security_groups = None + + # Act + self.openstack_connector._remove_security_groups_from_server(server_mock) + + # Assert + # The method should not raise any exceptions if there are no security groups + self.assertTrue(True) + + @patch.object(OpenStackConnector, "is_security_group_in_use") + def test_remove_security_groups_from_server_with_security_groups(self, mock_is_security_group_in_use): + # Arrange + server_mock = fakes.generate_fake_resource(server.Server) + fake_groups = list(fakes.generate_fake_resources(security_group.SecurityGroup, count=4)) + fake_groups[2].name = "bibigrid-sec" + server_mock.security_groups = fake_groups + self.openstack_connector.openstack_connection.get_security_group.side_effect = fake_groups + mock_is_security_group_in_use.side_effect = [False, False, True, True] + + # Act + self.openstack_connector._remove_security_groups_from_server(server_mock) + + for group in fake_groups: + self.openstack_connector.openstack_connection.compute.remove_security_group_from_server.assert_any_call( + server=server_mock, security_group=group + ) + + with self.assertRaises(AssertionError): + self.openstack_connector.openstack_connection.delete_security_group.assert_any_call(fake_groups[2]) + self.openstack_connector.openstack_connection.delete_security_group.assert_any_call(fake_groups[3]) + + for group in fake_groups[:2]: + self.openstack_connector.openstack_connection.delete_security_group.assert_any_call(group) + + @patch.object(OpenStackConnector, "get_server") + def test_stop_server_success(self, mock_get_server): + # Arrange + server_mock = fakes.generate_fake_resource(server.Server) + mock_get_server.return_value = server_mock + # Act + self.openstack_connector.stop_server(openstack_id="some_openstack_id") + + # Assert + # Ensure the stop_server method is called with the correct server + self.openstack_connector.openstack_connection.compute.stop_server.assert_called_once_with( + server_mock + ) + + @patch.object(OpenStackConnector, "get_server") + @patch("simple_vm_client.openstack_connector.openstack_connector.logger.exception") + def test_stop_server_conflict_exception(self, mock_logger_exception, mock_get_server): + # Arrange + server_mock = fakes.generate_fake_resource(server.Server) + mock_get_server.return_value = server_mock + self.openstack_connector.openstack_connection.compute.stop_server.side_effect = ConflictException("Unit Test") + # Act + with self.assertRaises(OpenStackConflictException): + self.openstack_connector.stop_server(openstack_id="some_openstack_id") + mock_logger_exception.assert_called_once_with(f"Stop Server some_openstack_id failed!") + + @patch.object(OpenStackConnector, "get_server") + def test_reboot_server_success(self, mock_get_server): + # Arrange + server_mock = fakes.generate_fake_resource(server.Server) + mock_get_server.return_value = server_mock + # Act + self.openstack_connector.reboot_server(server_mock.id, "SOFT") + + # Assert + # Ensure the stop_server method is called with the correct server + self.openstack_connector.openstack_connection.compute.reboot_server.assert_called_once_with( + server_mock, "SOFT" + ) + + @patch.object(OpenStackConnector, "get_server") + @patch("simple_vm_client.openstack_connector.openstack_connector.logger.exception") + def test_reboot_server_conflict_exception(self, mock_logger_exception, mock_get_server): + + self.openstack_connector.openstack_connection.compute.reboot_server.side_effect = ConflictException("Unit Test") + # Act + with self.assertRaises(OpenStackConflictException): + self.openstack_connector.reboot_server("some_openstack_id", "SOFT") + mock_logger_exception.assert_called_once_with(f"Reboot Server some_openstack_id failed!") + + @patch.object(OpenStackConnector, "get_server") + def test_reboot_soft_server(self, mock_get_server): + # Arrange + server_mock = fakes.generate_fake_resource(server.Server) + mock_get_server.return_value = server_mock + self.openstack_connector.reboot_soft_server(server_mock.id) + self.openstack_connector.openstack_connection.compute.reboot_server.assert_called_once_with( + server_mock, "SOFT" + ) + + @patch.object(OpenStackConnector, "get_server") + def test_reboot_hard_server(self, mock_get_server): + # Arrange + server_mock = fakes.generate_fake_resource(server.Server) + mock_get_server.return_value = server_mock + self.openstack_connector.reboot_hard_server(server_mock.id) + self.openstack_connector.openstack_connection.compute.reboot_server.assert_called_once_with( + server_mock, "HARD" + ) + + @patch.object(OpenStackConnector, "get_server") + def test_resume_server_success(self, mock_get_server): + # Arrange + server_mock = fakes.generate_fake_resource(server.Server) + mock_get_server.return_value = server_mock + # Act + self.openstack_connector.resume_server(server_mock.id) + + # Assert + # Ensure the stop_server method is called with the correct server + self.openstack_connector.openstack_connection.compute.start_server.assert_called_once_with( + server_mock + ) + + @patch.object(OpenStackConnector, "get_server") + @patch("simple_vm_client.openstack_connector.openstack_connector.logger.exception") + def test_resume_server_conflict_exception(self, mock_logger_exception, mock_get_server): + + self.openstack_connector.openstack_connection.compute.start_server.side_effect = ConflictException("Unit Test") + # Act + with self.assertRaises(OpenStackConflictException): + self.openstack_connector.resume_server("some_openstack_id") + mock_logger_exception.assert_called_once_with(f"Resume Server some_openstack_id failed!") + + @patch.object(OpenStackConnector, "_calculate_vm_ports") + @patch.object(OpenStackConnector, "get_image") + @patch.object(OpenStackConnector, "get_flavor") + @patch.object(OpenStackConnector, "netcat") + def test_get_server(self, mock_netcat, + mock_get_flavor, mock_get_image, mock_calculate_ports): + # Arrange + openstack_id = "your_openstack_id" + server_mock = fakes.generate_fake_resource(server.Server) + server_mock.vm_state = VmStates.ACTIVE.value + image_mock = fakes.generate_fake_resource(image.Image) + server_mock.image = image_mock + flavor_mock = fakes.generate_fake_resource(flavor.Flavor) + server_mock.flavor = flavor_mock + + # Mocking the methods and attributes + self.openstack_connector.openstack_connection.get_server_by_id.return_value = server_mock + mock_get_image.return_value = image_mock + mock_get_flavor.return_value = flavor_mock + mock_calculate_ports.return_value = (30111, 30111) + mock_netcat.return_value = True # Assuming SSH connection is successful + + # Act + self.openstack_connector.get_server(openstack_id) + + # Assert + self.openstack_connector.openstack_connection.get_server_by_id.assert_called_once_with(id=openstack_id) + mock_calculate_ports.assert_called_once_with(server=server_mock) + mock_netcat.assert_called_once_with(host=self.openstack_connector.GATEWAY_IP, port=30111) + mock_get_image.assert_called_once_with( + name_or_id=image_mock.id, + ignore_not_active=True, + ignore_not_found=True, + ) + mock_get_flavor.assert_called_once_with(name_or_id=flavor_mock.id) + mock_netcat.return_value = False # Assuming SSH connection is successful + # Act + result_server = self.openstack_connector.get_server(openstack_id) + self.assertEqual(result_server.task_state, VmTaskStates.CHECKING_SSH_CONNECTION.value) + + def test_get_server_not_found(self): + self.openstack_connector.openstack_connection.get_server_by_id.return_value = None + + with self.assertRaises(ServerNotFoundException): + self.openstack_connector.get_server("someid") + + def test_get_server_openstack_exception(self): + self.openstack_connector.openstack_connection.get_server_by_id.side_effect = OpenStackCloudException("UNit Test") + + with self.assertRaises(DefaultException): + self.openstack_connector.get_server("someid") + + @patch.object(OpenStackConnector, "get_server") + def test_set_server_metadata_success(self, mock_get_server): + # Arrange + server_mock = fakes.generate_fake_resource(server.Server) + mock_get_server.return_value = server_mock + metadata = {"data": "123"} + # Act + self.openstack_connector.set_server_metadata(server_mock.id, metadata) + + # Assert + # Ensure the stop_server method is called with the correct server + self.openstack_connector.openstack_connection.compute.set_server_metadata.assert_called_once_with( + server_mock, metadata + ) + + @patch.object(OpenStackConnector, "get_server") + def test_set_server_metadata_exception(self, mock_get_server): + # Arrange + server_mock = fakes.generate_fake_resource(server.Server) + mock_get_server.return_value = server_mock + metadata = {"data": "123"} + self.openstack_connector.openstack_connection.compute.set_server_metadata.side_effect = OpenStackCloudException("Unit Tests") + # Act + with self.assertRaises(DefaultException): + self.openstack_connector.set_server_metadata(server_mock.id, metadata) + + @patch.object(OpenStackConnector, "get_server") + @patch("simple_vm_client.openstack_connector.openstack_connector.logger.exception") + def test_reboot_server_conflict_exception(self, mock_logger_exception, mock_get_server): + + self.openstack_connector.openstack_connection.compute.reboot_server.side_effect = ConflictException("Unit Test") + # Act + with self.assertRaises(OpenStackConflictException): + self.openstack_connector.reboot_server("some_openstack_id", "SOFT") + mock_logger_exception.assert_called_once_with(f"Reboot Server some_openstack_id failed!") + + def test_exist_server_true(self): + server_mock = fakes.generate_fake_resource(server.Server) + + self.openstack_connector.openstack_connection.compute.find_server.return_value = server_mock + + result = self.openstack_connector.exist_server(server_mock.name) + self.assertTrue(result) + + def test_exist_server_false(self): + server_mock = fakes.generate_fake_resource(server.Server) + + self.openstack_connector.openstack_connection.compute.find_server.return_value = None + + result = self.openstack_connector.exist_server(server_mock.name) + self.assertFalse(result) + + def test_get_or_create_project_security_group_exists(self): + # Mock the get_security_group method to simulate an existing security group + existing_security_group = fakes.generate_fake_resource(security_group.SecurityGroup) + + self.openstack_connector.openstack_connection.get_security_group.return_value = existing_security_group + + # Call the method + result = self.openstack_connector.get_or_create_project_security_group("project_name", "project_id") + + # Assertions + self.assertEqual(result, existing_security_group.id) + self.openstack_connector.openstack_connection.create_security_group.assert_not_called() + + def test_get_or_create_project_security_group_create_new(self): + # Mock the get_security_group method to simulate a non-existing security group + self.openstack_connector.openstack_connection.get_security_group.return_value = None + + # Mock the create_security_group method to simulate creating a new security group + new_security_group = fakes.generate_fake_resource(security_group.SecurityGroup) + self.openstack_connector.openstack_connection.create_security_group.return_value = new_security_group + + # Call the method + result = self.openstack_connector.get_or_create_project_security_group("project_name", "project_id") + + # Assertions + self.assertEqual(result, new_security_group.id) + self.openstack_connector.openstack_connection.create_security_group.assert_called_once() + + def test_get_or_create_vm_security_group_exist(self): + # Mock the get_security_group method to simulate an existing security group + existing_security_group = fakes.generate_fake_resource(security_group.SecurityGroup) + + self.openstack_connector.openstack_connection.get_security_group.return_value = existing_security_group + + # Call the method + result = self.openstack_connector.get_or_create_vm_security_group("server_id") + + # Assertions + self.assertEqual(result, existing_security_group.id) + self.openstack_connector.openstack_connection.create_security_group.assert_not_called() + + def test_get_or_create_vm_security_group_create_new(self): + # Mock the get_security_group method to simulate a non-existing security group + self.openstack_connector.openstack_connection.get_security_group.return_value = None + + # Mock the create_security_group method to simulate creating a new security group + new_security_group = fakes.generate_fake_resource(security_group.SecurityGroup) + self.openstack_connector.openstack_connection.create_security_group.return_value = new_security_group + + # Call the method + result = self.openstack_connector.get_or_create_vm_security_group("openstack_id") + + # Assertions + self.assertEqual(result, new_security_group.id) + self.openstack_connector.openstack_connection.create_security_group.assert_called_once() + + def test_get_or_create_research_environment_security_group_exist(self): + # Mock the get_security_group method to simulate an existing security group + existing_security_group = fakes.generate_fake_resource(security_group.SecurityGroup) + + self.openstack_connector.openstack_connection.get_security_group.return_value = existing_security_group + + # Call the method + result = self.openstack_connector.get_or_create_research_environment_security_group(resenv_metadata=METADATA_EXAMPLE) + + # Assertions + self.assertEqual(result, existing_security_group.id) + self.openstack_connector.openstack_connection.create_security_group.assert_not_called() + + def test_get_or_create_research_environment_no_forc_support(self): + # Mock the get_security_group method to simulate a non-existing security group + self.openstack_connector.get_or_create_research_environment_security_group(resenv_metadata=METADATA_EXAMPLE_NO_FORC) + self.openstack_connector.openstack_connection.get_security_group.assert_not_called() + + def test_get_or_create_research_environment_security_group_new(self): + # Mock the get_security_group method to simulate a non-existing security group + self.openstack_connector.openstack_connection.get_security_group.return_value = None + + # Mock the create_security_group method to simulate creating a new security group + new_security_group = fakes.generate_fake_resource(security_group.SecurityGroup) + self.openstack_connector.openstack_connection.create_security_group.return_value = new_security_group + + # Call the method + result = self.openstack_connector.get_or_create_research_environment_security_group(resenv_metadata=METADATA_EXAMPLE) + + # Assertions + self.assertEqual(result, new_security_group.id) + self.openstack_connector.openstack_connection.create_security_group.assert_called_once() + + def test_is_security_group_in_use_instances(self): + # Mock the compute.servers method to simulate instances using the security group + instances = [{"id": "instance_id", "name": "instance_name"}] + self.openstack_connector.openstack_connection.compute.servers = MagicMock(return_value=instances) + + # Call the method + result = self.openstack_connector.is_security_group_in_use("security_group_id") + + # Assertions + self.assertTrue(result) + + def test_is_security_group_in_use_ports(self): + # Mock the network.ports method to simulate ports associated with the security group + ports = [{"id": "port_id", "name": "port_name"}] + self.openstack_connector.openstack_connection.network.ports = MagicMock(return_value=ports) + + # Call the method + result = self.openstack_connector.is_security_group_in_use("security_group_id") + + # Assertions + self.assertTrue(result) + + def test_is_security_group_in_use_load_balancers(self): + # Mock the network.load_balancers method to simulate load balancers associated with the security group + load_balancers = [{"id": "lb_id", "name": "lb_name"}] + self.openstack_connector.openstack_connection.network.load_balancers = MagicMock(return_value=load_balancers) + + # Call the method + result = self.openstack_connector.is_security_group_in_use("security_group_id") + + # Assertions + self.assertTrue(result) + + def test_is_security_group_not_in_use(self): + # Mock both compute.servers and network.ports methods to simulate no usage of the security group + self.openstack_connector.openstack_connection.compute.servers = MagicMock(return_value=[]) + self.openstack_connector.openstack_connection.network.ports = MagicMock(return_value=[]) + self.openstack_connector.openstack_connection.network.load_balancers = MagicMock(return_value=[]) + + + # Call the method + result = self.openstack_connector.is_security_group_in_use("security_group_id") + + # Assertions + self.assertFalse(result) + + + def test_create_security_group(self): + # Mock the get_security_group method to simulate non-existing security group + self.openstack_connector.openstack_connection.get_security_group.return_value = None + + # Mock the create_security_group method to return a fake SecurityGroup + fake_sg = fakes.generate_fake_resource(security_group.SecurityGroup) + self.openstack_connector.openstack_connection.create_security_group.return_value =fake_sg + + # Call the method + result = self.openstack_connector.create_security_group( + name=fake_sg.name, + udp_port=1234, + ssh=True, + udp=True, + description=fake_sg.description, + research_environment_metadata=METADATA_EXAMPLE, + ) + + # Assertions + self.assertEqual(result, fake_sg) + self.openstack_connector.openstack_connection.create_security_group.assert_called_once_with(name=fake_sg.name, description=fake_sg.description) + self.openstack_connector.openstack_connection.create_security_group_rule.assert_any_call( + direction="ingress", + protocol="udp", + port_range_max=1234, + port_range_min=1234, + secgroup_name_or_id=fake_sg.id, + remote_group_id=self.openstack_connector.GATEWAY_SECURITY_GROUP_ID, + ) + self.openstack_connector.openstack_connection.create_security_group_rule.assert_any_call( + direction="ingress", + protocol="udp", + ethertype="IPv6", + port_range_max=1234, + port_range_min=1234, + secgroup_name_or_id=fake_sg.id, + remote_group_id=self.openstack_connector.GATEWAY_SECURITY_GROUP_ID, + ) + + self.openstack_connector.openstack_connection.create_security_group_rule.assert_any_call( + direction="ingress", + protocol="tcp", + port_range_max=22, + port_range_min=22, + secgroup_name_or_id=fake_sg.id, + remote_group_id=self.openstack_connector.GATEWAY_SECURITY_GROUP_ID, + ) + self.openstack_connector.openstack_connection.create_security_group_rule.assert_any_call( + direction="ingress", + protocol="tcp", + ethertype="IPv6", + port_range_max=22, + port_range_min=22, + secgroup_name_or_id=fake_sg.id, + remote_group_id=self.openstack_connector.GATEWAY_SECURITY_GROUP_ID, + ) if __name__ == "__main__": unittest.main()