diff --git a/algorithmic_efficiency/workloads/wmt/bleu.py b/algorithmic_efficiency/workloads/wmt/bleu.py index 1efc87381..dda6d102a 100644 --- a/algorithmic_efficiency/workloads/wmt/bleu.py +++ b/algorithmic_efficiency/workloads/wmt/bleu.py @@ -1,8 +1,20 @@ +""" +Removing the dependency on sacrebleu, we reimplement the BLEU score computation in this file. +Reference: +https://github.com/mjpost/sacrebleu/blob/v1.3.1/sacrebleu.py. +""" + +from collections import Counter +from collections import namedtuple from itertools import zip_longest -from typing import Sequence +import logging +import math +import re +import sys +from typing import List, Sequence +import unicodedata from absl import logging -import sacrebleu import torch import torch.distributed as dist @@ -10,10 +22,340 @@ USE_PYTORCH_DDP, _, DEVICE, N_GPUS = pytorch_setup() +NGRAM_ORDER = 4 +# The default floor value to use with `--smooth floor` +SMOOTH_VALUE_DEFAULT = 0.0 + + +def my_log(num): + """ + Floors the log function + + :param num: the number + :return: log(num) floored to a very low number + """ + + if num == 0.0: + return -9999999999 + return math.log(num) + + +def tokenize_13a(line): + """ + Tokenizes an input line using a relatively minimal tokenization that is however equivalent to mteval-v13a, used by WMT. + + :param line: a segment to tokenize + :return: the tokenized line + """ + + norm = line + + # language-independent part: + norm = norm.replace('', '') + norm = norm.replace('-\n', '') + norm = norm.replace('\n', ' ') + norm = norm.replace('"', '"') + norm = norm.replace('&', '&') + norm = norm.replace('<', '<') + norm = norm.replace('>', '>') + + # language-dependent part (assuming Western languages): + norm = " {} ".format(norm) + norm = re.sub(r'([\{-\~\[-\` -\&\(-\+\:-\@\/])', ' \\1 ', norm) + norm = re.sub(r'([^0-9])([\.,])', '\\1 \\2 ', + norm) # tokenize period and comma unless preceded by a digit + norm = re.sub(r'([\.,])([^0-9])', ' \\1 \\2', + norm) # tokenize period and comma unless followed by a digit + norm = re.sub(r'([0-9])(-)', '\\1 \\2 ', + norm) # tokenize dash when preceded by a digit + norm = re.sub(r'\s+', ' ', norm) # one space only between words + norm = re.sub(r'^\s+', '', norm) # no leading space + norm = re.sub(r'\s+$', '', norm) # no trailing space + + return norm + + +class UnicodeRegex: + """Ad-hoc hack to recognize all punctuation and symbols. + + without depending on https://pypi.python.org/pypi/regex/.""" + + def _property_chars(prefix): + return ''.join( + chr(x) + for x in range(sys.maxunicode) + if unicodedata.category(chr(x)).startswith(prefix)) + + punctuation = _property_chars('P') + nondigit_punct_re = re.compile(r'([^\d])([' + punctuation + r'])') + punct_nondigit_re = re.compile(r'([' + punctuation + r'])([^\d])') + symbol_re = re.compile('([' + _property_chars('S') + '])') + + +def tokenize_v14_international(string): + r"""Tokenize a string following the official BLEU implementation. + + See https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983 + In our case, the input string is expected to be just one line + and no HTML entities de-escaping is needed. + So we just tokenize on punctuation and symbols, + except when a punctuation is preceded and followed by a digit + (e.g. a comma/dot as a thousand/decimal separator). + + Note that a number (e.g., a year) followed by a dot at the end of sentence is NOT tokenized, + i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g` + does not match this case (unless we add a space after each sentence). + However, this error is already in the original mteval-v14.pl + and we want to be consistent with it. + The error is not present in the non-international version, + which uses `$norm_text = " $norm_text "` (or `norm = " {} ".format(norm)` in Python). + + :param string: the input string + :return: a list of tokens + """ + string = UnicodeRegex.nondigit_punct_re.sub(r'\1 \2 ', string) + string = UnicodeRegex.punct_nondigit_re.sub(r' \1 \2', string) + string = UnicodeRegex.symbol_re.sub(r' \1 ', string) + return string.strip() + + +def tokenize_zh(sentence): + """MIT License + Copyright (c) 2017 - Shujian Huang + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + The tokenization of Chinese text in this script contains two steps: separate each Chinese + characters (by utf-8 encoding); tokenize the non Chinese part (following the mteval script). + Author: Shujian Huang huangsj@nju.edu.cn + + :param sentence: input sentence + :return: tokenized sentence + """ + + def is_chinese_char(uchar): + """ + :param uchar: input char in unicode + :return: whether the input char is a Chinese character. + """ + if uchar >= u'\u3400' and uchar <= u'\u4db5': # CJK Unified Ideographs Extension A, release 3.0 + return True + elif uchar >= u'\u4e00' and uchar <= u'\u9fa5': # CJK Unified Ideographs, release 1.1 + return True + elif uchar >= u'\u9fa6' and uchar <= u'\u9fbb': # CJK Unified Ideographs, release 4.1 + return True + elif uchar >= u'\uf900' and uchar <= u'\ufa2d': # CJK Compatibility Ideographs, release 1.1 + return True + elif uchar >= u'\ufa30' and uchar <= u'\ufa6a': # CJK Compatibility Ideographs, release 3.2 + return True + elif uchar >= u'\ufa70' and uchar <= u'\ufad9': # CJK Compatibility Ideographs, release 4.1 + return True + elif uchar >= u'\u20000' and uchar <= u'\u2a6d6': # CJK Unified Ideographs Extension B, release 3.1 + return True + elif uchar >= u'\u2f800' and uchar <= u'\u2fa1d': # CJK Compatibility Supplement, release 3.1 + return True + elif uchar >= u'\uff00' and uchar <= u'\uffef': # Full width ASCII, full width of English punctuation, half width Katakana, half wide half width kana, Korean alphabet + return True + elif uchar >= u'\u2e80' and uchar <= u'\u2eff': # CJK Radicals Supplement + return True + elif uchar >= u'\u3000' and uchar <= u'\u303f': # CJK punctuation mark + return True + elif uchar >= u'\u31c0' and uchar <= u'\u31ef': # CJK stroke + return True + elif uchar >= u'\u2f00' and uchar <= u'\u2fdf': # Kangxi Radicals + return True + elif uchar >= u'\u2ff0' and uchar <= u'\u2fff': # Chinese character structure + return True + elif uchar >= u'\u3100' and uchar <= u'\u312f': # Phonetic symbols + return True + elif uchar >= u'\u31a0' and uchar <= u'\u31bf': # Phonetic symbols (Taiwanese and Hakka expansion) + return True + elif uchar >= u'\ufe10' and uchar <= u'\ufe1f': + return True + elif uchar >= u'\ufe30' and uchar <= u'\ufe4f': + return True + elif uchar >= u'\u2600' and uchar <= u'\u26ff': + return True + elif uchar >= u'\u2700' and uchar <= u'\u27bf': + return True + elif uchar >= u'\u3200' and uchar <= u'\u32ff': + return True + elif uchar >= u'\u3300' and uchar <= u'\u33ff': + return True + + return False + + sentence = sentence.strip() + sentence_in_chars = "" + for char in sentence: + if is_chinese_char(char): + sentence_in_chars += " " + sentence_in_chars += char + sentence_in_chars += " " + else: + sentence_in_chars += char + sentence = sentence_in_chars + + # TODO: the code above could probably be replaced with the following line: + # import regex + # sentence = regex.sub(r'(\p{Han})', r' \1 ', sentence) + + # tokenize punctuation + sentence = re.sub(r'([\{-\~\[-\` -\&\(-\+\:-\@\/])', r' \1 ', sentence) + + # tokenize period and comma unless preceded by a digit + sentence = re.sub(r'([^0-9])([\.,])', r'\1 \2 ', sentence) + + # tokenize period and comma unless followed by a digit + sentence = re.sub(r'([\.,])([^0-9])', r' \1 \2', sentence) + + # tokenize dash when preceded by a digit + sentence = re.sub(r'([0-9])(-)', r'\1 \2 ', sentence) + + # one space only between words + sentence = re.sub(r'\s+', r' ', sentence) + + # no leading or trailing spaces + sentence = sentence.strip() + + return sentence + + +TOKENIZERS = { + '13a': tokenize_13a, + 'intl': tokenize_v14_international, + 'zh': tokenize_zh, + 'none': lambda x: x, +} +DEFAULT_TOKENIZER = '13a' + + +def extract_ngrams(line, min_order=1, max_order=NGRAM_ORDER) -> Counter: + """Extracts all the ngrams (1 <= n <= NGRAM_ORDER) from a sequence of tokens. + + :param line: a segment containing a sequence of words + :param max_order: collect n-grams from 1<=n<=max + :return: a dictionary containing ngrams and counts + """ + + ngrams = Counter() + tokens = line.split() + for n in range(min_order, max_order + 1): + for i in range(0, len(tokens) - n + 1): + ngram = ' '.join(tokens[i:i + n]) + ngrams[ngram] += 1 + + return ngrams + + +def ref_stats(output, refs): + ngrams = Counter() + closest_diff = None + closest_len = None + for ref in refs: + tokens = ref.split() + reflen = len(tokens) + diff = abs(len(output.split()) - reflen) + if closest_diff is None or diff < closest_diff: + closest_diff = diff + closest_len = reflen + elif diff == closest_diff: + if reflen < closest_len: + closest_len = reflen + + ngrams_ref = extract_ngrams(ref) + for ngram in ngrams_ref.keys(): + ngrams[ngram] = max(ngrams[ngram], ngrams_ref[ngram]) + + return ngrams, closest_diff, closest_len + + +BLEU = namedtuple('BLEU', + 'score, counts, totals, precisions, bp, sys_len, ref_len') + + +def compute_bleu(correct: List[int], + total: List[int], + sys_len: int, + ref_len: int, + smooth_method='none', + smooth_value=SMOOTH_VALUE_DEFAULT, + use_effective_order=False) -> BLEU: + """Computes BLEU score from its sufficient statistics. Adds smoothing. + + Smoothing methods (citing "A Systematic Comparison of Smoothing Techniques for Sentence-Level BLEU", + Boxing Chen and Colin Cherry, WMT 2014: http://aclweb.org/anthology/W14-3346) + + - exp: NIST smoothing method (Method 3) + - floor: Method 1 + - add-k: Method 2 (generalizing Lin and Och, 2004) + - none: do nothing. + + :param correct: List of counts of correct ngrams, 1 <= n <= NGRAM_ORDER + :param total: List of counts of total ngrams, 1 <= n <= NGRAM_ORDER + :param sys_len: The cumulative system length + :param ref_len: The cumulative reference length + :param smooth: The smoothing method to use + :param smooth_value: The smoothing value added, if smooth method 'floor' is used + :param use_effective_order: Use effective order. + :return: A BLEU object with the score (100-based) and other statistics. + """ + + precisions = [0 for x in range(NGRAM_ORDER)] + + smooth_mteval = 1. + effective_order = NGRAM_ORDER + for n in range(NGRAM_ORDER): + if smooth_method == 'add-k' and n > 1: + correct[n] += smooth_value + total[n] += smooth_value + if total[n] == 0: + break + + if use_effective_order: + effective_order = n + 1 + + if correct[n] == 0: + if smooth_method == 'exp': + smooth_mteval *= 2 + precisions[n] = 100. / (smooth_mteval * total[n]) + elif smooth_method == 'floor': + precisions[n] = 100. * smooth_value / total[n] + else: + precisions[n] = 100. * correct[n] / total[n] + + # If the system guesses no i-grams, 1 <= i <= NGRAM_ORDER, the BLEU score is 0 (technically undefined). + # This is a problem for sentence-level BLEU or a corpus of short sentences, where systems will get no credit + # if sentence lengths fall under the NGRAM_ORDER threshold. This fix scales NGRAM_ORDER to the observed + # maximum order. It is only available through the API and off by default + + brevity_penalty = 1.0 + if sys_len < ref_len: + brevity_penalty = math.exp(1 - ref_len / sys_len) if sys_len > 0 else 0.0 + + bleu = brevity_penalty * math.exp( + sum(map(my_log, precisions[:effective_order])) / effective_order) + + return BLEU._make( + [bleu, correct, total, precisions, brevity_penalty, sys_len, ref_len]) + -# Modified (added sync for PyTorch DDP) from -# https://github.com/mjpost/sacrebleu/blob/v1.3.1/sacrebleu.py. -# Assumes that sacrebleu==1.3.1 is installed. def corpus_bleu(sys_stream: Sequence[str], ref_streams: Sequence[str], smooth_method: str = 'exp', @@ -21,7 +363,7 @@ def corpus_bleu(sys_stream: Sequence[str], force: bool = False, lowercase: bool = False, tokenize: str = '13a', - use_effective_order: bool = False) -> sacrebleu.BLEU: + use_effective_order: bool = False) -> BLEU: """Produces BLEU scores along with its sufficient statistics from a source against one or more references. :param sys_stream: The system stream (a sequence of segments). @@ -44,8 +386,8 @@ def corpus_bleu(sys_stream: Sequence[str], sys_len = 0 ref_len = 0 - correct = [0 for _ in range(sacrebleu.NGRAM_ORDER)] - total = [0 for _ in range(sacrebleu.NGRAM_ORDER)] + correct = [0 for _ in range(NGRAM_ORDER)] + total = [0 for _ in range(NGRAM_ORDER)] # Look for already-tokenized sentences. tokenized_count = 0 @@ -70,14 +412,14 @@ def corpus_bleu(sys_stream: Sequence[str], 'or don\'t care, you can suppress this message with ' '\'--force\'.') - output, *refs = [sacrebleu.TOKENIZERS[tokenize](x.rstrip()) for x in lines] + output, *refs = [TOKENIZERS[tokenize](x.rstrip()) for x in lines] - ref_ngrams, _, closest_len = sacrebleu.ref_stats(output, refs) + ref_ngrams, _, closest_len = ref_stats(output, refs) sys_len += len(output.split()) ref_len += closest_len - sys_ngrams = sacrebleu.extract_ngrams(output) + sys_ngrams = extract_ngrams(output) for ngram, sys_ngram in sys_ngrams.items(): n = len(ngram.split()) correct[n - 1] += min(sys_ngram, ref_ngrams.get(ngram, 0)) @@ -100,7 +442,7 @@ def corpus_bleu(sys_stream: Sequence[str], dist.all_reduce(total) total = total.cpu().numpy().tolist() - return sacrebleu.compute_bleu( + return compute_bleu( correct, total, sys_len, diff --git a/setup.cfg b/setup.cfg index 2d246b48b..8e37acb7a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -102,7 +102,7 @@ librispeech_conformer = wmt = sentencepiece==0.2.0 tensorflow-text==2.18.0 - sacrebleu==1.3.1 + # Frameworks # # JAX Core