diff --git a/clientlib/src/main/proto/yelp/nrtsearch/search.proto b/clientlib/src/main/proto/yelp/nrtsearch/search.proto index 705aca0e7..ee03bce97 100644 --- a/clientlib/src/main/proto/yelp/nrtsearch/search.proto +++ b/clientlib/src/main/proto/yelp/nrtsearch/search.proto @@ -1085,6 +1085,9 @@ message LoggingHits { string name = 1; //Optional logging parameters google.protobuf.Struct params = 2; + // number of hits to log. The number of final hits to be logged can be less than this number + // if a query has less hits. + int32 hitsToLog = 3; } // Specify how to highlight matched text in SearchRequest diff --git a/src/main/java/com/yelp/nrtsearch/server/luceneserver/SearchHandler.java b/src/main/java/com/yelp/nrtsearch/server/luceneserver/SearchHandler.java index 0d7f5ae38..1e5bc64ef 100644 --- a/src/main/java/com/yelp/nrtsearch/server/luceneserver/SearchHandler.java +++ b/src/main/java/com/yelp/nrtsearch/server/luceneserver/SearchHandler.java @@ -220,7 +220,14 @@ public SearchResponse handle(IndexState indexState, SearchRequest searchRequest) long t0 = System.nanoTime(); - hits = getHitsFromOffset(hits, searchContext.getStartHit(), searchContext.getTopHits()); + // hits to be logged also need to have their fields fetched + hits = + getHitsFromOffset( + hits, + searchContext.getStartHit(), + Math.max( + searchContext.getTopHits(), + searchContext.getHitsToLog() + searchContext.getStartHit())); // create Hit.Builder for each hit, and populate with lucene doc id and ranking info setResponseHits(searchContext, hits); @@ -228,6 +235,12 @@ public SearchResponse handle(IndexState indexState, SearchRequest searchRequest) // fill Hit.Builder with requested fields fetchFields(searchContext); + // if there were extra hits for the logging, the response size needs to be reduced to match + // the topHits + if (searchContext.getFetchTasks().getHitsLoggerFetchTask() != null) { + setResponseTopHits(searchContext); + } + SearchState.Builder searchState = SearchState.newBuilder(); searchContext.getResponseBuilder().setSearchState(searchState); searchState.setTimestamp(searchContext.getTimestampSec()); @@ -444,17 +457,17 @@ private void fetchFields(SearchContext searchContext) /** * Given all the top documents, produce a slice of the documents starting from a start offset and - * going up to the query needed maximum hits. There may be more top docs than the topHits limit, + * going up to the query needed maximum hits. There may be more top docs than the hitsCount limit, * if top docs sampling facets are used. * * @param hits all hits * @param startHit offset into top docs - * @param topHits maximum number of hits needed for search response + * @param hitsCount maximum number of hits needed for the query * @return slice of hits starting at given offset, or empty slice if there are less than startHit * docs */ - public static TopDocs getHitsFromOffset(TopDocs hits, int startHit, int topHits) { - int retrieveHits = Math.min(topHits, hits.scoreDocs.length); + public static TopDocs getHitsFromOffset(TopDocs hits, int startHit, int hitsCount) { + int retrieveHits = Math.min(hitsCount, hits.scoreDocs.length); if (startHit != 0 || retrieveHits != hits.scoreDocs.length) { // Slice: int count = Math.max(0, retrieveHits - startHit); @@ -467,6 +480,20 @@ public static TopDocs getHitsFromOffset(TopDocs hits, int startHit, int topHits) return hits; } + /** + * Reduce response size by removing any extra hits used for logging. Final search response should + * only return top hits. + * + * @param context search context + */ + private static void setResponseTopHits(SearchContext context) { + while (context.getResponseBuilder().getHitsCount() + > context.getTopHits() - context.getStartHit()) { + int hitLastIdx = context.getResponseBuilder().getHitsCount() - 1; + context.getResponseBuilder().removeHits(hitLastIdx); + } + } + /** * Add {@link com.yelp.nrtsearch.server.grpc.SearchResponse.Hit.Builder}s to the context {@link * SearchResponse.Builder} for each of the query hits. Populate the builders with the lucene doc diff --git a/src/main/java/com/yelp/nrtsearch/server/luceneserver/logging/HitsLoggerFetchTask.java b/src/main/java/com/yelp/nrtsearch/server/luceneserver/logging/HitsLoggerFetchTask.java index c3e674b17..9d2e4e135 100644 --- a/src/main/java/com/yelp/nrtsearch/server/luceneserver/logging/HitsLoggerFetchTask.java +++ b/src/main/java/com/yelp/nrtsearch/server/luceneserver/logging/HitsLoggerFetchTask.java @@ -30,10 +30,12 @@ public class HitsLoggerFetchTask implements FetchTask { private static final double TEN_TO_THE_POWER_SIX = Math.pow(10, 6); private final HitsLogger hitsLogger; + private final int hitsToLog; private final DoubleAdder timeTakenMs = new DoubleAdder(); public HitsLoggerFetchTask(LoggingHits loggingHits) { this.hitsLogger = HitsLoggerCreator.getInstance().createHitsLogger(loggingHits); + this.hitsToLog = loggingHits.getHitsToLog(); } /** @@ -46,7 +48,15 @@ public HitsLoggerFetchTask(LoggingHits loggingHits) { @Override public void processAllHits(SearchContext searchContext, List hits) { long startTime = System.nanoTime(); - hitsLogger.log(searchContext, hits); + + // hits list can contain extra hits that don't need to be logged, otherwise, pass all hits that + // can be logged + if (searchContext.getHitsToLog() < hits.size()) { + hitsLogger.log(searchContext, hits.subList(0, searchContext.getHitsToLog())); + } else { + hitsLogger.log(searchContext, hits); + } + timeTakenMs.add(((System.nanoTime() - startTime) / TEN_TO_THE_POWER_SIX)); } @@ -58,4 +68,13 @@ public void processAllHits(SearchContext searchContext, List 0 && facetSample > collectHits) { diff --git a/src/test/java/com/yelp/nrtsearch/server/luceneserver/logging/HitsLoggerTest.java b/src/test/java/com/yelp/nrtsearch/server/luceneserver/logging/HitsLoggerTest.java index 122b81368..d1abb6f12 100644 --- a/src/test/java/com/yelp/nrtsearch/server/luceneserver/logging/HitsLoggerTest.java +++ b/src/test/java/com/yelp/nrtsearch/server/luceneserver/logging/HitsLoggerTest.java @@ -60,34 +60,29 @@ protected FieldDefRequest getIndexDef(String name) throws IOException { @Override protected void initIndex(String name) throws Exception { List docs = new ArrayList<>(); - AddDocumentRequest request = - AddDocumentRequest.newBuilder() - .setIndexName(name) - .putFields( - "doc_id", AddDocumentRequest.MultiValuedField.newBuilder().addValue("1").build()) - .putFields( - "vendor_name", - AddDocumentRequest.MultiValuedField.newBuilder().addValue("first vendor").build()) - .putFields( - "long_field", - AddDocumentRequest.MultiValuedField.newBuilder().addValue("5").build()) - .build(); - docs.add(request); - request = - AddDocumentRequest.newBuilder() - .setIndexName(name) - .putFields( - "doc_id", AddDocumentRequest.MultiValuedField.newBuilder().addValue("2").build()) - .putFields( - "vendor_name", - AddDocumentRequest.MultiValuedField.newBuilder() - .addValue("second vendor review") - .build()) - .putFields( - "long_field", - AddDocumentRequest.MultiValuedField.newBuilder().addValue("10").build()) - .build(); - docs.add(request); + + for (int docNum = 1; docNum < 11; docNum++) { + AddDocumentRequest request = + AddDocumentRequest.newBuilder() + .setIndexName(name) + .putFields( + "doc_id", + AddDocumentRequest.MultiValuedField.newBuilder() + .addValue(String.valueOf(docNum)) + .build()) + .putFields( + "vendor_name", + AddDocumentRequest.MultiValuedField.newBuilder() + .addValue("vendor " + docNum) + .build()) + .putFields( + "long_field", + AddDocumentRequest.MultiValuedField.newBuilder() + .addValue(String.valueOf(2 + docNum)) + .build()) + .build(); + docs.add(request); + } addDocuments(docs.stream()); } @@ -106,7 +101,14 @@ public CustomHitsLogger(Map params) { @Override public void log(SearchContext context, List hits) { - HitsLoggerTest.logMessage = "LOGGED " + hits.toString(); + HitsLoggerTest.logMessage = "LOGGED "; + + for (SearchResponse.Hit.Builder hit : hits) { + HitsLoggerTest.logMessage += + "doc_id: " + + hit.getFieldsMap().get("doc_id").getFieldValueList().get(0).getTextValue() + + ", "; + } if (!params.isEmpty()) { HitsLoggerTest.logMessage += " " + params; @@ -123,7 +125,14 @@ public CustomHitsLogger2(Map params) { @Override public void log(SearchContext context, List hits) { - HitsLoggerTest.logMessage = "LOGGED_2 " + hits.toString(); + HitsLoggerTest.logMessage = "LOGGED_2 "; + + for (SearchResponse.Hit.Builder hit : hits) { + HitsLoggerTest.logMessage += + "doc_id: " + + hit.getFieldsMap().get("doc_id").getFieldValueList().get(0).getTextValue() + + ", "; + } if (!params.isEmpty()) { HitsLoggerTest.logMessage += " " + params; @@ -143,7 +152,7 @@ public Map> getHitsLoggers() { public void testCustomHitsLoggerWithParam() { SearchRequest request = SearchRequest.newBuilder() - .setTopHits(1) + .setTopHits(2) .setStartHit(0) .setIndexName(DEFAULT_TEST_INDEX) .addRetrieveFields("doc_id") @@ -158,6 +167,7 @@ public void testCustomHitsLoggerWithParam() { .setLoggingHits( LoggingHits.newBuilder() .setName("custom_logger") + .setHitsToLog(2) .setParams( Struct.newBuilder() .putFields( @@ -165,13 +175,69 @@ public void testCustomHitsLoggerWithParam() { .build()) .build(); SearchResponse response = getGrpcServer().getBlockingStub().search(request); - String expectedLogMessage = "LOGGED " + List.of(response.getHits(0)) + " {external_value=abc}"; + String expectedLogMessage = "LOGGED doc_id: 1, doc_id: 2, {external_value=abc}"; assertEquals(expectedLogMessage, HitsLoggerTest.logMessage); + assertEquals(10, response.getTotalHits().getValue()); + assertEquals(2, response.getHitsCount()); } @Test public void testCustomHitsLoggerWithoutParam() { + SearchRequest request = + SearchRequest.newBuilder() + .setTopHits(2) + .setStartHit(0) + .setIndexName(DEFAULT_TEST_INDEX) + .addRetrieveFields("doc_id") + .setQuery( + Query.newBuilder() + .setTermQuery( + TermQuery.newBuilder() + .setField("vendor_name") + .setTextValue("vendor") + .build()) + .build()) + .setLoggingHits( + LoggingHits.newBuilder().setName("custom_logger_2").setHitsToLog(2).build()) + .build(); + SearchResponse response = getGrpcServer().getBlockingStub().search(request); + String expectedLogMessage = "LOGGED_2 doc_id: 1, doc_id: 2, "; + + assertEquals(expectedLogMessage, HitsLoggerTest.logMessage); + assertEquals(10, response.getTotalHits().getValue()); + assertEquals(2, response.getHitsCount()); + } + + @Test + public void testResponseSizeReductionWithHitsToLogSameAsHitsCount() { + SearchRequest request = + SearchRequest.newBuilder() + .setTopHits(5) + .setStartHit(0) + .setIndexName(DEFAULT_TEST_INDEX) + .addRetrieveFields("doc_id") + .setQuery( + Query.newBuilder() + .setTermQuery( + TermQuery.newBuilder() + .setField("vendor_name") + .setTextValue("vendor") + .build()) + .build()) + .setLoggingHits( + LoggingHits.newBuilder().setName("custom_logger").setHitsToLog(5).build()) + .build(); + SearchResponse response = getGrpcServer().getBlockingStub().search(request); + String expectedLogMessage = "LOGGED doc_id: 1, doc_id: 2, doc_id: 3, doc_id: 4, doc_id: 5, "; + + assertEquals(expectedLogMessage, HitsLoggerTest.logMessage); + assertEquals(10, response.getTotalHits().getValue()); + assertEquals(5, response.getHitsCount()); + } + + @Test + public void testResponseSizeReductionWithHitsToLogGreaterThanHitsCount() { SearchRequest request = SearchRequest.newBuilder() .setTopHits(1) @@ -186,12 +252,186 @@ public void testCustomHitsLoggerWithoutParam() { .setTextValue("vendor") .build()) .build()) - .setLoggingHits(LoggingHits.newBuilder().setName("custom_logger_2").build()) + .setLoggingHits( + LoggingHits.newBuilder() + .setName("custom_logger") + .setHitsToLog(2) + .setParams( + Struct.newBuilder() + .putFields( + "external_value", Value.newBuilder().setStringValue("abc").build())) + .build()) + .build(); + SearchResponse response = getGrpcServer().getBlockingStub().search(request); + String expectedLogMessage = "LOGGED doc_id: 1, doc_id: 2, {external_value=abc}"; + + assertEquals(expectedLogMessage, HitsLoggerTest.logMessage); + assertEquals(10, response.getTotalHits().getValue()); + assertEquals(1, response.getHitsCount()); + } + + @Test + public void testResponseSizeReductionWithHitsToLogGreaterThanHitsCountAndTotalDocs() { + SearchRequest request = + SearchRequest.newBuilder() + .setTopHits(10) + .setStartHit(0) + .setIndexName(DEFAULT_TEST_INDEX) + .addRetrieveFields("doc_id") + .setQuery( + Query.newBuilder() + .setTermQuery( + TermQuery.newBuilder() + .setField("vendor_name") + .setTextValue("vendor") + .build()) + .build()) + .setLoggingHits( + LoggingHits.newBuilder().setName("custom_logger").setHitsToLog(15).build()) + .build(); + SearchResponse response = getGrpcServer().getBlockingStub().search(request); + String expectedLogMessage = + "LOGGED doc_id: 1, doc_id: 2, doc_id: 3, doc_id: 4, doc_id: 5, doc_id: 6, doc_id: 7, doc_id: 8, doc_id: 9, doc_id: 10, "; + + assertEquals(expectedLogMessage, HitsLoggerTest.logMessage); + assertEquals(10, response.getTotalHits().getValue()); + assertEquals(10, response.getHitsCount()); + } + + @Test + public void testResponseSizeReductionWithHitsToLogLessThanHitsCount() { + SearchRequest request = + SearchRequest.newBuilder() + .setTopHits(5) + .setStartHit(0) + .setIndexName(DEFAULT_TEST_INDEX) + .addRetrieveFields("doc_id") + .setQuery( + Query.newBuilder() + .setTermQuery( + TermQuery.newBuilder() + .setField("vendor_name") + .setTextValue("vendor") + .build()) + .build()) + .setLoggingHits( + LoggingHits.newBuilder().setName("custom_logger").setHitsToLog(3).build()) + .build(); + SearchResponse response = getGrpcServer().getBlockingStub().search(request); + String expectedLogMessage = "LOGGED doc_id: 1, doc_id: 2, doc_id: 3, "; + + assertEquals(expectedLogMessage, HitsLoggerTest.logMessage); + assertEquals(10, response.getTotalHits().getValue()); + assertEquals(5, response.getHitsCount()); + } + + @Test + public void testResponseSizeReductionWithStartHitAndHitsToLogSameAsHitsCount() { + SearchRequest request = + SearchRequest.newBuilder() + .setTopHits(10) + .setStartHit(5) + .setIndexName(DEFAULT_TEST_INDEX) + .addRetrieveFields("doc_id") + .setQuery( + Query.newBuilder() + .setTermQuery( + TermQuery.newBuilder() + .setField("vendor_name") + .setTextValue("vendor") + .build()) + .build()) + .setLoggingHits( + LoggingHits.newBuilder().setName("custom_logger").setHitsToLog(5).build()) + .build(); + SearchResponse response = getGrpcServer().getBlockingStub().search(request); + String expectedLogMessage = "LOGGED doc_id: 6, doc_id: 7, doc_id: 8, doc_id: 9, doc_id: 10, "; + + assertEquals(expectedLogMessage, HitsLoggerTest.logMessage); + assertEquals(10, response.getTotalHits().getValue()); + assertEquals(5, response.getHitsCount()); + } + + @Test + public void testResponseSizeReductionWithStartHitAndHitsToLogGreaterThanHitsCountAndTotalDocs() { + SearchRequest request = + SearchRequest.newBuilder() + .setTopHits(10) + .setStartHit(5) + .setIndexName(DEFAULT_TEST_INDEX) + .addRetrieveFields("doc_id") + .setQuery( + Query.newBuilder() + .setTermQuery( + TermQuery.newBuilder() + .setField("vendor_name") + .setTextValue("vendor") + .build()) + .build()) + .setLoggingHits( + LoggingHits.newBuilder().setName("custom_logger").setHitsToLog(6).build()) + .build(); + SearchResponse response = getGrpcServer().getBlockingStub().search(request); + String expectedLogMessage = "LOGGED doc_id: 6, doc_id: 7, doc_id: 8, doc_id: 9, doc_id: 10, "; + + assertEquals(expectedLogMessage, HitsLoggerTest.logMessage); + assertEquals(10, response.getTotalHits().getValue()); + assertEquals(5, response.getHitsCount()); + } + + @Test + public void testResponseSizeReductionWithStartHitAndHitsToLogGreaterThanHitsCount() { + SearchRequest request = + SearchRequest.newBuilder() + .setTopHits(9) + .setStartHit(4) + .setIndexName(DEFAULT_TEST_INDEX) + .addRetrieveFields("doc_id") + .setQuery( + Query.newBuilder() + .setTermQuery( + TermQuery.newBuilder() + .setField("vendor_name") + .setTextValue("vendor") + .build()) + .build()) + .setLoggingHits( + LoggingHits.newBuilder().setName("custom_logger").setHitsToLog(6).build()) + .build(); + SearchResponse response = getGrpcServer().getBlockingStub().search(request); + String expectedLogMessage = + "LOGGED doc_id: 5, doc_id: 6, doc_id: 7, doc_id: 8, doc_id: 9, doc_id: 10, "; + + assertEquals(expectedLogMessage, HitsLoggerTest.logMessage); + assertEquals(10, response.getTotalHits().getValue()); + assertEquals(5, response.getHitsCount()); + } + + @Test + public void testResponseSizeReductionWithStartHitAndHitsToLogLessThanHitsCount() { + SearchRequest request = + SearchRequest.newBuilder() + .setTopHits(10) + .setStartHit(5) + .setIndexName(DEFAULT_TEST_INDEX) + .addRetrieveFields("doc_id") + .setQuery( + Query.newBuilder() + .setTermQuery( + TermQuery.newBuilder() + .setField("vendor_name") + .setTextValue("vendor") + .build()) + .build()) + .setLoggingHits( + LoggingHits.newBuilder().setName("custom_logger").setHitsToLog(3).build()) .build(); SearchResponse response = getGrpcServer().getBlockingStub().search(request); - String expectedLogMessage = "LOGGED_2 " + List.of(response.getHits(0)); + String expectedLogMessage = "LOGGED doc_id: 6, doc_id: 7, doc_id: 8, "; assertEquals(expectedLogMessage, HitsLoggerTest.logMessage); + assertEquals(10, response.getTotalHits().getValue()); + assertEquals(5, response.getHitsCount()); } @Test diff --git a/src/test/java/com/yelp/nrtsearch/server/luceneserver/search/collectors/DocCollectorTest.java b/src/test/java/com/yelp/nrtsearch/server/luceneserver/search/collectors/DocCollectorTest.java index 3aba08b6e..f3ce40f42 100644 --- a/src/test/java/com/yelp/nrtsearch/server/luceneserver/search/collectors/DocCollectorTest.java +++ b/src/test/java/com/yelp/nrtsearch/server/luceneserver/search/collectors/DocCollectorTest.java @@ -20,6 +20,7 @@ import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.when; +import com.yelp.nrtsearch.server.grpc.LoggingHits; import com.yelp.nrtsearch.server.grpc.Rescorer; import com.yelp.nrtsearch.server.grpc.SearchRequest; import com.yelp.nrtsearch.server.grpc.SearchResponse.Hit.Builder; @@ -210,6 +211,25 @@ public void testNumHitsToCollect() { assertEquals(1000, docCollector.getNumHitsToCollect()); } + @Test + public void testNumHitsToCollectWithHitsToLog() { + SearchRequest.Builder builder = SearchRequest.newBuilder(); + builder.setTopHits(200); + builder.setLoggingHits(LoggingHits.newBuilder().setHitsToLog(300).build()); + TestDocCollector docCollector = new TestDocCollector(builder.build()); + assertEquals(300, docCollector.getNumHitsToCollect()); + } + + @Test + public void testNumHitsToCollectWithHitsToLogAndWindowSize() { + SearchRequest.Builder builder = SearchRequest.newBuilder(); + builder.setTopHits(200); + builder.addRescorers(Rescorer.newBuilder().setWindowSize(1000).build()); + builder.setLoggingHits(LoggingHits.newBuilder().setHitsToLog(300).build()); + TestDocCollector docCollector = new TestDocCollector(builder.build()); + assertEquals(1000, docCollector.getNumHitsToCollect()); + } + @Test public void testHasTerminateAfterWrapper() { SearchRequest request =