From e4da0d579e4cdba172073c7b18c12f56e5cbeb0e Mon Sep 17 00:00:00 2001 From: sylvain-morin Date: Tue, 10 Sep 2024 11:46:15 +0200 Subject: [PATCH] Gemini evaluation --- gemini/evaluate_gemini_results.py | 63 ++++++++++++++++++------------- gemini/reconcile.py | 61 ++++++++++++++++++++++++++++++ gemini/requirements.txt | 4 +- 3 files changed, 99 insertions(+), 29 deletions(-) create mode 100644 gemini/reconcile.py diff --git a/gemini/evaluate_gemini_results.py b/gemini/evaluate_gemini_results.py index 7a370cb..e0fd421 100644 --- a/gemini/evaluate_gemini_results.py +++ b/gemini/evaluate_gemini_results.py @@ -24,17 +24,11 @@ def __init__(self, **args): currentDatetime = datetime.now() self.start_timestamp = currentDatetime.strftime("%Y%m%d") - def export_path(self, filename_addition=None, label=None): - if label is None: - label = self.cmd_args["label"] - export_path = f"evaluation-results-{self.start_timestamp}-{label}" - if filename_addition: - export_path += f"-{filename_addition}" - export_path += ".csv" - if self.cmd_args["data_dir"]: - export_path = os.path.join(self.cmd_args["data_dir"], export_path) - return export_path - + def export_path(self, path, file): + label = self.cmd_args["label"] + export_path = f"{label}-{file}" + return os.path.join(self.cmd_args["data_dir"], export_path) + async def run_async(self): if self.cmd_args["data_dir"]: for file in sorted(os.listdir(self.cmd_args["data_dir"])): @@ -44,7 +38,8 @@ async def run_async(self): path = os.path.join(self.cmd_args["data_dir"], file) print(f"\nProcessing {file}") await self.evaluate_observations_at_path(path) - self.display_and_save_results() + observations_export_path = self.export_path(path, file) + self.display_and_save_results(observations_export_path) async def evaluate_observations_at_path(self, path): N_WORKERS = 5 @@ -109,6 +104,33 @@ async def test_observation_async(self, observation): observation.clean_gemini_name = ''.join(char for char in observation.gemini_response if char.isalpha() or char.isspace() or char == '-') print(observation.clean_gemini_name) + self.evaluate_taxon_name(observation) + + if not observation.evaluation_status: + clean_gemini_name = observation.gemini_response.replace("*", "") + gbif_response = requests.get("https://api.gbif.org/v1/parser/name?name="+clean_gemini_name) + if gbif_response.status_code == 200: + gbif_data = gbif_response.json() + if gbif_data[0]: + observation.clean_gemini_name = gbif_data[0]["canonicalName"] + observation.evaluate_using_gbif = True + print(observation.clean_gemini_name + " (from GBIF)") + + self.evaluate_taxon_name(observation) + + observation.matching_active = False + observation.matching_synonym = False + + if observation.evaluation_status: + if observation.evaluation_is_active and observation.evaluation_taxon_id == observation.taxon_id: + observation.matching_active = True + elif not observation.evaluation_is_active and observation.taxon_id in observation.evaluation_synonymous_taxon_ids: + observation.matching_synonym = True + + observation.matching = observation.matching_active or observation.matching_synonym + observation.matching_int = int(observation.matching) + + def evaluate_taxon_name(self, observation): original_taxa_url = self.TAXA_API_URL + str(observation.taxon_id) + "?fields=name,is_active,current_synonymous_taxon_ids" original_response = requests.get(original_taxa_url) if original_response.status_code == 200: @@ -158,21 +180,8 @@ async def test_observation_async(self, observation): observation.evaluation_is_active = result["is_active"] observation.evaluation_synonymous_taxon_ids = result["current_synonymous_taxon_ids"] break - - observation.matching_active = False - observation.matching_synonym = False - - if observation.evaluation_status: - if observation.evaluation_is_active and observation.evaluation_taxon_id == observation.taxon_id: - observation.matching_active = True - elif not observation.evaluation_is_active and observation.taxon_id in observation.evaluation_synonymous_taxon_ids: - observation.matching_synonym = True - - observation.matching = observation.matching_active or observation.matching_synonym - observation.matching_int = int(observation.matching) - - def display_and_save_results(self): - observations_export_path = self.export_path() + + def display_and_save_results(self, observations_export_path): test_observations_data = [obs.to_dict() for obs in self.test_observations.values()] test_observations_df = pd.DataFrame(test_observations_data) test_observations_df.to_csv(observations_export_path) diff --git a/gemini/reconcile.py b/gemini/reconcile.py new file mode 100644 index 0000000..2882525 --- /dev/null +++ b/gemini/reconcile.py @@ -0,0 +1,61 @@ +import yaml +import json +import re +import os +import click +import pandas as pd + +def reconcile(gemini_file_path, vision_file_path, result_file_path): + print("Reconcile:") + print(gemini_file_path) + print(vision_file_path) + print(result_file_path) + + gemini_file = pd.read_csv(gemini_file_path) + vision_file = pd.read_csv(vision_file_path) + + # Filter vision_file where method is "combined" + vision_file_filtered = vision_file[(vision_file['method'] == 'combined') & (vision_file['inferrer_name'] == 2.15)] + + # Merge gemini_file with vision_file_filtered on observation_id == uuid + merged_df = pd.merge(gemini_file, vision_file_filtered[['uuid', 'matching_index']], + left_on='observation_id', right_on='uuid', + how='left') + + # Drop the redundant 'uuid' column if necessary + merged_df.drop('uuid', axis=1, inplace=True) + + # Add a new column 'is_matching_index_zero' which is True if matching_index == 0, otherwise False + merged_df['is_matching_index_zero'] = merged_df['matching_index'] == 0 + + # Add another column 'is_matching_index_zero_int' which is 1 if True, 0 if False + merged_df['is_matching_index_zero_int'] = merged_df['is_matching_index_zero'].astype(int) + + # Save the result to a new CSV file + merged_df.to_csv(result_file_path, index=False) + +@click.command() +@click.option("--data_dir", type=click.Path(), help="Path to test data CSVs directory.") +@click.option("--label", required=True, type=str, help="Label used for output.") +def test(**args): + print("\nArguments:") + print(json.dumps(args, indent=4)) + + label = args["label"] + folder_path = args["data_dir"] + + for file in sorted(os.listdir(folder_path)): + eval_filename_match = re.search(rf"{label}-test-results-[0-9]{{8}}-([a-zA-Z]+)-.*\.csv", file) + if eval_filename_match: + group = eval_filename_match.group(1) + for observation_file in sorted(os.listdir(folder_path)): + if re.search(f"test-results-[0-9]{{8}}-{group}-.*-observations\.csv", observation_file): + reconcile( + os.path.join(folder_path, file), + os.path.join(folder_path, observation_file), + os.path.join(folder_path, "comparison-"+file) + ) + + +if __name__ == "__main__": + test() diff --git a/gemini/requirements.txt b/gemini/requirements.txt index 3c71813..a32a905 100644 --- a/gemini/requirements.txt +++ b/gemini/requirements.txt @@ -6,5 +6,5 @@ requests python-magic numpy==1.26.4;python_version>="3.11" numpy==1.23.5;python_version=="3.8" -pandas==2.1.2;python_version>="3.11" -pandas==2.0.3;python_version=="3.8" \ No newline at end of file +pandas==2.0.3;python_version=="3.8" +pandas==2.1.2;python_version>="3.11" \ No newline at end of file