Skip to content

Commit

Permalink
fix: Read all EventGroups from simulators rather than stopping at fir…
Browse files Browse the repository at this point in the history
…st page
  • Loading branch information
SanjayVas committed Nov 21, 2024
1 parent 1391cab commit 7b4e922
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ import kotlin.math.roundToInt
import kotlin.random.Random
import kotlin.random.asJavaRandom
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.asFlow
import kotlinx.coroutines.flow.emitAll
import kotlinx.coroutines.flow.firstOrNull
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.withContext
Expand Down Expand Up @@ -109,6 +111,9 @@ import org.wfanet.measurement.api.v2alpha.updateEventGroupRequest
import org.wfanet.measurement.common.Health
import org.wfanet.measurement.common.ProtoReflection
import org.wfanet.measurement.common.SettableHealth
import org.wfanet.measurement.common.api.grpc.ResourceList
import org.wfanet.measurement.common.api.grpc.flattenConcat
import org.wfanet.measurement.common.api.grpc.listResources
import org.wfanet.measurement.common.asBufferedFlow
import org.wfanet.measurement.common.crypto.authorityKeyIdentifier
import org.wfanet.measurement.common.crypto.readCertificate
Expand Down Expand Up @@ -339,27 +344,29 @@ class EdpSimulator(
* Returns the first [EventGroup] for this `DataProvider` and [MeasurementConsumer] with
* [eventGroupReferenceId], or `null` if not found.
*/
@OptIn(ExperimentalCoroutinesApi::class) // For `flattenConcat`.
private suspend fun getEventGroupByReferenceId(eventGroupReferenceId: String): EventGroup? {
val response =
try {
eventGroupsStub.listEventGroups(
listEventGroupsRequest {
parent = edpData.name
filter =
ListEventGroupsRequestKt.filter { measurementConsumers += measurementConsumerName }
pageSize = Int.MAX_VALUE
return eventGroupsStub
.listResources { pageToken ->
val response =
try {
listEventGroups(
listEventGroupsRequest {
parent = edpData.name
filter =
ListEventGroupsRequestKt.filter {
measurementConsumers += measurementConsumerName
}
this.pageToken = pageToken
}
)
} catch (e: StatusException) {
throw Exception("Error listing EventGroups", e)
}
)
} catch (e: StatusException) {
throw Exception("Error listing EventGroups", e)
ResourceList(response.eventGroupsList, response.nextPageToken)
}

// TODO(@SanjayVas): Support filtering by reference ID so we don't need to handle multiple pages
// of EventGroups.
check(response.nextPageToken.isEmpty()) {
"Too many EventGroups for ${edpData.name} and $measurementConsumerName"
}
return response.eventGroupsList.find { it.eventGroupReferenceId == eventGroupReferenceId }
.flattenConcat()
.firstOrNull { it.eventGroupReferenceId == eventGroupReferenceId }
}

private suspend fun ensureMetadataDescriptor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ kt_jvm_library(
"//src/main/kotlin/org/wfanet/measurement/api/v2alpha:packed_messages",
"//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key",
"//src/main/kotlin/org/wfanet/measurement/api/v2alpha/testing",
"//src/main/kotlin/org/wfanet/measurement/common/api/grpc:list_resources",
"//src/main/kotlin/org/wfanet/measurement/common/identity",
"//src/main/kotlin/org/wfanet/measurement/integration/common:configs",
"//src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha:api_key_authentication_server_interceptor",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ import kotlin.math.log2
import kotlin.math.max
import kotlin.math.sqrt
import kotlin.random.Random
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.filter
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.time.delay
import org.projectnessie.cel.Program
import org.wfanet.measurement.api.v2alpha.Certificate
Expand Down Expand Up @@ -92,6 +96,9 @@ import org.wfanet.measurement.api.v2alpha.unpack
import org.wfanet.measurement.api.withAuthenticationKey
import org.wfanet.measurement.common.ExponentialBackoff
import org.wfanet.measurement.common.OpenEndTimeRange
import org.wfanet.measurement.common.api.grpc.ResourceList
import org.wfanet.measurement.common.api.grpc.flattenConcat
import org.wfanet.measurement.common.api.grpc.listResources
import org.wfanet.measurement.common.coerceAtMost
import org.wfanet.measurement.common.crypto.Hashing
import org.wfanet.measurement.common.crypto.PrivateKeyHandle
Expand Down Expand Up @@ -812,11 +819,13 @@ class MeasurementConsumerSimulator(
maxDataProviders: Int = 20,
): MeasurementInfo {
val eventGroups: List<EventGroup> =
listEventGroups(measurementConsumer.name).filter {
it.eventGroupReferenceId.startsWith(
TestIdentifiers.SIMULATOR_EVENT_GROUP_REFERENCE_ID_PREFIX
)
}
listEventGroups(measurementConsumer.name)
.filter {
it.eventGroupReferenceId.startsWith(
TestIdentifiers.SIMULATOR_EVENT_GROUP_REFERENCE_ID_PREFIX
)
}
.toList()
check(eventGroups.isNotEmpty()) { "No event groups found for ${measurementConsumer.name}" }
val nonceHashes = mutableListOf<ByteString>()
val keyToDataProviderMap: Map<DataProviderKey, DataProvider> =
Expand Down Expand Up @@ -1255,16 +1264,25 @@ class MeasurementConsumerSimulator(
}
}

private suspend fun listEventGroups(measurementConsumer: String): List<EventGroup> {
val request = listEventGroupsRequest { parent = measurementConsumer }
try {
return eventGroupsClient
.withAuthenticationKey(measurementConsumerData.apiAuthenticationKey)
.listEventGroups(request)
.eventGroupsList
} catch (e: StatusException) {
throw Exception("Error listing event groups for MC $measurementConsumer", e)
}
@OptIn(ExperimentalCoroutinesApi::class) // For `flattenConcat`.
private fun listEventGroups(measurementConsumer: String): Flow<EventGroup> {
return eventGroupsClient
.withAuthenticationKey(measurementConsumerData.apiAuthenticationKey)
.listResources { pageToken ->
val response =
try {
listEventGroups(
listEventGroupsRequest {
parent = measurementConsumer
this.pageToken = pageToken
}
)
} catch (e: StatusException) {
throw Exception("Error listing event groups for MC $measurementConsumer", e)
}
ResourceList(response.eventGroupsList, response.nextPageToken)
}
.flattenConcat()
}

private fun extractDataProviderKey(eventGroupName: String): DataProviderKey {
Expand All @@ -1283,7 +1301,7 @@ class MeasurementConsumerSimulator(
}
}

private suspend fun buildRequisitionInfo(
private fun buildRequisitionInfo(
dataProvider: DataProvider,
eventGroups: List<EventGroup>,
measurementConsumer: MeasurementConsumer,
Expand Down

0 comments on commit 7b4e922

Please sign in to comment.