diff --git a/src/main/java/com/yelp/nrtsearch/server/search/collectors/DocCollector.java b/src/main/java/com/yelp/nrtsearch/server/search/collectors/DocCollector.java index 4960a906e..57e3fff79 100644 --- a/src/main/java/com/yelp/nrtsearch/server/search/collectors/DocCollector.java +++ b/src/main/java/com/yelp/nrtsearch/server/search/collectors/DocCollector.java @@ -70,7 +70,7 @@ public static int computeNumHitsToCollect(SearchRequest request) { // determine how many hits to collect based on request, facets, rescore window and hits to log int collectHits = request.getTopHits(); if (request.hasLoggingHits()) { - collectHits = request.getLoggingHits().getHitsToLog(); + collectHits = Math.max(collectHits, request.getLoggingHits().getHitsToLog()); } for (Facet facet : request.getFacetsList()) { int facetSample = facet.getSampleTopDocs(); diff --git a/src/test/java/com/yelp/nrtsearch/server/logging/HitsLoggerTest.java b/src/test/java/com/yelp/nrtsearch/server/logging/HitsLoggerTest.java index c7cb35999..c22721731 100644 --- a/src/test/java/com/yelp/nrtsearch/server/logging/HitsLoggerTest.java +++ b/src/test/java/com/yelp/nrtsearch/server/logging/HitsLoggerTest.java @@ -59,34 +59,29 @@ protected FieldDefRequest getIndexDef(String name) throws IOException { @Override protected void initIndex(String name) throws Exception { List docs = new ArrayList<>(); - AddDocumentRequest request = - AddDocumentRequest.newBuilder() - .setIndexName(name) - .putFields( - "doc_id", AddDocumentRequest.MultiValuedField.newBuilder().addValue("1").build()) - .putFields( - "vendor_name", - AddDocumentRequest.MultiValuedField.newBuilder().addValue("first vendor").build()) - .putFields( - "long_field", - AddDocumentRequest.MultiValuedField.newBuilder().addValue("5").build()) - .build(); - docs.add(request); - request = - AddDocumentRequest.newBuilder() - .setIndexName(name) - .putFields( - "doc_id", AddDocumentRequest.MultiValuedField.newBuilder().addValue("2").build()) - .putFields( - "vendor_name", - AddDocumentRequest.MultiValuedField.newBuilder() - .addValue("second vendor review") - .build()) - .putFields( - "long_field", - AddDocumentRequest.MultiValuedField.newBuilder().addValue("10").build()) - .build(); - docs.add(request); + + for (int docNum = 0; docNum < 10; docNum++) { + AddDocumentRequest request = + AddDocumentRequest.newBuilder() + .setIndexName(name) + .putFields( + "doc_id", + AddDocumentRequest.MultiValuedField.newBuilder() + .addValue(String.valueOf(docNum)) + .build()) + .putFields( + "vendor_name", + AddDocumentRequest.MultiValuedField.newBuilder() + .addValue("vendor " + docNum) + .build()) + .putFields( + "long_field", + AddDocumentRequest.MultiValuedField.newBuilder() + .addValue(String.valueOf(2 + docNum)) + .build()) + .build(); + docs.add(request); + } addDocuments(docs.stream()); } @@ -179,10 +174,10 @@ public void testCustomHitsLoggerWithParam() { .build()) .build(); SearchResponse response = getGrpcServer().getBlockingStub().search(request); - String expectedLogMessage = "LOGGED doc_id: 1, doc_id: 2, {external_value=abc}"; + String expectedLogMessage = "LOGGED doc_id: 0, doc_id: 1, {external_value=abc}"; assertEquals(expectedLogMessage, HitsLoggerTest.logMessage); - assertEquals(2, response.getTotalHits().getValue()); + assertEquals(10, response.getTotalHits().getValue()); assertEquals(2, response.getHitsCount()); } @@ -206,10 +201,10 @@ public void testCustomHitsLoggerWithoutParam() { LoggingHits.newBuilder().setName("custom_logger_2").setHitsToLog(2).build()) .build(); SearchResponse response = getGrpcServer().getBlockingStub().search(request); - String expectedLogMessage = "LOGGED_2 doc_id: 1, doc_id: 2, "; + String expectedLogMessage = "LOGGED_2 doc_id: 0, doc_id: 1, "; assertEquals(expectedLogMessage, HitsLoggerTest.logMessage); - assertEquals(2, response.getTotalHits().getValue()); + assertEquals(10, response.getTotalHits().getValue()); assertEquals(2, response.getHitsCount()); } @@ -240,10 +235,37 @@ public void testHitsLoggerResponseSizeReduction() { .build()) .build(); SearchResponse response = getGrpcServer().getBlockingStub().search(request); - String expectedLogMessage = "LOGGED doc_id: 1, doc_id: 2, {external_value=abc}"; + String expectedLogMessage = "LOGGED doc_id: 0, doc_id: 1, {external_value=abc}"; assertEquals(expectedLogMessage, HitsLoggerTest.logMessage); - assertEquals(2, response.getTotalHits().getValue()); + assertEquals(10, response.getTotalHits().getValue()); assertEquals(1, response.getHitsCount()); } + + @Test + public void testHitsLoggerResponseSizeReductionWithStartHitGreaterThanZero() { + SearchRequest request = + SearchRequest.newBuilder() + .setTopHits(10) + .setStartHit(5) + .setIndexName(DEFAULT_TEST_INDEX) + .addRetrieveFields("doc_id") + .setQuery( + Query.newBuilder() + .setTermQuery( + TermQuery.newBuilder() + .setField("vendor_name") + .setTextValue("vendor") + .build()) + .build()) + .setLoggingHits( + LoggingHits.newBuilder().setName("custom_logger").setHitsToLog(5).build()) + .build(); + SearchResponse response = getGrpcServer().getBlockingStub().search(request); + String expectedLogMessage = "LOGGED doc_id: 5, doc_id: 6, doc_id: 7, doc_id: 8, doc_id: 9, "; + + assertEquals(expectedLogMessage, HitsLoggerTest.logMessage); + assertEquals(10, response.getTotalHits().getValue()); + assertEquals(5, response.getHitsCount()); + } }