diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py index 12773995..8b985c13 100644 --- a/tests/unit/test_layer.py +++ b/tests/unit/test_layer.py @@ -811,6 +811,36 @@ def test_refresh_routes_not_implemented(self, openai_encoder, routes, index_cls) ): route_layer._refresh_routes() + def test_update_threshold(self, openai_encoder, routes, index_cls): + index = init_index(index_cls) + route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=index) + route_name = "Route 1" + new_threshold = 0.8 + route_layer.update(name=route_name, threshold=new_threshold) + updated_route = route_layer.get(route_name) + assert updated_route.score_threshold == new_threshold, f"Expected threshold to be updated to {new_threshold}, but got {updated_route.score_threshold}" + + def test_update_non_existent_route(self, openai_encoder, routes, index_cls): + index = init_index(index_cls) + route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=index) + non_existent_route = "Non-existent Route" + with pytest.raises(ValueError, match=f"Route '{non_existent_route}' not found. Nothing updated."): + route_layer.update(name=non_existent_route, threshold=0.7) + + def test_update_without_parameters(self, openai_encoder, routes, index_cls): + index = init_index(index_cls) + route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=index) + with pytest.raises(ValueError, match="At least one of 'threshold' or 'utterances' must be provided."): + route_layer.update(name="Route 1") + + def test_update_utterances_not_implemented(self, openai_encoder, routes, index_cls): + index = init_index(index_cls) + route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=index) + with pytest.raises(NotImplementedError, match="The update method cannot be used for updating utterances yet."): + route_layer.update(name="Route 1", utterances=["New utterance"]) + + + class TestLayerFit: def test_eval(self, openai_encoder, routes, test_data): @@ -948,3 +978,4 @@ def test_semantic_classify_multiple_routes_with_different_aggregation( elif agg == "max": assert classification == "Route 3" assert score == [0.1, 1.0] +