From 222b099e4aad9997831a624271f225eac2908602 Mon Sep 17 00:00:00 2001 From: Nat Date: Fri, 27 Aug 2021 10:47:43 +0200 Subject: [PATCH] bugfix in load/save tokenizer parameters (json tuple keys) --- miditok/midi_tokenizer_base.py | 9 +++++++-- setup.py | 2 +- tests/tests_utils.py | 9 +++++++++ 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/miditok/midi_tokenizer_base.py b/miditok/midi_tokenizer_base.py index ddee033c..bdf9745a 100644 --- a/miditok/midi_tokenizer_base.py +++ b/miditok/midi_tokenizer_base.py @@ -273,15 +273,18 @@ def save_params(self, out_dir: Union[str, Path, PurePath]): """ Saves the base parameters of this encoding in a txt file Useful to keep track of how a dataset has been tokenized / encoded It will also save the name of the class used, i.e. the encoding strategy + NOTE: as json cant save tuples as keys, the beat ranges are saved as strings + with the form startingBeat_endingBeat (underscore separating these two values) :param out_dir: output directory to save the file """ Path(out_dir).mkdir(parents=True, exist_ok=True) with open(PurePath(out_dir, 'config').with_suffix(".txt"), 'w') as outfile: json.dump({'pitch_range': (self.pitch_range.start, self.pitch_range.stop), - 'beat_res': self.beat_res, 'nb_velocities': len(self.velocity_bins), + 'beat_res': {f'{k1}_{k2}': v for (k1, k2), v in self.beat_res.items()}, + 'nb_velocities': len(self.velocity_bins), 'additional_tokens': self.additional_tokens, - 'encoding': self.__class__.__name__}, outfile) + 'encoding': self.__class__.__name__}, outfile, indent=4) def load_params(self, params: Union[str, Path, PurePath, Dict[str, Any]]): """ Load parameters and set the encoder attributes @@ -296,6 +299,8 @@ def load_params(self, params: Union[str, Path, PurePath, Dict[str, Any]]): params['pitch_range'] = range(*params['pitch_range']) for key, value in params.items(): + if key == 'beat_res': + value = {tuple(map(int, beat_range.split('_'))): res for beat_range, res in value.items()} setattr(self, key, value) diff --git a/setup.py b/setup.py index e34e589e..fb09adbf 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ author='Nathan Fradet', url='https://github.com/Natooz/MidiTok', packages=find_packages(exclude=("tests",)), - version='0.1.5', + version='0.1.6', license='MIT', description='A convenient MIDI tokenizer for Deep Learning networks, with multiple encoding strategies', long_description=long_description, diff --git a/tests/tests_utils.py b/tests/tests_utils.py index 62a39acd..2c480acd 100644 --- a/tests/tests_utils.py +++ b/tests/tests_utils.py @@ -4,6 +4,7 @@ from typing import List +from miditok import REMIEncoding from miditoolkit import Instrument, Note @@ -22,3 +23,11 @@ def strict_valid(expected_notes: List[Note], produced_notes: List[Note]): return False elif exp_note.velocity != prod_note.velocity: return False + + +def save_and_load_params(): + enc = REMIEncoding(beat_res={(0, 3): 5}) + enc.save_params('') + + enc2 = REMIEncoding() + enc2.load_params('config.txt')