From 747f99a0f3eb86e46d62db8ed08e7171b8752a3d Mon Sep 17 00:00:00 2001 From: "Marc P. Rostock" Date: Wed, 28 Apr 2021 17:21:55 +0200 Subject: [PATCH] Revert Accuracy Hack in CTC LitModel. Fix for lab3 only, ajusted everywhere for consistency --- lab3/text_recognizer/lit_models/ctc.py | 4 ++++ lab4/text_recognizer/lit_models/ctc.py | 4 ++++ lab5/text_recognizer/lit_models/ctc.py | 4 ++++ lab7/text_recognizer/lit_models/ctc.py | 4 ++++ lab8/text_recognizer/lit_models/ctc.py | 4 ++++ lab9/text_recognizer/lit_models/ctc.py | 4 ++++ 6 files changed, 24 insertions(+) diff --git a/lab3/text_recognizer/lit_models/ctc.py b/lab3/text_recognizer/lit_models/ctc.py index 9c761c9..ac01773 100644 --- a/lab3/text_recognizer/lit_models/ctc.py +++ b/lab3/text_recognizer/lit_models/ctc.py @@ -1,5 +1,6 @@ import argparse import itertools +import pytorch_lightning as pl import torch from .base import BaseLitModel @@ -46,6 +47,9 @@ def __init__(self, model, args: argparse.Namespace = None): self.loss_fn = torch.nn.CTCLoss(zero_infinity=True) # https://pytorch.org/docs/stable/generated/torch.nn.CTCLoss.html + self.val_acc = pl.metrics.Accuracy() + self.test_acc = pl.metrics.Accuracy() + ignore_tokens = [start_index, end_index, self.padding_index] self.val_cer = CharacterErrorRate(ignore_tokens) self.test_cer = CharacterErrorRate(ignore_tokens) diff --git a/lab4/text_recognizer/lit_models/ctc.py b/lab4/text_recognizer/lit_models/ctc.py index 9c761c9..ac01773 100644 --- a/lab4/text_recognizer/lit_models/ctc.py +++ b/lab4/text_recognizer/lit_models/ctc.py @@ -1,5 +1,6 @@ import argparse import itertools +import pytorch_lightning as pl import torch from .base import BaseLitModel @@ -46,6 +47,9 @@ def __init__(self, model, args: argparse.Namespace = None): self.loss_fn = torch.nn.CTCLoss(zero_infinity=True) # https://pytorch.org/docs/stable/generated/torch.nn.CTCLoss.html + self.val_acc = pl.metrics.Accuracy() + self.test_acc = pl.metrics.Accuracy() + ignore_tokens = [start_index, end_index, self.padding_index] self.val_cer = CharacterErrorRate(ignore_tokens) self.test_cer = CharacterErrorRate(ignore_tokens) diff --git a/lab5/text_recognizer/lit_models/ctc.py b/lab5/text_recognizer/lit_models/ctc.py index 9c761c9..ac01773 100644 --- a/lab5/text_recognizer/lit_models/ctc.py +++ b/lab5/text_recognizer/lit_models/ctc.py @@ -1,5 +1,6 @@ import argparse import itertools +import pytorch_lightning as pl import torch from .base import BaseLitModel @@ -46,6 +47,9 @@ def __init__(self, model, args: argparse.Namespace = None): self.loss_fn = torch.nn.CTCLoss(zero_infinity=True) # https://pytorch.org/docs/stable/generated/torch.nn.CTCLoss.html + self.val_acc = pl.metrics.Accuracy() + self.test_acc = pl.metrics.Accuracy() + ignore_tokens = [start_index, end_index, self.padding_index] self.val_cer = CharacterErrorRate(ignore_tokens) self.test_cer = CharacterErrorRate(ignore_tokens) diff --git a/lab7/text_recognizer/lit_models/ctc.py b/lab7/text_recognizer/lit_models/ctc.py index 9c761c9..ac01773 100644 --- a/lab7/text_recognizer/lit_models/ctc.py +++ b/lab7/text_recognizer/lit_models/ctc.py @@ -1,5 +1,6 @@ import argparse import itertools +import pytorch_lightning as pl import torch from .base import BaseLitModel @@ -46,6 +47,9 @@ def __init__(self, model, args: argparse.Namespace = None): self.loss_fn = torch.nn.CTCLoss(zero_infinity=True) # https://pytorch.org/docs/stable/generated/torch.nn.CTCLoss.html + self.val_acc = pl.metrics.Accuracy() + self.test_acc = pl.metrics.Accuracy() + ignore_tokens = [start_index, end_index, self.padding_index] self.val_cer = CharacterErrorRate(ignore_tokens) self.test_cer = CharacterErrorRate(ignore_tokens) diff --git a/lab8/text_recognizer/lit_models/ctc.py b/lab8/text_recognizer/lit_models/ctc.py index 9c761c9..ac01773 100644 --- a/lab8/text_recognizer/lit_models/ctc.py +++ b/lab8/text_recognizer/lit_models/ctc.py @@ -1,5 +1,6 @@ import argparse import itertools +import pytorch_lightning as pl import torch from .base import BaseLitModel @@ -46,6 +47,9 @@ def __init__(self, model, args: argparse.Namespace = None): self.loss_fn = torch.nn.CTCLoss(zero_infinity=True) # https://pytorch.org/docs/stable/generated/torch.nn.CTCLoss.html + self.val_acc = pl.metrics.Accuracy() + self.test_acc = pl.metrics.Accuracy() + ignore_tokens = [start_index, end_index, self.padding_index] self.val_cer = CharacterErrorRate(ignore_tokens) self.test_cer = CharacterErrorRate(ignore_tokens) diff --git a/lab9/text_recognizer/lit_models/ctc.py b/lab9/text_recognizer/lit_models/ctc.py index 9c761c9..ac01773 100644 --- a/lab9/text_recognizer/lit_models/ctc.py +++ b/lab9/text_recognizer/lit_models/ctc.py @@ -1,5 +1,6 @@ import argparse import itertools +import pytorch_lightning as pl import torch from .base import BaseLitModel @@ -46,6 +47,9 @@ def __init__(self, model, args: argparse.Namespace = None): self.loss_fn = torch.nn.CTCLoss(zero_infinity=True) # https://pytorch.org/docs/stable/generated/torch.nn.CTCLoss.html + self.val_acc = pl.metrics.Accuracy() + self.test_acc = pl.metrics.Accuracy() + ignore_tokens = [start_index, end_index, self.padding_index] self.val_cer = CharacterErrorRate(ignore_tokens) self.test_cer = CharacterErrorRate(ignore_tokens)