From 1391cab0c5eebcd292cfe03bf75808c2370db7f1 Mon Sep 17 00:00:00 2001 From: Sanjay Vasandani Date: Thu, 21 Nov 2024 15:52:37 -0800 Subject: [PATCH 1/3] refactor: Extract listResources utility function for handling pagination This also fixes an issue where EventGroupsServiceTest was waiting for a real 30s deadline to pass. --- .../measurement/api/v2alpha/BUILD.bazel | 2 +- .../measurement/common/api/grpc/BUILD.bazel | 12 +- .../common/api/grpc/ListResources.kt | 92 +++++++++++ .../integration/common/BUILD.bazel | 1 + ...sMeasurementSystemProberIntegrationTest.kt | 52 +++--- .../measurement/kingdom/batch/BUILD.bazel | 1 + .../kingdom/batch/MeasurementSystemProber.kt | 151 ++++++++---------- .../reporting/service/api/v1alpha/BUILD.bazel | 2 +- .../reporting/service/api/v2alpha/BUILD.bazel | 3 +- .../service/api/v2alpha/EventGroupsService.kt | 125 +++++++++------ .../api/v2alpha/EventGroupsServiceTest.kt | 39 ++++- 11 files changed, 316 insertions(+), 164 deletions(-) create mode 100644 src/main/kotlin/org/wfanet/measurement/common/api/grpc/ListResources.kt diff --git a/src/main/kotlin/org/wfanet/measurement/api/v2alpha/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/api/v2alpha/BUILD.bazel index 45bc9a2576f..e53b6c87c1c 100644 --- a/src/main/kotlin/org/wfanet/measurement/api/v2alpha/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/api/v2alpha/BUILD.bazel @@ -52,7 +52,7 @@ kt_jvm_library( deps = [ "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:context_keys", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:measurement_principal", - "//src/main/kotlin/org/wfanet/measurement/common/api/grpc", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:akid_principal_server_interceptor", "//src/main/kotlin/org/wfanet/measurement/common/grpc:context", "//src/main/kotlin/org/wfanet/measurement/common/identity", "//src/main/proto/wfa/measurement/api/v2alpha:duchy_kt_jvm_proto", diff --git a/src/main/kotlin/org/wfanet/measurement/common/api/grpc/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/common/api/grpc/BUILD.bazel index 3a897c0ad81..4ffd63ed7ab 100644 --- a/src/main/kotlin/org/wfanet/measurement/common/api/grpc/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/common/api/grpc/BUILD.bazel @@ -3,7 +3,7 @@ load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") package(default_visibility = ["//visibility:public"]) kt_jvm_library( - name = "grpc", + name = "akid_principal_server_interceptor", srcs = ["AkidPrincipalServerInterceptor.kt"], deps = [ "//src/main/kotlin/org/wfanet/measurement/common/api:principal", @@ -14,3 +14,13 @@ kt_jvm_library( "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", ], ) + +kt_jvm_library( + name = "list_resources", + srcs = ["ListResources.kt"], + deps = [ + "@wfa_common_jvm//imports/java/com/google/protobuf", + "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", + "@wfa_rules_kotlin_jvm//imports/io/gprc/kotlin:stub", + ], +) diff --git a/src/main/kotlin/org/wfanet/measurement/common/api/grpc/ListResources.kt b/src/main/kotlin/org/wfanet/measurement/common/api/grpc/ListResources.kt new file mode 100644 index 00000000000..ebd768b2412 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/common/api/grpc/ListResources.kt @@ -0,0 +1,92 @@ +/* + * Copyright 2024 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.common.api.grpc + +import com.google.protobuf.Message +import io.grpc.kotlin.AbstractCoroutineStub +import kotlin.coroutines.coroutineContext +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.ensureActive +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.asFlow +import kotlinx.coroutines.flow.flattenConcat +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.map + +/** A [List] of resources from a paginated List method. */ +data class ResourceList( + val resources: List, + /** + * A token that can be sent on subsequent requests to retrieve the next page. If empty, there are + * no subsequent pages. + */ + val nextPageToken: String, +) : List by resources + +/** + * Lists resources from a paginated List method on this stub. + * + * @param pageToken page token for initial request + * @param list function which calls the appropriate List method on the stub + */ +fun > S.listResources( + pageToken: String = "", + list: suspend S.(pageToken: String) -> ResourceList, +): Flow> = + listResources(Int.MAX_VALUE, pageToken) { nextPageToken, _ -> list(nextPageToken) } + +/** + * Lists resources from a paginated List method on this stub. + * + * @param limit maximum number of resources to emit + * @param pageToken page token for initial request + * @param list function which calls the appropriate List method on the stub, returning no more than + * the specified remaining number of resources + */ +fun > S.listResources( + limit: Int, + pageToken: String = "", + list: suspend S.(pageToken: String, remaining: Int) -> ResourceList, +): Flow> { + require(limit > 0) { "limit must be positive" } + return flow { + var remaining: Int = limit + var nextPageToken = pageToken + + while (true) { + coroutineContext.ensureActive() + + val resourceList: ResourceList = list(nextPageToken, remaining) + require(resourceList.size <= remaining) { + "List call must ensure that limit is not exceeded. " + + "Returned ${resourceList.size} items when only $remaining were remaining" + } + emit(resourceList) + + remaining -= resourceList.size + nextPageToken = resourceList.nextPageToken + if (nextPageToken.isEmpty() || remaining == 0) { + break + } + } + } +} + +/** @see [flattenConcat] */ +@ExperimentalCoroutinesApi // Overloads experimental `flattenConcat` function. +fun Flow>.flattenConcat(): Flow = + map { it.asFlow() }.flattenConcat() diff --git a/src/main/kotlin/org/wfanet/measurement/integration/common/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/integration/common/BUILD.bazel index f47074b4ae3..8764c226d65 100644 --- a/src/main/kotlin/org/wfanet/measurement/integration/common/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/integration/common/BUILD.bazel @@ -319,6 +319,7 @@ kt_jvm_library( ], deps = [ ":in_process_cmms_components", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:list_resources", "//src/main/kotlin/org/wfanet/measurement/kingdom/batch:measurement_system_prober", "//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common/service:data_services", "@wfa_common_jvm//imports/java/com/google/common/truth", diff --git a/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessMeasurementSystemProberIntegrationTest.kt b/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessMeasurementSystemProberIntegrationTest.kt index 0a015ca80d3..d9a14f29152 100644 --- a/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessMeasurementSystemProberIntegrationTest.kt +++ b/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessMeasurementSystemProberIntegrationTest.kt @@ -22,6 +22,9 @@ import java.io.File import java.nio.file.Paths import java.time.Clock import java.time.Duration +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.toList import kotlinx.coroutines.runBlocking import org.junit.After import org.junit.Before @@ -30,13 +33,15 @@ import org.junit.Rule import org.junit.Test import org.wfanet.measurement.api.v2alpha.DataProvidersGrpcKt.DataProvidersCoroutineStub import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub -import org.wfanet.measurement.api.v2alpha.ListMeasurementsResponse import org.wfanet.measurement.api.v2alpha.Measurement import org.wfanet.measurement.api.v2alpha.MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineStub import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt.MeasurementsCoroutineStub import org.wfanet.measurement.api.v2alpha.RequisitionsGrpcKt.RequisitionsCoroutineStub import org.wfanet.measurement.api.v2alpha.listMeasurementsRequest import org.wfanet.measurement.api.withAuthenticationKey +import org.wfanet.measurement.common.api.grpc.ResourceList +import org.wfanet.measurement.common.api.grpc.flattenConcat +import org.wfanet.measurement.common.api.grpc.listResources import org.wfanet.measurement.common.getRuntimePath import org.wfanet.measurement.common.identity.withPrincipalName import org.wfanet.measurement.common.testing.ProviderRule @@ -129,33 +134,32 @@ abstract class InProcessMeasurementSystemProberIntegrationTest( assertThat(measurements.size).isEqualTo(1) } + @OptIn(ExperimentalCoroutinesApi::class) // For `flattenConcat`. private suspend fun listMeasurements(): List { - var nextPageToken = "" val measurementConsumerData = inProcessCmmsComponents.getMeasurementConsumerData() - do { - val response: ListMeasurementsResponse = - try { - publicMeasurementsClient - .withAuthenticationKey(measurementConsumerData.apiAuthenticationKey) - .listMeasurements( - listMeasurementsRequest { - parent = measurementConsumerData.name - pageToken = nextPageToken - } - ) - } catch (e: StatusException) { - throw Exception( - "Unable to list measurements for measurement consumer ${measurementConsumerData.name}", - e, - ) + val measurementLists: Flow> = + publicMeasurementsClient + .withAuthenticationKey(measurementConsumerData.apiAuthenticationKey) + .listResources { pageToken -> + val response = + try { + listMeasurements( + listMeasurementsRequest { + parent = measurementConsumerData.name + this.pageToken = pageToken + } + ) + } catch (e: StatusException) { + throw Exception( + "Unable to list measurements for measurement consumer ${measurementConsumerData.name}", + e, + ) + } + ResourceList(response.measurementsList, response.nextPageToken) } - if (response.measurementsList.isNotEmpty()) { - return response.measurementsList - } - nextPageToken = response.nextPageToken - } while (nextPageToken.isNotEmpty()) - return emptyList() + + return measurementLists.flattenConcat().toList() } companion object { diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/batch/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/kingdom/batch/BUILD.bazel index 12d1d96644d..2d455bdfd32 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/batch/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/batch/BUILD.bazel @@ -47,6 +47,7 @@ kt_jvm_library( "//src/main/kotlin/org/wfanet/measurement/api:api_key_constants", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:packed_messages", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:list_resources", "//src/main/proto/wfa/measurement/api/v2alpha:data_provider_kt_jvm_proto", "//src/main/proto/wfa/measurement/api/v2alpha:data_providers_service_kt_jvm_grpc_proto", "//src/main/proto/wfa/measurement/api/v2alpha:event_group_kt_jvm_proto", diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/batch/MeasurementSystemProber.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/batch/MeasurementSystemProber.kt index b64a3a55842..e7c95bdd42c 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/batch/MeasurementSystemProber.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/batch/MeasurementSystemProber.kt @@ -27,8 +27,12 @@ import java.security.SecureRandom import java.time.Clock import java.time.Duration import java.util.logging.Logger +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.single +import kotlinx.coroutines.flow.singleOrNull import org.wfanet.measurement.api.v2alpha.CanonicalRequisitionKey -import org.wfanet.measurement.api.v2alpha.DataProvider import org.wfanet.measurement.api.v2alpha.DataProvidersGrpcKt import org.wfanet.measurement.api.v2alpha.EventGroup import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt @@ -60,6 +64,9 @@ import org.wfanet.measurement.api.v2alpha.requisitionSpec import org.wfanet.measurement.api.v2alpha.unpack import org.wfanet.measurement.api.withAuthenticationKey import org.wfanet.measurement.common.Instrumentation +import org.wfanet.measurement.common.api.grpc.ResourceList +import org.wfanet.measurement.common.api.grpc.flattenConcat +import org.wfanet.measurement.common.api.grpc.listResources import org.wfanet.measurement.common.crypto.Hashing import org.wfanet.measurement.common.crypto.SigningKeyHandle import org.wfanet.measurement.common.crypto.readCertificate @@ -198,43 +205,31 @@ class MeasurementSystemProber( private suspend fun buildDataProviderNameToEventGroup(): Map { val dataProviderNameToEventGroup = mutableMapOf() for (dataProviderName in dataProviderNames) { - val getDataProviderRequest = getDataProviderRequest { name = dataProviderName } - val dataProvider: DataProvider = - try { - dataProvidersStub - .withAuthenticationKey(apiAuthenticationKey) - .getDataProvider(getDataProviderRequest) - } catch (e: StatusException) { - throw Exception("Unable to get DataProvider with name $dataProviderName", e) - } - - // TODO(@roaminggypsy): Implement QA event group logic using simulatorEventGroupName - val listEventGroupsRequest = listEventGroupsRequest { - parent = measurementConsumerName - filter = ListEventGroupsRequestKt.filter { dataProviders += dataProviderName } - } - - val eventGroups: List = - try { - eventGroupsStub - .withAuthenticationKey(apiAuthenticationKey) - .listEventGroups(listEventGroupsRequest) - .eventGroupsList - .toList() - } catch (e: StatusException) { - throw Exception( - "Unable to get event groups associated with measurement consumer $measurementConsumerName and data provider $dataProviderName", - e, - ) - } - - if (eventGroups.size != 1) { - throw IllegalStateException( - "here should be exactly 1:1 mapping between a data provider and an event group, but data provider $dataProvider is related to ${eventGroups.size} event groups" - ) - } + val eventGroup: EventGroup = + eventGroupsStub + .withAuthenticationKey(apiAuthenticationKey) + .listResources(1) { pageToken, remaining -> + val request = listEventGroupsRequest { + parent = measurementConsumerName + filter = ListEventGroupsRequestKt.filter { dataProviders += dataProviderName } + this.pageToken = pageToken + pageSize = remaining + } + val response = + try { + listEventGroups(request) + } catch (e: StatusException) { + throw Exception( + "Unable to get event groups associated with measurement consumer $measurementConsumerName and data provider $dataProviderName", + e, + ) + } + ResourceList(response.eventGroupsList, response.nextPageToken) + } + .map { it.single() } + .single() - dataProviderNameToEventGroup[dataProviderName] = eventGroups[0] + dataProviderNameToEventGroup[dataProviderName] = eventGroup } return dataProviderNameToEventGroup } @@ -253,55 +248,47 @@ class MeasurementSystemProber( return clock.instant() >= nextMeasurementEarliestInstant } + @OptIn(ExperimentalCoroutinesApi::class) // For `flattenConcat`. private suspend fun getLastUpdatedMeasurement(): Measurement? { - var nextPageToken = "" - do { - val response: ListMeasurementsResponse = - try { - measurementsStub - .withAuthenticationKey(apiAuthenticationKey) - .listMeasurements( + val measurements: Flow> = + measurementsStub.withAuthenticationKey(apiAuthenticationKey).listResources(1) { + pageToken, + remaining -> + val response: ListMeasurementsResponse = + try { + listMeasurements( listMeasurementsRequest { parent = measurementConsumerName - this.pageSize = 1 - pageToken = nextPageToken + this.pageToken = pageToken + this.pageSize = remaining } ) - } catch (e: StatusException) { - throw Exception( - "Unable to list measurements for measurement consumer $measurementConsumerName", - e, - ) - } - if (response.measurementsList.isNotEmpty()) { - return response.measurementsList.single() + } catch (e: StatusException) { + throw Exception( + "Unable to list measurements for measurement consumer $measurementConsumerName", + e, + ) + } + ResourceList(response.measurementsList, response.nextPageToken) } - nextPageToken = response.nextPageToken - } while (nextPageToken.isNotEmpty()) - return null + + return measurements.flattenConcat().singleOrNull() } - private suspend fun getRequisitionsForMeasurement(measurementName: String): List { - var nextPageToken = "" - val requisitions = mutableListOf() - do { - val response: ListRequisitionsResponse = - try { - requisitionsStub - .withAuthenticationKey(apiAuthenticationKey) - .listRequisitions( - listRequisitionsRequest { - parent = measurementName - pageToken = nextPageToken - } - ) - } catch (e: StatusException) { - throw Exception("Unable to list requisitions for measurement $measurementName", e) - } - requisitions.addAll(response.requisitionsList) - nextPageToken = response.nextPageToken - } while (nextPageToken.isNotEmpty()) - return requisitions + @OptIn(ExperimentalCoroutinesApi::class) // For `flattenConcat`. + private fun getRequisitionsForMeasurement(measurementName: String): Flow { + return requisitionsStub + .withAuthenticationKey(apiAuthenticationKey) + .listResources { pageToken -> + val response: ListRequisitionsResponse = + try { + listRequisitions(listRequisitionsRequest { this.pageToken = pageToken }) + } catch (e: StatusException) { + throw Exception("Unable to list requisitions for measurement $measurementName", e) + } + ResourceList(response.requisitionsList, response.nextPageToken) + } + .flattenConcat() } private suspend fun getDataProviderEntry( @@ -360,10 +347,12 @@ class MeasurementSystemProber( private suspend fun updateLastTerminalRequisitionGauge(lastUpdatedMeasurement: Measurement) { val requisitions = getRequisitionsForMeasurement(lastUpdatedMeasurement.name) - for (requisition in requisitions) { + requisitions.collect { requisition -> if (requisition.state == Requisition.State.FULFILLED) { - val requisitionKey = CanonicalRequisitionKey.fromName(requisition.name) - require(requisitionKey != null) { "CanonicalRequisitionKey cannot be null" } + val requisitionKey = + requireNotNull(CanonicalRequisitionKey.fromName(requisition.name)) { + "Requisition name ${requisition.name} is invalid" + } val dataProviderName: String = requisitionKey.dataProviderId val attributes = Attributes.of(DATA_PROVIDER_ATTRIBUTE_KEY, dataProviderName) lastTerminalRequisitionTimeGauge.set( diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/BUILD.bazel index 445524ac557..3cdded1d308 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/BUILD.bazel @@ -132,7 +132,7 @@ kt_jvm_library( "context_keys", ":reporting_principal", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key", - "//src/main/kotlin/org/wfanet/measurement/common/api/grpc", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:akid_principal_server_interceptor", "//src/main/kotlin/org/wfanet/measurement/common/identity", "@wfa_common_jvm//imports/java/com/google/protobuf", "@wfa_common_jvm//imports/java/io/grpc:api", diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/BUILD.bazel index 22d6343b537..fc750cdc976 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/BUILD.bazel @@ -52,7 +52,7 @@ kt_jvm_library( "context_keys", ":reporting_principal", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key", - "//src/main/kotlin/org/wfanet/measurement/common/api/grpc", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:akid_principal_server_interceptor", "//src/main/kotlin/org/wfanet/measurement/common/identity", "@wfa_common_jvm//imports/java/com/google/protobuf", "@wfa_common_jvm//imports/java/io/grpc:api", @@ -120,6 +120,7 @@ kt_jvm_library( "//imports/java/org/projectnessie/cel", "//src/main/kotlin/org/wfanet/measurement/api:api_key_constants", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:packed_messages", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:list_resources", "//src/main/kotlin/org/wfanet/measurement/reporting/service/api:cel_env_provider", "//src/main/kotlin/org/wfanet/measurement/reporting/service/api:encryption_key_pair_store", "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha:principal_server_interceptor", diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsService.kt index 825e90c2bff..650a883fa68 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsService.kt @@ -20,10 +20,13 @@ import com.google.protobuf.DynamicMessage import com.google.protobuf.kotlin.unpack import io.grpc.Context import io.grpc.Deadline +import io.grpc.Deadline.Ticker import io.grpc.Status import io.grpc.StatusException import java.security.GeneralSecurityException import java.util.concurrent.TimeUnit +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.transformWhile import org.projectnessie.cel.common.types.Err import org.projectnessie.cel.common.types.ref.Val import org.wfanet.measurement.api.v2alpha.DataProviderKey @@ -31,9 +34,12 @@ import org.wfanet.measurement.api.v2alpha.EncryptionPublicKey import org.wfanet.measurement.api.v2alpha.EventGroup as CmmsEventGroup import org.wfanet.measurement.api.v2alpha.EventGroupKey as CmmsEventGroupKey import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub as CmmsEventGroupsCoroutineStub +import org.wfanet.measurement.api.v2alpha.ListEventGroupsResponse as CmmsListEventGroupsResponse import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey import org.wfanet.measurement.api.v2alpha.listEventGroupsRequest import org.wfanet.measurement.api.withAuthenticationKey +import org.wfanet.measurement.common.api.grpc.ResourceList +import org.wfanet.measurement.common.api.grpc.listResources import org.wfanet.measurement.common.crypto.PrivateKeyHandle import org.wfanet.measurement.common.grpc.grpcRequire import org.wfanet.measurement.common.grpc.grpcRequireNotNull @@ -52,6 +58,7 @@ class EventGroupsService( private val cmmsEventGroupsStub: CmmsEventGroupsCoroutineStub, private val encryptionKeyPairStore: EncryptionKeyPairStore, private val celEnvProvider: CelEnvProvider, + private val ticker: Ticker = Deadline.getSystemTicker(), ) : EventGroupsCoroutineImplBase() { override suspend fun listEventGroups(request: ListEventGroupsRequest): ListEventGroupsResponse { val parentKey = @@ -71,66 +78,79 @@ class EventGroupsService( } } + val deadline: Deadline = + Context.current().deadline + ?: Deadline.after(RPC_DEFAULT_DEADLINE_MILLIS, TimeUnit.MILLISECONDS, ticker) val apiAuthenticationKey: String = principal.config.apiKey grpcRequire(request.pageSize >= 0) { "page_size cannot be negative" } - val pageSize = - when { - request.pageSize < MIN_PAGE_SIZE -> DEFAULT_PAGE_SIZE - request.pageSize > MAX_PAGE_SIZE -> MAX_PAGE_SIZE - else -> request.pageSize - } + val limit = + if (request.pageSize > 0) request.pageSize.coerceAtMost(MAX_PAGE_SIZE) else DEFAULT_PAGE_SIZE + val parent = parentKey.toName() + val eventGroupLists: Flow> = + cmmsEventGroupsStub.withAuthenticationKey(apiAuthenticationKey).listResources( + limit, + request.pageToken, + ) { pageToken, remaining -> + val response: CmmsListEventGroupsResponse = + listEventGroups( + listEventGroupsRequest { + this.parent = parent + this.pageSize = remaining + this.pageToken = pageToken + } + ) - var nextPageToken = request.pageToken - val deadline = Context.current().deadline ?: Deadline.after(30, TimeUnit.SECONDS) - do { - val cmmsListEventGroupResponse = - try { - cmmsEventGroupsStub - .withAuthenticationKey(apiAuthenticationKey) - .listEventGroups( - listEventGroupsRequest { - parent = parentKey.toName() - this.pageSize = pageSize - pageToken = nextPageToken + val eventGroups: List = + response.eventGroupsList.map { + val cmmsMetadata: CmmsEventGroup.Metadata? = + if (it.hasEncryptedMetadata()) { + decryptMetadata(it, principal.resourceKey.toName()) + } else { + null } - ) - } catch (e: StatusException) { - throw when (e.status.code) { - Status.Code.DEADLINE_EXCEEDED -> Status.DEADLINE_EXCEEDED - Status.Code.CANCELLED -> Status.CANCELLED - else -> Status.UNKNOWN - } - .withCause(e) - .asRuntimeException() - } - val cmmsEventGroups = cmmsListEventGroupResponse.eventGroupsList - val eventGroups = - cmmsEventGroups.map { - val cmmsMetadata: CmmsEventGroup.Metadata? = - if (it.hasEncryptedMetadata()) { - decryptMetadata(it, principal.resourceKey.toName()) - } else { - null - } + it.toEventGroup(cmmsMetadata) + } - it.toEventGroup(cmmsMetadata) - } + ResourceList(filterEventGroups(eventGroups, request.filter), response.nextPageToken) + } - val filteredEventGroups = filterEventGroups(eventGroups, request.filter) - if (filteredEventGroups.size > 0) { - return listEventGroupsResponse { - this.eventGroups += filteredEventGroups - this.nextPageToken = cmmsListEventGroupResponse.nextPageToken + var hasResponse = false + return listEventGroupsResponse { + try { + eventGroupLists + .transformWhile { + emit(it) + deadline.timeRemaining(TimeUnit.MILLISECONDS) > RPC_DEADLINE_OVERHEAD_MILLIS + } + .collect { eventGroupList -> + this.eventGroups += eventGroupList + nextPageToken = eventGroupList.nextPageToken + hasResponse = true + } + } catch (e: StatusException) { + when (e.status.code) { + Status.Code.DEADLINE_EXCEEDED, + Status.Code.CANCELLED -> { + if (!hasResponse) { + // Only throw an error if we don't have any response yet. Otherwise, just return what + // we have so far. + throw Status.DEADLINE_EXCEEDED.withDescription( + "Timed out listing EventGroups from backend" + ) + .withCause(e) + .asRuntimeException() + } + } + else -> + throw Status.UNKNOWN.withDescription("Error listing EventGroups from backend") + .withCause(e) + .asRuntimeException() } - } else { - nextPageToken = cmmsListEventGroupResponse.nextPageToken } - } while (deadline.timeRemaining(TimeUnit.SECONDS) > 5) - - return listEventGroupsResponse { this.nextPageToken = nextPageToken } + } } private suspend fun filterEventGroups( @@ -259,8 +279,13 @@ class EventGroupsService( companion object { private const val METADATA_FIELD = "metadata.metadata" - private const val MIN_PAGE_SIZE = 1 private const val DEFAULT_PAGE_SIZE = 50 private const val MAX_PAGE_SIZE = 1000 + + /** Overhead to allow for RPC deadlines in milliseconds. */ + private const val RPC_DEADLINE_OVERHEAD_MILLIS = 100L + + /** Default RPC deadline in milliseconds. */ + private const val RPC_DEFAULT_DEADLINE_MILLIS = 30_000L } } diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsServiceTest.kt index befbdcedd0b..42df07ae256 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsServiceTest.kt @@ -19,11 +19,13 @@ package org.wfanet.measurement.reporting.service.api.v2alpha import com.google.common.truth.Truth.assertThat import com.google.common.truth.extensions.proto.ProtoTruth.assertThat import com.google.protobuf.Any +import io.grpc.Deadline import io.grpc.Status import io.grpc.StatusRuntimeException import java.nio.file.Path import java.nio.file.Paths import java.time.Duration +import java.util.concurrent.TimeUnit import kotlin.test.assertFailsWith import kotlinx.coroutines.runBlocking import org.junit.After @@ -110,6 +112,7 @@ class EventGroupsServiceTest { private lateinit var celEnvCacheProvider: CelEnvCacheProvider private lateinit var service: EventGroupsService + private val fakeTicker = SettableSystemTicker() @Before fun initService() { @@ -126,6 +129,7 @@ class EventGroupsServiceTest { EventGroupsCoroutineStub(grpcTestServerRule.channel), ENCRYPTION_KEY_PAIR_STORE, celEnvCacheProvider, + fakeTicker, ) } @@ -152,13 +156,13 @@ class EventGroupsServiceTest { whenever(publicKingdomEventGroupsMock.listEventGroups(any())) .thenReturn( listEventGroupsResponse { - nextPageToken = "1" eventGroups += cmmsEventGroup2 + nextPageToken = "1" } ) .thenReturn( listEventGroupsResponse { - eventGroups += listOf(CMMS_EVENT_GROUP, cmmsEventGroup2) + eventGroups += CMMS_EVENT_GROUP nextPageToken = "2" } ) @@ -170,6 +174,7 @@ class EventGroupsServiceTest { listEventGroupsRequest { parent = MEASUREMENT_CONSUMER_NAME filter = "metadata.metadata.publisher_id > 5" + pageSize = 1 } ) } @@ -200,7 +205,10 @@ class EventGroupsServiceTest { eventGroups += CMMS_EVENT_GROUP } ) - + .then { + // Advance time. + fakeTicker.setNanoTime(fakeTicker.nanoTime() + TimeUnit.SECONDS.toNanos(30)) + } val response = withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { runBlocking { @@ -443,7 +451,12 @@ class EventGroupsServiceTest { val response = withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { runBlocking { - service.listEventGroups(listEventGroupsRequest { parent = MEASUREMENT_CONSUMER_NAME }) + service.listEventGroups( + listEventGroupsRequest { + parent = MEASUREMENT_CONSUMER_NAME + pageSize = 2 + } + ) } } @@ -455,7 +468,7 @@ class EventGroupsServiceTest { .isEqualTo( cmmsListEventGroupsRequest { parent = MEASUREMENT_CONSUMER_NAME - pageSize = DEFAULT_PAGE_SIZE + pageSize = 2 } ) } @@ -813,4 +826,20 @@ class EventGroupsServiceTest { } } } + + /** + * Fake [Deadline.Ticker] implementation that allows time to be specified to override delegation + * to the system ticker. + */ + private class SettableSystemTicker : Deadline.Ticker() { + private var nanoTime: Long? = null + + fun setNanoTime(value: Long) { + nanoTime = value + } + + override fun nanoTime(): Long { + return this.nanoTime ?: Deadline.getSystemTicker().nanoTime() + } + } } From 22c5268dc5334f7a3a2185177d354d7026b9d9b4 Mon Sep 17 00:00:00 2001 From: Sanjay Vasandani Date: Fri, 22 Nov 2024 09:24:45 -0800 Subject: [PATCH 2/3] fix: Read all EventGroups from simulators rather than stopping at first page (#1927) The MC and EDP simulators were only reading the first page of results from ListEventGroups. This means that in environments with more EventGroups, not all test EventGroups would be read. This is exacerbated by #1916, as the default page size is now smaller. --- .../loadtest/dataprovider/EdpSimulator.kt | 43 +++++++++------- .../loadtest/measurementconsumer/BUILD.bazel | 1 + .../MeasurementConsumerSimulator.kt | 50 +++++++++++++------ 3 files changed, 60 insertions(+), 34 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulator.kt b/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulator.kt index 7a29bd5f9d3..4737bcf56bd 100644 --- a/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulator.kt +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulator.kt @@ -36,9 +36,11 @@ import kotlin.math.roundToInt import kotlin.random.Random import kotlin.random.asJavaRandom import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.asFlow import kotlinx.coroutines.flow.emitAll +import kotlinx.coroutines.flow.firstOrNull import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.map import kotlinx.coroutines.withContext @@ -109,6 +111,9 @@ import org.wfanet.measurement.api.v2alpha.updateEventGroupRequest import org.wfanet.measurement.common.Health import org.wfanet.measurement.common.ProtoReflection import org.wfanet.measurement.common.SettableHealth +import org.wfanet.measurement.common.api.grpc.ResourceList +import org.wfanet.measurement.common.api.grpc.flattenConcat +import org.wfanet.measurement.common.api.grpc.listResources import org.wfanet.measurement.common.asBufferedFlow import org.wfanet.measurement.common.crypto.authorityKeyIdentifier import org.wfanet.measurement.common.crypto.readCertificate @@ -339,27 +344,29 @@ class EdpSimulator( * Returns the first [EventGroup] for this `DataProvider` and [MeasurementConsumer] with * [eventGroupReferenceId], or `null` if not found. */ + @OptIn(ExperimentalCoroutinesApi::class) // For `flattenConcat`. private suspend fun getEventGroupByReferenceId(eventGroupReferenceId: String): EventGroup? { - val response = - try { - eventGroupsStub.listEventGroups( - listEventGroupsRequest { - parent = edpData.name - filter = - ListEventGroupsRequestKt.filter { measurementConsumers += measurementConsumerName } - pageSize = Int.MAX_VALUE + return eventGroupsStub + .listResources { pageToken -> + val response = + try { + listEventGroups( + listEventGroupsRequest { + parent = edpData.name + filter = + ListEventGroupsRequestKt.filter { + measurementConsumers += measurementConsumerName + } + this.pageToken = pageToken + } + ) + } catch (e: StatusException) { + throw Exception("Error listing EventGroups", e) } - ) - } catch (e: StatusException) { - throw Exception("Error listing EventGroups", e) + ResourceList(response.eventGroupsList, response.nextPageToken) } - - // TODO(@SanjayVas): Support filtering by reference ID so we don't need to handle multiple pages - // of EventGroups. - check(response.nextPageToken.isEmpty()) { - "Too many EventGroups for ${edpData.name} and $measurementConsumerName" - } - return response.eventGroupsList.find { it.eventGroupReferenceId == eventGroupReferenceId } + .flattenConcat() + .firstOrNull { it.eventGroupReferenceId == eventGroupReferenceId } } private suspend fun ensureMetadataDescriptor( diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer/BUILD.bazel index a6b9b428e95..1451a300dae 100644 --- a/src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer/BUILD.bazel @@ -28,6 +28,7 @@ kt_jvm_library( "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:packed_messages", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha/testing", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:list_resources", "//src/main/kotlin/org/wfanet/measurement/common/identity", "//src/main/kotlin/org/wfanet/measurement/integration/common:configs", "//src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha:api_key_authentication_server_interceptor", diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer/MeasurementConsumerSimulator.kt b/src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer/MeasurementConsumerSimulator.kt index bbdf684ba45..f008d4ef94a 100644 --- a/src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer/MeasurementConsumerSimulator.kt +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer/MeasurementConsumerSimulator.kt @@ -31,6 +31,10 @@ import kotlin.math.log2 import kotlin.math.max import kotlin.math.sqrt import kotlin.random.Random +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.filter +import kotlinx.coroutines.flow.toList import kotlinx.coroutines.time.delay import org.projectnessie.cel.Program import org.wfanet.measurement.api.v2alpha.Certificate @@ -92,6 +96,9 @@ import org.wfanet.measurement.api.v2alpha.unpack import org.wfanet.measurement.api.withAuthenticationKey import org.wfanet.measurement.common.ExponentialBackoff import org.wfanet.measurement.common.OpenEndTimeRange +import org.wfanet.measurement.common.api.grpc.ResourceList +import org.wfanet.measurement.common.api.grpc.flattenConcat +import org.wfanet.measurement.common.api.grpc.listResources import org.wfanet.measurement.common.coerceAtMost import org.wfanet.measurement.common.crypto.Hashing import org.wfanet.measurement.common.crypto.PrivateKeyHandle @@ -812,11 +819,13 @@ class MeasurementConsumerSimulator( maxDataProviders: Int = 20, ): MeasurementInfo { val eventGroups: List = - listEventGroups(measurementConsumer.name).filter { - it.eventGroupReferenceId.startsWith( - TestIdentifiers.SIMULATOR_EVENT_GROUP_REFERENCE_ID_PREFIX - ) - } + listEventGroups(measurementConsumer.name) + .filter { + it.eventGroupReferenceId.startsWith( + TestIdentifiers.SIMULATOR_EVENT_GROUP_REFERENCE_ID_PREFIX + ) + } + .toList() check(eventGroups.isNotEmpty()) { "No event groups found for ${measurementConsumer.name}" } val nonceHashes = mutableListOf() val keyToDataProviderMap: Map = @@ -1255,16 +1264,25 @@ class MeasurementConsumerSimulator( } } - private suspend fun listEventGroups(measurementConsumer: String): List { - val request = listEventGroupsRequest { parent = measurementConsumer } - try { - return eventGroupsClient - .withAuthenticationKey(measurementConsumerData.apiAuthenticationKey) - .listEventGroups(request) - .eventGroupsList - } catch (e: StatusException) { - throw Exception("Error listing event groups for MC $measurementConsumer", e) - } + @OptIn(ExperimentalCoroutinesApi::class) // For `flattenConcat`. + private fun listEventGroups(measurementConsumer: String): Flow { + return eventGroupsClient + .withAuthenticationKey(measurementConsumerData.apiAuthenticationKey) + .listResources { pageToken -> + val response = + try { + listEventGroups( + listEventGroupsRequest { + parent = measurementConsumer + this.pageToken = pageToken + } + ) + } catch (e: StatusException) { + throw Exception("Error listing event groups for MC $measurementConsumer", e) + } + ResourceList(response.eventGroupsList, response.nextPageToken) + } + .flattenConcat() } private fun extractDataProviderKey(eventGroupName: String): DataProviderKey { @@ -1283,7 +1301,7 @@ class MeasurementConsumerSimulator( } } - private suspend fun buildRequisitionInfo( + private fun buildRequisitionInfo( dataProvider: DataProvider, eventGroups: List, measurementConsumer: MeasurementConsumer, From c479098df00fc3550435921095cf3d70166254e6 Mon Sep 17 00:00:00 2001 From: renjiez Date: Fri, 22 Nov 2024 17:45:07 +0000 Subject: [PATCH 3/3] Update BUILD.bazel --- .../org/wfanet/measurement/loadtest/dataprovider/BUILD.bazel | 1 + 1 file changed, 1 insertion(+) diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/BUILD.bazel index 64479522bc0..99d3e013189 100644 --- a/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/BUILD.bazel @@ -123,6 +123,7 @@ kt_jvm_library( "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:packed_messages", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key", "//src/main/kotlin/org/wfanet/measurement/common:health", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:list_resources", "//src/main/kotlin/org/wfanet/measurement/dataprovider:requisition_fulfiller", "//src/main/kotlin/org/wfanet/measurement/eventdataprovider/noiser", "//src/main/kotlin/org/wfanet/measurement/eventdataprovider/privacybudgetmanagement:privacy_budget_manager",