Skip to content

Commit

Permalink
Merge pull request #1 from mhburrell/mark_optimize
Browse files Browse the repository at this point in the history
Mark optimize
  • Loading branch information
mhburrell authored May 8, 2024
2 parents 42e1f02 + 0278d60 commit 70b8a4a
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 9 deletions.
35 changes: 33 additions & 2 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import platform
import os
import warnings
import pickle

import sys
import contextlib
Expand Down Expand Up @@ -386,12 +387,18 @@ def run(self):
else:
n_jobs = min(self.n_jobs, len(all_chunks))

init_args = (self.func, self.init_func, self.init_args, self.max_threads_per_process)
import pickle
#dump this to a pickle file
with open('init_args.pkl', 'wb') as f:
pickle.dump(init_args, f)

# parallel
with ProcessPoolExecutor(
max_workers=n_jobs,
initializer=worker_initializer,
initializer=worker_initializer_read_pickle,
mp_context=mp.get_context(self.mp_context),
initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process),
initargs=('init_args.pkl',),
) as executor:
results = executor.map(function_wrapper, all_chunks)

Expand All @@ -414,6 +421,29 @@ def run(self):
global _worker_ctx
global _func

import pickle
from threadpoolctl import threadpool_limits

def worker_initializer_read_pickle(in_pickle):
# Load the configuration from a pickle file.
try:
with open(in_pickle, 'rb') as f:
func, init_func, init_args, max_threads_per_process = pickle.load(f)
except (FileNotFoundError, pickle.UnpicklingError) as e:
raise Exception(f"Failed to load initialization data: {e}")

# Initialize global context for the worker.
global _worker_ctx, _func
if max_threads_per_process is None:
_worker_ctx = init_func(*init_args) # Unpack init_args correctly here.
else:
# Limit the number of threads for this process if specified.
with threadpool_limits(limits=max_threads_per_process):
_worker_ctx = init_func(*init_args) # Unpack init_args correctly here.

# Store additional configuration in the context.
_worker_ctx["max_threads_per_process"] = max_threads_per_process
_func = func # Set the function to be used by the worker.

def worker_initializer(func, init_func, init_args, max_threads_per_process):
global _worker_ctx
Expand Down Expand Up @@ -493,3 +523,4 @@ def get_poolexecutor(n_jobs):
return MockPoolExecutor
else:
return ProcessPoolExecutor

44 changes: 38 additions & 6 deletions src/spikeinterface/core/recording_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import gc
import mmap
import tqdm


#if OS is windows, import win32file
if os.name == 'nt':
import win32file
import win32api
import numpy as np

from .core_tools import add_suffix, make_shared_array
Expand Down Expand Up @@ -67,6 +69,28 @@ def _init_binary_worker(recording, file_path_dict, dtype, byte_offest, cast_unsi
return worker_ctx



def allocate_file(file_path, file_size_bytes):
# Open the file
if isinstance(file_path, Path):
file_path = str(file_path)

handle = win32file.CreateFile(
file_path,
win32file.GENERIC_WRITE,
0,
None,
win32file.CREATE_ALWAYS,
0,
None
)

# Move the file pointer and set the end of the file
win32file.SetFilePointer(handle, file_size_bytes, win32file.FILE_BEGIN)
win32file.SetEndOfFile(handle)
win32file.CloseHandle(handle)


def write_binary_recording(
recording,
file_paths,
Expand Down Expand Up @@ -122,10 +146,18 @@ def write_binary_recording(
num_frames = recording.get_num_frames(segment_index=segment_index)
data_size_bytes = dtype_size_bytes * num_frames * num_channels
file_size_bytes = data_size_bytes + byte_offset

file = open(file_path, "wb+")
file.truncate(file_size_bytes)
file.close()

if os.name == 'nt':
allocate_file(file_path, file_size_bytes)
else:
file = open(file_path, "wb+")
file.truncate(file_size_bytes)
file.close()

#file = open(file_path, "wb+")
#file.truncate(file_size_bytes)
#file.close()
#allocate_file(file_path, file_size_bytes)
assert Path(file_path).is_file()

# use executor (loop or workers)
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/extractors/cbin_ibl.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def extract_stream_info(meta_file, meta):
v = meta["imroTbl"][c].split(" ")[index_imroTbl]
per_channel_gain[c] = 1.0 / float(v)
gain_factor = float(meta["imAiRangeMax"]) / 512
channel_gains = gain_factor * per_channel_gain * 1e6
channel_gains = gain_factor * per_channel_gain * 1e3
elif meta["imDatPrb_type"] in ("21", "24", "2003", "2004", "2013", "2014"):
# This work with NP 2.0 case with different metadata versions
# https://github.com/billkarsh/SpikeGLX/blob/15ec8898e17829f9f08c226bf04f46281f106e5f/Markdown/Metadata_30.md#imec
Expand Down

0 comments on commit 70b8a4a

Please sign in to comment.