diff --git a/src/main/kotlin/org/wfanet/measurement/api/v2alpha/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/api/v2alpha/BUILD.bazel index 45bc9a2576f..e53b6c87c1c 100644 --- a/src/main/kotlin/org/wfanet/measurement/api/v2alpha/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/api/v2alpha/BUILD.bazel @@ -52,7 +52,7 @@ kt_jvm_library( deps = [ "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:context_keys", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:measurement_principal", - "//src/main/kotlin/org/wfanet/measurement/common/api/grpc", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:akid_principal_server_interceptor", "//src/main/kotlin/org/wfanet/measurement/common/grpc:context", "//src/main/kotlin/org/wfanet/measurement/common/identity", "//src/main/proto/wfa/measurement/api/v2alpha:duchy_kt_jvm_proto", diff --git a/src/main/kotlin/org/wfanet/measurement/common/api/grpc/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/common/api/grpc/BUILD.bazel index 3a897c0ad81..4ffd63ed7ab 100644 --- a/src/main/kotlin/org/wfanet/measurement/common/api/grpc/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/common/api/grpc/BUILD.bazel @@ -3,7 +3,7 @@ load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") package(default_visibility = ["//visibility:public"]) kt_jvm_library( - name = "grpc", + name = "akid_principal_server_interceptor", srcs = ["AkidPrincipalServerInterceptor.kt"], deps = [ "//src/main/kotlin/org/wfanet/measurement/common/api:principal", @@ -14,3 +14,13 @@ kt_jvm_library( "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", ], ) + +kt_jvm_library( + name = "list_resources", + srcs = ["ListResources.kt"], + deps = [ + "@wfa_common_jvm//imports/java/com/google/protobuf", + "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", + "@wfa_rules_kotlin_jvm//imports/io/gprc/kotlin:stub", + ], +) diff --git a/src/main/kotlin/org/wfanet/measurement/common/api/grpc/ListResources.kt b/src/main/kotlin/org/wfanet/measurement/common/api/grpc/ListResources.kt new file mode 100644 index 00000000000..9e19f98043d --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/common/api/grpc/ListResources.kt @@ -0,0 +1,92 @@ +/* + * Copyright 2024 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.common.api.grpc + +import com.google.protobuf.Message +import io.grpc.kotlin.AbstractCoroutineStub +import kotlin.coroutines.coroutineContext +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.ensureActive +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.asFlow +import kotlinx.coroutines.flow.flattenConcat +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.map + +/** A [List] of resources from a paginated List method. */ +data class ResourceList( + val resources: List, + /** + * A token that can be sent on subsequent requests to retrieve the next page. If non-empty, there + * are no subsequent pages. + */ + val nextPageToken: String, +) : List by resources + +/** + * Lists resources from a paginated List method on this stub. + * + * @param pageToken page token for initial request + * @param list function which calls the appropriate List method on the stub + */ +fun > S.listResources( + pageToken: String = "", + list: suspend S.(pageToken: String) -> ResourceList, +): Flow> = + listResources(Int.MAX_VALUE, pageToken) { nextPageToken, _ -> list(nextPageToken) } + +/** + * Lists resources from a paginated List method on this stub. + * + * @param limit maximum number of resources to emit + * @param pageToken page token for initial request + * @param list function which calls the appropriate List method on the stub, returning no more than + * the specified remaining number of resources + */ +fun > S.listResources( + limit: Int, + pageToken: String = "", + list: suspend S.(pageToken: String, remaining: Int) -> ResourceList, +): Flow> { + require(limit > 0) { "limit must be positive" } + return flow { + var remaining: Int = limit + var nextPageToken = pageToken + + while (true) { + coroutineContext.ensureActive() + + val resourceList: ResourceList = list(nextPageToken, remaining) + require(resourceList.size <= remaining) { + "List call must ensure that limit is not exceeded. " + + "Returned ${resourceList.size} items when only $remaining were remaining" + } + emit(resourceList) + + remaining -= resourceList.size + nextPageToken = resourceList.nextPageToken + if (nextPageToken.isEmpty() || remaining == 0) { + break + } + } + } +} + +/** @see [flattenConcat] */ +@ExperimentalCoroutinesApi // Overloads experimental `flattenConcat` function. +fun Flow>.flattenConcat(): Flow = + map { it.asFlow() }.flattenConcat() diff --git a/src/main/kotlin/org/wfanet/measurement/integration/common/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/integration/common/BUILD.bazel index f47074b4ae3..8764c226d65 100644 --- a/src/main/kotlin/org/wfanet/measurement/integration/common/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/integration/common/BUILD.bazel @@ -319,6 +319,7 @@ kt_jvm_library( ], deps = [ ":in_process_cmms_components", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:list_resources", "//src/main/kotlin/org/wfanet/measurement/kingdom/batch:measurement_system_prober", "//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common/service:data_services", "@wfa_common_jvm//imports/java/com/google/common/truth", diff --git a/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessMeasurementSystemProberIntegrationTest.kt b/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessMeasurementSystemProberIntegrationTest.kt index 0a015ca80d3..d9a14f29152 100644 --- a/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessMeasurementSystemProberIntegrationTest.kt +++ b/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessMeasurementSystemProberIntegrationTest.kt @@ -22,6 +22,9 @@ import java.io.File import java.nio.file.Paths import java.time.Clock import java.time.Duration +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.toList import kotlinx.coroutines.runBlocking import org.junit.After import org.junit.Before @@ -30,13 +33,15 @@ import org.junit.Rule import org.junit.Test import org.wfanet.measurement.api.v2alpha.DataProvidersGrpcKt.DataProvidersCoroutineStub import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub -import org.wfanet.measurement.api.v2alpha.ListMeasurementsResponse import org.wfanet.measurement.api.v2alpha.Measurement import org.wfanet.measurement.api.v2alpha.MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineStub import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt.MeasurementsCoroutineStub import org.wfanet.measurement.api.v2alpha.RequisitionsGrpcKt.RequisitionsCoroutineStub import org.wfanet.measurement.api.v2alpha.listMeasurementsRequest import org.wfanet.measurement.api.withAuthenticationKey +import org.wfanet.measurement.common.api.grpc.ResourceList +import org.wfanet.measurement.common.api.grpc.flattenConcat +import org.wfanet.measurement.common.api.grpc.listResources import org.wfanet.measurement.common.getRuntimePath import org.wfanet.measurement.common.identity.withPrincipalName import org.wfanet.measurement.common.testing.ProviderRule @@ -129,33 +134,32 @@ abstract class InProcessMeasurementSystemProberIntegrationTest( assertThat(measurements.size).isEqualTo(1) } + @OptIn(ExperimentalCoroutinesApi::class) // For `flattenConcat`. private suspend fun listMeasurements(): List { - var nextPageToken = "" val measurementConsumerData = inProcessCmmsComponents.getMeasurementConsumerData() - do { - val response: ListMeasurementsResponse = - try { - publicMeasurementsClient - .withAuthenticationKey(measurementConsumerData.apiAuthenticationKey) - .listMeasurements( - listMeasurementsRequest { - parent = measurementConsumerData.name - pageToken = nextPageToken - } - ) - } catch (e: StatusException) { - throw Exception( - "Unable to list measurements for measurement consumer ${measurementConsumerData.name}", - e, - ) + val measurementLists: Flow> = + publicMeasurementsClient + .withAuthenticationKey(measurementConsumerData.apiAuthenticationKey) + .listResources { pageToken -> + val response = + try { + listMeasurements( + listMeasurementsRequest { + parent = measurementConsumerData.name + this.pageToken = pageToken + } + ) + } catch (e: StatusException) { + throw Exception( + "Unable to list measurements for measurement consumer ${measurementConsumerData.name}", + e, + ) + } + ResourceList(response.measurementsList, response.nextPageToken) } - if (response.measurementsList.isNotEmpty()) { - return response.measurementsList - } - nextPageToken = response.nextPageToken - } while (nextPageToken.isNotEmpty()) - return emptyList() + + return measurementLists.flattenConcat().toList() } companion object { diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/batch/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/kingdom/batch/BUILD.bazel index 12d1d96644d..2d455bdfd32 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/batch/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/batch/BUILD.bazel @@ -47,6 +47,7 @@ kt_jvm_library( "//src/main/kotlin/org/wfanet/measurement/api:api_key_constants", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:packed_messages", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:list_resources", "//src/main/proto/wfa/measurement/api/v2alpha:data_provider_kt_jvm_proto", "//src/main/proto/wfa/measurement/api/v2alpha:data_providers_service_kt_jvm_grpc_proto", "//src/main/proto/wfa/measurement/api/v2alpha:event_group_kt_jvm_proto", diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/batch/MeasurementSystemProber.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/batch/MeasurementSystemProber.kt index b64a3a55842..b0e6baa0215 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/batch/MeasurementSystemProber.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/batch/MeasurementSystemProber.kt @@ -27,6 +27,11 @@ import java.security.SecureRandom import java.time.Clock import java.time.Duration import java.util.logging.Logger +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.single +import kotlinx.coroutines.flow.singleOrNull import org.wfanet.measurement.api.v2alpha.CanonicalRequisitionKey import org.wfanet.measurement.api.v2alpha.DataProvider import org.wfanet.measurement.api.v2alpha.DataProvidersGrpcKt @@ -60,6 +65,9 @@ import org.wfanet.measurement.api.v2alpha.requisitionSpec import org.wfanet.measurement.api.v2alpha.unpack import org.wfanet.measurement.api.withAuthenticationKey import org.wfanet.measurement.common.Instrumentation +import org.wfanet.measurement.common.api.grpc.ResourceList +import org.wfanet.measurement.common.api.grpc.flattenConcat +import org.wfanet.measurement.common.api.grpc.listResources import org.wfanet.measurement.common.crypto.Hashing import org.wfanet.measurement.common.crypto.SigningKeyHandle import org.wfanet.measurement.common.crypto.readCertificate @@ -208,33 +216,31 @@ class MeasurementSystemProber( throw Exception("Unable to get DataProvider with name $dataProviderName", e) } - // TODO(@roaminggypsy): Implement QA event group logic using simulatorEventGroupName - val listEventGroupsRequest = listEventGroupsRequest { - parent = measurementConsumerName - filter = ListEventGroupsRequestKt.filter { dataProviders += dataProviderName } - } - - val eventGroups: List = - try { - eventGroupsStub - .withAuthenticationKey(apiAuthenticationKey) - .listEventGroups(listEventGroupsRequest) - .eventGroupsList - .toList() - } catch (e: StatusException) { - throw Exception( - "Unable to get event groups associated with measurement consumer $measurementConsumerName and data provider $dataProviderName", - e, - ) - } - - if (eventGroups.size != 1) { - throw IllegalStateException( - "here should be exactly 1:1 mapping between a data provider and an event group, but data provider $dataProvider is related to ${eventGroups.size} event groups" - ) - } + val eventGroup: EventGroup = + eventGroupsStub + .withAuthenticationKey(apiAuthenticationKey) + .listResources(1) { pageToken, remaining -> + val request = listEventGroupsRequest { + parent = measurementConsumerName + filter = ListEventGroupsRequestKt.filter { dataProviders += dataProviderName } + this.pageToken = pageToken + pageSize = remaining + } + val response = + try { + listEventGroups(request) + } catch (e: StatusException) { + throw Exception( + "Unable to get event groups associated with measurement consumer $measurementConsumerName and data provider $dataProviderName", + e, + ) + } + ResourceList(response.eventGroupsList, response.nextPageToken) + } + .map { it.single() } + .single() - dataProviderNameToEventGroup[dataProviderName] = eventGroups[0] + dataProviderNameToEventGroup[dataProviderName] = eventGroup } return dataProviderNameToEventGroup } @@ -253,55 +259,47 @@ class MeasurementSystemProber( return clock.instant() >= nextMeasurementEarliestInstant } + @OptIn(ExperimentalCoroutinesApi::class) // For `flattenConcat`. private suspend fun getLastUpdatedMeasurement(): Measurement? { - var nextPageToken = "" - do { - val response: ListMeasurementsResponse = - try { - measurementsStub - .withAuthenticationKey(apiAuthenticationKey) - .listMeasurements( + val measurements: Flow> = + measurementsStub.withAuthenticationKey(apiAuthenticationKey).listResources(1) { + pageToken, + remaining -> + val response: ListMeasurementsResponse = + try { + listMeasurements( listMeasurementsRequest { parent = measurementConsumerName - this.pageSize = 1 - pageToken = nextPageToken + this.pageToken = pageToken + this.pageSize = remaining } ) - } catch (e: StatusException) { - throw Exception( - "Unable to list measurements for measurement consumer $measurementConsumerName", - e, - ) - } - if (response.measurementsList.isNotEmpty()) { - return response.measurementsList.single() + } catch (e: StatusException) { + throw Exception( + "Unable to list measurements for measurement consumer $measurementConsumerName", + e, + ) + } + ResourceList(response.measurementsList, response.nextPageToken) } - nextPageToken = response.nextPageToken - } while (nextPageToken.isNotEmpty()) - return null + + return measurements.flattenConcat().singleOrNull() } - private suspend fun getRequisitionsForMeasurement(measurementName: String): List { - var nextPageToken = "" - val requisitions = mutableListOf() - do { - val response: ListRequisitionsResponse = - try { - requisitionsStub - .withAuthenticationKey(apiAuthenticationKey) - .listRequisitions( - listRequisitionsRequest { - parent = measurementName - pageToken = nextPageToken - } - ) - } catch (e: StatusException) { - throw Exception("Unable to list requisitions for measurement $measurementName", e) - } - requisitions.addAll(response.requisitionsList) - nextPageToken = response.nextPageToken - } while (nextPageToken.isNotEmpty()) - return requisitions + @OptIn(ExperimentalCoroutinesApi::class) // For `flattenConcat`. + private fun getRequisitionsForMeasurement(measurementName: String): Flow { + return requisitionsStub + .withAuthenticationKey(apiAuthenticationKey) + .listResources { pageToken -> + val response: ListRequisitionsResponse = + try { + listRequisitions(listRequisitionsRequest { this.pageToken = pageToken }) + } catch (e: StatusException) { + throw Exception("Unable to list requisitions for measurement $measurementName", e) + } + ResourceList(response.requisitionsList, response.nextPageToken) + } + .flattenConcat() } private suspend fun getDataProviderEntry( @@ -360,10 +358,12 @@ class MeasurementSystemProber( private suspend fun updateLastTerminalRequisitionGauge(lastUpdatedMeasurement: Measurement) { val requisitions = getRequisitionsForMeasurement(lastUpdatedMeasurement.name) - for (requisition in requisitions) { + requisitions.collect { requisition -> if (requisition.state == Requisition.State.FULFILLED) { - val requisitionKey = CanonicalRequisitionKey.fromName(requisition.name) - require(requisitionKey != null) { "CanonicalRequisitionKey cannot be null" } + val requisitionKey = + requireNotNull(CanonicalRequisitionKey.fromName(requisition.name)) { + "Requisition name ${requisition.name} is invalid" + } val dataProviderName: String = requisitionKey.dataProviderId val attributes = Attributes.of(DATA_PROVIDER_ATTRIBUTE_KEY, dataProviderName) lastTerminalRequisitionTimeGauge.set( diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/BUILD.bazel index 445524ac557..3cdded1d308 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/BUILD.bazel @@ -132,7 +132,7 @@ kt_jvm_library( "context_keys", ":reporting_principal", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key", - "//src/main/kotlin/org/wfanet/measurement/common/api/grpc", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:akid_principal_server_interceptor", "//src/main/kotlin/org/wfanet/measurement/common/identity", "@wfa_common_jvm//imports/java/com/google/protobuf", "@wfa_common_jvm//imports/java/io/grpc:api", diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/BUILD.bazel index 22d6343b537..fc750cdc976 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/BUILD.bazel @@ -52,7 +52,7 @@ kt_jvm_library( "context_keys", ":reporting_principal", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key", - "//src/main/kotlin/org/wfanet/measurement/common/api/grpc", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:akid_principal_server_interceptor", "//src/main/kotlin/org/wfanet/measurement/common/identity", "@wfa_common_jvm//imports/java/com/google/protobuf", "@wfa_common_jvm//imports/java/io/grpc:api", @@ -120,6 +120,7 @@ kt_jvm_library( "//imports/java/org/projectnessie/cel", "//src/main/kotlin/org/wfanet/measurement/api:api_key_constants", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:packed_messages", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:list_resources", "//src/main/kotlin/org/wfanet/measurement/reporting/service/api:cel_env_provider", "//src/main/kotlin/org/wfanet/measurement/reporting/service/api:encryption_key_pair_store", "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha:principal_server_interceptor", diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsService.kt index 825e90c2bff..650a883fa68 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsService.kt @@ -20,10 +20,13 @@ import com.google.protobuf.DynamicMessage import com.google.protobuf.kotlin.unpack import io.grpc.Context import io.grpc.Deadline +import io.grpc.Deadline.Ticker import io.grpc.Status import io.grpc.StatusException import java.security.GeneralSecurityException import java.util.concurrent.TimeUnit +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.transformWhile import org.projectnessie.cel.common.types.Err import org.projectnessie.cel.common.types.ref.Val import org.wfanet.measurement.api.v2alpha.DataProviderKey @@ -31,9 +34,12 @@ import org.wfanet.measurement.api.v2alpha.EncryptionPublicKey import org.wfanet.measurement.api.v2alpha.EventGroup as CmmsEventGroup import org.wfanet.measurement.api.v2alpha.EventGroupKey as CmmsEventGroupKey import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub as CmmsEventGroupsCoroutineStub +import org.wfanet.measurement.api.v2alpha.ListEventGroupsResponse as CmmsListEventGroupsResponse import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey import org.wfanet.measurement.api.v2alpha.listEventGroupsRequest import org.wfanet.measurement.api.withAuthenticationKey +import org.wfanet.measurement.common.api.grpc.ResourceList +import org.wfanet.measurement.common.api.grpc.listResources import org.wfanet.measurement.common.crypto.PrivateKeyHandle import org.wfanet.measurement.common.grpc.grpcRequire import org.wfanet.measurement.common.grpc.grpcRequireNotNull @@ -52,6 +58,7 @@ class EventGroupsService( private val cmmsEventGroupsStub: CmmsEventGroupsCoroutineStub, private val encryptionKeyPairStore: EncryptionKeyPairStore, private val celEnvProvider: CelEnvProvider, + private val ticker: Ticker = Deadline.getSystemTicker(), ) : EventGroupsCoroutineImplBase() { override suspend fun listEventGroups(request: ListEventGroupsRequest): ListEventGroupsResponse { val parentKey = @@ -71,66 +78,79 @@ class EventGroupsService( } } + val deadline: Deadline = + Context.current().deadline + ?: Deadline.after(RPC_DEFAULT_DEADLINE_MILLIS, TimeUnit.MILLISECONDS, ticker) val apiAuthenticationKey: String = principal.config.apiKey grpcRequire(request.pageSize >= 0) { "page_size cannot be negative" } - val pageSize = - when { - request.pageSize < MIN_PAGE_SIZE -> DEFAULT_PAGE_SIZE - request.pageSize > MAX_PAGE_SIZE -> MAX_PAGE_SIZE - else -> request.pageSize - } + val limit = + if (request.pageSize > 0) request.pageSize.coerceAtMost(MAX_PAGE_SIZE) else DEFAULT_PAGE_SIZE + val parent = parentKey.toName() + val eventGroupLists: Flow> = + cmmsEventGroupsStub.withAuthenticationKey(apiAuthenticationKey).listResources( + limit, + request.pageToken, + ) { pageToken, remaining -> + val response: CmmsListEventGroupsResponse = + listEventGroups( + listEventGroupsRequest { + this.parent = parent + this.pageSize = remaining + this.pageToken = pageToken + } + ) - var nextPageToken = request.pageToken - val deadline = Context.current().deadline ?: Deadline.after(30, TimeUnit.SECONDS) - do { - val cmmsListEventGroupResponse = - try { - cmmsEventGroupsStub - .withAuthenticationKey(apiAuthenticationKey) - .listEventGroups( - listEventGroupsRequest { - parent = parentKey.toName() - this.pageSize = pageSize - pageToken = nextPageToken + val eventGroups: List = + response.eventGroupsList.map { + val cmmsMetadata: CmmsEventGroup.Metadata? = + if (it.hasEncryptedMetadata()) { + decryptMetadata(it, principal.resourceKey.toName()) + } else { + null } - ) - } catch (e: StatusException) { - throw when (e.status.code) { - Status.Code.DEADLINE_EXCEEDED -> Status.DEADLINE_EXCEEDED - Status.Code.CANCELLED -> Status.CANCELLED - else -> Status.UNKNOWN - } - .withCause(e) - .asRuntimeException() - } - val cmmsEventGroups = cmmsListEventGroupResponse.eventGroupsList - val eventGroups = - cmmsEventGroups.map { - val cmmsMetadata: CmmsEventGroup.Metadata? = - if (it.hasEncryptedMetadata()) { - decryptMetadata(it, principal.resourceKey.toName()) - } else { - null - } + it.toEventGroup(cmmsMetadata) + } - it.toEventGroup(cmmsMetadata) - } + ResourceList(filterEventGroups(eventGroups, request.filter), response.nextPageToken) + } - val filteredEventGroups = filterEventGroups(eventGroups, request.filter) - if (filteredEventGroups.size > 0) { - return listEventGroupsResponse { - this.eventGroups += filteredEventGroups - this.nextPageToken = cmmsListEventGroupResponse.nextPageToken + var hasResponse = false + return listEventGroupsResponse { + try { + eventGroupLists + .transformWhile { + emit(it) + deadline.timeRemaining(TimeUnit.MILLISECONDS) > RPC_DEADLINE_OVERHEAD_MILLIS + } + .collect { eventGroupList -> + this.eventGroups += eventGroupList + nextPageToken = eventGroupList.nextPageToken + hasResponse = true + } + } catch (e: StatusException) { + when (e.status.code) { + Status.Code.DEADLINE_EXCEEDED, + Status.Code.CANCELLED -> { + if (!hasResponse) { + // Only throw an error if we don't have any response yet. Otherwise, just return what + // we have so far. + throw Status.DEADLINE_EXCEEDED.withDescription( + "Timed out listing EventGroups from backend" + ) + .withCause(e) + .asRuntimeException() + } + } + else -> + throw Status.UNKNOWN.withDescription("Error listing EventGroups from backend") + .withCause(e) + .asRuntimeException() } - } else { - nextPageToken = cmmsListEventGroupResponse.nextPageToken } - } while (deadline.timeRemaining(TimeUnit.SECONDS) > 5) - - return listEventGroupsResponse { this.nextPageToken = nextPageToken } + } } private suspend fun filterEventGroups( @@ -259,8 +279,13 @@ class EventGroupsService( companion object { private const val METADATA_FIELD = "metadata.metadata" - private const val MIN_PAGE_SIZE = 1 private const val DEFAULT_PAGE_SIZE = 50 private const val MAX_PAGE_SIZE = 1000 + + /** Overhead to allow for RPC deadlines in milliseconds. */ + private const val RPC_DEADLINE_OVERHEAD_MILLIS = 100L + + /** Default RPC deadline in milliseconds. */ + private const val RPC_DEFAULT_DEADLINE_MILLIS = 30_000L } } diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsServiceTest.kt index befbdcedd0b..42df07ae256 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsServiceTest.kt @@ -19,11 +19,13 @@ package org.wfanet.measurement.reporting.service.api.v2alpha import com.google.common.truth.Truth.assertThat import com.google.common.truth.extensions.proto.ProtoTruth.assertThat import com.google.protobuf.Any +import io.grpc.Deadline import io.grpc.Status import io.grpc.StatusRuntimeException import java.nio.file.Path import java.nio.file.Paths import java.time.Duration +import java.util.concurrent.TimeUnit import kotlin.test.assertFailsWith import kotlinx.coroutines.runBlocking import org.junit.After @@ -110,6 +112,7 @@ class EventGroupsServiceTest { private lateinit var celEnvCacheProvider: CelEnvCacheProvider private lateinit var service: EventGroupsService + private val fakeTicker = SettableSystemTicker() @Before fun initService() { @@ -126,6 +129,7 @@ class EventGroupsServiceTest { EventGroupsCoroutineStub(grpcTestServerRule.channel), ENCRYPTION_KEY_PAIR_STORE, celEnvCacheProvider, + fakeTicker, ) } @@ -152,13 +156,13 @@ class EventGroupsServiceTest { whenever(publicKingdomEventGroupsMock.listEventGroups(any())) .thenReturn( listEventGroupsResponse { - nextPageToken = "1" eventGroups += cmmsEventGroup2 + nextPageToken = "1" } ) .thenReturn( listEventGroupsResponse { - eventGroups += listOf(CMMS_EVENT_GROUP, cmmsEventGroup2) + eventGroups += CMMS_EVENT_GROUP nextPageToken = "2" } ) @@ -170,6 +174,7 @@ class EventGroupsServiceTest { listEventGroupsRequest { parent = MEASUREMENT_CONSUMER_NAME filter = "metadata.metadata.publisher_id > 5" + pageSize = 1 } ) } @@ -200,7 +205,10 @@ class EventGroupsServiceTest { eventGroups += CMMS_EVENT_GROUP } ) - + .then { + // Advance time. + fakeTicker.setNanoTime(fakeTicker.nanoTime() + TimeUnit.SECONDS.toNanos(30)) + } val response = withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { runBlocking { @@ -443,7 +451,12 @@ class EventGroupsServiceTest { val response = withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { runBlocking { - service.listEventGroups(listEventGroupsRequest { parent = MEASUREMENT_CONSUMER_NAME }) + service.listEventGroups( + listEventGroupsRequest { + parent = MEASUREMENT_CONSUMER_NAME + pageSize = 2 + } + ) } } @@ -455,7 +468,7 @@ class EventGroupsServiceTest { .isEqualTo( cmmsListEventGroupsRequest { parent = MEASUREMENT_CONSUMER_NAME - pageSize = DEFAULT_PAGE_SIZE + pageSize = 2 } ) } @@ -813,4 +826,20 @@ class EventGroupsServiceTest { } } } + + /** + * Fake [Deadline.Ticker] implementation that allows time to be specified to override delegation + * to the system ticker. + */ + private class SettableSystemTicker : Deadline.Ticker() { + private var nanoTime: Long? = null + + fun setNanoTime(value: Long) { + nanoTime = value + } + + override fun nanoTime(): Long { + return this.nanoTime ?: Deadline.getSystemTicker().nanoTime() + } + } }