Skip to content

Commit

Permalink
indices
Browse files Browse the repository at this point in the history
  • Loading branch information
kapoorlab committed Aug 14, 2024
1 parent 406962b commit 87287a5
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
24 changes: 12 additions & 12 deletions src/napatrackmater/Trackvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2485,7 +2485,7 @@ def train_gbr_neural_net(

if isinstance(block_config, int):
block_config = (block_config,)
if npz_file is not None:
if npz_file is not None:
mitosis_inception = MitosisInception(
npz_file=npz_file,
num_classes=num_classes,
Expand Down Expand Up @@ -4040,28 +4040,26 @@ def inception_model_prediction(
dynamic_model=None,
shape_model=None,
num_samples=10,
device='cpu'
device="cpu",
):
sub_dataframe = dataframe[dataframe["Track ID"] == track_id]
sub_dataframe_dynamic = sub_dataframe[DYNAMIC_FEATURES].values
sub_dataframe_shape = sub_dataframe[SHAPE_FEATURES].values

total_duration = sub_dataframe["Track Duration"].max()

def sample_subarrays(data, num_samples, tracklet_length, total_duration):



max_start_index = total_duration - tracklet_length
start_indices = random.sample(range(max_start_index + 1), num_samples)

start_indices = random.sample(range(max_start_index), num_samples)
subarrays = []
for start_index in start_indices:
end_index = start_index + tracklet_length
if end_index <= total_duration:
sub_data = data[start_index:end_index, :]
if sub_data.shape[0] == tracklet_length:
subarrays.append(sub_data)

sub_data = data[start_index:end_index, :]
if sub_data.shape[0] == tracklet_length:
subarrays.append(sub_data)

return subarrays

Expand Down Expand Up @@ -4117,7 +4115,9 @@ def get_most_frequent_prediction(predictions):
return "UnClassified"


def save_cell_type_predictions(tracks_dataframe, cell_map, predictions, save_dir, channel):
def save_cell_type_predictions(
tracks_dataframe, cell_map, predictions, save_dir, channel
):

cell_type = {}
for value in cell_map.values():
Expand Down
4 changes: 2 additions & 2 deletions src/napatrackmater/_version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version__ = version = "5.4.5"
__version_tuple__ = version_tuple = (5, 4, 5)
__version__ = version = "5.4.6"
__version_tuple__ = version_tuple = (5, 4, 6)

0 comments on commit 87287a5

Please sign in to comment.