Skip to content

Commit

Permalink
Merge pull request #366 from Limmen/apt-game
Browse files Browse the repository at this point in the history
apt_game unit tests
  • Loading branch information
Limmen authored Jun 4, 2024
2 parents 8855621 + e9d6494 commit 6c709e8
Show file tree
Hide file tree
Showing 8 changed files with 904 additions and 20 deletions.
Empty file.
67 changes: 67 additions & 0 deletions simulation-system/libs/gym-csle-apt-game/apt_game.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
2 1 2 2 5 45 6 1
0 0
1 0
CONTINUE
STOP
CONTINUE
STOP
o_0
o_1
o_2
o_3
o_4
0 1
0 1
0 1
0 0 0 0 0 0.542067363614722
0 0 0 1 0 0.25296476968687015
0 0 0 2 0 0.12901203254030388
0 0 0 3 0 0.05805541464313671
0 0 0 4 0 0.01790041951496717
0 0 1 0 0 0.4336538908917776
0 0 1 1 0 0.20237181574949614
0 0 1 2 0 0.10320962603224311
0 0 1 3 0 0.046444331714509374
0 0 1 4 0 0.014320335611973737
0 0 1 0 1 0.029787234042553196
0 0 1 1 1 0.03220241518113858
0 0 1 2 1 0.035780461312376215
0 0 1 3 1 0.042094660367501424
0 0 1 4 1 0.0601352290964306
0 1 0 0 0 0.542067363614722
0 1 0 1 0 0.25296476968687015
0 1 0 2 0 0.12901203254030388
0 1 0 3 0 0.05805541464313671
0 1 0 4 0 0.01790041951496717
0 1 1 0 0 0.542067363614722
0 1 1 1 0 0.25296476968687015
0 1 1 2 0 0.12901203254030388
0 1 1 3 0 0.05805541464313671
0 1 1 4 0 0.01790041951496717
1 0 0 0 1 0.14893617021276598
1 0 0 1 1 0.1610120759056929
1 0 0 2 1 0.17890230656188105
1 0 0 3 1 0.2104733018375071
1 0 0 4 1 0.300676145482153
1 0 1 0 1 0.14893617021276598
1 0 1 1 1 0.1610120759056929
1 0 1 2 1 0.17890230656188105
1 0 1 3 1 0.2104733018375071
1 0 1 4 1 0.300676145482153
1 1 0 0 0 0.542067363614722
1 1 0 1 0 0.25296476968687015
1 1 0 2 0 0.12901203254030388
1 1 0 3 0 0.05805541464313671
1 1 0 4 0 0.01790041951496717
1 1 1 0 0 0.542067363614722
1 1 1 1 0 0.25296476968687015
1 1 1 2 0 0.12901203254030388
1 1 1 3 0 0.05805541464313671
1 1 1 4 0 0.01790041951496717
0 1 0 -1.0
0 1 1 -1.0
1 0 0 -1.0
1 0 1 -1.0
1 1 0 1.0
1 1 1 1.0
0 1 0
Original file line number Diff line number Diff line change
Expand Up @@ -235,25 +235,6 @@ def bayes_filter(s_prime: int, o: int, a1: int, b: npt.NDArray[np.float_], pi2:
assert round(b_prime_s_prime, 2) <= 1
return b_prime_s_prime

@staticmethod
def p_o_given_b_a1_a2(o: int, b: List[float], a1: int, a2: int, config: AptGameConfig) -> float:
"""
Computes P[o|a,b]
:param o: the observation
:param b: the belief point
:param a1: the action of player 1
:param a2: the action of player 2
:param config: the game config
:return: the probability of observing o when taking action a in belief point b
"""
prob = 0
for s in config.S:
for s_prime in config.S:
prob += b[s] * config.T[a1][a2][s][s_prime] * config.Z[a1][a2][s_prime][o]
assert prob < 1
return prob

@staticmethod
def next_belief(o: int, a1: int, b: npt.NDArray[np.float_], pi2: npt.NDArray[Any],
config: AptGameConfig, a2: int = 0, s: int = 0) -> npt.NDArray[np.float_]:
Expand Down
234 changes: 234 additions & 0 deletions simulation-system/libs/gym-csle-apt-game/tests/test_apt_game_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
from gym_csle_apt_game.envs.apt_game_env import AptGameEnv
from gym_csle_apt_game.dao.apt_game_config import AptGameConfig
from gym_csle_apt_game.util.apt_game_util import AptGameUtil
from gym_csle_apt_game.dao.apt_game_state import AptGameState
from unittest.mock import patch, MagicMock
import csle_common.constants.constants as constants
from csle_common.dao.simulation_config.simulation_trace import SimulationTrace
import pytest
import numpy as np


class TestAptGameEnvSuite:
"""
Test suite for apt_game_env.py
"""

@pytest.fixture(autouse=True)
def setup_env(self) -> None:
"""
Sets up the configuration of the apt game
:return: None
"""
env_name = "test_env"
N = 3
p_a = 0.5
T = AptGameUtil.transition_tensor(N, p_a)
O = AptGameUtil.observation_space(100)
Z = AptGameUtil.observation_tensor(100, N)
C = AptGameUtil.cost_tensor(N)
S = AptGameUtil.state_space(N=3)
A1 = AptGameUtil.defender_actions()
A2 = AptGameUtil.attacker_actions()
b1 = AptGameUtil.b1(N=3)
save_dir = "save_directory"
checkpoint_traces_freq = 100
gamma = 0.9
self.config = AptGameConfig(
env_name,
T,
O,
Z,
C,
S,
A1,
A2,
b1,
N,
p_a,
save_dir,
checkpoint_traces_freq,
gamma,
)
self.env = AptGameEnv(self.config)

def test_init_(self) -> None:
"""
Tests the initializing function
:return: None
"""
assert self.env.config == self.config
assert self.env.state == AptGameState(b1=self.config.b1)
assert (
self.env.attacker_observation_space
== self.config.attacker_observation_space()
)
assert (
self.env.defender_observation_space
== self.config.defender_observation_space()
)
assert self.env.attacker_action_space == self.config.attacker_action_space()
assert self.env.defender_action_space == self.config.defender_action_space()
assert self.env.action_space == self.config.defender_action_space()
assert self.env.observation_space == self.config.defender_observation_space()
assert isinstance(self.env.traces, list)
assert isinstance(self.env.trace, SimulationTrace)
assert self.env.trace.simulation_env == self.config.env_name

def test_step(self) -> None:
"""
Tests the function for taking a step in the environment by executing the given action
:return: None
"""
initial_state = self.env.state.s
action_profile = (
0,
(np.array([[0.5, 0.5], [0.5, 0.5], [0.4, 0.6], [0.5, 0.5]]), 1),
)
obs, reward, terminated, truncated, info = self.env.step(action_profile)
assert obs[0].all() == self.env.state.defender_observation().all()
assert obs[1] == self.env.state.attacker_observation()
assert isinstance(terminated, bool) # type: ignore
assert isinstance(truncated, bool) # type: ignore
assert reward == (
self.config.C[action_profile[0]][initial_state],
-self.config.C[action_profile[0]][initial_state],
)
assert isinstance(info, dict) # type: ignore

def test_mean(self) -> None:
"""
Tests the utility function for getting the mean of a vector
:return: None
"""
test_cases = [
([], 0), # Test case for an empty vector
([5], 0), # Test case for a vector with a single element
([0.2, 0.3, 0.5], 1.3), # Test case for a vector with multiple elements
]
for prob_vector, expected_mean in test_cases:
result = AptGameEnv(self.config).mean(prob_vector)
assert result == expected_mean

def test_info(self) -> None:
"""
Tests the function for adding the cumulative reward and episode length to the info dict
:return: None
"""
info = {} # type: ignore
assert isinstance(self.env._info(info), dict) # type: ignore

def test_reset(self) -> None:
"""
Tests the function for reseting the environment state
:return: None
"""
self.env.trace.attacker_rewards = [1, 2, 3]
self.env.trace.attacker_observations = ["obs1", "obs2"]
self.env.trace.defender_observations = ["obs1", "obs2"]
initial_trace_count = len(self.env.traces)
AptGameState.reset = lambda self: None
initial_obs, info = self.env.reset(seed=10, soft=False, options=None)
assert len(self.env.traces) == initial_trace_count + 1
assert isinstance(self.env.trace, SimulationTrace)
assert self.env.trace.simulation_env == self.config.env_name
assert self.env.trace.attacker_observations == [initial_obs[1]]
assert self.env.trace.defender_observations == [initial_obs[0]]
assert info == {}

def test_render(self) -> None:
"""
Tests the function of rendering the environment
:return: None
"""
with pytest.raises(NotImplementedError):
self.env.render()

def test_is_defense_action_legal(self) -> None:
"""
Tests the function of checking whether a defender action in the environment is legal or not
:return: None
"""
assert self.env.is_defense_action_legal(1)

def test_is_attack_action_legal(self) -> None:
"""
Tests the function of checking whether an attacker action in the environment is legal or not
:return: None
"""
assert self.env.is_attack_action_legal(1)

def test_get_traces(self) -> None:
"""
Tests the function of getting the list of simulation traces
:return: None
"""
assert self.env.get_traces() == self.env.traces

def test_reset_traces(self) -> None:
"""
Tests the function of resetting the list of traces
:return: None
"""
self.env.traces = ["trace1", "trace2"]
self.env.reset_traces()
assert self.env.traces == []

def test_checkpoint_traces(self) -> None:
"""
Tests the function of checkpointing agent traces
:return: None
"""

fixed_timestamp = 123
with patch("time.time", return_value=fixed_timestamp):
with patch(
"csle_common.dao.simulation_config.simulation_trace.SimulationTrace.save_traces"
) as mock_save_traces:
self.env.traces = ["trace1", "trace2"]
self.env._AptGameEnv__checkpoint_traces()
mock_save_traces.assert_called_once_with(
traces_save_dir=constants.LOGGING.DEFAULT_LOG_DIR,
traces=self.env.traces,
traces_file=f"taus{fixed_timestamp}.json",
)

def test_set_model(self) -> None:
"""
Tests the function of setting the model
:return: None
"""
mock_model = MagicMock()
self.env.set_model(mock_model)
assert self.env.model == mock_model

def test_set_state(self) -> None:
"""
Tests the function of setting the state
:return: None
"""
self.env.state = MagicMock()
mock_state = MagicMock(spec=AptGameState)
self.env.set_state(mock_state)
assert self.env.state == mock_state

state_int = 5
self.env.set_state(state_int)
assert self.env.state.s == state_int

with pytest.raises(ValueError):
self.env.set_state([1, 2, 3]) # type: ignore
Loading

0 comments on commit 6c709e8

Please sign in to comment.