Skip to content

Commit

Permalink
Gemini evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
sylvain-morin committed Sep 10, 2024
1 parent ff5b713 commit e4da0d5
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 29 deletions.
63 changes: 36 additions & 27 deletions gemini/evaluate_gemini_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,11 @@ def __init__(self, **args):
currentDatetime = datetime.now()
self.start_timestamp = currentDatetime.strftime("%Y%m%d")

def export_path(self, filename_addition=None, label=None):
if label is None:
label = self.cmd_args["label"]
export_path = f"evaluation-results-{self.start_timestamp}-{label}"
if filename_addition:
export_path += f"-{filename_addition}"
export_path += ".csv"
if self.cmd_args["data_dir"]:
export_path = os.path.join(self.cmd_args["data_dir"], export_path)
return export_path

def export_path(self, path, file):
label = self.cmd_args["label"]
export_path = f"{label}-{file}"
return os.path.join(self.cmd_args["data_dir"], export_path)

async def run_async(self):
if self.cmd_args["data_dir"]:
for file in sorted(os.listdir(self.cmd_args["data_dir"])):
Expand All @@ -44,7 +38,8 @@ async def run_async(self):
path = os.path.join(self.cmd_args["data_dir"], file)
print(f"\nProcessing {file}")
await self.evaluate_observations_at_path(path)
self.display_and_save_results()
observations_export_path = self.export_path(path, file)
self.display_and_save_results(observations_export_path)

async def evaluate_observations_at_path(self, path):
N_WORKERS = 5
Expand Down Expand Up @@ -109,6 +104,33 @@ async def test_observation_async(self, observation):
observation.clean_gemini_name = ''.join(char for char in observation.gemini_response if char.isalpha() or char.isspace() or char == '-')
print(observation.clean_gemini_name)

self.evaluate_taxon_name(observation)

if not observation.evaluation_status:
clean_gemini_name = observation.gemini_response.replace("*", "")
gbif_response = requests.get("https://api.gbif.org/v1/parser/name?name="+clean_gemini_name)
if gbif_response.status_code == 200:
gbif_data = gbif_response.json()
if gbif_data[0]:
observation.clean_gemini_name = gbif_data[0]["canonicalName"]
observation.evaluate_using_gbif = True
print(observation.clean_gemini_name + " (from GBIF)")

self.evaluate_taxon_name(observation)

observation.matching_active = False
observation.matching_synonym = False

if observation.evaluation_status:
if observation.evaluation_is_active and observation.evaluation_taxon_id == observation.taxon_id:
observation.matching_active = True
elif not observation.evaluation_is_active and observation.taxon_id in observation.evaluation_synonymous_taxon_ids:
observation.matching_synonym = True

observation.matching = observation.matching_active or observation.matching_synonym
observation.matching_int = int(observation.matching)

def evaluate_taxon_name(self, observation):
original_taxa_url = self.TAXA_API_URL + str(observation.taxon_id) + "?fields=name,is_active,current_synonymous_taxon_ids"
original_response = requests.get(original_taxa_url)
if original_response.status_code == 200:
Expand Down Expand Up @@ -158,21 +180,8 @@ async def test_observation_async(self, observation):
observation.evaluation_is_active = result["is_active"]
observation.evaluation_synonymous_taxon_ids = result["current_synonymous_taxon_ids"]
break

observation.matching_active = False
observation.matching_synonym = False

if observation.evaluation_status:
if observation.evaluation_is_active and observation.evaluation_taxon_id == observation.taxon_id:
observation.matching_active = True
elif not observation.evaluation_is_active and observation.taxon_id in observation.evaluation_synonymous_taxon_ids:
observation.matching_synonym = True

observation.matching = observation.matching_active or observation.matching_synonym
observation.matching_int = int(observation.matching)

def display_and_save_results(self):
observations_export_path = self.export_path()

def display_and_save_results(self, observations_export_path):
test_observations_data = [obs.to_dict() for obs in self.test_observations.values()]
test_observations_df = pd.DataFrame(test_observations_data)
test_observations_df.to_csv(observations_export_path)
Expand Down
61 changes: 61 additions & 0 deletions gemini/reconcile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import yaml
import json
import re
import os
import click
import pandas as pd

def reconcile(gemini_file_path, vision_file_path, result_file_path):
print("Reconcile:")
print(gemini_file_path)
print(vision_file_path)
print(result_file_path)

gemini_file = pd.read_csv(gemini_file_path)
vision_file = pd.read_csv(vision_file_path)

# Filter vision_file where method is "combined"
vision_file_filtered = vision_file[(vision_file['method'] == 'combined') & (vision_file['inferrer_name'] == 2.15)]

# Merge gemini_file with vision_file_filtered on observation_id == uuid
merged_df = pd.merge(gemini_file, vision_file_filtered[['uuid', 'matching_index']],
left_on='observation_id', right_on='uuid',
how='left')

# Drop the redundant 'uuid' column if necessary
merged_df.drop('uuid', axis=1, inplace=True)

# Add a new column 'is_matching_index_zero' which is True if matching_index == 0, otherwise False
merged_df['is_matching_index_zero'] = merged_df['matching_index'] == 0

# Add another column 'is_matching_index_zero_int' which is 1 if True, 0 if False
merged_df['is_matching_index_zero_int'] = merged_df['is_matching_index_zero'].astype(int)

# Save the result to a new CSV file
merged_df.to_csv(result_file_path, index=False)

@click.command()
@click.option("--data_dir", type=click.Path(), help="Path to test data CSVs directory.")
@click.option("--label", required=True, type=str, help="Label used for output.")
def test(**args):
print("\nArguments:")
print(json.dumps(args, indent=4))

label = args["label"]
folder_path = args["data_dir"]

for file in sorted(os.listdir(folder_path)):
eval_filename_match = re.search(rf"{label}-test-results-[0-9]{{8}}-([a-zA-Z]+)-.*\.csv", file)
if eval_filename_match:
group = eval_filename_match.group(1)
for observation_file in sorted(os.listdir(folder_path)):
if re.search(f"test-results-[0-9]{{8}}-{group}-.*-observations\.csv", observation_file):
reconcile(
os.path.join(folder_path, file),
os.path.join(folder_path, observation_file),
os.path.join(folder_path, "comparison-"+file)
)


if __name__ == "__main__":
test()
4 changes: 2 additions & 2 deletions gemini/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ requests
python-magic
numpy==1.26.4;python_version>="3.11"
numpy==1.23.5;python_version=="3.8"
pandas==2.1.2;python_version>="3.11"
pandas==2.0.3;python_version=="3.8"
pandas==2.0.3;python_version=="3.8"
pandas==2.1.2;python_version>="3.11"

0 comments on commit e4da0d5

Please sign in to comment.