Skip to content

Commit

Permalink
handler 100%
Browse files Browse the repository at this point in the history
  • Loading branch information
dweinholz committed Jan 4, 2024
1 parent e2881db commit 4b7eb24
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 82 deletions.
158 changes: 76 additions & 82 deletions simple_vm_client/VirtualMachineHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,28 @@
"""
from __future__ import annotations

from typing import TYPE_CHECKING

from simple_vm_client.bibigrid_connector.bibigrid_connector import BibigridConnector
from simple_vm_client.forc_connector.forc_connector import ForcConnector
from simple_vm_client.openstack_connector.openstack_connector import OpenStackConnector
from simple_vm_client.util import thrift_converter
from simple_vm_client.util.logger import setup_custom_logger

from .ttypes import (
VM,
Backend,
ClusterInfo,
ClusterInstance,
CondaPackage,
Flavor,
Image,
PlaybookResult,
ResearchEnvironmentTemplate,
Snapshot,
Volume,
)
from .VirtualMachineService import Iface

if TYPE_CHECKING:
from ttypes import (
VM,
Backend,
ClusterInfo,
ClusterInstance,
CondaPackage,
Flavor,
Image,
PlaybookResult,
ResearchEnvironmentTemplate,
Snapshot,
Volume,
)

logger = setup_custom_logger(__name__)


Expand Down Expand Up @@ -144,15 +141,12 @@ def get_server(self, openstack_id: str) -> VM:
return server

def get_servers(self) -> list[VM]:

servers = openstack_servers = self.openstack_connector.get_servers()
servers_full=[]
servers_full = []

for server in servers:
servers_full.append(self.forc_connector.get_playbook_status(server=server))
serv = thrift_converter.os_to_thrift_servers(
openstack_servers=servers
)
serv = thrift_converter.os_to_thrift_servers(openstack_servers=servers)
return servers_full

def get_servers_by_ids(self, server_ids: list[str]) -> list[VM]:
Expand All @@ -179,12 +173,12 @@ def get_forc_url(self) -> str:
return self.forc_connector.get_forc_access_url()

def create_snapshot(
self,
openstack_id: str,
name: str,
username: str,
base_tags: list[str],
description: str,
self,
openstack_id: str,
name: str,
username: str,
base_tags: list[str],
description: str,
) -> str:
return self.openstack_connector.create_snapshot(
openstack_id=openstack_id,
Expand All @@ -198,7 +192,7 @@ def delete_image(self, image_id: str) -> None:
return self.openstack_connector.delete_image(image_id=image_id)

def create_volume(
self, volume_name: str, volume_storage: int, metadata: dict[str, str]
self, volume_name: str, volume_storage: int, metadata: dict[str, str]
) -> Volume:
return thrift_converter.os_to_thrift_volume(
openstack_volume=self.openstack_connector.create_volume(
Expand All @@ -209,7 +203,7 @@ def create_volume(
)

def create_volume_by_source_volume(
self, volume_name: str, metadata: dict[str, str], source_volume_id: str
self, volume_name: str, metadata: dict[str, str], source_volume_id: str
) -> Volume:
return thrift_converter.os_to_thrift_volume(
openstack_volume=self.openstack_connector.create_volume_by_source_volume(
Expand All @@ -220,7 +214,7 @@ def create_volume_by_source_volume(
)

def create_volume_by_volume_snap(
self, volume_name: str, metadata: dict[str, str], volume_snap_id: str
self, volume_name: str, metadata: dict[str, str], volume_snap_id: str
) -> Volume:
return thrift_converter.os_to_thrift_volume(
openstack_volume=self.openstack_connector.create_volume_by_volume_snap(
Expand All @@ -231,7 +225,7 @@ def create_volume_by_volume_snap(
)

def create_volume_snapshot(
self, volume_id: str, name: str, description: str
self, volume_id: str, name: str, description: str
) -> str:
return self.openstack_connector.create_volume_snapshot(
volume_id=volume_id, name=name, description=description
Expand All @@ -256,7 +250,7 @@ def delete_volume(self, volume_id: str) -> None:
return self.openstack_connector.delete_volume(volume_id=volume_id)

def attach_volume_to_server(
self, openstack_id: str, volume_id: str
self, openstack_id: str, volume_id: str
) -> dict[str, str]:
return self.openstack_connector.attach_volume_to_server(
openstack_id=openstack_id, volume_id=volume_id
Expand All @@ -266,7 +260,7 @@ def get_limits(self) -> dict[str, str]:
return self.openstack_connector.get_limits()

def create_backend(
self, owner: str, user_key_url: str, template: str, upstream_url: str
self, owner: str, user_key_url: str, template: str, upstream_url: str
) -> Backend:
return self.forc_connector.create_backend(
owner=owner,
Expand Down Expand Up @@ -312,12 +306,12 @@ def delete_security_group_rule(self, openstack_id):
)

def open_port_range_for_vm_in_project(
self,
range_start,
range_stop,
openstack_id,
ethertype: str = "IPv4",
protocol: str = "TCP",
self,
range_start,
range_stop,
openstack_id,
ethertype: str = "IPv4",
protocol: str = "TCP",
) -> str:
return self.openstack_connector.open_port_range_for_vm_in_project(
range_start=range_start,
Expand All @@ -331,17 +325,17 @@ def add_udp_security_group(self, server_id: str) -> None:
return self.openstack_connector.add_udp_security_group(server_id=server_id)

def start_server(
self,
flavor_name: str,
image_name: str,
public_key: str,
servername: str,
metadata: dict[str, str],
volume_ids_path_new: list[dict[str, str]],
volume_ids_path_attach: list[dict[str, str]],
additional_keys: list[str],
research_environment: str,
additional_security_group_ids: list[str],
self,
flavor_name: str,
image_name: str,
public_key: str,
servername: str,
metadata: dict[str, str],
volume_ids_path_new: list[dict[str, str]],
volume_ids_path_attach: list[dict[str, str]],
additional_keys: list[str],
research_environment: str,
additional_security_group_ids: list[str],
) -> str:
if research_environment:
research_environment_metadata = (
Expand All @@ -365,15 +359,15 @@ def start_server(
)

def start_server_with_custom_key(
self,
flavor_name: str,
image_name: str,
servername: str,
metadata: dict[str, str],
research_environment: str,
volume_ids_path_new: list[dict[str, str]],
volume_ids_path_attach: list[dict[str, str]],
additional_security_group_ids: list[str],
self,
flavor_name: str,
image_name: str,
servername: str,
metadata: dict[str, str],
research_environment: str,
volume_ids_path_new: list[dict[str, str]],
volume_ids_path_attach: list[dict[str, str]],
additional_security_group_ids: list[str],
) -> str:
if research_environment:
research_environment_metadata = (
Expand All @@ -399,14 +393,14 @@ def start_server_with_custom_key(
return openstack_id

def create_and_deploy_playbook(
self,
public_key: str,
openstack_id: str,
conda_packages: list[CondaPackage],
research_environment_template: str,
apt_packages: list[str],
create_only_backend: bool,
base_url: str = "",
self,
public_key: str,
openstack_id: str,
conda_packages: list[CondaPackage],
research_environment_template: str,
apt_packages: list[str],
create_only_backend: bool,
base_url: str = "",
) -> int:
port = int(
self.openstack_connector.get_vm_ports(openstack_id=openstack_id)["port"]
Expand Down Expand Up @@ -436,11 +430,11 @@ def get_cluster_status(self, cluster_id: str) -> dict[str, str]:
return self.bibigrid_connector.get_cluster_status(cluster_id=cluster_id)

def start_cluster(
self,
public_key: str,
master_instance: ClusterInstance,
worker_instances: list[ClusterInstance],
user: str,
self,
public_key: str,
master_instance: ClusterInstance,
worker_instances: list[ClusterInstance],
user: str,
) -> dict[str, str]:
return self.bibigrid_connector.start_cluster(
public_key=public_key,
Expand All @@ -453,16 +447,16 @@ def terminate_cluster(self, cluster_id: str) -> dict[str, str]:
return self.bibigrid_connector.terminate_cluster(cluster_id=cluster_id)

def add_cluster_machine(
self,
cluster_id: str,
cluster_user: str,
cluster_group_id: list[str],
image_name: str,
flavor_name: str,
name: str,
key_name: str,
batch_idx: int,
worker_idx: int,
self,
cluster_id: str,
cluster_user: str,
cluster_group_id: list[str],
image_name: str,
flavor_name: str,
name: str,
key_name: str,
batch_idx: int,
worker_idx: int,
) -> str:
return self.openstack_connector.add_cluster_machine(
cluster_id=cluster_id,
Expand Down
90 changes: 90 additions & 0 deletions simple_vm_client/test_virtualmachinehandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,36 @@ def test_add_udp_security_group(self):
server_id=OPENSTACK_ID
)

def test_start_server_with_res(self):
self.handler.forc_connector.get_metadata_by_research_environment.return_value = (
"res_metadata"
)

self.handler.start_server(
flavor_name=FLAVOR.name,
image_name=IMAGE.name,
public_key="pub",
servername=SERVER.name,
metadata=METADATA,
volume_ids_path_new=[],
volume_ids_path_attach=[],
additional_keys=[],
research_environment="de",
additional_security_group_ids=[],
)
self.handler.openstack_connector.start_server.assert_called_once_with(
flavor_name=FLAVOR.name,
image_name=IMAGE.name,
public_key="pub",
servername=SERVER.name,
metadata=METADATA,
volume_ids_path_new=[],
volume_ids_path_attach=[],
additional_keys=[],
research_environment_metadata="res_metadata",
additional_security_group_ids=[],
)

def test_start_server(self):
self.handler.start_server(
flavor_name=FLAVOR.name,
Expand Down Expand Up @@ -479,6 +509,38 @@ def test_start_server_with_custom_key(self):
openstack_id=SERVER.id, private_key="priv", name=SERVER.name
)

def test_start_server_with_custom_key_and_res(self):
self.handler.openstack_connector.start_server_with_playbook.return_value = (
SERVER.id,
"priv",
)
self.handler.forc_connector.get_metadata_by_research_environment.return_value = (
"res_metadata"
)
self.handler.start_server_with_custom_key(
flavor_name=FLAVOR.name,
image_name=IMAGE.name,
servername=SERVER.name,
metadata=METADATA,
volume_ids_path_new=[],
volume_ids_path_attach=[],
research_environment="de",
additional_security_group_ids=[],
)
self.handler.openstack_connector.start_server_with_playbook.assert_called_once_with(
flavor_name=FLAVOR.name,
image_name=IMAGE.name,
servername=SERVER.name,
metadata=METADATA,
volume_ids_path_new=[],
volume_ids_path_attach=[],
additional_security_group_ids=[],
research_environment_metadata="res_metadata",
)
self.handler.forc_connector.set_vm_wait_for_playbook.assert_called_once_with(
openstack_id=SERVER.id, private_key="priv", name=SERVER.name
)

def test_create_and_deploy_playbook(self):
self.handler.openstack_connector.get_vm_ports.return_value = {
"port": str(20),
Expand Down Expand Up @@ -577,3 +639,31 @@ def test_add_cluster_machine(self):
batch_idx=1,
worker_idx=1,
)

def test_keyboard_interrupt_handler_playbooks(self):
mock_stop_a = MagicMock()
mock_stop_b = MagicMock()
mock_stop_c = MagicMock()

self.handler.forc_connector._active_playbooks = {
"a": mock_stop_a,
"b": mock_stop_b,
"c": mock_stop_c,
}
self.handler.forc_connector.redis_connection.hget.side_effect = [
"a".encode("utf-8"),
"b".encode("utf-8"),
"c".encode("utf-8"),
]
with self.assertRaises(SystemExit):
self.handler.keyboard_interrupt_handler_playbooks()
for key in self.handler.forc_connector._active_playbooks.keys():
self.handler.openstack_connector.delete_keypair.assert_any_call(
key_name=key
)
self.handler.openstack_connector.delete_server.assert_any_call(
openstack_id=key
)
mock_stop_a.stop.assert_called_once()
mock_stop_b.stop.assert_called_once()
mock_stop_c.stop.assert_called_once()

0 comments on commit 4b7eb24

Please sign in to comment.