diff --git a/models/LesNet_labels.json b/models/LesNet_labels.json new file mode 100644 index 0000000..5fded78 --- /dev/null +++ b/models/LesNet_labels.json @@ -0,0 +1,30 @@ +[ + "acrochordon", + "actinic keratosis", + "AIMP", + "angiofibroma or fibrous papule", + "angiokeratoma", + "angioma", + "atypical melanocytic proliferation", + "atypical spitz tumor", + "basal cell carcinoma", + "benign", + "cafe-au-lait macule", + "clear cell acanthoma", + "dermatofibroma", + "lentigo NOS", + "lentigo simplex", + "lichenoid keratosis", + "malignant", + "melanoma", + "neurofibroma", + "nevus", + "pigmented benign keratosis", + "scar", + "sebaceous hyperplasia", + "seborrheic keratosis", + "solar lentigo", + "squamous cell carcinoma", + "vascular lesion", + "verruca" +] \ No newline at end of file diff --git a/skinvestigatorai/config/model.py b/skinvestigatorai/config/model.py index 020fc4e..8c1bd14 100644 --- a/skinvestigatorai/config/model.py +++ b/skinvestigatorai/config/model.py @@ -30,5 +30,5 @@ class ModelConfig(object): MAX_AUG_PER_IMAGE = 5000000 TRAIN_DIR = 'data/train' MODEL_TYPE = "KERAS" - MODEL_NAME = "LesNetM31.keras" + MODEL_NAME = "LesNet.keras" LABELS_NAME = "LesNet_labels.json" diff --git a/tests/test_inference_service.py b/tests/test_inference_service.py new file mode 100644 index 0000000..9e4579e --- /dev/null +++ b/tests/test_inference_service.py @@ -0,0 +1,90 @@ +from io import BytesIO +from unittest.mock import MagicMock, patch + +import numpy as np +from pyramid.response import Response +import pytest +from PIL import Image + +from skinvestigatorai.services.inference import Inference +from skinvestigatorai.services.model import SVModel + + +@pytest.fixture +def mock_svmodel(): + sv_model = MagicMock(SVModel) + sv_model.load_model.return_value = (MagicMock(), ["class1", "class2"]) + sv_model.preprocess_image_for_tflite = lambda x: x + return sv_model + + +@pytest.fixture +def inference(mock_svmodel): + with patch('skinvestigatorai.services.model.SVModel', return_value=mock_svmodel): + return Inference() + + +def create_mock_image(): + image = Image.new('RGB', (100, 100)) + img_byte_arr = BytesIO() + image.save(img_byte_arr, format='PNG') + img_byte_arr = BytesIO(img_byte_arr.getvalue()) + return img_byte_arr + + +def test_predict_success(inference): + mock_image = create_mock_image() + + inference.model.predict = MagicMock(return_value=np.array([[0.1, 0.9]])) + + result = inference.predict(mock_image) + + assert isinstance(result, dict) + assert 'prediction' in result + assert 'confidence' in result + + +def test_predict_failure(inference): + mock_image = create_mock_image() + + inference.model.predict = MagicMock(return_value=np.array([[0.3, 0.2]])) + + result = inference.predict(mock_image) + + assert isinstance(result, Response) + assert result.status_code == 400 + + +def test_is_image_similar(inference): + mock_image = np.random.rand(100, 100, 3) + + inference.dataset_embedding = np.random.rand(2048) + inference._predict_similar = MagicMock(return_value=np.random.rand(2048)) + + result = inference.is_image_similar(mock_image, threshold=0.5) + + assert result in [True, False] + + +def test__predict_similar_keras(inference): + mock_image = np.random.rand(100, 100, 3) + inference.model.predict = MagicMock(return_value=np.random.rand(1, 2048)) + + with patch('skinvestigatorai.config.model.ModelConfig.MODEL_TYPE', 'KERAS'): + result = inference._predict_similar(mock_image) + + assert result is not None + + +def test__predict_similar_tflite(inference): + mock_image = np.random.rand(100, 100, 3) + inference.model.get_input_details = MagicMock(return_value=[{'index': 0, 'dtype': np.float32}]) + inference.model.get_output_details = MagicMock(return_value=[{'index': 1}]) + inference.model.set_tensor = MagicMock() + inference.model.invoke = MagicMock() + inference.model.get_tensor = MagicMock(return_value=np.random.rand(1, 2048)) + + with patch('skinvestigatorai.config.model.ModelConfig.MODEL_TYPE', 'TFLITE'): + result = inference._predict_similar(mock_image) + + assert result is not None diff --git a/tests/test_model_service.py b/tests/test_model_service.py new file mode 100644 index 0000000..6a872b1 --- /dev/null +++ b/tests/test_model_service.py @@ -0,0 +1,69 @@ +from unittest.mock import patch, MagicMock + +import numpy as np +import pytest + +from skinvestigatorai.config.model import ModelConfig +from skinvestigatorai.services.model import SVModel + + +@pytest.fixture +def sv_model(): + return SVModel() + + +def test_create_feature_extractor_tflite(sv_model): + sv_model.model_type = 'TFLITE' + mock_model = MagicMock() + sv_model.model = mock_model + sv_model.create_feature_extractor() + assert sv_model.feature_extractor == mock_model + + +def test_create_feature_extractor_invalid_model_type(sv_model): + sv_model.model_type = 'INVALID' + with pytest.raises(ValueError, match="Unsupported model type. Please use 'KERAS' or 'TFLITE'."): + sv_model.create_feature_extractor() + + +def test_preprocess_image_for_tflite(sv_model): + img = np.random.rand(224, 224, 3).astype(np.float32) + processed_img = sv_model.preprocess_image_for_tflite(img) + assert processed_img.shape == (ModelConfig.IMG_SIZE[0], ModelConfig.IMG_SIZE[1], 3) + assert np.max(processed_img) <= 1.0 + assert np.min(processed_img) >= 0.0 + + +def test_evaluate_model(sv_model): + sv_model.model = MagicMock() + sv_model.model.evaluate.return_value = [0.5, 0.8, 0.7, 0.6] + test_datagen = MagicMock() + test_loss, test_acc, test_precision, test_recall = sv_model.evaluate_model(test_datagen) + assert test_loss == 0.5 + assert test_acc == 0.8 + assert test_precision == 0.7 + assert test_recall == 0.6 + + +@patch('tensorflow.summary.create_file_writer') +def test_run_experiments(mock_create_file_writer, sv_model): + sv_model.run_experiments = MagicMock() + train_ds = MagicMock() + val_ds = MagicMock() + sv_model.run_experiments(train_ds, val_ds) + sv_model.run_experiments.assert_called_once_with(train_ds, val_ds) + + +def test_save_model(sv_model): + sv_model.model = MagicMock() + with patch('builtins.open', MagicMock()): + with patch('tensorflow.keras.models.Model.save', MagicMock()): + sv_model.save_model() + sv_model.model.save.assert_called_once() + + +def test_load_model(sv_model): + with patch('os.path.exists', return_value=True): + with patch('tensorflow.keras.models.load_model', return_value=MagicMock()): + sv_model.load_model() + assert isinstance(sv_model.model, MagicMock)