Skip to content

Commit

Permalink
Add sentence align option to locate passage
Browse files Browse the repository at this point in the history
  • Loading branch information
ProbablyFaiz committed Jun 15, 2024
1 parent c47758d commit 5e26ff0
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 24 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies = [
"pandas",
"pyspark",
"unidecode",
"legal-segmenter @ git+https://github.com/lexeme-dev/legal-segmenter#egg=main",
]

[project.optional-dependencies]
Expand Down
4 changes: 1 addition & 3 deletions rl/llm/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,7 @@ def generate(self, prompt: ChatInput) -> InferenceOutput:
},
system_instruction=system_message,
)
chat_session = model.start_chat(
history=prev_messages,
)
chat_session = model.start_chat(history=prev_messages)
# Can't include the last message in the history, because
# that would make too much sense!
response = chat_session.send_message(last_message)
Expand Down
68 changes: 47 additions & 21 deletions rl/utils/locate_passage.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import unidecode

from rl.utils import LOGGER

_TRIM_EDGE_NEWLINES_LIMIT = 20


Expand All @@ -28,7 +30,16 @@ def get_best_ngram_match(
ngram_size: int = 3,
max_passage_size: int = 512,
stride_size: int = None,
line_align: bool = True,
sentence_align: bool = False,
) -> tuple[int, int] | None:
if line_align and sentence_align:
LOGGER.warning(
"Both line_align and sentence_align are set to True. "
"This is not recommended, and you have only yourself to "
"blame if strange results ensue."
)

if not passage or not document:
return None
# if there's an exact match, return it
Expand Down Expand Up @@ -70,28 +81,43 @@ def get_best_ngram_match(
new_fragment.start += new_start
new_fragment.end += new_start

# If there is a newline in the first or last 10 characters of the span,
# let's chop it off. We do this because it mitigates slight imprecision
# in the resultant span, and avoids leading to a bounding box that
# sweeps in extra lines.
start, end = new_fragment.start, new_fragment.end
trim_limit = _TRIM_EDGE_NEWLINES_LIMIT
if "\n" in document[start : start + trim_limit]:
# We use rfind here because we want to get the last newline before the
# trim limit, not the first newline after the trim limit.
# start = document[start : start + trim_limit].rfind("\n") + start + 1
start = document.rfind("\n", start, start + trim_limit) + 1
if "\n" in document[end - trim_limit : end]:
# First, get the index of the closest newline to the start of this range:
closest_newline = document.find("\n", end - trim_limit, end)
next_newline = document.find("\n", closest_newline + 1)
if next_newline == -1:
next_newline = len(document)
# If there is a short line at the end, it's probably the end of the span. let's keep it.
if 1 <= next_newline - closest_newline <= trim_limit:
end = next_newline
else:
end = document[end - trim_limit : end].find("\n") + (end - trim_limit)
if line_align:
# If there is a newline in the first or last 10 characters of the span,
# let's chop it off. We do this because it mitigates slight imprecision
# in the resultant span, and avoids leading to a bounding box that
# sweeps in extra lines.
trim_limit = _TRIM_EDGE_NEWLINES_LIMIT
if "\n" in document[start : start + trim_limit]:
# We use rfind here because we want to get the last newline before the
# trim limit, not the first newline after the trim limit.
# start = document[start : start + trim_limit].rfind("\n") + start + 1
start = document.rfind("\n", start, start + trim_limit) + 1
if "\n" in document[end - trim_limit : end]:
# First, get the index of the closest newline to the start of this range:
closest_newline = document.find("\n", end - trim_limit, end)
next_newline = document.find("\n", closest_newline + 1)
if next_newline == -1:
next_newline = len(document)
# If there is a short line at the end, it's probably the end of the span. let's keep it.
if 1 <= next_newline - closest_newline <= trim_limit:
end = next_newline
else:
end = document[end - trim_limit : end].find("\n") + (end - trim_limit)
if sentence_align:
from legal_segmenter.segmenter import Segmenter

segmenter = Segmenter()
sentences = [
sent
for para in segmenter.segment(
document[max(0, start - 512) : min(len(document), end + 512)],
include_metadata=True,
)
for sent in para["sentences"]
]
start = min(sentences, key=lambda s: abs(s.start - start)).start
end = max(sentences, key=lambda s: abs(s.end - end)).end

return start, end

Expand Down

0 comments on commit 5e26ff0

Please sign in to comment.