Skip to content

Commit

Permalink
fix multiprocessing in torch dataProvider
Browse files Browse the repository at this point in the history
  • Loading branch information
pythonlessons committed Oct 4, 2023
1 parent 23ecf0c commit ee6f0b5
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 7 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## [1.1.4] - 2022-09-29
### Changed
- Improoved `mltu.torch.dataProvider.DataProvider` to hangle `multiprocessing` when it doesn't work to switch to `multithreading`

## [1.1.3] - 2022-09-29
### Changed
- Removed `Librosa` library dependency in requirements, now it is optional and required only with modules that use librosa
Expand Down
6 changes: 4 additions & 2 deletions Tutorials/10_wav2vec2_torch/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
torch==1.13.1+cu117
torch>=1.13.1+cu117
transformers==4.33.1
onnx
mltu==1.1.4
onnx
onnxruntime
1 change: 1 addition & 0 deletions Tutorials/10_wav2vec2_torch/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,6 @@ def predict(self, audio: np.ndarray):

accum_cer.append(cer)
accum_wer.append(wer)
print(label)

pbar.set_description(f"Average CER: {np.average(accum_cer):.4f}, Average WER: {np.average(accum_wer):.4f}")
6 changes: 3 additions & 3 deletions Tutorials/10_wav2vec2_torch/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,15 @@ def download_and_unzip(url, extract_to="Datasets", chunk_size=1024*1024):
],
transformers=[
LabelIndexer(vocab),
LabelPadding(max_word_length=configs.max_label_length, padding_value=len(vocab)),
],
use_cache=False,
batch_postprocessors=[
AudioPadding(max_audio_length=configs.max_audio_length, padding_value=0, use_on_batch=True)
AudioPadding(max_audio_length=configs.max_audio_length, padding_value=0, use_on_batch=True),
LabelPadding(padding_value=len(vocab), use_on_batch=True),
],
use_multiprocessing=True,
max_queue_size=10,
workers=64,
workers=configs.train_workers,
)
train_dataProvider, test_dataProvider = data_provider.split(split=0.9)

Expand Down
2 changes: 1 addition & 1 deletion mltu/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "1.1.3"
__version__ = "1.1.4"

from .annotations.images import Image
from .annotations.images import CVImage
Expand Down
7 changes: 6 additions & 1 deletion mltu/torch/dataProvider.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,12 @@ def start_executor(self) -> None:

if not hasattr(self, "_executor"):
if self.use_multiprocessing:
self._executor = ProcessExecutor(self.process_data, self.workers)
try:
self._executor = ProcessExecutor(self.process_data, self.workers)
except:
self.use_multiprocessing = False
self.logger.error("Failed to start multiprocessing, switching to multithreading")
self._executor = ThreadExecutor(self.process_data, self.workers)
else:
self._executor = ThreadExecutor(self.process_data, self.workers)

Expand Down

0 comments on commit ee6f0b5

Please sign in to comment.