Skip to content

Commit

Permalink
bugfix in load/save tokenizer parameters (json tuple keys)
Browse files Browse the repository at this point in the history
  • Loading branch information
Natooz committed Aug 27, 2021
1 parent ccd838c commit 222b099
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 3 deletions.
9 changes: 7 additions & 2 deletions miditok/midi_tokenizer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,15 +273,18 @@ def save_params(self, out_dir: Union[str, Path, PurePath]):
""" Saves the base parameters of this encoding in a txt file
Useful to keep track of how a dataset has been tokenized / encoded
It will also save the name of the class used, i.e. the encoding strategy
NOTE: as json cant save tuples as keys, the beat ranges are saved as strings
with the form startingBeat_endingBeat (underscore separating these two values)
:param out_dir: output directory to save the file
"""
Path(out_dir).mkdir(parents=True, exist_ok=True)
with open(PurePath(out_dir, 'config').with_suffix(".txt"), 'w') as outfile:
json.dump({'pitch_range': (self.pitch_range.start, self.pitch_range.stop),
'beat_res': self.beat_res, 'nb_velocities': len(self.velocity_bins),
'beat_res': {f'{k1}_{k2}': v for (k1, k2), v in self.beat_res.items()},
'nb_velocities': len(self.velocity_bins),
'additional_tokens': self.additional_tokens,
'encoding': self.__class__.__name__}, outfile)
'encoding': self.__class__.__name__}, outfile, indent=4)

def load_params(self, params: Union[str, Path, PurePath, Dict[str, Any]]):
""" Load parameters and set the encoder attributes
Expand All @@ -296,6 +299,8 @@ def load_params(self, params: Union[str, Path, PurePath, Dict[str, Any]]):
params['pitch_range'] = range(*params['pitch_range'])

for key, value in params.items():
if key == 'beat_res':
value = {tuple(map(int, beat_range.split('_'))): res for beat_range, res in value.items()}
setattr(self, key, value)


Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
author='Nathan Fradet',
url='https://github.com/Natooz/MidiTok',
packages=find_packages(exclude=("tests",)),
version='0.1.5',
version='0.1.6',
license='MIT',
description='A convenient MIDI tokenizer for Deep Learning networks, with multiple encoding strategies',
long_description=long_description,
Expand Down
9 changes: 9 additions & 0 deletions tests/tests_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from typing import List

from miditok import REMIEncoding
from miditoolkit import Instrument, Note


Expand All @@ -22,3 +23,11 @@ def strict_valid(expected_notes: List[Note], produced_notes: List[Note]):
return False
elif exp_note.velocity != prod_note.velocity:
return False


def save_and_load_params():
enc = REMIEncoding(beat_res={(0, 3): 5})
enc.save_params('')

enc2 = REMIEncoding()
enc2.load_params('config.txt')

0 comments on commit 222b099

Please sign in to comment.