Skip to content

Commit

Permalink
bugfix in save params for octuple/mumidi
Browse files Browse the repository at this point in the history
  • Loading branch information
Natooz committed Aug 29, 2021
1 parent 222b099 commit 0ef15d9
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 9 deletions.
12 changes: 6 additions & 6 deletions miditok/mumidi.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class MuMIDIEncoding(MIDITokenizer):
https://arxiv.org/abs/2008.07703
:param pitch_range: range of used MIDI pitches
:param drum_pitch_range: range of used MIDI pitches for drums exclusively
:param beat_res: beat resolutions, with the form:
{(beat_x1, beat_x2): beat_res_1, (beat_x2, beat_x3): beat_res_2, ...}
The keys of the dict are tuples indicating a range of beats, ex 0 to 3 for the first bar
Expand All @@ -36,11 +35,11 @@ class MuMIDIEncoding(MIDITokenizer):
:param program_tokens: will add entries for MIDI programs in the dictionary, to use
in the case of multitrack generation for instance
:param params: can be a path to the parameter (json encoded) file or a dictionary
:param drum_pitch_range: range of used MIDI pitches for drums exclusively
"""
def __init__(self, pitch_range: range = PITCH_RANGE, drum_pitch_range: range = DRUM_PITCH_RANGE,
beat_res: Dict[Tuple[int, int], int] = BEAT_RES, nb_velocities: int = NB_VELOCITIES,
additional_tokens: Dict[str, bool] = ADDITIONAL_TOKENS, program_tokens: bool = PROGRAM_TOKENS,
params=None):
def __init__(self, pitch_range: range = PITCH_RANGE, beat_res: Dict[Tuple[int, int], int] = BEAT_RES,
nb_velocities: int = NB_VELOCITIES, additional_tokens: Dict[str, bool] = ADDITIONAL_TOKENS,
program_tokens: bool = PROGRAM_TOKENS, params=None, drum_pitch_range: range = DRUM_PITCH_RANGE):
self.drum_pitch_range = drum_pitch_range
# used in place of positional encoding
self.max_bar_embedding = 60 # this attribute might increase during encoding
Expand All @@ -58,7 +57,8 @@ def save_params(self, out_dir: Union[str, Path, PurePath]):
with open(PurePath(out_dir, 'config').with_suffix(".txt"), 'w') as outfile:
json.dump({'pitch_range': (self.pitch_range.start, self.pitch_range.stop),
'drum_pitch_range': (self.drum_pitch_range.start, self.drum_pitch_range.stop),
'beat_res': self.beat_res, 'nb_velocities': len(self.velocity_bins),
'beat_res': {f'{k1}_{k2}': v for (k1, k2), v in self.beat_res.items()},
'nb_velocities': len(self.velocity_bins),
'additional_tokens': self.additional_tokens, 'encoding': self.__class__.__name__,
'max_bar_embedding': self.max_bar_embedding},
outfile)
Expand Down
3 changes: 2 additions & 1 deletion miditok/octuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def save_params(self, out_dir: Union[str, Path, PurePath]):
Path(out_dir).mkdir(parents=True, exist_ok=True)
with open(PurePath(out_dir, 'config').with_suffix(".txt"), 'w') as outfile:
json.dump({'pitch_range': (self.pitch_range.start, self.pitch_range.stop),
'beat_res': self.beat_res, 'nb_velocities': len(self.velocity_bins),
'beat_res': {f'{k1}_{k2}': v for (k1, k2), v in self.beat_res.items()},
'nb_velocities': len(self.velocity_bins),
'additional_tokens': self.additional_tokens, 'encoding': self.__class__.__name__,
'max_bar_embedding': self.max_bar_embedding},
outfile)
Expand Down
3 changes: 2 additions & 1 deletion miditok/octuple_mono.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def save_params(self, out_dir: Union[str, Path, PurePath]):
Path(out_dir).mkdir(parents=True, exist_ok=True)
with open(PurePath(out_dir, 'config').with_suffix(".txt"), 'w') as outfile:
json.dump({'pitch_range': (self.pitch_range.start, self.pitch_range.stop),
'beat_res': self.beat_res, 'nb_velocities': len(self.velocity_bins),
'beat_res': {f'{k1}_{k2}': v for (k1, k2), v in self.beat_res.items()},
'nb_velocities': len(self.velocity_bins),
'additional_tokens': self.additional_tokens, 'encoding': self.__class__.__name__,
'max_bar_embedding': self.max_bar_embedding},
outfile)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
author='Nathan Fradet',
url='https://github.com/Natooz/MidiTok',
packages=find_packages(exclude=("tests",)),
version='0.1.6',
version='0.1.7',
license='MIT',
description='A convenient MIDI tokenizer for Deep Learning networks, with multiple encoding strategies',
long_description=long_description,
Expand Down

0 comments on commit 0ef15d9

Please sign in to comment.