diff --git a/tests/test_predict.py b/tests/test_predict.py index 995ee5c..2c88dd3 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -75,7 +75,7 @@ def test_train_predict(tmp_path, cfg_train, cfg_predict): HydraConfig().set_config(cfg_predict) _, object_dict = predict(cfg_predict) - predicted_entities = [list(doc.entities.predictions) for doc in object_dict["documents"]] + predicted_entities = [list(doc.labeled_spans.predictions) for doc in object_dict["documents"]] num_predicted_entities = sum([len(preds) for preds in predicted_entities]) assert num_predicted_entities > 0