Skip to content

Commit

Permalink
metrics 구현 완료, train 코드 진행, 텐서플로우 레거시 제거
Browse files Browse the repository at this point in the history
  • Loading branch information
jason9693 committed Oct 19, 2019
1 parent d6d7e6a commit 0d7863b
Show file tree
Hide file tree
Showing 7 changed files with 296 additions and 296 deletions.
1 change: 1 addition & 0 deletions custom/criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(self, weight: Optional[Any] = ..., ignore_index: int = ..., reducti
def forward(self, input: Tensor, target: Tensor) -> Tensor:
mask = target != par.pad_token
not_masked_length = mask.to(torch.int).sum()
input = input.permute(0, -1, -2)
_loss = super().forward(input, target)
_loss *= mask
return _loss.sum() / not_masked_length
Expand Down
23 changes: 14 additions & 9 deletions custom/metrics.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,40 @@
import torch
from typing import List
import torch.nn.functional as F

from typing import Dict


class _Metric(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input: torch.Tensor, target: torch.Tensor):
pass
raise NotImplementedError()


class CategoricalAccuracy(_Metric):
class Accuracy(_Metric):
def __init__(self):
super().__init__()

def forward(self, input: torch.Tensor, target: torch.Tensor):
pass
bool_acc = input == target
return bool_acc.sum() / bool_acc.numel()


class Accuracy(_Metric):
class CategoricalAccuracy(Accuracy):
def __init__(self):
super().__init__()

def forward(self, input: torch.Tensor, target: torch.Tensor):
pass
categorical_input = input.argmax(-1)
return super().forward(categorical_input, target)


class MetricsSet(_Metric):
def __init__(self, metrics: List[_Metric]):
def __init__(self, metric_dict: Dict):
super().__init__()
self.metrics = metrics
self.metrics = metric_dict

def forward(self, input: torch.Tensor, target: torch.Tensor):
return [metric(input, target) for metric in self.metrics]
# return [metric(input, target) for metric in self.metrics]
return {k: metric(input, target) for k, metric in self.metrics.items()}
File renamed without changes.
22 changes: 10 additions & 12 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,16 @@
gen_summary_writer = tf.summary.create_file_writer(gen_log_dir)


if mode == 'enc-dec':
print(">> generate with original seq2seq wise... beam size is {}".format(beam))
mt = MusicTransformer(
embedding_dim=256,
vocab_size=par.vocab_size,
num_layer=6,
max_seq=2048,
dropout=0.2,
debug=False, loader_path=load_path)
else:
print(">> generate with decoder wise... beam size is {}".format(beam))
mt = MusicTransformerDecoder(loader_path=load_path)
print(">> generate with original seq2seq wise... beam size is {}".format(beam))
# mt = MusicTransformer(
# embedding_dim=256,
# vocab_size=par.vocab_size,
# num_layer=6,
# max_seq=2048,
# dropout=0.2,
# debug=False, loader_path=load_path)
mt = torch.load(load_path)
mt.eval()

inputs = encode_midi('dataset/midi/BENABD10.mid')

Expand Down
Loading

0 comments on commit 0d7863b

Please sign in to comment.