Skip to content

Commit

Permalink
Merge pull request #5 from fedecosta/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
fedecosta authored Oct 20, 2023
2 parents 13ebd87 + 3a71833 commit b6ffdd0
Show file tree
Hide file tree
Showing 8 changed files with 1,409 additions and 300 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ lexicon*.txt
words.txt

# TODO uncomment
#/data/
#/data/
1 change: 0 additions & 1 deletion gruut/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,3 @@ def is_language_supported(lang: str) -> bool:
def get_supported_languages() -> typing.Set[str]:
"""Set of supported gruut languages"""
return set(KNOWN_LANGS)

209 changes: 12 additions & 197 deletions gruut/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sys
from enum import Enum
from pathlib import Path

import jsonlines

from gruut.const import KNOWN_LANGS
Expand Down Expand Up @@ -39,7 +40,6 @@ class StdinFormat(str, Enum):


def main():

"""Main entry point"""
if len(sys.argv) < 2:
# Print known languages and exit
Expand All @@ -64,29 +64,16 @@ def main():
args.model_prefix = "espeak"

# -------------------------------------------------------------------------

text_processor = TextProcessor(
default_lang=args.language, model_prefix=args.model_prefix,
)

if args.debug:
_LOGGER.debug(f"settings: {text_processor.settings}")

# lines definition
if args.input_csv_path:
with open(args.input_csv_path) as csvfile:
reader = csv.reader(csvfile, delimiter = args.input_csv_delimiter)
lines_ids = [row[0] for row in reader]
csvfile.close()
with open(args.input_csv_path) as csvfile:
reader = csv.reader(csvfile, delimiter = args.input_csv_delimiter)
lines = [row[1] for row in reader]
csvfile.close()

elif args.text:
_LOGGER.debug(text_processor.settings)

if args.text:
# Use arguments
lines = args.text

else:
# Use stdin
stdin_format = StdinFormat.LINES
Expand All @@ -105,59 +92,7 @@ def main():
if os.isatty(sys.stdin.fileno()):
print("Reading input from stdin...", file=sys.stderr)

# writer, input_text an output_sentences definition
if args.output_csv_path:

# Clean output file if exists
with open(args.output_csv_path, 'w') as outcsvfile:
outcsvfile.close()

def input_text(lines):
for line_num, line in enumerate(lines):
text = line
text_id = lines_ids[line_num]
yield (text, text_id)

def output_sentences(sentences, writer, text_data=None):
for sentence in sentences:
sentence_dict = dataclasses.asdict(sentence)
writer.write(sentence_dict)

def output_transcription(
sentences,
writer,
text_data=None,
word_begin_sep = '[',
word_end_sep = ']',
g2p_word_begin_sep = '{',
g2p_word_end_sep = '}',
):

transcription = ""
for sentence in sentences:
sentence_dict = dataclasses.asdict(sentence)
for word_dict in sentence_dict["words"]:
word_phonemes = word_dict["phonemes"]
in_lexicon = text_processor._is_word_in_lexicon(
word_dict["text"],
text_processor.get_settings(lang = args.language),
)
if in_lexicon == False:
transcription = f"{transcription.strip()} {' '.join([g2p_word_begin_sep] + word_phonemes + [g2p_word_end_sep]).strip()}".strip()
else:
transcription = f"{transcription.strip()} {' '.join([word_begin_sep] + word_phonemes + [word_end_sep]).strip()}".strip()

row_to_write = f"{text_data}{args.output_csv_delimiter}{transcription}"
row_to_write = [text_data, transcription]
writer.writerow(row_to_write)

def output_json(sentences, writer, text_data=None):
import json
for sentence in sentences:
sentence_dict = dataclasses.asdict(sentence)
print(json.dumps(sentence_dict, indent=4))

elif args.csv:
if args.csv:
writer = csv.writer(sys.stdout, delimiter=args.csv_delimiter)

def input_text(lines):
Expand All @@ -184,7 +119,7 @@ def output_sentences(sentences, writer, text_data=None):

row.append(args.phoneme_word_separator.join(phonemes))
writer.writerow(row)

else:
writer = jsonlines.Writer(sys.stdout, flush=True)

Expand All @@ -196,46 +131,8 @@ def output_sentences(sentences, writer, text_data=None):
for sentence in sentences:
sentence_dict = dataclasses.asdict(sentence)
writer.write(sentence_dict)

# TEST
def output_transcription(
sentences,
writer,
text_data=None,
word_begin_sep = '[',
word_end_sep = ']',
g2p_word_begin_sep = '{',
g2p_word_end_sep = '}',
):

transcription = ""
for sentence in sentences:
sentence_dict = dataclasses.asdict(sentence)
for word_dict in sentence_dict["words"]:
word_phonemes = word_dict["phonemes"]
in_lexicon = text_processor._is_word_in_lexicon(
word_dict["text"],
text_processor.get_settings(lang = args.language),
)
if in_lexicon == False:
transcription = f"{transcription.strip()} {' '.join([g2p_word_begin_sep] + word_phonemes + [g2p_word_end_sep]).strip()}".strip()
else:
transcription = f"{transcription.strip()} {' '.join([word_begin_sep] + word_phonemes + [word_end_sep]).strip()}".strip()

writer.write(transcription)

def output_json(sentences, writer, text_data=None):
import json
for sentence in sentences:
sentence_dict = dataclasses.asdict(sentence)
print(json.dumps(sentence_dict, indent=4))

# Transcription output

