Skip to content

Commit

Permalink
Fix file tagging to be more adapted to current system. Does not fix e…
Browse files Browse the repository at this point in the history
…verything...

Squashed commit of the following:

commit 4969a4b
Author: Thibault Clérice <[email protected]>
Date:   Wed Apr 20 13:25:26 2022 +0200

    Fixed tagging probably

commit 3858463
Author: Thibault Clérice <[email protected]>
Date:   Wed Apr 20 11:19:03 2022 +0200

    WIP
  • Loading branch information
PonteIneptique committed Apr 20, 2022
1 parent 0ea9830 commit d7f4fe1
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 53 deletions.
47 changes: 14 additions & 33 deletions boudams/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,45 +408,27 @@ def test(test_path, models, batch_size, device, debug, workers: int, avg: str):
def tag(model, filename, device="cpu", batch_size=64):
""" Tag all [FILENAME] using [MODEL]"""
print("Loading the model.")
model = BoudamsTagger.load(model)
model = BoudamsTagger.load(model, device=device)
model.eval()
model.to(device)
print("Model loaded.")
remove_line = True
spaces = re.compile(r"\s+")
apos = re.compile(r"['’]")
for file in tqdm.tqdm(filename):
out_name = file.name.replace(".txt", ".tokenized.txt")
content = file.read() # Could definitely be done a better way...
if remove_line:
content = spaces.sub("", content)
if model.vocabulary.mode.name == "simple-space":
content = re.sub(r"\s+", "", content)
elif model.vocabulary.mode.NormalizeSpace:
content = re.sub(r"\s+", " ", content)
file.close()
# Now, extract apostrophes, remove them, and reinject them
apos_positions = [
i
for i in range(len(content))
if content[i] in ["'", "’"]
]
content = apos.sub("", content)

with open(out_name, "w") as out_io:
out = ''
for tokenized_string in model.annotate_text(content, batch_size=batch_size, device=device):
out = out + tokenized_string+" "

# Reinject apostrophes
#out = 'Sainz Tiebauz fu nez en l evesché de Troies ; ses peres ot non Ernous et sa mere, Gile et furent fra'
true_index = 0
for i in range(len(out) + len(apos_positions)):
if true_index in apos_positions:
out = out[:i] + "'" + out[i:]
true_index = true_index + 1
else:
if not out[i] == ' ':
true_index = true_index + 1

for tokenized_string in model.annotate_text(
content,
batch_size=batch_size,
device=device
):
out = out + tokenized_string + "\n"
out_io.write(out)
# print("--- File " + file.name + " has been tokenized")
print("--- File " + file.name + " has been tokenized")


@cli.command("tag-check")
Expand All @@ -458,11 +440,10 @@ def tag_check(config_model, content, device="cpu", batch_size=64):
""" Tag all [FILENAME] using [MODEL]"""
for model in config_model:
click.echo(f"Loading the model {model}.")
boudams = BoudamsTagger.load(model)
boudams = BoudamsTagger.load(model, device=device)
boudams.eval()
boudams.to(device)
click.echo(f"\t[X] Model loaded")
click.echo("\n".join(boudams.annotate_text(content, splitter="([\.!\?]+)", batch_size=batch_size, device=device)))
click.echo("\n".join(boudams.annotate_text(content, splitter=r"([\.!\?]+)", batch_size=batch_size, device=device)))


@cli.command("graph")
Expand Down
4 changes: 2 additions & 2 deletions boudams/modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class SimpleSpaceMode:
NormalizeSpace: bool = True

def __init__(self, masks: Dict[str, int] = None):
self.name = "Default"
self.name = "simple-space"
self.masks_to_index: Dict[str, int] = masks or {
DEFAULT_PAD_TOKEN: 0,
DEFAULT_MASK_TOKEN: 1,
Expand Down Expand Up @@ -139,7 +139,7 @@ def computer_wer(self, confusion_matrix):

class AdvancedSpaceMode(SimpleSpaceMode):
def __init__(self, masks: Dict[str, int] = None):
self.name = "Default"
self.name = "advanced-space"
self.masks_to_index: Dict[str, int] = masks or {
DEFAULT_PAD_TOKEN: 0,
DEFAULT_MASK_TOKEN: 1,
Expand Down
56 changes: 38 additions & 18 deletions boudams/tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,30 +449,50 @@ def annotate(self, texts: List[str], batch_size=32, device: str = "cpu"):
for index in range(len(translations)):
yield "".join(translations[order.index(index)])

def annotate_text(self, string, splitter=r"([⁊\W\d]+)", batch_size=32, device: str = "cpu"):
splitter = re.compile(splitter)
splits = splitter.split(string)

tempList = splits + [""] * 2
strings = ["".join(tempList[n:n + 2]) for n in range(0, len(splits), 2)]
strings = list(filter(lambda x: x.strip(), strings))
@staticmethod
def _apply_max_size(tokens: str, size: int):
# Use finditer when applied to things with spaces ?
# [(m.start(0), m.end(0)) for m in re.finditer(pattern, string)] ?
current = []
for tok in re.split(r"(\s+)", tokens):
if not tok:
continue
current.append(tok)
string_size = len("".join(current))
if string_size > size:
yield "".join(current[:-1])
current = current[-1:]
elif string_size == size:
yield "".join(current)
current = []
if current:
yield "".join(current)

def annotate_text(self, single_sentence, splitter: Optional[str] = None, batch_size=32, device: str = "cpu", rolling=True):
if splitter is None:
# ToDo: Mode specific splitter ?
splitter = r"([\.!\?]+)"

splitter = re.compile(splitter)
sentences = [tok for tok in splitter.split(single_sentence) if tok.strip()]

if self._maximum_sentence_size:
# This is currently quite limitating.
# If the end token is ending with a W and not a WB, there is no way to "correct it"
# We'd need a rolling system: cut in the middle of maximum sentence size ?
treated = []
max_size = self._maximum_sentence_size
for string in strings:
if len(string) > max_size:
treated.extend([
"".join(string[n:n + max_size])
for n in range(0, len(string), max_size)
])
for single_sentence in sentences:
if len(single_sentence) > max_size:
treated.extend(self._apply_max_size(single_sentence, max_size))
else:
treated.append(string)
strings = treated
yield from self.annotate(strings, batch_size=batch_size, device=device)
treated.append(single_sentence)
sentences = treated

yield from self.annotate(sentences, batch_size=batch_size, device=device)

@classmethod
def load(cls, fpath="./model.boudams_model"):
def load(cls, fpath="./model.boudams_model", device=None):
with tarfile.open(utils.ensure_ext(fpath, 'boudams_model'), 'r') as tar:
settings = json.loads(utils.get_gzip_from_tar(tar, 'settings.json.zip'))

Expand All @@ -487,7 +507,7 @@ def load(cls, fpath="./model.boudams_model"):
tar.extract('state_dict.pt', path=tmppath)
dictpath = os.path.join(tmppath, 'state_dict.pt')
# Strict false for predict (nll_weight is removed)
obj.load_state_dict(torch.load(dictpath), strict=False)
obj.load_state_dict(torch.load(dictpath, map_location=device), strict=False)

obj.eval()

Expand Down

0 comments on commit d7f4fe1

Please sign in to comment.