Skip to content

Commit

Permalink
refactor(search): Support searching multiple entities in search() as …
Browse files Browse the repository at this point in the history
…in scroll() (#8461)

Co-authored-by: Indy Prentice <[email protected]>
  • Loading branch information
iprentic and Indy Prentice authored Jul 24, 2023
1 parent c0dbea8 commit 27392f9
Show file tree
Hide file tree
Showing 17 changed files with 125 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ public SearchResult search(@Nonnull String entity, @Nonnull String input,
@Nullable SearchFlags searchFlags)
throws RemoteInvocationException {

return ValidationUtils.validateSearchResult(_entitySearchService.search(entity, input, newFilter(requestFilters),
return ValidationUtils.validateSearchResult(_entitySearchService.search(List.of(entity), input, newFilter(requestFilters),
null, start, count, searchFlags), _entityService);
}

Expand Down Expand Up @@ -329,7 +329,7 @@ public SearchResult search(
@Nullable SearchFlags searchFlags)
throws RemoteInvocationException {
return ValidationUtils.validateSearchResult(
_entitySearchService.search(entity, input, filter, sortCriterion, start, count, searchFlags), _entityService);
_entitySearchService.search(List.of(entity), input, filter, sortCriterion, start, count, searchFlags), _entityService);
}

@Nonnull
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public Map<String, Long> docCountPerEntity(@Nonnull List<String> entityNames) {
* Gets a list of documents that match given search request. The results are aggregated and filters are applied to the
* search hits and not the aggregation results.
*
* @param entityName name of the entity
* @param entityNames names of the entity
* @param input the search input text
* @param postFilters the request map with fields and values as filters to be applied to search hits
* @param sortCriterion {@link SortCriterion} to be applied to search results
Expand All @@ -54,10 +54,10 @@ public Map<String, Long> docCountPerEntity(@Nonnull List<String> entityNames) {
* @return a {@link SearchResult} that contains a list of matched documents and related search result metadata
*/
@Nonnull
public SearchResult search(@Nonnull String entityName, @Nonnull String input, @Nullable Filter postFilters,
public SearchResult search(@Nonnull List<String> entityNames, @Nonnull String input, @Nullable Filter postFilters,
@Nullable SortCriterion sortCriterion, int from, int size, @Nullable SearchFlags searchFlags) {
SearchResult result =
_cachingEntitySearchService.search(entityName, input, postFilters, sortCriterion, from, size, searchFlags, null);
_cachingEntitySearchService.search(entityNames, input, postFilters, sortCriterion, from, size, searchFlags, null);

try {
return result.copy().setEntities(new SearchEntityArray(_searchRanker.rank(result.getEntities())));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ private Map<String, SearchResult> getSearchResultsForEachEntity(@Nonnull List<St
// Query the entity search service for all entities asynchronously
try (Timer.Context ignored = MetricUtils.timer(this.getClass(), "searchEntities").time()) {
searchResults = ConcurrencyUtils.transformAndCollectAsync(entities, entity -> new Pair<>(entity,
_cachingEntitySearchService.search(entity, input, postFilters, sortCriterion, queryFrom, querySize, searchFlags, facets)))
_cachingEntitySearchService.search(List.of(entity), input, postFilters, sortCriterion, queryFrom, querySize, searchFlags, facets)))
.stream()
.collect(Collectors.toMap(Pair::getKey, Pair::getValue));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@ public class CachingEntitySearchService {
* @return a {@link SearchResult} containing the requested batch of search results
*/
public SearchResult search(
@Nonnull String entityName,
@Nonnull List<String> entityNames,
@Nonnull String query,
@Nullable Filter filters,
@Nullable SortCriterion sortCriterion,
int from,
int size,
@Nullable SearchFlags flags,
@Nullable List<String> facets) {
return getCachedSearchResults(entityName, query, filters, sortCriterion, from, size, flags, facets);
return getCachedSearchResults(entityNames, query, filters, sortCriterion, from, size, flags, facets);
}

/**
Expand Down Expand Up @@ -141,7 +141,7 @@ public ScrollResult scroll(
* This lets us have batches that return a variable number of results (we have no idea which batch the "from" "size" page corresponds to)
*/
public SearchResult getCachedSearchResults(
@Nonnull String entityName,
@Nonnull List<String> entityNames,
@Nonnull String query,
@Nullable Filter filters,
@Nullable SortCriterion sortCriterion,
Expand All @@ -152,9 +152,9 @@ public SearchResult getCachedSearchResults(
return new CacheableSearcher<>(
cacheManager.getCache(ENTITY_SEARCH_SERVICE_SEARCH_CACHE_NAME),
batchSize,
querySize -> getRawSearchResults(entityName, query, filters, sortCriterion, querySize.getFrom(),
querySize -> getRawSearchResults(entityNames, query, filters, sortCriterion, querySize.getFrom(),
querySize.getSize(), flags, facets),
querySize -> Sextet.with(entityName, query, filters != null ? toJsonString(filters) : null,
querySize -> Sextet.with(entityNames, query, filters != null ? toJsonString(filters) : null,
sortCriterion != null ? toJsonString(sortCriterion) : null, facets, querySize), flags, enableCache).getSearchResults(from, size);
}

Expand Down Expand Up @@ -272,15 +272,15 @@ public ScrollResult getCachedScrollResults(
* Executes the expensive search query using the {@link EntitySearchService}
*/
private SearchResult getRawSearchResults(
final String entityName,
final List<String> entityNames,
final String input,
final Filter filters,
final SortCriterion sortCriterion,
final int start,
final int count,
@Nullable final SearchFlags searchFlags,
@Nullable final List<String> facets) {
return entitySearchService.search(entityName, input, filters, sortCriterion, start, count, searchFlags, facets);
return entitySearchService.search(entityNames, input, filters, sortCriterion, start, count, searchFlags, facets);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,18 +106,18 @@ public void appendRunId(@Nonnull String entityName, @Nonnull Urn urn, @Nullable

@Nonnull
@Override
public SearchResult search(@Nonnull String entityName, @Nonnull String input, @Nullable Filter postFilters,
public SearchResult search(@Nonnull List<String> entityNames, @Nonnull String input, @Nullable Filter postFilters,
@Nullable SortCriterion sortCriterion, int from, int size, @Nullable SearchFlags searchFlags) {
return search(entityName, input, postFilters, sortCriterion, from, size, searchFlags, null);
return search(entityNames, input, postFilters, sortCriterion, from, size, searchFlags, null);
}

@Nonnull
public SearchResult search(@Nonnull String entityName, @Nonnull String input, @Nullable Filter postFilters,
public SearchResult search(@Nonnull List<String> entityNames, @Nonnull String input, @Nullable Filter postFilters,
@Nullable SortCriterion sortCriterion, int from, int size, @Nullable SearchFlags searchFlags, @Nullable List<String> facets) {
log.debug(String.format(
"Searching FullText Search documents entityName: %s, input: %s, postFilters: %s, sortCriterion: %s, from: %s, size: %s",
entityName, input, postFilters, sortCriterion, from, size));
return esSearchDAO.search(entityName, input, postFilters, sortCriterion, from, size, searchFlags, facets);
entityNames, input, postFilters, sortCriterion, from, size));
return esSearchDAO.search(entityNames, input, postFilters, sortCriterion, from, size, searchFlags, facets);
}

@Nonnull
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public long docCount(@Nonnull String entityName) {

@Nonnull
@WithSpan
private SearchResult executeAndExtract(@Nonnull EntitySpec entitySpec, @Nonnull SearchRequest searchRequest,
private SearchResult executeAndExtract(@Nonnull List<EntitySpec> entitySpec, @Nonnull SearchRequest searchRequest,
@Nullable Filter filter, int from, int size) {
long id = System.currentTimeMillis();
try (Timer.Context ignored = MetricUtils.timer(this.getClass(), "executeAndExtract_search").time()) {
Expand Down Expand Up @@ -181,20 +181,22 @@ private ScrollResult executeAndExtract(@Nonnull List<EntitySpec> entitySpecs, @N
* @return a {@link SearchResult} that contains a list of matched documents and related search result metadata
*/
@Nonnull
public SearchResult search(@Nonnull String entityName, @Nonnull String input, @Nullable Filter postFilters,
public SearchResult search(@Nonnull List<String> entityNames, @Nonnull String input, @Nullable Filter postFilters,
@Nullable SortCriterion sortCriterion, int from, int size, @Nullable SearchFlags searchFlags, @Nullable List<String> facets) {
final String finalInput = input.isEmpty() ? "*" : input;
Timer.Context searchRequestTimer = MetricUtils.timer(this.getClass(), "searchRequest").time();
EntitySpec entitySpec = entityRegistry.getEntitySpec(entityName);
List<EntitySpec> entitySpecs = entityNames.stream().map(entityRegistry::getEntitySpec).collect(Collectors.toList());
Filter transformedFilters = transformFilterForEntities(postFilters, indexConvention);
// Step 1: construct the query
final SearchRequest searchRequest = SearchRequestHandler
.getBuilder(entitySpec, searchConfiguration, customSearchConfiguration)
.getBuilder(entitySpecs, searchConfiguration, customSearchConfiguration)
.getSearchRequest(finalInput, transformedFilters, sortCriterion, from, size, searchFlags, facets);
searchRequest.indices(indexConvention.getIndexName(entitySpec));
searchRequest.indices(entityNames.stream()
.map(indexConvention::getEntityIndexName)
.toArray(String[]::new));
searchRequestTimer.stop();
// Step 2: execute the query and extract results, validated against document model as well
return executeAndExtract(entitySpec, searchRequest, transformedFilters, from, size);
return executeAndExtract(entitySpecs, searchRequest, transformedFilters, from, size);
}

/**
Expand All @@ -217,7 +219,7 @@ public SearchResult filter(@Nonnull String entityName, @Nullable Filter filters,
.getFilterRequest(transformedFilters, sortCriterion, from, size);

searchRequest.indices(indexConvention.getIndexName(entitySpec));
return executeAndExtract(entitySpec, searchRequest, transformedFilters, from, size);
return executeAndExtract(List.of(entitySpec), searchRequest, transformedFilters, from, size);
}

/**
Expand Down
15 changes: 12 additions & 3 deletions metadata-io/src/test/java/com/linkedin/metadata/ESTestUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,24 @@ private ESTestUtils() {
.collect(Collectors.toList());
}

public static SearchResult search(SearchService searchService, String query) {
return search(searchService, query, null);
public static SearchResult searchAcrossEntities(SearchService searchService, String query) {
return searchAcrossEntities(searchService, query, null);
}

public static SearchResult search(SearchService searchService, String query, @Nullable List<String> facets) {
public static SearchResult searchAcrossEntities(SearchService searchService, String query, @Nullable List<String> facets) {
return searchService.searchAcrossEntities(SEARCHABLE_ENTITIES, query, null, null, 0,
100, new SearchFlags().setFulltext(true).setSkipCache(true), facets);
}

public static SearchResult search(SearchService searchService, String query) {
return search(searchService, SEARCHABLE_ENTITIES, query);
}

public static SearchResult search(SearchService searchService, List<String> entities, String query) {
return searchService.search(entities, query, null, null, 0, 100,
new SearchFlags().setFulltext(true).setSkipCache(true));
}

public static ScrollResult scroll(SearchService searchService, String query, int batchSize, @Nullable String scrollId) {
return searchService.scrollAcrossEntities(SEARCHABLE_ENTITIES, query, null, null,
scrollId, "3m", batchSize, new SearchFlags().setFulltext(true).setSkipCache(true));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.linkedin.metadata.search.elasticsearch.update.ESWriteDAO;
import com.linkedin.metadata.utils.elasticsearch.IndexConvention;
import com.linkedin.metadata.utils.elasticsearch.IndexConventionImpl;
import java.util.List;
import org.elasticsearch.client.RestHighLevelClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Import;
Expand Down Expand Up @@ -93,7 +94,7 @@ private ElasticSearchService buildService() {

@Test
public void testElasticSearchServiceStructuredQuery() throws Exception {
SearchResult searchResult = _elasticSearchService.search(ENTITY_NAME, "test", null, null, 0, 10, new SearchFlags().setFulltext(false));
SearchResult searchResult = _elasticSearchService.search(List.of(ENTITY_NAME), "test", null, null, 0, 10, new SearchFlags().setFulltext(false));
assertEquals(searchResult.getNumEntities().intValue(), 0);
BrowseResult browseResult = _elasticSearchService.browse(ENTITY_NAME, "", null, 0, 10);
assertEquals(browseResult.getMetadata().getTotalNumEntities().longValue(), 0);
Expand All @@ -110,10 +111,10 @@ public void testElasticSearchServiceStructuredQuery() throws Exception {
_elasticSearchService.upsertDocument(ENTITY_NAME, document.toString(), urn.toString());
syncAfterWrite(_bulkProcessor);

searchResult = _elasticSearchService.search(ENTITY_NAME, "test", null, null, 0, 10, new SearchFlags().setFulltext(false));
searchResult = _elasticSearchService.search(List.of(ENTITY_NAME), "test", null, null, 0, 10, new SearchFlags().setFulltext(false));
assertEquals(searchResult.getNumEntities().intValue(), 1);
assertEquals(searchResult.getEntities().get(0).getEntity(), urn);
searchResult = _elasticSearchService.search(ENTITY_NAME, "foreignKey:Node", null, null, 0, 10, new SearchFlags().setFulltext(false));
searchResult = _elasticSearchService.search(List.of(ENTITY_NAME), "foreignKey:Node", null, null, 0, 10, new SearchFlags().setFulltext(false));
assertEquals(searchResult.getNumEntities().intValue(), 1);
assertEquals(searchResult.getEntities().get(0).getEntity(), urn);
browseResult = _elasticSearchService.browse(ENTITY_NAME, "", null, 0, 10);
Expand All @@ -135,7 +136,7 @@ public void testElasticSearchServiceStructuredQuery() throws Exception {
_elasticSearchService.upsertDocument(ENTITY_NAME, document2.toString(), urn2.toString());
syncAfterWrite(_bulkProcessor);

searchResult = _elasticSearchService.search(ENTITY_NAME, "test2", null, null, 0, 10, new SearchFlags().setFulltext(false));
searchResult = _elasticSearchService.search(List.of(ENTITY_NAME), "test2", null, null, 0, 10, new SearchFlags().setFulltext(false));
assertEquals(searchResult.getNumEntities().intValue(), 1);
assertEquals(searchResult.getEntities().get(0).getEntity(), urn2);
browseResult = _elasticSearchService.browse(ENTITY_NAME, "", null, 0, 10);
Expand All @@ -152,7 +153,7 @@ public void testElasticSearchServiceStructuredQuery() throws Exception {
_elasticSearchService.deleteDocument(ENTITY_NAME, urn.toString());
_elasticSearchService.deleteDocument(ENTITY_NAME, urn2.toString());
syncAfterWrite(_bulkProcessor);
searchResult = _elasticSearchService.search(ENTITY_NAME, "test2", null, null, 0, 10, new SearchFlags().setFulltext(false));
searchResult = _elasticSearchService.search(List.of(ENTITY_NAME), "test2", null, null, 0, 10, new SearchFlags().setFulltext(false));
assertEquals(searchResult.getNumEntities().intValue(), 0);
browseResult = _elasticSearchService.browse(ENTITY_NAME, "", null, 0, 10);
assertEquals(browseResult.getMetadata().getTotalNumEntities().longValue(), 0);
Expand All @@ -162,7 +163,7 @@ public void testElasticSearchServiceStructuredQuery() throws Exception {

@Test
public void testElasticSearchServiceFulltext() throws Exception {
SearchResult searchResult = _elasticSearchService.search(ENTITY_NAME, "test", null, null, 0, 10, new SearchFlags().setFulltext(true));
SearchResult searchResult = _elasticSearchService.search(List.of(ENTITY_NAME), "test", null, null, 0, 10, new SearchFlags().setFulltext(true));
assertEquals(searchResult.getNumEntities().intValue(), 0);

Urn urn = new TestEntityUrn("test", "urn1", "VALUE_1");
Expand All @@ -175,7 +176,7 @@ public void testElasticSearchServiceFulltext() throws Exception {
_elasticSearchService.upsertDocument(ENTITY_NAME, document.toString(), urn.toString());
syncAfterWrite(_bulkProcessor);

searchResult = _elasticSearchService.search(ENTITY_NAME, "test", null, null, 0, 10, new SearchFlags().setFulltext(true));
searchResult = _elasticSearchService.search(List.of(ENTITY_NAME), "test", null, null, 0, 10, new SearchFlags().setFulltext(true));
assertEquals(searchResult.getNumEntities().intValue(), 1);
assertEquals(searchResult.getEntities().get(0).getEntity(), urn);

Expand All @@ -192,7 +193,7 @@ public void testElasticSearchServiceFulltext() throws Exception {
_elasticSearchService.upsertDocument(ENTITY_NAME, document2.toString(), urn2.toString());
syncAfterWrite(_bulkProcessor);

searchResult = _elasticSearchService.search(ENTITY_NAME, "test2", null, null, 0, 10, new SearchFlags().setFulltext(true));
searchResult = _elasticSearchService.search(List.of(ENTITY_NAME), "test2", null, null, 0, 10, new SearchFlags().setFulltext(true));
assertEquals(searchResult.getNumEntities().intValue(), 1);
assertEquals(searchResult.getEntities().get(0).getEntity(), urn2);

Expand All @@ -203,7 +204,7 @@ public void testElasticSearchServiceFulltext() throws Exception {
_elasticSearchService.deleteDocument(ENTITY_NAME, urn.toString());
_elasticSearchService.deleteDocument(ENTITY_NAME, urn2.toString());
syncAfterWrite(_bulkProcessor);
searchResult = _elasticSearchService.search(ENTITY_NAME, "test2", null, null, 0, 10, new SearchFlags().setFulltext(true));
searchResult = _elasticSearchService.search(List.of(ENTITY_NAME), "test2", null, null, 0, 10, new SearchFlags().setFulltext(true));
assertEquals(searchResult.getNumEntities().intValue(), 0);

assertEquals(_elasticSearchService.docCount(ENTITY_NAME), 0);
Expand Down
Loading

0 comments on commit 27392f9

Please sign in to comment.