Skip to content

Commit

Permalink
tests: added test for TimeDependentModel with flat files, and expande…
Browse files Browse the repository at this point in the history
…d coverage to existing methods
  • Loading branch information
pabloitu committed Jul 28, 2024
1 parent 6b369b8 commit 2ad424f
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
lon, lat, M, time_string, depth, catalog_id, event_id
1.0,1.0,5.0,2020-01-01T01:01:01.0,10.0,1,1
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
lon, lat, M, time_string, depth, catalog_id, event_id
1.0,1.0,5.0,2020-01-02T01:01:01.0,10.0,1,1
116 changes: 115 additions & 1 deletion tests/unit/test_model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os.path
import tempfile
import json
from datetime import datetime

import csep.core.regions
import numpy.testing

from unittest import TestCase
from unittest.mock import patch
from unittest.mock import patch, MagicMock, mock_open
import filecmp

from floatcsep.model import TimeIndependentModel, TimeDependentModel
Expand Down Expand Up @@ -274,6 +275,85 @@ def test_create_forecast(self, mock_func):
model.create_forecast('2020-01-01_2021-01-01')
self.assertTrue(mock_func.called)

@patch('csep.load_catalog_forecast') # Mocking the load_catalog_forecast function
def test_get_forecast_single(self, mock_load_forecast):
# Arrange
model = self.init_model(name='mock', path='../artifacts/models/td_model')
tstring = '2020-01-01_2020-01-02' # Example time window string
model.stage([str2timewindow(tstring)])

print(model.path)
region = 'TestRegion'

# Mock the return value of load_catalog_forecast
mock_forecast = MagicMock()
mock_load_forecast.return_value = mock_forecast

# Act
result = model.get_forecast(tstring=tstring, region=region)

# Assert
mock_load_forecast.assert_called_once_with(model.path("forecasts", tstring),
region=region)
self.assertEqual(result, mock_forecast)

@patch('csep.load_catalog_forecast')
def test_get_forecast_multiple(self, mock_load_forecast):
# Arrange
model = self.init_model(name='mock', path='../artifacts/models/td_model')
tstrings = ['2020-01-01_2020-01-02',
'2020-01-02_2020-01-03'] # Example list of time window strings
region = 'TestRegion'
model.stage(str2timewindow(tstrings))
# Mock the return values of load_catalog_forecast for each forecast
mock_forecast1 = MagicMock()
mock_forecast2 = MagicMock()
mock_load_forecast.side_effect = [mock_forecast1, mock_forecast2]

# Act
result = model.get_forecast(tstring=tstrings, region=region)

# Assert
self.assertEqual(len(result), 2)
mock_load_forecast.assert_any_call(model.path("forecasts", tstrings[0]), region=region)
mock_load_forecast.assert_any_call(model.path("forecasts", tstrings[1]), region=region)
self.assertEqual(result[0], mock_forecast1)
self.assertEqual(result[1], mock_forecast2)

@patch('subprocess.run') # Mock subprocess.run
@patch('subprocess.Popen') # Mock subprocess.Popen
@patch('os.path.exists') # Mock os.path.exists
def test_build_model_creates_venv(self, mock_exists, mock_popen, mock_run):
# Arrange
model_path = '../artifacts/models/td_model'
model = self.init_model(name='TestModel', path=model_path, build='venv')
mock_exists.return_value = False # Simulate that the venv does not exist

# Act
model.build_model()

# Assert
mock_run.assert_called_once_with(["python", "-m", "venv", model.path("path") + "/venv"])
mock_popen.assert_called_once() # Ensure Popen was called to install dependencies
self.assertIn(f'cd {os.path.abspath(model_path)} && source', model.run_prefix)

@patch('subprocess.run')
@patch('subprocess.Popen')
@patch('os.path.exists')
def test_build_model_when_venv_exists(self, mock_exists, mock_popen, mock_run):
# Arrange
model_path = '../artifacts/models/td_model'
model = self.init_model(name='TestModel', path=model_path, build='venv')
mock_exists.return_value = True # Simulate that the venv already exists

# Act
model.build_model()

# Assert
mock_run.assert_not_called()
mock_popen.assert_not_called()
self.assertIn(f'cd {os.path.abspath(model_path)} && source', model.run_prefix)

def test_argprep(self):
model_path = os.path.join(self._dir, 'td_model')
with open(os.path.join(model_path, 'input', 'args.txt'), 'w') as args:
Expand Down Expand Up @@ -301,3 +381,37 @@ def test_argprep(self):
with open(os.path.join(model_path, 'input', 'args.txt'), 'r') as args:
self.assertEqual(args.readlines()[2],
f'n_sims = 200\n')

@patch('floatcsep.model.open', new_callable=mock_open, read_data='{"key": "value"}')
@patch('json.dump')
def test_argprep_json(self, mock_json_dump, mock_file):
model = self.init_model(name='TestModel', path=os.path.join(self._dir, 'td_model'))
model.path = MagicMock(return_value='path/to/model/args.json') # Mock path method
start = MagicMock()
end = MagicMock()
start.isoformat.return_value = '2023-01-01'
end.isoformat.return_value = '2023-01-31'

kwargs = {'key1': 'value1', 'key2': 'value2'}

# Act
model.prepare_args(start, end, **kwargs)

# Assert
# Check that the file was opened for reading
mock_file.assert_any_call('path/to/model/args.json', 'r')

# Check that the file was opened for writing
mock_file.assert_any_call('path/to/model/args.json', 'w')

# Check that the JSON data was updated correctly
expected_data = {
"key": "value",
"start_date": '2023-01-01',
"end_date": '2023-01-31',
"key1": 'value1',
"key2": 'value2'
}

# Check that json.dump was called with the expected data
mock_json_dump.assert_called_with(expected_data, mock_file(), indent=2)

0 comments on commit 2ad424f

Please sign in to comment.