Skip to content

Commit

Permalink
feat: add init_with_source_and_target
Browse files Browse the repository at this point in the history
  • Loading branch information
kod-kristoff committed May 8, 2024
1 parent da580db commit 2e9aa11
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 0 deletions.
29 changes: 29 additions & 0 deletions src/parallel_corpus/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ def init(s: str, *, manual: bool = False) -> Graph:
return init_from(token.tokenize(s), manual=manual)


def init_with_source_and_target(source: str, target: str, *, manual: bool = False) -> Graph:
return init_from_source_and_target(
source=token.tokenize(source), target=token.tokenize(target), manual=manual
)


def init_from(tokens: List[str], *, manual: bool = False) -> Graph:
return align(
Graph(
Expand All @@ -99,6 +105,25 @@ def init_from(tokens: List[str], *, manual: bool = False) -> Graph:
)


def init_from_source_and_target(
source: List[str], target: List[str], *, manual: bool = False
) -> Graph:
source_tokens = token.identify(source, "s")
target_tokens = token.identify(target, "t")
return align(
Graph(
source=source_tokens,
target=target_tokens,
edges=edge_record(
itertools.chain(
(edge([s.id], [], manual=manual) for s in source_tokens),
(edge([t.id], [], manual=manual) for t in target_tokens),
)
),
)
)


class TextLabels(TypedDict):
text: str
labels: List[str]
Expand Down Expand Up @@ -129,6 +154,10 @@ def modify(g: Graph, from_: int, to: int, text: str, side: Side = Side.target) -
return align(unaligned_modify(g, from_, to, text, side))


def set_source(g: Graph, text: str) -> Graph:
return align(unaligned_set_side(g, Side.source, text))


def set_target(g: Graph, text: str) -> Graph:
return align(unaligned_set_side(g, Side.target, text))

Expand Down
29 changes: 29 additions & 0 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,35 @@ def test_graph_init() -> None:
assert g.edges == edges


def test_init_from_source_and_target_1() -> None:
g = graph.init_with_source_and_target(source="apa", target="apa")
assert g == graph.init("apa")


def test_init_from_source_and_target_2() -> None:
g = graph.init_with_source_and_target(source="apa bepa", target="apa")
expected_source = token.identify(token.tokenize("apa bepa"), "s")
expected_target = token.identify(token.tokenize("apa"), "t")
g_expected = graph.Graph(
source=expected_source,
target=expected_target,
edges=graph.edge_record([graph.edge(["s0", "t0"], []), graph.edge(["s1"], [])]),
)
assert g == g_expected


def test_init_from_source_and_target_3() -> None:
g = graph.init_with_source_and_target(source="apa", target="bepa apa")
expected_source = token.identify(token.tokenize("apa"), "s")
expected_target = token.identify(token.tokenize("bepa apa"), "t")
g_expected = graph.Graph(
source=expected_source,
target=expected_target,
edges=graph.edge_record([graph.edge(["s0", "t1"], []), graph.edge(["t0"], [])]),
)
assert g == g_expected


def test_from_unaligned() -> None:
g = graph.from_unaligned(
SourceTarget(
Expand Down

0 comments on commit 2e9aa11

Please sign in to comment.