From ff5b7132c869c2b5eff587683bfcdfdabd77288c Mon Sep 17 00:00:00 2001 From: sylvain-morin Date: Fri, 30 Aug 2024 18:11:17 +0200 Subject: [PATCH] Gemini evaluation --- gemini/evaluate_gemini_results.py | 191 ++++++++++++++++++++++++++++++ gemini/requirements.txt | 10 ++ gemini/test_gemini.py | 23 ++++ gemini/test_observation.py | 8 ++ 4 files changed, 232 insertions(+) create mode 100644 gemini/evaluate_gemini_results.py create mode 100644 gemini/requirements.txt create mode 100644 gemini/test_gemini.py create mode 100644 gemini/test_observation.py diff --git a/gemini/evaluate_gemini_results.py b/gemini/evaluate_gemini_results.py new file mode 100644 index 0000000..7a370cb --- /dev/null +++ b/gemini/evaluate_gemini_results.py @@ -0,0 +1,191 @@ +import os +import hashlib +import magic +import time +import json +import pandas as pd +import numpy as np +import asyncio +import aiohttp +import aiofiles +import aiofiles.os +import re +import traceback +import requests +from datetime import datetime +from test_observation import TestObservation + +class GeminiEvalutation: + + TAXA_API_URL = "https://api.inaturalist.org/v2/taxa/" + + def __init__(self, **args): + self.cmd_args = args + currentDatetime = datetime.now() + self.start_timestamp = currentDatetime.strftime("%Y%m%d") + + def export_path(self, filename_addition=None, label=None): + if label is None: + label = self.cmd_args["label"] + export_path = f"evaluation-results-{self.start_timestamp}-{label}" + if filename_addition: + export_path += f"-{filename_addition}" + export_path += ".csv" + if self.cmd_args["data_dir"]: + export_path = os.path.join(self.cmd_args["data_dir"], export_path) + return export_path + + async def run_async(self): + if self.cmd_args["data_dir"]: + for file in sorted(os.listdir(self.cmd_args["data_dir"])): + exported_data_filename_match = re.search(r"test-results-[0-9]{8}-(.*).csv", file) + if exported_data_filename_match is None: + continue + path = os.path.join(self.cmd_args["data_dir"], file) + print(f"\nProcessing {file}") + await self.evaluate_observations_at_path(path) + self.display_and_save_results() + + async def evaluate_observations_at_path(self, path): + N_WORKERS = 5 + self.limit = self.cmd_args["limit"] or 100 + self.start_time = time.time() + self.queued_counter = 0 + self.processed_counter = 0 + self.test_observations = {} + + async with aiohttp.ClientSession() as self.session: + self.queue = asyncio.Queue() + self.workers = [ + asyncio.create_task(self.worker_task()) for _ in range(N_WORKERS) + ] + df = pd.read_csv( + path, + dtype={ + "observation_id": str, + "taxon_id": int, + "taxon_ancestry": str, + "gemini_response": str, + "gemini_error": str + } + ) + df = df.drop(df.columns[0], axis=1) + for index, observation in df.iterrows(): + obs = TestObservation(observation.to_dict()) + if obs.gemini_error == "True": + continue + self.test_observations[obs.observation_id] = obs + self.queue.put_nowait(obs.observation_id) + + # processes the queue + await self.queue.join() + # stop the workers + for worker in self.workers: + worker.cancel() + + async def worker_task(self): + while not self.queue.empty(): + observation_id = await self.queue.get() + try: + if self.processed_counter >= self.limit: + continue + observation = self.test_observations[observation_id] + await self.test_observation_async(observation) + self.processed_counter += 1 + self.report_progress() + + except Exception as err: + print(f"\nObservation: {observation_id} failed") + print(traceback.format_exc()) + print(err) + + finally: + self.queue.task_done() + + def string_raw_comparison(self, string1, string2): + return string1.lower().replace(" ", "") == string2.lower().replace(" ", "") + + async def test_observation_async(self, observation): + observation.clean_gemini_name = ''.join(char for char in observation.gemini_response if char.isalpha() or char.isspace() or char == '-') + print(observation.clean_gemini_name) + + original_taxa_url = self.TAXA_API_URL + str(observation.taxon_id) + "?fields=name,is_active,current_synonymous_taxon_ids" + original_response = requests.get(original_taxa_url) + if original_response.status_code == 200: + data = original_response.json() + observation.original_name = data["results"][0]["name"] + observation.original_is_active = data["results"][0]["is_active"] + observation.original_is_synonymous = data["results"][0]["current_synonymous_taxon_ids"] is not None + + observation.evaluation_status = False + + matching_active_taxa_url = self.TAXA_API_URL + "autocomplete?q=" + observation.clean_gemini_name + "&is_active=true&fields=id,name,is_active,current_synonymous_taxon_ids,rank" + matching_active_response = requests.get(matching_active_taxa_url) + + if matching_active_response.status_code == 200: + active_data = matching_active_response.json() + else: + observation.evaluation_status = False + observation.evaluation_error = "Error when calling iNat taxa suggest endpoint (active)" + return + + for result in active_data["results"]: + if self.string_raw_comparison(observation.clean_gemini_name, result["name"]): + observation.evaluation_status = True + observation.evaluation_name = result["name"] + observation.evaluation_taxon_id = result["id"] + observation.evaluation_rank = result["rank"] + observation.evaluation_is_active = result["is_active"] + break + + if not observation.evaluation_status: + matching_inactive_taxa_url = self.TAXA_API_URL + "autocomplete?q=" + observation.clean_gemini_name + "&is_active=false&fields=id,name,is_active,current_synonymous_taxon_ids,rank" + matching_inactive_response = requests.get(matching_inactive_taxa_url) + + if matching_inactive_response.status_code == 200: + inactive_data = matching_inactive_response.json() + else: + observation.evaluation_status = False + observation.evaluation_error = "Error when calling iNat taxa suggest endpoint (inactive)" + return + + for result in inactive_data["results"]: + if self.string_raw_comparison(observation.clean_gemini_name, result["name"]): + observation.evaluation_status = True + observation.evaluation_name = result["name"] + observation.evaluation_taxon_id = result["id"] + observation.evaluation_rank = result["rank"] + observation.evaluation_is_active = result["is_active"] + observation.evaluation_synonymous_taxon_ids = result["current_synonymous_taxon_ids"] + break + + observation.matching_active = False + observation.matching_synonym = False + + if observation.evaluation_status: + if observation.evaluation_is_active and observation.evaluation_taxon_id == observation.taxon_id: + observation.matching_active = True + elif not observation.evaluation_is_active and observation.taxon_id in observation.evaluation_synonymous_taxon_ids: + observation.matching_synonym = True + + observation.matching = observation.matching_active or observation.matching_synonym + observation.matching_int = int(observation.matching) + + def display_and_save_results(self): + observations_export_path = self.export_path() + test_observations_data = [obs.to_dict() for obs in self.test_observations.values()] + test_observations_df = pd.DataFrame(test_observations_data) + test_observations_df.to_csv(observations_export_path) + + def report_progress(self): + if self.processed_counter % 10 == 0: + total_time = round(time.time() - self.start_time, 2) + remaining_time = round(( + self.limit - self.processed_counter + ) / (self.processed_counter / total_time), 2) + rate = round(self.processed_counter / total_time, 2) + print( + f"Processed {self.processed_counter} in {total_time} sec \t" + f"{rate}/sec \t" + f"estimated {remaining_time} sec remaining\t" + ) diff --git a/gemini/requirements.txt b/gemini/requirements.txt new file mode 100644 index 0000000..3c71813 --- /dev/null +++ b/gemini/requirements.txt @@ -0,0 +1,10 @@ +click +PyYAML +aiohttp +aiofiles +requests +python-magic +numpy==1.26.4;python_version>="3.11" +numpy==1.23.5;python_version=="3.8" +pandas==2.1.2;python_version>="3.11" +pandas==2.0.3;python_version=="3.8" \ No newline at end of file diff --git a/gemini/test_gemini.py b/gemini/test_gemini.py new file mode 100644 index 0000000..46e3771 --- /dev/null +++ b/gemini/test_gemini.py @@ -0,0 +1,23 @@ +import click +import yaml +import json +import asyncio + +@click.command() +@click.option("--data_dir", type=click.Path(), help="Path to test data CSVs directory.") +@click.option("--label", required=True, type=str, help="Label used for output.") +@click.option("--limit", type=int, show_default=True, default=100, help="Max number of observations to test.") +def test(**args): + print("\nArguments:") + print(json.dumps(args, indent=4)) + + from evaluate_gemini_results import GeminiEvalutation + geminiEvalutation = GeminiEvalutation(**args) + + asyncio.run(geminiEvalutation.run_async()) + + print("\nDone\n") + + +if __name__ == "__main__": + test() diff --git a/gemini/test_observation.py b/gemini/test_observation.py new file mode 100644 index 0000000..8fac21e --- /dev/null +++ b/gemini/test_observation.py @@ -0,0 +1,8 @@ +class TestObservation: + + def __init__(self, row): + for key in row: + setattr(self, key, row[key]) + + def to_dict(self): + return vars(self) \ No newline at end of file