Skip to content

Commit

Permalink
Fix bug in complex query detection
Browse files Browse the repository at this point in the history
  • Loading branch information
pradh committed May 22, 2024
1 parent 6d23dfb commit 1ba7088
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 5 deletions.
4 changes: 2 additions & 2 deletions server/lib/nl/detection/llm_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,15 +201,15 @@ def _is_multi_sv_delimited(d: Detection, places_mentioned: List[str],
# Find all sv sub-part indexes.
vidx_list = []
for p in multi_sv.parts:
vidx = query.find(p.query_part)
vidx = dutils.find_word_boundary(query, p.query_part)
if vidx == -1:
ctr.err('failed_fallback_svidxmissing', p.query_part)
return False
vidx_list.append(vidx)

for place in places_mentioned:
# Find place idx.
pidx = query.find(place)
pidx = dutils.find_word_boundary(query, place)
if pidx == -1:
ctr.err('failed_fallback_placeidxmissing', place)
return False
Expand Down
15 changes: 14 additions & 1 deletion server/lib/nl/detection/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

# TODO: rename to variable_utils.py

from typing import Dict, List
import re
from typing import List

from server.lib.fetch import property_values
from server.lib.nl.common import constants
Expand Down Expand Up @@ -296,3 +297,15 @@ def is_llm_detection(d: Detection) -> bool:
return d.detector in [
ActualDetectorType.LLM, ActualDetectorType.HybridLLMFull
]


# Find "needle" at word boundary in "haystack".
def find_word_boundary(haystack: str, needle: str):
# Create a regex pattern with word boundaries
pattern = r'\b' + re.escape(needle) + r'\b'
# Search for the pattern in the string
match = re.search(pattern, haystack)
# Return the start index if a match is found, otherwise -1
if match:
return match.start()
return -1
19 changes: 17 additions & 2 deletions server/tests/lib/nl/detection/llm_fallback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
from shared.lib import detected_variables as dvars


def _place():
def _place(place_name: str = 'california'):
return PlaceDetection(query_original='',
query_without_place_substr='foo bar',
query_places_mentioned=['california'],
query_places_mentioned=[place_name],
places_found=[Place('geoId/06', 'CA', 'State')],
main_place=None,
entities_found=[],
Expand Down Expand Up @@ -200,6 +200,21 @@ class TestLLMFallback(unittest.TestCase):
classifications=[]),
NeedLLM.Fully,
'info_fallback_place_within_multi_sv'),
(
# Regression test for incorrect detection of place within a query.
# NOTE: "us" appears in "greenhoUSe"
# Previously: fallback with counter "info_fallback_place_within_multi_sv"
Detection(
original_query=
'what are the sources of greenhouse gas emissions in the US',
cleaned_query=
'what are the sources of greenhouse gas emissions in the US',
places_detected=_place('us'),
svs_detected=_sv(['sources greenhouse gas', 'emissions'],
above_thres=True),
classifications=[]),
NeedLLM.No,
'info_fallback_multi_sv_no_delim'),
])
def test_main(self, heuristic, fallback, counter):
ctr = Counters()
Expand Down

0 comments on commit 1ba7088

Please sign in to comment.