Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Philippa A. (Pippa) Richter committed Jan 15, 2025
1 parent 77a3c10 commit e200a7b
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions scripts/reformat-gtdb-embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ def parse_description(description:str):
match = re.search(pattern, description)
return {col:match.group(i + 1) for i, col in enumerate(columns)}

df = FASTAFile(f'../data/gtdb_subset_proteins/{file_name}').to_df(parse_description=False)
df = FASTAFile(path).to_df(parse_description=False)
df = pd.concat([df, pd.DataFrame([parse_description(d) for d in df.description], index=df.index)], axis=1)
df['prefix'] = prefix
df['genome_id'] = genome_id
df = df.sort_values(by='ID') # So the entries align with the loaded embeddings.
return df


Expand All @@ -76,17 +77,21 @@ def load_embeddings(genome_id:str, prefix:str, dir_:str=None, file_name_format:s

genome_metadata_df = pd.read_csv(args.genome_metadata_path, index_col=0)
genome_metadata_columns = genome_metadata_df.columns
genome_ids = genome_metadata_df.index.values

pbar = tqdm(total=len(genome_metadata_df), desc='Reading genome data...')
for row in genome_metadata_df.itertuples():
genome_id, prefix = row.Index, row.prefix
proteins_df = load_proteins(genome_id, prefix, dir_=args.proteins_dir, file_name_format=args.proteins_file_name_format)
embeddings_df = load_embeddings(genome_id, prefix, dir_=args.embeddings_dir, file_name_format=args.embeddings_file_name_format)

for genome_id, row in zip(genome_ids, genome_metadata_df.to_dict(orient='records')):
try:
proteins_df = load_proteins(genome_id, row['prefix'], dir_=args.proteins_dir, file_name_format=args.proteins_file_name_format)
except FileNotFoundError:
print(f'Skipping genome {genome_id}, could not find proteins file in {args.proteins_dir}.')
pbar.update(1)
continue

embeddings_df = load_embeddings(genome_id, row['prefix'], dir_=args.embeddings_dir, file_name_format=args.embeddings_file_name_format)
# Remove sequences which exceed the maximum length specification.
length_filter = proteins_df.seq.apply(len) < 2000
embeddings_df = embeddings_df[length_filter]
proteins_df = proteins_df[length_filter]
proteins_df = proteins_df[proteins_df.seq.apply(len) < 2000]
embeddings_df = embeddings_df[embeddings_df.index.isin(proteins_df.ID)]

assert np.all(np.equal(embeddings_df.index, proteins_df.ID)), 'The IDs in the proteins and embeddings DataFrames do not match.'
embeddings_df.index = proteins_df.index
Expand All @@ -95,15 +100,13 @@ def load_embeddings(genome_id:str, prefix:str, dir_:str=None, file_name_format:s
# Create a metadata DataFrame with the genome and protein metadata.
metadata_df = proteins_df.copy()
for col in genome_metadata_columns:
metadata_df[col] = getattr(row, col)
metadata_df[col] = row[col]

store = pd.HDFStore(os.path.join(args.output_dir, f'{genome_id}.h5'), 'w')
store.put('metadata', metadata_df)
store.put('plm', embeddings_df)
store.close()

pbar.update(1)




0 comments on commit e200a7b

Please sign in to comment.