diff --git a/stanza/models/constituency/parse_tree.py b/stanza/models/constituency/parse_tree.py index 372b3066cc..9dbaf6dc39 100644 --- a/stanza/models/constituency/parse_tree.py +++ b/stanza/models/constituency/parse_tree.py @@ -308,6 +308,8 @@ def get_unique_constituent_labels(trees): """ Walks over all of the trees and gets all of the unique constituent names from the trees """ + if isinstance(trees, Tree): + trees = [trees] constituents = Tree.get_constituent_counts(trees) return sorted(set(constituents.keys())) diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py index dbf3913be1..0202c562fe 100644 --- a/stanza/models/constituency/trainer.py +++ b/stanza/models/constituency/trainer.py @@ -466,7 +466,15 @@ def check_constituents(train_constituents, trees, treebank_name): constituents = parse_tree.Tree.get_unique_constituent_labels(trees) for con in constituents: if con not in train_constituents: - raise RuntimeError("Found constituent label {} in the {} set which don't exist in the train set".format(con, treebank_name)) + first_error = None + num_errors = 0 + for tree_idx, tree in enumerate(trees): + constituents = parse_tree.Tree.get_unique_constituent_labels(tree) + if con in constituents: + num_errors += 1 + if first_error is None: + first_error = tree_idx + raise RuntimeError("Found constituent label {} in the {} set which don't exist in the train set. This constituent label occured in {} trees, with the first tree index at {} counting from 1\nThe error tree (which may have POS tags changed from the retagger and may be missing functional tags or empty nodes) is:\n{:P}".format(con, treebank_name, num_errors, (first_error+1), trees[first_error])) def check_root_labels(root_labels, other_trees, treebank_name): """