From 7b4e9220c88e41a373478eeca1cbbfae6e80008a Mon Sep 17 00:00:00 2001 From: Sanjay Vasandani Date: Thu, 21 Nov 2024 15:52:38 -0800 Subject: [PATCH] fix: Read all EventGroups from simulators rather than stopping at first page --- .../loadtest/dataprovider/EdpSimulator.kt | 43 +++++++++------- .../loadtest/measurementconsumer/BUILD.bazel | 1 + .../MeasurementConsumerSimulator.kt | 50 +++++++++++++------ 3 files changed, 60 insertions(+), 34 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulator.kt b/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulator.kt index 7a29bd5f9d3..4737bcf56bd 100644 --- a/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulator.kt +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulator.kt @@ -36,9 +36,11 @@ import kotlin.math.roundToInt import kotlin.random.Random import kotlin.random.asJavaRandom import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.asFlow import kotlinx.coroutines.flow.emitAll +import kotlinx.coroutines.flow.firstOrNull import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.map import kotlinx.coroutines.withContext @@ -109,6 +111,9 @@ import org.wfanet.measurement.api.v2alpha.updateEventGroupRequest import org.wfanet.measurement.common.Health import org.wfanet.measurement.common.ProtoReflection import org.wfanet.measurement.common.SettableHealth +import org.wfanet.measurement.common.api.grpc.ResourceList +import org.wfanet.measurement.common.api.grpc.flattenConcat +import org.wfanet.measurement.common.api.grpc.listResources import org.wfanet.measurement.common.asBufferedFlow import org.wfanet.measurement.common.crypto.authorityKeyIdentifier import org.wfanet.measurement.common.crypto.readCertificate @@ -339,27 +344,29 @@ class EdpSimulator( * Returns the first [EventGroup] for this `DataProvider` and [MeasurementConsumer] with * [eventGroupReferenceId], or `null` if not found. */ + @OptIn(ExperimentalCoroutinesApi::class) // For `flattenConcat`. private suspend fun getEventGroupByReferenceId(eventGroupReferenceId: String): EventGroup? { - val response = - try { - eventGroupsStub.listEventGroups( - listEventGroupsRequest { - parent = edpData.name - filter = - ListEventGroupsRequestKt.filter { measurementConsumers += measurementConsumerName } - pageSize = Int.MAX_VALUE + return eventGroupsStub + .listResources { pageToken -> + val response = + try { + listEventGroups( + listEventGroupsRequest { + parent = edpData.name + filter = + ListEventGroupsRequestKt.filter { + measurementConsumers += measurementConsumerName + } + this.pageToken = pageToken + } + ) + } catch (e: StatusException) { + throw Exception("Error listing EventGroups", e) } - ) - } catch (e: StatusException) { - throw Exception("Error listing EventGroups", e) + ResourceList(response.eventGroupsList, response.nextPageToken) } - - // TODO(@SanjayVas): Support filtering by reference ID so we don't need to handle multiple pages - // of EventGroups. - check(response.nextPageToken.isEmpty()) { - "Too many EventGroups for ${edpData.name} and $measurementConsumerName" - } - return response.eventGroupsList.find { it.eventGroupReferenceId == eventGroupReferenceId } + .flattenConcat() + .firstOrNull { it.eventGroupReferenceId == eventGroupReferenceId } } private suspend fun ensureMetadataDescriptor( diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer/BUILD.bazel index a6b9b428e95..1451a300dae 100644 --- a/src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer/BUILD.bazel @@ -28,6 +28,7 @@ kt_jvm_library( "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:packed_messages", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha/testing", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:list_resources", "//src/main/kotlin/org/wfanet/measurement/common/identity", "//src/main/kotlin/org/wfanet/measurement/integration/common:configs", "//src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha:api_key_authentication_server_interceptor", diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer/MeasurementConsumerSimulator.kt b/src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer/MeasurementConsumerSimulator.kt index bbdf684ba45..f008d4ef94a 100644 --- a/src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer/MeasurementConsumerSimulator.kt +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer/MeasurementConsumerSimulator.kt @@ -31,6 +31,10 @@ import kotlin.math.log2 import kotlin.math.max import kotlin.math.sqrt import kotlin.random.Random +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.filter +import kotlinx.coroutines.flow.toList import kotlinx.coroutines.time.delay import org.projectnessie.cel.Program import org.wfanet.measurement.api.v2alpha.Certificate @@ -92,6 +96,9 @@ import org.wfanet.measurement.api.v2alpha.unpack import org.wfanet.measurement.api.withAuthenticationKey import org.wfanet.measurement.common.ExponentialBackoff import org.wfanet.measurement.common.OpenEndTimeRange +import org.wfanet.measurement.common.api.grpc.ResourceList +import org.wfanet.measurement.common.api.grpc.flattenConcat +import org.wfanet.measurement.common.api.grpc.listResources import org.wfanet.measurement.common.coerceAtMost import org.wfanet.measurement.common.crypto.Hashing import org.wfanet.measurement.common.crypto.PrivateKeyHandle @@ -812,11 +819,13 @@ class MeasurementConsumerSimulator( maxDataProviders: Int = 20, ): MeasurementInfo { val eventGroups: List = - listEventGroups(measurementConsumer.name).filter { - it.eventGroupReferenceId.startsWith( - TestIdentifiers.SIMULATOR_EVENT_GROUP_REFERENCE_ID_PREFIX - ) - } + listEventGroups(measurementConsumer.name) + .filter { + it.eventGroupReferenceId.startsWith( + TestIdentifiers.SIMULATOR_EVENT_GROUP_REFERENCE_ID_PREFIX + ) + } + .toList() check(eventGroups.isNotEmpty()) { "No event groups found for ${measurementConsumer.name}" } val nonceHashes = mutableListOf() val keyToDataProviderMap: Map = @@ -1255,16 +1264,25 @@ class MeasurementConsumerSimulator( } } - private suspend fun listEventGroups(measurementConsumer: String): List { - val request = listEventGroupsRequest { parent = measurementConsumer } - try { - return eventGroupsClient - .withAuthenticationKey(measurementConsumerData.apiAuthenticationKey) - .listEventGroups(request) - .eventGroupsList - } catch (e: StatusException) { - throw Exception("Error listing event groups for MC $measurementConsumer", e) - } + @OptIn(ExperimentalCoroutinesApi::class) // For `flattenConcat`. + private fun listEventGroups(measurementConsumer: String): Flow { + return eventGroupsClient + .withAuthenticationKey(measurementConsumerData.apiAuthenticationKey) + .listResources { pageToken -> + val response = + try { + listEventGroups( + listEventGroupsRequest { + parent = measurementConsumer + this.pageToken = pageToken + } + ) + } catch (e: StatusException) { + throw Exception("Error listing event groups for MC $measurementConsumer", e) + } + ResourceList(response.eventGroupsList, response.nextPageToken) + } + .flattenConcat() } private fun extractDataProviderKey(eventGroupName: String): DataProviderKey { @@ -1283,7 +1301,7 @@ class MeasurementConsumerSimulator( } } - private suspend fun buildRequisitionInfo( + private fun buildRequisitionInfo( dataProvider: DataProvider, eventGroups: List, measurementConsumer: MeasurementConsumer,