diff --git a/src/main/kotlin/org/wfanet/measurement/api/v2alpha/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/api/v2alpha/BUILD.bazel index 45bc9a2576f..e53b6c87c1c 100644 --- a/src/main/kotlin/org/wfanet/measurement/api/v2alpha/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/api/v2alpha/BUILD.bazel @@ -52,7 +52,7 @@ kt_jvm_library( deps = [ "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:context_keys", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:measurement_principal", - "//src/main/kotlin/org/wfanet/measurement/common/api/grpc", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:akid_principal_server_interceptor", "//src/main/kotlin/org/wfanet/measurement/common/grpc:context", "//src/main/kotlin/org/wfanet/measurement/common/identity", "//src/main/proto/wfa/measurement/api/v2alpha:duchy_kt_jvm_proto", diff --git a/src/main/kotlin/org/wfanet/measurement/common/api/grpc/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/common/api/grpc/BUILD.bazel index 3a897c0ad81..4ffd63ed7ab 100644 --- a/src/main/kotlin/org/wfanet/measurement/common/api/grpc/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/common/api/grpc/BUILD.bazel @@ -3,7 +3,7 @@ load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") package(default_visibility = ["//visibility:public"]) kt_jvm_library( - name = "grpc", + name = "akid_principal_server_interceptor", srcs = ["AkidPrincipalServerInterceptor.kt"], deps = [ "//src/main/kotlin/org/wfanet/measurement/common/api:principal", @@ -14,3 +14,13 @@ kt_jvm_library( "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", ], ) + +kt_jvm_library( + name = "list_resources", + srcs = ["ListResources.kt"], + deps = [ + "@wfa_common_jvm//imports/java/com/google/protobuf", + "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", + "@wfa_rules_kotlin_jvm//imports/io/gprc/kotlin:stub", + ], +) diff --git a/src/main/kotlin/org/wfanet/measurement/common/api/grpc/ListResources.kt b/src/main/kotlin/org/wfanet/measurement/common/api/grpc/ListResources.kt new file mode 100644 index 00000000000..c6aa2fb13f6 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/common/api/grpc/ListResources.kt @@ -0,0 +1,85 @@ +/* + * Copyright 2024 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.common.api.grpc + +import com.google.protobuf.Message +import io.grpc.kotlin.AbstractCoroutineStub +import kotlin.coroutines.coroutineContext +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.ensureActive +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.asFlow +import kotlinx.coroutines.flow.flattenConcat +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.map + +data class ResourceList(val resources: List, val nextPageToken: String) : + List by resources + +/** + * Lists resources from this stub, handling pagination. + * + * @param pageToken page token for initial request + * @param list function which calls the appropriate List method on the stub + */ +fun > S.listResources( + pageToken: String = "", + list: suspend S.(pageToken: String) -> ResourceList, +): Flow> = + listResources(Int.MAX_VALUE, pageToken) { nextPageToken, _ -> list(nextPageToken) } + +/** + * Lists resources from this stub, handling pagination. + * + * @param limit maximum number of resources to emit + * @param pageToken page token for initial request + * @param list function which calls the appropriate List method on the stub, returning no more than + * the specified remaining number of resources + */ +fun > S.listResources( + limit: Int, + pageToken: String = "", + list: suspend S.(pageToken: String, remaining: Int) -> ResourceList, +): Flow> { + require(limit > 0) { "limit must be positive" } + return flow { + var remaining: Int = limit + var nextPageToken = pageToken + + while (true) { + coroutineContext.ensureActive() + + val resourceList: ResourceList = list(nextPageToken, remaining) + require(resourceList.size <= remaining) { + "List call must ensure that limit is not exceeded. " + + "Returned ${resourceList.size} items when only $remaining were remaining" + } + emit(resourceList) + + remaining -= resourceList.size + nextPageToken = resourceList.nextPageToken + if (nextPageToken.isEmpty() || remaining == 0) { + break + } + } + } +} + +/** @see [flattenConcat] */ +@ExperimentalCoroutinesApi // Overloads experimental `flattenConcat` function. +fun Flow>.flattenConcat(): Flow = + map { it.asFlow() }.flattenConcat() diff --git a/src/main/kotlin/org/wfanet/measurement/integration/common/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/integration/common/BUILD.bazel index 981e656f91e..2361f2b15bf 100644 --- a/src/main/kotlin/org/wfanet/measurement/integration/common/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/integration/common/BUILD.bazel @@ -280,6 +280,7 @@ kt_jvm_library( ], deps = [ ":in_process_cmms_components", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:list_resources", "//src/main/kotlin/org/wfanet/measurement/kingdom/batch:measurement_system_prober", "//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common/service:data_services", "@wfa_common_jvm//imports/java/com/google/common/truth", diff --git a/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessMeasurementSystemProberIntegrationTest.kt b/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessMeasurementSystemProberIntegrationTest.kt index 0a015ca80d3..bb6b17e4e1e 100644 --- a/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessMeasurementSystemProberIntegrationTest.kt +++ b/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessMeasurementSystemProberIntegrationTest.kt @@ -22,6 +22,8 @@ import java.io.File import java.nio.file.Paths import java.time.Clock import java.time.Duration +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.flow.toList import kotlinx.coroutines.runBlocking import org.junit.After import org.junit.Before @@ -30,13 +32,15 @@ import org.junit.Rule import org.junit.Test import org.wfanet.measurement.api.v2alpha.DataProvidersGrpcKt.DataProvidersCoroutineStub import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub -import org.wfanet.measurement.api.v2alpha.ListMeasurementsResponse import org.wfanet.measurement.api.v2alpha.Measurement import org.wfanet.measurement.api.v2alpha.MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineStub import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt.MeasurementsCoroutineStub import org.wfanet.measurement.api.v2alpha.RequisitionsGrpcKt.RequisitionsCoroutineStub import org.wfanet.measurement.api.v2alpha.listMeasurementsRequest import org.wfanet.measurement.api.withAuthenticationKey +import org.wfanet.measurement.common.api.grpc.ResourceList +import org.wfanet.measurement.common.api.grpc.flattenConcat +import org.wfanet.measurement.common.api.grpc.listResources import org.wfanet.measurement.common.getRuntimePath import org.wfanet.measurement.common.identity.withPrincipalName import org.wfanet.measurement.common.testing.ProviderRule @@ -129,33 +133,32 @@ abstract class InProcessMeasurementSystemProberIntegrationTest( assertThat(measurements.size).isEqualTo(1) } + @OptIn(ExperimentalCoroutinesApi::class) // For `flattenConcat`. private suspend fun listMeasurements(): List { - var nextPageToken = "" val measurementConsumerData = inProcessCmmsComponents.getMeasurementConsumerData() - do { - val response: ListMeasurementsResponse = - try { - publicMeasurementsClient - .withAuthenticationKey(measurementConsumerData.apiAuthenticationKey) - .listMeasurements( + val measurementLists = + publicMeasurementsClient + .withAuthenticationKey(measurementConsumerData.apiAuthenticationKey) + .listResources { pageToken -> + val response = + listMeasurements( listMeasurementsRequest { parent = measurementConsumerData.name - pageToken = nextPageToken + this.pageToken = pageToken } ) - } catch (e: StatusException) { - throw Exception( - "Unable to list measurements for measurement consumer ${measurementConsumerData.name}", - e, - ) + ResourceList(response.measurementsList, response.nextPageToken) } - if (response.measurementsList.isNotEmpty()) { - return response.measurementsList - } - nextPageToken = response.nextPageToken - } while (nextPageToken.isNotEmpty()) - return emptyList() + + return try { + measurementLists.flattenConcat().toList() + } catch (e: StatusException) { + throw Exception( + "Unable to list measurements for measurement consumer ${measurementConsumerData.name}", + e, + ) + } } companion object { diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/batch/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/kingdom/batch/BUILD.bazel index 12d1d96644d..2d455bdfd32 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/batch/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/batch/BUILD.bazel @@ -47,6 +47,7 @@ kt_jvm_library( "//src/main/kotlin/org/wfanet/measurement/api:api_key_constants", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:packed_messages", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:list_resources", "//src/main/proto/wfa/measurement/api/v2alpha:data_provider_kt_jvm_proto", "//src/main/proto/wfa/measurement/api/v2alpha:data_providers_service_kt_jvm_grpc_proto", "//src/main/proto/wfa/measurement/api/v2alpha:event_group_kt_jvm_proto", diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/batch/MeasurementSystemProber.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/batch/MeasurementSystemProber.kt index b64a3a55842..2aa6a1e413b 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/batch/MeasurementSystemProber.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/batch/MeasurementSystemProber.kt @@ -27,6 +27,10 @@ import java.security.SecureRandom import java.time.Clock import java.time.Duration import java.util.logging.Logger +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.catch +import kotlinx.coroutines.flow.singleOrNull import org.wfanet.measurement.api.v2alpha.CanonicalRequisitionKey import org.wfanet.measurement.api.v2alpha.DataProvider import org.wfanet.measurement.api.v2alpha.DataProvidersGrpcKt @@ -60,6 +64,9 @@ import org.wfanet.measurement.api.v2alpha.requisitionSpec import org.wfanet.measurement.api.v2alpha.unpack import org.wfanet.measurement.api.withAuthenticationKey import org.wfanet.measurement.common.Instrumentation +import org.wfanet.measurement.common.api.grpc.ResourceList +import org.wfanet.measurement.common.api.grpc.flattenConcat +import org.wfanet.measurement.common.api.grpc.listResources import org.wfanet.measurement.common.crypto.Hashing import org.wfanet.measurement.common.crypto.SigningKeyHandle import org.wfanet.measurement.common.crypto.readCertificate @@ -253,55 +260,50 @@ class MeasurementSystemProber( return clock.instant() >= nextMeasurementEarliestInstant } + @OptIn(ExperimentalCoroutinesApi::class) // For `flattenConcat`. private suspend fun getLastUpdatedMeasurement(): Measurement? { - var nextPageToken = "" - do { - val response: ListMeasurementsResponse = - try { - measurementsStub - .withAuthenticationKey(apiAuthenticationKey) - .listMeasurements( - listMeasurementsRequest { - parent = measurementConsumerName - this.pageSize = 1 - pageToken = nextPageToken - } - ) - } catch (e: StatusException) { - throw Exception( - "Unable to list measurements for measurement consumer $measurementConsumerName", - e, + val measurements: Flow> = + measurementsStub.withAuthenticationKey(apiAuthenticationKey).listResources(1) { + pageToken, + remaining -> + val response: ListMeasurementsResponse = + listMeasurements( + listMeasurementsRequest { + parent = measurementConsumerName + this.pageToken = pageToken + this.pageSize = remaining + } ) - } - if (response.measurementsList.isNotEmpty()) { - return response.measurementsList.single() + ResourceList(response.measurementsList, response.nextPageToken) } - nextPageToken = response.nextPageToken - } while (nextPageToken.isNotEmpty()) - return null + + return try { + measurements.flattenConcat().singleOrNull() + } catch (e: StatusException) { + throw Exception( + "Unable to list measurements for measurement consumer $measurementConsumerName", + e, + ) + } } - private suspend fun getRequisitionsForMeasurement(measurementName: String): List { - var nextPageToken = "" - val requisitions = mutableListOf() - do { - val response: ListRequisitionsResponse = - try { - requisitionsStub - .withAuthenticationKey(apiAuthenticationKey) - .listRequisitions( - listRequisitionsRequest { - parent = measurementName - pageToken = nextPageToken - } - ) - } catch (e: StatusException) { + @OptIn(ExperimentalCoroutinesApi::class) // For `flattenConcat`. + private fun getRequisitionsForMeasurement(measurementName: String): Flow { + return requisitionsStub + .withAuthenticationKey(apiAuthenticationKey) + .listResources { pageToken -> + val response: ListRequisitionsResponse = + listRequisitions(listRequisitionsRequest { this.pageToken = pageToken }) + ResourceList(response.requisitionsList, response.nextPageToken) + } + .catch { e -> + if (e is StatusException) { throw Exception("Unable to list requisitions for measurement $measurementName", e) + } else { + throw e } - requisitions.addAll(response.requisitionsList) - nextPageToken = response.nextPageToken - } while (nextPageToken.isNotEmpty()) - return requisitions + } + .flattenConcat() } private suspend fun getDataProviderEntry( @@ -360,10 +362,12 @@ class MeasurementSystemProber( private suspend fun updateLastTerminalRequisitionGauge(lastUpdatedMeasurement: Measurement) { val requisitions = getRequisitionsForMeasurement(lastUpdatedMeasurement.name) - for (requisition in requisitions) { + requisitions.collect { requisition -> if (requisition.state == Requisition.State.FULFILLED) { - val requisitionKey = CanonicalRequisitionKey.fromName(requisition.name) - require(requisitionKey != null) { "CanonicalRequisitionKey cannot be null" } + val requisitionKey = + requireNotNull(CanonicalRequisitionKey.fromName(requisition.name)) { + "Requisition name ${requisition.name} is invalid" + } val dataProviderName: String = requisitionKey.dataProviderId val attributes = Attributes.of(DATA_PROVIDER_ATTRIBUTE_KEY, dataProviderName) lastTerminalRequisitionTimeGauge.set( diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/BUILD.bazel index 445524ac557..3cdded1d308 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/BUILD.bazel @@ -132,7 +132,7 @@ kt_jvm_library( "context_keys", ":reporting_principal", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key", - "//src/main/kotlin/org/wfanet/measurement/common/api/grpc", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:akid_principal_server_interceptor", "//src/main/kotlin/org/wfanet/measurement/common/identity", "@wfa_common_jvm//imports/java/com/google/protobuf", "@wfa_common_jvm//imports/java/io/grpc:api", diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/BUILD.bazel index 22d6343b537..fc750cdc976 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/BUILD.bazel @@ -52,7 +52,7 @@ kt_jvm_library( "context_keys", ":reporting_principal", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key", - "//src/main/kotlin/org/wfanet/measurement/common/api/grpc", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:akid_principal_server_interceptor", "//src/main/kotlin/org/wfanet/measurement/common/identity", "@wfa_common_jvm//imports/java/com/google/protobuf", "@wfa_common_jvm//imports/java/io/grpc:api", @@ -120,6 +120,7 @@ kt_jvm_library( "//imports/java/org/projectnessie/cel", "//src/main/kotlin/org/wfanet/measurement/api:api_key_constants", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:packed_messages", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:list_resources", "//src/main/kotlin/org/wfanet/measurement/reporting/service/api:cel_env_provider", "//src/main/kotlin/org/wfanet/measurement/reporting/service/api:encryption_key_pair_store", "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha:principal_server_interceptor", diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsService.kt index 825e90c2bff..d957eb37c86 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsService.kt @@ -20,10 +20,13 @@ import com.google.protobuf.DynamicMessage import com.google.protobuf.kotlin.unpack import io.grpc.Context import io.grpc.Deadline +import io.grpc.Deadline.Ticker import io.grpc.Status import io.grpc.StatusException import java.security.GeneralSecurityException import java.util.concurrent.TimeUnit +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.transformWhile import org.projectnessie.cel.common.types.Err import org.projectnessie.cel.common.types.ref.Val import org.wfanet.measurement.api.v2alpha.DataProviderKey @@ -31,9 +34,12 @@ import org.wfanet.measurement.api.v2alpha.EncryptionPublicKey import org.wfanet.measurement.api.v2alpha.EventGroup as CmmsEventGroup import org.wfanet.measurement.api.v2alpha.EventGroupKey as CmmsEventGroupKey import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub as CmmsEventGroupsCoroutineStub +import org.wfanet.measurement.api.v2alpha.ListEventGroupsResponse as CmmsListEventGroupsResponse import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey import org.wfanet.measurement.api.v2alpha.listEventGroupsRequest import org.wfanet.measurement.api.withAuthenticationKey +import org.wfanet.measurement.common.api.grpc.ResourceList +import org.wfanet.measurement.common.api.grpc.listResources import org.wfanet.measurement.common.crypto.PrivateKeyHandle import org.wfanet.measurement.common.grpc.grpcRequire import org.wfanet.measurement.common.grpc.grpcRequireNotNull @@ -52,6 +58,7 @@ class EventGroupsService( private val cmmsEventGroupsStub: CmmsEventGroupsCoroutineStub, private val encryptionKeyPairStore: EncryptionKeyPairStore, private val celEnvProvider: CelEnvProvider, + private val ticker: Ticker = Deadline.getSystemTicker(), ) : EventGroupsCoroutineImplBase() { override suspend fun listEventGroups(request: ListEventGroupsRequest): ListEventGroupsResponse { val parentKey = @@ -71,66 +78,63 @@ class EventGroupsService( } } + val deadline: Deadline = + Context.current().deadline + ?: Deadline.after(RPC_DEFAULT_DEADLINE_MILLIS, TimeUnit.MILLISECONDS, ticker) val apiAuthenticationKey: String = principal.config.apiKey grpcRequire(request.pageSize >= 0) { "page_size cannot be negative" } - val pageSize = - when { - request.pageSize < MIN_PAGE_SIZE -> DEFAULT_PAGE_SIZE - request.pageSize > MAX_PAGE_SIZE -> MAX_PAGE_SIZE - else -> request.pageSize - } + val limit = + if (request.pageSize > 0) request.pageSize.coerceAtMost(MAX_PAGE_SIZE) else DEFAULT_PAGE_SIZE + val parent = parentKey.toName() + val eventGroupLists: Flow> = + cmmsEventGroupsStub.withAuthenticationKey(apiAuthenticationKey).listResources( + limit, + request.pageToken, + ) { pageToken, remaining -> + val response: CmmsListEventGroupsResponse = + listEventGroups( + listEventGroupsRequest { + this.parent = parent + this.pageSize = remaining + this.pageToken = pageToken + } + ) - var nextPageToken = request.pageToken - val deadline = Context.current().deadline ?: Deadline.after(30, TimeUnit.SECONDS) - do { - val cmmsListEventGroupResponse = - try { - cmmsEventGroupsStub - .withAuthenticationKey(apiAuthenticationKey) - .listEventGroups( - listEventGroupsRequest { - parent = parentKey.toName() - this.pageSize = pageSize - pageToken = nextPageToken + val eventGroups: List = + response.eventGroupsList.map { + val cmmsMetadata: CmmsEventGroup.Metadata? = + if (it.hasEncryptedMetadata()) { + decryptMetadata(it, principal.resourceKey.toName()) + } else { + null } - ) - } catch (e: StatusException) { - throw when (e.status.code) { - Status.Code.DEADLINE_EXCEEDED -> Status.DEADLINE_EXCEEDED - Status.Code.CANCELLED -> Status.CANCELLED - else -> Status.UNKNOWN - } - .withCause(e) - .asRuntimeException() - } - val cmmsEventGroups = cmmsListEventGroupResponse.eventGroupsList - val eventGroups = - cmmsEventGroups.map { - val cmmsMetadata: CmmsEventGroup.Metadata? = - if (it.hasEncryptedMetadata()) { - decryptMetadata(it, principal.resourceKey.toName()) - } else { - null - } + it.toEventGroup(cmmsMetadata) + } - it.toEventGroup(cmmsMetadata) - } + ResourceList(filterEventGroups(eventGroups, request.filter), response.nextPageToken) + } - val filteredEventGroups = filterEventGroups(eventGroups, request.filter) - if (filteredEventGroups.size > 0) { - return listEventGroupsResponse { - this.eventGroups += filteredEventGroups - this.nextPageToken = cmmsListEventGroupResponse.nextPageToken + return listEventGroupsResponse { + try { + eventGroupLists + .transformWhile { + emit(it) + deadline.timeRemaining(TimeUnit.MILLISECONDS) > RPC_DEADLINE_OVERHEAD_MILLIS + } + .collect { eventGroupList -> + this.eventGroups += eventGroupList + nextPageToken = eventGroupList.nextPageToken + } + } catch (e: StatusException) { + when (e.status.code) { + Status.Code.DEADLINE_EXCEEDED -> {} + else -> Status.UNKNOWN.withCause(e).asRuntimeException() } - } else { - nextPageToken = cmmsListEventGroupResponse.nextPageToken } - } while (deadline.timeRemaining(TimeUnit.SECONDS) > 5) - - return listEventGroupsResponse { this.nextPageToken = nextPageToken } + } } private suspend fun filterEventGroups( @@ -259,8 +263,13 @@ class EventGroupsService( companion object { private const val METADATA_FIELD = "metadata.metadata" - private const val MIN_PAGE_SIZE = 1 private const val DEFAULT_PAGE_SIZE = 50 private const val MAX_PAGE_SIZE = 1000 + + /** Overhead to allow for RPC deadlines in milliseconds. */ + private const val RPC_DEADLINE_OVERHEAD_MILLIS = 100L + + /** Default RPC deadline in milliseconds. */ + private const val RPC_DEFAULT_DEADLINE_MILLIS = 30_000L } } diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsServiceTest.kt index 913f7987731..410177e1953 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsServiceTest.kt @@ -19,13 +19,16 @@ package org.wfanet.measurement.reporting.service.api.v2alpha import com.google.common.truth.Truth.assertThat import com.google.common.truth.extensions.proto.ProtoTruth.assertThat import com.google.protobuf.Any +import io.grpc.Deadline import io.grpc.Status import io.grpc.StatusRuntimeException import java.nio.file.Path import java.nio.file.Paths import java.time.Duration +import java.util.concurrent.TimeUnit import kotlin.test.assertFailsWith import kotlinx.coroutines.runBlocking +import org.junit.After import org.junit.Before import org.junit.Rule import org.junit.Test @@ -105,11 +108,13 @@ class EventGroupsServiceTest { addService(publicKingdomEventGroupMetadataDescriptorsMock) } + private lateinit var celEnvCacheProvider: CelEnvCacheProvider private lateinit var service: EventGroupsService + private val fakeTicker = SettableSystemTicker() @Before fun initService() { - val celEnvCacheProvider = + celEnvCacheProvider = CelEnvCacheProvider( EventGroupMetadataDescriptorsCoroutineStub(grpcTestServerRule.channel), EventGroup.getDescriptor(), @@ -122,9 +127,15 @@ class EventGroupsServiceTest { EventGroupsCoroutineStub(grpcTestServerRule.channel), ENCRYPTION_KEY_PAIR_STORE, celEnvCacheProvider, + fakeTicker, ) } + @After + fun closeCelEnvCacheProvider() { + celEnvCacheProvider.close() + } + @Test fun `listEventGroups returns events groups after multiple calls to kingdom`() = runBlocking { val testMessage = testMetadataMessage { publisherId = 5 } @@ -143,13 +154,13 @@ class EventGroupsServiceTest { whenever(publicKingdomEventGroupsMock.listEventGroups(any())) .thenReturn( listEventGroupsResponse { - nextPageToken = "1" eventGroups += cmmsEventGroup2 + nextPageToken = "1" } ) .thenReturn( listEventGroupsResponse { - eventGroups += listOf(CMMS_EVENT_GROUP, cmmsEventGroup2) + eventGroups += CMMS_EVENT_GROUP nextPageToken = "2" } ) @@ -161,6 +172,7 @@ class EventGroupsServiceTest { listEventGroupsRequest { parent = MEASUREMENT_CONSUMER_NAME filter = "metadata.metadata.publisher_id > 5" + pageSize = 1 } ) } @@ -191,7 +203,10 @@ class EventGroupsServiceTest { eventGroups += CMMS_EVENT_GROUP } ) - + .then { + // Advance time. + fakeTicker.setNanoTime(fakeTicker.nanoTime() + TimeUnit.SECONDS.toNanos(30)) + } val response = withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { runBlocking { @@ -364,7 +379,12 @@ class EventGroupsServiceTest { val response = withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { runBlocking { - service.listEventGroups(listEventGroupsRequest { parent = MEASUREMENT_CONSUMER_NAME }) + service.listEventGroups( + listEventGroupsRequest { + parent = MEASUREMENT_CONSUMER_NAME + pageSize = 2 + } + ) } } @@ -376,7 +396,7 @@ class EventGroupsServiceTest { .isEqualTo( cmmsListEventGroupsRequest { parent = MEASUREMENT_CONSUMER_NAME - pageSize = DEFAULT_PAGE_SIZE + pageSize = 2 } ) } @@ -593,14 +613,6 @@ class EventGroupsServiceTest { @Test fun `listEventGroups throws FAILED_PRECONDITION when store doesn't have private key`() { - val celEnvCacheProvider = - CelEnvCacheProvider( - EventGroupMetadataDescriptorsCoroutineStub(grpcTestServerRule.channel), - EventGroup.getDescriptor(), - Duration.ofSeconds(5), - emptyList(), - ) - service = EventGroupsService( EventGroupsCoroutineStub(grpcTestServerRule.channel), @@ -742,4 +754,20 @@ class EventGroupsServiceTest { } } } + + /** + * Fake [Deadline.Ticker] implementation that allows time to be specified to override delegation + * to the system ticker. + */ + private class SettableSystemTicker : Deadline.Ticker() { + private var nanoTime: Long? = null + + fun setNanoTime(value: Long) { + nanoTime = value + } + + override fun nanoTime(): Long { + return this.nanoTime ?: Deadline.getSystemTicker().nanoTime() + } + } }