Skip to content

Commit

Permalink
feat: Make caching function extensible for any standard
Browse files Browse the repository at this point in the history
  • Loading branch information
roquelopez committed Nov 20, 2024
1 parent bbaf1ba commit a125b0c
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 40 deletions.
5 changes: 2 additions & 3 deletions bdikit/models/contrastive_learning/cl_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from typing import List, Dict, Tuple, Optional
from bdikit.config import get_device
import numpy as np
Expand All @@ -13,7 +12,7 @@
from sklearn.metrics.pairwise import cosine_similarity
from tqdm.auto import tqdm
from bdikit.download import get_cached_model_or_download
from bdikit.utils import check_gdc_cache, write_embeddings_to_cache
from bdikit.utils import check_embedding_cache, write_embeddings_to_cache
from bdikit.models import ColumnEmbedder


Expand Down Expand Up @@ -108,7 +107,7 @@ def _sample_to_15_rows(self, table: pd.DataFrame):

def _load_table_tokens(self, table: pd.DataFrame) -> List[np.ndarray]:

embedding_file, embeddings = check_gdc_cache(table, self.model_path)
embedding_file, embeddings = check_embedding_cache(table, self.model_path)

if embeddings != None:
print(f"Table features loaded for {len(table.columns)} columns")
Expand Down
58 changes: 21 additions & 37 deletions bdikit/utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
import os
import hashlib
import pandas as pd
from os.path import join, dirname
from os.path import join, dirname, isfile
from bdikit.download import BDIKIT_EMBEDDINGS_CACHE_DIR
from bdikit.standards.standard_factory import Standards

GDC_TABLE_PATH = join(dirname(__file__), "./resource/gdc_table.csv")

__gdc_df = None
__gdc_hash = None


def hash_dataframe(df: pd.DataFrame) -> str:

hash_object = hashlib.sha256()

columns_string = ",".join(df.columns) + "\n"
Expand All @@ -27,50 +20,41 @@ def hash_dataframe(df: pd.DataFrame) -> str:

def write_embeddings_to_cache(embedding_file: str, embeddings: list):

os.makedirs(os.path.dirname(embedding_file), exist_ok=True)
os.makedirs(dirname(embedding_file), exist_ok=True)

with open(embedding_file, "w") as file:
for vec in embeddings:
file.write(",".join([str(val) for val in vec]) + "\n")


def load_gdc_data():
global __gdc_df, __gdc_hash
if __gdc_df is None or __gdc_hash is None:
standard = Standards.get_standard("gdc")
__gdc_df = standard.get_dataframe_rep()
__gdc_hash = hash_dataframe(__gdc_df)


def check_gdc_cache(table: pd.DataFrame, model_path: str):
global __gdc_df, __gdc_hash
load_gdc_data()

def check_embedding_cache(table: pd.DataFrame, model_path: str):
embedding_file = None
embeddings = None
table_hash = hash_dataframe(table)
model_name = model_path.split("/")[-1]
cache_model_path = join(BDIKIT_EMBEDDINGS_CACHE_DIR, model_name)
os.makedirs(cache_model_path, exist_ok=True)

df_hash_file = None
features = None
hash_list = {
f for f in os.listdir(cache_model_path) if isfile(join(cache_model_path, f))
}

# check if table for computing embedding is the same as the GDC table we have in resources
if table_hash == __gdc_hash:
model_name = model_path.split("/")[-1]
cache_model_path = join(BDIKIT_EMBEDDINGS_CACHE_DIR, model_name)
df_hash_file = join(cache_model_path, __gdc_hash)
embedding_file = join(cache_model_path, table_hash)

# Found file in cache
if os.path.isfile(df_hash_file):
# Check if table for computing embedding is the same as the tables we have in resources
if table_hash in hash_list:
if isfile(embedding_file):
try:
# Load embeddings from disk
with open(df_hash_file, "r") as file:
features = [
with open(embedding_file, "r") as file:
embeddings = [
[float(val) for val in vec.split(",")]
for vec in file.read().split("\n")
if vec.strip()
]
if len(features) != len(__gdc_df.columns):
features = None
raise ValueError("Mismatch in the number of features")

except Exception as e:
print(f"Error loading features from cache: {e}")
features = None
return df_hash_file, features
embeddings = None

return embedding_file, embeddings

0 comments on commit a125b0c

Please sign in to comment.