Skip to content

Commit

Permalink
perf(broad-1/scoring): optimize check
Browse files Browse the repository at this point in the history
  • Loading branch information
Caceresenzo committed Oct 28, 2024
1 parent 9b23813 commit b945e97
Showing 1 changed file with 23 additions and 26 deletions.
49 changes: 23 additions & 26 deletions competitions/broad-1/scoring/scoring.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import datetime
import gc
import os
import typing

Expand All @@ -15,7 +16,8 @@
def log(action: str):
global _LOG_DEPTH

print(datetime.datetime.now(), " " * _LOG_DEPTH, action)
start = datetime.datetime.now()
print(start, " " * _LOG_DEPTH, action)

try:
_LOG_DEPTH += 1
Expand All @@ -24,6 +26,11 @@ def log(action: str):
finally:
_LOG_DEPTH -= 1

gc.collect()

end = datetime.datetime.now()
print(end, " " * _LOG_DEPTH, action, "took", end - start)


class ParticipantVisibleError(Exception):
"""Custom exception for errors related to participant visibility."""
Expand All @@ -41,61 +48,51 @@ def check(

if difference:
raise ParticipantVisibleError(f"Missing or extra columns: {', '.join(difference)}")


prediction.set_index("sample", drop=True, inplace=True)
with log("Check for missing samples"):
difference = set(prediction['sample'].unique()) ^ set(target_names)
difference = set(prediction.index.unique()) ^ set(target_names)

if difference:
raise ParticipantVisibleError(f"Missing or extra samples: {', '.join(difference)}")

with log("Filter predictions by samples once to avoid filtering in the loop"):
group_by_sample = {
sample: group
for sample, group in prediction.groupby('sample')
}
for target_name in target_names:
with log(f"Filter prediction at target -> {target_name}"):
prediction_slice = prediction[prediction.index == target_name]

for target in target_names:
log(f"Loop through each target -> {target}")

sdata = _read_zarr(data_directory_path, target)

with log("Get predictions for the current sample"):
prediction = group_by_sample.get(target)

if prediction is None:
raise ParticipantVisibleError(f"No predictions for gene {target.name}.")
sdata = _read_zarr(data_directory_path, target_name)

with log("Extract unique cell IDs where the group is either 'test' or 'validation'"):
cell_ids = set(sdata['cell_id-group'].obs.query("group == 'test' or group == 'validation'")['cell_id'])
gene_names = set(sdata['anucleus'].var.index)

with log("Check for NaN values in predictions"):
if prediction.isnull().values.any():
if prediction_slice.isnull().values.any():
raise ParticipantVisibleError("Predictions contain NaN values, which are not allowed.")

with log("Check that all genes are present in predictions"):
missing = set(prediction['gene']) - gene_names
missing = set(prediction_slice['gene']) - gene_names

if missing:
raise ParticipantVisibleError(f"The following genes are missing in predictions: {', '.join(list(missing)[-10:])}.")

with log("Check that all cell IDs are present in predictions"):
missing = set(prediction['cell_id']) - cell_ids
missing = set(prediction_slice['cell_id']) - cell_ids

if missing:
raise ParticipantVisibleError(f"The following cell IDs are missing in predictions: {', '.join(list(missing)[-10:])}.")
raise ParticipantVisibleError(f"The following cell IDs are missing in predictions: {', '.join(list(map(str, missing))[-10:])}.")

with log("Check data types in the 'prediction' column"):
if not pandas.api.types.is_numeric_dtype(prediction['prediction']):
if not pandas.api.types.is_numeric_dtype(prediction_slice['prediction']):
raise ParticipantVisibleError("The 'prediction' column should only contain numeric values.")

with log("Ensure all prediction values are positive"):
if (prediction['prediction'] < 0).any():
if (prediction_slice['prediction'] < 0).any():
raise ParticipantVisibleError("Prediction values should be positive.")

with log("Verify the size of predictions matches expectations"):
expected = len(cell_ids) * len(gene_names)
got = len(prediction)
got = len(prediction_slice)

if expected != got:
raise ParticipantVisibleError(f"Predictions should have {expected} rows but has {got}.")
Expand Down Expand Up @@ -182,6 +179,6 @@ def _read_zarr(
zar_data = os.path.join(data_directory_path, f"{target_name}.zarr")

with log("Read the Zarr data"):
sdata = spatialdata.read_zarr(zar_data)
sdata = spatialdata.read_zarr(zar_data, selection=("tables", ))

return sdata

0 comments on commit b945e97

Please sign in to comment.