for text, text_data in input_text(lines):

# I think lowercase is not applied before!
text = text.lower()

try:
graph, root = text_processor(
text,
Expand Down Expand Up @@ -268,32 +165,9 @@ def output_json(sentences, writer, text_data=None):
punctuations=(not args.no_punctuation),
)
)

if args.output_csv_path:
with open(args.output_csv_path, 'a') as outcsvfile:
writer = csv.writer(outcsvfile, delimiter = args.output_csv_delimiter)
output_transcription(
sentences,
writer,
text_data,
word_begin_sep = args.word_begin_sep,
word_end_sep = args.word_end_sep,
g2p_word_begin_sep = args.g2p_word_begin_sep,
g2p_word_end_sep = args.g2p_word_end_sep,
)
outcsvfile.close()
else:
output_transcription(
sentences,
writer,
text_data,
word_begin_sep = args.word_begin_sep,
word_end_sep = args.word_end_sep,
g2p_word_begin_sep = args.g2p_word_begin_sep,
g2p_word_end_sep = args.g2p_word_end_sep,
)



output_sentences(sentences, writer, text_data)

except Exception as e:
_LOGGER.exception(text)

Expand All @@ -315,9 +189,7 @@ class TextProcessingError(Exception):

def get_args() -> argparse.Namespace:
"""Parse command-line arguments"""

parser = argparse.ArgumentParser(prog="gruut")

parser.add_argument(
"-l",
"--language",
Expand All @@ -326,11 +198,9 @@ def get_args() -> argparse.Namespace:
)

parser.add_argument("text", nargs="*", help="Text to tokenize (default: stdin)")

parser.add_argument(
"--ssml", action="store_true", help="Input text is SSML",
)

parser.add_argument(
"--stdin-format",
choices=[str(v.value) for v in StdinFormat],
Expand All @@ -344,61 +214,50 @@ def get_args() -> argparse.Namespace:
action="store_true",
help="Disable number replacement (1 -> one)",
)

parser.add_argument(
"--no-currency",
action="store_true",
help="Disable currency replacement ($1 -> one dollar)",
)

parser.add_argument(
"--no-dates",
action="store_true",
help="Disable date replacement (4/1/2021 -> April first twenty twenty one)",
)

parser.add_argument(
"--no-times",
action="store_true",
help="Disable time replacement (4:01pm -> four oh one P M)",
)

parser.add_argument(
"--no-pos", action="store_true", help="Disable part of speech tagger",
)

parser.add_argument(
"--no-lexicon", action="store_true", help="Disable phoneme lexicon database",
)

parser.add_argument(
"--no-g2p", action="store_true", help="Disable grapheme to phoneme guesser",
)

parser.add_argument(
"--no-punctuation",
action="store_true",
help="Don't output punctuations (quotes, brackets, etc.)",
)

parser.add_argument(
"--no-major-breaks",
action="store_true",
help="Don't output major breaks (periods, question marks, etc.)",
)

parser.add_argument(
"--no-minor-breaks",
action="store_true",
help="Don't output minor breaks (commas, semicolons, etc.)",
)

parser.add_argument(
"--no-post-process",
action="store_true",
help="Disable post-processing of sentences (e.g., liasons)",
)

parser.add_argument(
"--no-fail", action="store_true", help="Skip lines that result in errors",
)
Expand All @@ -409,80 +268,36 @@ def get_args() -> argparse.Namespace:
action="store_true",
help="Use eSpeak versions of lexicons (overrides --model-prefix)",
)

parser.add_argument(
"--model-prefix",
help="Sub-directory of gruut language data files with different lexicon, etc. (e.g., espeak)",
)

parser.add_argument(
"--csv", action="store_true", help="Input text is id|text (see --csv-delimiter)"
)

parser.add_argument(
"--input-csv-path", help="Input csv path",
)

parser.add_argument(
"--output-csv-path", help="Output csv path",
"--csv-delimiter", default="|", help="Delimiter for input text with --csv"
)

parser.add_argument(
"--input-csv-delimiter", default="|", help="Delimiter for input csv"
)

parser.add_argument(
"--output-csv-delimiter", default="|", help="Delimiter for output csv"
)

parser.add_argument(
"--sentence-separator",
default=". ",
help="String used to separate sentences in CSV output",
)

parser.add_argument(
"--word-separator",
default=" ",
help="String used to separate words in CSV output",
)

parser.add_argument(
"--phoneme-word-separator",
default="#",
help="String used to separate phonemes in CSV output",
)

parser.add_argument(
"--phoneme-separator",
default=" ",
help="String used to separate words in CSV output phonemes",
)

parser.add_argument(
"--word_begin_sep",
default="[",
help="String used to indicate the begining of words transcribed using the lexicon.",
)

parser.add_argument(
"--word_end_sep",
default="]",
help="String used to indicate the ending of words transcribed using the lexicon.",
)

parser.add_argument(
"--g2p_word_begin_sep",
default="{",
help="String used to indicate the begining of words transcribed using the g2p model.",
)

parser.add_argument(
"--g2p_word_end_sep",
default="}",
help="String used to indicate the ending of words transcribed using the g2p model.",
)

parser.add_argument(
"--debug", action="store_true", help="Print DEBUG messages to console"
)
Expand All @@ -494,4 +309,4 @@ def get_args() -> argparse.Namespace:


if __name__ == "__main__":
main()
main()
Loading

0 comments on commit b6ffdd0

Please sign in to comment.