diff --git a/generate_thresholds.py b/generate_thresholds.py index 316695b..40c9197 100644 --- a/generate_thresholds.py +++ b/generate_thresholds.py @@ -102,7 +102,7 @@ def main(args): }) presence_absence = presence_absence.fillna(0) - print("...looping through taxa") + print("loading taxonomy...") output = [] taxa = pd.read_csv( args.taxonomy, @@ -117,6 +117,12 @@ def main(args): taxon_ids = taxon_ids[0:args.stop_after] resolution = args.h3_resolution area = h3.hex_area(resolution) + + # we want the taxon id to be the index since we'll be selecting on it + train_df_h3.reset_index(inplace=True) + train_df_h3.set_index("taxon_id", inplace=True) + + print("...looping through taxa") for taxon_id in tqdm(taxon_ids): try: class_of_interest = mtd.df.loc[taxon_id]["leaf_class_id"] @@ -130,7 +136,7 @@ def main(args): # make presence absence dataset target_spatial_grid_counts = \ - train_df_h3[train_df_h3.taxon_id == taxon_id].index.value_counts() + train_df_h3[train_df_h3.index == taxon_id].h3_04.value_counts() presences = gdfk.loc[target_spatial_grid_counts.index]["pred"] if len(presences) == 0: print("not present")