diff --git a/.github/workflows/api-lint.yml b/.github/workflows/api-lint.yml index d15800a6ffa..08405ba3d7b 100644 --- a/.github/workflows/api-lint.yml +++ b/.github/workflows/api-lint.yml @@ -59,3 +59,4 @@ jobs: - run: tools/api-lint wfa/measurement/system - run: tools/api-lint wfa/measurement/securecomputation + - run: tools/api-lint wfa/measurement/access diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index ea039a28bc9..93528038444 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -67,6 +67,8 @@ jobs: build --define worker2_id=worker2 build --define worker2_public_api_target=worker2.example.com:8443 build --define mc_name=measurementConsumers/foo + build --define mc_api_key=foo + build --define mc_cert_name=measurementConsumers/foo/certificates/bar build --define edp1_name=dataProviders/foo1 build --define edp1_cert_name=dataProviders/foo1/certificates/bar1 build --define edp2_name=dataProviders/foo2 diff --git a/.github/workflows/configure-reporting.yml b/.github/workflows/configure-reporting.yml deleted file mode 100644 index 1dd99809b09..00000000000 --- a/.github/workflows/configure-reporting.yml +++ /dev/null @@ -1,184 +0,0 @@ -# Copyright 2023 The Cross-Media Measurement Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: "Configure Reporting" - -on: - workflow_call: - inputs: - environment: - type: string - required: true - image-tag: - description: "Tag of container images" - type: string - required: true - apply: - description: "Apply the new configuration" - type: boolean - required: true - workflow_dispatch: - inputs: - environment: - description: "GitHub-managed environment" - required: true - type: choice - options: - - dev - - qa - - head - image-tag: - description: "Tag of container images" - type: string - required: true - apply: - description: "Apply the new configuration" - type: boolean - default: false - -permissions: - id-token: write - -env: - KUSTOMIZATION_PATH: "k8s/reporting" - -jobs: - update-reporting: - runs-on: ubuntu-22.04 - environment: ${{ inputs.environment }} - steps: - - uses: actions/checkout@v4 - - # Authenticate to Google Cloud. This will export some environment - # variables, including GCLOUD_PROJECT. - - name: Authenticate to Google Cloud - uses: google-github-actions/auth@v2 - with: - workload_identity_provider: ${{ vars.WORKLOAD_IDENTITY_PROVIDER }} - service_account: ${{ vars.GKE_CONFIG_SERVICE_ACCOUNT }} - - - name: Write auth.bazelrc - env: - BUILDBUDDY_API_KEY: ${{ secrets.BUILDBUDDY_API_KEY }} - run: | - cat << EOF > auth.bazelrc - build --remote_header=x-buildbuddy-api-key=$BUILDBUDDY_API_KEY - EOF - - - name: Write ~/.bazelrc - env: - IMAGE_TAG: ${{ inputs.image-tag }} - POSTGRES_INSTANCE: ${{ vars.POSTGRES_INSTANCE }} - GCLOUD_REGION: ${{ vars.GCLOUD_REGION }} - KINGDOM_PUBLIC_API_TARGET: ${{ vars.KINGDOM_PUBLIC_API_TARGET }} - run: | - cat << EOF > ~/.bazelrc - common --config=ci - build --remote_download_outputs=toplevel # Need build output. - common --config=ghcr - build --define "image_tag=$IMAGE_TAG" - build --define "google_cloud_project=$GCLOUD_PROJECT" - build --define "postgres_instance=$POSTGRES_INSTANCE" - build --define "postgres_region=$GCLOUD_REGION" - build --define "kingdom_public_api_target=$KINGDOM_PUBLIC_API_TARGET" - build --define reporting_public_api_address_name=reporting-v1alpha - EOF - - - - name: Export BAZEL_BIN - run: echo "BAZEL_BIN=$(bazelisk info bazel-bin)" >> $GITHUB_ENV - - - name: Get GKE cluster credentials - uses: google-github-actions/get-gke-credentials@v2 - with: - cluster_name: reporting - location: ${{ vars.GCLOUD_ZONE }} - - - name: Configure metrics - uses: ./.github/actions/configure-metrics - if: ${{ inputs.apply }} - - - name: Generate archives - run: > - bazelisk build - //src/main/k8s/dev:reporting.tar - //src/main/k8s/testing/secretfiles:archive - - - name: Make Kustomization dir - run: mkdir -p "$KUSTOMIZATION_PATH" - - - name: Extract Kustomization archive - run: > - tar -xf "$BAZEL_BIN/src/main/k8s/dev/reporting.tar" - -C "$KUSTOMIZATION_PATH" - - - name: Extract secret files archive - run: > - tar -xf "$BAZEL_BIN/src/main/k8s/testing/secretfiles/archive.tar" - -C "$KUSTOMIZATION_PATH/src/main/k8s/dev/reporting_secrets" - - # Write files from configuration variables. Since it appears that GitHub - # configuration variables use DOS (CRLF) line endings, we convert these to - # Unix (LF) line endings. - - - name: Write AKID to principal map - env: - AKID_TO_PRINCIPAL_MAP: ${{ vars.AKID_TO_PRINCIPAL_MAP }} - run: > - echo "$AKID_TO_PRINCIPAL_MAP" | sed $'s/\r$//' > - "$KUSTOMIZATION_PATH/src/main/k8s/dev/reporting_config_files/authority_key_identifier_to_principal_map.textproto" - - - name: Write encryption key-pair config - env: - ENCRYPTION_KEY_PAIR_CONFIG: ${{ vars.ENCRYPTION_KEY_PAIR_CONFIG }} - run: > - echo "$ENCRYPTION_KEY_PAIR_CONFIG" | sed $'s/\r$//' > - "$KUSTOMIZATION_PATH/src/main/k8s/dev/reporting_config_files/encryption_key_pair_config.textproto" - - - name: Copy measurement spec config - run: > - cp src/main/k8s/testing/secretfiles/measurement_spec_config.textproto - "$KUSTOMIZATION_PATH/src/main/k8s/dev/reporting_config_files/" - - - name: Write measurement consumer config - env: - MEASUREMENT_CONSUMER_CONFIG: ${{ secrets.MEASUREMENT_CONSUMER_CONFIG }} - run: > - echo "$MEASUREMENT_CONSUMER_CONFIG" | sed $'s/\r$//' > - "$KUSTOMIZATION_PATH/src/main/k8s/dev/reporting_secrets/measurement_consumer_config.textproto" - - - name: Copy secret generator - run: > - cp src/main/k8s/testing/secretfiles/reporting_secrets_kustomization.yaml - "$KUSTOMIZATION_PATH/src/main/k8s/dev/reporting_secrets/kustomization.yaml" - - - name: Export KUSTOMIZE_PATH - run: echo "KUSTOMIZE_PATH=$KUSTOMIZATION_PATH/src/main/k8s/dev/reporting" >> $GITHUB_ENV - - # Run kubectl diff, treating the command as succeeded even if the exit - # code is 1 as kubectl uses this code to indicate there's a diff. - - name: kubectl diff - id: kubectl-diff - run: kubectl diff -k "$KUSTOMIZE_PATH" || (( $? == 1 )) - - - name: kubectl apply - if: ${{ inputs.apply }} - run: kubectl apply -k "$KUSTOMIZE_PATH" - - - name: Wait for rollout - if: ${{ inputs.apply }} - run: | - for deployment in $(kubectl get deployments -o name); do - kubectl rollout status "$deployment" --timeout=5m - done diff --git a/.github/workflows/run-k8s-tests.yml b/.github/workflows/run-k8s-tests.yml index 0f6ccea3c86..5906348104b 100644 --- a/.github/workflows/run-k8s-tests.yml +++ b/.github/workflows/run-k8s-tests.yml @@ -52,6 +52,7 @@ jobs: MC_NAME: ${{ vars.MC_NAME }} MC_API_KEY: ${{ secrets.MC_API_KEY }} GCLOUD_PROJECT: ${{ vars.GCLOUD_PROJECT }} + REPORTING_PUBLIC_API_TARGET: ${{ vars.REPORTING_PUBLIC_API_TARGET }} run: | cat << EOF > ~/.bazelrc common --config=ci @@ -59,6 +60,7 @@ jobs: build --define mc_name=$MC_NAME build --define mc_api_key=$MC_API_KEY build --define google_cloud_project=$GCLOUD_PROJECT + build --define reporting_public_api_target=$REPORTING_PUBLIC_API_TARGET test --test_output=streamed test --test_timeout=3600 EOF diff --git a/.github/workflows/scan-images.yml b/.github/workflows/scan-images.yml index 158513221f1..ee6d925037f 100644 --- a/.github/workflows/scan-images.yml +++ b/.github/workflows/scan-images.yml @@ -66,16 +66,13 @@ jobs: - kingdom/system-api - duchy/herald - duchy/spanner-computations - - reporting/postgres-data-server - kingdom/exchanges-deletion - - reporting/v1alpha-public-api - - reporting/postgres-update-schema - - reporting/v2/v2alpha-public-api - - reporting/v2/postgres-update-schema - panel-exchange/gcloud-example-daemon - panel-exchange/aws-example-daemon - - simulator/synthetic-generator-edp + - reporting/v2/v2alpha-public-api + - reporting/v2/postgres-update-schema - reporting/v2/postgres-internal-server + - simulator/synthetic-generator-edp - duchy/postgres-update-schema - duchy/gcloud-postgres-update-schema - duchy/postgres-internal-server diff --git a/.github/workflows/update-cmms.yml b/.github/workflows/update-cmms.yml index c6611ca358c..a08130e5c49 100644 --- a/.github/workflows/update-cmms.yml +++ b/.github/workflows/update-cmms.yml @@ -124,15 +124,6 @@ jobs: # Update the Reporting system. # # This isn't technically part of the CMMS, but we do it here for simplicity. - update-reporting: - uses: ./.github/workflows/configure-reporting.yml - secrets: inherit - needs: [publish-images, terraform] - with: - image-tag: ${{ needs.publish-images.outputs.image-tag }} - environment: ${{ inputs.environment }} - apply: ${{ inputs.apply }} - update-reporting-v2: uses: ./.github/workflows/configure-reporting-v2.yml secrets: inherit diff --git a/docs/gke/correctness-test.md b/docs/gke/correctness-test.md index 5b738737809..055def71ac0 100644 --- a/docs/gke/correctness-test.md +++ b/docs/gke/correctness-test.md @@ -80,7 +80,8 @@ bazel test //src/test/kotlin/org/wfanet/measurement/integration/k8s:SyntheticGen --test_output=streamed \ --define=kingdom_public_api_target=v2alpha.kingdom.dev.halo-cmm.org:8443 \ --define=mc_name=measurementConsumers/Rcn7fKd25C8 \ ---define=mc_api_key=W9q4zad246g +--define=mc_api_key=W9q4zad246g \ +--define=reporting_public_api_target=v2alpha.reporting.dev.halo-cmm.org:8443 ``` The time the test takes depends on the size of the data set. With the default diff --git a/docs/gke/metrics-deployment.md b/docs/gke/metrics-deployment.md index d0a8550439c..497e1a03994 100644 --- a/docs/gke/metrics-deployment.md +++ b/docs/gke/metrics-deployment.md @@ -30,7 +30,7 @@ free to use whichever you prefer. Deploy a Halo component. See the related guides: [Create Kingdom Cluster](kingdom-deployment.md), [Create Duchy Cluster](duchy-deployment.md), or -[Create Reporting Cluster](reporting-server-deployment.md). +[Create Reporting Cluster](reporting-server-v2-deployment.md). ## Google Cloud APIs diff --git a/docs/gke/reporting-server-deployment.md b/docs/gke/reporting-server-deployment.md deleted file mode 100644 index bb9259e2800..00000000000 --- a/docs/gke/reporting-server-deployment.md +++ /dev/null @@ -1,394 +0,0 @@ -# Halo Reporting Server Deployment on GKE - -## Important Note - -This is the deployment guide for the old V1. For the new V2, see -[Reporting V2](reporting-v2-server-deployment.md). - -## Background - -The configuration for the [`dev` environment](../../src/main/k8s/dev) can be -used as the basis for deploying CMMS components using Google Kubernetes Engine -(GKE) on another Google Cloud project. - -Many operations can be done either via the gcloud CLI or the Google Cloud web -console. This guide picks whichever is most convenient for that operation. Feel -free to use whichever you prefer. - -### What are we creating/deploying? - -- 1 Cloud SQL managed PostgreSQL database -- 1 GKE cluster - - 1 Kubernetes secret - - `certs-and-configs` - - 1 Kubernetes configmap - - `config-files` - - 2 Kubernetes services - - `postgres-reporting-data-server` (Cluster IP) - - `v1alpha-public-api-server` (External load balancer) - - 2 Kubernetes deployments - - `postgres-reporting-data-server-deployment` - - `v1alpha-public-api-server-deployment` - - 3 Kubernetes network policies - - `internal-data-server-network-policy` - - `public-api-server-network-policy` - - `default-deny-ingress-and-egress` - -## Before you start - -See [Machine Setup](machine-setup.md). - -### Managed PostgreSQL Quick Start - -If you don't have a managed PostgreSQL instance in your project, you can create -one in the -[Cloud Console](https://console.cloud.google.com/sql/instances/create;engine=PostgreSQL). -For the purposes of this guide, we assume the instance ID is `dev-postgres`. - -Make sure that the instance has the `cloudsql.iam_authentication` flag set to -`On`. Set the machine type and storage based on your expected usage. - -## Create the database - -The Reporting server expects its own database within your PostgreSQL instance. -You can create one with the `gcloud` CLI. For example, a database named -`reporting` in the `dev-postgres` instance. - -```shell -gcloud sql databases create reporting --instance=dev-postgres -``` - -## Build and push the container images - -If you aren't using pre-built release images, you can build the images yourself -from source and push them to a container registry. For example, if you're using -the [Google Container Registry](https://cloud.google.com/container-registry), -you would specify `gcr.io` as your container registry and your Cloud project -name as your image repository prefix. - -Assuming a project named `halo-cmm-dev` and an image tag `build-0001`, run the -following to build and push the images: - -```shell -bazel run -c opt //src/main/docker:push_all_reporting_gke_images \ - --define container_registry=gcr.io \ - --define image_repo_prefix=halo-cmm-dev --define image_tag=build-0001 -``` - -Tip: If you're using [Hybrid Development](../building.md#hybrid-development) for -containerized builds, replace `bazel build` with `tools/bazel-container build` -and `bazel run` with `tools/bazel-container-run`. - -## Create resources for the cluster - -See [GKE Cluster Configuration](cluster-config.md) for background. - -### IAM Service Accounts - -We'll want to -[create a least privilege service account](https://cloud.google.com/kubernetes-engine/docs/how-to/hardening-your-cluster#use_least_privilege_sa) -that our cluster will run under. Follow the steps in the linked guide to do -this. - -We'll additionally want to create a service account that we'll use to allow the -internal API server to access the database. See -[Granting Cloud SQL database access](cluster-config.md#granting-cloud-sql-instance-access) -for how to make sure this service account has the appropriate role. - -### KMS key for secret encryption - -Follow the steps in -[Create a Cloud KMS key](https://cloud.google.com/kubernetes-engine/docs/how-to/encrypting-secrets#creating-key) -to create a KMS key and grant permission to the GKE service agent to use it. - -Let's assume we've created a key named `k8s-secret` in a key ring named -`test-key-ring` in the `us-central1` region under the `halo-cmm-dev` project. -The resource name would be the following: -`projects/halo-cmm-dev/locations/us-central1/keyRings/test-key-ring/cryptoKeys/k8s-secret`. -We'll use this when creating the cluster. - -Tip: For convenience, there is a "Copy resource name" action on the key in the -Cloud console. - -## Create the cluster - -See [GKE Cluster Configuration](cluster-config.md) for tips on cluster creation -parameters, or follow the quick start instructions below. - -After creating the cluster, we can configure `kubectl` to be able to access it - -```shell -gcloud container clusters get-credentials reporting -``` - -### Add Metrics to the cluster - -See [Metrics Deployment](metrics-deployment.md). - -### Quick start - -Supposing you want to create a cluster named `reporting` for the Reporting -server, running under the `gke-cluster` service account in the `halo-cmm-dev` -project, the command would be - -```shell -gcloud container clusters create reporting \ - --enable-network-policy --workload-pool=halo-cmm-dev.svc.id.goog \ - --service-account="gke-cluster@halo-cmm-dev.iam.gserviceaccount.com" \ - --database-encryption-key=projects/halo-cmm-dev/locations/us-central1/keyRings/test-key-ring/cryptoKeys/k8s-secret \ - --num-nodes=3 --enable-autoscaling --min-nodes=2 --max-nodes=4 \ - --machine-type=e2-small -``` - -Adjust the number of nodes and machine type according to your expected usage. -The cluster version should be no older than `1.24.0` in order to support -built-in gRPC health probe. - -## Create the K8s ServiceAccount - -In order to use the IAM service account that we created earlier from our -cluster, we need to create a K8s ServiceAccount and give it access to that IAM -service account. - -For example, to create a K8s ServiceAccount named `internal-reporting-server`, -run - -```shell -kubectl create serviceaccount internal-reporting-server -``` - -Supposing the IAM service account you created in a previous step is named -`reporting-internal` within the `halo-cmm-dev` project. You'll need to allow the -K8s service account to impersonate it - -```shell -gcloud iam service-accounts add-iam-policy-binding \ - reporting-internal@halo-cmm-dev.iam.gserviceaccount.com \ - --role roles/iam.workloadIdentityUser \ - --member "serviceAccount:halo-cmm-dev.svc.id.goog[default/internal-reporting-server]" -``` - -Finally, add an annotation to link the K8s service account to the IAM service -account: - -```shell -kubectl annotate serviceaccount internal-reporting-server \ - iam.gke.io/gcp-service-account=reporting-internal@halo-cmm-dev.iam.gserviceaccount.com -``` - -## Generate the K8s Kustomization - -Populating a cluster is generally done by applying a K8s Kustomization. You can -use the `dev` configuration as a base to get started. The Kustomization is -generated using Bazel rules from files written in [CUE](https://cuelang.org/). - -To generate the `dev` Kustomization, run the following (substituting your own -values): - -```shell -bazel build //src/main/k8s/dev:reporting.tar \ - --define reporting_public_api_address_name=reporting-v1alpha \ - --define google_cloud_project=halo-cmm-dev \ - --define postgres_instance=dev-postgres \ - --define postgres_region=us-central1 \ - --define kingdom_public_api_target=v2alpha.kingdom.dev.halo-cmm.org:8443 \ - --define container_registry=gcr.io \ - --define image_repo_prefix=halo-kingdom-demo --define image_tag=build-0001 -``` - -Extract the generated archive to some directory. - -You can customize this generated object configuration with your own settings -such as the number of replicas per deployment, the memory and CPU requirements -of each container, and the JVM options of each container. - -## Customize the K8s secrets - -We use K8s secrets to hold sensitive information, such as private keys. - -### Certificates and signing keys - -First, prepare all the files we want to include in the Kubernetes secret. The -`dev` configuration assumes the files have the following names: - -1. `all_root_certs.pem` - - This makes up the trusted root CA store. It's the concatenation of the root - CA certificates for all the entities that the Reporting server interacts - with, including: - - * All Measurement Consumers - * Any entity which produces Measurement results (e.g. the Aggregator Duchy - and Data Providers) - * The Kingdom - * The Reporting server itself (for internal traffic) - - Supposing your root certs are all in a single folder and end with - `_root.pem`, you can concatenate them all with a simple shell command: - - ```shell - cat *_root.pem > all_root_certs.pem - ``` - - Note: This assumes that all your root certificate PEM files end in newline. - -1. `reporting_tls.pem` - - The Reporting server's TLS certificate. - -1. `reporting_tls.key` - - The private key for the Reporting server's TLS certificate. - -In addition, you'll need to include the encryption and signing private keys for -the Measurement Consumers that this Reporting server instance needs to act on -behalf of. The encryption keys are assumed to be in Tink's binary keyset format. -The signing private keys are assumed to be DER-encoded unencrypted PKCS #8. - -#### Testing keys - -There are some [testing keys](../../src/main/k8s/testing/secretfiles) within the -repository. These can be used to create the above secret for testing, but **must -not** be used for production environments as doing so would be highly insecure. - -Generate the archive: - -```shell -bazel build //src/main/k8s/testing/secretfiles:archive -``` - -Extract the generated archive to the `src/main/k8s/dev/reporting_secrets/` path -within the Kustomization directory. - -### Measurement Consumer config - -Contents: - -1. `measurement_consumer_config.textproto` - - [`MeasurementConsumerConfig`](../../src/main/proto/wfa/measurement/config/reporting/measurement_consumer_config.proto) - protobuf message in text format. - -### Generator - -Place the above files into the `src/main/k8s/dev/reporting_secrets/` path within -the Kustomization directory. - -Create a `kustomization.yaml` file in that path with the following content, -substituting the names of your own keys: - -```yaml -secretGenerator: -- name: signing - files: - - all_root_certs.pem - - reporting_tls.key - - reporting_tls.pem - - mc_enc_public.tink - - mc_enc_private.tink - - mc_cs_private.der -- name: mc-config - files: - - measurement_consumer_config.textproto -``` - -## Customize the K8s ConfigMap - -Configuration that may frequently change is stored in a K8s configMap. The `dev` -configuration uses one named `config-files`. - -* `authority_key_identifier_to_principal_map.textproto` - * [`AuthorityKeyToPrincipalMap`](../../src/main/proto/wfa/measurement/config/authority_key_to_principal_map.proto) -* `encryption_key_pair_config.textproto` - * [`EncryptionKeyPairConfig`](../../src/main/proto/wfa/measurement/config/reporting/encryption_key_pair_config.proto) -* `measurement_spec_config.textproto` - * [`MeasurementSpecConfig`](../../src/main/proto/wfa/measurement/config/reporting/measurement_spec_config.proto) -* `known_event_group_metadata_type_set.pb` - * Protobuf `FileDescriptorSet` containing known `EventGroup` metadata - types. - -Place these files into the `src/main/k8s/dev/reporting_config_files/` path -within the Kustomization directory. - -## Apply the K8s Kustomization - -Within the Kustomization directory, run - -```shell -kubectl apply -k src/main/k8s/dev/reporting -``` - -Now all components should be successfully deployed to your GKE cluster. You can -verify by running - -```shell -kubectl get deployments -``` - -and - -```shell -kubectl get services -``` - -You should see something like the following: - -``` -NAME READY UP-TO-DATE AVAILABLE AGE -postgres-reporting-data-server-deployment 1/1 1 1 254d -reporting-public-api-v1alpha-server-deployment 1/1 1 1 9m2s -``` - -``` -NAME TYPE CLUSTER-IP EXTERNAL-IP PORT(S) AGE -kubernetes ClusterIP 10.16.32.1 443/TCP 260d -postgres-reporting-data-server ClusterIP 10.16.39.47 8443/TCP 254d -reporting-public-api-v1alpha-server LoadBalancer 10.16.32.255 34.135.79.68 8443:30104/TCP 8m45s -``` - -## Appendix - -### Troubleshooting - -* `notAuthorized` error - - You see an error that looks something like this: - - ``` - { - "code": 403, - "errors": [ - { - "domain": "global", - "message": "The client is not authorized to make this request.", - "reason": "notAuthorized" - } - ], - "message": "The client is not authorized to make this request." - } - ``` - - Make sure that your Cloud SQL instance has the `cloudsql.iam_authentication` - flag set to `On`, and that you've followed all the steps for using Workload - Identity and IAM Authentication. See the - [Workload Identity](cluster-config.md#workload-identity) section in the - Cluster Configuration doc. - - If you believe you have everything configured correctly, try deleting and - recreating the IAM service account for DB access. Apparently there's a - glitch with Cloud SQL that this sometimes resolves. - -### Manual testing via CLI - -The -[`Reporting`](../../src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools) -CLI tool can be used for manual testing as well as examples of how to call the -API. - -### HTTP/REST - -The public API is a set of gRPC services following the -[API Improvement Proposals](https://google.aip.dev/). These can be exposed as an -HTTP REST API using the -[gRPC Gateway](https://github.com/grpc-ecosystem/grpc-gateway). The service -definitions include the appropriate protobuf annotations for this purpose. diff --git a/docs/gke/reporting-v2-server-deployment.md b/docs/gke/reporting-v2-server-deployment.md index 5e5d0056e3a..df8aeeb0c0f 100644 --- a/docs/gke/reporting-v2-server-deployment.md +++ b/docs/gke/reporting-v2-server-deployment.md @@ -1,10 +1,5 @@ # Halo Reporting V2 Server Deployment on GKE -## Important Note - -This is the deployment guide for the new V2. For the old V1, see -[Reporting V1](reporting-server-deployment.md). - ## Background The configuration for the [`dev` environment](../../src/main/k8s/dev) can be diff --git a/src/main/docker/images.bzl b/src/main/docker/images.bzl index 0fa56f1e971..08125217c30 100644 --- a/src/main/docker/images.bzl +++ b/src/main/docker/images.bzl @@ -232,40 +232,6 @@ LOCAL_IMAGES = [ ), ] -REPORTING_COMMON_IMAGES = [ - struct( - name = "reporting_v1alpha_public_api_server_image", - image = "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server:v1alpha_public_api_server_image", - repository = _PREFIX + "/reporting/v1alpha-public-api", - ), -] - -REPORTING_LOCAL_IMAGES = [ - struct( - name = "reporting_data_server_image", - image = "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/server:postgres_reporting_data_server_image", - repository = _PREFIX + "/reporting/local-postgres-internal", - ), - struct( - name = "reporting_postgres_update_schema_image", - image = "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/tools:update_schema_image", - repository = _PREFIX + "/reporting/local-postgres-update-schema", - ), -] - -REPORTING_GKE_IMAGES = [ - struct( - name = "gcloud_reporting_data_server_image", - image = "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/gcloud/postgres/server:gcloud_postgres_reporting_data_server_image", - repository = _PREFIX + "/reporting/postgres-data-server", - ), - struct( - name = "gcloud_reporting_postgres_update_schema_image", - image = "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/gcloud/postgres/tools:update_schema_image", - repository = _PREFIX + "/reporting/postgres-update-schema", - ), -] - REPORTING_V2_COMMON_IMAGES = [ struct( name = "reporting_v2alpha_public_api_server_image", @@ -305,12 +271,12 @@ REPORTING_V2_GKE_IMAGES = [ ), ] -ALL_GKE_IMAGES = COMMON_IMAGES + GKE_IMAGES + REPORTING_COMMON_IMAGES + REPORTING_GKE_IMAGES + REPORTING_V2_COMMON_IMAGES + REPORTING_V2_GKE_IMAGES +ALL_GKE_IMAGES = COMMON_IMAGES + GKE_IMAGES + REPORTING_V2_COMMON_IMAGES + REPORTING_V2_GKE_IMAGES -ALL_LOCAL_IMAGES = COMMON_IMAGES + LOCAL_IMAGES + REPORTING_COMMON_IMAGES + REPORTING_LOCAL_IMAGES + REPORTING_V2_COMMON_IMAGES + REPORTING_V2_LOCAL_IMAGES +ALL_LOCAL_IMAGES = COMMON_IMAGES + LOCAL_IMAGES + REPORTING_V2_COMMON_IMAGES + REPORTING_V2_LOCAL_IMAGES -ALL_IMAGES = COMMON_IMAGES + LOCAL_IMAGES + GKE_IMAGES + REPORTING_COMMON_IMAGES + REPORTING_LOCAL_IMAGES + REPORTING_GKE_IMAGES + REPORTING_V2_COMMON_IMAGES + REPORTING_V2_LOCAL_IMAGES + REPORTING_V2_GKE_IMAGES + EKS_IMAGES +ALL_IMAGES = COMMON_IMAGES + LOCAL_IMAGES + GKE_IMAGES + REPORTING_V2_COMMON_IMAGES + REPORTING_V2_LOCAL_IMAGES + REPORTING_V2_GKE_IMAGES + EKS_IMAGES -ALL_REPORTING_GKE_IMAGES = REPORTING_COMMON_IMAGES + REPORTING_GKE_IMAGES + REPORTING_V2_COMMON_IMAGES + REPORTING_V2_GKE_IMAGES +ALL_REPORTING_GKE_IMAGES = REPORTING_V2_COMMON_IMAGES + REPORTING_V2_GKE_IMAGES ALL_EKS_IMAGES = COMMON_IMAGES + EKS_IMAGES diff --git a/src/main/k8s/BUILD.bazel b/src/main/k8s/BUILD.bazel index e815499a116..1e930ced64a 100644 --- a/src/main/k8s/BUILD.bazel +++ b/src/main/k8s/BUILD.bazel @@ -78,16 +78,6 @@ cue_library( srcs = ["postgres.cue"], ) -cue_library( - name = "reporting", - srcs = ["reporting.cue"], - deps = [ - ":base", - ":config", - ":postgres", - ], -) - cue_library( name = "reporting_v2", srcs = ["reporting_v2.cue"], diff --git a/src/main/k8s/dev/BUILD.bazel b/src/main/k8s/dev/BUILD.bazel index 432063212be..67136a60f35 100644 --- a/src/main/k8s/dev/BUILD.bazel +++ b/src/main/k8s/dev/BUILD.bazel @@ -419,58 +419,6 @@ kustomization_dir( ], ) -cue_dump( - name = "reporting_gke", - srcs = ["reporting_gke.cue"], - cue_tags = { - "secret_name": SIGNING_SECRET_NAME, - "mc_config_secret_name": MC_CONFIG_SECRET_NAME, - "container_registry": IMAGE_REPOSITORY_SETTINGS.container_registry, - "image_repo_prefix": IMAGE_REPOSITORY_SETTINGS.repository_prefix, - "image_tag": IMAGE_REPOSITORY_SETTINGS.image_tag, - "public_api_address_name": REPORTING_K8S_SETTINGS.public_api_address_name, - "google_cloud_project": GCLOUD_SETTINGS.project, - "postgres_instance": GCLOUD_SETTINGS.postgres_instance, - "postgres_region": GCLOUD_SETTINGS.postgres_region, - "kingdom_public_api_target": KINGDOM_K8S_SETTINGS.public_api_target, - }, - tags = ["manual"], - deps = [ - ":base_gke", - ":config_gke", - "//src/main/k8s:reporting", - ], -) - -kustomization_dir( - name = "reporting_config_files", - testonly = True, - srcs = [ - "reporting_config_files_kustomization.yaml", - "//src/main/k8s/testing/secretfiles:known_event_group_metadata_type_set", - ], - renames = {"reporting_config_files_kustomization.yaml": "kustomization.yaml"}, -) - -kustomization_dir( - name = "reporting_secrets", -) - -kustomization_dir( - name = "reporting", - testonly = True, - srcs = [ - "resource_requirements.yaml", - ":reporting_gke", - ], - generate_kustomization = True, - tags = ["manual"], - deps = [ - ":reporting_config_files", - ":reporting_secrets", - ], -) - cue_dump( name = "reporting_v2_gke", srcs = ["reporting_v2_gke.cue"], diff --git a/src/main/k8s/dev/reporting_gke.cue b/src/main/k8s/dev/reporting_gke.cue deleted file mode 100644 index 7af25490016..00000000000 --- a/src/main/k8s/dev/reporting_gke.cue +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package k8s - -_reportingSecretName: string @tag("secret_name") -_reportingMcConfigSecretName: string @tag("mc_config_secret_name") -_publicApiAddressName: string @tag("public_api_address_name") - -#KingdomApiTarget: #GrpcTarget & { - target: string @tag("kingdom_public_api_target") -} - -// Name of K8s service account for the internal API server. -#InternalServerServiceAccount: "internal-reporting-server" - -#InternalServerResourceRequirements: #ResourceRequirements & { - requests: { - cpu: "100m" - } -} -#PublicServerResourceRequirements: ResourceRequirements=#ResourceRequirements & { - requests: { - cpu: "25m" - memory: "256Mi" - } - limits: { - memory: ResourceRequirements.requests.memory - } -} - -objectSets: [ - defaultNetworkPolicies, - reporting.serviceAccounts, - reporting.configMaps, - reporting.deployments, - reporting.services, - reporting.networkPolicies, -] - -reporting: #Reporting & { - _secretName: _reportingSecretName - _mcConfigSecretName: _reportingMcConfigSecretName - _kingdomApiTarget: #KingdomApiTarget - _internalApiTarget: certificateHost: "localhost" - - _postgresConfig: { - iamUserLocal: "reporting-internal" - database: "reporting" - } - - _verboseGrpcServerLogging: true - - serviceAccounts: { - "\(#InternalServerServiceAccount)": #WorkloadIdentityServiceAccount & { - _iamServiceAccountName: "reporting-internal" - } - } - - configMaps: "java": #JavaConfigMap - - deployments: { - "postgres-reporting-data-server": { - _container: resources: #InternalServerResourceRequirements - spec: template: spec: #ServiceAccountPodSpec & { - serviceAccountName: #InternalServerServiceAccount - } - } - "reporting-public-api-v1alpha-server": { - _container: resources: #PublicServerResourceRequirements - } - } - - services: { - "reporting-public-api-v1alpha-server": _ipAddressName: _publicApiAddressName - } -} diff --git a/src/main/k8s/local/BUILD.bazel b/src/main/k8s/local/BUILD.bazel index 552219510fe..9d9c7cb5612 100644 --- a/src/main/k8s/local/BUILD.bazel +++ b/src/main/k8s/local/BUILD.bazel @@ -200,25 +200,6 @@ cue_dump( ], ) -cue_dump( - name = "reporting", - srcs = ["reporting.cue"], - cue_tags = { - "secret_name": SECRET_NAME, - "db_secret_name": DB_SECRET_NAME, - "mc_config_secret_name": MC_CONFIG_SECRET_NAME, - "container_registry": IMAGE_REPOSITORY_SETTINGS.container_registry, - "image_repo_prefix": IMAGE_REPOSITORY_SETTINGS.repository_prefix, - "image_tag": IMAGE_REPOSITORY_SETTINGS.image_tag, - }, - tags = ["manual"], - deps = [ - ":config_cue", - "//src/main/k8s:postgres", - "//src/main/k8s:reporting", - ], -) - cue_dump( name = "reporting_v2", srcs = ["reporting_v2.cue"], @@ -292,7 +273,6 @@ kustomization_dir( "empty_encryption_key_pair_config.textproto", "//src/main/k8s/testing/data:synthetic_generation_specs_small", "//src/main/k8s/testing/secretfiles:known_event_group_metadata_type_set", - "//src/main/k8s/testing/secretfiles:measurement_spec_config.textproto", "//src/main/k8s/testing/secretfiles:metric_spec_config.textproto", ], renames = { @@ -309,7 +289,6 @@ kustomization_dir( ":encryption_key_pair_config.textproto", "//src/main/k8s/testing/data:synthetic_generation_specs_small", "//src/main/k8s/testing/secretfiles:known_event_group_metadata_type_set", - "//src/main/k8s/testing/secretfiles:measurement_spec_config.textproto", "//src/main/k8s/testing/secretfiles:metric_spec_config.textproto", ], renames = { @@ -384,26 +363,6 @@ kustomization_dir( ], ) -kustomization_dir( - name = "cmms_with_reporting", - srcs = [ - ":duchies", - ":edp_simulators", - ":emulators", - ":kingdom", - ":postgres_database", - ":reporting", - ], - generate_kustomization = True, - tags = ["manual"], - deps = [ - ":config_files", - ":db_creds", - ":mc_config", - "//src/main/k8s/testing/secretfiles:kustomization", - ], -) - kustomization_dir( name = "cmms_with_reporting_v2", srcs = [ diff --git a/src/main/k8s/local/README.md b/src/main/k8s/local/README.md index d1d6a20c609..b930fa1e7fa 100644 --- a/src/main/k8s/local/README.md +++ b/src/main/k8s/local/README.md @@ -111,7 +111,7 @@ kubectl port-forward --address=localhost services/v2alpha-public-api-server 8443 kubectl port-forward --address=localhost services/gcp-kingdom-data-server 9443:8443 ``` -Then run the tool, outputting to some directory (e.g. `/tmp/resource-setup`, +Then run the tool, outputting to some directory (e.g. `/tmp/cmms/resource-setup`, make sure this directory has been created): ```shell @@ -119,7 +119,7 @@ src/main/k8s/testing/resource_setup.sh \ --kingdom-public-api-target=localhost:8443 \ --kingdom-internal-api-target=localhost:9443 \ --bazel-config-name=halo-local \ - --output-dir=/tmp/resource-setup + --output-dir=/tmp/cmms/resource-setup ``` Note: You can stop the port forwarding at this point. Future steps involve @@ -198,8 +198,8 @@ This is an alternate version of the section above. This assumes you've already done the Initial Setup and have the output from the `ResourceSetup` tool. Use the command in the above section to build the tar archive, swapping the -target with `//src/main/k8s/local:cmms_with_reporting.tar`. Extract this archive -to some directory (e.g. `/tmp/cmms`). +target with `//src/main/k8s/local:cmms_with_reporting_v2.tar`. Extract this +archive to some directory (e.g. `/tmp/cmms`). Copy the `authority_key_identifier_to_principal_map.textproto` output from the `ResourceSetup` tool to the `src/main/k8s/local/config_files` path within this @@ -209,12 +209,9 @@ You can then apply the Kustomization from the directory where you extracted the archive: ```shell -kubectl apply -k src/main/k8s/local/cmms_with_reporting/ +kubectl apply -k src/main/k8s/local/cmms_with_reporting_v2/ ``` -To deploy Reporting V2, swap out `cmms_with_reporting` with -`cmms_with_reporting_v2`. - To use the Reporting CLI tool, you need to forward the port again. For example: ```shell @@ -228,17 +225,18 @@ kubectl port-forward --address=localhost services/reporting-v2alpha-public-api-s The code is instrumented with OpenTelemetry. To enable metrics collection in the cluster, we use the [OpenTelemetry Operator for Kubernetes](https://github.com/open-telemetry/opentelemetry-operator). -This depends on [cert-manager](https://github.com/cert-manager/cert-manager) so -we install that first: +This depends on [cert-manager](https://github.com/cert-manager/cert-manager), so +we install that first. Make sure their versions are the same as versions listed in +[install-otel-operator/action.yml](/.github/actions/install-otel-operator/action.yml#L4): ```shell -kubectl apply -f https://github.com/cert-manager/cert-manager/releases/download/v1.13.2/cert-manager.yaml +kubectl apply -f https://github.com/cert-manager/cert-manager/releases/download/v1.14.5/cert-manager.yaml ``` Once that is installed, then install the Operator: ```shell -kubectl apply -f https://github.com/open-telemetry/opentelemetry-operator/releases/download/v0.88.0/opentelemetry-operator.yaml +kubectl apply -f https://github.com/open-telemetry/opentelemetry-operator/releases/download/v0.99.0/opentelemetry-operator.yaml ``` ### Deploy OpenTelemetry Resources @@ -316,6 +314,24 @@ bazel test //src/test/kotlin/org/wfanet/measurement/integration/k8s:SyntheticGen --define=mc_api_key=W9q4zad246g ``` +## Debugging Tips + +To quickly test changes on a deployed cluster, build and push container +images first with a new image_tag, and then edit corresponding resource's +`image` field in its configuration with `kubectl edit` + +```shell +tools/bazel-container-run //src/main/docker:push_all_local_images \ + --define container_registry=registry.dev.svc.cluster.local:5001 \ + --define image_repo_prefix=halo --define image_tag= +``` + +an example: kubectl edit cronjob/measurement-system-prober-cronjob: + +```yaml + image: localhost:5001/halo/kingdom/measurement-system-prober: +``` + ## Old Guide This has instructions that may be outdated. @@ -420,3 +436,25 @@ do this using `minikube ssh` or using the Note: If you are using rootless Docker or rootless Podman, you may need to also enable IP forwarding on your host machine. See https://github.com/kubernetes/minikube/issues/18667 + +### Delete Existing Kind Cluster + +Only execute this part if you want to delete existing local cluster before creating +a new one. Let's say the cluster is named `kind`. + +Delete the KIND cluster: + +```shell +kind delete cluster --name=kind +``` + +Delete the local registry container: + +```shell +docker rm -f kind-registry +``` + +Prune Docker resources (optional, recommended): +```shell +docker system prune -a +``` diff --git a/src/main/k8s/local/config_files_kustomization.yaml b/src/main/k8s/local/config_files_kustomization.yaml index 56e23573e98..71c188ba7fb 100644 --- a/src/main/k8s/local/config_files_kustomization.yaml +++ b/src/main/k8s/local/config_files_kustomization.yaml @@ -17,7 +17,6 @@ configMapGenerator: files: - authority_key_identifier_to_principal_map.textproto - encryption_key_pair_config.textproto - - measurement_spec_config.textproto - metric_spec_config.textproto - synthetic_population_spec_small.textproto - synthetic_event_group_spec_small_1.textproto diff --git a/src/main/k8s/local/reporting.cue b/src/main/k8s/local/reporting.cue deleted file mode 100644 index 7d7857b5229..00000000000 --- a/src/main/k8s/local/reporting.cue +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package k8s - -_reportingSecretName: string @tag("secret_name") -_reportingDbSecretName: string @tag("db_secret_name") -_reportingMcConfigSecretName: string @tag("mc_config_secret_name") - -objectSets: [ for objectSet in reporting {objectSet}] - -reporting: #Reporting & { - _secretName: _reportingSecretName - _mcConfigSecretName: _reportingMcConfigSecretName - _imageSuffixes: { - "update-reporting-schema": "reporting/local-postgres-update-schema" - "postgres-reporting-data-server": "reporting/local-postgres-internal" - } - - _postgresConfig: { - serviceName: "postgres" - password: "$(POSTGRES_PASSWORD)" - user: "$(POSTGRES_USER)" - } - _kingdomApiTarget: { - serviceName: "v2alpha-public-api-server" - certificateHost: "localhost" - } - _internalApiTarget: { - certificateHost: "localhost" - } - _verboseGrpcServerLogging: true - _verboseGrpcClientLogging: true - - let EnvVars = #EnvVarMap & { - "POSTGRES_USER": { - valueFrom: - secretKeyRef: { - name: _reportingDbSecretName - key: "username" - } - } - "POSTGRES_PASSWORD": { - valueFrom: - secretKeyRef: { - name: _reportingDbSecretName - key: "password" - } - } - } - - deployments: { - "postgres-reporting-data-server": { - _container: _envVars: EnvVars - _updateSchemaContainer: _envVars: EnvVars - } - "reporting-public-api-v1alpha-server": { - spec: template: spec: { - _dependencies: [ - "postgres-reporting-data-server", - "v2alpha-public-api-server", // Kingdom public API server. - ] - } - } - } -} diff --git a/src/main/k8s/local/testing/BUILD.bazel b/src/main/k8s/local/testing/BUILD.bazel index 51b188a2f9d..c92878a7261 100644 --- a/src/main/k8s/local/testing/BUILD.bazel +++ b/src/main/k8s/local/testing/BUILD.bazel @@ -63,8 +63,25 @@ kustomization_dir( "config_files_kustomization.yaml", "//src/main/k8s/testing/data:synthetic_generation_specs_small", "//src/main/k8s/testing/secretfiles:known_event_group_metadata_type_set", + "//src/main/k8s/testing/secretfiles:metric_spec_config.textproto", ], renames = {"config_files_kustomization.yaml": "kustomization.yaml"}, + tags = ["manual"], +) + +kustomization_dir( + name = "config_files_for_panel_match", + srcs = [ + "config_files_kustomization.yaml", + "empty_encryption_key_pair_config.textproto", + "//src/main/k8s/testing/data:synthetic_generation_specs_small", + "//src/main/k8s/testing/secretfiles:known_event_group_metadata_type_set", + "//src/main/k8s/testing/secretfiles:metric_spec_config.textproto", + ], + renames = { + "config_files_kustomization.yaml": "kustomization.yaml", + "empty_encryption_key_pair_config.textproto": "encryption_key_pair_config.textproto", + }, ) kustomization_dir( @@ -75,6 +92,15 @@ kustomization_dir( }, ) +kustomization_dir( + name = "mc_config", + srcs = ["mc_config_kustomization.yaml"], + renames = { + "mc_config_kustomization.yaml": "kustomization.yaml", + }, + tags = ["manual"], +) + kustomization_dir( name = "cmms", srcs = [ @@ -83,12 +109,14 @@ kustomization_dir( "//src/main/k8s/local:emulators", "//src/main/k8s/local:kingdom", "//src/main/k8s/local:postgres_database", + "//src/main/k8s/local:reporting_v2", ], generate_kustomization = True, tags = ["manual"], deps = [ ":config_files", ":db_creds", + ":mc_config", "//src/main/k8s/testing/secretfiles:kustomization", ], ) @@ -103,7 +131,7 @@ kustomization_dir( generate_kustomization = True, tags = ["manual"], deps = [ - ":config_files", + ":config_files_for_panel_match", "//src/main/k8s/testing/secretfiles:kustomization", ], ) diff --git a/src/main/k8s/local/testing/config_files_kustomization.yaml b/src/main/k8s/local/testing/config_files_kustomization.yaml index d882167e307..71c188ba7fb 100644 --- a/src/main/k8s/local/testing/config_files_kustomization.yaml +++ b/src/main/k8s/local/testing/config_files_kustomization.yaml @@ -16,6 +16,8 @@ configMapGenerator: - name: config-files files: - authority_key_identifier_to_principal_map.textproto + - encryption_key_pair_config.textproto + - metric_spec_config.textproto - synthetic_population_spec_small.textproto - synthetic_event_group_spec_small_1.textproto - synthetic_event_group_spec_small_2.textproto diff --git a/src/main/k8s/local/testing/empty_encryption_key_pair_config.textproto b/src/main/k8s/local/testing/empty_encryption_key_pair_config.textproto new file mode 100644 index 00000000000..51c20dc96c6 --- /dev/null +++ b/src/main/k8s/local/testing/empty_encryption_key_pair_config.textproto @@ -0,0 +1,2 @@ +# proto-file: wfa/measurement/config/reporting/encryption_key_pair_config.proto +# proto-message: EncryptionKeyPairConfig diff --git a/src/main/k8s/dev/reporting_config_files_kustomization.yaml b/src/main/k8s/local/testing/mc_config_kustomization.yaml similarity index 67% rename from src/main/k8s/dev/reporting_config_files_kustomization.yaml rename to src/main/k8s/local/testing/mc_config_kustomization.yaml index a15c366fe23..8e329fcff4d 100644 --- a/src/main/k8s/dev/reporting_config_files_kustomization.yaml +++ b/src/main/k8s/local/testing/mc_config_kustomization.yaml @@ -1,4 +1,4 @@ -# Copyright 2022 The Cross-Media Measurement Authors +# Copyright 2024 The Cross-Media Measurement Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -configMapGenerator: -- name: config-files +secretGenerator: +- name: mc-config files: - - authority_key_identifier_to_principal_map.textproto - - encryption_key_pair_config.textproto - - measurement_spec_config.textproto - - known_event_group_metadata_type_set.pb + - measurement_consumer_config.textproto diff --git a/src/main/k8s/reporting.cue b/src/main/k8s/reporting.cue deleted file mode 100644 index 8cae8d4b4a4..00000000000 --- a/src/main/k8s/reporting.cue +++ /dev/null @@ -1,176 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package k8s - -#Reporting: Reporting={ - _verboseGrpcServerLogging: bool | *false - _verboseGrpcClientLogging: bool | *false - - _postgresConfig: #PostgresConfig - - _internalApiTarget: #GrpcTarget & { - serviceName: "postgres-reporting-data-server" - targetOption: "--internal-api-target" - certificateHostOption: "--internal-api-cert-host" - } - _kingdomApiTarget: #GrpcTarget & { - targetOption: "--kingdom-api-target" - certificateHostOption: "--kingdom-api-cert-host" - } - - _imageSuffixes: [_=string]: string - _imageSuffixes: { - "update-reporting-schema": string | *"reporting/postgres-update-schema" - "postgres-reporting-data-server": string | *"reporting/postgres-data-server" - "reporting-public-api-v1alpha-server": string | *"reporting/v1alpha-public-api" - } - _imageConfigs: [_=string]: #ImageConfig - _imageConfigs: { - for name, suffix in _imageSuffixes { - "\(name)": {repoSuffix: suffix} - } - } - _images: { - for name, config in _imageConfigs { - "\(name)": config.image - } - } - _secretName: string - _mcConfigSecretName: string - - _tlsArgs: [ - "--tls-cert-file=/var/run/secrets/files/reporting_tls.pem", - "--tls-key-file=/var/run/secrets/files/reporting_tls.key", - ] - _reportingCertCollectionFileFlag: "--cert-collection-file=/var/run/secrets/files/all_root_certs.pem" - _akidToPrincipalMapFileFlag: "--authority-key-identifier-to-principal-map-file=/etc/\(#AppName)/config-files/authority_key_identifier_to_principal_map.textproto" - _measurementConsumerConfigFileFlag: "--measurement-consumer-config-file=/var/run/secrets/files/config/mc/measurement_consumer_config.textproto" - _signingPrivateKeyStoreDirFlag: "--signing-private-key-store-dir=/var/run/secrets/files" - _encryptionKeyPairDirFlag: "--key-pair-dir=/var/run/secrets/files" - _encryptionKeyPairConfigFileFlag: "--key-pair-config-file=/etc/\(#AppName)/config-files/encryption_key_pair_config.textproto" - _measurementSpecConfigFileFlag: "--measurement-spec-config-file=/etc/\(#AppName)/config-files/measurement_spec_config.textproto" - _knownEventGroupMetadataTypeFlag: "--known-event-group-metadata-type=/etc/\(#AppName)/config-files/known_event_group_metadata_type_set.pb" - _debugVerboseGrpcClientLoggingFlag: "--debug-verbose-grpc-client-logging=\(_verboseGrpcClientLogging)" - _debugVerboseGrpcServerLoggingFlag: "--debug-verbose-grpc-server-logging=\(_verboseGrpcServerLogging)" - - services: [Name=_]: #GrpcService & { - metadata: { - _component: "reporting" - name: Name - } - } - services: { - "postgres-reporting-data-server": {} - "reporting-public-api-v1alpha-server": #ExternalService - } - - deployments: [Name=_]: #ServerDeployment & { - _name: Name - _secretName: Reporting._secretName - _system: "reporting" - _container: { - image: _images[_name] - } - } - deployments: { - "postgres-reporting-data-server": { - _container: args: [ - _reportingCertCollectionFileFlag, - _debugVerboseGrpcServerLoggingFlag, - "--port=8443", - "--health-port=8080", - ] + _postgresConfig.flags + _tlsArgs - - _updateSchemaContainer: Container=#Container & { - image: _images[Container.name] - args: _postgresConfig.flags - imagePullPolicy?: _container.imagePullPolicy - } - - spec: template: spec: _initContainers: { - "update-reporting-schema": _updateSchemaContainer - } - } - - "reporting-public-api-v1alpha-server": { - _container: args: [ - _debugVerboseGrpcClientLoggingFlag, - _debugVerboseGrpcServerLoggingFlag, - _reportingCertCollectionFileFlag, - _akidToPrincipalMapFileFlag, - _measurementConsumerConfigFileFlag, - _signingPrivateKeyStoreDirFlag, - _encryptionKeyPairDirFlag, - _encryptionKeyPairConfigFileFlag, - _measurementSpecConfigFileFlag, - _knownEventGroupMetadataTypeFlag, - "--port=8443", - "--health-port=8080", - "--event-group-metadata-descriptor-cache-duration=1h", - ] + _tlsArgs + _internalApiTarget.args + _kingdomApiTarget.args - - spec: template: spec: { - _mounts: { - "mc-config": { - volume: secret: secretName: Reporting._mcConfigSecretName - volumeMount: mountPath: "/var/run/secrets/files/config/mc/" - } - "config-files": #ConfigMapMount - } - _dependencies: _ | *["postgres-reporting-data-server"] - } - } - } - - networkPolicies: [Name=_]: #NetworkPolicy & { - _name: Name - } - - networkPolicies: { - "internal-reporting-data-server": { - _app_label: "postgres-reporting-data-server-app" - _sourceMatchLabels: [ - "reporting-public-api-v1alpha-server-app", - ] - _egresses: { - // Needs to call out to Postgres server. - any: {} - } - } - "public-reporting-api-server": { - _app_label: "reporting-public-api-v1alpha-server-app" - _destinationMatchLabels: ["postgres-reporting-data-server-app"] - _ingresses: { - gRpc: { - ports: [{ - port: #GrpcPort - }] - } - } - _egresses: { - // Needs to call out to Kingdom. - any: {} - } - } - } - - configMaps: [Name=string]: #ConfigMap & { - metadata: name: Name - } - - serviceAccounts: [Name=string]: #ServiceAccount & { - metadata: name: Name - } -} diff --git a/src/main/k8s/testing/secretfiles/BUILD.bazel b/src/main/k8s/testing/secretfiles/BUILD.bazel index 2ad33a4952b..611511c00bc 100644 --- a/src/main/k8s/testing/secretfiles/BUILD.bazel +++ b/src/main/k8s/testing/secretfiles/BUILD.bazel @@ -19,7 +19,6 @@ package( ) exports_files([ - "measurement_spec_config.textproto", "metric_spec_config.textproto", ]) diff --git a/src/main/k8s/testing/secretfiles/measurement_spec_config.textproto b/src/main/k8s/testing/secretfiles/measurement_spec_config.textproto deleted file mode 100644 index 43bf7b2c38b..00000000000 --- a/src/main/k8s/testing/secretfiles/measurement_spec_config.textproto +++ /dev/null @@ -1,82 +0,0 @@ -# proto-file: wfa/measurement/config/reporting/measurement_spec_config.proto -# proto-message: MeasurementSpecConfig -reach_single_data_provider { - privacy_params { - epsilon: 0.000207 - delta: 1e-15 - } - vid_sampling_interval { - fixed_start { - start: 0f - width: 1f - } - } -} -reach { - privacy_params { - epsilon: 0.0007444 - delta: 1e-15 - } - vid_sampling_interval { - random_start { - width: 256 - num_vid_buckets: 300 - } - } -} -reach_and_frequency_single_data_provider { - reach_privacy_params { - epsilon: 0.000207 - delta: 1e-15 - } - frequency_privacy_params { - epsilon: 0.004728 - delta: 1e-15 - } - vid_sampling_interval { - fixed_start { - start: 0f - width: 1f - } - } -} -reach_and_frequency { - reach_privacy_params { - epsilon: 0.0007444 - delta: 1e-15 - } - frequency_privacy_params { - epsilon: 0.014638 - delta: 1e-15 - } - vid_sampling_interval { - random_start { - width: 256 - num_vid_buckets: 300 - } - } -} -impression { - privacy_params { - epsilon: 0.003592 - delta: 1e-15 - } - vid_sampling_interval { - fixed_start { - start: 0f - width: 1f - } - } -} -duration { - privacy_params { - epsilon: 0.007418 - delta: 1e-15 - } - vid_sampling_interval { - fixed_start { - start: 0f - width: 1f - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/api/v2alpha/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/api/v2alpha/BUILD.bazel index 45bc9a2576f..e53b6c87c1c 100644 --- a/src/main/kotlin/org/wfanet/measurement/api/v2alpha/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/api/v2alpha/BUILD.bazel @@ -52,7 +52,7 @@ kt_jvm_library( deps = [ "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:context_keys", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:measurement_principal", - "//src/main/kotlin/org/wfanet/measurement/common/api/grpc", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:akid_principal_server_interceptor", "//src/main/kotlin/org/wfanet/measurement/common/grpc:context", "//src/main/kotlin/org/wfanet/measurement/common/identity", "//src/main/proto/wfa/measurement/api/v2alpha:duchy_kt_jvm_proto", diff --git a/src/main/kotlin/org/wfanet/measurement/common/api/grpc/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/common/api/grpc/BUILD.bazel index 3a897c0ad81..4ffd63ed7ab 100644 --- a/src/main/kotlin/org/wfanet/measurement/common/api/grpc/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/common/api/grpc/BUILD.bazel @@ -3,7 +3,7 @@ load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") package(default_visibility = ["//visibility:public"]) kt_jvm_library( - name = "grpc", + name = "akid_principal_server_interceptor", srcs = ["AkidPrincipalServerInterceptor.kt"], deps = [ "//src/main/kotlin/org/wfanet/measurement/common/api:principal", @@ -14,3 +14,13 @@ kt_jvm_library( "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", ], ) + +kt_jvm_library( + name = "list_resources", + srcs = ["ListResources.kt"], + deps = [ + "@wfa_common_jvm//imports/java/com/google/protobuf", + "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", + "@wfa_rules_kotlin_jvm//imports/io/gprc/kotlin:stub", + ], +) diff --git a/src/main/kotlin/org/wfanet/measurement/common/api/grpc/ListResources.kt b/src/main/kotlin/org/wfanet/measurement/common/api/grpc/ListResources.kt new file mode 100644 index 00000000000..ebd768b2412 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/common/api/grpc/ListResources.kt @@ -0,0 +1,92 @@ +/* + * Copyright 2024 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.common.api.grpc + +import com.google.protobuf.Message +import io.grpc.kotlin.AbstractCoroutineStub +import kotlin.coroutines.coroutineContext +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.ensureActive +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.asFlow +import kotlinx.coroutines.flow.flattenConcat +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.map + +/** A [List] of resources from a paginated List method. */ +data class ResourceList( + val resources: List, + /** + * A token that can be sent on subsequent requests to retrieve the next page. If empty, there are + * no subsequent pages. + */ + val nextPageToken: String, +) : List by resources + +/** + * Lists resources from a paginated List method on this stub. + * + * @param pageToken page token for initial request + * @param list function which calls the appropriate List method on the stub + */ +fun > S.listResources( + pageToken: String = "", + list: suspend S.(pageToken: String) -> ResourceList, +): Flow> = + listResources(Int.MAX_VALUE, pageToken) { nextPageToken, _ -> list(nextPageToken) } + +/** + * Lists resources from a paginated List method on this stub. + * + * @param limit maximum number of resources to emit + * @param pageToken page token for initial request + * @param list function which calls the appropriate List method on the stub, returning no more than + * the specified remaining number of resources + */ +fun > S.listResources( + limit: Int, + pageToken: String = "", + list: suspend S.(pageToken: String, remaining: Int) -> ResourceList, +): Flow> { + require(limit > 0) { "limit must be positive" } + return flow { + var remaining: Int = limit + var nextPageToken = pageToken + + while (true) { + coroutineContext.ensureActive() + + val resourceList: ResourceList = list(nextPageToken, remaining) + require(resourceList.size <= remaining) { + "List call must ensure that limit is not exceeded. " + + "Returned ${resourceList.size} items when only $remaining were remaining" + } + emit(resourceList) + + remaining -= resourceList.size + nextPageToken = resourceList.nextPageToken + if (nextPageToken.isEmpty() || remaining == 0) { + break + } + } + } +} + +/** @see [flattenConcat] */ +@ExperimentalCoroutinesApi // Overloads experimental `flattenConcat` function. +fun Flow>.flattenConcat(): Flow = + map { it.asFlow() }.flattenConcat() diff --git a/src/main/kotlin/org/wfanet/measurement/dataprovider/RequisitionFulfiller.kt b/src/main/kotlin/org/wfanet/measurement/dataprovider/RequisitionFulfiller.kt index b529a5470db..6b65168cdb8 100644 --- a/src/main/kotlin/org/wfanet/measurement/dataprovider/RequisitionFulfiller.kt +++ b/src/main/kotlin/org/wfanet/measurement/dataprovider/RequisitionFulfiller.kt @@ -30,7 +30,6 @@ import org.wfanet.measurement.api.v2alpha.EncryptedMessage import org.wfanet.measurement.api.v2alpha.EncryptionPublicKey import org.wfanet.measurement.api.v2alpha.ListRequisitionsRequestKt.filter import org.wfanet.measurement.api.v2alpha.Measurement -import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey import org.wfanet.measurement.api.v2alpha.MeasurementSpec import org.wfanet.measurement.api.v2alpha.Requisition import org.wfanet.measurement.api.v2alpha.RequisitionKt.refusal @@ -75,11 +74,7 @@ abstract class RequisitionFulfiller( private val requisitionsStub: RequisitionsCoroutineStub, val throttler: Throttler, private val trustedCertificates: Map, - protected val measurementConsumerName: String, ) { - protected val measurementConsumerKey = - checkNotNull(MeasurementConsumerKey.fromName(measurementConsumerName)) - protected data class Specifications( val measurementSpec: MeasurementSpec, val requisitionSpec: RequisitionSpec, diff --git a/src/main/kotlin/org/wfanet/measurement/integration/common/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/integration/common/BUILD.bazel index 981e656f91e..8764c226d65 100644 --- a/src/main/kotlin/org/wfanet/measurement/integration/common/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/integration/common/BUILD.bazel @@ -27,6 +27,11 @@ java_library( "//src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha:exchanges_service", "//src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha:measurement_consumers_service", "//src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha:measurements_service", + "//src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha:model_lines_service", + "//src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha:model_releases_service", + "//src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha:model_rollouts_service", + "//src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha:model_suites_service", + "//src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha:populations_service", "//src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha:public_keys_service", "//src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha:requisitions_service", ], @@ -98,8 +103,10 @@ kt_jvm_library( "//src/main/k8s/testing/secretfiles:all_tink_keysets", ], deps = [ + "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key", "//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common:duchy_ids", "//src/main/kotlin/org/wfanet/measurement/loadtest/resourcesetup:resource_setup", + "//src/main/kotlin/org/wfanet/measurement/populationdataprovider:population_requisition_fulfiller", "//src/main/proto/wfa/measurement/config:duchy_cert_config_kt_jvm_proto", "//src/main/proto/wfa/measurement/internal/duchy/config:protocols_setup_config_kt_jvm_proto", "//src/main/proto/wfa/measurement/internal/kingdom:duchy_id_config_kt_jvm_proto", @@ -193,6 +200,36 @@ kt_jvm_library( ], ) +kt_jvm_library( + name = "in_process_population_requisition_fulfiller", + srcs = [ + "InProcessPopulationRequisitionFulfiller.kt", + ], + visibility = [ + "//src/main/kotlin/org/wfanet/measurement/integration:__subpackages__", + ], + deps = [ + ":all_kingdom_services", + ":configs", + ":synthetic_generation_specs", + "//src/main/kotlin/org/wfanet/measurement/api/v2alpha/testing", + "//src/main/kotlin/org/wfanet/measurement/common/identity", + "//src/main/kotlin/org/wfanet/measurement/eventdataprovider/privacybudgetmanagement/testing", + "//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common/service:data_services", + "//src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha:api_key_authentication_server_interceptor", + "//src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha:model_releases_service", + "//src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha:model_rollouts_service", + "//src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider:edp_simulator", + "//src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider:population_spec_converter", + "//src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider:synthetic_generator_event_query", + "//src/main/proto/wfa/measurement/api/v2alpha:event_annotations_kt_jvm_proto", + "@wfa_common_jvm//imports/java/io/grpc:core", + "@wfa_common_jvm//imports/java/org/junit", + "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/throttler", + ], +) + kt_jvm_library( name = "in_process_cmms_components", srcs = [ @@ -206,6 +243,7 @@ kt_jvm_library( ":in_process_duchy", ":in_process_edp_simulator", ":in_process_kingdom", + ":in_process_population_requisition_fulfiller", "//src/main/kotlin/org/wfanet/measurement/common/identity", "//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common:duchy_ids", "//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common/service:data_services", @@ -229,6 +267,7 @@ kt_jvm_library( "//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common/service:data_services", "//src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer:simulator", "//src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer:synthetic_generator_event_query", + "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha:metadata_principal_server_interceptor", "@wfa_common_jvm//imports/java/org/junit", "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/testing", @@ -280,6 +319,7 @@ kt_jvm_library( ], deps = [ ":in_process_cmms_components", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:list_resources", "//src/main/kotlin/org/wfanet/measurement/kingdom/batch:measurement_system_prober", "//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common/service:data_services", "@wfa_common_jvm//imports/java/com/google/common/truth", diff --git a/src/main/kotlin/org/wfanet/measurement/integration/common/Configs.kt b/src/main/kotlin/org/wfanet/measurement/integration/common/Configs.kt index dec0aadb3e6..7e2500c1d61 100644 --- a/src/main/kotlin/org/wfanet/measurement/integration/common/Configs.kt +++ b/src/main/kotlin/org/wfanet/measurement/integration/common/Configs.kt @@ -116,6 +116,8 @@ val ALL_EDP_WITHOUT_HMSS_CAPABILITIES_DISPLAY_NAMES = listOf("edp2") val ALL_EDP_DISPLAY_NAMES = ALL_EDP_WITH_HMSS_CAPABILITIES_DISPLAY_NAMES + ALL_EDP_WITHOUT_HMSS_CAPABILITIES_DISPLAY_NAMES +const val PDP_DISPLAY_NAME = "pdp1" + const val DUCHY_MILL_PARALLELISM = 3 const val MC_DISPLAY_NAME = "mc" diff --git a/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessCmmsComponents.kt b/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessCmmsComponents.kt index 09c48449221..f24b5b58ebe 100644 --- a/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessCmmsComponents.kt +++ b/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessCmmsComponents.kt @@ -17,6 +17,7 @@ package org.wfanet.measurement.integration.common import com.google.protobuf.ByteString +import com.google.protobuf.TypeRegistry import java.security.cert.X509Certificate import kotlinx.coroutines.runBlocking import org.junit.rules.TestRule @@ -28,8 +29,13 @@ import org.wfanet.measurement.api.v2alpha.DataProviderCertificateKey import org.wfanet.measurement.api.v2alpha.DataProviderKey import org.wfanet.measurement.api.v2alpha.EventGroup import org.wfanet.measurement.api.v2alpha.MeasurementConsumersGrpcKt +import org.wfanet.measurement.api.v2alpha.ModelProviderKey +import org.wfanet.measurement.api.v2alpha.PopulationKey import org.wfanet.measurement.api.v2alpha.event_group_metadata.testing.SyntheticEventGroupSpec import org.wfanet.measurement.api.v2alpha.event_group_metadata.testing.SyntheticPopulationSpec +import org.wfanet.measurement.api.v2alpha.event_templates.testing.Dummy +import org.wfanet.measurement.api.v2alpha.event_templates.testing.Person +import org.wfanet.measurement.api.v2alpha.event_templates.testing.TestEvent import org.wfanet.measurement.common.crypto.subjectKeyIdentifier import org.wfanet.measurement.common.crypto.tink.TinkPrivateKeyHandle import org.wfanet.measurement.common.identity.DuchyInfo @@ -37,18 +43,23 @@ import org.wfanet.measurement.common.identity.externalIdToApiId import org.wfanet.measurement.common.testing.ProviderRule import org.wfanet.measurement.common.testing.chainRulesSequentially import org.wfanet.measurement.config.DuchyCertConfig +import org.wfanet.measurement.dataprovider.DataProviderData import org.wfanet.measurement.kingdom.deploy.common.DuchyIds import org.wfanet.measurement.kingdom.deploy.common.HmssProtocolConfig import org.wfanet.measurement.kingdom.deploy.common.Llv2ProtocolConfig import org.wfanet.measurement.kingdom.deploy.common.RoLlv2ProtocolConfig import org.wfanet.measurement.kingdom.deploy.common.service.DataServices +import org.wfanet.measurement.kingdom.service.api.v2alpha.toPopulation +import org.wfanet.measurement.loadtest.dataprovider.toPopulationSpec import org.wfanet.measurement.loadtest.measurementconsumer.MeasurementConsumerData +import org.wfanet.measurement.loadtest.measurementconsumer.PopulationData import org.wfanet.measurement.loadtest.resourcesetup.DuchyCert import org.wfanet.measurement.loadtest.resourcesetup.EntityContent import org.wfanet.measurement.loadtest.resourcesetup.ResourceSetup import org.wfanet.measurement.loadtest.resourcesetup.Resources import org.wfanet.measurement.loadtest.resourcesetup.ResourcesKt.ResourceKt import org.wfanet.measurement.loadtest.resourcesetup.ResourcesKt.resource +import org.wfanet.measurement.populationdataprovider.PopulationInfo import org.wfanet.measurement.system.v1alpha.ComputationLogEntriesGrpcKt.ComputationLogEntriesCoroutineStub class InProcessCmmsComponents( @@ -107,6 +118,27 @@ class InProcessCmmsComponents( } } + private val populationRequisitionFulfiller: InProcessPopulationRequisitionFulfiller by lazy { + InProcessPopulationRequisitionFulfiller( + pdpData = + DataProviderData( + populationDataProviderResource.name, + PDP_DISPLAY_NAME, + loadEncryptionPrivateKey("${PDP_DISPLAY_NAME}_enc_private.tink"), + loadSigningKey("${PDP_DISPLAY_NAME}_cs_cert.der", "${PDP_DISPLAY_NAME}_cs_private.der"), + DataProviderCertificateKey.fromName( + populationDataProviderResource.dataProvider.certificate + )!!, + ), + populationDataProviderResource.name, + mapOf(populationKey to populationInfo), + typeRegistry, + kingdom.publicApiChannel, + TRUSTED_CERTIFICATES, + verboseGrpcLogging = false, + ) + } + val ruleChain: TestRule by lazy { chainRulesSequentially( kingdomDataServicesRule, @@ -126,17 +158,30 @@ class InProcessCmmsComponents( ApiKeysGrpcKt.ApiKeysCoroutineStub(kingdom.publicApiChannel) } + val modelProviderResourceName: String + get() = _modelProviderResourceName + + val typeRegistry: TypeRegistry + get() = _typeRegistry + private lateinit var mcResourceName: String private lateinit var apiAuthenticationKey: String private lateinit var edpDisplayNameToResourceMap: Map private lateinit var duchyCertMap: Map private lateinit var eventGroups: List + private lateinit var populationDataProviderResource: Resources.Resource + private lateinit var populationKey: PopulationKey + private lateinit var populationInfo: PopulationInfo + private lateinit var _typeRegistry: TypeRegistry + private lateinit var _modelProviderResourceName: String private suspend fun createAllResources() { val resourceSetup = ResourceSetup( internalAccountsClient = kingdom.internalAccountsClient, internalDataProvidersClient = kingdom.internalDataProvidersClient, + internalModelProvidersClient = kingdom.internalModelProvidersClient, + internalPopulationsClient = kingdom.internalPopulationsClient, accountsClient = publicAccountsClient, apiKeysClient = publicApiKeysClient, internalCertificatesClient = kingdom.internalCertificatesClient, @@ -169,6 +214,9 @@ class InProcessCmmsComponents( ResourceKt.dataProvider { certificate = externalDataProviderCertificateKeyName } } } + + createPopulationResources(resourceSetup) + // Create all duchy certificates. duchyCertMap = ALL_DUCHY_NAMES.associateWith { @@ -178,6 +226,37 @@ class InProcessCmmsComponents( } } + private suspend fun createPopulationResources(resourceSetup: ResourceSetup) { + val internalDataProvider = + resourceSetup.createInternalDataProvider(createEntityContent(PDP_DISPLAY_NAME)) + val externalDataProviderId = externalIdToApiId(internalDataProvider.externalDataProviderId) + val externalCertificateId = + externalIdToApiId(internalDataProvider.certificate.externalCertificateId) + val externalDataProviderResourceName = DataProviderKey(externalDataProviderId).toName() + val externalDataProviderCertificateKeyName = + DataProviderCertificateKey(externalDataProviderId, externalCertificateId).toName() + populationDataProviderResource = resource { + name = externalDataProviderResourceName + dataProvider = + ResourceKt.dataProvider { certificate = externalDataProviderCertificateKeyName } + } + + val internalModelProvider = resourceSetup.createInternalModelProvider() + _modelProviderResourceName = + ModelProviderKey(externalIdToApiId(internalModelProvider.externalModelProviderId)).toName() + + val population = resourceSetup.createInternalPopulation(internalDataProvider) + populationKey = PopulationKey.fromName(population.toPopulation().name)!! + populationInfo = + PopulationInfo( + SyntheticGenerationSpecs.SYNTHETIC_POPULATION_SPEC_LARGE.toPopulationSpec(), + TestEvent.getDescriptor(), + ) + + _typeRegistry = + TypeRegistry.newBuilder().add(listOf(Person.getDescriptor(), Dummy.getDescriptor())).build() + } + fun getMeasurementConsumerData(): MeasurementConsumerData { return MeasurementConsumerData( mcResourceName, @@ -187,6 +266,14 @@ class InProcessCmmsComponents( ) } + fun getPopulationData(): PopulationData { + return PopulationData( + populationDataProviderName = populationDataProviderResource.name, + populationInfo = populationInfo, + populationKey = populationKey, + ) + } + fun getDataProviderResourceNames(): List { return edpDisplayNameToResourceMap.values.map { it.name } } @@ -204,6 +291,8 @@ class InProcessCmmsComponents( } edpSimulators.forEach { it.start() } edpSimulators.forEach { it.waitUntilHealthy() } + + populationRequisitionFulfiller.start() } fun stopEdpSimulators() = runBlocking { edpSimulators.forEach { it.stop() } } @@ -215,6 +304,10 @@ class InProcessCmmsComponents( } } + fun stopPopulationRequisitionFulfillerDaemon() = runBlocking { + populationRequisitionFulfiller.stop() + } + override fun apply(statement: Statement, description: Description): Statement { return ruleChain.apply(statement, description) } @@ -224,7 +317,6 @@ class InProcessCmmsComponents( val MC_ENTITY_CONTENT: EntityContent = createEntityContent(MC_DISPLAY_NAME) val MC_ENCRYPTION_PRIVATE_KEY: TinkPrivateKeyHandle = loadEncryptionPrivateKey("${MC_DISPLAY_NAME}_enc_private.tink") - val TRUSTED_CERTIFICATES: Map = loadTestCertCollection("all_root_certs.pem").associateBy { checkNotNull(it.subjectKeyIdentifier) diff --git a/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessKingdom.kt b/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessKingdom.kt index fe10fc193b4..e5c56e28e77 100644 --- a/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessKingdom.kt +++ b/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessKingdom.kt @@ -38,6 +38,11 @@ import org.wfanet.measurement.internal.kingdom.ExchangesGrpcKt.ExchangesCoroutin import org.wfanet.measurement.internal.kingdom.MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineStub as InternalMeasurementConsumersCoroutineStub import org.wfanet.measurement.internal.kingdom.MeasurementLogEntriesGrpcKt.MeasurementLogEntriesCoroutineStub as InternalMeasurementLogEntriesCoroutineStub import org.wfanet.measurement.internal.kingdom.MeasurementsGrpcKt.MeasurementsCoroutineStub as InternalMeasurementsCoroutineStub +import org.wfanet.measurement.internal.kingdom.ModelLinesGrpcKt.ModelLinesCoroutineStub as InternalModelLinesCoroutineStub +import org.wfanet.measurement.internal.kingdom.ModelReleasesGrpcKt.ModelReleasesCoroutineStub as InternalModelReleasesCoroutineStub +import org.wfanet.measurement.internal.kingdom.ModelRolloutsGrpcKt.ModelRolloutsCoroutineStub as InternalModelRolloutsCoroutineStub +import org.wfanet.measurement.internal.kingdom.ModelSuitesGrpcKt.ModelSuitesCoroutineStub as InternalModelSuitesCoroutineStub +import org.wfanet.measurement.internal.kingdom.PopulationsGrpcKt.PopulationsCoroutineStub as InternalPopulationsCoroutineStub import org.wfanet.measurement.internal.kingdom.PublicKeysGrpcKt.PublicKeysCoroutineStub as InternalPublicKeysCoroutineStub import org.wfanet.measurement.internal.kingdom.RecurringExchangesGrpcKt.RecurringExchangesCoroutineStub as InternalRecurringExchangesCoroutineStub import org.wfanet.measurement.internal.kingdom.RequisitionsGrpcKt.RequisitionsCoroutineStub as InternalRequisitionsCoroutineStub @@ -54,6 +59,11 @@ import org.wfanet.measurement.kingdom.service.api.v2alpha.ExchangeStepsService import org.wfanet.measurement.kingdom.service.api.v2alpha.ExchangesService import org.wfanet.measurement.kingdom.service.api.v2alpha.MeasurementConsumersService import org.wfanet.measurement.kingdom.service.api.v2alpha.MeasurementsService +import org.wfanet.measurement.kingdom.service.api.v2alpha.ModelLinesService +import org.wfanet.measurement.kingdom.service.api.v2alpha.ModelReleasesService +import org.wfanet.measurement.kingdom.service.api.v2alpha.ModelRolloutsService +import org.wfanet.measurement.kingdom.service.api.v2alpha.ModelSuitesService +import org.wfanet.measurement.kingdom.service.api.v2alpha.PopulationsService import org.wfanet.measurement.kingdom.service.api.v2alpha.PublicKeysService import org.wfanet.measurement.kingdom.service.api.v2alpha.RequisitionsService import org.wfanet.measurement.kingdom.service.api.v2alpha.withAccountAuthenticationServerInterceptor @@ -112,6 +122,7 @@ class InProcessKingdom( private val internalRecurringExchangesClient by lazy { InternalRecurringExchangesCoroutineStub(internalApiChannel) } + private val internalDataServer = GrpcTestServerRule( logAllRequests = verboseGrpcLogging, @@ -165,6 +176,21 @@ class InProcessKingdom( RequisitionsService(internalRequisitionsClient) .withMetadataPrincipalIdentities() .withApiKeyAuthenticationServerInterceptor(internalApiKeysClient), + ModelRolloutsService(internalModelRolloutsClient) + .withMetadataPrincipalIdentities() + .withApiKeyAuthenticationServerInterceptor(internalApiKeysClient), + ModelReleasesService(internalModelReleasesClient) + .withMetadataPrincipalIdentities() + .withApiKeyAuthenticationServerInterceptor(internalApiKeysClient), + ModelSuitesService(internalModelSuitesClient) + .withMetadataPrincipalIdentities() + .withApiKeyAuthenticationServerInterceptor(internalApiKeysClient), + ModelLinesService(internalModelLinesClient) + .withMetadataPrincipalIdentities() + .withApiKeyAuthenticationServerInterceptor(internalApiKeysClient), + PopulationsService(internalPopulationsClient) + .withMetadataPrincipalIdentities() + .withApiKeyAuthenticationServerInterceptor(internalApiKeysClient), AccountsService(internalAccountsClient, redirectUri) .withAccountAuthenticationServerInterceptor(internalAccountsClient, redirectUri), MeasurementConsumersService(internalMeasurementConsumersClient) @@ -203,6 +229,36 @@ class InProcessKingdom( /** Provides access to Duchy Certificate creation without having multiple Duchy clients. */ val internalCertificatesClient by lazy { InternalCertificatesCoroutineStub(internalApiChannel) } + /** Provides access to ModelProvider creation. */ + val internalModelProvidersClient by lazy { + org.wfanet.measurement.internal.kingdom.ModelProvidersGrpcKt.ModelProvidersCoroutineStub( + internalApiChannel + ) + } + + /** Provides access to ModelRollout creation. */ + private val internalModelRolloutsClient by lazy { + InternalModelRolloutsCoroutineStub(internalApiChannel) + } + + /** Provides access to ModelRelease creation. */ + private val internalModelReleasesClient by lazy { + InternalModelReleasesCoroutineStub(internalApiChannel) + } + + /** Provides access to ModelSuite creation. */ + private val internalModelSuitesClient by lazy { + InternalModelSuitesCoroutineStub(internalApiChannel) + } + + /** Provides access to ModelLine creation. */ + private val internalModelLinesClient by lazy { + InternalModelLinesCoroutineStub(internalApiChannel) + } + + /** Provides access to Population creation. */ + val internalPopulationsClient by lazy { InternalPopulationsCoroutineStub(internalApiChannel) } + override fun apply(statement: Statement, description: Description): Statement { return chainRulesSequentially(internalDataServer, systemApiServer, publicApiServer) .apply(statement, description) diff --git a/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessLifeOfAMeasurementIntegrationTest.kt b/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessLifeOfAMeasurementIntegrationTest.kt index c8867408d4d..bf08b9e2d59 100644 --- a/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessLifeOfAMeasurementIntegrationTest.kt +++ b/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessLifeOfAMeasurementIntegrationTest.kt @@ -14,6 +14,8 @@ package org.wfanet.measurement.integration.common +import com.google.type.date +import java.time.Instant import kotlinx.coroutines.runBlocking import org.junit.After import org.junit.Before @@ -26,10 +28,27 @@ import org.wfanet.measurement.api.v2alpha.DataProvidersGrpcKt.DataProvidersCorou import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub import org.wfanet.measurement.api.v2alpha.MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineStub import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt.MeasurementsCoroutineStub +import org.wfanet.measurement.api.v2alpha.ModelLine +import org.wfanet.measurement.api.v2alpha.ModelLinesGrpcKt.ModelLinesCoroutineStub +import org.wfanet.measurement.api.v2alpha.ModelReleasesGrpcKt.ModelReleasesCoroutineStub +import org.wfanet.measurement.api.v2alpha.ModelRolloutsGrpcKt.ModelRolloutsCoroutineStub +import org.wfanet.measurement.api.v2alpha.ModelSuitesGrpcKt.ModelSuitesCoroutineStub import org.wfanet.measurement.api.v2alpha.ProtocolConfig.NoiseMechanism import org.wfanet.measurement.api.v2alpha.RequisitionsGrpcKt.RequisitionsCoroutineStub +import org.wfanet.measurement.api.v2alpha.createModelLineRequest +import org.wfanet.measurement.api.v2alpha.createModelReleaseRequest +import org.wfanet.measurement.api.v2alpha.createModelRolloutRequest +import org.wfanet.measurement.api.v2alpha.createModelSuiteRequest +import org.wfanet.measurement.api.v2alpha.dateInterval import org.wfanet.measurement.api.v2alpha.differentialPrivacyParams +import org.wfanet.measurement.api.v2alpha.event_templates.testing.Person +import org.wfanet.measurement.api.v2alpha.modelLine +import org.wfanet.measurement.api.v2alpha.modelRelease +import org.wfanet.measurement.api.v2alpha.modelRollout +import org.wfanet.measurement.api.v2alpha.modelSuite +import org.wfanet.measurement.common.identity.withPrincipalName import org.wfanet.measurement.common.testing.ProviderRule +import org.wfanet.measurement.common.toProtoTime import org.wfanet.measurement.kingdom.deploy.common.service.DataServices import org.wfanet.measurement.loadtest.measurementconsumer.MeasurementConsumerData import org.wfanet.measurement.loadtest.measurementconsumer.MeasurementConsumerSimulator @@ -73,6 +92,26 @@ abstract class InProcessLifeOfAMeasurementIntegrationTest( RequisitionsCoroutineStub(inProcessCmmsComponents.kingdom.publicApiChannel) } + private val publicModelSuitesClient by lazy { + ModelSuitesCoroutineStub(inProcessCmmsComponents.kingdom.publicApiChannel) + .withPrincipalName(inProcessCmmsComponents.modelProviderResourceName) + } + + private val publicModelLinesClient by lazy { + ModelLinesCoroutineStub(inProcessCmmsComponents.kingdom.publicApiChannel) + .withPrincipalName(inProcessCmmsComponents.modelProviderResourceName) + } + + private val publicModelReleasesClient by lazy { + ModelReleasesCoroutineStub(inProcessCmmsComponents.kingdom.publicApiChannel) + .withPrincipalName(inProcessCmmsComponents.modelProviderResourceName) + } + + private val publicModelRolloutClient by lazy { + ModelRolloutsCoroutineStub(inProcessCmmsComponents.kingdom.publicApiChannel) + .withPrincipalName(inProcessCmmsComponents.modelProviderResourceName) + } + @Before fun startDaemons() { inProcessCmmsComponents.startDaemons() @@ -116,6 +155,11 @@ abstract class InProcessLifeOfAMeasurementIntegrationTest( inProcessCmmsComponents.stopDuchyDaemons() } + @After + fun stopPopulationRequisitionFulfillerDaemon() { + inProcessCmmsComponents.stopPopulationRequisitionFulfillerDaemon() + } + @Test fun `create a Llv2 RF measurement and check the result is equal to the expected result`() = runBlocking { @@ -209,6 +253,59 @@ abstract class InProcessLifeOfAMeasurementIntegrationTest( // TODO(@renjiez): Add Multi-round test given the same input to verify correctness. + @Test + fun `create a population measurement`() = runBlocking { + val modelSuite = + publicModelSuitesClient.createModelSuite( + createModelSuiteRequest { + parent = inProcessCmmsComponents.modelProviderResourceName + modelSuite = modelSuite { displayName = MODEL_SUITE_DISPLAY_NAME } + } + ) + val modelLine = + publicModelLinesClient.createModelLine( + createModelLineRequest { + parent = modelSuite.name + modelLine = modelLine { + activeStartTime = MODEL_LINE_ACTIVE_START_TIME + type = ModelLine.Type.PROD + } + } + ) + + val populationData = inProcessCmmsComponents.getPopulationData() + + val modelRelease = + publicModelReleasesClient.createModelRelease( + createModelReleaseRequest { + parent = modelSuite.name + modelRelease = modelRelease { population = populationData.populationKey.toName() } + } + ) + + publicModelRolloutClient.createModelRollout( + createModelRolloutRequest { + parent = modelLine.name + modelRollout = modelRollout { + this.modelRelease = modelRelease.name + gradualRolloutPeriod = dateInterval { + startDate = ROLLOUT_PERIOD_START_DATE + endDate = ROLLOUT_PERIOD_END_DATE + } + } + } + ) + + // Use frontend simulator to create a population measurement + mcSimulator.testPopulation( + "1234", + populationData, + modelLine.name, + DEFAULT_POPULATION_FILTER_EXPRESSION, + inProcessCmmsComponents.typeRegistry, + ) + } + companion object { // Epsilon can vary from 0.0001 to 1.0, delta = 1e-15 is a realistic value. // Set epsilon higher without exceeding privacy budget so the noise is smaller in the @@ -218,6 +315,25 @@ abstract class InProcessLifeOfAMeasurementIntegrationTest( delta = 1e-15 } + private const val MODEL_SUITE_DISPLAY_NAME = "ModelSuite1" + + private val MODEL_LINE_ACTIVE_START_TIME = Instant.now().plusSeconds(2000L).toProtoTime() + + private const val DEFAULT_POPULATION_FILTER_EXPRESSION = + "person.age_group == ${Person.AgeGroup.YEARS_18_TO_34_VALUE}" + + private val ROLLOUT_PERIOD_START_DATE = date { + year = 2025 + month = 1 + day = 2 + } + + private val ROLLOUT_PERIOD_END_DATE = date { + year = 2025 + month = 1 + day = 3 + } + @BeforeClass @JvmStatic fun initConfig() { diff --git a/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessMeasurementSystemProberIntegrationTest.kt b/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessMeasurementSystemProberIntegrationTest.kt index 0a015ca80d3..d9a14f29152 100644 --- a/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessMeasurementSystemProberIntegrationTest.kt +++ b/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessMeasurementSystemProberIntegrationTest.kt @@ -22,6 +22,9 @@ import java.io.File import java.nio.file.Paths import java.time.Clock import java.time.Duration +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.toList import kotlinx.coroutines.runBlocking import org.junit.After import org.junit.Before @@ -30,13 +33,15 @@ import org.junit.Rule import org.junit.Test import org.wfanet.measurement.api.v2alpha.DataProvidersGrpcKt.DataProvidersCoroutineStub import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub -import org.wfanet.measurement.api.v2alpha.ListMeasurementsResponse import org.wfanet.measurement.api.v2alpha.Measurement import org.wfanet.measurement.api.v2alpha.MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineStub import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt.MeasurementsCoroutineStub import org.wfanet.measurement.api.v2alpha.RequisitionsGrpcKt.RequisitionsCoroutineStub import org.wfanet.measurement.api.v2alpha.listMeasurementsRequest import org.wfanet.measurement.api.withAuthenticationKey +import org.wfanet.measurement.common.api.grpc.ResourceList +import org.wfanet.measurement.common.api.grpc.flattenConcat +import org.wfanet.measurement.common.api.grpc.listResources import org.wfanet.measurement.common.getRuntimePath import org.wfanet.measurement.common.identity.withPrincipalName import org.wfanet.measurement.common.testing.ProviderRule @@ -129,33 +134,32 @@ abstract class InProcessMeasurementSystemProberIntegrationTest( assertThat(measurements.size).isEqualTo(1) } + @OptIn(ExperimentalCoroutinesApi::class) // For `flattenConcat`. private suspend fun listMeasurements(): List { - var nextPageToken = "" val measurementConsumerData = inProcessCmmsComponents.getMeasurementConsumerData() - do { - val response: ListMeasurementsResponse = - try { - publicMeasurementsClient - .withAuthenticationKey(measurementConsumerData.apiAuthenticationKey) - .listMeasurements( - listMeasurementsRequest { - parent = measurementConsumerData.name - pageToken = nextPageToken - } - ) - } catch (e: StatusException) { - throw Exception( - "Unable to list measurements for measurement consumer ${measurementConsumerData.name}", - e, - ) + val measurementLists: Flow> = + publicMeasurementsClient + .withAuthenticationKey(measurementConsumerData.apiAuthenticationKey) + .listResources { pageToken -> + val response = + try { + listMeasurements( + listMeasurementsRequest { + parent = measurementConsumerData.name + this.pageToken = pageToken + } + ) + } catch (e: StatusException) { + throw Exception( + "Unable to list measurements for measurement consumer ${measurementConsumerData.name}", + e, + ) + } + ResourceList(response.measurementsList, response.nextPageToken) } - if (response.measurementsList.isNotEmpty()) { - return response.measurementsList - } - nextPageToken = response.nextPageToken - } while (nextPageToken.isNotEmpty()) - return emptyList() + + return measurementLists.flattenConcat().toList() } companion object { diff --git a/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessPopulationRequisitionFulfiller.kt b/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessPopulationRequisitionFulfiller.kt new file mode 100644 index 00000000000..ffc6b666623 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessPopulationRequisitionFulfiller.kt @@ -0,0 +1,115 @@ +// Copyright 2024 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.integration.common + +import com.google.protobuf.ByteString +import com.google.protobuf.TypeRegistry +import io.grpc.Channel +import java.security.cert.X509Certificate +import java.time.Clock +import java.time.Duration +import java.util.logging.Level +import java.util.logging.Logger +import kotlin.coroutines.CoroutineContext +import kotlinx.coroutines.CoroutineExceptionHandler +import kotlinx.coroutines.CoroutineName +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import kotlinx.coroutines.cancelAndJoin +import kotlinx.coroutines.launch +import org.junit.rules.TestRule +import org.junit.runner.Description +import org.junit.runners.model.Statement +import org.wfanet.measurement.api.v2alpha.CertificatesGrpcKt +import org.wfanet.measurement.api.v2alpha.ModelReleasesGrpcKt +import org.wfanet.measurement.api.v2alpha.ModelRolloutsGrpcKt +import org.wfanet.measurement.api.v2alpha.PopulationKey +import org.wfanet.measurement.api.v2alpha.RequisitionsGrpcKt +import org.wfanet.measurement.common.identity.withPrincipalName +import org.wfanet.measurement.common.throttler.MinimumIntervalThrottler +import org.wfanet.measurement.dataprovider.DataProviderData +import org.wfanet.measurement.populationdataprovider.PopulationInfo +import org.wfanet.measurement.populationdataprovider.PopulationRequisitionFulfiller + +class InProcessPopulationRequisitionFulfiller( + val pdpData: DataProviderData, + val resourceName: String, + private val populationInfoMap: Map, + val typeRegistry: TypeRegistry, + private val kingdomPublicApiChannel: Channel, + private val trustedCertificates: Map, + val verboseGrpcLogging: Boolean = true, + daemonContext: CoroutineContext = Dispatchers.Default, +) : TestRule { + private val daemonScope = CoroutineScope(daemonContext) + private lateinit var populationRequisitionFulfillerJob: Job + + private val modelRolloutsClient by lazy { + ModelRolloutsGrpcKt.ModelRolloutsCoroutineStub(kingdomPublicApiChannel) + .withPrincipalName(resourceName) + } + + private val modelReleasesClient by lazy { + ModelReleasesGrpcKt.ModelReleasesCoroutineStub(kingdomPublicApiChannel) + .withPrincipalName(resourceName) + } + + private val certificatesClient by lazy { + CertificatesGrpcKt.CertificatesCoroutineStub(kingdomPublicApiChannel) + .withPrincipalName(resourceName) + } + + private val requisitionsClient by lazy { + RequisitionsGrpcKt.RequisitionsCoroutineStub(kingdomPublicApiChannel) + .withPrincipalName(resourceName) + } + + fun start() { + populationRequisitionFulfillerJob = + daemonScope.launch( + CoroutineName("Population Requisition Fulfiller Daemon") + + CoroutineExceptionHandler { _, e -> + logger.log(Level.SEVERE, e) { "Error in Population Requisition Fulfiller Daemon" } + } + ) { + val populationRequisitionFulfiller = + PopulationRequisitionFulfiller( + pdpData, + certificatesClient, + requisitionsClient, + MinimumIntervalThrottler(Clock.systemUTC(), Duration.ofMillis(1000)), + trustedCertificates, + modelRolloutsClient, + modelReleasesClient, + populationInfoMap, + typeRegistry, + ) + populationRequisitionFulfiller.run() + } + } + + suspend fun stop() { + populationRequisitionFulfillerJob.cancelAndJoin() + } + + override fun apply(statement: Statement, description: Description): Statement { + return statement + } + + companion object { + private val logger: Logger = Logger.getLogger(this::class.java.name) + } +} diff --git a/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/BUILD.bazel deleted file mode 100644 index 5be39730803..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/BUILD.bazel +++ /dev/null @@ -1,63 +0,0 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") - -package( - default_testonly = True, - default_visibility = [ - "//src/test/kotlin/org/wfanet/measurement/integration:__subpackages__", - ], -) - -kt_jvm_library( - name = "in_process_life_of_a_report_integration_test", - srcs = [ - "InProcessLifeOfAReportIntegrationTest.kt", - ], - data = [ - "//src/main/k8s/testing/secretfiles:secret_files", - ], - deps = [ - ":in_process_reporting_server", - "//src/main/kotlin/org/wfanet/measurement/integration/common/reporting/identity:reporting_principal_identity", - "//src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/testing:fake_measurements_service", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server:reporting_data_server", - "//src/main/proto/wfa/measurement/config/reporting:encryption_key_pair_config_kt_jvm_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:event_groups_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:reporting_sets_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:reports_service_kt_jvm_grpc_proto", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/identity", - "@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/duchy", - ], -) - -kt_jvm_library( - name = "in_process_reporting_server", - srcs = [ - "InProcessReportingServer.kt", - ], - data = [ - "//src/main/k8s/testing/secretfiles:secret_files", - ], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/integration/common/reporting/identity:metadata_principal_interceptor", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server:reporting_data_server", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api:cel_env_provider", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:event_groups_service", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:reporting_sets_service", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:reports_service", - "//src/main/proto/wfa/measurement/api/v2alpha:certificates_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:data_providers_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:event_group_metadata_descriptors_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:event_groups_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:measurement_consumers_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:measurements_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/config/reporting:encryption_key_pair_config_kt_jvm_proto", - "//src/main/proto/wfa/measurement/config/reporting:measurement_spec_config_kt_jvm_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:event_groups_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:reporting_sets_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:reports_service_kt_jvm_grpc_proto", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto/testing", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc/testing", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/testing", - ], -) diff --git a/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/InProcessLifeOfAReportIntegrationTest.kt b/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/InProcessLifeOfAReportIntegrationTest.kt deleted file mode 100644 index 936f8dfad6f..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/InProcessLifeOfAReportIntegrationTest.kt +++ /dev/null @@ -1,434 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.integration.common.reporting - -import com.google.common.truth.Truth.assertThat -import com.google.common.truth.extensions.proto.ProtoTruth.assertThat -import com.google.protobuf.ByteString -import com.google.protobuf.DescriptorProtos -import com.google.protobuf.duration -import com.google.protobuf.kotlin.toByteString -import com.google.protobuf.timestamp -import java.io.File -import java.nio.file.Paths -import java.time.Clock -import kotlinx.coroutines.launch -import kotlinx.coroutines.runBlocking -import org.junit.Rule -import org.junit.Test -import org.junit.rules.TestRule -import org.mockito.kotlin.any -import org.wfanet.measurement.api.v2alpha.Certificate -import org.wfanet.measurement.api.v2alpha.CertificatesGrpcKt -import org.wfanet.measurement.api.v2alpha.DataProvidersGrpcKt -import org.wfanet.measurement.api.v2alpha.EncryptionPublicKey -import org.wfanet.measurement.api.v2alpha.EventGroup -import org.wfanet.measurement.api.v2alpha.EventGroupKt -import org.wfanet.measurement.api.v2alpha.EventGroupMetadataDescriptor -import org.wfanet.measurement.api.v2alpha.EventGroupMetadataDescriptorsGrpcKt -import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt -import org.wfanet.measurement.api.v2alpha.MeasurementConsumer -import org.wfanet.measurement.api.v2alpha.MeasurementConsumersGrpcKt -import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt -import org.wfanet.measurement.api.v2alpha.certificate -import org.wfanet.measurement.api.v2alpha.dataProvider -import org.wfanet.measurement.api.v2alpha.eventGroup -import org.wfanet.measurement.api.v2alpha.eventGroupMetadataDescriptor -import org.wfanet.measurement.api.v2alpha.listEventGroupMetadataDescriptorsResponse -import org.wfanet.measurement.api.v2alpha.listEventGroupsResponse -import org.wfanet.measurement.api.v2alpha.measurementConsumer -import org.wfanet.measurement.common.crypto.SigningKeyHandle -import org.wfanet.measurement.common.crypto.readCertificate -import org.wfanet.measurement.common.crypto.readCertificateCollection -import org.wfanet.measurement.common.crypto.subjectKeyIdentifier -import org.wfanet.measurement.common.crypto.testing.loadSigningKey -import org.wfanet.measurement.common.crypto.tink.loadPublicKey -import org.wfanet.measurement.common.getRuntimePath -import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule -import org.wfanet.measurement.common.grpc.testing.mockService -import org.wfanet.measurement.common.identity.RandomIdGenerator -import org.wfanet.measurement.common.pack -import org.wfanet.measurement.common.readByteString -import org.wfanet.measurement.common.testing.chainRulesSequentially -import org.wfanet.measurement.config.reporting.EncryptionKeyPairConfig -import org.wfanet.measurement.config.reporting.EncryptionKeyPairConfigKt.keyPair -import org.wfanet.measurement.config.reporting.EncryptionKeyPairConfigKt.principalKeyPairs -import org.wfanet.measurement.config.reporting.encryptionKeyPairConfig -import org.wfanet.measurement.config.reporting.measurementConsumerConfig -import org.wfanet.measurement.consent.client.common.encryptMessage -import org.wfanet.measurement.consent.client.common.toEncryptionPublicKey -import org.wfanet.measurement.consent.client.common.toPublicKeyHandle -import org.wfanet.measurement.consent.client.measurementconsumer.signEncryptionPublicKey -import org.wfanet.measurement.integration.common.reporting.identity.withPrincipalName -import org.wfanet.measurement.kingdom.service.api.v2alpha.testing.FakeMeasurementsService -import org.wfanet.measurement.reporting.deploy.common.server.ReportingDataServer -import org.wfanet.measurement.reporting.v1alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub -import org.wfanet.measurement.reporting.v1alpha.ListEventGroupsResponse -import org.wfanet.measurement.reporting.v1alpha.ListReportingSetsResponse -import org.wfanet.measurement.reporting.v1alpha.ListReportsResponse -import org.wfanet.measurement.reporting.v1alpha.Metric -import org.wfanet.measurement.reporting.v1alpha.MetricKt -import org.wfanet.measurement.reporting.v1alpha.Report -import org.wfanet.measurement.reporting.v1alpha.ReportKt -import org.wfanet.measurement.reporting.v1alpha.ReportingSet -import org.wfanet.measurement.reporting.v1alpha.ReportingSetsGrpcKt.ReportingSetsCoroutineStub -import org.wfanet.measurement.reporting.v1alpha.ReportsGrpcKt.ReportsCoroutineStub -import org.wfanet.measurement.reporting.v1alpha.copy -import org.wfanet.measurement.reporting.v1alpha.createReportRequest -import org.wfanet.measurement.reporting.v1alpha.createReportingSetRequest -import org.wfanet.measurement.reporting.v1alpha.getReportRequest -import org.wfanet.measurement.reporting.v1alpha.listEventGroupsRequest -import org.wfanet.measurement.reporting.v1alpha.listReportingSetsRequest -import org.wfanet.measurement.reporting.v1alpha.listReportsRequest -import org.wfanet.measurement.reporting.v1alpha.metric -import org.wfanet.measurement.reporting.v1alpha.periodicTimeInterval -import org.wfanet.measurement.reporting.v1alpha.report -import org.wfanet.measurement.reporting.v1alpha.reportingSet - -private const val NUM_SET_OPERATIONS = 50 - -/** - * Test that everything is wired up properly. - * - * This is abstract so that different implementations of dependencies can all run the same tests - * easily. - */ -abstract class InProcessLifeOfAReportIntegrationTest { - abstract val reportingServerDataServices: ReportingDataServer.Services - - private val publicKingdomCertificatesMock: CertificatesGrpcKt.CertificatesCoroutineImplBase = - mockService { - onBlocking { getCertificate(any()) }.thenReturn(CERTIFICATE) - } - private val publicKingdomDataProvidersMock: DataProvidersGrpcKt.DataProvidersCoroutineImplBase = - mockService { - onBlocking { getDataProvider(any()) }.thenReturn(DATA_PROVIDER) - } - private val publicKingdomEventGroupsMock: EventGroupsGrpcKt.EventGroupsCoroutineImplBase = - mockService { - onBlocking { listEventGroups(any()) } - .thenReturn(listEventGroupsResponse { eventGroups += EVENT_GROUP }) - } - private val publicKingdomEventGroupMetadataDescriptorsMock: - EventGroupMetadataDescriptorsGrpcKt.EventGroupMetadataDescriptorsCoroutineImplBase = - mockService { - onBlocking { listEventGroupMetadataDescriptors(any()) } - .thenReturn( - listEventGroupMetadataDescriptorsResponse { - eventGroupMetadataDescriptors += EVENT_GROUP_METADATA_DESCRIPTOR - } - ) - } - private val publicKingdomMeasurementConsumersMock: - MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineImplBase = - mockService { - onBlocking { getMeasurementConsumer(any()) }.thenReturn(MEASUREMENT_CONSUMER) - } - private val publicFakeKingdomMeasurementsService: - MeasurementsGrpcKt.MeasurementsCoroutineImplBase = - FakeMeasurementsService( - RandomIdGenerator(Clock.systemUTC()), - EDP_SIGNING_KEY_HANDLE, - DATA_PROVIDER_CERTIFICATE_NAME, - ) - - private val publicKingdomServer = GrpcTestServerRule { - addService(publicKingdomCertificatesMock) - addService(publicKingdomDataProvidersMock) - addService(publicKingdomEventGroupsMock) - addService(publicKingdomEventGroupMetadataDescriptorsMock) - addService(publicKingdomMeasurementConsumersMock) - addService(publicFakeKingdomMeasurementsService) - } - - private val encryptionKeyPairConfig: EncryptionKeyPairConfig = encryptionKeyPairConfig { - principalKeyPairs += principalKeyPairs { - principal = MEASUREMENT_CONSUMER_NAME - keyPairs += keyPair { - publicKeyFile = "mc_enc_public.tink" - privateKeyFile = "mc_enc_private.tink" - } - } - } - - private val measurementConsumerConfig = measurementConsumerConfig { - apiKey = API_KEY - signingCertificateName = MEASUREMENT_CONSUMER_CERTIFICATE_NAME - signingPrivateKeyPath = MC_SIGNING_PRIVATE_KEY_PATH - } - - private val reportingServer: InProcessReportingServer by lazy { - InProcessReportingServer( - reportingServerDataServices, - publicKingdomServer.channel, - encryptionKeyPairConfig, - SECRETS_DIR, - measurementConsumerConfig, - TRUSTED_CERTIFICATES, - verboseGrpcLogging = false, - ) - } - - @get:Rule - val ruleChain: TestRule by lazy { chainRulesSequentially(publicKingdomServer, reportingServer) } - - private val publicEventGroupsClient by lazy { - EventGroupsCoroutineStub(reportingServer.publicApiChannel) - } - private val publicReportingSetsClient by lazy { - ReportingSetsCoroutineStub(reportingServer.publicApiChannel) - } - private val publicReportsClient by lazy { ReportsCoroutineStub(reportingServer.publicApiChannel) } - - @Test - fun `create Report and get the result successfully`() = runBlocking { - createReportingSet("1", MEASUREMENT_CONSUMER_NAME) - createReportingSet("2", MEASUREMENT_CONSUMER_NAME) - createReportingSet("3", MEASUREMENT_CONSUMER_NAME) - - val createdReport = createReport("1234", MEASUREMENT_CONSUMER_NAME) - val reports = listReports(MEASUREMENT_CONSUMER_NAME) - assertThat(reports.reportsList).hasSize(1) - val completedReport = getReport(createdReport.name, createdReport.measurementConsumer) - assertThat(assertThat(completedReport.state).isEqualTo(Report.State.SUCCEEDED)) - val reportResult = computeReportResult(completedReport) - // each measurement has a result of 100.0 and there are two time intervals - assertThat(reportResult).isEqualTo(200.0 * NUM_SET_OPERATIONS) - } - - @Test - fun `create multiple Reports concurrently successfully`() = runBlocking { - createReportingSet("1", MEASUREMENT_CONSUMER_NAME) - createReportingSet("2", MEASUREMENT_CONSUMER_NAME) - createReportingSet("3", MEASUREMENT_CONSUMER_NAME) - launch { createReport("5", MEASUREMENT_CONSUMER_NAME, true) } - for (i in 1..4) { - launch { createReport("$i", MEASUREMENT_CONSUMER_NAME) } - } - } - - private fun computeReportResult(completedReport: Report): Double { - var sum = 0.0 - completedReport.result.scalarTable.columnsList.forEach { column -> - column.setOperationsList.forEach { sum += it } - } - return sum - } - - private suspend fun listEventGroups(measurementConsumerName: String): ListEventGroupsResponse { - return publicEventGroupsClient - .withPrincipalName(measurementConsumerName) - .listEventGroups( - listEventGroupsRequest { parent = "$measurementConsumerName/dataProviders/-" } - ) - } - - private suspend fun createReportingSet( - runId: String, - measurementConsumerName: String, - ): ReportingSet { - val eventGroupsList = listEventGroups(measurementConsumerName).eventGroupsList - return publicReportingSetsClient - .withPrincipalName(measurementConsumerName) - .createReportingSet( - createReportingSetRequest { - parent = measurementConsumerName - reportingSet = reportingSet { - displayName = "reporting-set-$runId" - eventGroups += eventGroupsList.map { it.name } - } - } - ) - } - - private suspend fun listReportingSets( - measurementConsumerName: String - ): ListReportingSetsResponse { - return publicReportingSetsClient - .withPrincipalName(measurementConsumerName) - .listReportingSets(listReportingSetsRequest { parent = measurementConsumerName }) - } - - private suspend fun createReport( - runId: String, - measurementConsumerName: String, - cumulative: Boolean = false, - ): Report { - val eventGroupsList = listEventGroups(measurementConsumerName).eventGroupsList - val reportingSets = listReportingSets(measurementConsumerName).reportingSetsList - assertThat(reportingSets.size).isAtLeast(3) - val createReportRequest = createReportRequest { - parent = measurementConsumerName - report = report { - measurementConsumer = measurementConsumerName - reportIdempotencyKey = runId - eventGroupUniverse = - ReportKt.eventGroupUniverse { - eventGroupsList.forEach { - eventGroupEntries += ReportKt.EventGroupUniverseKt.eventGroupEntry { key = it.name } - } - } - periodicTimeInterval = periodicTimeInterval { - startTime = timestamp { seconds = 100 } - increment = duration { seconds = 5 } - intervalCount = 2 - } - metrics += metric { - this.cumulative = cumulative - impressionCount = MetricKt.impressionCountParams { maximumFrequencyPerUser = 5 } - val setOperation = - MetricKt.namedSetOperation { - uniqueName = "set-operation" - setOperation = - MetricKt.setOperation { - type = Metric.SetOperation.Type.UNION - lhs = - MetricKt.SetOperationKt.operand { - operation = - MetricKt.setOperation { - type = Metric.SetOperation.Type.UNION - lhs = - MetricKt.SetOperationKt.operand { reportingSet = reportingSets[0].name } - rhs = - MetricKt.SetOperationKt.operand { reportingSet = reportingSets[1].name } - } - } - rhs = MetricKt.SetOperationKt.operand { reportingSet = reportingSets[2].name } - } - } - - for (i in 1..NUM_SET_OPERATIONS) { - setOperations += setOperation.copy { uniqueName = "$uniqueName-$i" } - } - } - } - } - - val report = - publicReportsClient - .withPrincipalName(measurementConsumerName) - .createReport(createReportRequest) - - // Verify concurrent operations process the metrics without skipping set operations. - assertThat(report.metricsList) - .ignoringRepeatedFieldOrder() - .containsExactlyElementsIn(createReportRequest.report.metricsList) - return report - } - - private suspend fun getReport(reportName: String, principalName: String): Report { - return publicReportsClient - .withPrincipalName(principalName) - .getReport(getReportRequest { name = reportName }) - } - - private suspend fun listReports(measurementConsumerName: String): ListReportsResponse { - return publicReportsClient - .withPrincipalName(measurementConsumerName) - .listReports(listReportsRequest { parent = measurementConsumerName }) - } - - companion object { - private val SECRETS_DIR: File = - getRuntimePath( - Paths.get("wfa_measurement_system", "src", "main", "k8s", "testing", "secretfiles") - )!! - .toFile() - - private val TRUSTED_CERTIFICATES = - readCertificateCollection(SECRETS_DIR.resolve("all_root_certs.pem")).associateBy { - it.subjectKeyIdentifier!! - } - - private val MC_CERTIFICATE_DER: ByteString = - SECRETS_DIR.resolve("mc_cs_cert.der").readByteString() - private val MC_SIGNING_KEY_HANDLE: SigningKeyHandle = - loadSigningKey( - SECRETS_DIR.resolve("mc_cs_cert.der"), - SECRETS_DIR.resolve("mc_cs_private.der"), - ) - private val MC_ENCRYPTION_PUBLIC_KEY: EncryptionPublicKey = - loadPublicKey(SECRETS_DIR.resolve("mc_enc_public.tink")).toEncryptionPublicKey() - private const val MC_SIGNING_PRIVATE_KEY_PATH = "mc_cs_private.der" - private const val API_KEY = "AAAAAAAAAHs" - const val MEASUREMENT_CONSUMER_NAME = "measurementConsumers/AAAAAAAAAHs" - private const val MEASUREMENT_CONSUMER_CERTIFICATE_NAME = - "$MEASUREMENT_CONSUMER_NAME/certificates/AAAAAAAAAHs" - - private const val DATA_PROVIDER_NAME = "dataProviders/AAAAAAAAAHs" - private const val DATA_PROVIDER_CERTIFICATE_NAME = - "$DATA_PROVIDER_NAME/certificates/AAAAAAAAAHs" - private val EDP_CERTIFICATE_DER: ByteString = - readCertificate(SECRETS_DIR.resolve("edp1_cs_cert.der").readByteString()) - .encoded - .toByteString() - private val EDP_SIGNING_KEY_HANDLE: SigningKeyHandle = - loadSigningKey( - SECRETS_DIR.resolve("edp1_cs_cert.der"), - SECRETS_DIR.resolve("edp1_cs_private.der"), - ) - - private val CERTIFICATE: Certificate = certificate { - name = DATA_PROVIDER_CERTIFICATE_NAME - x509Der = EDP_CERTIFICATE_DER - } - - private val DATA_PROVIDER = dataProvider { - name = DATA_PROVIDER_NAME - certificate = DATA_PROVIDER_CERTIFICATE_NAME - certificateDer = EDP_CERTIFICATE_DER - publicKey = - signEncryptionPublicKey( - loadPublicKey(SECRETS_DIR.resolve("edp1_enc_public.tink")).toEncryptionPublicKey(), - EDP_SIGNING_KEY_HANDLE, - ) - } - - private val MEASUREMENT_CONSUMER: MeasurementConsumer = measurementConsumer { - name = MEASUREMENT_CONSUMER_NAME - certificate = MEASUREMENT_CONSUMER_CERTIFICATE_NAME - certificateDer = MC_CERTIFICATE_DER - publicKey = signEncryptionPublicKey(MC_ENCRYPTION_PUBLIC_KEY, MC_SIGNING_KEY_HANDLE) - } - - private const val EVENT_GROUP_NAME = "$DATA_PROVIDER_NAME/eventGroups/AAAAAAAAAHs" - private val VID_MODEL_LINES = listOf("model1", "model2") - private val EVENT_TEMPLATE_TYPES = listOf("type1", "type2") - private val EVENT_TEMPLATES = - EVENT_TEMPLATE_TYPES.map { type -> EventGroupKt.eventTemplate { this.type = type } } - - private val EVENT_GROUP: EventGroup = eventGroup { - name = EVENT_GROUP_NAME - measurementConsumer = MEASUREMENT_CONSUMER_NAME - eventGroupReferenceId = "aaa" - measurementConsumerPublicKey = MEASUREMENT_CONSUMER.publicKey.message - vidModelLines.addAll(VID_MODEL_LINES) - eventTemplates.addAll(EVENT_TEMPLATES) - encryptedMetadata = - MC_ENCRYPTION_PUBLIC_KEY.toPublicKeyHandle() - .encryptMessage(EventGroup.Metadata.getDefaultInstance().pack()) - } - - private const val EVENT_GROUP_METADATA_DESCRIPTOR_NAME = - "$DATA_PROVIDER_NAME/eventGroupMetadataDescriptors/AAAAAAAAAHs" - private val FILE_DESCRIPTOR_SET = DescriptorProtos.FileDescriptorSet.getDefaultInstance() - - private val EVENT_GROUP_METADATA_DESCRIPTOR: EventGroupMetadataDescriptor = - eventGroupMetadataDescriptor { - name = EVENT_GROUP_METADATA_DESCRIPTOR_NAME - descriptorSet = FILE_DESCRIPTOR_SET - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/InProcessReportingServer.kt b/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/InProcessReportingServer.kt deleted file mode 100644 index 483a04abc22..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/InProcessReportingServer.kt +++ /dev/null @@ -1,280 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.integration.common.reporting - -import com.google.protobuf.ByteString -import io.grpc.Channel -import java.io.File -import java.security.SecureRandom -import java.security.cert.X509Certificate -import java.time.Duration -import java.util.logging.Logger -import kotlin.random.asKotlinRandom -import org.junit.rules.TestRule -import org.junit.runner.Description -import org.junit.runners.model.Statement -import org.wfanet.measurement.api.v2alpha.CertificatesGrpcKt.CertificatesCoroutineStub as PublicKingdomCertificatesCoroutineStub -import org.wfanet.measurement.api.v2alpha.DataProvidersGrpcKt.DataProvidersCoroutineStub as PublicKingdomDataProvidersCoroutineStub -import org.wfanet.measurement.api.v2alpha.EventGroupMetadataDescriptorsGrpcKt.EventGroupMetadataDescriptorsCoroutineStub as PublicKingdomEventGroupMetadataDescriptorsCoroutineStub -import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub as PublicKingdomEventGroupsCoroutineStub -import org.wfanet.measurement.api.v2alpha.MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineStub as PublicKingdomMeasurementConsumersCoroutineStub -import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt.MeasurementsCoroutineStub as PublicKingdomMeasurementsCoroutineStub -import org.wfanet.measurement.api.withAuthenticationKey -import org.wfanet.measurement.common.crypto.tink.loadPrivateKey -import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule -import org.wfanet.measurement.common.grpc.withVerboseLogging -import org.wfanet.measurement.common.readByteString -import org.wfanet.measurement.common.testing.chainRulesSequentially -import org.wfanet.measurement.config.reporting.EncryptionKeyPairConfig -import org.wfanet.measurement.config.reporting.MeasurementConsumerConfig -import org.wfanet.measurement.config.reporting.MeasurementSpecConfigKt -import org.wfanet.measurement.config.reporting.measurementSpecConfig -import org.wfanet.measurement.integration.common.reporting.identity.withMetadataPrincipalIdentities -import org.wfanet.measurement.internal.reporting.MeasurementsGrpcKt.MeasurementsCoroutineStub as InternalMeasurementsCoroutineStub -import org.wfanet.measurement.internal.reporting.ReportingSetsGrpcKt.ReportingSetsCoroutineStub as InternalReportingSetsCoroutineStub -import org.wfanet.measurement.internal.reporting.ReportsGrpcKt.ReportsCoroutineStub as InternalReportsCoroutineStub -import org.wfanet.measurement.reporting.deploy.common.server.ReportingDataServer -import org.wfanet.measurement.reporting.deploy.common.server.ReportingDataServer.Companion.toList -import org.wfanet.measurement.reporting.service.api.CelEnvCacheProvider -import org.wfanet.measurement.reporting.service.api.InMemoryEncryptionKeyPairStore -import org.wfanet.measurement.reporting.service.api.v1alpha.EventGroupsService -import org.wfanet.measurement.reporting.service.api.v1alpha.ReportingSetsService -import org.wfanet.measurement.reporting.service.api.v1alpha.ReportsService -import org.wfanet.measurement.reporting.v1alpha.EventGroup - -/** TestRule that starts and stops all Reporting Server gRPC services. */ -class InProcessReportingServer( - private val reportingServerDataServices: ReportingDataServer.Services, - private val publicKingdomChannel: Channel, - private val encryptionKeyPairConfig: EncryptionKeyPairConfig, - private val signingPrivateKeyDir: File, - private val measurementConsumerConfig: MeasurementConsumerConfig, - trustedCertificates: Map, - private val verboseGrpcLogging: Boolean = true, -) : TestRule { - private val publicKingdomMeasurementConsumersClient by lazy { - PublicKingdomMeasurementConsumersCoroutineStub(publicKingdomChannel) - } - private val publicKingdomMeasurementsClient by lazy { - PublicKingdomMeasurementsCoroutineStub(publicKingdomChannel) - } - private val publicKingdomCertificatesClient by lazy { - PublicKingdomCertificatesCoroutineStub(publicKingdomChannel) - } - private val publicKingdomDataProvidersClient by lazy { - PublicKingdomDataProvidersCoroutineStub(publicKingdomChannel) - } - private val publicKingdomEventGroupsClient by lazy { - PublicKingdomEventGroupsCoroutineStub(publicKingdomChannel) - } - private val publicKingdomEventGroupMetadataDescriptorsClient by lazy { - PublicKingdomEventGroupMetadataDescriptorsCoroutineStub(publicKingdomChannel) - } - - private val internalApiChannel by lazy { internalDataServer.channel } - private val internalMeasurementsClient by lazy { - InternalMeasurementsCoroutineStub(internalApiChannel) - } - private val internalReportingSetsClient by lazy { - InternalReportingSetsCoroutineStub(internalApiChannel) - } - private val internalReportsClient by lazy { InternalReportsCoroutineStub(internalApiChannel) } - - private val internalDataServer = - GrpcTestServerRule(logAllRequests = verboseGrpcLogging) { - logger.info("Building Reporting Server's internal Data services") - reportingServerDataServices.toList().forEach { - addService(it.withVerboseLogging(verboseGrpcLogging)) - } - } - - private val publicApiServer = - GrpcTestServerRule(logAllRequests = verboseGrpcLogging) { - logger.info("Building Reporting Server's public API services") - - val encryptionKeyPairStore = - InMemoryEncryptionKeyPairStore( - encryptionKeyPairConfig.principalKeyPairsList.associateBy( - { it.principal }, - { - it.keyPairsList.map { keyPair -> - Pair( - signingPrivateKeyDir.resolve(keyPair.publicKeyFile).readByteString(), - loadPrivateKey(signingPrivateKeyDir.resolve(keyPair.privateKeyFile)), - ) - } - }, - ) - ) - - val celEnvCacheProvider = - CelEnvCacheProvider( - publicKingdomEventGroupMetadataDescriptorsClient.withAuthenticationKey( - measurementConsumerConfig.apiKey - ), - EventGroup.getDescriptor(), - Duration.ofSeconds(5), - knownMetadataTypes = emptyList(), - ) - - listOf( - EventGroupsService( - publicKingdomEventGroupsClient, - encryptionKeyPairStore, - celEnvCacheProvider, - ) - .withMetadataPrincipalIdentities(measurementConsumerConfig), - ReportingSetsService(internalReportingSetsClient) - .withMetadataPrincipalIdentities(measurementConsumerConfig), - ReportsService( - internalReportsClient, - internalReportingSetsClient, - internalMeasurementsClient, - publicKingdomDataProvidersClient, - publicKingdomMeasurementConsumersClient, - publicKingdomMeasurementsClient, - publicKingdomCertificatesClient, - encryptionKeyPairStore, - SecureRandom().asKotlinRandom(), - signingPrivateKeyDir, - trustedCertificates, - MEASUREMENT_SPEC_CONFIG, - ) - .withMetadataPrincipalIdentities(measurementConsumerConfig), - ) - .forEach { addService(it.withVerboseLogging(verboseGrpcLogging)) } - } - - /** Provides a gRPC channel to the Reporting Server's public API. */ - val publicApiChannel: Channel - get() = publicApiServer.channel - - override fun apply(statement: Statement, description: Description): Statement { - return chainRulesSequentially(internalDataServer, publicApiServer).apply(statement, description) - } - - companion object { - private val logger: Logger = Logger.getLogger(this::class.java.name) - - private val MEASUREMENT_SPEC_CONFIG = measurementSpecConfig { - reachSingleDataProvider = - MeasurementSpecConfigKt.reachSingleDataProvider { - privacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.000207 - delta = 1e-15 - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - fixedStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.fixedStart { - start = 0f - width = 1f - } - } - } - reach = - MeasurementSpecConfigKt.reach { - privacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.0007444 - delta = 1e-15 - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - randomStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.randomStart { - width = 256 - numVidBuckets = 300 - } - } - } - reachAndFrequencySingleDataProvider = - MeasurementSpecConfigKt.reachAndFrequencySingleDataProvider { - reachPrivacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.000207 - delta = 1e-15 - } - frequencyPrivacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.004728 - delta = 1e-15 - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - fixedStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.fixedStart { - start = 0f - width = 1f - } - } - } - reachAndFrequency = - MeasurementSpecConfigKt.reachAndFrequency { - reachPrivacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.0007444 - delta = 1e-15 - } - frequencyPrivacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.014638 - delta = 1e-15 - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - randomStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.randomStart { - width = 256 - numVidBuckets = 300 - } - } - } - impression = - MeasurementSpecConfigKt.impression { - privacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.003592 - delta = 1e-15 - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - fixedStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.fixedStart { - start = 0f - width = 1f - } - } - } - duration = - MeasurementSpecConfigKt.duration { - privacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.007418 - delta = 1e-15 - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - fixedStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.fixedStart { - start = 0f - width = 1f - } - } - } - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/identity/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/identity/BUILD.bazel deleted file mode 100644 index a1e471db718..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/identity/BUILD.bazel +++ /dev/null @@ -1,34 +0,0 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") - -package( - default_testonly = True, - default_visibility = [ - "//src/main/kotlin/org/wfanet/measurement/integration/common/reporting:__subpackages__", - ], -) - -kt_jvm_library( - name = "metadata_principal_interceptor", - srcs = ["MetadataPrincipalServerInterceptor.kt"], - deps = [ - ":reporting_principal_identity", - "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:context_keys", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:context_keys", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:principal_server_interceptor", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:reporting_principal", - "//src/main/proto/wfa/measurement/config:duchy_cert_config_kt_jvm_proto", - "@wfa_common_jvm//imports/java/io/grpc:api", - "@wfa_common_jvm//imports/java/io/grpc/stub", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", - ], -) - -kt_jvm_library( - name = "reporting_principal_identity", - srcs = ["ReportingPrincipalIdentity.kt"], - deps = [ - "@wfa_common_jvm//imports/java/io/grpc:api", - "@wfa_common_jvm//imports/java/io/grpc/stub", - ], -) diff --git a/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/identity/MetadataPrincipalServerInterceptor.kt b/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/identity/MetadataPrincipalServerInterceptor.kt deleted file mode 100644 index 629ddcbbe96..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/identity/MetadataPrincipalServerInterceptor.kt +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.integration.common.reporting.identity - -import io.grpc.BindableService -import io.grpc.Context -import io.grpc.Contexts -import io.grpc.Metadata -import io.grpc.ServerCall -import io.grpc.ServerCallHandler -import io.grpc.ServerInterceptor -import io.grpc.ServerInterceptors -import io.grpc.ServerServiceDefinition -import io.grpc.Status -import org.wfanet.measurement.config.reporting.MeasurementConsumerConfig -import org.wfanet.measurement.reporting.service.api.v1alpha.ContextKeys -import org.wfanet.measurement.reporting.service.api.v1alpha.ReportingPrincipal -import org.wfanet.measurement.reporting.service.api.v1alpha.principalFromCurrentContext -import org.wfanet.measurement.reporting.service.api.v1alpha.withPrincipal - -/** - * Extracts a name from the gRPC [Metadata] and creates a [ReportingPrincipal] to add to the gRPC - * [Context]. - * - * To install, wrap a service with: - * ``` - * yourService.withMetadataPrincipalIdentities() - * ``` - * - * The principal can be accessed within gRPC services via [principalFromCurrentContext]. - * - * This expects the Metadata to have a key "reporting_principal" associated with a value equal to - * the resource name of the principal. The recommended way to set this is to use [withPrincipalName] - * on a stub. - */ -class MetadataPrincipalServerInterceptor(private val config: MeasurementConsumerConfig) : - ServerInterceptor { - override fun interceptCall( - call: ServerCall, - headers: Metadata, - next: ServerCallHandler, - ): ServerCall.Listener { - if (ContextKeys.PRINCIPAL_CONTEXT_KEY.get() != null) { - return Contexts.interceptCall(Context.current(), call, headers, next) - } - - val principalName = headers[REPORTING_PRINCIPAL_NAME_METADATA_KEY] - if (principalName == null) { - call.close( - Status.UNAUTHENTICATED.withDescription("$REPORTING_PRINCIPAL_NAME_METADATA_KEY not found"), - Metadata(), - ) - return object : ServerCall.Listener() {} - } - val principal = ReportingPrincipal.fromConfigs(principalName, config) - if (principal == null) { - call.close(Status.UNAUTHENTICATED.withDescription("No valid Principal found"), Metadata()) - return object : ServerCall.Listener() {} - } - val context = Context.current().withPrincipal(principal) - return Contexts.interceptCall(context, call, headers, next) - } -} - -/** Installs [MetadataPrincipalServerInterceptor] on the service. */ -fun BindableService.withMetadataPrincipalIdentities( - config: MeasurementConsumerConfig -): ServerServiceDefinition = - ServerInterceptors.interceptForward(this, MetadataPrincipalServerInterceptor(config)) diff --git a/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/identity/ReportingPrincipalIdentity.kt b/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/identity/ReportingPrincipalIdentity.kt deleted file mode 100644 index ab42b0aa409..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/integration/common/reporting/identity/ReportingPrincipalIdentity.kt +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.integration.common.reporting.identity - -import io.grpc.Metadata -import io.grpc.stub.AbstractStub -import io.grpc.stub.MetadataUtils - -private const val KEY_NAME = "reporting_principal" -val REPORTING_PRINCIPAL_NAME_METADATA_KEY: Metadata.Key = - Metadata.Key.of(KEY_NAME, Metadata.ASCII_STRING_MARSHALLER) - -/** - * Sets metadata key "reporting_principal" on all outgoing requests. On the server side, use - * [MetadataPrincipalServerInterceptor]. Note that this should only be used in in-process tests - * where mTLS isn't used. - */ -fun > T.withPrincipalName(name: String): T { - val extraHeaders = Metadata() - extraHeaders.put(REPORTING_PRINCIPAL_NAME_METADATA_KEY, name) - return withInterceptors(MetadataUtils.newAttachHeadersInterceptor(extraHeaders)) -} diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/batch/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/kingdom/batch/BUILD.bazel index 12d1d96644d..2d455bdfd32 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/batch/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/batch/BUILD.bazel @@ -47,6 +47,7 @@ kt_jvm_library( "//src/main/kotlin/org/wfanet/measurement/api:api_key_constants", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:packed_messages", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:list_resources", "//src/main/proto/wfa/measurement/api/v2alpha:data_provider_kt_jvm_proto", "//src/main/proto/wfa/measurement/api/v2alpha:data_providers_service_kt_jvm_grpc_proto", "//src/main/proto/wfa/measurement/api/v2alpha:event_group_kt_jvm_proto", diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/batch/MeasurementSystemProber.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/batch/MeasurementSystemProber.kt index b64a3a55842..e7c95bdd42c 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/batch/MeasurementSystemProber.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/batch/MeasurementSystemProber.kt @@ -27,8 +27,12 @@ import java.security.SecureRandom import java.time.Clock import java.time.Duration import java.util.logging.Logger +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.single +import kotlinx.coroutines.flow.singleOrNull import org.wfanet.measurement.api.v2alpha.CanonicalRequisitionKey -import org.wfanet.measurement.api.v2alpha.DataProvider import org.wfanet.measurement.api.v2alpha.DataProvidersGrpcKt import org.wfanet.measurement.api.v2alpha.EventGroup import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt @@ -60,6 +64,9 @@ import org.wfanet.measurement.api.v2alpha.requisitionSpec import org.wfanet.measurement.api.v2alpha.unpack import org.wfanet.measurement.api.withAuthenticationKey import org.wfanet.measurement.common.Instrumentation +import org.wfanet.measurement.common.api.grpc.ResourceList +import org.wfanet.measurement.common.api.grpc.flattenConcat +import org.wfanet.measurement.common.api.grpc.listResources import org.wfanet.measurement.common.crypto.Hashing import org.wfanet.measurement.common.crypto.SigningKeyHandle import org.wfanet.measurement.common.crypto.readCertificate @@ -198,43 +205,31 @@ class MeasurementSystemProber( private suspend fun buildDataProviderNameToEventGroup(): Map { val dataProviderNameToEventGroup = mutableMapOf() for (dataProviderName in dataProviderNames) { - val getDataProviderRequest = getDataProviderRequest { name = dataProviderName } - val dataProvider: DataProvider = - try { - dataProvidersStub - .withAuthenticationKey(apiAuthenticationKey) - .getDataProvider(getDataProviderRequest) - } catch (e: StatusException) { - throw Exception("Unable to get DataProvider with name $dataProviderName", e) - } - - // TODO(@roaminggypsy): Implement QA event group logic using simulatorEventGroupName - val listEventGroupsRequest = listEventGroupsRequest { - parent = measurementConsumerName - filter = ListEventGroupsRequestKt.filter { dataProviders += dataProviderName } - } - - val eventGroups: List = - try { - eventGroupsStub - .withAuthenticationKey(apiAuthenticationKey) - .listEventGroups(listEventGroupsRequest) - .eventGroupsList - .toList() - } catch (e: StatusException) { - throw Exception( - "Unable to get event groups associated with measurement consumer $measurementConsumerName and data provider $dataProviderName", - e, - ) - } - - if (eventGroups.size != 1) { - throw IllegalStateException( - "here should be exactly 1:1 mapping between a data provider and an event group, but data provider $dataProvider is related to ${eventGroups.size} event groups" - ) - } + val eventGroup: EventGroup = + eventGroupsStub + .withAuthenticationKey(apiAuthenticationKey) + .listResources(1) { pageToken, remaining -> + val request = listEventGroupsRequest { + parent = measurementConsumerName + filter = ListEventGroupsRequestKt.filter { dataProviders += dataProviderName } + this.pageToken = pageToken + pageSize = remaining + } + val response = + try { + listEventGroups(request) + } catch (e: StatusException) { + throw Exception( + "Unable to get event groups associated with measurement consumer $measurementConsumerName and data provider $dataProviderName", + e, + ) + } + ResourceList(response.eventGroupsList, response.nextPageToken) + } + .map { it.single() } + .single() - dataProviderNameToEventGroup[dataProviderName] = eventGroups[0] + dataProviderNameToEventGroup[dataProviderName] = eventGroup } return dataProviderNameToEventGroup } @@ -253,55 +248,47 @@ class MeasurementSystemProber( return clock.instant() >= nextMeasurementEarliestInstant } + @OptIn(ExperimentalCoroutinesApi::class) // For `flattenConcat`. private suspend fun getLastUpdatedMeasurement(): Measurement? { - var nextPageToken = "" - do { - val response: ListMeasurementsResponse = - try { - measurementsStub - .withAuthenticationKey(apiAuthenticationKey) - .listMeasurements( + val measurements: Flow> = + measurementsStub.withAuthenticationKey(apiAuthenticationKey).listResources(1) { + pageToken, + remaining -> + val response: ListMeasurementsResponse = + try { + listMeasurements( listMeasurementsRequest { parent = measurementConsumerName - this.pageSize = 1 - pageToken = nextPageToken + this.pageToken = pageToken + this.pageSize = remaining } ) - } catch (e: StatusException) { - throw Exception( - "Unable to list measurements for measurement consumer $measurementConsumerName", - e, - ) - } - if (response.measurementsList.isNotEmpty()) { - return response.measurementsList.single() + } catch (e: StatusException) { + throw Exception( + "Unable to list measurements for measurement consumer $measurementConsumerName", + e, + ) + } + ResourceList(response.measurementsList, response.nextPageToken) } - nextPageToken = response.nextPageToken - } while (nextPageToken.isNotEmpty()) - return null + + return measurements.flattenConcat().singleOrNull() } - private suspend fun getRequisitionsForMeasurement(measurementName: String): List { - var nextPageToken = "" - val requisitions = mutableListOf() - do { - val response: ListRequisitionsResponse = - try { - requisitionsStub - .withAuthenticationKey(apiAuthenticationKey) - .listRequisitions( - listRequisitionsRequest { - parent = measurementName - pageToken = nextPageToken - } - ) - } catch (e: StatusException) { - throw Exception("Unable to list requisitions for measurement $measurementName", e) - } - requisitions.addAll(response.requisitionsList) - nextPageToken = response.nextPageToken - } while (nextPageToken.isNotEmpty()) - return requisitions + @OptIn(ExperimentalCoroutinesApi::class) // For `flattenConcat`. + private fun getRequisitionsForMeasurement(measurementName: String): Flow { + return requisitionsStub + .withAuthenticationKey(apiAuthenticationKey) + .listResources { pageToken -> + val response: ListRequisitionsResponse = + try { + listRequisitions(listRequisitionsRequest { this.pageToken = pageToken }) + } catch (e: StatusException) { + throw Exception("Unable to list requisitions for measurement $measurementName", e) + } + ResourceList(response.requisitionsList, response.nextPageToken) + } + .flattenConcat() } private suspend fun getDataProviderEntry( @@ -360,10 +347,12 @@ class MeasurementSystemProber( private suspend fun updateLastTerminalRequisitionGauge(lastUpdatedMeasurement: Measurement) { val requisitions = getRequisitionsForMeasurement(lastUpdatedMeasurement.name) - for (requisition in requisitions) { + requisitions.collect { requisition -> if (requisition.state == Requisition.State.FULFILLED) { - val requisitionKey = CanonicalRequisitionKey.fromName(requisition.name) - require(requisitionKey != null) { "CanonicalRequisitionKey cannot be null" } + val requisitionKey = + requireNotNull(CanonicalRequisitionKey.fromName(requisition.name)) { + "Requisition name ${requisition.name} is invalid" + } val dataProviderName: String = requisitionKey.dataProviderId val attributes = Attributes.of(DATA_PROVIDER_ATTRIBUTE_KEY, dataProviderName) lastTerminalRequisitionTimeGauge.set( diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/BUILD.bazel index 2865c224525..15694208ce0 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/BUILD.bazel @@ -8,6 +8,7 @@ package(default_visibility = [ "//src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/testing:__pkg__", "//src/main/kotlin/org/wfanet/measurement/loadtest/panelmatch:__pkg__", "//src/main/kotlin/org/wfanet/measurement/loadtest/panelmatchresourcesetup:__pkg__", + "//src/main/kotlin/org/wfanet/measurement/loadtest/reporting:__pkg__", "//src/main/kotlin/org/wfanet/measurement/loadtest/resourcesetup:__pkg__", "//src/test/kotlin/org/wfanet/measurement/integration/common:__pkg__", ]) diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ProtoConversions.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ProtoConversions.kt index f413bc1ddcc..6baa900c9b7 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ProtoConversions.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ProtoConversions.kt @@ -736,9 +736,10 @@ fun InternalModelRollout.toModelRollout(): ModelRollout { source.rolloutPeriodEndTime.toInstant().atZone(ZoneOffset.UTC).toLocalDate().toProtoDate() } } - - rolloutFreezeDate = - source.rolloutFreezeTime.toInstant().atZone(ZoneOffset.UTC).toLocalDate().toProtoDate() + if (source.hasRolloutFreezeTime()) { + rolloutFreezeDate = + source.rolloutFreezeTime.toInstant().atZone(ZoneOffset.UTC).toLocalDate().toProtoDate() + } if (source.externalPreviousModelRolloutId != 0L) { previousModelRollout = ModelRolloutKey( @@ -808,12 +809,15 @@ fun ModelRollout.toInternal( } } - rolloutFreezeTime = - publicModelRollout.rolloutFreezeDate - .toLocalDate() - .atStartOfDay() - .toInstant(ZoneOffset.UTC) - .toProtoTime() + if (publicModelRollout.hasRolloutFreezeDate()) { + rolloutFreezeTime = + publicModelRollout.rolloutFreezeDate + .toLocalDate() + .atStartOfDay() + .toInstant(ZoneOffset.UTC) + .toProtoTime() + } + externalModelReleaseId = apiIdToExternalId(modelReleaseKey.modelReleaseId) } } diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/BUILD.bazel index e75c9f0749c..99d3e013189 100644 --- a/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/BUILD.bazel @@ -103,6 +103,11 @@ kt_jvm_library( kt_jvm_library( name = "measurement_results", srcs = ["MeasurementResults.kt"], + deps = [ + "//imports/java/org/projectnessie/cel", + "//src/main/kotlin/org/wfanet/measurement/populationdataprovider:population_requisition_fulfiller", + "@wfa_common_jvm//imports/java/com/google/protobuf", + ], ) kt_jvm_library( @@ -118,6 +123,7 @@ kt_jvm_library( "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:packed_messages", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key", "//src/main/kotlin/org/wfanet/measurement/common:health", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:list_resources", "//src/main/kotlin/org/wfanet/measurement/dataprovider:requisition_fulfiller", "//src/main/kotlin/org/wfanet/measurement/eventdataprovider/noiser", "//src/main/kotlin/org/wfanet/measurement/eventdataprovider/privacybudgetmanagement:privacy_budget_manager", diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulator.kt b/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulator.kt index ff1959fd678..4737bcf56bd 100644 --- a/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulator.kt +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulator.kt @@ -36,9 +36,11 @@ import kotlin.math.roundToInt import kotlin.random.Random import kotlin.random.asJavaRandom import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.asFlow import kotlinx.coroutines.flow.emitAll +import kotlinx.coroutines.flow.firstOrNull import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.map import kotlinx.coroutines.withContext @@ -74,6 +76,7 @@ import org.wfanet.measurement.api.v2alpha.FulfillRequisitionRequestKt.header import org.wfanet.measurement.api.v2alpha.ListEventGroupsRequestKt import org.wfanet.measurement.api.v2alpha.Measurement import org.wfanet.measurement.api.v2alpha.MeasurementConsumer +import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey import org.wfanet.measurement.api.v2alpha.MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineStub import org.wfanet.measurement.api.v2alpha.MeasurementKey import org.wfanet.measurement.api.v2alpha.MeasurementKt @@ -108,6 +111,9 @@ import org.wfanet.measurement.api.v2alpha.updateEventGroupRequest import org.wfanet.measurement.common.Health import org.wfanet.measurement.common.ProtoReflection import org.wfanet.measurement.common.SettableHealth +import org.wfanet.measurement.common.api.grpc.ResourceList +import org.wfanet.measurement.common.api.grpc.flattenConcat +import org.wfanet.measurement.common.api.grpc.listResources import org.wfanet.measurement.common.asBufferedFlow import org.wfanet.measurement.common.crypto.authorityKeyIdentifier import org.wfanet.measurement.common.crypto.readCertificate @@ -143,7 +149,7 @@ import org.wfanet.measurement.loadtest.dataprovider.MeasurementResults.computeIm /** A simulator handling EDP businesses. */ class EdpSimulator( private val edpData: DataProviderData, - measurementConsumerName: String, + private val measurementConsumerName: String, private val measurementConsumersStub: MeasurementConsumersCoroutineStub, certificatesStub: CertificatesCoroutineStub, private val dataProvidersStub: DataProvidersCoroutineStub, @@ -176,18 +182,11 @@ class EdpSimulator( private val health: SettableHealth = SettableHealth(), private val blockingCoroutineContext: @BlockingExecutor CoroutineContext = Dispatchers.IO, ) : - RequisitionFulfiller( - edpData, - certificatesStub, - requisitionsStub, - throttler, - trustedCertificates, - measurementConsumerName, - ), + RequisitionFulfiller(edpData, certificatesStub, requisitionsStub, throttler, trustedCertificates), Health by health { - val eventGroupReferenceIdPrefix = getEventGroupReferenceIdPrefix(edpData.displayName) + private val eventGroupReferenceIdPrefix = getEventGroupReferenceIdPrefix(edpData.displayName) - val supportedProtocols = buildSet { + private val supportedProtocols = buildSet { add(ProtocolConfig.Protocol.ProtocolCase.LIQUID_LEGIONS_V2) add(ProtocolConfig.Protocol.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2) if (hmssVidIndexMap != null) { @@ -195,6 +194,9 @@ class EdpSimulator( } } + private val measurementConsumerKey = + checkNotNull(MeasurementConsumerKey.fromName(measurementConsumerName)) + /** A sequence of operations done in the simulator. */ override suspend fun run() { dataProvidersStub.replaceDataAvailabilityInterval( @@ -342,27 +344,29 @@ class EdpSimulator( * Returns the first [EventGroup] for this `DataProvider` and [MeasurementConsumer] with * [eventGroupReferenceId], or `null` if not found. */ + @OptIn(ExperimentalCoroutinesApi::class) // For `flattenConcat`. private suspend fun getEventGroupByReferenceId(eventGroupReferenceId: String): EventGroup? { - val response = - try { - eventGroupsStub.listEventGroups( - listEventGroupsRequest { - parent = edpData.name - filter = - ListEventGroupsRequestKt.filter { measurementConsumers += measurementConsumerName } - pageSize = Int.MAX_VALUE + return eventGroupsStub + .listResources { pageToken -> + val response = + try { + listEventGroups( + listEventGroupsRequest { + parent = edpData.name + filter = + ListEventGroupsRequestKt.filter { + measurementConsumers += measurementConsumerName + } + this.pageToken = pageToken + } + ) + } catch (e: StatusException) { + throw Exception("Error listing EventGroups", e) } - ) - } catch (e: StatusException) { - throw Exception("Error listing EventGroups", e) + ResourceList(response.eventGroupsList, response.nextPageToken) } - - // TODO(@SanjayVas): Support filtering by reference ID so we don't need to handle multiple pages - // of EventGroups. - check(response.nextPageToken.isEmpty()) { - "Too many EventGroups for ${edpData.name} and $measurementConsumerName" - } - return response.eventGroupsList.find { it.eventGroupReferenceId == eventGroupReferenceId } + .flattenConcat() + .firstOrNull { it.eventGroupReferenceId == eventGroupReferenceId } } private suspend fun ensureMetadataDescriptor( diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/MeasurementResults.kt b/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/MeasurementResults.kt index 9dd95d4e00c..8d43abe1981 100644 --- a/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/MeasurementResults.kt +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/MeasurementResults.kt @@ -16,6 +16,11 @@ package org.wfanet.measurement.loadtest.dataprovider +import com.google.protobuf.TypeRegistry +import org.projectnessie.cel.Program +import org.wfanet.measurement.populationdataprovider.PopulationInfo +import org.wfanet.measurement.populationdataprovider.PopulationRequisitionFulfiller + /** Utilities for computing Measurement results. */ object MeasurementResults { data class ReachAndFrequency(val reach: Int, val relativeFrequencyDistribution: Map) @@ -56,4 +61,13 @@ object MeasurementResults { // Cap each count at `maxFrequency`. return eventsPerVid.values.sumOf { count -> count.coerceAtMost(maxFrequency).toLong() } } + + /** Computes population using the "deterministic count" methodology. */ + fun computePopulation( + populationInfo: PopulationInfo, + program: Program, + typeRegistry: TypeRegistry, + ): Long { + return PopulationRequisitionFulfiller.computePopulation(populationInfo, program, typeRegistry) + } } diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer/BUILD.bazel index 79d14467157..1451a300dae 100644 --- a/src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer/BUILD.bazel @@ -28,7 +28,9 @@ kt_jvm_library( "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:packed_messages", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha/testing", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:list_resources", "//src/main/kotlin/org/wfanet/measurement/common/identity", + "//src/main/kotlin/org/wfanet/measurement/integration/common:configs", "//src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha:api_key_authentication_server_interceptor", "//src/main/kotlin/org/wfanet/measurement/loadtest/common:sample_vids", "//src/main/kotlin/org/wfanet/measurement/loadtest/config:test_identifiers", @@ -50,6 +52,7 @@ kt_jvm_library( "//src/main/proto/wfa/measurement/api/v2alpha:measurements_service_kt_jvm_grpc_proto", "//src/main/proto/wfa/measurement/api/v2alpha:requisition_spec_kt_jvm_proto", "//src/main/proto/wfa/measurement/api/v2alpha:requisitions_service_kt_jvm_grpc_proto", + "//src/main/proto/wfa/measurement/api/v2alpha/event_templates/testing:bad_templates_kt_jvm_proto", "//src/main/proto/wfa/measurement/api/v2alpha/event_templates/testing:test_event_kt_jvm_proto", "@wfa_common_jvm//imports/java/com/google/common/truth/extensions/proto", "@wfa_common_jvm//imports/java/org/apache/commons:math3", diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer/MeasurementConsumerSimulator.kt b/src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer/MeasurementConsumerSimulator.kt index 7d8c1d49364..f008d4ef94a 100644 --- a/src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer/MeasurementConsumerSimulator.kt +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer/MeasurementConsumerSimulator.kt @@ -18,6 +18,7 @@ import com.google.common.truth.Truth.assertThat import com.google.protobuf.Any as ProtoAny import com.google.protobuf.ByteString import com.google.protobuf.Message +import com.google.protobuf.TypeRegistry import com.google.protobuf.util.Durations import io.grpc.StatusException import java.security.SignatureException @@ -30,7 +31,12 @@ import kotlin.math.log2 import kotlin.math.max import kotlin.math.sqrt import kotlin.random.Random +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.filter +import kotlinx.coroutines.flow.toList import kotlinx.coroutines.time.delay +import org.projectnessie.cel.Program import org.wfanet.measurement.api.v2alpha.Certificate import org.wfanet.measurement.api.v2alpha.CertificatesGrpcKt.CertificatesCoroutineStub import org.wfanet.measurement.api.v2alpha.CustomDirectMethodologyKt @@ -40,6 +46,7 @@ import org.wfanet.measurement.api.v2alpha.DataProviderKey import org.wfanet.measurement.api.v2alpha.DataProviderKt import org.wfanet.measurement.api.v2alpha.DataProvidersGrpcKt.DataProvidersCoroutineStub import org.wfanet.measurement.api.v2alpha.DifferentialPrivacyParams +import org.wfanet.measurement.api.v2alpha.EventAnnotationsProto import org.wfanet.measurement.api.v2alpha.EventGroup import org.wfanet.measurement.api.v2alpha.EventGroupKey import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub @@ -63,12 +70,14 @@ import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.impression import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.reachAndFrequency import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.vidSamplingInterval import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt.MeasurementsCoroutineStub +import org.wfanet.measurement.api.v2alpha.PopulationKey import org.wfanet.measurement.api.v2alpha.ProtocolConfig import org.wfanet.measurement.api.v2alpha.ProtocolConfig.NoiseMechanism import org.wfanet.measurement.api.v2alpha.RequisitionSpec import org.wfanet.measurement.api.v2alpha.RequisitionSpecKt import org.wfanet.measurement.api.v2alpha.RequisitionSpecKt.eventFilter import org.wfanet.measurement.api.v2alpha.RequisitionSpecKt.eventGroupEntry +import org.wfanet.measurement.api.v2alpha.RequisitionSpecKt.population import org.wfanet.measurement.api.v2alpha.SignedMessage import org.wfanet.measurement.api.v2alpha.copy import org.wfanet.measurement.api.v2alpha.createMeasurementRequest @@ -87,6 +96,9 @@ import org.wfanet.measurement.api.v2alpha.unpack import org.wfanet.measurement.api.withAuthenticationKey import org.wfanet.measurement.common.ExponentialBackoff import org.wfanet.measurement.common.OpenEndTimeRange +import org.wfanet.measurement.common.api.grpc.ResourceList +import org.wfanet.measurement.common.api.grpc.flattenConcat +import org.wfanet.measurement.common.api.grpc.listResources import org.wfanet.measurement.common.coerceAtMost import org.wfanet.measurement.common.crypto.Hashing import org.wfanet.measurement.common.crypto.PrivateKeyHandle @@ -100,11 +112,13 @@ import org.wfanet.measurement.consent.client.measurementconsumer.encryptRequisit import org.wfanet.measurement.consent.client.measurementconsumer.signMeasurementSpec import org.wfanet.measurement.consent.client.measurementconsumer.signRequisitionSpec import org.wfanet.measurement.consent.client.measurementconsumer.verifyResult +import org.wfanet.measurement.eventdataprovider.eventfiltration.EventFilters import org.wfanet.measurement.eventdataprovider.noiser.DpParams as NoiserDpParams import org.wfanet.measurement.loadtest.common.sampleVids import org.wfanet.measurement.loadtest.config.TestIdentifiers import org.wfanet.measurement.loadtest.dataprovider.EventQuery import org.wfanet.measurement.loadtest.dataprovider.MeasurementResults +import org.wfanet.measurement.loadtest.dataprovider.MeasurementResults.computePopulation import org.wfanet.measurement.measurementconsumer.stats.DeterministicMethodology import org.wfanet.measurement.measurementconsumer.stats.FrequencyMeasurementParams import org.wfanet.measurement.measurementconsumer.stats.FrequencyMeasurementVarianceParams @@ -118,6 +132,7 @@ import org.wfanet.measurement.measurementconsumer.stats.ReachMeasurementParams import org.wfanet.measurement.measurementconsumer.stats.ReachMeasurementVarianceParams import org.wfanet.measurement.measurementconsumer.stats.VariancesImpl import org.wfanet.measurement.measurementconsumer.stats.VidSamplingInterval as StatsVidSamplingInterval +import org.wfanet.measurement.populationdataprovider.PopulationInfo data class MeasurementConsumerData( // The MC's public API resource name @@ -130,6 +145,12 @@ data class MeasurementConsumerData( val apiAuthenticationKey: String, ) +data class PopulationData( + val populationDataProviderName: String, + val populationInfo: PopulationInfo, + val populationKey: PopulationKey, +) + /** Simulator for MeasurementConsumer operations on the CMMS public API. */ class MeasurementConsumerSimulator( private val measurementConsumerData: MeasurementConsumerData, @@ -150,6 +171,8 @@ class MeasurementConsumerSimulator( /** Cache of resource name to [Certificate]. */ private val certificateCache = mutableMapOf() + private lateinit var populationModelLineName: String + data class RequisitionInfo( val dataProviderEntry: DataProviderEntry, val requisitionSpec: RequisitionSpec, @@ -162,6 +185,12 @@ class MeasurementConsumerSimulator( val requisitions: List, ) + data class PopulationMeasurementInfo( + val populationInfo: PopulationInfo, + val typeRegistry: TypeRegistry, + val measurementInfo: MeasurementInfo, + ) + private data class MeasurementComputationInfo( val methodology: Methodology, val noiseMechanism: NoiseMechanism, @@ -604,6 +633,43 @@ class MeasurementConsumerSimulator( logger.info("Duration result is equal to the expected result") } + /** A sequence of operations done in the simulator involving a population measurement. */ + suspend fun testPopulation( + runId: String, + populationData: PopulationData, + modelLineName: String, + populationFilterExpression: String, + typeRegistry: TypeRegistry, + ) { + logger.info { "Creating population Measurement..." } + // Create a new measurement on behalf of the measurement consumer. + val measurementConsumer = getMeasurementConsumer(measurementConsumerData.name) + populationModelLineName = modelLineName + val populationMeasurementInfo: PopulationMeasurementInfo = + createPopulationMeasurement( + measurementConsumer, + runId, + populationData, + populationFilterExpression, + typeRegistry, + ::newPopulationMeasurementSpec, + ) + + val measurementName = populationMeasurementInfo.measurementInfo.measurement.name + logger.info { "Created population Measurement $measurementName" } + + // Get the CMMS computed result and compare it with the expected result. + val populationResult: Result = pollForResult { getPopulationResult(measurementName) } + logger.info("Got population result from Kingdom: $populationResult") + + val expectedResult = getExpectedPopulationResult(populationMeasurementInfo) + logger.info("Expected result: $expectedResult") + + assertThat(populationResult.population.value).isEqualTo(expectedResult.population.value) + + logger.info("Population result is equal to the expected result") + } + /** Computes the tolerance values of an impression [Result] for testing. */ private fun computeImpressionVariance( result: Result, @@ -753,18 +819,20 @@ class MeasurementConsumerSimulator( maxDataProviders: Int = 20, ): MeasurementInfo { val eventGroups: List = - listEventGroups(measurementConsumer.name).filter { - it.eventGroupReferenceId.startsWith( - TestIdentifiers.SIMULATOR_EVENT_GROUP_REFERENCE_ID_PREFIX - ) - } + listEventGroups(measurementConsumer.name) + .filter { + it.eventGroupReferenceId.startsWith( + TestIdentifiers.SIMULATOR_EVENT_GROUP_REFERENCE_ID_PREFIX + ) + } + .toList() check(eventGroups.isNotEmpty()) { "No event groups found for ${measurementConsumer.name}" } val nonceHashes = mutableListOf() val keyToDataProviderMap: Map = eventGroups .groupBy { extractDataProviderKey(it.name) } .entries - .associate { it.key to getDataProvider(it.key) } + .associate { it.key to getDataProvider(it.key.toName()) } val requisitions: List = eventGroups @@ -787,6 +855,45 @@ class MeasurementConsumerSimulator( } val measurementSpec = newMeasurementSpec(measurementConsumer.publicKey.message, nonceHashes) + + return createMeasurementInfo(measurementConsumer, measurementSpec, requisitions, runId) + } + + private suspend fun createPopulationMeasurement( + measurementConsumer: MeasurementConsumer, + runId: String, + populationData: PopulationData, + populationFilterExpression: String, + typeRegistry: TypeRegistry, + newMeasurementSpec: + (packedMeasurementPublicKey: ProtoAny, nonceHashes: List) -> MeasurementSpec, + ): PopulationMeasurementInfo { + val nonce = Random.Default.nextLong() + val nonceHashes = mutableListOf() + nonceHashes.add(Hashing.hashSha256(nonce)) + val populationDataProvider = getDataProvider(populationData.populationDataProviderName) + val requisitions = + listOf( + buildPopulationMeasurementRequisitionInfo( + populationDataProvider, + measurementConsumer, + populationFilterExpression, + nonce, + ) + ) + val measurementSpec = newMeasurementSpec(measurementConsumer.publicKey.message, nonceHashes) + + val measurementInfo = + createMeasurementInfo(measurementConsumer, measurementSpec, requisitions, runId) + return PopulationMeasurementInfo(populationData.populationInfo, typeRegistry, measurementInfo) + } + + private suspend fun createMeasurementInfo( + measurementConsumer: MeasurementConsumer, + measurementSpec: MeasurementSpec, + requisitions: List, + runId: String, + ): MeasurementInfo { val request = createMeasurementRequest { parent = measurementConsumer.name measurement = measurement { @@ -851,6 +958,20 @@ class MeasurementConsumerSimulator( return result } + /** Gets the result of a [Measurement] if it is succeeded. */ + private suspend fun getPopulationResult(measurementName: String): Result? { + val measurement = checkNotFailed(getMeasurement(measurementName)) + if (measurement.state != Measurement.State.SUCCEEDED) { + return null + } + + val resultOutput = measurement.resultsList[0] + val result = parseAndVerifyResult(resultOutput) + assertThat(result.hasPopulation()).isTrue() + + return result + } + /** Gets [Measurement] with logging state. */ private suspend fun getMeasurement(measurementName: String): Measurement { val measurement: Measurement = @@ -931,7 +1052,7 @@ class MeasurementConsumerSimulator( getExpectedReachAndFrequencyResult(measurementInfo) MeasurementSpec.MeasurementTypeCase.IMPRESSION -> error("Should not be reached.") MeasurementSpec.MeasurementTypeCase.DURATION -> getExpectedDurationResult() - MeasurementSpec.MeasurementTypeCase.POPULATION -> getExpectedPopulationResult() + MeasurementSpec.MeasurementTypeCase.POPULATION -> error("Should not be reached.") MeasurementSpec.MeasurementTypeCase.MEASUREMENTTYPE_NOT_SET -> error("measurement_type not set") } @@ -958,8 +1079,47 @@ class MeasurementConsumerSimulator( } } - private fun getExpectedPopulationResult(): Result { - TODO("Not yet implemented") + private fun getExpectedPopulationResult( + populationMeasurementInfo: PopulationMeasurementInfo + ): Result { + val measurementInfo = populationMeasurementInfo.measurementInfo + val requisition = measurementInfo.requisitions[0] + val requisitionSpec = requisition.requisitionSpec + val requisitionFilterExpression = requisitionSpec.population.filter.expression + + val operativeFields = + populationMeasurementInfo.populationInfo.eventMessageDescriptor.fields + .flatMap { templateField -> + templateField.messageType.fields.map { templateFieldDescriptor -> + if ( + templateFieldDescriptor.options + .getExtension(EventAnnotationsProto.templateField) + .populationAttribute + ) { + "${templateField.name}.${templateFieldDescriptor.name}" + } else null + } + } + .filterNotNull() + .toSet() + val eventMessageDescriptor = populationMeasurementInfo.populationInfo.eventMessageDescriptor + val program: Program = + EventFilters.compileProgram( + eventMessageDescriptor, + requisitionFilterExpression, + operativeFields, + ) + return result { + population = + MeasurementKt.ResultKt.population { + value = + computePopulation( + populationMeasurementInfo.populationInfo, + program, + populationMeasurementInfo.typeRegistry, + ) + } + } } private fun getExpectedReachResult(measurementInfo: MeasurementInfo): Result { @@ -1092,25 +1252,45 @@ class MeasurementConsumerSimulator( } } - private suspend fun listEventGroups(measurementConsumer: String): List { - val request = listEventGroupsRequest { parent = measurementConsumer } - try { - return eventGroupsClient - .withAuthenticationKey(measurementConsumerData.apiAuthenticationKey) - .listEventGroups(request) - .eventGroupsList - } catch (e: StatusException) { - throw Exception("Error listing event groups for MC $measurementConsumer", e) + private fun newPopulationMeasurementSpec( + packedMeasurementPublicKey: ProtoAny, + nonceHashes: List, + ): MeasurementSpec { + return measurementSpec { + measurementPublicKey = packedMeasurementPublicKey + population = MeasurementSpecKt.population {} + this.nonceHashes += nonceHashes + modelLine = populationModelLineName } } + @OptIn(ExperimentalCoroutinesApi::class) // For `flattenConcat`. + private fun listEventGroups(measurementConsumer: String): Flow { + return eventGroupsClient + .withAuthenticationKey(measurementConsumerData.apiAuthenticationKey) + .listResources { pageToken -> + val response = + try { + listEventGroups( + listEventGroupsRequest { + parent = measurementConsumer + this.pageToken = pageToken + } + ) + } catch (e: StatusException) { + throw Exception("Error listing event groups for MC $measurementConsumer", e) + } + ResourceList(response.eventGroupsList, response.nextPageToken) + } + .flattenConcat() + } + private fun extractDataProviderKey(eventGroupName: String): DataProviderKey { val eventGroupKey = EventGroupKey.fromName(eventGroupName) ?: error("Invalid eventGroup name.") return eventGroupKey.parentKey } - private suspend fun getDataProvider(key: DataProviderKey): DataProvider { - val name = key.toName() + private suspend fun getDataProvider(name: String): DataProvider { val request = GetDataProviderRequest.newBuilder().also { it.name = name }.build() try { return dataProvidersClient @@ -1121,7 +1301,7 @@ class MeasurementConsumerSimulator( } } - private suspend fun buildRequisitionInfo( + private fun buildRequisitionInfo( dataProvider: DataProvider, eventGroups: List, measurementConsumer: MeasurementConsumer, @@ -1152,6 +1332,25 @@ class MeasurementConsumerSimulator( return RequisitionInfo(dataProviderEntry, requisitionSpec, eventGroups) } + private fun buildPopulationMeasurementRequisitionInfo( + dataProvider: DataProvider, + measurementConsumer: MeasurementConsumer, + populationFilterExpression: String, + nonce: Long, + ): RequisitionInfo { + val requisitionSpec = requisitionSpec { + population = population { filter = eventFilter { expression = populationFilterExpression } } + measurementPublicKey = measurementConsumer.publicKey.message + this.nonce = nonce + } + val signedRequisitionSpec = + signRequisitionSpec(requisitionSpec, measurementConsumerData.signingKey) + val dataProviderEntry = + dataProvider.toDataProviderEntry(signedRequisitionSpec, Hashing.hashSha256(nonce)) + + return RequisitionInfo(dataProviderEntry, requisitionSpec, listOf()) + } + private fun DataProvider.toDataProviderEntry( signedRequisitionSpec: SignedMessage, nonceHash: ByteString, diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/reporting/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/loadtest/reporting/BUILD.bazel new file mode 100644 index 00000000000..962beda9818 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/reporting/BUILD.bazel @@ -0,0 +1,26 @@ +load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") + +package( + default_testonly = True, + default_visibility = [ + "//src/main/kotlin/org/wfanet/measurement/integration:__subpackages__", + "//src/main/kotlin/org/wfanet/measurement/loadtest:__subpackages__", + "//src/test/kotlin/org/wfanet/measurement/integration:__subpackages__", + "//src/test/kotlin/org/wfanet/measurement/loadtest:__subpackages__", + ], +) + +kt_jvm_library( + name = "simulator", + srcs = ["ReportingUserSimulator.kt"], + deps = [ + "//src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha:data_providers_service", + "//src/main/kotlin/org/wfanet/measurement/loadtest/config:test_identifiers", + "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha:event_groups_service", + "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha:metric_calculation_specs_service", + "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha:reporting_sets_service", + "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha:reports_service", + "//src/main/proto/wfa/measurement/api/v2alpha/event_templates/testing:test_event_kt_jvm_proto", + "@wfa_common_jvm//imports/java/com/google/common/truth/extensions/proto", + ], +) diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/reporting/ReportingUserSimulator.kt b/src/main/kotlin/org/wfanet/measurement/loadtest/reporting/ReportingUserSimulator.kt new file mode 100644 index 00000000000..99a80a5dfff --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/reporting/ReportingUserSimulator.kt @@ -0,0 +1,250 @@ +/* + * Copyright 2024 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.loadtest.reporting + +import com.google.common.truth.Truth.assertThat +import com.google.type.DayOfWeek +import com.google.type.date +import com.google.type.dateTime +import com.google.type.timeZone +import io.grpc.StatusException +import java.time.Duration +import java.util.logging.Logger +import kotlinx.coroutines.time.delay +import org.wfanet.measurement.api.v2alpha.DataProvider +import org.wfanet.measurement.api.v2alpha.DataProvidersGrpcKt +import org.wfanet.measurement.api.v2alpha.getDataProviderRequest +import org.wfanet.measurement.common.ExponentialBackoff +import org.wfanet.measurement.common.coerceAtMost +import org.wfanet.measurement.loadtest.config.TestIdentifiers +import org.wfanet.measurement.reporting.v2alpha.EventGroup +import org.wfanet.measurement.reporting.v2alpha.EventGroupsGrpcKt +import org.wfanet.measurement.reporting.v2alpha.ListEventGroupsResponse +import org.wfanet.measurement.reporting.v2alpha.MetricCalculationSpec +import org.wfanet.measurement.reporting.v2alpha.MetricCalculationSpecKt +import org.wfanet.measurement.reporting.v2alpha.MetricCalculationSpecsGrpcKt +import org.wfanet.measurement.reporting.v2alpha.MetricSpecKt +import org.wfanet.measurement.reporting.v2alpha.Report +import org.wfanet.measurement.reporting.v2alpha.ReportKt +import org.wfanet.measurement.reporting.v2alpha.ReportingSet +import org.wfanet.measurement.reporting.v2alpha.ReportingSetKt +import org.wfanet.measurement.reporting.v2alpha.ReportingSetsGrpcKt +import org.wfanet.measurement.reporting.v2alpha.ReportsGrpcKt +import org.wfanet.measurement.reporting.v2alpha.createMetricCalculationSpecRequest +import org.wfanet.measurement.reporting.v2alpha.createReportRequest +import org.wfanet.measurement.reporting.v2alpha.createReportingSetRequest +import org.wfanet.measurement.reporting.v2alpha.getReportRequest +import org.wfanet.measurement.reporting.v2alpha.listEventGroupsRequest +import org.wfanet.measurement.reporting.v2alpha.metricCalculationSpec +import org.wfanet.measurement.reporting.v2alpha.metricSpec +import org.wfanet.measurement.reporting.v2alpha.report +import org.wfanet.measurement.reporting.v2alpha.reportingSet + +/** Simulator for Reporting operations on the Reporting public API. */ +class ReportingUserSimulator( + private val measurementConsumerName: String, + private val dataProvidersClient: DataProvidersGrpcKt.DataProvidersCoroutineStub, + private val eventGroupsClient: EventGroupsGrpcKt.EventGroupsCoroutineStub, + private val reportingSetsClient: ReportingSetsGrpcKt.ReportingSetsCoroutineStub, + private val metricCalculationSpecsClient: + MetricCalculationSpecsGrpcKt.MetricCalculationSpecsCoroutineStub, + private val reportsClient: ReportsGrpcKt.ReportsCoroutineStub, + private val initialResultPollingDelay: Duration = Duration.ofSeconds(1), + private val maximumResultPollingDelay: Duration = Duration.ofMinutes(1), +) { + suspend fun testCreateReport(runId: String) { + logger.info("Creating report...") + + val eventGroup = + listEventGroups() + .filter { + it.eventGroupReferenceId.startsWith( + TestIdentifiers.SIMULATOR_EVENT_GROUP_REFERENCE_ID_PREFIX + ) + } + .firstOrNull { + getDataProvider(it.cmmsDataProvider).capabilities.honestMajorityShareShuffleSupported + } ?: listEventGroups().first() + val createdPrimitiveReportingSet = createPrimitiveReportingSet(eventGroup) + val createdMetricCalculationSpec = createMetricCalculationSpec() + + val report = report { + reportingMetricEntries += + ReportKt.reportingMetricEntry { + key = createdPrimitiveReportingSet.name + value = + ReportKt.reportingMetricCalculationSpec { + metricCalculationSpecs += createdMetricCalculationSpec.name + } + } + reportingInterval = + ReportKt.reportingInterval { + reportStart = dateTime { + year = 2024 + month = 1 + day = 3 + timeZone = timeZone { id = "America/Los_Angeles" } + } + reportEnd = date { + year = 2024 + month = 1 + day = 4 + } + } + } + + val createdReport = + try { + reportsClient.createReport( + createReportRequest { + parent = measurementConsumerName + this.report = report + reportId = "a-$runId" + } + ) + } catch (e: StatusException) { + throw Exception("Error creating Report", e) + } + + val completedReport = pollForCompletedReport(createdReport.name) + + assertThat(completedReport.state).isEqualTo(Report.State.SUCCEEDED) + logger.info("Report creation succeeded") + } + + private suspend fun listEventGroups(): List { + try { + return buildList { + var response: ListEventGroupsResponse = ListEventGroupsResponse.getDefaultInstance() + do { + response = + eventGroupsClient.listEventGroups( + listEventGroupsRequest { + parent = measurementConsumerName + pageToken = response.nextPageToken + } + ) + addAll(response.eventGroupsList) + } while (response.nextPageToken.isNotEmpty()) + } + } catch (e: StatusException) { + throw Exception("Error listing EventGroups", e) + } + } + + private suspend fun getDataProvider(dataProviderName: String): DataProvider { + try { + return dataProvidersClient.getDataProvider(getDataProviderRequest { name = dataProviderName }) + } catch (e: StatusException) { + throw Exception("Error getting DataProvider $dataProviderName", e) + } + } + + private suspend fun createPrimitiveReportingSet(eventGroup: EventGroup): ReportingSet { + val primitiveReportingSet = reportingSet { + primitive = ReportingSetKt.primitive { cmmsEventGroups += eventGroup.cmmsEventGroup } + } + + try { + return reportingSetsClient.createReportingSet( + createReportingSetRequest { + parent = measurementConsumerName + reportingSet = primitiveReportingSet + reportingSetId = "a-123" + } + ) + } catch (e: StatusException) { + throw Exception("Error creating ReportingSet", e) + } + } + + private suspend fun createMetricCalculationSpec(): MetricCalculationSpec { + try { + return metricCalculationSpecsClient.createMetricCalculationSpec( + createMetricCalculationSpecRequest { + parent = measurementConsumerName + metricCalculationSpecId = "a-123" + metricCalculationSpec = metricCalculationSpec { + displayName = "union reach" + metricSpecs += metricSpec { + reach = + MetricSpecKt.reachParams { + singleDataProviderParams = + MetricSpecKt.samplingAndPrivacyParams { + privacyParams = MetricSpecKt.differentialPrivacyParams {} + } + multipleDataProviderParams = + MetricSpecKt.samplingAndPrivacyParams { + privacyParams = MetricSpecKt.differentialPrivacyParams {} + } + } + } + metricFrequencySpec = + MetricCalculationSpecKt.metricFrequencySpec { + weekly = + MetricCalculationSpecKt.MetricFrequencySpecKt.weekly { + dayOfWeek = DayOfWeek.WEDNESDAY + } + } + trailingWindow = + MetricCalculationSpecKt.trailingWindow { + count = 1 + increment = MetricCalculationSpec.TrailingWindow.Increment.WEEK + } + } + } + ) + } catch (e: StatusException) { + throw Exception("Error creating MetricCalculationSpec", e) + } + } + + private suspend fun pollForCompletedReport(reportName: String): Report { + val backoff = + ExponentialBackoff(initialDelay = initialResultPollingDelay, randomnessFactor = 0.0) + var attempt = 1 + while (true) { + val retrievedReport = + try { + reportsClient.getReport(getReportRequest { name = reportName }) + } catch (e: StatusException) { + throw Exception("Error getting Report", e) + } + + @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. + when (retrievedReport.state) { + Report.State.SUCCEEDED, + Report.State.FAILED -> return retrievedReport + Report.State.RUNNING, + Report.State.UNRECOGNIZED, + Report.State.STATE_UNSPECIFIED -> { + val resultPollingDelay = + backoff.durationForAttempt(attempt).coerceAtMost(maximumResultPollingDelay) + logger.info { + "Report not completed yet. Waiting for ${resultPollingDelay.seconds} seconds." + } + delay(resultPollingDelay) + attempt++ + } + } + } + } + + companion object { + private val logger: Logger = Logger.getLogger(this::class.java.name) + } +} diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/resourcesetup/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/loadtest/resourcesetup/BUILD.bazel index f4a374643f5..185e2f3da8c 100644 --- a/src/main/kotlin/org/wfanet/measurement/loadtest/resourcesetup/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/resourcesetup/BUILD.bazel @@ -25,6 +25,7 @@ kt_jvm_library( "//src/main/kotlin/org/wfanet/measurement/common/identity:principal_identity", "//src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha:account_authentication_server_interceptor", "//src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha:certificates", + "//src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha:model_suites_service", "//src/main/kotlin/org/wfanet/measurement/loadtest/common:output", "//src/main/proto/wfa/measurement/api/v2alpha:account_kt_jvm_proto", "//src/main/proto/wfa/measurement/api/v2alpha:accounts_service_kt_jvm_grpc_proto", @@ -36,12 +37,17 @@ kt_jvm_library( "//src/main/proto/wfa/measurement/api/v2alpha:measurement_consumer_kt_jvm_proto", "//src/main/proto/wfa/measurement/api/v2alpha:measurement_consumers_service_kt_jvm_grpc_proto", "//src/main/proto/wfa/measurement/config:authority_key_to_principal_map_kt_jvm_proto", + "//src/main/proto/wfa/measurement/config/reporting:encryption_key_pair_config_kt_jvm_proto", + "//src/main/proto/wfa/measurement/config/reporting:measurement_consumer_config_kt_jvm_proto", "//src/main/proto/wfa/measurement/internal/kingdom:account_kt_jvm_proto", "//src/main/proto/wfa/measurement/internal/kingdom:accounts_service_kt_jvm_grpc_proto", "//src/main/proto/wfa/measurement/internal/kingdom:certificate_kt_jvm_proto", "//src/main/proto/wfa/measurement/internal/kingdom:certificates_service_kt_jvm_grpc_proto", "//src/main/proto/wfa/measurement/internal/kingdom:data_provider_kt_jvm_proto", "//src/main/proto/wfa/measurement/internal/kingdom:data_providers_service_kt_jvm_grpc_proto", + "//src/main/proto/wfa/measurement/internal/kingdom:model_providers_service_kt_jvm_grpc_proto", + "//src/main/proto/wfa/measurement/internal/kingdom:population_kt_jvm_proto", + "//src/main/proto/wfa/measurement/internal/kingdom:populations_service_kt_jvm_grpc_proto", "//src/main/proto/wfa/measurement/loadtest/resourcesetup:resources_kt_jvm_proto", "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto:signatures", "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto/tink", diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/resourcesetup/ResourceSetup.kt b/src/main/kotlin/org/wfanet/measurement/loadtest/resourcesetup/ResourceSetup.kt index 4bae0c8d911..6bbc96d00f6 100644 --- a/src/main/kotlin/org/wfanet/measurement/loadtest/resourcesetup/ResourceSetup.kt +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/resourcesetup/ResourceSetup.kt @@ -55,17 +55,29 @@ import org.wfanet.measurement.common.crypto.tink.SelfIssuedIdTokens.generateIdTo import org.wfanet.measurement.common.identity.externalIdToApiId import org.wfanet.measurement.config.AuthorityKeyToPrincipalMapKt import org.wfanet.measurement.config.authorityKeyToPrincipalMap +import org.wfanet.measurement.config.reporting.EncryptionKeyPairConfigKt +import org.wfanet.measurement.config.reporting.encryptionKeyPairConfig +import org.wfanet.measurement.config.reporting.measurementConsumerConfig +import org.wfanet.measurement.config.reporting.measurementConsumerConfigs import org.wfanet.measurement.consent.client.measurementconsumer.signEncryptionPublicKey import org.wfanet.measurement.internal.kingdom.Account as InternalAccount import org.wfanet.measurement.internal.kingdom.AccountsGrpcKt import org.wfanet.measurement.internal.kingdom.CertificatesGrpcKt import org.wfanet.measurement.internal.kingdom.DataProvider as InternalDataProvider import org.wfanet.measurement.internal.kingdom.DataProvidersGrpcKt +import org.wfanet.measurement.internal.kingdom.ModelProvider as InternalModelProvider +import org.wfanet.measurement.internal.kingdom.ModelProvidersGrpcKt +import org.wfanet.measurement.internal.kingdom.Population as InternalPopulation +import org.wfanet.measurement.internal.kingdom.PopulationKt +import org.wfanet.measurement.internal.kingdom.PopulationsGrpcKt import org.wfanet.measurement.internal.kingdom.account as internalAccount import org.wfanet.measurement.internal.kingdom.certificate as internalCertificate import org.wfanet.measurement.internal.kingdom.createMeasurementConsumerCreationTokenRequest import org.wfanet.measurement.internal.kingdom.dataProvider as internalDataProvider import org.wfanet.measurement.internal.kingdom.dataProviderDetails +import org.wfanet.measurement.internal.kingdom.eventTemplate +import org.wfanet.measurement.internal.kingdom.modelProvider as internalModelProvider +import org.wfanet.measurement.internal.kingdom.population as internalPopulation import org.wfanet.measurement.kingdom.service.api.v2alpha.fillCertificateFromDer import org.wfanet.measurement.kingdom.service.api.v2alpha.parseCertificateDer import org.wfanet.measurement.loadtest.common.ConsoleOutput @@ -97,6 +109,9 @@ class ResourceSetup( private val requiredDuchies: List, private val bazelConfigName: String = DEFAULT_BAZEL_CONFIG_NAME, private val outputDir: File? = null, + private val internalModelProvidersClient: ModelProvidersGrpcKt.ModelProvidersCoroutineStub? = + null, + private val internalPopulationsClient: PopulationsGrpcKt.PopulationsCoroutineStub? = null, ) { data class MeasurementConsumerAndKey( val measurementConsumer: MeasurementConsumer, @@ -213,6 +228,47 @@ class ResourceSetup( TextFormat.printer().print(akidMap, writer) } + val measurementConsumerConfig = measurementConsumerConfigs { + for (resource in resources) { + when (resource.resourceCase) { + Resources.Resource.ResourceCase.MEASUREMENT_CONSUMER -> + configs.put( + resource.name, + measurementConsumerConfig { + apiKey = resource.measurementConsumer.apiKey + signingCertificateName = resource.measurementConsumer.certificate + signingPrivateKeyPath = MEASUREMENT_CONSUMER_SIGNING_PRIVATE_KEY_PATH + }, + ) + else -> continue + } + } + } + output.resolve(MEASUREMENT_CONSUMER_CONFIG_FILE).writer().use { writer -> + TextFormat.printer().print(measurementConsumerConfig, writer) + } + + val encryptionKeyPairConfig = encryptionKeyPairConfig { + for (resource in resources) { + when (resource.resourceCase) { + Resources.Resource.ResourceCase.MEASUREMENT_CONSUMER -> + principalKeyPairs += + EncryptionKeyPairConfigKt.principalKeyPairs { + principal = resource.name + keyPairs += + EncryptionKeyPairConfigKt.keyPair { + publicKeyFile = MEASUREMENT_CONSUMER_ENCRYPTION_PUBLIC_KEY_PATH + privateKeyFile = MEASUREMENT_CONSUMER_ENCRYPTION_PRIVATE_KEY_PATH + } + } + else -> continue + } + } + } + output.resolve(ENCRYPTION_KEY_PAIR_CONFIG_FILE).writer().use { writer -> + TextFormat.printer().print(encryptionKeyPairConfig, writer) + } + val configName = bazelConfigName output.resolve(BAZEL_RC_FILE).writer().use { writer -> for (resource in resources) { @@ -270,6 +326,28 @@ class ResourceSetup( } } + /** Create an internal modelProvider. */ + suspend fun createInternalModelProvider(): InternalModelProvider { + require(internalModelProvidersClient != null) + return try { + internalModelProvidersClient.createModelProvider(internalModelProvider {}) + } catch (e: StatusException) { + throw Exception("Error creating ModelProvider", e) + } + } + + suspend fun createInternalPopulation(dataProvider: InternalDataProvider): InternalPopulation { + require(internalPopulationsClient != null) + return internalPopulationsClient.createPopulation( + internalPopulation { + externalDataProviderId = dataProvider.externalDataProviderId + description = "DESCRIPTION" + populationBlob = PopulationKt.populationBlob { modelBlobUri = "BLOB_URI" } + eventTemplate = eventTemplate { fullyQualifiedType = "TYPE" } + } + ) + } + suspend fun createAccountWithRetries(): InternalAccount { // The initial account is created via the Kingdom Internal API by the Kingdom operator. // This is our first attempt to contact the Kingdom. If it fails, we will retry it. @@ -413,6 +491,11 @@ class ResourceSetup( const val RESOURCES_OUTPUT_FILE = "resources.textproto" const val AKID_PRINCIPAL_MAP_FILE = "authority_key_identifier_to_principal_map.textproto" const val BAZEL_RC_FILE = "resource-setup.bazelrc" + const val MEASUREMENT_CONSUMER_CONFIG_FILE = "measurement_consumer_config.textproto" + const val ENCRYPTION_KEY_PAIR_CONFIG_FILE = "encryption_key_pair_config.textproto" + const val MEASUREMENT_CONSUMER_SIGNING_PRIVATE_KEY_PATH = "mc_cs_private.der" + const val MEASUREMENT_CONSUMER_ENCRYPTION_PUBLIC_KEY_PATH = "mc_enc_public.tink" + const val MEASUREMENT_CONSUMER_ENCRYPTION_PRIVATE_KEY_PATH = "mc_enc_private.tink" private val logger: Logger = Logger.getLogger(this::class.java.name) } diff --git a/src/main/kotlin/org/wfanet/measurement/populationdataprovider/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/populationdataprovider/BUILD.bazel index ed2a4108553..08ec27c1275 100644 --- a/src/main/kotlin/org/wfanet/measurement/populationdataprovider/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/populationdataprovider/BUILD.bazel @@ -2,6 +2,7 @@ load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") package( default_visibility = [ + "//src/main/kotlin/org/wfanet/measurement/integration/common:__subpackages__", "//src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider:__subpackages__", "//src/test/kotlin/org/wfanet/measurement/loadtest/dataprovider:__subpackages__", ], @@ -20,3 +21,26 @@ kt_jvm_library( "//src/main/proto/wfa/measurement/api/v2alpha:population_spec_kt_jvm_proto", ], ) + +kt_jvm_library( + name = "population_requisition_fulfiller_daemon", + srcs = ["PopulationRequisitionFulfillerDaemon.kt"], + deps = [ + ":population_requisition_fulfiller", + "//src/main/proto/wfa/measurement/api/v2alpha:certificates_service_kt_jvm_grpc_proto", + "//src/main/proto/wfa/measurement/api/v2alpha:data_providers_service_kt_jvm_grpc_proto", + "//src/main/proto/wfa/measurement/api/v2alpha:event_group_metadata_descriptors_service_kt_jvm_grpc_proto", + "//src/main/proto/wfa/measurement/api/v2alpha:event_groups_service_kt_jvm_grpc_proto", + "//src/main/proto/wfa/measurement/api/v2alpha:measurement_consumers_service_kt_jvm_grpc_proto", + "//src/main/proto/wfa/measurement/api/v2alpha:requisition_fulfillment_service_kt_jvm_grpc_proto", + "//src/main/proto/wfa/measurement/api/v2alpha:requisitions_service_kt_jvm_grpc_proto", + "@wfa_common_jvm//imports/java/io/grpc:api", + "@wfa_common_jvm//imports/java/picocli", + "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto:signing_certs", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto/tink", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/throttler", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/storage:client", + ], +) diff --git a/src/main/kotlin/org/wfanet/measurement/populationdataprovider/PopulationRequisitionFulfiller.kt b/src/main/kotlin/org/wfanet/measurement/populationdataprovider/PopulationRequisitionFulfiller.kt index dedd30a5128..82491b0c522 100644 --- a/src/main/kotlin/org/wfanet/measurement/populationdataprovider/PopulationRequisitionFulfiller.kt +++ b/src/main/kotlin/org/wfanet/measurement/populationdataprovider/PopulationRequisitionFulfiller.kt @@ -16,6 +16,7 @@ package org.wfanet.measurement.populationdataprovider import com.google.protobuf.Any import com.google.protobuf.ByteString +import com.google.protobuf.Descriptors import com.google.protobuf.Descriptors.Descriptor import com.google.protobuf.Descriptors.FieldDescriptor import com.google.protobuf.DynamicMessage @@ -30,7 +31,6 @@ import org.wfanet.measurement.api.v2alpha.CertificatesGrpcKt.CertificatesCorouti import org.wfanet.measurement.api.v2alpha.DeterministicCount import org.wfanet.measurement.api.v2alpha.EventAnnotationsProto import org.wfanet.measurement.api.v2alpha.Measurement -import org.wfanet.measurement.api.v2alpha.MeasurementKey import org.wfanet.measurement.api.v2alpha.MeasurementKt import org.wfanet.measurement.api.v2alpha.MeasurementSpec import org.wfanet.measurement.api.v2alpha.ModelRelease @@ -74,7 +74,6 @@ class PopulationRequisitionFulfiller( requisitionsStub: RequisitionsCoroutineStub, throttler: Throttler, trustedCertificates: Map, - measurementConsumerName: String, private val modelRolloutsStub: ModelRolloutsCoroutineStub, private val modelReleasesStub: ModelReleasesCoroutineStub, private val populationInfoMap: Map, @@ -86,7 +85,6 @@ class PopulationRequisitionFulfiller( requisitionsStub, throttler, trustedCertificates, - measurementConsumerName, ) { /** A sequence of operations done in the simulator. */ @@ -96,11 +94,7 @@ class PopulationRequisitionFulfiller( /** Executes the requisition fulfillment workflow. */ override suspend fun executeRequisitionFulfillingWorkflow() { logger.info("Executing requisitionFulfillingWorkflow...") - val requisitions = - getRequisitions().filter { - checkNotNull(MeasurementKey.fromName(it.measurement)).measurementConsumerId == - measurementConsumerKey.measurementConsumerId - } + val requisitions = getRequisitions() if (requisitions.isEmpty()) { logger.fine("No unfulfilled requisition. Polling again later...") @@ -133,17 +127,7 @@ class PopulationRequisitionFulfiller( logger.log(Level.INFO, "MeasurementSpec:\n$measurementSpec") logger.log(Level.INFO, "RequisitionSpec:\n$requisitionSpec") - val modelRelease: ModelRelease = getModelRelease(measurementSpec) - - val populationId: PopulationKey = - requireNotNull(PopulationKey.fromName(modelRelease.population)) { - throw InvalidSpecException( - "Measurement spec model line does not contain a valid Population for the model release of its latest model rollout." - ) - } - - val populationInfo: PopulationInfo = populationInfoMap.getValue(populationId) - populationInfo.eventMessageDescriptor.fields + val populationInfo: PopulationInfo = getPopulationInfo(measurementSpec) PopulationSpecValidator.validateVidRangesList(populationInfo.populationSpec).getOrThrow() val requisitionFilterExpression = requisitionSpec.population.filter.expression @@ -166,6 +150,19 @@ class PopulationRequisitionFulfiller( } } + private suspend fun getPopulationInfo(measurementSpec: MeasurementSpec): PopulationInfo { + val modelRelease: ModelRelease = getModelRelease(measurementSpec) + + val populationId: PopulationKey = + requireNotNull(PopulationKey.fromName(modelRelease.population)) { + throw InvalidSpecException( + "Measurement spec model line does not contain a valid Population for the model release of its latest model rollout." + ) + } + + return populationInfoMap.getValue(populationId) + } + /** * Returns the [ModelRelease] associated with the latest `ModelRollout` that is connected to the * `ModelLine` provided in the MeasurementSpec` @@ -235,26 +232,12 @@ class PopulationRequisitionFulfiller( operativeFields, ) - // Filters populationBucketsList through a CEL program and sums the result. - val populationSum = - populationInfo.populationSpec.subpopulationsList.sumOf { - val attributesList = it.attributesList - val vidRanges = it.vidRangesList - val shouldSumPopulation = - isValidAttributesList(attributesList, populationInfo, program, typeRegistry) - if (shouldSumPopulation) { - vidRanges.sumOf { jt -> jt.size() } - } else { - 0L - } - } - // Create measurement result with sum of valid populations. val measurementResult: Measurement.Result = MeasurementKt.result { population = MeasurementKt.ResultKt.population { - value = populationSum + value = computePopulation(populationInfo, program, typeRegistry) deterministicCount = DeterministicCount.getDefaultInstance() } } @@ -286,62 +269,94 @@ class PopulationRequisitionFulfiller( .toSet() } - /** - * Returns a [Boolean] representing whether the attributes in the list are 1) the correct type and - * 2) pass a check against the filter expression after being run through a CEL program. - */ - private fun isValidAttributesList( - attributeList: List, - populationInfo: PopulationInfo, - program: Program, - typeRegistry: TypeRegistry, - ): Boolean { - val eventMessageDescriptor: Descriptor = populationInfo.eventMessageDescriptor - - // Event message that will be passed to CEL program - val eventMessage: DynamicMessage.Builder = DynamicMessage.newBuilder(eventMessageDescriptor) - - // Populate event message that will be used in the program if attribute is valid - attributeList.forEach { attribute -> - val attributeDescriptor: Descriptor = typeRegistry.getDescriptorForTypeUrl(attribute.typeUrl) - val requiredAttributes: List = - attributeDescriptor.fields.filter { - it.options.getExtension(EventAnnotationsProto.templateField).populationAttribute + companion object { + /** + * Computes population using the "deterministic count" methodology by filtering the + * populationBucketsList through a CEL program and summing the result. + */ + fun computePopulation( + populationInfo: PopulationInfo, + program: Program, + typeRegistry: TypeRegistry, + ): Long { + return populationInfo.populationSpec.subpopulationsList.sumOf { + val attributesList = it.attributesList + val vidRanges = it.vidRangesList + val shouldSumPopulation = + isValidAttributesList( + attributesList, + populationInfo.eventMessageDescriptor, + program, + typeRegistry, + ) + if (shouldSumPopulation) { + vidRanges.sumOf { jt -> jt.size() } + } else { + 0L } + } + } - // Create the attribute message of the type specified in attribute descriptor using the type - // registry. - val descriptor: Descriptor = typeRegistry.getDescriptorForTypeUrl(attribute.typeUrl) - val attributeMessage: DynamicMessage = DynamicMessage.parseFrom(descriptor, attribute.value) + /** + * Returns a [Boolean] representing whether the attributes in the list are 1) the correct type + * and + * 2) pass a check against the filter expression after being run through a CEL program. + */ + private fun isValidAttributesList( + attributeList: List, + eventMessageDescriptor: Descriptor, + program: Program, + typeRegistry: TypeRegistry, + ): Boolean { + // Event message that will be passed to CEL program + val eventMessage: DynamicMessage.Builder = DynamicMessage.newBuilder(eventMessageDescriptor) + + // Populate event message that will be used in the program if attribute is valid + attributeList.forEach { attribute -> + val attributeDescriptor: Descriptor = + typeRegistry.getDescriptorForTypeUrl(attribute.typeUrl) + val requiredAttributes: List = + attributeDescriptor.fields.filter { + it.options.getExtension(EventAnnotationsProto.templateField).populationAttribute + } - // If the attribute type is not a field in the event message, it is not valid. - val isAttributeFieldInEvent = - eventMessageDescriptor.fields.any { - it.messageType.name === attributeMessage.descriptorForType.name - } - require(isAttributeFieldInEvent) { - throw InvalidSpecException( - "Subpopulation attribute is not a field in the event descriptor." - ) - } + // Create the attribute message of the type specified in attribute descriptor using the type + // registry. + val descriptor: Descriptors.Descriptor = + typeRegistry.getDescriptorForTypeUrl(attribute.typeUrl) + val attributeMessage: DynamicMessage = DynamicMessage.parseFrom(descriptor, attribute.value) - // If the population_attribute option in the attribute message is set to true, we do not allow - // the value to be unspecified. - val isValidAttribute = attributeMessage.allFields.keys.containsAll(requiredAttributes) - require(isValidAttribute) { - throw InvalidSpecException("Subpopulation population attribute cannot be unspecified.") - } + // If the attribute type is not a field in the event message, it is not valid. + val isAttributeFieldInEvent = + eventMessageDescriptor.fields.any { + it.messageType.name === attributeMessage.descriptorForType.name + } + require(isAttributeFieldInEvent) { + throw RequisitionFulfiller.InvalidSpecException( + "Subpopulation attribute is not a field in the event descriptor." + ) + } - // Find corresponding field descriptor for this attribute. - val fieldDescriptor: FieldDescriptor = - eventMessageDescriptor.fields.first { eventField -> - eventField.messageType.name === attributeDescriptor.name + // If the population_attribute option in the attribute message is set to true, we do not + // allow + // the value to be unspecified. + val isValidAttribute = attributeMessage.allFields.keys.containsAll(requiredAttributes) + require(isValidAttribute) { + throw RequisitionFulfiller.InvalidSpecException( + "Subpopulation population attribute cannot be unspecified." + ) } - // Set field in event message with typed attribute message. - eventMessage.setField(fieldDescriptor, attributeMessage) - } + // Find corresponding field descriptor for this attribute. + val fieldDescriptor: Descriptors.FieldDescriptor = + eventMessageDescriptor.fields.first { eventField -> + eventField.messageType.name === attributeDescriptor.name + } - return EventFilters.matches(eventMessage.build(), program) + // Set field in event message with typed attribute message. + eventMessage.setField(fieldDescriptor, attributeMessage) + } + return EventFilters.matches(eventMessage.build(), program) + } } } diff --git a/src/main/kotlin/org/wfanet/measurement/populationdataprovider/PopulationRequisitionFulfillerDaemon.kt b/src/main/kotlin/org/wfanet/measurement/populationdataprovider/PopulationRequisitionFulfillerDaemon.kt new file mode 100644 index 00000000000..51212aecfce --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/populationdataprovider/PopulationRequisitionFulfillerDaemon.kt @@ -0,0 +1,241 @@ +// Copyright 2024 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.populationdataprovider + +import com.google.protobuf.DescriptorProtos +import com.google.protobuf.TypeRegistry +import java.io.File +import java.security.cert.X509Certificate +import java.time.Clock +import java.time.Duration +import kotlinx.coroutines.runBlocking +import org.wfanet.measurement.api.v2alpha.CertificatesGrpcKt.CertificatesCoroutineStub +import org.wfanet.measurement.api.v2alpha.DataProviderCertificateKey +import org.wfanet.measurement.api.v2alpha.ModelReleasesGrpcKt.ModelReleasesCoroutineStub +import org.wfanet.measurement.api.v2alpha.ModelRolloutsGrpcKt.ModelRolloutsCoroutineStub +import org.wfanet.measurement.api.v2alpha.PopulationKey +import org.wfanet.measurement.api.v2alpha.PopulationSpec +import org.wfanet.measurement.api.v2alpha.RequisitionsGrpcKt.RequisitionsCoroutineStub +import org.wfanet.measurement.common.commandLineMain +import org.wfanet.measurement.common.crypto.SigningCerts +import org.wfanet.measurement.common.crypto.SigningKeyHandle +import org.wfanet.measurement.common.crypto.readCertificate +import org.wfanet.measurement.common.crypto.readPrivateKey +import org.wfanet.measurement.common.crypto.tink.loadPrivateKey +import org.wfanet.measurement.common.grpc.TlsFlags +import org.wfanet.measurement.common.grpc.buildMutualTlsChannel +import org.wfanet.measurement.common.grpc.grpcRequireNotNull +import org.wfanet.measurement.common.parseTextProto +import org.wfanet.measurement.common.readByteString +import org.wfanet.measurement.common.throttler.MinimumIntervalThrottler +import org.wfanet.measurement.dataprovider.DataProviderData +import picocli.CommandLine +import picocli.CommandLine.Command +import picocli.CommandLine.Option + +@Command( + name = "PopulationRequisitionFulfillerDaemon", + mixinStandardHelpOptions = true, + showDefaultValues = true, + description = ["Fulfills a population requisition."], +) +class PopulationRequisitionFulfillerDaemon : Runnable { + @Option( + names = ["--kingdom-system-api-target"], + description = ["gRPC target (authority) of the Kingdom system API server"], + required = true, + ) + private lateinit var target: String + + @Option( + names = ["--kingdom-system-api-cert-host"], + description = + [ + "Expected hostname (DNS-ID) in the Kingdom system API server's TLS certificate.", + "This overrides derivation of the TLS DNS-ID from --kingdom-system-api-target.", + ], + required = false, + ) + private var certHost: String? = null + + /** The PdpSimulator pod's own tls certificates. */ + @CommandLine.Mixin private lateinit var tlsFlags: TlsFlags + + @Option( + names = ["--data-provider-resource-name"], + description = ["The public API resource name of this data provider."], + required = true, + ) + private lateinit var dataProviderResourceName: String + + @Option( + names = ["--data-provider-certificate-resource-name"], + description = ["The public API resource name for data provider consent signaling."], + required = true, + ) + private lateinit var dataProviderCertificateResourceName: String + + @Option( + names = ["--data-provider-display-name"], + description = ["The display name of this data provider."], + required = true, + ) + private lateinit var dataProviderDisplayName: String + + @Option( + names = ["--data-provider-encryption-private-keyset"], + description = ["The PDP's encryption private Tink Keyset."], + required = true, + ) + private lateinit var pdpEncryptionPrivateKeyset: File + + @Option( + names = ["--data-provider-consent-signaling-private-key-der-file"], + description = ["The PDP's consent signaling private key (DER format) file."], + required = true, + ) + private lateinit var pdpCsPrivateKeyDerFile: File + + @Option( + names = ["--data-provider-consent-signaling-certificate-der-file"], + description = ["The PDP's consent signaling private key (DER format) file."], + required = true, + ) + private lateinit var pdpCsCertificateDerFile: File + + @Option( + names = ["--throttler-minimum-interval"], + description = ["Minimum throttle interval"], + defaultValue = "2s", + ) + private lateinit var throttlerMinimumInterval: Duration + + @Option( + names = ["--event-message-descriptor-set"], + description = + [ + "Serialized FileDescriptorSet for the event message and its dependencies. This can be specified multiple times" + ], + required = false, + ) + private lateinit var eventMessageDescriptorSetFiles: List + + @CommandLine.ArgGroup(exclusive = false, multiplicity = "1..*") + private lateinit var populationKeyAndInfo: List + + private class PopulationKeyAndInfo { + @Option(names = ["--population-resource-name"], required = true) + lateinit var populationResourceName: String + private set + + @CommandLine.ArgGroup(exclusive = false, multiplicity = "1") + lateinit var populationInfo: PopulationInfo + private set + } + + private class PopulationInfo { + @Option(names = ["--population-spec"], required = true) + lateinit var populationSpecFile: File + private set + + @Option(names = ["--event-message-type-url"], required = true) + lateinit var eventMessageTypeUrl: String + private set + } + + override fun run() { + val certificate: X509Certificate = + pdpCsCertificateDerFile.inputStream().use { input -> readCertificate(input) } + val signingKeyHandle = + SigningKeyHandle( + certificate, + readPrivateKey(pdpCsPrivateKeyDerFile.readByteString(), certificate.publicKey.algorithm), + ) + val certificateKey = DataProviderCertificateKey.fromName(dataProviderCertificateResourceName)!! + val pdpData = + DataProviderData( + dataProviderResourceName, + dataProviderDisplayName, + loadPrivateKey(pdpEncryptionPrivateKeyset), + signingKeyHandle, + certificateKey, + ) + + val clientCerts = + SigningCerts.fromPemFiles( + certificateFile = tlsFlags.certFile, + privateKeyFile = tlsFlags.privateKeyFile, + trustedCertCollectionFile = tlsFlags.certCollectionFile, + ) + + val channelTarget: String = target + val channelCertHost: String? = certHost + + val publicApiChannel = buildMutualTlsChannel(channelTarget, clientCerts, channelCertHost) + + val certificatesStub = CertificatesCoroutineStub(publicApiChannel) + val requisitionsStub = RequisitionsCoroutineStub(publicApiChannel) + val modelRolloutsStub = ModelRolloutsCoroutineStub(publicApiChannel) + val modelReleasesStub = ModelReleasesCoroutineStub(publicApiChannel) + + val throttler = MinimumIntervalThrottler(Clock.systemUTC(), throttlerMinimumInterval) + + val typeRegistry = buildTypeRegistry() + + val populationInfoMap = buildPopulationInfoMap(typeRegistry) + + val populationRequisitionFulfiller = + PopulationRequisitionFulfiller( + pdpData, + certificatesStub, + requisitionsStub, + throttler, + clientCerts.trustedCertificates, + modelRolloutsStub, + modelReleasesStub, + populationInfoMap, + typeRegistry, + ) + + runBlocking { populationRequisitionFulfiller.run() } + } + + private fun buildTypeRegistry(): TypeRegistry { + val builder = TypeRegistry.newBuilder() + if (::eventMessageDescriptorSetFiles.isInitialized) { + val fileDescriptorSets: List = + eventMessageDescriptorSetFiles.map { + parseTextProto(it, DescriptorProtos.FileDescriptorSet.getDefaultInstance()) + } + val descriptors = fileDescriptorSets.map { it.descriptorForType } + builder.add(descriptors) + } + return builder.build() + } + + private fun buildPopulationInfoMap( + typeRegistry: TypeRegistry + ): Map { + return populationKeyAndInfo.associate { + grpcRequireNotNull(PopulationKey.fromName(it.populationResourceName)) to + PopulationInfo( + parseTextProto(it.populationInfo.populationSpecFile, PopulationSpec.getDefaultInstance()), + typeRegistry.getDescriptorForTypeUrl(it.populationInfo.eventMessageTypeUrl), + ) + } + } +} + +fun main(args: Array) = commandLineMain(PopulationRequisitionFulfillerDaemon(), args) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/BUILD.bazel deleted file mode 100644 index c576b711d5d..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/BUILD.bazel +++ /dev/null @@ -1,34 +0,0 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") - -package( - default_visibility = [ - "//src/main/kotlin/org/wfanet/measurement/reporting:__subpackages__", - "//src/test/kotlin/org/wfanet/measurement/reporting:__subpackages__", - ], -) - -kt_jvm_library( - name = "encryption_key_pair_map", - srcs = ["EncryptionKeyPairMap.kt"], - deps = [ - "//src/main/proto/wfa/measurement/config/reporting:encryption_key_pair_config_kt_jvm_proto", - "@wfa_common_jvm//imports/java/picocli", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto/tink", - ], -) - -kt_jvm_library( - name = "flags", - srcs = ["InternalApiFlags.kt"], - deps = [ - "@wfa_common_jvm//imports/java/picocli", - ], -) - -kt_jvm_library( - name = "kingdom_flags", - srcs = ["KingdomApiFlags.kt"], - deps = [ - "@wfa_common_jvm//imports/java/picocli", - ], -) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/EncryptionKeyPairMap.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/EncryptionKeyPairMap.kt deleted file mode 100644 index dcd3135ee9c..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/EncryptionKeyPairMap.kt +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.common - -import com.google.protobuf.ByteString -import java.io.File -import org.wfanet.measurement.common.crypto.PrivateKeyHandle -import org.wfanet.measurement.common.crypto.tink.loadPrivateKey -import org.wfanet.measurement.common.parseTextProto -import org.wfanet.measurement.common.readByteString -import org.wfanet.measurement.config.reporting.encryptionKeyPairConfig -import picocli.CommandLine.Option - -class EncryptionKeyPairMap { - @Option( - names = ["--key-pair-dir"], - description = ["Path to the directory of MeasurementConsumer's encryption keys"], - ) - private lateinit var keyFilesDirectory: File - - @Option( - names = ["--key-pair-config-file"], - description = ["Path to the textproto file of EncryptionKeyPairConfig that contains key pairs"], - required = true, - ) - private lateinit var keyPairConfigFile: File - - private fun loadKeyPairs(): Map>> { - val keyPairConfig = - parseTextProto(keyPairConfigFile, encryptionKeyPairConfig {}).principalKeyPairsList - return keyPairConfig.associate { config -> - val keyPairs = - config.keyPairsList.map { keyPair -> - val publicKey = keyFilesDirectory.resolve(keyPair.publicKeyFile).readByteString() - val privateKey = loadPrivateKey(keyFilesDirectory.resolve(keyPair.privateKeyFile)) - publicKey to privateKey - } - checkNotNull(config.principal) to keyPairs - } - } - - val keyPairs: Map>> by lazy { loadKeyPairs() } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/InternalApiFlags.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/InternalApiFlags.kt deleted file mode 100644 index 122b61bc0a2..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/InternalApiFlags.kt +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.common - -import picocli.CommandLine - -class InternalApiFlags { - @set:CommandLine.Option( - names = ["--internal-api-target"], - description = ["gRPC target (authority) of the Reporting internal API server"], - required = true, - ) - lateinit var target: String - - @CommandLine.Option( - names = ["--internal-api-cert-host"], - description = - [ - "Expected hostname (DNS-ID) in the Reporting internal API server's TLS certificate.", - "This overrides derivation of the TLS DNS-ID from --internal-api-target.", - ], - required = false, - ) - var certHost: String? = null - private set -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/KingdomApiFlags.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/KingdomApiFlags.kt deleted file mode 100644 index 608a72797c7..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/KingdomApiFlags.kt +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.common - -import picocli.CommandLine - -class KingdomApiFlags { - @set:CommandLine.Option( - names = ["--kingdom-api-target"], - description = ["gRPC target (authority) of the Kingdom public API server"], - required = true, - ) - lateinit var target: String - - @CommandLine.Option( - names = ["--kingdom-api-cert-host"], - description = - [ - "Expected hostname (DNS-ID) in the Kingdom public API server's TLS certificate.", - "This overrides derivation of the TLS DNS-ID from --kingdom-api-target.", - ], - required = false, - ) - var certHost: String? = null - private set -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/BUILD.bazel deleted file mode 100644 index a98176a11ef..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/BUILD.bazel +++ /dev/null @@ -1,78 +0,0 @@ -load("@rules_java//java:defs.bzl", "java_binary") -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") -load("//src/main/docker:macros.bzl", "java_image") - -kt_jvm_library( - name = "reporting_data_server", - srcs = ["ReportingDataServer.kt"], - visibility = [ - "//src/main/kotlin/org/wfanet/measurement/integration/common/reporting:__subpackages__", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy:__subpackages__", - ], - deps = [ - "//src/main/proto/wfa/measurement/internal/reporting:measurements_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/internal/reporting:reporting_sets_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/internal/reporting:reports_service_kt_jvm_grpc_proto", - "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", - ], -) - -kt_jvm_library( - name = "reporting_api_server_flags", - srcs = ["ReportingApiServerFlags.kt"], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/common:flags", - "@wfa_common_jvm//imports/java/picocli", - ], -) - -kt_jvm_library( - name = "v1alpha_public_api_server", - srcs = ["V1AlphaPublicApiServer.kt"], - runtime_deps = ["@wfa_common_jvm//imports/java/io/grpc/netty"], - deps = [ - ":reporting_api_server_flags", - "//src/main/kotlin/org/wfanet/measurement/common/api:memoizing_principal_lookup", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/common:encryption_key_pair_map", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/common:kingdom_flags", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/config:measurement_spec_config_validator", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api:cel_env_provider", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api:encryption_key_pair_store", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:akid_principal_lookup", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:event_groups_service", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:reporting_sets_service", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:reports_service", - "//src/main/proto/wfa/measurement/api/v2alpha:certificates_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:data_providers_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:event_group_metadata_descriptors_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:event_groups_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:measurement_consumers_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:measurements_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/config/reporting:measurement_consumer_config_kt_jvm_proto", - "//src/main/proto/wfa/measurement/config/reporting:measurement_spec_config_kt_jvm_proto", - "//src/main/proto/wfa/measurement/internal/reporting:measurements_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/internal/reporting:reporting_sets_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/internal/reporting:reports_service_kt_jvm_grpc_proto", - "@wfa_common_jvm//imports/java/io/grpc:api", - "@wfa_common_jvm//imports/java/picocli", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", - ], -) - -java_binary( - name = "V1AlphaPublicApiServer", - main_class = "org.wfanet.measurement.reporting.deploy.common.server.V1AlphaPublicApiServerKt", - runtime_deps = [ - ":v1alpha_public_api_server", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/logging", - ], -) - -java_image( - name = "v1alpha_public_api_server_image", - binary = ":V1AlphaPublicApiServer", - main_class = "org.wfanet.measurement.reporting.deploy.common.server.V1AlphaPublicApiServerKt", - visibility = ["//src:docker_image_deployment"], -) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/ReportingApiServerFlags.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/ReportingApiServerFlags.kt deleted file mode 100644 index fc3960af21a..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/ReportingApiServerFlags.kt +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.common.server - -import java.time.Duration -import kotlin.properties.Delegates -import org.wfanet.measurement.reporting.deploy.common.InternalApiFlags -import picocli.CommandLine - -class ReportingApiServerFlags { - @CommandLine.Mixin - lateinit var internalApiFlags: InternalApiFlags - private set - - @set:CommandLine.Option( - names = ["--debug-verbose-grpc-client-logging"], - description = ["Enables full gRPC request and response logging for outgoing gRPCs"], - defaultValue = "false", - ) - var debugVerboseGrpcClientLogging by Delegates.notNull() - private set - - @CommandLine.Option( - names = ["--event-group-metadata-descriptor-cache-duration"], - description = - [ - "How long the event group metadata descriptors are cached for before refreshing in format 1d1h1m1s1ms1ns" - ], - defaultValue = "1h", - required = false, - ) - lateinit var eventGroupMetadataDescriptorCacheDuration: Duration - private set -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/ReportingDataServer.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/ReportingDataServer.kt deleted file mode 100644 index 9f3b903ff41..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/ReportingDataServer.kt +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.common.server - -import io.grpc.BindableService -import kotlin.reflect.full.declaredMemberProperties -import kotlinx.coroutines.runInterruptible -import org.wfanet.measurement.common.grpc.CommonServer -import org.wfanet.measurement.internal.reporting.MeasurementsGrpcKt -import org.wfanet.measurement.internal.reporting.ReportingSetsGrpcKt -import org.wfanet.measurement.internal.reporting.ReportsGrpcKt -import picocli.CommandLine - -abstract class ReportingDataServer : Runnable { - data class Services( - val measurementsService: MeasurementsGrpcKt.MeasurementsCoroutineImplBase, - val reportingSetsService: ReportingSetsGrpcKt.ReportingSetsCoroutineImplBase, - val reportsService: ReportsGrpcKt.ReportsCoroutineImplBase, - ) - - @CommandLine.Mixin private lateinit var serverFlags: CommonServer.Flags - - protected suspend fun run(services: Services) { - val server = CommonServer.fromFlags(serverFlags, this::class.simpleName!!, services.toList()) - - runInterruptible { server.start().blockUntilShutdown() } - } - - companion object { - fun Services.toList(): List { - return Services::class.declaredMemberProperties.map { it.get(this) as BindableService } - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/V1AlphaPublicApiServer.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/V1AlphaPublicApiServer.kt deleted file mode 100644 index 2b1b64bdb11..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/V1AlphaPublicApiServer.kt +++ /dev/null @@ -1,220 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.common.server - -import com.google.protobuf.ByteString -import com.google.protobuf.DescriptorProtos -import com.google.protobuf.Descriptors -import io.grpc.Channel -import io.grpc.ServerServiceDefinition -import java.io.File -import java.security.SecureRandom -import kotlin.random.asKotlinRandom -import kotlinx.coroutines.Dispatchers -import org.wfanet.measurement.api.v2alpha.CertificatesGrpcKt.CertificatesCoroutineStub as KingdomCertificatesCoroutineStub -import org.wfanet.measurement.api.v2alpha.DataProvidersGrpcKt.DataProvidersCoroutineStub as KingdomDataProvidersCoroutineStub -import org.wfanet.measurement.api.v2alpha.EventGroupMetadataDescriptorsGrpcKt.EventGroupMetadataDescriptorsCoroutineStub as KingdomEventGroupMetadataDescriptorsCoroutineStub -import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub as KingdomEventGroupsCoroutineStub -import org.wfanet.measurement.api.v2alpha.MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineStub as KingdomMeasurementConsumersCoroutineStub -import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt.MeasurementsCoroutineStub as KingdomMeasurementsCoroutineStub -import org.wfanet.measurement.api.withAuthenticationKey -import org.wfanet.measurement.common.ProtoReflection -import org.wfanet.measurement.common.api.PrincipalLookup -import org.wfanet.measurement.common.api.memoizing -import org.wfanet.measurement.common.commandLineMain -import org.wfanet.measurement.common.crypto.SigningCerts -import org.wfanet.measurement.common.grpc.CommonServer -import org.wfanet.measurement.common.grpc.buildMutualTlsChannel -import org.wfanet.measurement.common.grpc.withVerboseLogging -import org.wfanet.measurement.common.parseTextProto -import org.wfanet.measurement.config.reporting.MeasurementConsumerConfigs -import org.wfanet.measurement.config.reporting.MeasurementSpecConfig -import org.wfanet.measurement.internal.reporting.MeasurementsGrpcKt.MeasurementsCoroutineStub as InternalMeasurementsCoroutineStub -import org.wfanet.measurement.internal.reporting.ReportingSetsGrpcKt.ReportingSetsCoroutineStub as InternalReportingSetsCoroutineStub -import org.wfanet.measurement.internal.reporting.ReportsGrpcKt.ReportsCoroutineStub as InternalReportsCoroutineStub -import org.wfanet.measurement.reporting.deploy.common.EncryptionKeyPairMap -import org.wfanet.measurement.reporting.deploy.common.KingdomApiFlags -import org.wfanet.measurement.reporting.deploy.config.validate -import org.wfanet.measurement.reporting.service.api.CelEnvCacheProvider -import org.wfanet.measurement.reporting.service.api.InMemoryEncryptionKeyPairStore -import org.wfanet.measurement.reporting.service.api.v1alpha.AkidPrincipalLookup -import org.wfanet.measurement.reporting.service.api.v1alpha.EventGroupsService -import org.wfanet.measurement.reporting.service.api.v1alpha.ReportingPrincipal -import org.wfanet.measurement.reporting.service.api.v1alpha.ReportingSetsService -import org.wfanet.measurement.reporting.service.api.v1alpha.ReportsService -import org.wfanet.measurement.reporting.service.api.v1alpha.withPrincipalsFromX509AuthorityKeyIdentifiers -import org.wfanet.measurement.reporting.v1alpha.EventGroup -import picocli.CommandLine - -private const val SERVER_NAME = "V1AlphaPublicApiServer" - -@CommandLine.Command( - name = SERVER_NAME, - description = ["Server daemon for Reporting v1alpha public API services."], - mixinStandardHelpOptions = true, - showDefaultValues = true, -) -private fun run( - @CommandLine.Mixin reportingApiServerFlags: ReportingApiServerFlags, - @CommandLine.Mixin kingdomApiFlags: KingdomApiFlags, - @CommandLine.Mixin commonServerFlags: CommonServer.Flags, - @CommandLine.Mixin v1AlphaFlags: V1AlphaFlags, - @CommandLine.Mixin encryptionKeyPairMap: EncryptionKeyPairMap, -) { - val clientCerts = - SigningCerts.fromPemFiles( - certificateFile = commonServerFlags.tlsFlags.certFile, - privateKeyFile = commonServerFlags.tlsFlags.privateKeyFile, - trustedCertCollectionFile = commonServerFlags.tlsFlags.certCollectionFile, - ) - val channel: Channel = - buildMutualTlsChannel( - reportingApiServerFlags.internalApiFlags.target, - clientCerts, - reportingApiServerFlags.internalApiFlags.certHost, - ) - .withVerboseLogging(reportingApiServerFlags.debugVerboseGrpcClientLogging) - - val kingdomChannel: Channel = - buildMutualTlsChannel( - target = kingdomApiFlags.target, - clientCerts = clientCerts, - hostName = kingdomApiFlags.certHost, - ) - .withVerboseLogging(reportingApiServerFlags.debugVerboseGrpcClientLogging) - - val principalLookup: PrincipalLookup = - AkidPrincipalLookup( - v1AlphaFlags.authorityKeyIdentifierToPrincipalMapFile, - v1AlphaFlags.measurementConsumerConfigFile, - ) - .memoizing() - - val measurementConsumerConfigs = - parseTextProto( - v1AlphaFlags.measurementConsumerConfigFile, - MeasurementConsumerConfigs.getDefaultInstance(), - ) - - val apiKey = measurementConsumerConfigs.configsMap.values.first().apiKey - val celEnvCacheProvider = - CelEnvCacheProvider( - KingdomEventGroupMetadataDescriptorsCoroutineStub(kingdomChannel) - .withAuthenticationKey(apiKey), - EventGroup.getDescriptor(), - reportingApiServerFlags.eventGroupMetadataDescriptorCacheDuration, - v1AlphaFlags.knownEventGroupMetadataTypes, - Dispatchers.Default, - ) - - val measurementSpecConfig = - parseTextProto( - v1AlphaFlags.measurementSpecConfigFile, - MeasurementSpecConfig.getDefaultInstance(), - ) - - try { - measurementSpecConfig.validate() - } catch (e: IllegalStateException) { - throw IllegalArgumentException("Invalid MeasurementSpeConfig", e) - } - - val services: List = - listOf( - EventGroupsService( - KingdomEventGroupsCoroutineStub(kingdomChannel), - InMemoryEncryptionKeyPairStore(encryptionKeyPairMap.keyPairs), - celEnvCacheProvider, - ) - .withPrincipalsFromX509AuthorityKeyIdentifiers(principalLookup), - ReportingSetsService(InternalReportingSetsCoroutineStub(channel)) - .withPrincipalsFromX509AuthorityKeyIdentifiers(principalLookup), - ReportsService( - InternalReportsCoroutineStub(channel), - InternalReportingSetsCoroutineStub(channel), - InternalMeasurementsCoroutineStub(channel), - KingdomDataProvidersCoroutineStub(kingdomChannel), - KingdomMeasurementConsumersCoroutineStub(kingdomChannel), - KingdomMeasurementsCoroutineStub(kingdomChannel), - KingdomCertificatesCoroutineStub(kingdomChannel), - InMemoryEncryptionKeyPairStore(encryptionKeyPairMap.keyPairs), - SecureRandom().asKotlinRandom(), - v1AlphaFlags.signingPrivateKeyStoreDir, - commonServerFlags.tlsFlags.signingCerts.trustedCertificates, - measurementSpecConfig, - ) - .withPrincipalsFromX509AuthorityKeyIdentifiers(principalLookup), - ) - CommonServer.fromFlags(commonServerFlags, SERVER_NAME, services).start().blockUntilShutdown() -} - -fun main(args: Array) = commandLineMain(::run, args) - -/** Flags specific to the V1Alpha API version. */ -private class V1AlphaFlags { - @CommandLine.Option( - names = ["--authority-key-identifier-to-principal-map-file"], - description = ["File path to a AuthorityKeyToPrincipalMap textproto"], - required = true, - ) - lateinit var authorityKeyIdentifierToPrincipalMapFile: File - private set - - @CommandLine.Option( - names = ["--measurement-consumer-config-file"], - description = ["File path to a MeasurementConsumerConfig textproto"], - required = true, - ) - lateinit var measurementConsumerConfigFile: File - private set - - @CommandLine.Option( - names = ["--measurement-spec-config-file"], - description = ["File path to a MeasurementSpecConfig textproto"], - required = true, - ) - lateinit var measurementSpecConfigFile: File - private set - - @CommandLine.Option( - names = ["--signing-private-key-store-dir"], - description = ["File path to the signing private key store directory"], - required = true, - ) - lateinit var signingPrivateKeyStoreDir: File - private set - - @CommandLine.Option( - names = ["--known-event-group-metadata-type"], - description = - [ - "File path to FileDescriptorSet containing known EventGroup metadata types.", - "This is in addition to standard protobuf well-known types.", - "Can be specified multiple times.", - ], - required = false, - defaultValue = "", - ) - private fun setKnownEventGroupMetadataTypes(fileDescriptorSetFiles: List) { - val fileDescriptorSets = - fileDescriptorSetFiles.map { file -> - file.inputStream().use { input -> DescriptorProtos.FileDescriptorSet.parseFrom(input) } - } - knownEventGroupMetadataTypes = ProtoReflection.buildFileDescriptors(fileDescriptorSets) - } - - lateinit var knownEventGroupMetadataTypes: List - private set -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/postgres/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/postgres/BUILD.bazel deleted file mode 100644 index bc5bae4cdce..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/postgres/BUILD.bazel +++ /dev/null @@ -1,20 +0,0 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") - -package( - default_visibility = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/gcloud/postgres/server:__subpackages__", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/server:__subpackages__", - "//src/test/kotlin/org/wfanet/measurement/integration/deploy/common/postgres:__pkg__", - ], -) - -kt_jvm_library( - name = "services", - srcs = ["PostgresServices.kt"], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server:reporting_data_server", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres:services", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/identity", - ], -) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/postgres/PostgresServices.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/postgres/PostgresServices.kt deleted file mode 100644 index 93b52a7c012..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/postgres/PostgresServices.kt +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.common.server.postgres - -import org.wfanet.measurement.common.db.r2dbc.DatabaseClient -import org.wfanet.measurement.common.identity.IdGenerator -import org.wfanet.measurement.reporting.deploy.common.server.ReportingDataServer.Services -import org.wfanet.measurement.reporting.deploy.postgres.PostgresMeasurementsService -import org.wfanet.measurement.reporting.deploy.postgres.PostgresReportingSetsService -import org.wfanet.measurement.reporting.deploy.postgres.PostgresReportsService - -object PostgresServices { - @JvmStatic - fun create(idGenerator: IdGenerator, client: DatabaseClient): Services { - return Services( - PostgresMeasurementsService(idGenerator, client), - PostgresReportingSetsService(idGenerator, client), - PostgresReportsService(idGenerator, client), - ) - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/config/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/config/BUILD.bazel deleted file mode 100644 index 37ae63fa26b..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/config/BUILD.bazel +++ /dev/null @@ -1,13 +0,0 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") - -kt_jvm_library( - name = "measurement_spec_config_validator", - srcs = ["MeasurementSpecConfigValidator.kt"], - visibility = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/common:__subpackages__", - "//src/test/kotlin/org/wfanet/measurement/reporting/deploy/config:__subpackages__", - ], - deps = [ - "//src/main/proto/wfa/measurement/config/reporting:measurement_spec_config_kt_jvm_proto", - ], -) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/config/MeasurementSpecConfigValidator.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/config/MeasurementSpecConfigValidator.kt deleted file mode 100644 index 8858d48f807..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/config/MeasurementSpecConfigValidator.kt +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Copyright 2023 The Cross-Media Measurement Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.wfanet.measurement.reporting.deploy.config - -import org.wfanet.measurement.config.reporting.MeasurementSpecConfig - -/** - * Validates a [MeasurementSpecConfig] - * - * @throws [IllegalStateException] if the [MeasurementSpecConfig] is invalid. - */ -fun MeasurementSpecConfig.validate() { - check(this.reachSingleDataProvider.privacyParams.isValid()) { - "reach_single_data_provider privacy_params is invalid." - } - check(this.reachSingleDataProvider.vidSamplingInterval.isValid()) { - "reach_single_data_provider vid_sampling_interval is invalid." - } - - check(this.reach.privacyParams.isValid()) { "reach privacy_params is invalid." } - check(this.reach.vidSamplingInterval.isValid()) { "reach vid_sampling_interval is invalid." } - - check(this.reachAndFrequencySingleDataProvider.reachPrivacyParams.isValid()) { - "reach_and_frequency_single_data_provider reach_privacy_params is invalid." - } - check(this.reachAndFrequencySingleDataProvider.frequencyPrivacyParams.isValid()) { - "reach_and_frequency_single_data_provider frequency_privacy_params is invalid." - } - check(this.reachAndFrequencySingleDataProvider.vidSamplingInterval.isValid()) { - "reach_and_frequency_single_data_provider vid_sampling_interval is invalid." - } - - check(this.reachAndFrequency.reachPrivacyParams.isValid()) { - "reach_and_frequency reach_privacy_params is invalid." - } - check(this.reachAndFrequency.frequencyPrivacyParams.isValid()) { - "reach_and_frequency frequency_privacy_params is invalid." - } - check(this.reachAndFrequency.vidSamplingInterval.isValid()) { - "reach_and_frequency vid_sampling_interval is invalid." - } - - check(this.impression.privacyParams.isValid()) { "impression privacy_params is invalid." } - check(this.impression.vidSamplingInterval.isValid()) { - "impression vid_sampling_interval is invalid." - } - - check(this.duration.privacyParams.isValid()) { "duration privacy_params is invalid." } - check(this.duration.vidSamplingInterval.isValid()) { - "duration vid_sampling_interval is invalid." - } -} - -private fun MeasurementSpecConfig.DifferentialPrivacyParams.isValid(): Boolean { - return (this.epsilon > 0 && this.delta >= 0) -} - -private fun MeasurementSpecConfig.VidSamplingInterval.isValid(): Boolean { - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") - return when (this.startCase) { - MeasurementSpecConfig.VidSamplingInterval.StartCase.FIXED_START -> this.fixedStart.isValid() - MeasurementSpecConfig.VidSamplingInterval.StartCase.RANDOM_START -> this.randomStart.isValid() - MeasurementSpecConfig.VidSamplingInterval.StartCase.START_NOT_SET -> true - } -} - -private fun MeasurementSpecConfig.VidSamplingInterval.FixedStart.isValid(): Boolean { - if (this.start < 0.0) { - return false - } - - if (this.width <= 0.0) { - return false - } - - if (this.start + this.width > 1.0) { - return false - } - - return true -} - -private fun MeasurementSpecConfig.VidSamplingInterval.RandomStart.isValid(): Boolean { - if (this.numVidBuckets <= 0) { - return false - } - - if (this.width <= 0 || this.width > this.numVidBuckets) { - return false - } - - return true -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/gcloud/postgres/server/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/gcloud/postgres/server/BUILD.bazel deleted file mode 100644 index a90ab6868a3..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/gcloud/postgres/server/BUILD.bazel +++ /dev/null @@ -1,35 +0,0 @@ -load("@rules_java//java:defs.bzl", "java_binary") -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") -load("//src/main/docker:macros.bzl", "java_image") - -kt_jvm_library( - name = "gcloud_postgres_reporting_data_server", - srcs = ["GCloudPostgresReportingDataServer.kt"], - runtime_deps = ["@wfa_common_jvm//imports/java/com/google/cloud/sql/postgres:r2dbc"], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server:reporting_data_server", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/postgres:services", - "@wfa_common_jvm//imports/java/io/r2dbc", - "@wfa_common_jvm//imports/java/picocli", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/postgres:factories", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/postgres:flags", - ], -) - -java_binary( - name = "GCloudPostgresReportingDataServer", - main_class = "org.wfanet.measurement.reporting.deploy.gcloud.postgres.server.GCloudPostgresReportingDataServerKt", - runtime_deps = [ - ":gcloud_postgres_reporting_data_server", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/logging", - ], -) - -java_image( - name = "gcloud_postgres_reporting_data_server_image", - binary = ":GCloudPostgresReportingDataServer", - main_class = "org.wfanet.measurement.reporting.deploy.gcloud.postgres.server.GCloudPostgresReportingDataServerKt", - visibility = ["//src:docker_image_deployment"], -) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/gcloud/postgres/server/GCloudPostgresReportingDataServer.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/gcloud/postgres/server/GCloudPostgresReportingDataServer.kt deleted file mode 100644 index a59917f794e..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/gcloud/postgres/server/GCloudPostgresReportingDataServer.kt +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.gcloud.postgres.server - -import java.time.Clock -import kotlinx.coroutines.runBlocking -import org.wfanet.measurement.common.commandLineMain -import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresDatabaseClient -import org.wfanet.measurement.common.identity.RandomIdGenerator -import org.wfanet.measurement.gcloud.postgres.PostgresConnectionFactories -import org.wfanet.measurement.gcloud.postgres.PostgresFlags as GCloudPostgresFlags -import org.wfanet.measurement.reporting.deploy.common.server.ReportingDataServer -import org.wfanet.measurement.reporting.deploy.common.server.postgres.PostgresServices -import picocli.CommandLine - -/** Implementation of [ReportingDataServer] using Google Cloud Postgres. */ -@CommandLine.Command( - name = "GCloudPostgresReportingDataServer", - description = ["Start the internal Reporting data-layer services in a single blocking server."], - mixinStandardHelpOptions = true, - showDefaultValues = true, -) -class GCloudPostgresReportingDataServer : ReportingDataServer() { - @CommandLine.Mixin private lateinit var gCloudPostgresFlags: GCloudPostgresFlags - - override fun run() = runBlocking { - val clock = Clock.systemUTC() - val idGenerator = RandomIdGenerator(clock) - - val factory = PostgresConnectionFactories.buildConnectionFactory(gCloudPostgresFlags) - val client = PostgresDatabaseClient.fromConnectionFactory(factory) - - run(PostgresServices.create(idGenerator, client)) - } -} - -fun main(args: Array) = commandLineMain(GCloudPostgresReportingDataServer(), args) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/gcloud/postgres/tools/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/gcloud/postgres/tools/BUILD.bazel deleted file mode 100644 index 613680b7c64..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/gcloud/postgres/tools/BUILD.bazel +++ /dev/null @@ -1,21 +0,0 @@ -load("@rules_java//java:defs.bzl", "java_binary") -load("//src/main/docker:macros.bzl", "java_image") - -java_binary( - name = "UpdateSchema", - args = ["--changelog=reporting/postgres/changelog.yaml"], - main_class = "org.wfanet.measurement.gcloud.postgres.tools.UpdateSchema", - resources = ["//src/main/resources/reporting/postgres"], - runtime_deps = [ - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/logging", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/postgres/tools:update_schema", - ], -) - -java_image( - name = "update_schema_image", - args = ["--changelog=reporting/postgres/changelog.yaml"], - binary = ":UpdateSchema", - main_class = "org.wfanet.measurement.gcloud.postgres.tools.UpdateSchema", - visibility = ["//src:docker_image_deployment"], -) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/BUILD.bazel deleted file mode 100644 index 49aa526e488..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/BUILD.bazel +++ /dev/null @@ -1,22 +0,0 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") - -kt_jvm_library( - name = "services", - srcs = glob(["*Service.kt"]), - visibility = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/postgres:__subpackages__", - "//src/test/kotlin/org/wfanet/measurement/reporting/deploy/postgres:__pkg__", - ], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers", - "//src/main/proto/wfa/measurement/internal/reporting:measurements_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/internal/reporting:reporting_sets_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/internal/reporting:reports_service_kt_jvm_grpc_proto", - "@wfa_common_jvm//imports/java/io/grpc:api", - "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres", - ], -) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresMeasurementsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresMeasurementsService.kt deleted file mode 100644 index b1d02bfc0bc..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresMeasurementsService.kt +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres - -import io.grpc.Status -import org.wfanet.measurement.common.db.r2dbc.DatabaseClient -import org.wfanet.measurement.common.db.r2dbc.postgres.SerializableErrors -import org.wfanet.measurement.common.identity.IdGenerator -import org.wfanet.measurement.internal.reporting.BatchCreateMeasurementsRequest -import org.wfanet.measurement.internal.reporting.BatchCreateMeasurementsResponse -import org.wfanet.measurement.internal.reporting.GetMeasurementRequest -import org.wfanet.measurement.internal.reporting.Measurement -import org.wfanet.measurement.internal.reporting.MeasurementsGrpcKt.MeasurementsCoroutineImplBase -import org.wfanet.measurement.internal.reporting.SetMeasurementFailureRequest -import org.wfanet.measurement.internal.reporting.SetMeasurementResultRequest -import org.wfanet.measurement.internal.reporting.batchCreateMeasurementsResponse -import org.wfanet.measurement.reporting.deploy.postgres.readers.MeasurementReader -import org.wfanet.measurement.reporting.deploy.postgres.writers.CreateMeasurements -import org.wfanet.measurement.reporting.deploy.postgres.writers.SetMeasurementFailure -import org.wfanet.measurement.reporting.deploy.postgres.writers.SetMeasurementResult -import org.wfanet.measurement.reporting.service.internal.MeasurementNotFoundException -import org.wfanet.measurement.reporting.service.internal.MeasurementStateInvalidException - -class PostgresMeasurementsService( - private val idGenerator: IdGenerator, - private val client: DatabaseClient, -) : MeasurementsCoroutineImplBase() { - override suspend fun batchCreateMeasurements( - request: BatchCreateMeasurementsRequest - ): BatchCreateMeasurementsResponse { - return batchCreateMeasurementsResponse { - measurements += CreateMeasurements(request.measurementsList).execute(client, idGenerator) - } - } - - override suspend fun getMeasurement(request: GetMeasurementRequest): Measurement { - val measurementResult = - SerializableErrors.retrying { - MeasurementReader() - .readMeasurementByReferenceIds( - client.singleUse(), - request.measurementConsumerReferenceId, - request.measurementReferenceId, - ) - } - ?: throw MeasurementNotFoundException() - .asStatusRuntimeException(Status.Code.NOT_FOUND, "Measurement not found.") - - return measurementResult.measurement - } - - override suspend fun setMeasurementResult(request: SetMeasurementResultRequest): Measurement { - return try { - SetMeasurementResult(request).execute(client, idGenerator) - } catch (e: MeasurementNotFoundException) { - throw e.asStatusRuntimeException(Status.Code.NOT_FOUND, "Measurement not found.") - } catch (e: MeasurementStateInvalidException) { - throw e.asStatusRuntimeException( - Status.Code.FAILED_PRECONDITION, - "Measurement has already been updated.", - ) - } - } - - override suspend fun setMeasurementFailure(request: SetMeasurementFailureRequest): Measurement { - return try { - SetMeasurementFailure(request).execute(client, idGenerator) - } catch (e: MeasurementNotFoundException) { - throw e.asStatusRuntimeException(Status.Code.NOT_FOUND, "Measurement not found.") - } catch (e: MeasurementStateInvalidException) { - throw e.asStatusRuntimeException( - Status.Code.FAILED_PRECONDITION, - "Measurement has already been updated.", - ) - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresReportingSetsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresReportingSetsService.kt deleted file mode 100644 index c429fb046e3..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresReportingSetsService.kt +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres - -import io.grpc.Status -import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.map -import org.wfanet.measurement.common.db.r2dbc.DatabaseClient -import org.wfanet.measurement.common.db.r2dbc.postgres.SerializableErrors -import org.wfanet.measurement.common.db.r2dbc.postgres.SerializableErrors.withSerializableErrorRetries -import org.wfanet.measurement.common.identity.ExternalId -import org.wfanet.measurement.common.identity.IdGenerator -import org.wfanet.measurement.internal.reporting.BatchGetReportingSetRequest -import org.wfanet.measurement.internal.reporting.GetReportingSetRequest -import org.wfanet.measurement.internal.reporting.ReportingSet -import org.wfanet.measurement.internal.reporting.ReportingSetsGrpcKt.ReportingSetsCoroutineImplBase -import org.wfanet.measurement.internal.reporting.StreamReportingSetsRequest -import org.wfanet.measurement.reporting.deploy.postgres.readers.ReportingSetReader -import org.wfanet.measurement.reporting.deploy.postgres.writers.CreateReportingSet -import org.wfanet.measurement.reporting.service.internal.ReportingSetAlreadyExistsException -import org.wfanet.measurement.reporting.service.internal.ReportingSetNotFoundException - -class PostgresReportingSetsService( - private val idGenerator: IdGenerator, - private val client: DatabaseClient, -) : ReportingSetsCoroutineImplBase() { - override suspend fun createReportingSet(request: ReportingSet): ReportingSet { - return try { - CreateReportingSet(request).execute(client, idGenerator) - } catch (e: ReportingSetAlreadyExistsException) { - throw e.asStatusRuntimeException( - Status.Code.ALREADY_EXISTS, - "IDs generated for Reporting Set already exist", - ) - } - } - - override suspend fun getReportingSet(request: GetReportingSetRequest): ReportingSet { - return try { - SerializableErrors.retrying { - ReportingSetReader() - .readReportingSetByExternalId( - client.singleUse(), - request.measurementConsumerReferenceId, - ExternalId(request.externalReportingSetId), - ) - .reportingSet - } - } catch (e: ReportingSetNotFoundException) { - throw e.asStatusRuntimeException(Status.Code.NOT_FOUND, "Reporting Set not found") - } - } - - override fun streamReportingSets(request: StreamReportingSetsRequest): Flow { - return ReportingSetReader() - .listReportingSets(client, request.filter, request.limit) - .map { result -> result.reportingSet } - .withSerializableErrorRetries() - } - - override fun batchGetReportingSet(request: BatchGetReportingSetRequest): Flow { - return ReportingSetReader() - .getReportingSetsByExternalIds( - client, - request.measurementConsumerReferenceId, - request.externalReportingSetIdsList, - ) - .map { result -> result.reportingSet } - .withSerializableErrorRetries() - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresReportsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresReportsService.kt deleted file mode 100644 index 0ee2a9ae64a..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresReportsService.kt +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres - -import io.grpc.Status -import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.map -import org.wfanet.measurement.common.db.r2dbc.DatabaseClient -import org.wfanet.measurement.common.db.r2dbc.postgres.SerializableErrors -import org.wfanet.measurement.common.db.r2dbc.postgres.SerializableErrors.withSerializableErrorRetries -import org.wfanet.measurement.common.identity.IdGenerator -import org.wfanet.measurement.internal.reporting.CreateReportRequest -import org.wfanet.measurement.internal.reporting.GetReportByIdempotencyKeyRequest -import org.wfanet.measurement.internal.reporting.GetReportRequest -import org.wfanet.measurement.internal.reporting.Report -import org.wfanet.measurement.internal.reporting.ReportsGrpcKt.ReportsCoroutineImplBase -import org.wfanet.measurement.internal.reporting.StreamReportsRequest -import org.wfanet.measurement.reporting.deploy.postgres.readers.ReportReader -import org.wfanet.measurement.reporting.deploy.postgres.writers.CreateReport -import org.wfanet.measurement.reporting.service.internal.MeasurementCalculationTimeIntervalNotFoundException -import org.wfanet.measurement.reporting.service.internal.ReportNotFoundException -import org.wfanet.measurement.reporting.service.internal.ReportingSetNotFoundException - -class PostgresReportsService( - private val idGenerator: IdGenerator, - private val client: DatabaseClient, -) : ReportsCoroutineImplBase() { - override suspend fun createReport(request: CreateReportRequest): Report { - try { - return CreateReport(request).execute(client, idGenerator) - } catch (e: ReportingSetNotFoundException) { - throw e.asStatusRuntimeException(Status.Code.NOT_FOUND, "Reporting Set not found") - } catch (e: MeasurementCalculationTimeIntervalNotFoundException) { - throw e.asStatusRuntimeException( - Status.Code.INVALID_ARGUMENT, - "Measurement Calculation Time Interval not found in Report", - ) - } - } - - override suspend fun getReport(request: GetReportRequest): Report { - try { - return SerializableErrors.retrying { - ReportReader() - .getReportByExternalId( - client.singleUse(), - request.measurementConsumerReferenceId, - request.externalReportId, - ) - .report - } - } catch (e: ReportNotFoundException) { - throw e.asStatusRuntimeException(Status.Code.NOT_FOUND, "Report not found") - } - } - - override suspend fun getReportByIdempotencyKey( - request: GetReportByIdempotencyKeyRequest - ): Report { - try { - return SerializableErrors.retrying { - ReportReader() - .getReportByIdempotencyKey( - client.singleUse(), - request.measurementConsumerReferenceId, - request.reportIdempotencyKey, - ) - .report - } - } catch (e: ReportNotFoundException) { - throw e.asStatusRuntimeException(Status.Code.NOT_FOUND, "Report not found") - } - } - - override fun streamReports(request: StreamReportsRequest): Flow { - return ReportReader() - .listReports(client, request.filter, request.limit) - .map { result -> result.report } - .withSerializableErrorRetries() - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/BUILD.bazel deleted file mode 100644 index 53bcb7330d3..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/BUILD.bazel +++ /dev/null @@ -1,26 +0,0 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") - -package(default_visibility = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres:__subpackages__", -]) - -kt_jvm_library( - name = "readers", - srcs = glob(["*.kt"]), - deps = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/service/internal:internal_exception", - "//src/main/proto/wfa/measurement/internal/reporting:measurement_kt_jvm_proto", - "//src/main/proto/wfa/measurement/internal/reporting:measurements_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/internal/reporting:metric_kt_jvm_proto", - "//src/main/proto/wfa/measurement/internal/reporting:report_kt_jvm_proto", - "//src/main/proto/wfa/measurement/internal/reporting:reporting_set_kt_jvm_proto", - "//src/main/proto/wfa/measurement/internal/reporting:reporting_sets_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/internal/reporting:reports_service_kt_jvm_grpc_proto", - "@wfa_common_jvm//imports/java/com/google/gson", - "@wfa_common_jvm//imports/java/io/r2dbc", - "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/identity", - ], -) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/MeasurementReader.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/MeasurementReader.kt deleted file mode 100644 index 21e9218f31f..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/MeasurementReader.kt +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres.readers - -import kotlinx.coroutines.flow.firstOrNull -import org.wfanet.measurement.common.db.r2dbc.ReadContext -import org.wfanet.measurement.common.db.r2dbc.ResultRow -import org.wfanet.measurement.common.db.r2dbc.boundStatement -import org.wfanet.measurement.internal.reporting.Measurement -import org.wfanet.measurement.internal.reporting.measurement - -class MeasurementReader { - data class Result(val measurement: Measurement) - - private val baseSql: String = - """ - SELECT - MeasurementConsumerReferenceId, - MeasurementReferenceId, - State, - Failure, - Result - FROM - Measurements - """ - - fun translate(row: ResultRow): Result = Result(buildMeasurement(row)) - - /** - * Reads a Measurement using reference IDs. - * - * @return null when the Measurement is not found. - */ - suspend fun readMeasurementByReferenceIds( - readContext: ReadContext, - measurementConsumerReferenceId: String, - measurementReferenceId: String, - ): Result? { - val builder = - boundStatement( - baseSql + - """ - WHERE MeasurementConsumerReferenceId = $1 - AND MeasurementReferenceId = $2 - """ - ) { - bind("$1", measurementConsumerReferenceId) - bind("$2", measurementReferenceId) - } - - return readContext.executeQuery(builder).consume(::translate).firstOrNull() - } - - private fun buildMeasurement(row: ResultRow): Measurement { - return measurement { - measurementConsumerReferenceId = row["MeasurementConsumerReferenceId"] - measurementReferenceId = row["MeasurementReferenceId"] - state = Measurement.State.forNumber(row["State"]) - val failure: Measurement.Failure? = - row.getProtoMessage("Failure", Measurement.Failure.parser()) - if (failure != null) { - this.failure = failure - } - val result: Measurement.Result? = row.getProtoMessage("Result", Measurement.Result.parser()) - if (result != null) { - this.result = result - } - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/MeasurementResultsReader.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/MeasurementResultsReader.kt deleted file mode 100644 index ae63b9c7232..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/MeasurementResultsReader.kt +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres.readers - -import kotlinx.coroutines.flow.Flow -import org.wfanet.measurement.common.db.r2dbc.ReadContext -import org.wfanet.measurement.common.db.r2dbc.ResultRow -import org.wfanet.measurement.common.db.r2dbc.boundStatement -import org.wfanet.measurement.common.identity.InternalId -import org.wfanet.measurement.internal.reporting.Measurement - -class MeasurementResultsReader { - data class Result( - val reportId: InternalId, - val measurementReferenceId: String, - val state: Measurement.State, - val result: Measurement.Result?, - ) - - private val sql = - """ - SELECT - ReportId, - MeasurementReferenceId, - Measurements.State, - Result - FROM - ( - SELECT - MeasurementConsumerReferenceId, - ReportId - FROM - ReportMeasurements - WHERE - MeasurementConsumerReferenceId = $1 AND MeasurementReferenceId = $2 - ) AS Reports - JOIN ReportMeasurements USING (MeasurementConsumerReferenceId, ReportId) - JOIN Measurements USING (MeasurementConsumerReferenceId, MeasurementReferenceId) - """ - - fun translate(row: ResultRow): Result = - Result( - reportId = row["ReportId"], - measurementReferenceId = row["MeasurementReferenceId"], - state = Measurement.State.forNumber(row["State"]), - result = row.getProtoMessage("Result", Measurement.Result.parser()), - ) - - suspend fun listMeasurementsForReportsByMeasurementReferenceId( - readContext: ReadContext, - measurementConsumerReferenceId: String, - measurementReferenceId: String, - ): Flow { - val statement = - boundStatement(sql) { - bind("$1", measurementConsumerReferenceId) - bind("$2", measurementReferenceId) - } - - return readContext.executeQuery(statement).consume(::translate) - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/ReportReader.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/ReportReader.kt deleted file mode 100644 index 127a5d04cd6..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/ReportReader.kt +++ /dev/null @@ -1,527 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres.readers - -import com.google.gson.JsonObject -import com.google.gson.JsonParser -import com.google.protobuf.duration -import com.google.protobuf.timestamp -import java.time.Instant -import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.emitAll -import kotlinx.coroutines.flow.flow -import kotlinx.coroutines.flow.singleOrNull -import org.wfanet.measurement.common.base64MimeDecode -import org.wfanet.measurement.common.db.r2dbc.DatabaseClient -import org.wfanet.measurement.common.db.r2dbc.ReadContext -import org.wfanet.measurement.common.db.r2dbc.ResultRow -import org.wfanet.measurement.common.db.r2dbc.boundStatement -import org.wfanet.measurement.common.identity.InternalId -import org.wfanet.measurement.common.toProtoTime -import org.wfanet.measurement.internal.reporting.Measurement -import org.wfanet.measurement.internal.reporting.Metric -import org.wfanet.measurement.internal.reporting.Metric.SetOperation -import org.wfanet.measurement.internal.reporting.MetricKt -import org.wfanet.measurement.internal.reporting.MetricKt.namedSetOperation -import org.wfanet.measurement.internal.reporting.MetricKt.setOperation -import org.wfanet.measurement.internal.reporting.PeriodicTimeInterval -import org.wfanet.measurement.internal.reporting.Report -import org.wfanet.measurement.internal.reporting.StreamReportsRequest -import org.wfanet.measurement.internal.reporting.TimeIntervals -import org.wfanet.measurement.internal.reporting.measurement -import org.wfanet.measurement.internal.reporting.metric -import org.wfanet.measurement.internal.reporting.periodicTimeInterval -import org.wfanet.measurement.internal.reporting.report -import org.wfanet.measurement.internal.reporting.timeInterval -import org.wfanet.measurement.internal.reporting.timeIntervals -import org.wfanet.measurement.reporting.service.internal.ReportNotFoundException - -class ReportReader { - data class Result( - val measurementConsumerReferenceId: String, - val reportId: InternalId, - val report: Report, - ) - - private val baseSql: String = - """ - SELECT - MeasurementConsumerReferenceId, - ReportId, - ExternalReportId, - State, - ReportDetails, - ReportIdempotencyKey, - CreateTime, - ARRAY( - SELECT - json_build_object( - 'timeIntervalId', TimeIntervalId, - 'startSeconds', StartSeconds, - 'startNanos', StartNanos, - 'endSeconds', EndSeconds, - 'endNanos', EndNanos - ) - FROM TimeIntervals - WHERE TimeIntervals.MeasurementConsumerReferenceId = Reports.MeasurementConsumerReferenceId - AND TimeIntervals.ReportId = Reports.ReportId - ) AS TimeIntervals, - ARRAY( - SELECT - json_build_object( - 'measurementReferenceId', MeasurementReferenceId, - 'state', Measurements.State, - 'failure', encode(Measurements.failure, 'base64'), - 'result', encode(Measurements.result, 'base64') - ) - FROM ( - SELECT - MeasurementConsumerReferenceId, - MeasurementReferenceId - FROM ReportMeasurements - WHERE ReportMeasurements.ReportId = Reports.ReportId - ) AS ReportMeasurements - JOIN Measurements USING(MeasurementConsumerReferenceId, MeasurementReferenceId) - ) AS Measurements, - IntervalCount, - StartSeconds, - StartNanos, - IncrementSeconds, - IncrementNanos, - ( - SELECT array_agg( - json_build_object( - 'metricId', MetricId, - 'metricDetails', MetricDetails, - 'namedSetOperations', NamedSetOperations, - 'setOperations', SetOperations - ) - ) - FROM ( - SELECT - MetricId, - encode(MetricDetails, 'base64') AS MetricDetails, - ( - SELECT json_agg( - json_build_object( - 'displayName', DisplayName, - 'setOperationId', SetOperationId, - 'measurementCalculations', MeasurementCalculations - ) - ) - FROM ( - SELECT - DisplayName, - SetOperationId, - ( - SELECT json_agg( - json_build_object( - 'timeInterval', TimeInterval, - 'weightedMeasurements', WeightedMeasurements - ) - ) - FROM ( - SELECT - ( - SELECT json_build_object( - 'startSeconds', StartSeconds, - 'startNanos', StartNanos, - 'endSeconds', EndSeconds, - 'endNanos', EndNanos - ) - FROM TimeIntervals - WHERE MeasurementCalculations.MeasurementConsumerReferenceId = TimeIntervals.MeasurementConsumerReferenceId - AND MeasurementCalculations.ReportId = TimeIntervals.ReportId - AND MeasurementCalculations.TimeIntervalId = TimeIntervals.TimeIntervalId - ) AS TimeInterval, - ( - SELECT json_agg( - json_build_object( - 'measurementReferenceId', MeasurementReferenceId, - 'coefficient', Coefficient - ) - ) - FROM WeightedMeasurements - Where WeightedMeasurements.MeasurementConsumerReferenceId = MeasurementCalculations.MeasurementConsumerReferenceId - AND WeightedMeasurements.ReportId = MeasurementCalculations.ReportId - AND WeightedMeasurements.MetricId = MeasurementCalculations.MetricId - AND WeightedMeasurements.NamedSetOperationId = MeasurementCalculations.NamedSetOperationId - AND WeightedMeasurements.MeasurementCalculationId = MeasurementCalculations.MeasurementCalculationId - ) AS WeightedMeasurements - FROM MeasurementCalculations - Where MeasurementCalculations.MeasurementConsumerReferenceId = NamedSetOperations.MeasurementConsumerReferenceId - AND MeasurementCalculations.ReportId = NamedSetOperations.ReportId - AND MeasurementCalculations.MetricId = NamedSetOperations.MetricId - AND MeasurementCalculations.NamedSetOperationId = NamedSetOperations.NamedSetOperationId - ) MeasurementCalculations - ) AS MeasurementCalculations - FROM NamedSetOperations - WHERE NamedSetOperations.MeasurementConsumerReferenceId = Metrics.MeasurementConsumerReferenceId - AND NamedSetOperations.ReportId = Metrics.ReportId - AND NamedSetOperations.MetricId = Metrics.MetricId - ) AS NamedSetOperations) AS NamedSetOperations, - ( - SELECT json_agg( - json_build_object( - 'type', Type, - 'setOperationId', SetOperationId, - 'leftHandSetOperationId', LeftHandSetOperationId, - 'rightHandSetOperationId', RightHandSetOperationId, - 'leftHandReportingSetId', - ( - SELECT ExternalReportingSetId - FROM ReportingSets - WHERE SetOperations.MeasurementConsumerReferenceId = ReportingSets.MeasurementConsumerReferenceId - AND ReportingSetId = LeftHandReportingSetId - ), - 'rightHandReportingSetId', - ( - SELECT ExternalReportingSetId - FROM ReportingSets - WHERE SetOperations.MeasurementConsumerReferenceId = ReportingSets.MeasurementConsumerReferenceId - AND ReportingSetId = RightHandReportingSetId - ) - ) - ) - FROM SetOperations - Where SetOperations.MeasurementConsumerReferenceId = Metrics.MeasurementConsumerReferenceId - AND SetOperations.ReportId = Metrics.ReportId - AND SetOperations.MetricId = Metrics.MetricId - ) AS SetOperations - FROM Metrics - WHERE Metrics.MeasurementConsumerReferenceId = Reports.MeasurementConsumerReferenceId - AND Metrics.ReportId = Reports.ReportId - ) AS Metrics - ) AS Metrics - FROM Reports - LEFT JOIN PeriodicTimeIntervals USING(MeasurementConsumerReferenceId, ReportId) - """ - - fun translate(row: ResultRow): Result = - Result(row["MeasurementConsumerReferenceId"], row["ReportId"], buildReport(row)) - - /** - * Gets the report by external report id. - * - * @throws [ReportNotFoundException] - */ - suspend fun getReportByExternalId( - readContext: ReadContext, - measurementConsumerReferenceId: String, - externalReportId: Long, - ): Result { - val statement = - boundStatement( - (baseSql + - """ - WHERE MeasurementConsumerReferenceId = $1 - AND ExternalReportId = $2 - """) - ) { - bind("$1", measurementConsumerReferenceId) - bind("$2", externalReportId) - } - - return readContext.executeQuery(statement).consume(::translate).singleOrNull() - ?: throw ReportNotFoundException() - } - - /** - * Gets the report by report id. - * - * @throws [ReportNotFoundException] - */ - suspend fun getReportById( - readContext: ReadContext, - measurementConsumerReferenceId: String, - reportId: Long, - ): Result { - val statement = - boundStatement( - (baseSql + - """ - WHERE MeasurementConsumerReferenceId = $1 - AND ReportId = $2 - """) - ) { - bind("$1", measurementConsumerReferenceId) - bind("$2", reportId) - } - - return readContext.executeQuery(statement).consume(::translate).singleOrNull() - ?: throw ReportNotFoundException() - } - - /** - * Gets the report by report idempotency key. - * - * @throws [ReportNotFoundException] - */ - suspend fun getReportByIdempotencyKey( - readContext: ReadContext, - measurementConsumerReferenceId: String, - reportIdempotencyKey: String, - ): Result { - val statement = - boundStatement( - (baseSql + - """ - WHERE MeasurementConsumerReferenceId = $1 - AND ReportIdempotencyKey = $2 - """) - ) { - bind("$1", measurementConsumerReferenceId) - bind("$2", reportIdempotencyKey) - } - - return readContext.executeQuery(statement).consume(::translate).singleOrNull() - ?: throw ReportNotFoundException() - } - - fun listReports( - client: DatabaseClient, - filter: StreamReportsRequest.Filter, - limit: Int = 0, - ): Flow { - val statement = - boundStatement( - baseSql + - """ - WHERE MeasurementConsumerReferenceId = $1 - AND ExternalReportId > $2 - ORDER BY ExternalReportId ASC - LIMIT $3 - """ - ) { - bind("$1", filter.measurementConsumerReferenceId) - bind("$2", filter.externalReportIdAfter) - if (limit > 0) { - bind("$3", limit) - } else { - bind("$3", 50) - } - } - - return flow { - val readContext = client.readTransaction() - try { - emitAll(readContext.executeQuery(statement).consume(::translate)) - } finally { - readContext.close() - } - } - } - - private fun buildReport(row: ResultRow): Report { - return report { - measurementConsumerReferenceId = row["MeasurementConsumerReferenceId"] - externalReportId = row["ExternalReportId"] - state = Report.State.forNumber(row["State"]) - details = row.getProtoMessage("ReportDetails", Report.Details.parser()) - reportIdempotencyKey = row["ReportIdempotencyKey"] - createTime = row.get("CreateTime").toProtoTime() - val intervalCount: Int? = row["IntervalCount"] - if (intervalCount != null) { - this.periodicTimeInterval = buildPeriodicTimeInterval(row) - } else { - timeIntervals = buildTimeIntervals(row["TimeIntervals"]) - } - metrics += buildMetrics(measurementConsumerReferenceId, row["Metrics"]) - measurements.putAll(buildMeasurements(measurementConsumerReferenceId, row["Measurements"])) - } - } - - private fun buildPeriodicTimeInterval(row: ResultRow): PeriodicTimeInterval { - return periodicTimeInterval { - startTime = timestamp { - seconds = row["StartSeconds"] - nanos = row["StartNanos"] - } - increment = duration { - seconds = row["IncrementSeconds"] - nanos = row["IncrementNanos"] - } - intervalCount = row["IntervalCount"] - } - } - - private fun buildTimeIntervals(timeIntervalsArr: Array): TimeIntervals { - return timeIntervals { - timeIntervalsArr.forEach { - val timeIntervalObject = JsonParser.parseString(it).asJsonObject - timeIntervals += timeInterval { - startTime = timestamp { - seconds = timeIntervalObject.getAsJsonPrimitive("startSeconds").asLong - nanos = timeIntervalObject.getAsJsonPrimitive("startNanos").asInt - } - endTime = timestamp { - seconds = timeIntervalObject.getAsJsonPrimitive("endSeconds").asLong - nanos = timeIntervalObject.getAsJsonPrimitive("endNanos").asInt - } - } - } - } - } - - private fun buildMeasurements( - measurementConsumerReferenceId: String, - measurementsArr: Array, - ): Map { - return measurementsArr - .map { JsonParser.parseString(it).asJsonObject } - .associateBy( - { it.getAsJsonPrimitive("measurementReferenceId").asString }, - { - measurement { - this.measurementConsumerReferenceId = measurementConsumerReferenceId - measurementReferenceId = it.getAsJsonPrimitive("measurementReferenceId").asString - state = Measurement.State.forNumber(it.getAsJsonPrimitive("state").asInt) - if (!it.get("failure").isJsonNull) { - failure = - Measurement.Failure.parseFrom(it.getAsJsonPrimitive("failure").base64MimeDecode()) - } - if (!it.get("result").isJsonNull) { - result = - Measurement.Result.parseFrom(it.getAsJsonPrimitive("result").base64MimeDecode()) - } - } - }, - ) - } - - private fun buildMetrics( - measurementConsumerReferenceId: String, - metricsArr: Array, - ): Collection { - val metricsList = ArrayList(metricsArr.size) - metricsArr.forEach { - val metricObject = JsonParser.parseString(it).asJsonObject - metricsList.add( - metric { - details = - Metric.Details.parseFrom( - metricObject.getAsJsonPrimitive("metricDetails").base64MimeDecode() - ) - - val setOperationsArr = metricObject.getAsJsonArray("setOperations") - val setOperationsMap = mutableMapOf() - setOperationsArr.forEach { setOperationElement -> - val setOperationObject = setOperationElement.asJsonObject - setOperationsMap[setOperationObject.getAsJsonPrimitive("setOperationId").asLong] = - setOperationObject - } - - val namedSetOperationsArr = metricObject.getAsJsonArray("namedSetOperations") - namedSetOperationsArr.forEach { namedSetOperationElement -> - val namedSetOperationObject = namedSetOperationElement.asJsonObject - namedSetOperations += namedSetOperation { - displayName = namedSetOperationObject.getAsJsonPrimitive("displayName").asString - val setOperationId = - namedSetOperationObject.getAsJsonPrimitive("setOperationId").asLong - setOperation = - buildSetOperation(measurementConsumerReferenceId, setOperationId, setOperationsMap) - val measurementCalculationsArr = - namedSetOperationObject.getAsJsonArray("measurementCalculations") - measurementCalculationsArr.forEach { measurementCalculationElement -> - val measurementCalculationObject = measurementCalculationElement.asJsonObject - measurementCalculations += - MetricKt.measurementCalculation { - val timeIntervalObject = - measurementCalculationObject.getAsJsonObject("timeInterval") - timeInterval = timeInterval { - startTime = timestamp { - seconds = timeIntervalObject.getAsJsonPrimitive("startSeconds").asLong - nanos = timeIntervalObject.getAsJsonPrimitive("startNanos").asInt - } - endTime = timestamp { - seconds = timeIntervalObject.getAsJsonPrimitive("endSeconds").asLong - nanos = timeIntervalObject.getAsJsonPrimitive("endNanos").asInt - } - } - val weightedMeasurementsArr = - measurementCalculationObject.getAsJsonArray("weightedMeasurements") - weightedMeasurementsArr.forEach { weightedMeasurementElement -> - val weightedMeasurementObject = weightedMeasurementElement.asJsonObject - weightedMeasurements += - MetricKt.MeasurementCalculationKt.weightedMeasurement { - measurementReferenceId = - weightedMeasurementObject - .getAsJsonPrimitive("measurementReferenceId") - .asString - coefficient = - weightedMeasurementObject.getAsJsonPrimitive("coefficient").asInt - } - } - } - } - } - } - } - ) - } - return metricsList - } - - private fun buildSetOperation( - measurementConsumerReferenceId: String, - setOperationId: Long, - setOperationMap: Map, - ): SetOperation { - val setOperationObject = setOperationMap[setOperationId] - return setOperation { - type = SetOperation.Type.forNumber(setOperationObject!!.getAsJsonPrimitive("type").asInt) - lhs = - MetricKt.SetOperationKt.operand { - if (setOperationObject.get("leftHandSetOperationId").isJsonNull) { - if (!setOperationObject.get("leftHandReportingSetId").isJsonNull) { - reportingSetId = - MetricKt.SetOperationKt.reportingSetKey { - this.measurementConsumerReferenceId = measurementConsumerReferenceId - externalReportingSetId = - setOperationObject.getAsJsonPrimitive("leftHandReportingSetId").asLong - } - } - } else { - operation = - buildSetOperation( - measurementConsumerReferenceId, - setOperationObject.getAsJsonPrimitive("leftHandSetOperationId").asLong, - setOperationMap, - ) - } - } - rhs = - MetricKt.SetOperationKt.operand { - if (setOperationObject.get("rightHandSetOperationId").isJsonNull) { - if (!setOperationObject.get("rightHandReportingSetId").isJsonNull) { - reportingSetId = - MetricKt.SetOperationKt.reportingSetKey { - this.measurementConsumerReferenceId = measurementConsumerReferenceId - externalReportingSetId = - setOperationObject.getAsJsonPrimitive("rightHandReportingSetId").asLong - } - } - } else { - operation = - buildSetOperation( - measurementConsumerReferenceId, - setOperationObject.getAsJsonPrimitive("rightHandSetOperationId").asLong, - setOperationMap, - ) - } - } - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/ReportingSetReader.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/ReportingSetReader.kt deleted file mode 100644 index 8ded14f23f3..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers/ReportingSetReader.kt +++ /dev/null @@ -1,202 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres.readers - -import com.google.gson.JsonObject -import com.google.gson.JsonParser -import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.emitAll -import kotlinx.coroutines.flow.emptyFlow -import kotlinx.coroutines.flow.firstOrNull -import kotlinx.coroutines.flow.flow -import org.wfanet.measurement.common.db.r2dbc.DatabaseClient -import org.wfanet.measurement.common.db.r2dbc.ReadContext -import org.wfanet.measurement.common.db.r2dbc.ResultRow -import org.wfanet.measurement.common.db.r2dbc.boundStatement -import org.wfanet.measurement.common.identity.ExternalId -import org.wfanet.measurement.common.identity.InternalId -import org.wfanet.measurement.internal.reporting.ReportingSet -import org.wfanet.measurement.internal.reporting.ReportingSet.EventGroupKey -import org.wfanet.measurement.internal.reporting.ReportingSetKt -import org.wfanet.measurement.internal.reporting.StreamReportingSetsRequest -import org.wfanet.measurement.internal.reporting.reportingSet -import org.wfanet.measurement.reporting.service.internal.ReportingInternalException -import org.wfanet.measurement.reporting.service.internal.ReportingSetNotFoundException - -class ReportingSetReader { - data class Result( - val measurementConsumerReferenceId: String, - val reportingSetId: InternalId, - var reportingSet: ReportingSet, - ) - - private val baseSql: String = - """ - SELECT - MeasurementConsumerReferenceId, - ReportingSetId, - ExternalReportingSetId, - Filter, - DisplayName, - ( - SELECT ARRAY( - SELECT - json_build_object( - 'measurementConsumerReferenceId', MeasurementConsumerReferenceId, - 'dataProviderReferenceId', DataProviderReferenceId, - 'eventGroupReferenceId', EventGroupReferenceId - ) - FROM ReportingSetEventGroups - WHERE ReportingSetEventGroups.MeasurementConsumerReferenceId = ReportingSets.MeasurementConsumerReferenceId - AND ReportingSetEventGroups.ReportingSetId = ReportingSets.ReportingSetId - ) - ) AS EventGroups - FROM - ReportingSets - """ - - fun translate(row: ResultRow): Result = - Result(row["MeasurementConsumerReferenceId"], row["ReportingSetId"], buildReportingSet(row)) - - /** - * Reads a Reporting Set using external ID. - * - * Throws a subclass of [ReportingInternalException]. - * - * @throws [ReportingSetNotFoundException] Reporting Set not found. - */ - suspend fun readReportingSetByExternalId( - readContext: ReadContext, - measurementConsumerReferenceId: String, - externalReportingSetId: ExternalId, - ): Result { - val statement = - boundStatement( - baseSql + - """ - WHERE MeasurementConsumerReferenceId = $1 - AND ExternalReportingSetId = $2 - """ - ) { - bind("$1", measurementConsumerReferenceId) - bind("$2", externalReportingSetId) - } - - return readContext.executeQuery(statement).consume(::translate).firstOrNull() - ?: throw ReportingSetNotFoundException() - } - - fun listReportingSets( - client: DatabaseClient, - filter: StreamReportingSetsRequest.Filter, - limit: Int = 0, - ): Flow { - val statement = - boundStatement( - baseSql + - """ - WHERE MeasurementConsumerReferenceId = $1 - AND ExternalReportingSetId > $2 - ORDER BY ExternalReportingSetId ASC - LIMIT $3 - """ - ) { - bind("$1", filter.measurementConsumerReferenceId) - bind("$2", filter.externalReportingSetIdAfter) - if (limit > 0) { - bind("$3", limit) - } else { - bind("$3", 50) - } - } - - return flow { - val readContext = client.readTransaction() - try { - emitAll(readContext.executeQuery(statement).consume(::translate)) - } finally { - readContext.close() - } - } - } - - fun getReportingSetsByExternalIds( - client: DatabaseClient, - measurementConsumerReferenceId: String, - externalReportingSetIds: List, - ): Flow { - val sql = - StringBuilder( - baseSql + - """ - WHERE MeasurementConsumerReferenceId = $1 - AND ExternalReportingSetId IN - """ - ) - - if (externalReportingSetIds.isEmpty()) { - return emptyFlow() - } - - var i = 2 - val bindingMap = mutableMapOf() - val inList = - externalReportingSetIds.joinToString(separator = ",", prefix = "(", postfix = ")") { - val index = "$$i" - bindingMap[it] = "$$i" - i++ - index - } - sql.append(inList) - - val statement = - boundStatement(sql.toString()) { - bind("$1", measurementConsumerReferenceId) - - externalReportingSetIds.forEach { bind(bindingMap.getValue(it), it) } - } - - return flow { - val readContext = client.readTransaction() - try { - emitAll(readContext.executeQuery(statement).consume(::translate)) - } finally { - readContext.close() - } - } - } - - private fun buildReportingSet(row: ResultRow): ReportingSet { - return reportingSet { - measurementConsumerReferenceId = row["MeasurementConsumerReferenceId"] - externalReportingSetId = row["ExternalReportingSetId"] - filter = row["Filter"] - displayName = row["DisplayName"] - val eventGroupsArr = row.get>("EventGroups") - eventGroupsArr.forEach { - eventGroupKeys += buildEventGroupKey(JsonParser.parseString(it).asJsonObject) - } - } - } - - private fun buildEventGroupKey(eventGroup: JsonObject): EventGroupKey { - return ReportingSetKt.eventGroupKey { - measurementConsumerReferenceId = - eventGroup.getAsJsonPrimitive("measurementConsumerReferenceId").asString - dataProviderReferenceId = eventGroup.getAsJsonPrimitive("dataProviderReferenceId").asString - eventGroupReferenceId = eventGroup.getAsJsonPrimitive("eventGroupReferenceId").asString - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/server/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/server/BUILD.bazel deleted file mode 100644 index fddb2884cc9..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/server/BUILD.bazel +++ /dev/null @@ -1,29 +0,0 @@ -load("@rules_java//java:defs.bzl", "java_binary") -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") -load("//src/main/docker:macros.bzl", "java_image") - -kt_jvm_library( - name = "postgres_reporting_data_server", - srcs = ["PostgresReportingDataServer.kt"], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server:reporting_data_server", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/postgres:services", - "@wfa_common_jvm//imports/java/picocli", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/postgres:flags", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres", - ], -) - -java_binary( - name = "PostgresReportingDataServer", - main_class = "org.wfanet.measurement.reporting.deploy.postgres.server.PostgresReportingDataServerKt", - runtime_deps = [":postgres_reporting_data_server"], -) - -java_image( - name = "postgres_reporting_data_server_image", - binary = ":PostgresReportingDataServer", - main_class = "org.wfanet.measurement.reporting.deploy.postgres.server.PostgresReportingDataServerKt", - visibility = ["//src:docker_image_deployment"], -) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/server/PostgresReportingDataServer.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/server/PostgresReportingDataServer.kt deleted file mode 100644 index b84b5ad9acd..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/server/PostgresReportingDataServer.kt +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres.server - -import java.time.Clock -import kotlinx.coroutines.runBlocking -import org.wfanet.measurement.common.commandLineMain -import org.wfanet.measurement.common.db.postgres.PostgresFlags -import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresDatabaseClient -import org.wfanet.measurement.common.identity.RandomIdGenerator -import org.wfanet.measurement.reporting.deploy.common.server.ReportingDataServer -import org.wfanet.measurement.reporting.deploy.common.server.postgres.PostgresServices -import picocli.CommandLine - -/** Implementation of [ReportingDataServer] using Postgres. */ -@CommandLine.Command( - name = "PostgresReportingDataServer", - description = ["Start the internal Reporting data-layer services in a single blocking server."], - mixinStandardHelpOptions = true, - showDefaultValues = true, -) -class PostgresReportingDataServer : ReportingDataServer() { - @CommandLine.Mixin private lateinit var postgresFlags: PostgresFlags - - override fun run() = runBlocking { - val clock = Clock.systemUTC() - val idGenerator = RandomIdGenerator(clock) - - val client = PostgresDatabaseClient.fromFlags(postgresFlags) - - run(PostgresServices.create(idGenerator, client)) - } -} - -fun main(args: Array) = commandLineMain(PostgresReportingDataServer(), args) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/testing/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/testing/BUILD.bazel deleted file mode 100644 index 45fba8267c3..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/testing/BUILD.bazel +++ /dev/null @@ -1,18 +0,0 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") - -package( - default_testonly = True, - default_visibility = [ - "//src/test/kotlin/org/wfanet/measurement/integration/deploy/common/postgres:__subpackages__", - "//src/test/kotlin/org/wfanet/measurement/reporting/deploy/postgres:__subpackages__", - ], -) - -kt_jvm_library( - name = "testing", - srcs = glob(["*.kt"]), - resources = ["//src/main/resources/reporting/postgres"], - deps = [ - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", - ], -) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/testing/Schemata.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/testing/Schemata.kt deleted file mode 100644 index bf9233ecdda..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/testing/Schemata.kt +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres.testing - -import java.nio.file.Path -import org.wfanet.measurement.common.getJarResourcePath - -object Schemata { - private const val REPORTING_POSTGRES_RESOURCE_PREFIX = "reporting/postgres" - - private fun getResourcePath(fileName: String): Path { - val resourceName = "$REPORTING_POSTGRES_RESOURCE_PREFIX/$fileName" - val classLoader: ClassLoader = Thread.currentThread().contextClassLoader - return requireNotNull(classLoader.getJarResourcePath(resourceName)) { - "Resource $resourceName not found" - } - } - - val REPORTING_CHANGELOG_PATH: Path = getResourcePath("changelog.yaml") -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/tools/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/tools/BUILD.bazel deleted file mode 100644 index 31c1f110702..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/tools/BUILD.bazel +++ /dev/null @@ -1,20 +0,0 @@ -load("@rules_java//java:defs.bzl", "java_binary") -load("//src/main/docker:macros.bzl", "java_image") - -java_binary( - name = "UpdateSchema", - args = ["--changelog=reporting/postgres/changelog.yaml"], - main_class = "org.wfanet.measurement.common.db.postgres.tools.UpdateSchema", - resources = ["//src/main/resources/reporting/postgres"], - runtime_deps = [ - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/postgres/tools:update_schema", - ], -) - -java_image( - name = "update_schema_image", - args = ["--changelog=reporting/postgres/changelog.yaml"], - binary = ":UpdateSchema", - main_class = "org.wfanet.measurement.common.db.postgres.tools.UpdateSchema", - visibility = ["//src:docker_image_deployment"], -) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/BUILD.bazel deleted file mode 100644 index 6073e762463..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/BUILD.bazel +++ /dev/null @@ -1,27 +0,0 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") - -package(default_visibility = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres:__subpackages__", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres:__subpackages__", -]) - -kt_jvm_library( - name = "writers", - srcs = glob(["*.kt"]), - deps = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/readers", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/internal:internal_exception", - "//src/main/proto/wfa/measurement/internal/reporting:measurement_kt_jvm_proto", - "//src/main/proto/wfa/measurement/internal/reporting:metric_kt_jvm_proto", - "//src/main/proto/wfa/measurement/internal/reporting:report_kt_jvm_proto", - "//src/main/proto/wfa/measurement/internal/reporting:reporting_set_kt_jvm_proto", - "@wfa_common_jvm//imports/java/com/google/protobuf", - "@wfa_common_jvm//imports/java/io/r2dbc", - "@wfa_common_jvm//imports/java/org/postgresql:r2dbc", - "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/identity", - ], -) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/CreateMeasurements.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/CreateMeasurements.kt deleted file mode 100644 index 65dd19af31b..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/CreateMeasurements.kt +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres.writers - -import org.wfanet.measurement.common.db.r2dbc.boundStatement -import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter -import org.wfanet.measurement.internal.reporting.Measurement -import org.wfanet.measurement.internal.reporting.copy -import org.wfanet.measurement.reporting.service.internal.MeasurementAlreadyExistsException - -/** - * Inserts Measurements into the database. - * - * Throws the following on [execute]: - * * [MeasurementAlreadyExistsException] Measurement already exists - */ -class CreateMeasurements(private val measurements: Iterable) : - PostgresWriter>() { - override suspend fun TransactionScope.runTransaction(): Iterable { - transactionContext.run { - for (measurement in measurements) { - val builder = - boundStatement( - """ - INSERT INTO Measurements (MeasurementConsumerReferenceId, MeasurementReferenceId, State) - VALUES ($1, $2, $3) - ON CONFLICT DO NOTHING - """ - ) { - bind("$1", measurement.measurementConsumerReferenceId) - bind("$2", measurement.measurementReferenceId) - bind("$3", Measurement.State.PENDING_VALUE) - } - - executeStatement(builder) - } - } - - return measurements.map { it.copy { state = Measurement.State.PENDING } } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/CreateReport.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/CreateReport.kt deleted file mode 100644 index fd358b20f9a..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/CreateReport.kt +++ /dev/null @@ -1,482 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres.writers - -import com.google.protobuf.util.Timestamps -import org.wfanet.measurement.common.db.r2dbc.boundStatement -import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter -import org.wfanet.measurement.common.identity.ExternalId -import org.wfanet.measurement.internal.reporting.CreateReportRequest -import org.wfanet.measurement.internal.reporting.CreateReportRequest.MeasurementKey -import org.wfanet.measurement.internal.reporting.Measurement -import org.wfanet.measurement.internal.reporting.Metric -import org.wfanet.measurement.internal.reporting.Metric.MeasurementCalculation.WeightedMeasurement -import org.wfanet.measurement.internal.reporting.Metric.NamedSetOperation -import org.wfanet.measurement.internal.reporting.Metric.SetOperation -import org.wfanet.measurement.internal.reporting.PeriodicTimeInterval -import org.wfanet.measurement.internal.reporting.Report -import org.wfanet.measurement.internal.reporting.TimeInterval -import org.wfanet.measurement.internal.reporting.copy -import org.wfanet.measurement.internal.reporting.timeInterval -import org.wfanet.measurement.reporting.deploy.postgres.readers.ReportingSetReader -import org.wfanet.measurement.reporting.service.internal.MeasurementCalculationTimeIntervalNotFoundException -import org.wfanet.measurement.reporting.service.internal.ReportingSetNotFoundException - -/** - * Inserts a Report into the database. - * - * Throws the following on [execute]: - * * [ReportingSetNotFoundException] ReportingSet not found - * * [MeasurementCalculationTimeIntervalNotFoundException] MeasurementCalculation TimeInterval not - * found. - */ -class CreateReport(private val request: CreateReportRequest) : PostgresWriter() { - override suspend fun TransactionScope.runTransaction(): Report { - val report = request.report - val internalReportId = idGenerator.generateInternalId().value - val externalReportId = idGenerator.generateExternalId().value - - val timeIntervals: List - val isPeriodic = report.timeCase == Report.TimeCase.PERIODIC_TIME_INTERVAL - if (isPeriodic) { - timeIntervals = ArrayList(report.periodicTimeInterval.intervalCount) - timeIntervals.add( - timeInterval { - startTime = report.periodicTimeInterval.startTime - endTime = Timestamps.add(startTime, report.periodicTimeInterval.increment) - } - ) - for (i in 1 until report.periodicTimeInterval.intervalCount) { - timeIntervals.add( - timeInterval { - startTime = timeIntervals[i - 1].endTime - endTime = Timestamps.add(startTime, report.periodicTimeInterval.increment) - } - ) - } - } else { - timeIntervals = report.timeIntervals.timeIntervalsList - } - - val statement = - boundStatement( - """ - INSERT INTO Reports (MeasurementConsumerReferenceId, ReportId, ExternalReportId, State, ReportDetails, ReportIdempotencyKey, CreateTime) - VALUES ($1, $2, $3, $4, $5, $6, now() at time zone 'utc') - """ - ) { - bind("$1", report.measurementConsumerReferenceId) - bind("$2", internalReportId) - bind("$3", externalReportId) - bind("$4", Report.State.RUNNING_VALUE) - bind("$5", report.details) - bind("$6", report.reportIdempotencyKey) - } - - transactionContext.run { - executeStatement(statement) - if (isPeriodic) { - insertPeriodicTimeInterval( - report.measurementConsumerReferenceId, - internalReportId, - report.periodicTimeInterval, - ) - } - insertMeasurements(request.measurementsList) - val timeIntervalMap = - insertTimeIntervals(report.measurementConsumerReferenceId, internalReportId, timeIntervals) - - report.metricsList.forEach { - insertMetric(report.measurementConsumerReferenceId, internalReportId, timeIntervalMap, it) - } - insertReportMeasurements(request.measurementsList, internalReportId) - } - - return report.copy { - this.externalReportId = externalReportId - state = Report.State.RUNNING - } - } - - private suspend fun TransactionScope.insertMeasurements( - measurements: Collection - ) { - val sql = - StringBuilder( - """ - INSERT INTO Measurements (MeasurementConsumerReferenceId, MeasurementReferenceId, State) - VALUES ($1, $2, $3) - """ - ) - val numParameters = 3 - var firstParam = 1 - for (i in 1 until measurements.size) { - firstParam += numParameters - sql.append(",($$firstParam, $${firstParam + 1}, $${firstParam + 2})") - } - sql.append("ON CONFLICT DO NOTHING") - - firstParam = 1 - val statement = - boundStatement(sql.toString()) { - measurements.forEach { - bind("$$firstParam", it.measurementConsumerReferenceId) - bind("$${firstParam + 1}", it.measurementReferenceId) - bind("$${firstParam + 2}", Measurement.State.PENDING_VALUE) - firstParam += numParameters - } - } - transactionContext.executeStatement(statement) - } - - private suspend fun TransactionScope.insertReportMeasurements( - measurements: Collection, - reportId: Long, - ) { - val sql = - StringBuilder( - """ - INSERT INTO ReportMeasurements (MeasurementConsumerReferenceId, MeasurementReferenceId, ReportId) - VALUES ($1, $2, $3) - """ - ) - val numParameters = 3 - var firstParam = 1 - for (i in 1 until measurements.size) { - firstParam += numParameters - sql.append(",($$firstParam, $${firstParam + 1}, $${firstParam + 2})") - } - - firstParam = 1 - val statement = - boundStatement(sql.toString()) { - measurements.forEach { - bind("$$firstParam", it.measurementConsumerReferenceId) - bind("$${firstParam + 1}", it.measurementReferenceId) - bind("$${firstParam + 2}", reportId) - firstParam += numParameters - } - } - transactionContext.executeStatement(statement) - } - - private suspend fun TransactionScope.insertPeriodicTimeInterval( - measurementConsumerReferenceId: String, - reportId: Long, - periodicTimeInterval: PeriodicTimeInterval, - ) { - val statement = - boundStatement( - """ - INSERT INTO PeriodicTimeIntervals (MeasurementConsumerReferenceId, ReportId, StartSeconds, StartNanos, IncrementSeconds, IncrementNanos, IntervalCount) - VALUES ($1, $2, $3, $4, $5, $6, $7) - """ - ) { - bind("$1", measurementConsumerReferenceId) - bind("$2", reportId) - bind("$3", periodicTimeInterval.startTime.seconds) - bind("$4", periodicTimeInterval.startTime.nanos) - bind("$5", periodicTimeInterval.increment.seconds) - bind("$6", periodicTimeInterval.increment.nanos) - bind("$7", periodicTimeInterval.intervalCount) - } - - transactionContext.executeStatement(statement) - } - - private suspend fun TransactionScope.insertTimeIntervals( - measurementConsumerReferenceId: String, - reportId: Long, - timeIntervals: Collection, - ): Map { - val sql = - StringBuilder( - """ - INSERT INTO TimeIntervals (MeasurementConsumerReferenceId, ReportId, TimeIntervalId, StartSeconds, StartNanos, EndSeconds, EndNanos) - VALUES ($1, $2, $3, $4, $5, $6, $7) - """ - ) - val numParameters = 7 - var firstParam = 1 - for (i in 1 until timeIntervals.size) { - firstParam += numParameters - sql.append( - """,($$firstParam, $${firstParam + 1}, $${firstParam + 2}, $${firstParam + 3}, - $${firstParam + 4}, $${firstParam + 5}, $${firstParam + 6}) - """ - ) - } - - val timeIntervalMap = mutableMapOf() - firstParam = 1 - val statement = - boundStatement(sql.toString()) { - timeIntervals.forEach { - val timeIntervalId = idGenerator.generateInternalId().value - timeIntervalMap[it] = timeIntervalId - bind("$$firstParam", measurementConsumerReferenceId) - bind("$${firstParam + 1}", reportId) - bind("$${firstParam + 2}", timeIntervalId) - bind("$${firstParam + 3}", it.startTime.seconds) - bind("$${firstParam + 4}", it.startTime.nanos) - bind("$${firstParam + 5}", it.endTime.seconds) - bind("$${firstParam + 6}", it.endTime.nanos) - firstParam += numParameters - } - } - transactionContext.executeStatement(statement) - return timeIntervalMap - } - - private suspend fun TransactionScope.insertMetric( - measurementConsumerReferenceId: String, - reportId: Long, - timeIntervalMap: Map, - metric: Metric, - ) { - val metricId = idGenerator.generateInternalId().value - - val statement = - boundStatement( - """ - INSERT INTO Metrics (MeasurementConsumerReferenceId, ReportId, MetricId, MetricDetails) - VALUES ($1, $2, $3, $4) - """ - ) { - bind("$1", measurementConsumerReferenceId) - bind("$2", reportId) - bind("$3", metricId) - bind("$4", metric.details) - } - - transactionContext.executeStatement(statement) - - metric.namedSetOperationsList.forEach { - insertNamedSetOperation( - measurementConsumerReferenceId, - reportId, - metricId, - timeIntervalMap, - it, - ) - } - } - - private suspend fun TransactionScope.insertNamedSetOperation( - measurementConsumerReferenceId: String, - reportId: Long, - metricId: Long, - timeIntervalMap: Map, - namedSetOperation: NamedSetOperation, - ) { - val namedSetOperationId = idGenerator.generateInternalId().value - val setOperationId = - insertSetOperation( - measurementConsumerReferenceId, - reportId, - metricId, - namedSetOperation.setOperation, - ) - - val statement = - boundStatement( - """ - INSERT INTO NamedSetOperations(MeasurementConsumerReferenceId, ReportId, MetricId, NamedSetOperationId, DisplayName, SetOperationId) - VALUES ($1, $2, $3, $4, $5, $6) - """ - ) { - bind("$1", measurementConsumerReferenceId) - bind("$2", reportId) - bind("$3", metricId) - bind("$4", namedSetOperationId) - bind("$5", namedSetOperation.displayName) - bind("$6", setOperationId) - } - - transactionContext.executeStatement(statement) - - namedSetOperation.measurementCalculationsList.forEach { - insertMeasurementCalculations( - measurementConsumerReferenceId, - reportId, - metricId, - namedSetOperationId, - timeIntervalMap[it.timeInterval] - ?: throw MeasurementCalculationTimeIntervalNotFoundException(), - it, - ) - } - } - - private suspend fun TransactionScope.insertMeasurementCalculations( - measurementConsumerReferenceId: String, - reportId: Long, - metricId: Long, - namedSetOperationId: Long, - timeIntervalId: Long, - measurementCalculation: Metric.MeasurementCalculation, - ) { - val measurementCalculationId = idGenerator.generateInternalId().value - val statement = - boundStatement( - """ - INSERT INTO MeasurementCalculations(MeasurementConsumerReferenceId, ReportId, MetricId, NamedSetOperationId, MeasurementCalculationId, TimeIntervalId) - VALUES ($1, $2, $3, $4, $5, $6) - """ - ) { - bind("$1", measurementConsumerReferenceId) - bind("$2", reportId) - bind("$3", metricId) - bind("$4", namedSetOperationId) - bind("$5", measurementCalculationId) - bind("$6", timeIntervalId) - } - - transactionContext.executeStatement(statement) - insertWeightedMeasurements( - measurementConsumerReferenceId, - reportId, - metricId, - namedSetOperationId, - measurementCalculationId, - measurementCalculation.weightedMeasurementsList, - ) - } - - private suspend fun TransactionScope.insertWeightedMeasurements( - measurementConsumerReferenceId: String, - reportId: Long, - metricId: Long, - namedSetOperationId: Long, - measurementCalculationId: Long, - weightedMeasurements: Collection, - ) { - transactionContext.run { - weightedMeasurements.forEach { - val weightedMeasurementId = idGenerator.generateInternalId().value - val statement = - boundStatement( - """ - INSERT INTO WeightedMeasurements(MeasurementConsumerReferenceId, ReportId, MetricId, NamedSetOperationId, MeasurementCalculationId, WeightedMeasurementId, MeasurementReferenceId, Coefficient) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) - """ - ) { - bind("$1", measurementConsumerReferenceId) - bind("$2", reportId) - bind("$3", metricId) - bind("$4", namedSetOperationId) - bind("$5", measurementCalculationId) - bind("$6", weightedMeasurementId) - bind("$7", it.measurementReferenceId) - bind("$8", it.coefficient) - } - executeStatement(statement) - } - } - } - - private suspend fun TransactionScope.insertSetOperation( - measurementConsumerReferenceId: String, - reportId: Long, - metricId: Long, - setOperation: SetOperation, - ): Long { - val setOperationId = idGenerator.generateInternalId().value - val lhsReportingSetId: Long? - val rhsReportingSetId: Long? - val lhsSetOperationId: Long? - val rhsSetOperationId: Long? - - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") - when (setOperation.lhs.operandCase) { - SetOperation.Operand.OperandCase.REPORTINGSETID -> { - lhsSetOperationId = null - val reportingSetResult = - ReportingSetReader() - .readReportingSetByExternalId( - transactionContext, - measurementConsumerReferenceId, - ExternalId(setOperation.lhs.reportingSetId.externalReportingSetId), - ) - lhsReportingSetId = reportingSetResult.reportingSetId.value - } - SetOperation.Operand.OperandCase.OPERATION -> { - lhsSetOperationId = - insertSetOperation( - measurementConsumerReferenceId, - reportId, - metricId, - setOperation.lhs.operation, - ) - lhsReportingSetId = null - } - SetOperation.Operand.OperandCase.OPERAND_NOT_SET -> { - lhsSetOperationId = null - lhsReportingSetId = null - } - } - - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") - when (setOperation.rhs.operandCase) { - SetOperation.Operand.OperandCase.REPORTINGSETID -> { - rhsSetOperationId = null - val reportingSetResult = - ReportingSetReader() - .readReportingSetByExternalId( - transactionContext, - measurementConsumerReferenceId, - ExternalId(setOperation.rhs.reportingSetId.externalReportingSetId), - ) - rhsReportingSetId = reportingSetResult.reportingSetId.value - } - SetOperation.Operand.OperandCase.OPERATION -> { - rhsSetOperationId = - insertSetOperation( - measurementConsumerReferenceId, - reportId, - metricId, - setOperation.rhs.operation, - ) - rhsReportingSetId = null - } - SetOperation.Operand.OperandCase.OPERAND_NOT_SET -> { - rhsSetOperationId = null - rhsReportingSetId = null - } - } - - val statement = - boundStatement( - """ - INSERT INTO SetOperations(MeasurementConsumerReferenceId, ReportId, MetricId, SetOperationId, Type, LeftHandSetOperationId, LeftHandReportingSetId, RightHandSetOperationId, RightHandReportingSetId) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) - """ - ) { - bind("$1", measurementConsumerReferenceId) - bind("$2", reportId) - bind("$3", metricId) - bind("$4", setOperationId) - bind("$5", setOperation.typeValue) - bind("$6", lhsSetOperationId) - bind("$7", lhsReportingSetId) - bind("$8", rhsSetOperationId) - bind("$9", rhsReportingSetId) - } - - transactionContext.executeStatement(statement) - - return setOperationId - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/CreateReportingSet.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/CreateReportingSet.kt deleted file mode 100644 index 12e734bc410..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/CreateReportingSet.kt +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres.writers - -import io.r2dbc.spi.R2dbcDataIntegrityViolationException -import kotlinx.coroutines.coroutineScope -import org.wfanet.measurement.common.db.r2dbc.boundStatement -import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter -import org.wfanet.measurement.common.identity.ExternalId -import org.wfanet.measurement.common.identity.InternalId -import org.wfanet.measurement.internal.reporting.ReportingSet -import org.wfanet.measurement.internal.reporting.copy -import org.wfanet.measurement.reporting.service.internal.ReportingSetAlreadyExistsException - -/** - * Inserts a Reporting Set into the database. - * - * Throws the following on [execute]: - * * [ReportingSetAlreadyExistsException] ReportingSet already exists - */ -class CreateReportingSet(private val request: ReportingSet) : PostgresWriter() { - override suspend fun TransactionScope.runTransaction(): ReportingSet { - val internalReportingSetId = idGenerator.generateInternalId() - val externalReportingSetId = idGenerator.generateExternalId() - - insertReportingSet(internalReportingSetId, externalReportingSetId) - coroutineScope { - for (i in 0 until request.eventGroupKeysList.size) { - insertReportingSetEventGroup(request.eventGroupKeysList[i], internalReportingSetId) - } - } - - return request.copy { this.externalReportingSetId = externalReportingSetId.value } - } - - private suspend fun TransactionScope.insertReportingSet( - internalReportingSetId: InternalId, - externalReportingSetId: ExternalId, - ) { - val statement = - boundStatement( - """ - INSERT INTO ReportingSets (MeasurementConsumerReferenceId, ReportingSetId, ExternalReportingSetId, Filter, DisplayName) - VALUES ($1, $2, $3, $4, $5) - """ - ) { - bind("$1", request.measurementConsumerReferenceId) - bind("$2", internalReportingSetId) - bind("$3", externalReportingSetId) - bind("$4", request.filter) - bind("$5", request.displayName) - } - - try { - transactionContext.executeStatement(statement) - } catch (e: R2dbcDataIntegrityViolationException) { - throw ReportingSetAlreadyExistsException() - } - } - - private suspend fun TransactionScope.insertReportingSetEventGroup( - eventGroupKey: ReportingSet.EventGroupKey, - reportingSetId: InternalId, - ) { - val statement = - boundStatement( - """ - INSERT INTO ReportingSetEventGroups (MeasurementConsumerReferenceId, DataProviderReferenceId, EventGroupReferenceId, ReportingSetId) - VALUES ($1, $2, $3, $4) - """ - ) { - bind("$1", eventGroupKey.measurementConsumerReferenceId) - bind("$2", eventGroupKey.dataProviderReferenceId) - bind("$3", eventGroupKey.eventGroupReferenceId) - bind("$4", reportingSetId) - } - - transactionContext.executeStatement(statement) - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/SetMeasurementFailure.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/SetMeasurementFailure.kt deleted file mode 100644 index 830f7d2a6e1..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/SetMeasurementFailure.kt +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres.writers - -import org.wfanet.measurement.common.db.r2dbc.boundStatement -import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter -import org.wfanet.measurement.internal.reporting.Measurement -import org.wfanet.measurement.internal.reporting.Report -import org.wfanet.measurement.internal.reporting.SetMeasurementFailureRequest -import org.wfanet.measurement.internal.reporting.measurement -import org.wfanet.measurement.reporting.deploy.postgres.readers.MeasurementReader -import org.wfanet.measurement.reporting.service.internal.MeasurementNotFoundException -import org.wfanet.measurement.reporting.service.internal.MeasurementStateInvalidException - -/** - * Update a [Measurement] to be in a failure state along with any dependent [Report]. - * - * Throws the following on [execute]: - * * [MeasurementNotFoundException] Measurement not found. - * * [MeasurementStateInvalidException] Measurement does not have PENDING state. - */ -class SetMeasurementFailure(private val request: SetMeasurementFailureRequest) : - PostgresWriter() { - override suspend fun TransactionScope.runTransaction(): Measurement { - val measurementResult = - MeasurementReader() - .readMeasurementByReferenceIds( - transactionContext, - measurementConsumerReferenceId = request.measurementConsumerReferenceId, - measurementReferenceId = request.measurementReferenceId, - ) ?: throw MeasurementNotFoundException() - - if (measurementResult.measurement.state != Measurement.State.PENDING) { - throw MeasurementStateInvalidException() - } - - val updateMeasurementStatement = - boundStatement( - """ - UPDATE Measurements - SET State = $1, Failure = $2 - WHERE MeasurementConsumerReferenceId = $3 AND MeasurementReferenceId = $4 - """ - ) { - bind("$1", Measurement.State.FAILED_VALUE) - bind("$2", request.failure) - bind("$3", request.measurementConsumerReferenceId) - bind("$4", request.measurementReferenceId) - } - - val updateReportStatement = - boundStatement( - """ - UPDATE Reports - SET State = $1 - FROM ( - SELECT - ReportId - FROM ReportMeasurements - WHERE MeasurementConsumerReferenceId = $2 AND MeasurementReferenceId = $3 - ) AS ReportMeasurements - WHERE Reports.ReportId = ReportMeasurements.ReportId - """ - ) { - bind("$1", Report.State.FAILED_VALUE) - bind("$2", request.measurementConsumerReferenceId) - bind("$3", request.measurementReferenceId) - } - - transactionContext.run { - val numRowsUpdated = executeStatement(updateMeasurementStatement).numRowsUpdated - if (numRowsUpdated == 0L) { - return@run - } - executeStatement(updateReportStatement) - } - - return measurement { - measurementConsumerReferenceId = request.measurementConsumerReferenceId - measurementReferenceId = request.measurementReferenceId - state = Measurement.State.FAILED - failure = request.failure - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/SetMeasurementResult.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/SetMeasurementResult.kt deleted file mode 100644 index 67db2e99ae2..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/writers/SetMeasurementResult.kt +++ /dev/null @@ -1,411 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres.writers - -import com.google.protobuf.Duration -import com.google.protobuf.duration -import com.google.protobuf.util.Durations -import com.google.protobuf.util.Timestamps -import java.time.Instant -import java.util.concurrent.TimeUnit -import org.wfanet.measurement.common.db.r2dbc.boundStatement -import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter -import org.wfanet.measurement.internal.reporting.Measurement -import org.wfanet.measurement.internal.reporting.Metric -import org.wfanet.measurement.internal.reporting.Metric.MeasurementCalculation -import org.wfanet.measurement.internal.reporting.PeriodicTimeInterval -import org.wfanet.measurement.internal.reporting.Report -import org.wfanet.measurement.internal.reporting.ReportKt -import org.wfanet.measurement.internal.reporting.SetMeasurementResultRequest -import org.wfanet.measurement.internal.reporting.TimeInterval -import org.wfanet.measurement.internal.reporting.copy -import org.wfanet.measurement.internal.reporting.measurement -import org.wfanet.measurement.internal.reporting.timeInterval -import org.wfanet.measurement.reporting.deploy.postgres.readers.MeasurementReader -import org.wfanet.measurement.reporting.deploy.postgres.readers.MeasurementResultsReader -import org.wfanet.measurement.reporting.deploy.postgres.readers.ReportReader -import org.wfanet.measurement.reporting.service.internal.MeasurementNotFoundException -import org.wfanet.measurement.reporting.service.internal.MeasurementStateInvalidException - -private const val NANOS_PER_SECOND = 1_000_000_000 - -/** - * Updates the result for a Measurement and for the corresponding Report too if the report result - * has been computed. - * - * Throws the following on [execute]: - * * [MeasurementNotFoundException] Measurement not found. - * * [MeasurementStateInvalidException] Measurement does not have PENDING state. - */ -class SetMeasurementResult(private val request: SetMeasurementResultRequest) : - PostgresWriter() { - data class MeasurementResult(val result: Measurement.Result, val coefficient: Int) - - override suspend fun TransactionScope.runTransaction(): Measurement { - val measurementResult = - MeasurementReader() - .readMeasurementByReferenceIds( - transactionContext, - measurementConsumerReferenceId = request.measurementConsumerReferenceId, - measurementReferenceId = request.measurementReferenceId, - ) ?: throw MeasurementNotFoundException() - - if (measurementResult.measurement.state != Measurement.State.PENDING) { - throw MeasurementStateInvalidException() - } - - val updateMeasurementStatement = - boundStatement( - """ - UPDATE Measurements - SET State = $1, Result = $2 - WHERE MeasurementConsumerReferenceId = $3 AND MeasurementReferenceId = $4 - """ - ) { - bind("$1", Measurement.State.SUCCEEDED_VALUE) - bind("$2", request.result) - bind("$3", request.measurementConsumerReferenceId) - bind("$4", request.measurementReferenceId) - } - - transactionContext.run { - val numRowsUpdated = executeStatement(updateMeasurementStatement).numRowsUpdated - if (numRowsUpdated == 0L) { - return@run - } - - val measurementResultsMap = mutableMapOf() - val reportsSet = mutableSetOf() - val incompleteReportsSet = mutableSetOf() - MeasurementResultsReader() - .listMeasurementsForReportsByMeasurementReferenceId( - transactionContext, - request.measurementConsumerReferenceId, - request.measurementReferenceId, - ) - .collect { result -> - if (result.state == Measurement.State.SUCCEEDED) { - reportsSet.add(result.reportId.value) - measurementResultsMap[result.measurementReferenceId] = result - } else { - incompleteReportsSet.add(result.reportId.value) - } - } - - reportsSet.forEach { - if (incompleteReportsSet.contains(it)) { - return@forEach - } - - val reportResult = - ReportReader() - .getReportById(transactionContext, request.measurementConsumerReferenceId, it) - - val updatedDetails = - reportResult.report.details.copy { - this.result = constructResult(reportResult.report, measurementResultsMap) - } - - val updateReportStatement = - boundStatement( - """ - UPDATE Reports - SET ReportDetails = $1, State = $2 - WHERE MeasurementConsumerReferenceId = $3 - AND ReportId = $4 - """ - ) { - bind("$1", updatedDetails) - bind("$2", Report.State.SUCCEEDED.number) - bind("$3", request.measurementConsumerReferenceId) - bind("$4", it) - } - executeStatement(updateReportStatement) - } - } - - return measurement { - measurementConsumerReferenceId = request.measurementConsumerReferenceId - measurementReferenceId = request.measurementReferenceId - state = Measurement.State.SUCCEEDED - result = request.result - } - } - - private fun constructResult( - report: Report, - measurementResultsMap: Map, - ): Report.Details.Result { - return ReportKt.DetailsKt.result { - val rowHeaders = getRowHeaders(report) - val scalarTableColumnsList = ArrayList() - // Each metric contains results for several columns - for (metric in report.metricsList) { - when (val metricType = metric.details.metricTypeCase) { - // REACH, IMPRESSION_COUNT, and WATCH_DURATION are aggregated in one table. - Metric.Details.MetricTypeCase.REACH, - Metric.Details.MetricTypeCase.IMPRESSION_COUNT, - Metric.Details.MetricTypeCase.WATCH_DURATION -> { - // One namedSetOperation is one column in the report - for (namedSetOperation in metric.namedSetOperationsList) { - scalarTableColumnsList += - ReportKt.DetailsKt.ResultKt.column { - columnHeader = buildColumnHeader(metricType.name, namedSetOperation.displayName) - setOperations += - namedSetOperation.sortedMeasurementCalculations.map { - calculateScalarResult( - metricType, - it.getMeasurementResults(measurementResultsMap), - ) - } - } - } - } - Metric.Details.MetricTypeCase.FREQUENCY_HISTOGRAM -> { - histogramTables += metric.toHistogramTable(rowHeaders, measurementResultsMap) - } - Metric.Details.MetricTypeCase.METRICTYPE_NOT_SET -> error("Metric Type should be set.") - } - } - if (scalarTableColumnsList.size > 0) { - scalarTable = - ReportKt.DetailsKt.ResultKt.scalarTable { - this.rowHeaders += rowHeaders - columns += scalarTableColumnsList - } - } - } - } - - /** Generate row headers of [Report.Details.Result] from a [Report]. */ - private fun getRowHeaders(report: Report): List { - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - return when (report.timeCase) { - Report.TimeCase.TIME_INTERVALS -> { - report.timeIntervals.timeIntervalsList - .sortedWith { a, b -> - val start = Timestamps.compare(a.startTime, b.startTime) - if (start != 0) { - start - } else { - Timestamps.compare(a.endTime, b.endTime) - } - } - .map(TimeInterval::toRowHeader) - } - Report.TimeCase.PERIODIC_TIME_INTERVAL -> { - - report.periodicTimeInterval.toTimeIntervals().map(TimeInterval::toRowHeader) - } - Report.TimeCase.TIME_NOT_SET -> { - error("Time should be set.") - } - } - } - - private fun PeriodicTimeInterval.toTimeIntervals(): List { - val source = this - var startTime = checkNotNull(source.startTime) - return (0 until source.intervalCount).map { - timeInterval { - this.startTime = startTime - this.endTime = Timestamps.add(startTime, source.increment) - startTime = this.endTime - } - } - } - - private val Metric.NamedSetOperation.sortedMeasurementCalculations: List - get() { - return measurementCalculationsList.sortedWith { a, b -> - val start = Timestamps.compare(a.timeInterval.startTime, b.timeInterval.startTime) - if (start != 0) { - start - } else { - Timestamps.compare(a.timeInterval.endTime, b.timeInterval.endTime) - } - } - } - - private fun MeasurementCalculation.getMeasurementResults( - measurementResultsMap: Map - ): List { - return weightedMeasurementsList.map { - MeasurementResult( - measurementResultsMap[it.measurementReferenceId]?.result - ?: throw MeasurementNotFoundException(), - it.coefficient, - ) - } - } - - /** Build a column header given the metric and set operation name. */ - private fun buildColumnHeader(metricTypeName: String, setOperationName: String): String { - return "${metricTypeName}_$setOperationName" - } - - /** Calculate the equation to get the scalar result. */ - private fun calculateScalarResult( - metricType: Metric.Details.MetricTypeCase, - measurementResultsList: List, - ): Double { - return when (metricType) { - Metric.Details.MetricTypeCase.REACH -> calculateReachResult(measurementResultsList) - Metric.Details.MetricTypeCase.IMPRESSION_COUNT -> - calculateImpressionResult(measurementResultsList) - Metric.Details.MetricTypeCase.WATCH_DURATION -> - calculateWatchDurationResult(measurementResultsList) - Metric.Details.MetricTypeCase.FREQUENCY_HISTOGRAM -> - error("$metricType is not a scalar metric type") - Metric.Details.MetricTypeCase.METRICTYPE_NOT_SET -> error("Metric Type should be set.") - } - } - - /** Calculate the reach result by summing up weighted [Measurement]s. */ - private fun calculateReachResult(measurementResultsList: List): Double { - return measurementResultsList - .sumOf { (result, coefficient) -> - if (!result.hasReach()) { - error("Reach measurement is missing.") - } - result.reach.value * coefficient - } - .toDouble() - } - - /** Calculate the impression result by summing up weighted [Measurement]s. */ - private fun calculateImpressionResult( - measurementCoefficientPairsList: List - ): Double { - return measurementCoefficientPairsList - .sumOf { (result, coefficient) -> - if (!result.hasImpression()) { - error("Impression measurement is missing.") - } - result.impression.value * coefficient - } - .toDouble() - } - - /** Calculate the watch duration result by summing up weighted [Measurement]s. */ - private fun calculateWatchDurationResult( - measurementCoefficientPairsList: List - ): Double { - val watchDuration = - measurementCoefficientPairsList - .map { (result, coefficient) -> - if (!result.hasWatchDuration()) { - error("Watch duration measurement is missing.") - } - result.watchDuration.value * coefficient - } - .reduce { sum, element -> sum + element } - - return watchDuration.seconds + (watchDuration.nanos.toDouble() / NANOS_PER_SECOND) - } - - private operator fun Duration.times(coefficient: Int): Duration { - val source = this - return duration { - val weightedTotalNanos: Long = - (TimeUnit.SECONDS.toNanos(source.seconds) + source.nanos) * coefficient - seconds = TimeUnit.NANOSECONDS.toSeconds(weightedTotalNanos) - nanos = (weightedTotalNanos % NANOS_PER_SECOND).toInt() - } - } - - private operator fun Duration.plus(other: Duration): Duration { - val source = this - return Durations.add(source, other) - } - - /** Calculate the frequency histogram result by summing up weighted [Measurement]s. */ - private fun calculateFrequencyHistogram( - measurementCoefficientPairsList: List - ): Map { - val aggregatedFrequencyHistogramMap = - measurementCoefficientPairsList - .map { (result, coefficient) -> - if (!result.hasFrequency() || !result.hasReach()) { - error("Reach-Frequency measurement is missing.") - } - val reach = result.reach.value - result.frequency.relativeFrequencyDistributionMap.mapValues { - it.value * coefficient * reach - } - } - .fold(mutableMapOf().withDefault { 0.0 }) { - aggregatedFrequencyHistogramMap: MutableMap, - weightedFrequencyHistogramMap -> - for ((frequency, count) in weightedFrequencyHistogramMap) { - aggregatedFrequencyHistogramMap[frequency] = - aggregatedFrequencyHistogramMap.getValue(frequency) + count - } - aggregatedFrequencyHistogramMap - } - - return aggregatedFrequencyHistogramMap - } - - /** Convert a [Metric] to a [Report.Details.Result.HistogramTable] of a [Report] */ - private fun Metric.toHistogramTable( - rowHeaders: List, - measurementResultsMap: Map, - ): Report.Details.Result.HistogramTable { - val setOperationFrequencyHistograms: List>> = - namedSetOperationsList.map { - it.sortedMeasurementCalculations.map { calculation -> - calculateFrequencyHistogram(calculation.getMeasurementResults(measurementResultsMap)) - } - } - val largestFrequency = - setOperationFrequencyHistograms.flatten().flatMap { it.keys }.maxOrNull() ?: 1L - - val source = this - return ReportKt.DetailsKt.ResultKt.histogramTable { - for (rowHeader in rowHeaders) { - for (frequency in 1..largestFrequency) { - rows += - ReportKt.DetailsKt.ResultKt.HistogramTableKt.row { - this.rowHeader = rowHeader - this.frequency = frequency.toInt() - } - } - } - for ((namedSetOperation, frequencyHistograms) in - source.namedSetOperationsList.zip(setOperationFrequencyHistograms)) { - columns += - ReportKt.DetailsKt.ResultKt.column { - columnHeader = - buildColumnHeader(source.details.metricTypeCase.name, namedSetOperation.displayName) - for (frequencyHistogram in frequencyHistograms) { - for (frequency in 1..largestFrequency) { - setOperations += frequencyHistogram.getOrDefault(frequency, 0.0) - } - } - } - } - } - } -} - -/** Convert a [TimeInterval] to a row header in String. */ -private fun TimeInterval.toRowHeader(): String { - val source = this - val startTimeInstant = - Instant.ofEpochSecond(source.startTime.seconds, source.startTime.nanos.toLong()) - val endTimeInstant = Instant.ofEpochSecond(source.endTime.seconds, source.endTime.nanos.toLong()) - return "$startTimeInstant-$endTimeInstant" -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/AkidPrincipalLookup.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/AkidPrincipalLookup.kt deleted file mode 100644 index afba11912f2..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/AkidPrincipalLookup.kt +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright 2022 The Cross-Media Measurement Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import com.google.protobuf.ByteString -import java.io.File -import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey -import org.wfanet.measurement.common.api.AkidConfigPrincipalLookup -import org.wfanet.measurement.common.api.AkidConfigResourceNameLookup -import org.wfanet.measurement.common.api.PrincipalLookup -import org.wfanet.measurement.common.api.ResourceKey -import org.wfanet.measurement.common.api.toResourceKeyLookup -import org.wfanet.measurement.common.parseTextProto -import org.wfanet.measurement.config.AuthorityKeyToPrincipalMap -import org.wfanet.measurement.config.reporting.MeasurementConsumerConfig -import org.wfanet.measurement.config.reporting.MeasurementConsumerConfigs - -/** [PrincipalLookup] of [ReportingPrincipal] by authority key identifier (AKID). */ -class AkidPrincipalLookup( - akidConfig: AuthorityKeyToPrincipalMap, - measurementConsumerConfigs: MeasurementConsumerConfigs, -) : PrincipalLookup { - - /** - * Constructs [AkidConfigPrincipalLookup] from a file. - * - * @param akidConfig a [File] containing an [AuthorityKeyToPrincipalMap] message in text format - * @param measurementConsumerConfigs a [File] containing a [MeasurementConsumerConfigs] message in - * text format - */ - constructor( - akidConfig: File, - measurementConsumerConfigs: File, - ) : this( - parseTextProto(akidConfig, AuthorityKeyToPrincipalMap.getDefaultInstance()), - parseTextProto(measurementConsumerConfigs, MeasurementConsumerConfigs.getDefaultInstance()), - ) - - private val measurementConsumerConfigs: Map = - measurementConsumerConfigs.configsMap - - private val resourceKeyLookup = - AkidConfigResourceNameLookup(akidConfig).toResourceKeyLookup(MeasurementConsumerKey.FACTORY) - - override suspend fun getPrincipal(lookupKey: ByteString): ReportingPrincipal? { - val resourceKey: ResourceKey = resourceKeyLookup.getResourceKey(lookupKey) ?: return null - return when (resourceKey) { - is MeasurementConsumerKey -> { - val resourceName: String = resourceKey.toName() - val config = - measurementConsumerConfigs[resourceName] - ?: error("Missing MeasurementConsumerConfig for $resourceName") - MeasurementConsumerPrincipal(resourceKey, config) - } - else -> null - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/BUILD.bazel deleted file mode 100644 index 445524ac557..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/BUILD.bazel +++ /dev/null @@ -1,143 +0,0 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") - -package(default_visibility = ["//visibility:public"]) - -kt_jvm_library( - name = "resource_key", - srcs = glob(["*Key.kt"]) + ["IdVariable.kt"], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/common/api:resource_key", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", - ], -) - -kt_jvm_library( - name = "reporting_sets_service", - srcs = ["ReportingSetsService.kt"], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/api:public_api_version", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:principal_server_interceptor", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:reporting_principal", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:resource_key", - "//src/main/proto/wfa/measurement/internal/reporting:reporting_sets_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:page_token_kt_jvm_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:reporting_sets_service_kt_jvm_grpc_proto", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/identity", - ], -) - -kt_jvm_library( - name = "event_groups_service", - srcs = ["EventGroupsService.kt"], - deps = [ - ":resource_key", - "//imports/java/org/projectnessie/cel", - "//src/main/kotlin/org/wfanet/measurement/api:api_key_constants", - "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:packed_messages", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api:cel_env_provider", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api:encryption_key_pair_store", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:principal_server_interceptor", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:reporting_principal", - "//src/main/proto/wfa/measurement/api/v2alpha:event_group_kt_jvm_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:event_groups_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:event_group_kt_jvm_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:event_groups_service_kt_jvm_grpc_proto", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto/tink", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", - "@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/measurementconsumer", - ], -) - -kt_jvm_library( - name = "reports_service", - srcs = ["ReportsService.kt"], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/api:api_key_constants", - "//src/main/kotlin/org/wfanet/measurement/api:public_api_version", - "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:packed_messages", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api:encryption_key_pair_store", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:principal_server_interceptor", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:reporting_principal", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:resource_key", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:set_operation_compiler", - "//src/main/proto/wfa/measurement/api/v2alpha:certificate_kt_jvm_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:certificates_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:data_providers_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:measurement_consumers_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:measurement_kt_jvm_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:measurements_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/config/reporting:measurement_spec_config_kt_jvm_proto", - "//src/main/proto/wfa/measurement/internal/reporting:measurements_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/internal/reporting:reporting_sets_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/internal/reporting:reports_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:page_token_kt_jvm_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:reports_service_kt_jvm_grpc_proto", - "@wfa_common_jvm//imports/java/com/google/protobuf/util", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto:security_provider", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/identity", - "@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/dataprovider", - "@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/measurementconsumer", - ], -) - -kt_jvm_library( - name = "akid_principal_lookup", - srcs = ["AkidPrincipalLookup.kt"], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key", - "//src/main/kotlin/org/wfanet/measurement/common/api:akid_config_lookup", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:reporting_principal", - "//src/main/proto/wfa/measurement/config/reporting:measurement_consumer_config_kt_jvm_proto", - ], -) - -kt_jvm_library( - name = "set_operation_compiler", - srcs = ["SetOperationCompiler.kt"], - deps = [ - "//src/main/proto/wfa/measurement/reporting/v1alpha:report_kt_jvm_proto", - "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", - ], -) - -kt_jvm_library( - name = "context_keys", - srcs = ["ContextKeys.kt"], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/common/api:principal", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:reporting_principal", - "@wfa_common_jvm//imports/java/io/grpc:api", - "@wfa_common_jvm//imports/java/io/grpc:context", - ], -) - -kt_jvm_library( - name = "reporting_principal", - srcs = ["ReportingPrincipal.kt"], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key", - "//src/main/kotlin/org/wfanet/measurement/common/api:principal", - "//src/main/proto/wfa/measurement/config/reporting:measurement_consumer_config_kt_jvm_proto", - ], -) - -kt_jvm_library( - name = "principal_server_interceptor", - srcs = ["PrincipalServerInterceptor.kt"], - deps = [ - "context_keys", - ":reporting_principal", - "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key", - "//src/main/kotlin/org/wfanet/measurement/common/api/grpc", - "//src/main/kotlin/org/wfanet/measurement/common/identity", - "@wfa_common_jvm//imports/java/com/google/protobuf", - "@wfa_common_jvm//imports/java/io/grpc:api", - "@wfa_common_jvm//imports/java/io/grpc:context", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/identity", - ], -) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupKey.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupKey.kt deleted file mode 100644 index 98286678227..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupKey.kt +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import org.wfanet.measurement.common.ResourceNameParser -import org.wfanet.measurement.common.api.ResourceKey - -private val parser = - ResourceNameParser( - "measurementConsumers/{measurement_consumer}" + - "/dataProviders/{data_provider}/eventGroups/{event_group}" - ) - -/** [ResourceKey] of an EventGroup. */ -class EventGroupKey( - val measurementConsumerReferenceId: String, - val dataProviderReferenceId: String, - val eventGroupReferenceId: String, -) : ResourceKey { - override fun toName(): String { - return parser.assembleName( - mapOf( - IdVariable.MEASUREMENT_CONSUMER to measurementConsumerReferenceId, - IdVariable.DATA_PROVIDER to dataProviderReferenceId, - IdVariable.EVENT_GROUP to eventGroupReferenceId, - ) - ) - } - - companion object FACTORY : ResourceKey.Factory { - val defaultValue = EventGroupKey("", "", "") - - override fun fromName(resourceName: String): EventGroupKey? { - return parser.parseIdVars(resourceName)?.let { - EventGroupKey( - it.getValue(IdVariable.MEASUREMENT_CONSUMER), - it.getValue(IdVariable.DATA_PROVIDER), - it.getValue(IdVariable.EVENT_GROUP), - ) - } - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupParentKey.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupParentKey.kt deleted file mode 100644 index 53526048649..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupParentKey.kt +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import org.wfanet.measurement.common.ResourceNameParser -import org.wfanet.measurement.common.api.ResourceKey - -private val parser = - ResourceNameParser("measurementConsumers/{measurement_consumer}/dataProviders/{data_provider}") - -/** [ResourceKey] of a EventGroupParent. */ -data class EventGroupParentKey( - val measurementConsumerId: String, - val dataProviderReferenceId: String, -) : ResourceKey { - override fun toName(): String { - return parser.assembleName( - mapOf( - IdVariable.MEASUREMENT_CONSUMER to measurementConsumerId, - IdVariable.DATA_PROVIDER to dataProviderReferenceId, - ) - ) - } - - companion object FACTORY : ResourceKey.Factory { - val defaultValue = EventGroupParentKey("", "") - - override fun fromName(resourceName: String): EventGroupParentKey? { - return parser.parseIdVars(resourceName)?.let { - EventGroupParentKey( - it.getValue(IdVariable.MEASUREMENT_CONSUMER), - it.getValue(IdVariable.DATA_PROVIDER), - ) - } - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupsService.kt deleted file mode 100644 index e180f567cb7..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupsService.kt +++ /dev/null @@ -1,237 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import com.google.protobuf.DynamicMessage -import com.google.protobuf.kotlin.unpack -import io.grpc.Status -import io.grpc.StatusException -import java.security.GeneralSecurityException -import org.projectnessie.cel.common.types.Err -import org.projectnessie.cel.common.types.ref.Val -import org.wfanet.measurement.api.v2alpha.DataProviderKey as CmmsDataProviderKey -import org.wfanet.measurement.api.v2alpha.EncryptionPublicKey as CmmsEncryptionPublicKey -import org.wfanet.measurement.api.v2alpha.EventGroup as CmmsEventGroup -import org.wfanet.measurement.api.v2alpha.EventGroupKey as CmmsEventGroupKey -import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub -import org.wfanet.measurement.api.v2alpha.ListEventGroupsRequestKt.filter -import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey as CmmsMeasurementConsumerKey -import org.wfanet.measurement.api.v2alpha.listEventGroupsRequest as cmmsListEventGroupsRequest -import org.wfanet.measurement.api.withAuthenticationKey -import org.wfanet.measurement.common.api.ResourceKey -import org.wfanet.measurement.common.crypto.PrivateKeyHandle -import org.wfanet.measurement.common.grpc.failGrpc -import org.wfanet.measurement.consent.client.measurementconsumer.decryptMetadata -import org.wfanet.measurement.reporting.service.api.CelEnvProvider -import org.wfanet.measurement.reporting.service.api.EncryptionKeyPairStore -import org.wfanet.measurement.reporting.v1alpha.EventGroup -import org.wfanet.measurement.reporting.v1alpha.EventGroupKt.eventTemplate -import org.wfanet.measurement.reporting.v1alpha.EventGroupKt.metadata -import org.wfanet.measurement.reporting.v1alpha.EventGroupsGrpcKt.EventGroupsCoroutineImplBase -import org.wfanet.measurement.reporting.v1alpha.ListEventGroupsRequest -import org.wfanet.measurement.reporting.v1alpha.ListEventGroupsResponse -import org.wfanet.measurement.reporting.v1alpha.eventGroup -import org.wfanet.measurement.reporting.v1alpha.listEventGroupsResponse - -private const val METADATA_FIELD = "metadata.metadata" - -private const val MIN_PAGE_SIZE = 1 -private const val DEFAULT_PAGE_SIZE = 50 -private const val MAX_PAGE_SIZE = 1000 - -class EventGroupsService( - private val cmmsEventGroupsStub: EventGroupsCoroutineStub, - private val encryptionKeyPairStore: EncryptionKeyPairStore, - private val celEnvProvider: CelEnvProvider, -) : EventGroupsCoroutineImplBase() { - override suspend fun listEventGroups(request: ListEventGroupsRequest): ListEventGroupsResponse { - val principal: ReportingPrincipal = principalFromCurrentContext - - if (principal !is MeasurementConsumerPrincipal) { - failGrpc(Status.PERMISSION_DENIED) { - "Cannot list event groups with entities other than measurement consumer." - } - } - - val principalName = principal.resourceKey.toName() - val apiAuthenticationKey: String = principal.config.apiKey - val parentKey = - EventGroupParentKey.fromName(request.parent) - ?: failGrpc(Status.INVALID_ARGUMENT) { "parent malformed or unspecified" } - val pageSize = - when { - request.pageSize < MIN_PAGE_SIZE -> DEFAULT_PAGE_SIZE - request.pageSize > MAX_PAGE_SIZE -> MAX_PAGE_SIZE - else -> request.pageSize - } - - val cmmsListEventGroupResponse = - try { - cmmsEventGroupsStub - .withAuthenticationKey(apiAuthenticationKey) - .listEventGroups( - cmmsListEventGroupsRequest { - parent = CmmsMeasurementConsumerKey(parentKey.measurementConsumerId).toName() - this.pageSize = pageSize - pageToken = request.pageToken - filter = filter { - if (parentKey.dataProviderReferenceId != ResourceKey.WILDCARD_ID) { - dataProviders += CmmsDataProviderKey(parentKey.dataProviderReferenceId).toName() - } - } - } - ) - } catch (e: StatusException) { - throw when (e.status.code) { - Status.Code.DEADLINE_EXCEEDED -> Status.DEADLINE_EXCEEDED - Status.Code.CANCELLED -> Status.CANCELLED - else -> Status.UNKNOWN - } - .withCause(e) - .asRuntimeException() - } - val cmmsEventGroups = cmmsListEventGroupResponse.eventGroupsList - - val eventGroups = - cmmsEventGroups.map { - val cmmsMetadata: CmmsEventGroup.Metadata? = - if (it.hasEncryptedMetadata()) { - decryptMetadata(it, principalName) - } else { - null - } - - it.toEventGroup(cmmsMetadata) - } - - return listEventGroupsResponse { - this.eventGroups += filterEventGroups(eventGroups, request.filter) - nextPageToken = cmmsListEventGroupResponse.nextPageToken - } - } - - private suspend fun filterEventGroups( - eventGroups: List, - filter: String, - ): List { - if (filter.isEmpty()) { - return eventGroups - } - - val typeRegistryAndEnv = celEnvProvider.getTypeRegistryAndEnv() - val env = typeRegistryAndEnv.env - val typeRegistry = typeRegistryAndEnv.typeRegistry - - val astAndIssues = env.compile(filter) - if (astAndIssues.hasIssues()) { - throw Status.INVALID_ARGUMENT.withDescription( - "filter is not a valid CEL expression: ${astAndIssues.issues}" - ) - .asRuntimeException() - } - val program = env.program(astAndIssues.ast) - - eventGroups - .filter { it.hasMetadata() } - .distinctBy { it.metadata.metadata.typeUrl } - .forEach { - val typeUrl = it.metadata.metadata.typeUrl - typeRegistry.getDescriptorForTypeUrl(typeUrl) - ?: throw Status.FAILED_PRECONDITION.withDescription( - "${it.metadata.eventGroupMetadataDescriptor} does not contain descriptor for $typeUrl" - ) - .asRuntimeException() - } - - return eventGroups.filter { eventGroup -> - val variables: Map = - mutableMapOf().apply { - for (fieldDescriptor in eventGroup.descriptorForType.fields) { - put(fieldDescriptor.name, eventGroup.getField(fieldDescriptor)) - } - // TODO(projectnessie/cel-java#295): Remove when fixed. - if (eventGroup.hasMetadata()) { - val metadata: com.google.protobuf.Any = eventGroup.metadata.metadata - put( - METADATA_FIELD, - DynamicMessage.parseFrom( - typeRegistry.getDescriptorForTypeUrl(metadata.typeUrl), - metadata.value, - ), - ) - } - } - val result: Val = program.eval(variables).`val` - if (result is Err) { - // For when the field in the filter doesn't exist in the event group. - if (result.toString().contains("undeclared reference to")) { - return@filter false - } - throw result.toRuntimeException() - } - result.booleanValue() - } - } - - private suspend fun decryptMetadata( - cmmsEventGroup: CmmsEventGroup, - principalName: String, - ): CmmsEventGroup.Metadata { - if (!cmmsEventGroup.hasMeasurementConsumerPublicKey()) { - failGrpc(Status.FAILED_PRECONDITION) { - "EventGroup ${cmmsEventGroup.name} has encrypted metadata but no encryption public key" - } - } - val encryptionKey: CmmsEncryptionPublicKey = - cmmsEventGroup.measurementConsumerPublicKey.unpack() - val decryptionKeyHandle: PrivateKeyHandle = - encryptionKeyPairStore.getPrivateKeyHandle(principalName, encryptionKey.data) - ?: failGrpc(Status.FAILED_PRECONDITION) { - "Public key does not have corresponding private key" - } - return try { - decryptMetadata(cmmsEventGroup.encryptedMetadata, decryptionKeyHandle) - } catch (e: GeneralSecurityException) { - throw Status.FAILED_PRECONDITION.withCause(e) - .withDescription("Metadata cannot be decrypted") - .asRuntimeException() - } - } -} - -private fun CmmsEventGroup.toEventGroup(cmmsMetadata: CmmsEventGroup.Metadata?): EventGroup { - val source = this - val cmmsEventGroupKey = requireNotNull(CmmsEventGroupKey.fromName(name)) - val measurementConsumerKey = - requireNotNull(CmmsMeasurementConsumerKey.fromName(measurementConsumer)) - return eventGroup { - name = - EventGroupKey( - measurementConsumerKey.measurementConsumerId, - cmmsEventGroupKey.dataProviderId, - cmmsEventGroupKey.eventGroupId, - ) - .toName() - dataProvider = CmmsDataProviderKey(cmmsEventGroupKey.dataProviderId).toName() - eventGroupReferenceId = source.eventGroupReferenceId - eventTemplates += source.eventTemplatesList.map { eventTemplate { type = it.type } } - if (cmmsMetadata != null) { - metadata = metadata { - eventGroupMetadataDescriptor = cmmsMetadata.eventGroupMetadataDescriptor - metadata = cmmsMetadata.metadata - } - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/IdVariable.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/IdVariable.kt deleted file mode 100644 index ed764300324..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/IdVariable.kt +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import java.util.Locale -import org.wfanet.measurement.common.ResourceNameParser - -internal enum class IdVariable { - DATA_PROVIDER, - EVENT_GROUP, - MEASUREMENT_CONSUMER, - REPORTING_SET, - REPORT, -} - -internal fun ResourceNameParser.assembleName(idMap: Map): String { - return assembleName(idMap.mapKeys { it.key.name.lowercase(Locale.getDefault()) }) -} - -internal fun ResourceNameParser.parseIdVars(resourceName: String): Map? { - return parseIdSegments(resourceName)?.mapKeys { - IdVariable.valueOf(it.key.uppercase(Locale.getDefault())) - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/PrincipalServerInterceptor.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/PrincipalServerInterceptor.kt deleted file mode 100644 index e52fbc49861..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/PrincipalServerInterceptor.kt +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import com.google.protobuf.ByteString -import io.grpc.BindableService -import io.grpc.Context -import io.grpc.ServerInterceptors -import io.grpc.ServerServiceDefinition -import io.grpc.Status -import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey -import org.wfanet.measurement.common.api.PrincipalLookup -import org.wfanet.measurement.common.api.grpc.AkidPrincipalServerInterceptor -import org.wfanet.measurement.common.grpc.failGrpc -import org.wfanet.measurement.common.identity.AuthorityKeyServerInterceptor -import org.wfanet.measurement.config.reporting.MeasurementConsumerConfig - -/** - * Returns a [ReportingPrincipal] in the current gRPC context. Requires [PrincipalServerInterceptor] - * to be installed. - * - * Callers can trust that the [ReportingPrincipal] is authenticated (but not necessarily - * authorized). - */ -val principalFromCurrentContext: ReportingPrincipal - get() = - ContextKeys.PRINCIPAL_CONTEXT_KEY.get() - ?: failGrpc(Status.UNAUTHENTICATED) { "No ReportingPrincipal found" } - -/** - * Executes [block] with [principal] installed in a new [Context]. - * - * The caller of [withPrincipal] is responsible for guaranteeing that [block] can act as [principal] - * -- in other words, [principal] is treated as already authenticated. - */ -fun withPrincipal(principal: ReportingPrincipal, block: () -> T): T { - return Context.current().withPrincipal(principal).call(block) -} - -/** Executes [block] with a [MeasurementConsumerPrincipal] installed in a new [Context]. */ -fun withMeasurementConsumerPrincipal( - measurementConsumerName: String, - config: MeasurementConsumerConfig, - block: () -> T, -): T { - return Context.current() - .withPrincipal( - MeasurementConsumerPrincipal( - MeasurementConsumerKey.fromName(measurementConsumerName)!!, - config, - ) - ) - .call(block) -} - -/** Adds [principal] to the receiver and returns the new [Context]. */ -fun Context.withPrincipal(principal: ReportingPrincipal): Context { - return withValue(ContextKeys.PRINCIPAL_CONTEXT_KEY, principal) -} - -/** Convenience helper for [AkidPrincipalServerInterceptor]. */ -fun BindableService.withPrincipalsFromX509AuthorityKeyIdentifiers( - akidPrincipalLookup: PrincipalLookup -): ServerServiceDefinition { - return ServerInterceptors.interceptForward( - this, - AuthorityKeyServerInterceptor(), - AkidPrincipalServerInterceptor( - ContextKeys.PRINCIPAL_CONTEXT_KEY, - AuthorityKeyServerInterceptor.AUTHORITY_KEY_IDENTIFIERS_CONTEXT_KEY, - akidPrincipalLookup, - ), - ) -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportKey.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportKey.kt deleted file mode 100644 index 11037f6824a..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportKey.kt +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import org.wfanet.measurement.common.ResourceNameParser -import org.wfanet.measurement.common.api.ResourceKey - -private val parser = - ResourceNameParser("measurementConsumers/{measurement_consumer}/reports/{report}") - -/** [ResourceKey] of a Report. */ -data class ReportKey(val measurementConsumerId: String, val reportId: String) : ResourceKey { - override fun toName(): String { - return parser.assembleName( - mapOf(IdVariable.MEASUREMENT_CONSUMER to measurementConsumerId, IdVariable.REPORT to reportId) - ) - } - - companion object FACTORY : ResourceKey.Factory { - val defaultValue = ReportKey("", "") - - override fun fromName(resourceName: String): ReportKey? { - return parser.parseIdVars(resourceName)?.let { - ReportKey(it.getValue(IdVariable.MEASUREMENT_CONSUMER), it.getValue(IdVariable.REPORT)) - } - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportingPrincipal.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportingPrincipal.kt deleted file mode 100644 index 76c43f1d704..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportingPrincipal.kt +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import org.wfanet.measurement.api.v2alpha.MeasurementConsumerCertificateKey -import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey -import org.wfanet.measurement.common.api.Principal -import org.wfanet.measurement.common.api.ResourcePrincipal -import org.wfanet.measurement.config.reporting.MeasurementConsumerConfig - -/** Identifies the sender of an inbound gRPC request. */ -sealed interface ReportingPrincipal : Principal { - val config: MeasurementConsumerConfig - - companion object { - fun fromConfigs(name: String, config: MeasurementConsumerConfig): ReportingPrincipal? { - return when (name.substringBefore('/')) { - MeasurementConsumerKey.COLLECTION_NAME -> { - require( - config.apiKey.isNotBlank() && - MeasurementConsumerCertificateKey.fromName(config.signingCertificateName) != null && - config.signingPrivateKeyPath.isNotBlank() - ) - MeasurementConsumerKey.fromName(name)?.let { MeasurementConsumerPrincipal(it, config) } - } - else -> null - } - } - } -} - -data class MeasurementConsumerPrincipal( - override val resourceKey: MeasurementConsumerKey, - override val config: MeasurementConsumerConfig, -) : ReportingPrincipal, ResourcePrincipal diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportingSetKey.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportingSetKey.kt deleted file mode 100644 index 7869652a1ff..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportingSetKey.kt +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import org.wfanet.measurement.common.ResourceNameParser -import org.wfanet.measurement.common.api.ResourceKey - -private val parser = - ResourceNameParser("measurementConsumers/{measurement_consumer}/reportingSets/{reporting_set}") - -/** [ResourceKey] of a ReportingSet. */ -data class ReportingSetKey(val measurementConsumerId: String, val reportingSetId: String) : - ResourceKey { - override fun toName(): String { - return parser.assembleName( - mapOf( - IdVariable.MEASUREMENT_CONSUMER to measurementConsumerId, - IdVariable.REPORTING_SET to reportingSetId, - ) - ) - } - - companion object FACTORY : ResourceKey.Factory { - val defaultValue = ReportingSetKey("", "") - - override fun fromName(resourceName: String): ReportingSetKey? { - return parser.parseIdVars(resourceName)?.let { - ReportingSetKey( - it.getValue(IdVariable.MEASUREMENT_CONSUMER), - it.getValue(IdVariable.REPORTING_SET), - ) - } - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportingSetsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportingSetsService.kt deleted file mode 100644 index 12ecd237e27..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportingSetsService.kt +++ /dev/null @@ -1,235 +0,0 @@ -/* - * Copyright 2022 The Cross-Media Measurement Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import io.grpc.Status -import kotlin.math.min -import kotlinx.coroutines.flow.toList -import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey -import org.wfanet.measurement.common.base64UrlDecode -import org.wfanet.measurement.common.base64UrlEncode -import org.wfanet.measurement.common.grpc.failGrpc -import org.wfanet.measurement.common.grpc.grpcRequire -import org.wfanet.measurement.common.grpc.grpcRequireNotNull -import org.wfanet.measurement.common.identity.externalIdToApiId -import org.wfanet.measurement.internal.reporting.ReportingSet as InternalReportingSet -import org.wfanet.measurement.internal.reporting.ReportingSet.EventGroupKey as InternalEventGroupKey -import org.wfanet.measurement.internal.reporting.ReportingSetKt.eventGroupKey as internalEventGroupKey -import org.wfanet.measurement.internal.reporting.ReportingSetsGrpcKt.ReportingSetsCoroutineStub -import org.wfanet.measurement.internal.reporting.StreamReportingSetsRequest -import org.wfanet.measurement.internal.reporting.StreamReportingSetsRequestKt.filter -import org.wfanet.measurement.internal.reporting.reportingSet as internalReportingSet -import org.wfanet.measurement.internal.reporting.streamReportingSetsRequest -import org.wfanet.measurement.reporting.v1alpha.CreateReportingSetRequest -import org.wfanet.measurement.reporting.v1alpha.ListReportingSetsPageToken -import org.wfanet.measurement.reporting.v1alpha.ListReportingSetsPageTokenKt.previousPageEnd -import org.wfanet.measurement.reporting.v1alpha.ListReportingSetsRequest -import org.wfanet.measurement.reporting.v1alpha.ListReportingSetsResponse -import org.wfanet.measurement.reporting.v1alpha.ReportingSet -import org.wfanet.measurement.reporting.v1alpha.ReportingSetsGrpcKt.ReportingSetsCoroutineImplBase -import org.wfanet.measurement.reporting.v1alpha.copy -import org.wfanet.measurement.reporting.v1alpha.listReportingSetsPageToken -import org.wfanet.measurement.reporting.v1alpha.listReportingSetsResponse -import org.wfanet.measurement.reporting.v1alpha.reportingSet - -private const val MIN_PAGE_SIZE = 1 -private const val DEFAULT_PAGE_SIZE = 50 -private const val MAX_PAGE_SIZE = 1000 - -class ReportingSetsService(private val internalReportingSetsStub: ReportingSetsCoroutineStub) : - ReportingSetsCoroutineImplBase() { - override suspend fun createReportingSet(request: CreateReportingSetRequest): ReportingSet { - val parentKey = - grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) { - "Parent is either unspecified or invalid." - } - - when (val principal: ReportingPrincipal = principalFromCurrentContext) { - is MeasurementConsumerPrincipal -> { - if (request.parent != principal.resourceKey.toName()) { - failGrpc(Status.PERMISSION_DENIED) { - "Cannot create a ReportingSet for another MeasurementConsumer." - } - } - } - } - - grpcRequire(request.hasReportingSet()) { "ReportingSet is not specified." } - - grpcRequire(request.reportingSet.eventGroupsList.isNotEmpty()) { - "EventGroups in ReportingSet cannot be empty." - } - - return internalReportingSetsStub - .createReportingSet(request.reportingSet.toInternal(parentKey)) - .toReportingSet() - } - - override suspend fun listReportingSets( - request: ListReportingSetsRequest - ): ListReportingSetsResponse { - val listReportingSetsPageToken = request.toListReportingSetsPageToken() - - // Based on AIP-132#Errors - when (val principal: ReportingPrincipal = principalFromCurrentContext) { - is MeasurementConsumerPrincipal -> { - if (request.parent != principal.resourceKey.toName()) { - failGrpc(Status.PERMISSION_DENIED) { - "Cannot list ReportingSets belonging to other MeasurementConsumers." - } - } - } - } - - val results: List = - internalReportingSetsStub - .streamReportingSets(listReportingSetsPageToken.toStreamReportingSetsRequest()) - .toList() - - if (results.isEmpty()) { - return ListReportingSetsResponse.getDefaultInstance() - } - - return listReportingSetsResponse { - reportingSets += - results - .subList(0, min(results.size, listReportingSetsPageToken.pageSize)) - .map(InternalReportingSet::toReportingSet) - - if (results.size > listReportingSetsPageToken.pageSize) { - val pageToken = - listReportingSetsPageToken.copy { - lastReportingSet = previousPageEnd { - measurementConsumerReferenceId = - results[results.lastIndex - 1].measurementConsumerReferenceId - externalReportingSetId = results[results.lastIndex - 1].externalReportingSetId - } - } - nextPageToken = pageToken.toByteString().base64UrlEncode() - } - } - } -} - -/** - * Converts an internal [ListReportingSetsPageToken] to an internal [StreamReportingSetsRequest]. - */ -private fun ListReportingSetsPageToken.toStreamReportingSetsRequest(): StreamReportingSetsRequest { - val source = this - return streamReportingSetsRequest { - // get 1 more than the actual page size for deciding whether or not to set page token - limit = pageSize + 1 - filter = filter { - measurementConsumerReferenceId = source.measurementConsumerReferenceId - externalReportingSetIdAfter = source.lastReportingSet.externalReportingSetId - } - } -} - -/** Converts a public [ListReportingSetsRequest] to an internal [ListReportingSetsPageToken]. */ -private fun ListReportingSetsRequest.toListReportingSetsPageToken(): ListReportingSetsPageToken { - grpcRequire(pageSize >= 0) { "Page size cannot be less than 0" } - - val source = this - val parentKey: MeasurementConsumerKey = - grpcRequireNotNull(MeasurementConsumerKey.fromName(parent)) { - "Parent is either unspecified or invalid." - } - val measurementConsumerReferenceId = parentKey.measurementConsumerId - - return if (pageToken.isNotBlank()) { - ListReportingSetsPageToken.parseFrom(pageToken.base64UrlDecode()).copy { - grpcRequire(this.measurementConsumerReferenceId == measurementConsumerReferenceId) { - "Arguments must be kept the same when using a page token" - } - - if ( - source.pageSize != 0 && source.pageSize >= MIN_PAGE_SIZE && source.pageSize <= MAX_PAGE_SIZE - ) { - pageSize = source.pageSize - } - } - } else { - listReportingSetsPageToken { - pageSize = - when { - source.pageSize < MIN_PAGE_SIZE -> DEFAULT_PAGE_SIZE - source.pageSize > MAX_PAGE_SIZE -> MAX_PAGE_SIZE - else -> source.pageSize - } - this.measurementConsumerReferenceId = measurementConsumerReferenceId - } - } -} - -/** Converts a public [ReportingSet] to an internal [InternalReportingSet]. */ -private fun ReportingSet.toInternal( - measurementConsumerKey: MeasurementConsumerKey -): InternalReportingSet { - val source = this - - return internalReportingSet { - measurementConsumerReferenceId = measurementConsumerKey.measurementConsumerId - - for (eventGroup: String in source.eventGroupsList) { - eventGroupKeys += - grpcRequireNotNull(EventGroupKey.fromName(eventGroup)) { - "EventGroup is either unspecified or invalid." - } - .toInternal(measurementConsumerKey.measurementConsumerId) - } - filter = source.filter - displayName = source.displayName - } -} - -/** Converts a public [EventGroupKey] to an internal [InternalEventGroupKey] */ -private fun EventGroupKey.toInternal( - measurementConsumerReferenceId: String -): InternalEventGroupKey { - val source = this - return internalEventGroupKey { - dataProviderReferenceId = source.dataProviderReferenceId - eventGroupReferenceId = source.eventGroupReferenceId - this.measurementConsumerReferenceId = measurementConsumerReferenceId - } -} - -/** Converts an internal [InternalReportingSet] to a public [ReportingSet]. */ -private fun InternalReportingSet.toReportingSet(): ReportingSet { - val source = this - return reportingSet { - name = - ReportingSetKey( - measurementConsumerId = source.measurementConsumerReferenceId, - reportingSetId = externalIdToApiId(source.externalReportingSetId), - ) - .toName() - eventGroups.addAll( - eventGroupKeysList.map { - EventGroupKey( - measurementConsumerReferenceId = it.measurementConsumerReferenceId, - dataProviderReferenceId = it.dataProviderReferenceId, - eventGroupReferenceId = it.eventGroupReferenceId, - ) - .toName() - } - ) - filter = source.filter - displayName = source.displayName - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportsService.kt deleted file mode 100644 index 6308cdbbc9e..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportsService.kt +++ /dev/null @@ -1,2260 +0,0 @@ -/* - * Copyright 2022 The Cross-Media Measurement Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import com.google.protobuf.Any as ProtoAny -import com.google.protobuf.ByteString -import com.google.protobuf.Duration as ProtoDuration -import com.google.protobuf.kotlin.unpack -import com.google.protobuf.util.Durations -import com.google.protobuf.util.Timestamps -import com.google.type.Interval -import com.google.type.interval -import io.grpc.Status -import io.grpc.StatusException -import java.io.File -import java.security.PrivateKey -import java.security.SignatureException -import java.security.cert.CertPathValidatorException -import java.security.cert.X509Certificate -import java.time.Instant -import kotlin.math.min -import kotlin.random.Random -import kotlinx.coroutines.Deferred -import kotlinx.coroutines.async -import kotlinx.coroutines.awaitAll -import kotlinx.coroutines.coroutineScope -import kotlinx.coroutines.flow.toList -import org.wfanet.measurement.api.v2alpha.Certificate -import org.wfanet.measurement.api.v2alpha.CertificatesGrpcKt.CertificatesCoroutineStub -import org.wfanet.measurement.api.v2alpha.CreateMeasurementRequest -import org.wfanet.measurement.api.v2alpha.DataProvider -import org.wfanet.measurement.api.v2alpha.DataProviderKey -import org.wfanet.measurement.api.v2alpha.DataProvidersGrpcKt.DataProvidersCoroutineStub -import org.wfanet.measurement.api.v2alpha.EncryptionPublicKey -import org.wfanet.measurement.api.v2alpha.EventGroupKey as CmmsEventGroupKey -import org.wfanet.measurement.api.v2alpha.Measurement -import org.wfanet.measurement.api.v2alpha.Measurement.DataProviderEntry -import org.wfanet.measurement.api.v2alpha.MeasurementConsumer -import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey -import org.wfanet.measurement.api.v2alpha.MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineStub -import org.wfanet.measurement.api.v2alpha.MeasurementKey -import org.wfanet.measurement.api.v2alpha.MeasurementKt.DataProviderEntryKt -import org.wfanet.measurement.api.v2alpha.MeasurementKt.dataProviderEntry -import org.wfanet.measurement.api.v2alpha.MeasurementSpec -import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt -import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt.MeasurementsCoroutineStub -import org.wfanet.measurement.api.v2alpha.RequisitionSpec.EventGroupEntry -import org.wfanet.measurement.api.v2alpha.RequisitionSpecKt -import org.wfanet.measurement.api.v2alpha.SignedMessage -import org.wfanet.measurement.api.v2alpha.createMeasurementRequest -import org.wfanet.measurement.api.v2alpha.differentialPrivacyParams -import org.wfanet.measurement.api.v2alpha.getCertificateRequest -import org.wfanet.measurement.api.v2alpha.getDataProviderRequest -import org.wfanet.measurement.api.v2alpha.getMeasurementConsumerRequest -import org.wfanet.measurement.api.v2alpha.getMeasurementRequest -import org.wfanet.measurement.api.v2alpha.measurement -import org.wfanet.measurement.api.v2alpha.measurementSpec -import org.wfanet.measurement.api.v2alpha.requisitionSpec -import org.wfanet.measurement.api.v2alpha.unpack -import org.wfanet.measurement.api.withAuthenticationKey -import org.wfanet.measurement.common.base64UrlDecode -import org.wfanet.measurement.common.base64UrlEncode -import org.wfanet.measurement.common.crypto.Hashing -import org.wfanet.measurement.common.crypto.PrivateKeyHandle -import org.wfanet.measurement.common.crypto.SigningKeyHandle -import org.wfanet.measurement.common.crypto.authorityKeyIdentifier -import org.wfanet.measurement.common.crypto.readCertificate -import org.wfanet.measurement.common.crypto.readPrivateKey -import org.wfanet.measurement.common.grpc.failGrpc -import org.wfanet.measurement.common.grpc.grpcRequire -import org.wfanet.measurement.common.grpc.grpcRequireNotNull -import org.wfanet.measurement.common.identity.apiIdToExternalId -import org.wfanet.measurement.common.identity.externalIdToApiId -import org.wfanet.measurement.common.readByteString -import org.wfanet.measurement.config.reporting.MeasurementSpecConfig -import org.wfanet.measurement.consent.client.measurementconsumer.decryptResult -import org.wfanet.measurement.consent.client.measurementconsumer.encryptRequisitionSpec -import org.wfanet.measurement.consent.client.measurementconsumer.signMeasurementSpec -import org.wfanet.measurement.consent.client.measurementconsumer.signRequisitionSpec -import org.wfanet.measurement.consent.client.measurementconsumer.verifyEncryptionPublicKey -import org.wfanet.measurement.consent.client.measurementconsumer.verifyResult -import org.wfanet.measurement.internal.reporting.CreateReportRequest as InternalCreateReportRequest -import org.wfanet.measurement.internal.reporting.CreateReportRequest.MeasurementKey as InternalMeasurementKey -import org.wfanet.measurement.internal.reporting.CreateReportRequestKt as InternalCreateReportRequestKt -import org.wfanet.measurement.internal.reporting.Measurement as InternalMeasurement -import org.wfanet.measurement.internal.reporting.Measurement.Result as InternalMeasurementResult -import org.wfanet.measurement.internal.reporting.MeasurementKt as InternalMeasurementKt -import org.wfanet.measurement.internal.reporting.MeasurementsGrpcKt.MeasurementsCoroutineStub as InternalMeasurementsCoroutineStub -import org.wfanet.measurement.internal.reporting.Metric as InternalMetric -import org.wfanet.measurement.internal.reporting.Metric.Details as InternalMetricDetails -import org.wfanet.measurement.internal.reporting.Metric.Details.MetricTypeCase as InternalMetricTypeCase -import org.wfanet.measurement.internal.reporting.Metric.FrequencyHistogramParams as InternalFrequencyHistogramParams -import org.wfanet.measurement.internal.reporting.Metric.ImpressionCountParams as InternalImpressionCountParams -import org.wfanet.measurement.internal.reporting.Metric.MeasurementCalculation -import org.wfanet.measurement.internal.reporting.Metric.NamedSetOperation as InternalNamedSetOperation -import org.wfanet.measurement.internal.reporting.Metric.SetOperation as InternalSetOperation -import org.wfanet.measurement.internal.reporting.Metric.SetOperation.Operand as InternalOperand -import org.wfanet.measurement.internal.reporting.Metric.WatchDurationParams as InternalWatchDurationParams -import org.wfanet.measurement.internal.reporting.MetricKt as InternalMetricKt -import org.wfanet.measurement.internal.reporting.MetricKt.MeasurementCalculationKt -import org.wfanet.measurement.internal.reporting.MetricKt.SetOperationKt as InternalSetOperationKt -import org.wfanet.measurement.internal.reporting.PeriodicTimeInterval as InternalPeriodicTimeInterval -import org.wfanet.measurement.internal.reporting.Report as InternalReport -import org.wfanet.measurement.internal.reporting.ReportKt as InternalReportKt -import org.wfanet.measurement.internal.reporting.ReportingSet as InternalReportingSet -import org.wfanet.measurement.internal.reporting.ReportingSetsGrpcKt -import org.wfanet.measurement.internal.reporting.ReportingSetsGrpcKt.ReportingSetsCoroutineStub as InternalReportingSetsCoroutineStub -import org.wfanet.measurement.internal.reporting.ReportsGrpcKt.ReportsCoroutineStub as InternalReportsCoroutineStub -import org.wfanet.measurement.internal.reporting.SetMeasurementResultRequest as SetInternalMeasurementResultRequest -import org.wfanet.measurement.internal.reporting.StreamReportsRequest as StreamInternalReportsRequest -import org.wfanet.measurement.internal.reporting.StreamReportsRequestKt.filter -import org.wfanet.measurement.internal.reporting.TimeInterval as InternalTimeInterval -import org.wfanet.measurement.internal.reporting.TimeIntervals as InternalTimeIntervals -import org.wfanet.measurement.internal.reporting.batchCreateMeasurementsRequest -import org.wfanet.measurement.internal.reporting.batchGetReportingSetRequest -import org.wfanet.measurement.internal.reporting.createReportRequest as internalCreateReportRequest -import org.wfanet.measurement.internal.reporting.getReportByIdempotencyKeyRequest -import org.wfanet.measurement.internal.reporting.getReportRequest as getInternalReportRequest -import org.wfanet.measurement.internal.reporting.measurement as internalMeasurement -import org.wfanet.measurement.internal.reporting.metric as internalMetric -import org.wfanet.measurement.internal.reporting.periodicTimeInterval as internalPeriodicTimeInterval -import org.wfanet.measurement.internal.reporting.report as internalReport -import org.wfanet.measurement.internal.reporting.setMeasurementFailureRequest as setInternalMeasurementFailureRequest -import org.wfanet.measurement.internal.reporting.setMeasurementResultRequest as setInternalMeasurementResultRequest -import org.wfanet.measurement.internal.reporting.streamReportsRequest as streamInternalReportsRequest -import org.wfanet.measurement.internal.reporting.timeInterval as internalTimeInterval -import org.wfanet.measurement.internal.reporting.timeIntervals as internalTimeIntervals -import org.wfanet.measurement.reporting.service.api.EncryptionKeyPairStore -import org.wfanet.measurement.reporting.v1alpha.CreateReportRequest -import org.wfanet.measurement.reporting.v1alpha.GetReportRequest -import org.wfanet.measurement.reporting.v1alpha.ListReportsPageToken -import org.wfanet.measurement.reporting.v1alpha.ListReportsPageTokenKt.previousPageEnd -import org.wfanet.measurement.reporting.v1alpha.ListReportsRequest -import org.wfanet.measurement.reporting.v1alpha.ListReportsResponse -import org.wfanet.measurement.reporting.v1alpha.Metric -import org.wfanet.measurement.reporting.v1alpha.Metric.FrequencyHistogramParams -import org.wfanet.measurement.reporting.v1alpha.Metric.ImpressionCountParams -import org.wfanet.measurement.reporting.v1alpha.Metric.MetricTypeCase -import org.wfanet.measurement.reporting.v1alpha.Metric.NamedSetOperation -import org.wfanet.measurement.reporting.v1alpha.Metric.SetOperation -import org.wfanet.measurement.reporting.v1alpha.Metric.SetOperation.Operand -import org.wfanet.measurement.reporting.v1alpha.Metric.WatchDurationParams -import org.wfanet.measurement.reporting.v1alpha.MetricKt.SetOperationKt.operand -import org.wfanet.measurement.reporting.v1alpha.MetricKt.frequencyHistogramParams -import org.wfanet.measurement.reporting.v1alpha.MetricKt.impressionCountParams -import org.wfanet.measurement.reporting.v1alpha.MetricKt.namedSetOperation -import org.wfanet.measurement.reporting.v1alpha.MetricKt.reachParams -import org.wfanet.measurement.reporting.v1alpha.MetricKt.setOperation -import org.wfanet.measurement.reporting.v1alpha.MetricKt.watchDurationParams -import org.wfanet.measurement.reporting.v1alpha.PeriodicTimeInterval -import org.wfanet.measurement.reporting.v1alpha.Report -import org.wfanet.measurement.reporting.v1alpha.Report.Result -import org.wfanet.measurement.reporting.v1alpha.ReportKt.EventGroupUniverseKt -import org.wfanet.measurement.reporting.v1alpha.ReportKt.ResultKt.HistogramTableKt.row -import org.wfanet.measurement.reporting.v1alpha.ReportKt.ResultKt.column -import org.wfanet.measurement.reporting.v1alpha.ReportKt.ResultKt.histogramTable -import org.wfanet.measurement.reporting.v1alpha.ReportKt.ResultKt.scalarTable -import org.wfanet.measurement.reporting.v1alpha.ReportKt.eventGroupUniverse -import org.wfanet.measurement.reporting.v1alpha.ReportKt.result -import org.wfanet.measurement.reporting.v1alpha.ReportsGrpcKt.ReportsCoroutineImplBase -import org.wfanet.measurement.reporting.v1alpha.TimeInterval -import org.wfanet.measurement.reporting.v1alpha.TimeIntervals -import org.wfanet.measurement.reporting.v1alpha.copy -import org.wfanet.measurement.reporting.v1alpha.listReportsPageToken -import org.wfanet.measurement.reporting.v1alpha.listReportsResponse -import org.wfanet.measurement.reporting.v1alpha.metric -import org.wfanet.measurement.reporting.v1alpha.periodicTimeInterval -import org.wfanet.measurement.reporting.v1alpha.report -import org.wfanet.measurement.reporting.v1alpha.timeInterval -import org.wfanet.measurement.reporting.v1alpha.timeIntervals - -private const val MIN_PAGE_SIZE = 1 -private const val DEFAULT_PAGE_SIZE = 50 -private const val MAX_PAGE_SIZE = 1000 - -private val timeIntervalComparator: (TimeInterval, TimeInterval) -> Int = { a, b -> - val start = Timestamps.compare(a.startTime, b.startTime) - if (start != 0) { - start - } else { - Timestamps.compare(a.endTime, b.endTime) - } -} - -class ReportsService( - private val internalReportsStub: InternalReportsCoroutineStub, - private val internalReportingSetsStub: InternalReportingSetsCoroutineStub, - private val internalMeasurementsStub: InternalMeasurementsCoroutineStub, - private val dataProvidersStub: DataProvidersCoroutineStub, - private val measurementConsumersStub: MeasurementConsumersCoroutineStub, - private val measurementsStub: MeasurementsCoroutineStub, - private val certificateStub: CertificatesCoroutineStub, - private val encryptionKeyPairStore: EncryptionKeyPairStore, - private val secureRandom: Random, - private val signingPrivateKeyDir: File, - private val trustedCertificates: Map, - measurementSpecConfig: MeasurementSpecConfig, -) : ReportsCoroutineImplBase() { - private val setOperationCompiler = SetOperationCompiler() - private val measurementSpecComponentFactory = - MeasurementSpecComponentFactory(measurementSpecConfig, secureRandom) - - private data class ReportInfo( - val measurementConsumerReferenceId: String, - val reportIdempotencyKey: String, - val eventGroupFilters: Map, - ) - - private data class SigningConfig( - val signingCertificateName: String, - val signingCertificateDer: ByteString, - val signingPrivateKey: PrivateKey, - ) - - private data class WeightedMeasurementInfo( - val reportingMeasurementId: String, - val weightedMeasurement: WeightedMeasurement, - val timeInterval: TimeInterval, - val reportTimeInterval: TimeInterval, - var kingdomMeasurementId: String? = null, - ) - - private data class SetOperationResult( - val weightedMeasurementInfoList: List, - val internalMetricDetails: InternalMetricDetails, - ) - - private data class DataProviderInfo( - val dataProviderName: String, - val publicKey: SignedMessage, - val certificateName: String, - ) - - override suspend fun createReport(request: CreateReportRequest): Report { - grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) { - "Parent is either unspecified or invalid." - } - val principal: ReportingPrincipal = principalFromCurrentContext - - when (principal) { - is MeasurementConsumerPrincipal -> { - if (request.parent != principal.resourceKey.toName()) { - failGrpc(Status.PERMISSION_DENIED) { - "Cannot create a Report for another MeasurementConsumer." - } - } - } - } - - val resourceKey = principal.resourceKey - val apiAuthenticationKey: String = principal.config.apiKey - - grpcRequire(request.hasReport()) { "Report is not specified." } - - // TODO(@riemanli) Put the check here as the reportIdempotencyKey will be moved to the request - // level in the future. - grpcRequire(request.report.reportIdempotencyKey.isNotBlank()) { - "ReportIdempotencyKey is not specified." - } - grpcRequire(request.report.measurementConsumer == request.parent) { - "Cannot create a Report for another MeasurementConsumer." - } - - grpcRequire(request.report.metricsList.isNotEmpty()) { "Metrics in Report cannot be empty." } - request.report.metricsList.forEach { - grpcRequire(it.setOperationsList.isNotEmpty()) { "Metric setOperationsList cannot be empty." } - it.setOperationsList.forEach { namedSetOperation -> - grpcRequire(namedSetOperation.uniqueName.isNotBlank()) { - "NamedSetOperation uniqueName is unspecified." - } - grpcRequire( - !namedSetOperation.setOperation.lhs.operandCase.equals( - Operand.OperandCase.OPERAND_NOT_SET - ) - ) { - "NamedSetOperation SetOperation Operand is unspecified." - } - grpcRequire( - !namedSetOperation.setOperation.type.equals(SetOperation.Type.TYPE_UNSPECIFIED) - ) { - "NamedSetOperation SetOperation Type is unspecified." - } - } - } - checkSetOperationNamesUniqueness(request.report.metricsList) - - val existingInternalReport: InternalReport? = - getInternalReport(resourceKey.measurementConsumerId, request.report.reportIdempotencyKey) - - if (existingInternalReport != null) return existingInternalReport.toReport() - - val reportInfo: ReportInfo = buildReportInfo(request, resourceKey.measurementConsumerId) - - val metrics = - Metrics(reportInfo, internalReportingSetsStub, setOperationCompiler, request.report) - metrics.process() - val namedSetOperationResults: Map = - metrics.getNamedSetOperationResults() - val internalReportingSetMap: Map = - metrics.getInternalReportingSetsMap() - - val measurementConsumer = - try { - measurementConsumersStub - .withAuthenticationKey(apiAuthenticationKey) - .getMeasurementConsumer( - getMeasurementConsumerRequest { - name = MeasurementConsumerKey(resourceKey.measurementConsumerId).toName() - } - ) - } catch (e: StatusException) { - throw Exception( - "Unable to retrieve the measurement consumer " + - "[${MeasurementConsumerKey(resourceKey.measurementConsumerId).toName()}].", - e, - ) - } - - val dataProviderNames = mutableSetOf() - for (internalReportingSet in internalReportingSetMap.values) { - for (eventGroupKey in internalReportingSet.eventGroupKeysList) { - dataProviderNames.add(DataProviderKey(eventGroupKey.dataProviderReferenceId).toName()) - } - } - val dataProviderInfoMap: Map = - buildDataProviderInfoMap(apiAuthenticationKey, dataProviderNames) - - // TODO: Factor this out to a separate class similar to EncryptionKeyPairStore. - val signingPrivateKeyDer: ByteString = - signingPrivateKeyDir.resolve(principal.config.signingPrivateKeyPath).readByteString() - - val signingCertificateDer: ByteString = - getSigningCertificateDer(apiAuthenticationKey, principal.config.signingCertificateName) - - val signingConfig = - SigningConfig( - principal.config.signingCertificateName, - signingCertificateDer, - readPrivateKey( - signingPrivateKeyDer, - readCertificate(signingCertificateDer).publicKey.algorithm, - ), - ) - - createMeasurements( - request, - namedSetOperationResults, - reportInfo, - measurementConsumer, - apiAuthenticationKey, - signingConfig, - internalReportingSetMap, - dataProviderInfoMap, - ) - - val internalCreateReportRequest: InternalCreateReportRequest = - buildInternalCreateReportRequest(request, reportInfo, namedSetOperationResults) - try { - return internalReportsStub.createReport(internalCreateReportRequest).toReport() - } catch (e: StatusException) { - throw Exception("Unable to create a report in the reporting database.", e) - } - } - - /** Gets a signing certificate x509Der in ByteString. */ - private suspend fun getSigningCertificateDer( - apiAuthenticationKey: String, - signingCertificateName: String, - ): ByteString { - // TODO: Replace this with caching certificates or having them stored alongside the private key. - return try { - certificateStub - .withAuthenticationKey(apiAuthenticationKey) - .getCertificate(getCertificateRequest { name = signingCertificateName }) - .x509Der - } catch (e: StatusException) { - throw Exception( - "Unable to retrieve the signing certificate for the measurement consumer " + - "[$signingCertificateName].", - e, - ) - } - } - - /** Builds a [ReportInfo] from a [CreateReportRequest]. */ - private fun buildReportInfo( - request: CreateReportRequest, - measurementConsumerReferenceId: String, - ): ReportInfo { - grpcRequire(request.report.hasEventGroupUniverse()) { "EventGroupUniverse is not specified." } - grpcRequire(request.report.eventGroupUniverse.eventGroupEntriesList.isNotEmpty()) { - "EventGroupUniverse's eventGroupEntriesList cannot be empty." - } - - val eventGroupFilters = - request.report.eventGroupUniverse.eventGroupEntriesList.associate { - grpcRequireNotNull(EventGroupKey.fromName(it.key)) { - "EventGroupEntry key is not specified or invalid." - } - it.key to it.value - } - - return ReportInfo( - measurementConsumerReferenceId, - request.report.reportIdempotencyKey, - eventGroupFilters, - ) - } - - /** Creates CMM public [Measurement]s and [InternalMeasurement]s from [SetOperationResult]s. */ - private suspend fun createMeasurements( - request: CreateReportRequest, - namedSetOperationResults: Map, - reportInfo: ReportInfo, - measurementConsumer: MeasurementConsumer, - apiAuthenticationKey: String, - signingConfig: SigningConfig, - internalReportingSetMap: Map, - dataProviderInfoMap: Map, - ) = coroutineScope { - val deferredMeasurements = mutableListOf>() - for (metric in request.report.metricsList) { - val internalMetricDetails = buildInternalMetricDetails(metric) - - for (namedSetOperation in metric.setOperationsList) { - val setOperationId = - buildSetOperationId( - reportInfo.reportIdempotencyKey, - internalMetricDetails, - namedSetOperation.uniqueName, - ) - - val setOperationResult: SetOperationResult = - namedSetOperationResults[setOperationId] ?: continue - - for (weightedMeasurementInfo in setOperationResult.weightedMeasurementInfoList) { - deferredMeasurements.add( - async { - createMeasurement( - weightedMeasurementInfo, - reportInfo, - setOperationResult.internalMetricDetails, - measurementConsumer, - apiAuthenticationKey, - signingConfig, - internalReportingSetMap, - dataProviderInfoMap, - ) - .also { - weightedMeasurementInfo.kingdomMeasurementId = - checkNotNull(MeasurementKey.fromName(it.name)).measurementId - } - } - ) - } - } - } - - val internalMeasurements = mutableListOf() - for (measurement in deferredMeasurements.awaitAll()) { - val measurementKey = checkNotNull(MeasurementKey.fromName(measurement.name)) - internalMeasurements.add( - internalMeasurement { - this.measurementConsumerReferenceId = measurementKey.measurementConsumerId - this.measurementReferenceId = measurementKey.measurementId - state = InternalMeasurement.State.PENDING - } - ) - } - - try { - internalMeasurementsStub.batchCreateMeasurements( - batchCreateMeasurementsRequest { measurements += internalMeasurements } - ) - } catch (e: StatusException) { - throw Status.UNKNOWN.withDescription( - "Unable to create measurement in the reporting database." - ) - .withCause(e) - .asRuntimeException() - } - } - - /** Creates a kingdom measurement for a [WeightedMeasurement]. */ - private suspend fun createMeasurement( - weightedMeasurementInfo: WeightedMeasurementInfo, - reportInfo: ReportInfo, - internalMetricDetails: InternalMetricDetails, - measurementConsumer: MeasurementConsumer, - apiAuthenticationKey: String, - signingConfig: SigningConfig, - internalReportingSetMap: Map, - dataProviderInfoMap: Map, - ): Measurement { - val eventGroupEntriesByDataProvider = - groupEventGroupEntriesByDataProvider( - weightedMeasurementInfo.weightedMeasurement.reportingSets, - weightedMeasurementInfo.timeInterval.toMeasurementTimeInterval(), - reportInfo.eventGroupFilters, - internalReportingSetMap, - ) - - val createMeasurementRequest: CreateMeasurementRequest = - buildCreateMeasurementRequest( - measurementConsumer, - eventGroupEntriesByDataProvider, - internalMetricDetails, - weightedMeasurementInfo.reportingMeasurementId, - signingConfig, - dataProviderInfoMap, - ) - - try { - return measurementsStub - .withAuthenticationKey(apiAuthenticationKey) - .createMeasurement(createMeasurementRequest) - } catch (e: StatusException) { - throw Exception( - "Unable to create Measurement with request ID ${createMeasurementRequest.requestId}", - e, - ) - } - } - - /** Gets an [InternalReport]. */ - private suspend fun getInternalReport( - measurementConsumerReferenceId: String, - reportIdempotencyKey: String, - ): InternalReport? { - return try { - internalReportsStub.getReportByIdempotencyKey( - getReportByIdempotencyKeyRequest { - this.measurementConsumerReferenceId = measurementConsumerReferenceId - this.reportIdempotencyKey = reportIdempotencyKey - } - ) - } catch (e: StatusException) { - if (e.status.code != Status.Code.NOT_FOUND) { - throw Exception( - "Unable to retrieve a report from the reporting database using the provided " + - "reportIdempotencyKey [$reportIdempotencyKey].", - e, - ) - } - null - } - } - - override suspend fun listReports(request: ListReportsRequest): ListReportsResponse { - val listReportsPageToken = request.toListReportsPageToken() - - // Based on AIP-132#Errors - val principal: ReportingPrincipal = principalFromCurrentContext - when (principal) { - is MeasurementConsumerPrincipal -> { - if (request.parent != principal.resourceKey.toName()) { - failGrpc(Status.PERMISSION_DENIED) { - "Cannot list Reports belonging to other MeasurementConsumers." - } - } - } - } - val principalName = principal.resourceKey.toName() - - val apiAuthenticationKey: String = principal.config.apiKey - - val streamInternalReportsRequest: StreamInternalReportsRequest = - listReportsPageToken.toStreamReportsRequest() - val results: List = - try { - internalReportsStub.streamReports(streamInternalReportsRequest).toList() - } catch (e: StatusException) { - throw Exception("Unable to list reports from the reporting database.", e) - } - - if (results.isEmpty()) { - return ListReportsResponse.getDefaultInstance() - } - - val nextPageToken: ListReportsPageToken? = - if (results.size > listReportsPageToken.pageSize) { - listReportsPageToken.copy { - lastReport = previousPageEnd { - measurementConsumerReferenceId = - results[results.lastIndex - 1].measurementConsumerReferenceId - externalReportId = results[results.lastIndex - 1].externalReportId - } - } - } else null - - return listReportsResponse { - reports += - results - .subList(0, min(results.size, listReportsPageToken.pageSize)) - .map { syncReport(it, apiAuthenticationKey, principalName) } - .map(InternalReport::toReport) - - if (nextPageToken != null) { - this.nextPageToken = nextPageToken.toByteString().base64UrlEncode() - } - } - } - - override suspend fun getReport(request: GetReportRequest): Report { - val reportKey = - grpcRequireNotNull(ReportKey.fromName(request.name)) { - "Report name is either unspecified or invalid" - } - - val principal: ReportingPrincipal = principalFromCurrentContext - when (principal) { - is MeasurementConsumerPrincipal -> { - if (reportKey.measurementConsumerId != principal.resourceKey.measurementConsumerId) { - failGrpc(Status.PERMISSION_DENIED) { - "Cannot get Report belonging to other MeasurementConsumers." - } - } - } - } - val principalName = principal.resourceKey.toName() - - val apiAuthenticationKey: String = principal.config.apiKey - - val internalReport = - try { - internalReportsStub.getReport( - getInternalReportRequest { - measurementConsumerReferenceId = reportKey.measurementConsumerId - externalReportId = apiIdToExternalId(reportKey.reportId) - } - ) - } catch (e: StatusException) { - throw Exception("Unable to get the report from the reporting database.", e) - } - - val syncedInternalReport = syncReport(internalReport, apiAuthenticationKey, principalName) - - return syncedInternalReport.toReport() - } - - /** Syncs the [InternalReport] and all [InternalMeasurement]s used by it. */ - private suspend fun syncReport( - internalReport: InternalReport, - apiAuthenticationKey: String, - principalName: String, - ): InternalReport { - // Report with SUCCEEDED or FAILED state is already synced. - if ( - internalReport.state == InternalReport.State.SUCCEEDED || - internalReport.state == InternalReport.State.FAILED - ) { - return internalReport - } else if ( - internalReport.state == InternalReport.State.STATE_UNSPECIFIED || - internalReport.state == InternalReport.State.UNRECOGNIZED - ) { - error( - "The measurements cannot be synced because the report state was not set correctly as it " + - "should've been." - ) - } - - // Syncs measurements - syncMeasurements( - internalReport.measurementsMap, - internalReport.measurementConsumerReferenceId, - apiAuthenticationKey, - principalName, - ) - - return try { - internalReportsStub.getReport( - getInternalReportRequest { - measurementConsumerReferenceId = internalReport.measurementConsumerReferenceId - externalReportId = internalReport.externalReportId - } - ) - } catch (e: StatusException) { - val reportName = - ReportKey( - internalReport.measurementConsumerReferenceId, - externalIdToApiId(internalReport.externalReportId), - ) - .toName() - throw Exception("Unable to get the report [$reportName] from the reporting database.", e) - } - } - - /** Syncs [InternalMeasurement]s. */ - private suspend fun syncMeasurements( - measurementsMap: Map, - measurementConsumerReferenceId: String, - apiAuthenticationKey: String, - principalName: String, - ) { - for ((measurementReferenceId, internalMeasurement) in measurementsMap) { - // Measurement with SUCCEEDED state is already synced - if (internalMeasurement.state == InternalMeasurement.State.SUCCEEDED) continue - - syncMeasurement( - measurementReferenceId, - measurementConsumerReferenceId, - apiAuthenticationKey, - principalName, - ) - } - } - - /** Syncs [InternalMeasurement] with the CMM [Measurement] given the measurement reference ID. */ - private suspend fun syncMeasurement( - measurementReferenceId: String, - measurementConsumerReferenceId: String, - apiAuthenticationKey: String, - principalName: String, - ) { - val measurementResourceName = - MeasurementKey(measurementConsumerReferenceId, measurementReferenceId).toName() - val measurement = - try { - measurementsStub - .withAuthenticationKey(apiAuthenticationKey) - .getMeasurement(getMeasurementRequest { name = measurementResourceName }) - } catch (e: StatusException) { - throw Exception("Unable to retrieve the measurement [$measurementResourceName].", e) - } - - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - when (measurement.state) { - Measurement.State.SUCCEEDED -> { - // Converts a Measurement to an InternalMeasurement and store it into the database with - // SUCCEEDED state - val measurementSpec: MeasurementSpec = measurement.measurementSpec.unpack() - val encryptionPrivateKeyHandle = - encryptionKeyPairStore.getPrivateKeyHandle( - principalName, - measurementSpec.measurementPublicKey.unpack().data, - ) ?: failGrpc(Status.PERMISSION_DENIED) { "Encryption private key not found" } - - val setInternalMeasurementResultRequest = - buildSetInternalMeasurementResultRequest( - measurementConsumerReferenceId, - measurementReferenceId, - measurement.resultsList, - encryptionPrivateKeyHandle, - apiAuthenticationKey, - ) - - try { - internalMeasurementsStub.setMeasurementResult(setInternalMeasurementResultRequest) - } catch (e: StatusException) { - throw Exception( - "Unable to update the measurement [$measurementResourceName] in the reporting " + - "database.", - e, - ) - } - } - Measurement.State.AWAITING_REQUISITION_FULFILLMENT, - Measurement.State.COMPUTING -> {} // No action needed - Measurement.State.FAILED, - Measurement.State.CANCELLED -> { - val setInternalMeasurementFailureRequest = setInternalMeasurementFailureRequest { - this.measurementConsumerReferenceId = measurementConsumerReferenceId - this.measurementReferenceId = measurementReferenceId - failure = measurement.failure.toInternal() - } - - try { - internalMeasurementsStub.setMeasurementFailure(setInternalMeasurementFailureRequest) - } catch (e: StatusException) { - throw Exception( - "Unable to update the measurement [$measurementResourceName] in the reporting " + - "database.", - e, - ) - } - } - Measurement.State.STATE_UNSPECIFIED -> error("The measurement state should've been set.") - Measurement.State.UNRECOGNIZED -> error("Unrecognized measurement state.") - } - } - - /** Builds a [SetInternalMeasurementResultRequest]. */ - private suspend fun buildSetInternalMeasurementResultRequest( - measurementConsumerReferenceId: String, - measurementReferenceId: String, - resultsList: List, - privateKeyHandle: PrivateKeyHandle, - apiAuthenticationKey: String, - ): SetInternalMeasurementResultRequest { - - return setInternalMeasurementResultRequest { - this.measurementConsumerReferenceId = measurementConsumerReferenceId - this.measurementReferenceId = measurementReferenceId - result = - aggregateResults( - resultsList - .map { decryptMeasurementResultOutput(it, privateKeyHandle, apiAuthenticationKey) } - .map(Measurement.Result::toInternal) - ) - } - } - - /** Decrypts a [Measurement.ResultOutput] to [Measurement.Result] */ - private suspend fun decryptMeasurementResultOutput( - measurementResultOutput: Measurement.ResultOutput, - encryptionPrivateKeyHandle: PrivateKeyHandle, - apiAuthenticationKey: String, - ): Measurement.Result { - // TODO: Cache the certificate - val certificate = - try { - certificateStub - .withAuthenticationKey(apiAuthenticationKey) - .getCertificate(getCertificateRequest { name = measurementResultOutput.certificate }) - } catch (e: StatusException) { - throw Exception( - "Unable to retrieve the certificate [${measurementResultOutput.certificate}].", - e, - ) - } - - val signedResult: SignedMessage = - decryptResult(measurementResultOutput.encryptedResult, encryptionPrivateKeyHandle) - - val x509Certificate: X509Certificate = readCertificate(certificate.x509Der) - val trustedIssuer: X509Certificate = - checkNotNull(trustedCertificates[checkNotNull(x509Certificate.authorityKeyIdentifier)]) { - "${certificate.name} not issued by trusted CA" - } - // TODO: Record verification failure in internal Measurement rather than having the RPC fail. - try { - verifyResult(signedResult, x509Certificate, trustedIssuer) - } catch (e: CertPathValidatorException) { - throw Exception("Certificate path for ${certificate.name} is invalid", e) - } catch (e: SignatureException) { - throw Exception("Measurement result signature is invalid", e) - } - return signedResult.unpack() - } - - /** Builds an [InternalCreateReportRequest] from a public [CreateReportRequest]. */ - private suspend fun buildInternalCreateReportRequest( - request: CreateReportRequest, - reportInfo: ReportInfo, - namedSetOperationResults: Map, - ): InternalCreateReportRequest { - val internalReport: InternalReport = internalReport { - this.measurementConsumerReferenceId = reportInfo.measurementConsumerReferenceId - - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - when (request.report.timeCase) { - Report.TimeCase.TIME_INTERVALS -> { - this.timeIntervals = request.report.timeIntervals.toInternal() - } - Report.TimeCase.PERIODIC_TIME_INTERVAL -> { - this.periodicTimeInterval = request.report.periodicTimeInterval.toInternal() - } - Report.TimeCase.TIME_NOT_SET -> - failGrpc(Status.INVALID_ARGUMENT) { "The time in Report is not specified." } - } - - for (metric in request.report.metricsList) { - this@internalReport.metrics += - buildInternalMetric(metric, reportInfo, namedSetOperationResults) - } - - details = - InternalReportKt.details { this.eventGroupFilters.putAll(reportInfo.eventGroupFilters) } - - this.reportIdempotencyKey = reportInfo.reportIdempotencyKey - } - - return internalCreateReportRequest { - report = internalReport - measurements += - internalReport.metricsList.flatMap { internalMetric -> - buildInternalMeasurementKeys(internalMetric, reportInfo.measurementConsumerReferenceId) - } - } - } - - /** Builds an [InternalMetric] from a public [Metric]. */ - private suspend fun buildInternalMetric( - metric: Metric, - reportInfo: ReportInfo, - namedSetOperationResults: Map, - ): InternalMetric { - return internalMetric { - details = buildInternalMetricDetails(metric) - - metric.setOperationsList.map { setOperation -> - val setOperationId = - buildSetOperationId(reportInfo.reportIdempotencyKey, details, setOperation.uniqueName) - - namedSetOperationResults[setOperationId]?.let { setOperationResult -> - val internalNamedSetOperation = - buildInternalNamedSetOperation(setOperation, reportInfo, setOperationResult) - namedSetOperations += internalNamedSetOperation - } - } - } - } - - /** Builds an [InternalNamedSetOperation] from a public [NamedSetOperation]. */ - private suspend fun buildInternalNamedSetOperation( - namedSetOperation: NamedSetOperation, - reportInfo: ReportInfo, - setOperationResult: SetOperationResult, - ): InternalNamedSetOperation { - return InternalMetricKt.namedSetOperation { - displayName = namedSetOperation.uniqueName - setOperation = - buildInternalSetOperation( - namedSetOperation.setOperation, - reportInfo.measurementConsumerReferenceId, - ) - - this.measurementCalculations += buildMeasurementCalculationList(setOperationResult) - } - } - - /** Builds an [InternalSetOperation] from a public [SetOperation]. */ - private suspend fun buildInternalSetOperation( - setOperation: SetOperation, - measurementConsumerReferenceId: String, - ): InternalSetOperation { - return InternalMetricKt.setOperation { - this.type = setOperation.type.toInternal() - this.lhs = buildInternalOperand(setOperation.lhs, measurementConsumerReferenceId) - this.rhs = buildInternalOperand(setOperation.rhs, measurementConsumerReferenceId) - } - } - - /** Builds an [InternalOperand] from an [Operand]. */ - private suspend fun buildInternalOperand( - operand: Operand, - measurementConsumerReferenceId: String, - ): InternalOperand { - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - return when (operand.operandCase) { - Operand.OperandCase.OPERATION -> - InternalSetOperationKt.operand { - operation = buildInternalSetOperation(operand.operation, measurementConsumerReferenceId) - } - Operand.OperandCase.REPORTING_SET -> { - val reportingSetId = - grpcRequireNotNull(ReportingSetKey.fromName(operand.reportingSet)) { - "Invalid reporting set name ${operand.reportingSet}." - } - .reportingSetId - - InternalSetOperationKt.operand { - this.reportingSetId = - InternalSetOperationKt.reportingSetKey { - this.measurementConsumerReferenceId = measurementConsumerReferenceId - externalReportingSetId = apiIdToExternalId(reportingSetId) - } - } - } - Operand.OperandCase.OPERAND_NOT_SET -> InternalSetOperationKt.operand {} - } - } - - /** - * Builds a list of [MeasurementCalculation]s from a list of [WeightedMeasurement]s and a list of - * [InternalTimeInterval]s. - */ - private fun buildMeasurementCalculationList( - setOperationResult: SetOperationResult - ): List { - val measurementCalculations = mutableListOf() - setOperationResult.weightedMeasurementInfoList - .groupBy { it.reportTimeInterval } - .forEach { (reportTimeInterval, weightedMeasurementInfos) -> - measurementCalculations.add( - InternalMetricKt.measurementCalculation { - timeInterval = reportTimeInterval.toInternal() - - for (weightedMeasurementInfo in weightedMeasurementInfos) { - weightedMeasurements += - MeasurementCalculationKt.weightedMeasurement { - measurementReferenceId = - checkNotNull(weightedMeasurementInfo.kingdomMeasurementId) - coefficient = weightedMeasurementInfo.weightedMeasurement.coefficient - } - } - } - ) - } - return measurementCalculations - } - - /** Builds a [CreateMeasurementRequest]. */ - private fun buildCreateMeasurementRequest( - measurementConsumer: MeasurementConsumer, - eventGroupEntriesByDataProvider: Map>, - internalMetricDetails: InternalMetricDetails, - requestId: String, - signingConfig: SigningConfig, - dataProviderInfoMap: Map, - ): CreateMeasurementRequest { - val measurementConsumerCertificate: X509Certificate = - readCertificate(signingConfig.signingCertificateDer) - val measurementConsumerSigningKey = - SigningKeyHandle(measurementConsumerCertificate, signingConfig.signingPrivateKey) - val measurementEncryptionPublicKey: ProtoAny = measurementConsumer.publicKey.message - - return createMeasurementRequest { - parent = measurementConsumer.name - measurement = measurement { - this.measurementConsumerCertificate = signingConfig.signingCertificateName - - dataProviders += - buildDataProviderEntries( - eventGroupEntriesByDataProvider, - measurementEncryptionPublicKey, - measurementConsumerSigningKey, - dataProviderInfoMap, - ) - - val unsignedMeasurementSpec: MeasurementSpec = - buildUnsignedMeasurementSpec( - measurementEncryptionPublicKey, - dataProviders.map { it.value.nonceHash }, - internalMetricDetails, - ) - - measurementSpec = - signMeasurementSpec(unsignedMeasurementSpec, measurementConsumerSigningKey) - } - this.requestId = requestId - } - } - - /** - * Converts internal event group entries into [EventGroupEntry] messages, grouping them by - * DataProvider. - */ - private fun groupEventGroupEntriesByDataProvider( - reportingSetNames: List, - timeInterval: Interval, - eventGroupFilters: Map, - internalReportingSetMap: Map, - ): Map> { - return reportingSetNames - .flatMap { - val reportingSetKey = - grpcRequireNotNull(ReportingSetKey.fromName(it)) { "Invalid reporting set name $it" } - val internalReportingSet = - internalReportingSetMap.getValue(apiIdToExternalId(reportingSetKey.reportingSetId)) - internalReportingSet.eventGroupKeysList.map { internalEventGroupKey -> - val eventGroupKey = - EventGroupKey( - internalEventGroupKey.measurementConsumerReferenceId, - internalEventGroupKey.dataProviderReferenceId, - internalEventGroupKey.eventGroupReferenceId, - ) - val eventGroupName = eventGroupKey.toName() - val filter = - combineEventGroupFilters(internalReportingSet.filter, eventGroupFilters[eventGroupName]) - - eventGroupKey to - RequisitionSpecKt.eventGroupEntry { - key = - CmmsEventGroupKey( - internalEventGroupKey.dataProviderReferenceId, - internalEventGroupKey.eventGroupReferenceId, - ) - .toName() - value = - RequisitionSpecKt.EventGroupEntryKt.value { - collectionInterval = timeInterval - if (filter != null) { - this.filter = RequisitionSpecKt.eventFilter { expression = filter } - } - } - } - } - } - .groupBy( - { (eventGroupKey, _) -> DataProviderKey(eventGroupKey.dataProviderReferenceId) }, - { (_, eventGroupEntry) -> eventGroupEntry }, - ) - } - - /** Builds a [Map] of [DataProvider] name to [DataProviderInfo]. */ - private suspend fun buildDataProviderInfoMap( - apiAuthenticationKey: String, - dataProviderNames: Collection, - ): Map { - val dataProviderInfoMap = mutableMapOf() - - if (dataProviderNames.isEmpty()) { - return dataProviderInfoMap - } - - val deferredDataProviderInfoList = mutableListOf>() - coroutineScope { - for (dataProviderName in dataProviderNames) { - deferredDataProviderInfoList.add( - async { - val dataProvider: DataProvider = - try { - dataProvidersStub - .withAuthenticationKey(apiAuthenticationKey) - .getDataProvider(getDataProviderRequest { name = dataProviderName }) - } catch (e: StatusException) { - throw when (e.status.code) { - Status.Code.NOT_FOUND -> - Status.FAILED_PRECONDITION.withDescription("$dataProviderName not found") - else -> Status.UNKNOWN.withDescription("Unable to retrieve $dataProviderName") - } - .withCause(e) - .asRuntimeException() - } - - val certificate: Certificate = - try { - certificateStub - .withAuthenticationKey(apiAuthenticationKey) - .getCertificate(getCertificateRequest { name = dataProvider.certificate }) - } catch (e: StatusException) { - throw Exception("Unable to retrieve Certificate ${dataProvider.certificate}", e) - } - if ( - certificate.revocationState != - Certificate.RevocationState.REVOCATION_STATE_UNSPECIFIED - ) { - throw Status.FAILED_PRECONDITION.withDescription( - "${certificate.name} revocation state is ${certificate.revocationState}" - ) - .asRuntimeException() - } - - val x509Certificate: X509Certificate = readCertificate(certificate.x509Der) - val trustedIssuer: X509Certificate = - trustedCertificates[checkNotNull(x509Certificate.authorityKeyIdentifier)] - ?: throw Status.FAILED_PRECONDITION.withDescription( - "${certificate.name} not issued by trusted CA" - ) - .asRuntimeException() - try { - verifyEncryptionPublicKey(dataProvider.publicKey, x509Certificate, trustedIssuer) - } catch (e: CertPathValidatorException) { - throw Status.FAILED_PRECONDITION.withCause(e) - .withDescription("Certificate path for ${certificate.name} is invalid") - .asRuntimeException() - } catch (e: SignatureException) { - throw Status.FAILED_PRECONDITION.withCause(e) - .withDescription("DataProvider public key signature is invalid") - .asRuntimeException() - } - - DataProviderInfo(dataProvider.name, dataProvider.publicKey, certificate.name) - } - ) - } - - for (deferredDataProviderInfo in deferredDataProviderInfoList.awaitAll()) { - dataProviderInfoMap[deferredDataProviderInfo.dataProviderName] = deferredDataProviderInfo - } - } - - return dataProviderInfoMap - } - - /** Builds a [List] of [DataProviderEntry] messages from [eventGroupEntriesByDataProvider]. */ - private fun buildDataProviderEntries( - eventGroupEntriesByDataProvider: Map>, - measurementEncryptionPublicKey: ProtoAny, - measurementConsumerSigningKey: SigningKeyHandle, - dataProviderInfoMap: Map, - ): List { - return eventGroupEntriesByDataProvider.map { (dataProviderKey, eventGroupEntriesList) -> - // TODO(@SanjayVas): Consider caching the public key and certificate. - val dataProviderName: String = dataProviderKey.toName() - val dataProviderInfo = dataProviderInfoMap.getValue(dataProviderName) - - val requisitionSpec = requisitionSpec { - events = RequisitionSpecKt.events { eventGroups += eventGroupEntriesList } - measurementPublicKey = measurementEncryptionPublicKey - // TODO(world-federation-of-advertisers/cross-media-measurement#1301): Stop setting this - // field. - serializedMeasurementPublicKey = measurementEncryptionPublicKey.value - nonce = secureRandom.nextLong() - } - val encryptRequisitionSpec = - encryptRequisitionSpec( - signRequisitionSpec(requisitionSpec, measurementConsumerSigningKey), - dataProviderInfo.publicKey.unpack(), - ) - - dataProviderEntry { - key = dataProviderName - value = - DataProviderEntryKt.value { - dataProviderCertificate = dataProviderInfo.certificateName - dataProviderPublicKey = dataProviderInfo.publicKey.message - this.encryptedRequisitionSpec = encryptRequisitionSpec - nonceHash = Hashing.hashSha256(requisitionSpec.nonce) - } - } - } - } - - /** Builds the unsigned [MeasurementSpec]. */ - private fun buildUnsignedMeasurementSpec( - measurementEncryptionPublicKey: ProtoAny, - nonceHashes: List, - internalMetricDetails: InternalMetricDetails, - ): MeasurementSpec { - val isSingleDataProvider = nonceHashes.size == 1 - return measurementSpec { - measurementPublicKey = measurementEncryptionPublicKey - // TODO(world-federation-of-advertisers/cross-media-measurement#1301): Stop setting this - // field. - serializedMeasurementPublicKey = measurementEncryptionPublicKey.value - this.nonceHashes += nonceHashes - - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - when (internalMetricDetails.metricTypeCase) { - InternalMetricTypeCase.REACH -> { - if (isSingleDataProvider) { - reach = measurementSpecComponentFactory.getReachSingleDataProviderType() - vidSamplingInterval = - measurementSpecComponentFactory.getReachSingleDataProviderVidSamplingInterval() - } else { - reach = measurementSpecComponentFactory.getReachType() - vidSamplingInterval = measurementSpecComponentFactory.getReachVidSamplingInterval() - } - } - InternalMetricTypeCase.FREQUENCY_HISTOGRAM -> { - if (isSingleDataProvider) { - reachAndFrequency = - measurementSpecComponentFactory.getReachAndFrequencySingleDataProviderType( - internalMetricDetails.frequencyHistogram.maximumFrequency - ) - vidSamplingInterval = - measurementSpecComponentFactory - .getReachAndFrequencySingleDataProviderVidSamplingInterval() - } else { - reachAndFrequency = - measurementSpecComponentFactory.getReachAndFrequencyType( - internalMetricDetails.frequencyHistogram.maximumFrequency - ) - vidSamplingInterval = - measurementSpecComponentFactory.getReachAndFrequencyVidSamplingInterval() - } - } - InternalMetricTypeCase.IMPRESSION_COUNT -> { - impression = - measurementSpecComponentFactory.getImpressionType( - internalMetricDetails.impressionCount.maximumFrequencyPerUser - ) - vidSamplingInterval = measurementSpecComponentFactory.getImpressionVidSamplingInterval() - } - InternalMetricTypeCase.WATCH_DURATION -> { - duration = - measurementSpecComponentFactory.getDurationType( - internalMetricDetails.watchDuration.maximumWatchDurationPerUserSeconds - ) - vidSamplingInterval = measurementSpecComponentFactory.getDurationVidSamplingInterval() - } - InternalMetricTypeCase.METRICTYPE_NOT_SET -> - error("Unset metric type should've already raised error.") - } - } - } - - private class Metrics( - private val reportInfo: ReportInfo, - private val internalReportingSetsStub: ReportingSetsGrpcKt.ReportingSetsCoroutineStub, - private val setOperationCompiler: SetOperationCompiler, - private val report: Report, - ) { - - private val reportingSetExternalIds: MutableSet = mutableSetOf() - private lateinit var namedSetOperationResults: Map - private lateinit var internalReportingSetMap: Map - - suspend fun process() { - if (this::namedSetOperationResults.isInitialized) { - return - } - namedSetOperationResults = compileAllSetOperations() - internalReportingSetMap = getReportingSets() - internalReportingSetMap.values.forEach { - it.checkReportingSetEventGroupFilters(reportInfo.eventGroupFilters) - } - } - - suspend fun getNamedSetOperationResults(): Map { - if (!this::namedSetOperationResults.isInitialized) { - process() - } - return namedSetOperationResults - } - - suspend fun getInternalReportingSetsMap(): Map { - if (!this::namedSetOperationResults.isInitialized) { - process() - } - return internalReportingSetMap - } - - /** Compiles all [SetOperation]s and outputs each result with measurement reference ID. */ - private suspend fun compileAllSetOperations(): Map { - val namedSetOperationResults = mutableMapOf() - - var hasCumulativeMetric = false - for (metric in report.metricsList) { - if (metric.cumulative) { - hasCumulativeMetric = true - break - } - } - val timeIntervalsList = report.timeIntervalsList(hasCumulativeMetric) - val sortedTimeIntervalsList = timeIntervalsList.sortedWith(timeIntervalComparator) - val cumulativeTimeIntervalsList = - sortedTimeIntervalsList.map { timeInterval -> - timeInterval.copy { this.startTime = sortedTimeIntervalsList.first().startTime } - } - - for (metric in report.metricsList) { - val metricTimeIntervalsList = - if (metric.cumulative) cumulativeTimeIntervalsList else sortedTimeIntervalsList - val internalMetricDetails: InternalMetricDetails = buildInternalMetricDetails(metric) - - for (namedSetOperation in metric.setOperationsList) { - checkSetOperationReportingSetName(namedSetOperation.setOperation) - - val setOperationId = - buildSetOperationId( - reportInfo.reportIdempotencyKey, - internalMetricDetails, - namedSetOperation.uniqueName, - ) - - val weightedMeasurementInfoList = - compileSetOperation( - namedSetOperation.setOperation, - setOperationId, - metricTimeIntervalsList, - sortedTimeIntervalsList, - ) - namedSetOperationResults[setOperationId] = - SetOperationResult(weightedMeasurementInfoList, internalMetricDetails) - } - } - - return namedSetOperationResults.toMap() - } - - /** Checks if all reporting sets under a [SetOperation] have valid names. */ - private fun checkSetOperationReportingSetName(setOperation: SetOperation) { - checkOperandReportingSetName(setOperation.lhs) - checkOperandReportingSetName(setOperation.rhs) - } - - /** Checks if all reporting sets under a [Operand] have valid names. */ - private fun checkOperandReportingSetName(operand: Operand) { - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - when (operand.operandCase) { - Operand.OperandCase.OPERATION -> checkSetOperationReportingSetName(operand.operation) - Operand.OperandCase.REPORTING_SET -> checkReportingSetName(operand.reportingSet) - Operand.OperandCase.OPERAND_NOT_SET -> {} - } - } - - /** Check if the event groups in the public [ReportingSet] have valid names. */ - private fun checkReportingSetName(reportingSetName: String) { - val reportingSetKey = - grpcRequireNotNull(ReportingSetKey.fromName(reportingSetName)) { - "Invalid reporting set name $reportingSetName." - } - - grpcRequire( - reportingSetKey.measurementConsumerId == reportInfo.measurementConsumerReferenceId - ) { - "No access to the reporting set [$reportingSetName]." - } - - reportingSetExternalIds.add(apiIdToExternalId(reportingSetKey.reportingSetId)) - } - - /** - * Compiles a [SetOperation] and outputs each result with measurement reference ID. - * - * metricTimeIntervalsList and timeIntervalsList are required to be the same size. - */ - private suspend fun compileSetOperation( - setOperation: SetOperation, - setOperationId: String, - metricTimeIntervalsList: List, - reportTimeIntervalsList: List, - ): List { - if (metricTimeIntervalsList.size != reportTimeIntervalsList.size) { - throw IllegalArgumentException() - } - val sortedReportTimeIntervalsList = reportTimeIntervalsList.sortedWith(timeIntervalComparator) - - val weightedMeasurementsList = setOperationCompiler.compileSetOperation(setOperation) - - return metricTimeIntervalsList.sortedWith(timeIntervalComparator).flatMapIndexed { - timeIntervalsIndex, - timeInterval -> - weightedMeasurementsList.mapIndexed { index, weightedMeasurement -> - val measurementReferenceId = - buildMeasurementReferenceId(setOperationId, timeInterval, index) - - WeightedMeasurementInfo( - measurementReferenceId, - weightedMeasurement, - timeInterval = timeInterval, - reportTimeInterval = sortedReportTimeIntervalsList[timeIntervalsIndex], - ) - } - } - } - - private suspend fun getReportingSets(): Map { - val batchGetReportingSetRequest = batchGetReportingSetRequest { - measurementConsumerReferenceId = reportInfo.measurementConsumerReferenceId - reportingSetExternalIds.forEach { externalReportingSetIds += it } - } - - val internalReportingSetsList = - internalReportingSetsStub.batchGetReportingSet(batchGetReportingSetRequest).toList() - - if (internalReportingSetsList.size < reportingSetExternalIds.size) { - val errorMessage = StringBuilder("The following reporting set names were not found:") - internalReportingSetsList.forEach { - reportingSetExternalIds.remove(it.externalReportingSetId) - } - reportingSetExternalIds.forEach { - errorMessage.append( - " ${ReportingSetKey(reportInfo.measurementConsumerReferenceId, externalIdToApiId(it)).toName()}" - ) - } - failGrpc(Status.NOT_FOUND) { errorMessage.toString() } - } - - return internalReportingSetsList.associateBy { it.externalReportingSetId } - } - } - - private class MeasurementSpecComponentFactory( - private val measurementSpecConfig: MeasurementSpecConfig, - private val secureRandom: Random, - ) { - private val DEFAULT_VID_START = 0.0f - private val DEFAULT_VID_WIDTH = 1.0f - - private val reachSingleDataProviderType = - MeasurementSpecKt.reach { - privacyParams = differentialPrivacyParams { - epsilon = measurementSpecConfig.reachSingleDataProvider.privacyParams.epsilon - delta = measurementSpecConfig.reachSingleDataProvider.privacyParams.delta - } - } - - private val reachType = - MeasurementSpecKt.reach { - privacyParams = differentialPrivacyParams { - epsilon = measurementSpecConfig.reach.privacyParams.epsilon - delta = measurementSpecConfig.reach.privacyParams.delta - } - } - - private val reachAndFrequencySingleDataProviderReachPrivacyParams = differentialPrivacyParams { - epsilon = measurementSpecConfig.reachAndFrequencySingleDataProvider.reachPrivacyParams.epsilon - delta = measurementSpecConfig.reachAndFrequencySingleDataProvider.reachPrivacyParams.delta - } - - private val reachAndFrequencySingleDataProviderFrequencyPrivacyParams = - differentialPrivacyParams { - epsilon = - measurementSpecConfig.reachAndFrequencySingleDataProvider.frequencyPrivacyParams.epsilon - delta = - measurementSpecConfig.reachAndFrequencySingleDataProvider.frequencyPrivacyParams.delta - } - - private val reachAndFrequencyReachPrivacyParams = differentialPrivacyParams { - epsilon = measurementSpecConfig.reachAndFrequency.reachPrivacyParams.epsilon - delta = measurementSpecConfig.reachAndFrequency.reachPrivacyParams.delta - } - - private val reachAndFrequencyFrequencyPrivacyParams = differentialPrivacyParams { - epsilon = measurementSpecConfig.reachAndFrequency.frequencyPrivacyParams.epsilon - delta = measurementSpecConfig.reachAndFrequency.frequencyPrivacyParams.delta - } - - private val impressionPrivacyParams = differentialPrivacyParams { - epsilon = measurementSpecConfig.impression.privacyParams.epsilon - delta = measurementSpecConfig.impression.privacyParams.delta - } - - private val durationPrivacyParams = differentialPrivacyParams { - epsilon = measurementSpecConfig.duration.privacyParams.epsilon - delta = measurementSpecConfig.duration.privacyParams.delta - } - - private fun createVidSamplingInterval( - vidSamplingInterval: MeasurementSpecConfig.VidSamplingInterval - ): MeasurementSpec.VidSamplingInterval { - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") - return when (vidSamplingInterval.startCase) { - MeasurementSpecConfig.VidSamplingInterval.StartCase.FIXED_START -> - MeasurementSpecKt.vidSamplingInterval { - start = vidSamplingInterval.fixedStart.start - width = vidSamplingInterval.fixedStart.width - } - MeasurementSpecConfig.VidSamplingInterval.StartCase.RANDOM_START -> - MeasurementSpecKt.vidSamplingInterval { - start = - calculateRandomVidStart( - vidSamplingInterval.randomStart.width, - vidSamplingInterval.randomStart.numVidBuckets, - ) - width = - vidSamplingInterval.randomStart.width.toFloat() / - vidSamplingInterval.randomStart.numVidBuckets - } - MeasurementSpecConfig.VidSamplingInterval.StartCase.START_NOT_SET -> - MeasurementSpecKt.vidSamplingInterval { - start = DEFAULT_VID_START - width = DEFAULT_VID_WIDTH - } - } - } - - private fun calculateRandomVidStart(width: Int, numVidBuckets: Int): Float { - val maxStart = numVidBuckets - width - val start = secureRandom.nextInt(maxStart + 1) - return start.toFloat() / numVidBuckets - } - - fun getReachSingleDataProviderVidSamplingInterval(): MeasurementSpec.VidSamplingInterval { - val vidSamplingInterval = measurementSpecConfig.reachSingleDataProvider.vidSamplingInterval - return createVidSamplingInterval(vidSamplingInterval) - } - - fun getReachSingleDataProviderType(): MeasurementSpec.Reach { - return reachSingleDataProviderType - } - - fun getReachVidSamplingInterval(): MeasurementSpec.VidSamplingInterval { - val vidSamplingInterval = measurementSpecConfig.reach.vidSamplingInterval - return createVidSamplingInterval(vidSamplingInterval) - } - - fun getReachType(): MeasurementSpec.Reach { - return reachType - } - - fun getReachAndFrequencySingleDataProviderVidSamplingInterval(): - MeasurementSpec.VidSamplingInterval { - val vidSamplingInterval = - measurementSpecConfig.reachAndFrequencySingleDataProvider.vidSamplingInterval - return createVidSamplingInterval(vidSamplingInterval) - } - - fun getReachAndFrequencySingleDataProviderType( - maximumFrequency: Int - ): MeasurementSpec.ReachAndFrequency { - return MeasurementSpecKt.reachAndFrequency { - reachPrivacyParams = reachAndFrequencySingleDataProviderReachPrivacyParams - frequencyPrivacyParams = reachAndFrequencySingleDataProviderFrequencyPrivacyParams - this.maximumFrequency = maximumFrequency - } - } - - fun getReachAndFrequencyVidSamplingInterval(): MeasurementSpec.VidSamplingInterval { - val vidSamplingInterval = measurementSpecConfig.reachAndFrequency.vidSamplingInterval - return createVidSamplingInterval(vidSamplingInterval) - } - - fun getReachAndFrequencyType(maximumFrequency: Int): MeasurementSpec.ReachAndFrequency { - return MeasurementSpecKt.reachAndFrequency { - reachPrivacyParams = reachAndFrequencyReachPrivacyParams - frequencyPrivacyParams = reachAndFrequencyFrequencyPrivacyParams - this.maximumFrequency = maximumFrequency - } - } - - fun getImpressionVidSamplingInterval(): MeasurementSpec.VidSamplingInterval { - val vidSamplingInterval = measurementSpecConfig.impression.vidSamplingInterval - return createVidSamplingInterval(vidSamplingInterval) - } - - fun getImpressionType(maximumFrequencyPerUser: Int): MeasurementSpec.Impression { - return MeasurementSpecKt.impression { - privacyParams = impressionPrivacyParams - this.maximumFrequencyPerUser = maximumFrequencyPerUser - } - } - - fun getDurationVidSamplingInterval(): MeasurementSpec.VidSamplingInterval { - val vidSamplingInterval = measurementSpecConfig.duration.vidSamplingInterval - return createVidSamplingInterval(vidSamplingInterval) - } - - fun getDurationType(maximumWatchDurationPerUserSeconds: Int): MeasurementSpec.Duration { - return MeasurementSpecKt.duration { - privacyParams = durationPrivacyParams - maximumWatchDurationPerUser = - Durations.fromSeconds(maximumWatchDurationPerUserSeconds.toLong()) - } - } - } -} - -/** Converts the time in [Report] to a list of [TimeInterval]. */ -private fun Report.timeIntervalsList(hasCumulativeMetric: Boolean): List { - val source = this - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - return when (source.timeCase) { - Report.TimeCase.TIME_INTERVALS -> { - if (hasCumulativeMetric) { - failGrpc(Status.INVALID_ARGUMENT) { "Cannot use TimeIntervals with a cumulative Metric." } - } - grpcRequire(source.timeIntervals.timeIntervalsList.isNotEmpty()) { - "TimeIntervals timeIntervalsList is empty." - } - source.timeIntervals.timeIntervalsList.forEach { - grpcRequire(it.startTime.seconds > 0 || it.startTime.nanos > 0) { - "TimeInterval startTime is unspecified." - } - grpcRequire(it.endTime.seconds > 0 || it.endTime.nanos > 0) { - "TimeInterval endTime is unspecified." - } - grpcRequire( - it.endTime.seconds > it.startTime.seconds || it.endTime.nanos > it.startTime.nanos - ) { - "TimeInterval endTime is not later than startTime." - } - } - source.timeIntervals.timeIntervalsList.map { it } - } - Report.TimeCase.PERIODIC_TIME_INTERVAL -> { - grpcRequire( - source.periodicTimeInterval.startTime.seconds > 0 || - source.periodicTimeInterval.startTime.nanos > 0 - ) { - "PeriodicTimeInterval startTime is unspecified." - } - grpcRequire( - source.periodicTimeInterval.increment.seconds > 0 || - source.periodicTimeInterval.increment.nanos > 0 - ) { - "PeriodicTimeInterval increment is unspecified." - } - grpcRequire(source.periodicTimeInterval.intervalCount > 0) { - "PeriodicTimeInterval intervalCount is unspecified." - } - source.periodicTimeInterval.toTimeIntervalsList() - } - Report.TimeCase.TIME_NOT_SET -> - failGrpc(Status.INVALID_ARGUMENT) { "The time in Report is not specified." } - } -} - -/** - * Check if the event groups in the internal [InternalReportingSet] are covered by the event group - * universe. - */ -private fun InternalReportingSet.checkReportingSetEventGroupFilters( - eventGroupFilters: Map -) { - for (eventGroupKey in this.eventGroupKeysList) { - val eventGroupName = - EventGroupKey( - eventGroupKey.measurementConsumerReferenceId, - eventGroupKey.dataProviderReferenceId, - eventGroupKey.eventGroupReferenceId, - ) - .toName() - val internalReportingSetDisplayName = this.displayName - grpcRequire(eventGroupFilters.containsKey(eventGroupName)) { - "The event group [$eventGroupName] in the reporting set " + - "[$internalReportingSetDisplayName] is not included in the event group universe." - } - } -} - -/** Check if the names of the set operations within the same metric type are unique. */ -private fun checkSetOperationNamesUniqueness(metricsList: List) { - val seenNames = mutableMapOf>().withDefault { mutableSetOf() } - - for (metric in metricsList) { - for (setOperation in metric.setOperationsList) { - grpcRequire(!seenNames.getValue(metric.metricTypeCase).contains(setOperation.uniqueName)) { - "The names of the set operations within the same metric type should be unique." - } - seenNames.getOrPut(metric.metricTypeCase, ::mutableSetOf) += setOperation.uniqueName - } - } -} - -/** Builds a list of [InternalMeasurementKey]s from an [InternalMetric]. */ -private fun buildInternalMeasurementKeys( - internalMetric: InternalMetric, - measurementConsumerReferenceId: String, -): List { - return internalMetric.namedSetOperationsList - .flatMap { namedSetOperation -> - namedSetOperation.measurementCalculationsList.flatMap { measurementCalculation -> - measurementCalculation.weightedMeasurementsList.map { it.measurementReferenceId } - } - } - .map { measurementReferenceId -> - InternalCreateReportRequestKt.measurementKey { - this.measurementConsumerReferenceId = measurementConsumerReferenceId - this.measurementReferenceId = measurementReferenceId - } - } -} - -/** Converts an [TimeInterval] to a [Interval] for measurement request. */ -private fun TimeInterval.toMeasurementTimeInterval(): Interval { - val source = this - return interval { - startTime = source.startTime - endTime = source.endTime - } -} - -/** Builds an [InternalMetricDetails] from a [Metric]. */ -private fun buildInternalMetricDetails(metric: Metric): InternalMetricDetails { - return InternalMetricKt.details { - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - when (metric.metricTypeCase) { - MetricTypeCase.REACH -> reach = InternalMetricKt.reachParams {} - MetricTypeCase.FREQUENCY_HISTOGRAM -> - frequencyHistogram = metric.frequencyHistogram.toInternal() - MetricTypeCase.IMPRESSION_COUNT -> impressionCount = metric.impressionCount.toInternal() - MetricTypeCase.WATCH_DURATION -> watchDuration = metric.watchDuration.toInternal() - MetricTypeCase.METRICTYPE_NOT_SET -> - failGrpc(Status.INVALID_ARGUMENT) { "The metric type in Report is not specified." } - } - - cumulative = metric.cumulative - } -} - -/** Builds a unique ID for a [SetOperation]. */ -private fun buildSetOperationId( - reportIdempotencyKey: String, - internalMetricDetails: InternalMetricDetails, - setOperationUniqueName: String, -): String { - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - val metricType = - when (internalMetricDetails.metricTypeCase) { - InternalMetricTypeCase.REACH -> "Reach" - InternalMetricTypeCase.FREQUENCY_HISTOGRAM -> "FrequencyHistogram" - InternalMetricTypeCase.IMPRESSION_COUNT -> "ImpressionCount" - InternalMetricTypeCase.WATCH_DURATION -> "WatchDuration" - InternalMetricTypeCase.METRICTYPE_NOT_SET -> - error("Unset metric type should've already raised error.") - } - - return "$reportIdempotencyKey-$metricType-$setOperationUniqueName" -} - -/** Builds a unique reference ID for a [Measurement]. */ -private fun buildMeasurementReferenceId( - setOperationId: String, - timeInterval: TimeInterval, - index: Int, -): String { - val rowHeader = buildRowHeader(timeInterval) - return "$setOperationId-$rowHeader-measurement-$index" -} - -/** Combines two event group filters. */ -private fun combineEventGroupFilters(filter1: String?, filter2: String?): String? { - if (filter1.isNullOrBlank()) return filter2 - - return if (filter2.isNullOrBlank()) filter1 - else { - "($filter1) && ($filter2)" - } -} - -/** Converts a public [SetOperation.Type] to an [InternalSetOperation.Type]. */ -private fun SetOperation.Type.toInternal(): InternalSetOperation.Type { - return when (this) { - SetOperation.Type.UNION -> InternalSetOperation.Type.UNION - SetOperation.Type.INTERSECTION -> InternalSetOperation.Type.INTERSECTION - SetOperation.Type.DIFFERENCE -> InternalSetOperation.Type.DIFFERENCE - SetOperation.Type.TYPE_UNSPECIFIED -> error("Set operator type is not specified.") - SetOperation.Type.UNRECOGNIZED -> error("Unrecognized Set operator type.") - } -} - -/** Converts a [WatchDurationParams] to an [InternalWatchDurationParams]. */ -private fun WatchDurationParams.toInternal(): InternalWatchDurationParams { - val source = this - return InternalMetricKt.watchDurationParams { - maximumWatchDurationPerUserSeconds = source.maximumWatchDurationPerUser - } -} - -/** Converts a [ImpressionCountParams] to an [InternalImpressionCountParams]. */ -private fun ImpressionCountParams.toInternal(): InternalImpressionCountParams { - val source = this - return InternalMetricKt.impressionCountParams { - maximumFrequencyPerUser = source.maximumFrequencyPerUser - } -} - -/** Converts a [FrequencyHistogramParams] to an [InternalFrequencyHistogramParams]. */ -private fun FrequencyHistogramParams.toInternal(): InternalFrequencyHistogramParams { - val source = this - return InternalMetricKt.frequencyHistogramParams { - maximumFrequency = source.maximumFrequencyPerUser - } -} - -/** Converts a public [PeriodicTimeInterval] to an [InternalPeriodicTimeInterval]. */ -private fun PeriodicTimeInterval.toInternal(): InternalPeriodicTimeInterval { - val source = this - return internalPeriodicTimeInterval { - startTime = source.startTime - increment = source.increment - intervalCount = source.intervalCount - } -} - -/** Converts a public [TimeInterval] to an [InternalTimeInterval]. */ -private fun TimeInterval.toInternal(): InternalTimeInterval { - val source = this - return internalTimeInterval { - startTime = source.startTime - endTime = source.endTime - } -} - -/** Converts a public [TimeIntervals] to an [InternalTimeIntervals]. */ -private fun TimeIntervals.toInternal(): InternalTimeIntervals { - val source = this - return internalTimeIntervals { - for (timeInternal in source.timeIntervalsList) { - this.timeIntervals += internalTimeInterval { - startTime = timeInternal.startTime - endTime = timeInternal.endTime - } - } - } -} - -/** Convert an [PeriodicTimeInterval] to a list of [TimeInterval]s. */ -private fun PeriodicTimeInterval.toTimeIntervalsList(): List { - val source = this - var startTime = checkNotNull(source.startTime) - return (0 until source.intervalCount).map { - timeInterval { - this.startTime = startTime - this.endTime = Timestamps.add(startTime, source.increment) - startTime = this.endTime - } - } -} - -/** Builds a row header in String from an [TimeInterval]. */ -private fun buildRowHeader(timeInterval: TimeInterval): String { - val startTimeInstant = - Instant.ofEpochSecond(timeInterval.startTime.seconds, timeInterval.startTime.nanos.toLong()) - val endTimeInstant = - Instant.ofEpochSecond(timeInterval.endTime.seconds, timeInterval.endTime.nanos.toLong()) - return "$startTimeInstant-$endTimeInstant" -} - -private operator fun ProtoDuration.plus(other: ProtoDuration): ProtoDuration { - return Durations.add(this, other) -} - -/** Converts a CMM [Measurement.Failure] to an [InternalMeasurement.Failure]. */ -private fun Measurement.Failure.toInternal(): InternalMeasurement.Failure { - val source = this - - return InternalMeasurementKt.failure { - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - reason = - when (source.reason) { - Measurement.Failure.Reason.REASON_UNSPECIFIED -> - InternalMeasurement.Failure.Reason.REASON_UNSPECIFIED - Measurement.Failure.Reason.CERTIFICATE_REVOKED -> - InternalMeasurement.Failure.Reason.CERTIFICATE_REVOKED - Measurement.Failure.Reason.REQUISITION_REFUSED -> - InternalMeasurement.Failure.Reason.REQUISITION_REFUSED - Measurement.Failure.Reason.COMPUTATION_PARTICIPANT_FAILED -> - InternalMeasurement.Failure.Reason.COMPUTATION_PARTICIPANT_FAILED - Measurement.Failure.Reason.UNRECOGNIZED -> InternalMeasurement.Failure.Reason.UNRECOGNIZED - } - message = source.message - } -} - -/** Aggregate a list of [InternalMeasurementResult] to a [InternalMeasurementResult] */ -private fun aggregateResults( - internalResultsList: List -): InternalMeasurementResult { - if (internalResultsList.isEmpty()) { - error("No measurement result.") - } - - var reachValue = 0L - var impressionValue = 0L - val frequencyDistribution = mutableMapOf() - var watchDurationValue = ProtoDuration.getDefaultInstance() - - // Aggregation - for (result in internalResultsList) { - if (result.hasFrequency()) { - if (!result.hasReach()) { - error("Missing reach measurement in the Reach-Frequency measurement.") - } - for ((frequency, percentage) in result.frequency.relativeFrequencyDistributionMap) { - val previousTotalReachCount = - frequencyDistribution.getOrDefault(frequency, 0.0) * reachValue - val currentReachCount = percentage * result.reach.value - frequencyDistribution[frequency] = - (previousTotalReachCount + currentReachCount) / (reachValue + result.reach.value) - } - } - if (result.hasReach()) { - reachValue += result.reach.value - } - if (result.hasImpression()) { - impressionValue += result.impression.value - } - if (result.hasWatchDuration()) { - watchDurationValue += result.watchDuration.value - } - } - - return InternalMeasurementKt.result { - if (internalResultsList.first().hasReach()) { - this.reach = InternalMeasurementKt.ResultKt.reach { value = reachValue } - } - if (internalResultsList.first().hasFrequency()) { - this.frequency = - InternalMeasurementKt.ResultKt.frequency { - relativeFrequencyDistribution.putAll(frequencyDistribution) - } - } - if (internalResultsList.first().hasImpression()) { - this.impression = InternalMeasurementKt.ResultKt.impression { value = impressionValue } - } - if (internalResultsList.first().hasWatchDuration()) { - this.watchDuration = - InternalMeasurementKt.ResultKt.watchDuration { value = watchDurationValue } - } - } -} - -/** Converts a CMM [Measurement.Result] to an [InternalMeasurementResult]. */ -private fun Measurement.Result.toInternal(): InternalMeasurementResult { - val source = this - - return InternalMeasurementKt.result { - if (source.hasReach()) { - this.reach = InternalMeasurementKt.ResultKt.reach { value = source.reach.value } - } - if (source.hasFrequency()) { - this.frequency = - InternalMeasurementKt.ResultKt.frequency { - relativeFrequencyDistribution.putAll(source.frequency.relativeFrequencyDistributionMap) - } - } - if (source.hasImpression()) { - this.impression = - InternalMeasurementKt.ResultKt.impression { value = source.impression.value } - } - if (source.hasWatchDuration()) { - this.watchDuration = - InternalMeasurementKt.ResultKt.watchDuration { value = source.watchDuration.value } - } - } -} - -/** Converts an internal [InternalReport] to a public [Report]. */ -private fun InternalReport.toReport(): Report { - val source = this - val reportResourceName = - ReportKey( - measurementConsumerId = source.measurementConsumerReferenceId, - reportId = externalIdToApiId(source.externalReportId), - ) - .toName() - val measurementConsumerResourceName = - MeasurementConsumerKey(source.measurementConsumerReferenceId).toName() - val eventGroupEntries = - source.details.eventGroupFiltersMap.map { (eventGroupResourceName, filterPredicate) -> - EventGroupUniverseKt.eventGroupEntry { - key = eventGroupResourceName - value = filterPredicate - } - } - - return report { - name = reportResourceName - reportIdempotencyKey = source.reportIdempotencyKey - measurementConsumer = measurementConsumerResourceName - eventGroupUniverse = eventGroupUniverse { this.eventGroupEntries += eventGroupEntries } - - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - when (source.timeCase) { - InternalReport.TimeCase.TIME_INTERVALS -> - this.timeIntervals = source.timeIntervals.toTimeIntervals() - InternalReport.TimeCase.PERIODIC_TIME_INTERVAL -> - this.periodicTimeInterval = source.periodicTimeInterval.toPeriodicTimeInterval() - InternalReport.TimeCase.TIME_NOT_SET -> - error("The time in the internal report should've been set.") - } - - for (metric in source.metricsList) { - this.metrics += metric.toMetric() - } - - this.state = source.state.toState() - if (source.details.hasResult()) { - this.result = source.details.result.toResult() - } - } -} - -/** Converts an [InternalReport.State] to a public [Report.State]. */ -private fun InternalReport.State.toState(): Report.State { - return when (this) { - InternalReport.State.RUNNING -> Report.State.RUNNING - InternalReport.State.SUCCEEDED -> Report.State.SUCCEEDED - InternalReport.State.FAILED -> Report.State.FAILED - InternalReport.State.STATE_UNSPECIFIED -> error("Report state should've been set.") - InternalReport.State.UNRECOGNIZED -> error("Unrecognized report state.") - } -} - -/** Converts an [InternalReport.Details.Result] to a public [Report.Result]. */ -private fun InternalReport.Details.Result.toResult(): Result { - val source = this - return result { - scalarTable = scalarTable { - rowHeaders += source.scalarTable.rowHeadersList - for (sourceColumn in source.scalarTable.columnsList) { - columns += column { - columnHeader = sourceColumn.columnHeader - setOperations += sourceColumn.setOperationsList - } - } - } - for (sourceHistogram in source.histogramTablesList) { - histogramTables += histogramTable { - for (sourceRow in sourceHistogram.rowsList) { - rows += row { - rowHeader = sourceRow.rowHeader - frequency = sourceRow.frequency - } - } - for (sourceColumn in sourceHistogram.columnsList) { - columns += column { - columnHeader = sourceColumn.columnHeader - setOperations += sourceColumn.setOperationsList - } - } - } - } - } -} - -/** Converts an internal [InternalMetric] to a public [Metric]. */ -private fun InternalMetric.toMetric(): Metric { - val source = this - - return metric { - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - when (source.details.metricTypeCase) { - InternalMetricTypeCase.REACH -> reach = reachParams {} - InternalMetricTypeCase.FREQUENCY_HISTOGRAM -> - frequencyHistogram = source.details.frequencyHistogram.toFrequencyHistogram() - InternalMetricTypeCase.IMPRESSION_COUNT -> - impressionCount = source.details.impressionCount.toImpressionCount() - InternalMetricTypeCase.WATCH_DURATION -> - watchDuration = source.details.watchDuration.toWatchDuration() - InternalMetricTypeCase.METRICTYPE_NOT_SET -> - error("The metric type in the internal report should've been set.") - } - - cumulative = source.details.cumulative - - for (internalSetOperation in source.namedSetOperationsList) { - setOperations += internalSetOperation.toNamedSetOperation() - } - } -} - -/** Converts an internal [InternalNamedSetOperation] to a public [NamedSetOperation]. */ -private fun InternalNamedSetOperation.toNamedSetOperation(): NamedSetOperation { - val source = this - - return namedSetOperation { - uniqueName = source.displayName - setOperation = source.setOperation.toSetOperation() - } -} - -/** Converts an internal [InternalSetOperation] to a public [SetOperation]. */ -private fun InternalSetOperation.toSetOperation(): SetOperation { - val source = this - - return setOperation { - this.type = source.type.toType() - this.lhs = source.lhs.toOperand() - this.rhs = source.rhs.toOperand() - } -} - -/** Converts an internal [InternalOperand] to a public [Operand]. */ -private fun InternalOperand.toOperand(): Operand { - val source = this - - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - return when (source.operandCase) { - InternalOperand.OperandCase.OPERATION -> - operand { operation = source.operation.toSetOperation() } - InternalOperand.OperandCase.REPORTINGSETID -> - operand { - reportingSet = - ReportingSetKey( - source.reportingSetId.measurementConsumerReferenceId, - externalIdToApiId(source.reportingSetId.externalReportingSetId), - ) - .toName() - } - InternalOperand.OperandCase.OPERAND_NOT_SET -> operand {} - } -} - -/** Converts an internal [InternalSetOperation.Type] to a public [SetOperation.Type]. */ -private fun InternalSetOperation.Type.toType(): SetOperation.Type { - return when (this) { - InternalSetOperation.Type.UNION -> SetOperation.Type.UNION - InternalSetOperation.Type.INTERSECTION -> SetOperation.Type.INTERSECTION - InternalSetOperation.Type.DIFFERENCE -> SetOperation.Type.DIFFERENCE - InternalSetOperation.Type.TYPE_UNSPECIFIED -> error("Set operator type should've been set.") - InternalSetOperation.Type.UNRECOGNIZED -> error("Unrecognized Set operator type.") - } -} - -/** Converts an internal [InternalWatchDurationParams] to a public [WatchDurationParams]. */ -private fun InternalWatchDurationParams.toWatchDuration(): WatchDurationParams { - val source = this - return watchDurationParams { - maximumWatchDurationPerUser = source.maximumWatchDurationPerUserSeconds - } -} - -/** Converts an internal [InternalImpressionCountParams] to a public [ImpressionCountParams]. */ -private fun InternalImpressionCountParams.toImpressionCount(): ImpressionCountParams { - val source = this - return impressionCountParams { maximumFrequencyPerUser = source.maximumFrequencyPerUser } -} - -/** - * Converts an internal [InternalFrequencyHistogramParams] to a public [FrequencyHistogramParams]. - */ -private fun InternalFrequencyHistogramParams.toFrequencyHistogram(): FrequencyHistogramParams { - val source = this - return frequencyHistogramParams { maximumFrequencyPerUser = source.maximumFrequency } -} - -/** Converts an internal [InternalPeriodicTimeInterval] to a public [PeriodicTimeInterval]. */ -private fun InternalPeriodicTimeInterval.toPeriodicTimeInterval(): PeriodicTimeInterval { - val source = this - return periodicTimeInterval { - startTime = source.startTime - increment = source.increment - intervalCount = source.intervalCount - } -} - -/** Converts an internal [InternalTimeIntervals] to a public [TimeIntervals]. */ -private fun InternalTimeIntervals.toTimeIntervals(): TimeIntervals { - val source = this - return timeIntervals { - for (internalTimeInternal in source.timeIntervalsList) { - this.timeIntervals += timeInterval { - startTime = internalTimeInternal.startTime - endTime = internalTimeInternal.endTime - } - } - } -} - -/** Converts an internal [ListReportsPageToken] to an internal [StreamInternalReportsRequest]. */ -private fun ListReportsPageToken.toStreamReportsRequest(): StreamInternalReportsRequest { - val source = this - return streamInternalReportsRequest { - // get 1 more than the actual page size for deciding whether or not to set page token - limit = pageSize + 1 - filter = filter { - measurementConsumerReferenceId = source.measurementConsumerReferenceId - externalReportIdAfter = source.lastReport.externalReportId - } - } -} - -/** Converts a public [ListReportsRequest] to an internal [ListReportsPageToken]. */ -private fun ListReportsRequest.toListReportsPageToken(): ListReportsPageToken { - grpcRequire(pageSize >= 0) { "Page size cannot be less than 0" } - - val source = this - val parentKey: MeasurementConsumerKey = - grpcRequireNotNull(MeasurementConsumerKey.fromName(parent)) { - "Parent is either unspecified or invalid." - } - val measurementConsumerReferenceId = parentKey.measurementConsumerId - - val isValidPageSize = - source.pageSize != 0 && source.pageSize >= MIN_PAGE_SIZE && source.pageSize <= MAX_PAGE_SIZE - - return if (pageToken.isNotBlank()) { - ListReportsPageToken.parseFrom(pageToken.base64UrlDecode()).copy { - grpcRequire(this.measurementConsumerReferenceId == measurementConsumerReferenceId) { - "Arguments must be kept the same when using a page token" - } - - if (isValidPageSize) { - pageSize = source.pageSize - } - } - } else { - listReportsPageToken { - pageSize = - when { - source.pageSize < MIN_PAGE_SIZE -> DEFAULT_PAGE_SIZE - source.pageSize > MAX_PAGE_SIZE -> MAX_PAGE_SIZE - else -> source.pageSize - } - this.measurementConsumerReferenceId = measurementConsumerReferenceId - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/SetOperationCompiler.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/SetOperationCompiler.kt deleted file mode 100644 index b2be4a9f684..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/SetOperationCompiler.kt +++ /dev/null @@ -1,500 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import kotlin.math.pow -import kotlinx.coroutines.coroutineScope -import kotlinx.coroutines.launch -import org.wfanet.measurement.reporting.v1alpha.Metric.SetOperation - -/** - * A primitive region of a Venn diagram is the intersection of a set of reporting sets, and it is - * represented by a bit representation of an integer. Only the reporting sets with IDs equal to the - * bit positions of set bits constitute the primitive region. Ex: Given a Venn Diagram of 3 - * reporting sets (rs0, rs1, rs2), a primitive region with an integer value equal to 3 has the bit - * representation b’011’. This means this primitive region is only covered by the intersection of - * rs0 and rs1 and not covered by rs2 (the order of the bit positions is from right to left). In - * other words, rs0 INTERSECT rs1 INTERSECT COMPLEMENT(rs2). Note that primitive regions are - * disjoint. - */ -private typealias PrimitiveRegion = ULong - -/** - * A union set is the union of a set of reporting sets, and it is represented by a bit - * representation of an integer. Only the reporting sets with IDs equal to the bit positions of set - * bits constitute the union set. Given a Venn Diagram of 3 reporting sets (rs0, rs1, rs2), a - * union-set with an integer value equal to 3 has the bit representation b’011’. This means this - * union-set is only covered by the union of rs0 and rs1 (the order of the bit positions is from - * right to left). - */ -private typealias UnionSet = ULong - -private typealias NumberReportingSets = Int - -/** A mapping from a [UnionSet] to its coefficient in the Venn diagram region decomposition. */ -private typealias UnionSetCoefficientMap = Map - -/** - * A mapping for cardinality computation from a [PrimitiveRegion] to its decomposition in terms of - * union-sets represented by [UnionSetCoefficientMap]. Take a case of 3 reporting sets (rs0, rs1, - * rs2) as an example. A primitive region with its value equal to 3 (b'011') means rs0 INTERSECT rs1 - * INTERSECT COMPLEMENT(rs2). The decomposition of the cardinality of the region = - * PrimitiveRegionToUnionSetCoefficientMap\[region\] = {4: -1, 5: 1, 6: 1, 7: -1}, i.e. |union-set5| - * + |union-set6| - |union-set4| - |union-set7|. - */ -private typealias PrimitiveRegionToUnionSetCoefficientMap = - MutableMap - -/** - * A memory cache that stores the Venn diagram region cardinality decompositions for different - * numbers of reporting sets. - */ -private typealias PrimitiveRegionCache = - MutableMap - -private enum class Operator { - UNION, - INTERSECT, - DIFFERENCE -} - -private interface Operand - -private data class ReportingSet(val id: Int, val resourceName: String) : Operand - -private data class SetOperationExpression( - val setOperator: Operator, - val lhs: Operand, - val rhs: Operand?, -) : Operand - -data class WeightedMeasurement(val reportingSets: List, val coefficient: Int) - -class SetOperationCompiler { - - private var primitiveRegionCache: PrimitiveRegionCache = mutableMapOf() - - // For unit test only. - fun getPrimitiveRegionCache(): - Map> { - return primitiveRegionCache.mapValues { it.value.toMap() }.toMap() - } - - /** - * Compiles a set operation to a list of [WeightedMeasurement]s which will be used for the - * cardinality computation. For example, given a set = primitiveRegion1 UNION primitiveRegion2, - * Count(set) = Count(primitiveRegion1) + Count(primitiveRegion2) = Count(unionSet1) - - * Count(unionSet2) + Count(unionSet3) - Count(unionSet2) = Count(unionSet1) + Count(unionSet3) - - * 2 * Count(unionSet2). - */ - suspend fun compileSetOperation(setOperation: SetOperation): List { - val reportingSetNames = mutableSetOf() - setOperation.storeReportingSetNames(reportingSetNames) - - // Sorts the list in alphabetical order to make sure the IDs are consistent for the same run. - val sortedReportingSetNames = reportingSetNames.sortedBy { it } - val reportingSetsMap = createReportingSetsMap(sortedReportingSetNames) - val numReportingSets = reportingSetsMap.size - - val setOperationExpression = setOperation.toSetOperationExpression(reportingSetsMap) - - // Step 1 - Gets the primitive regions that form the set operation - val primitiveRegions = - setOperationExpressionToPrimitiveRegions(numReportingSets, setOperationExpression) - - // Step 2 - Converts a set of primitive regions to a map of union-set to its coefficients for - // cardinality computation. - val unionSetCoefficientMap = - convertPrimitiveRegionsToUnionSetCoefficientMap(numReportingSets, primitiveRegions) - - return unionSetCoefficientMap.map { (unionSet, coefficient) -> - convertUnionSetToWeightedMeasurements(unionSet, coefficient, sortedReportingSetNames) - } - } - - /** Converts unionSetCoefficientMap to WeightedMeasurements. */ - private fun convertUnionSetToWeightedMeasurements( - unionSet: UnionSet, - coefficient: Int, - sortedReportingSetNames: List, - ): WeightedMeasurement { - // Find the reporting sets in the union-set. - val reportingSetNames = - (sortedReportingSetNames.indices).mapNotNull { bitPosition -> - if (isBitSet(unionSet, bitPosition)) sortedReportingSetNames[bitPosition] else null - } - - return WeightedMeasurement(reportingSetNames, coefficient) - } - - /** - * Converts a set of primitive regions to a map of union-set to its coefficients for cardinality - * computation - */ - private suspend fun convertPrimitiveRegionsToUnionSetCoefficientMap( - numReportingSets: Int, - primitiveRegions: Set, - ): UnionSetCoefficientMap { - - val primitiveRegionsToUnionSetCoefficients: PrimitiveRegionToUnionSetCoefficientMap = - mutableMapOf() - - coroutineScope { - for (region in primitiveRegions) { - // Reuse previous computation if available - if ( - reusePreviousComputation(numReportingSets, region, primitiveRegionsToUnionSetCoefficients) - ) { - continue - } - - launch { - convertSinglePrimitiveRegionToUnionSetCoefficientMap( - numReportingSets, - region, - primitiveRegionsToUnionSetCoefficients, - ) - } - } - } - - // Updates the memory cache with new computation result. - primitiveRegionCache.getOrPut(numReportingSets, ::mutableMapOf) += - primitiveRegionsToUnionSetCoefficients - - return aggregateCoefficientsByUnionSets(primitiveRegionsToUnionSetCoefficients) - } - - /** Aggregates the coefficients by union-sets. */ - private fun aggregateCoefficientsByUnionSets( - primitiveRegionsToUnionSetCoefficients: PrimitiveRegionToUnionSetCoefficientMap - ): UnionSetCoefficientMap { - val aggregatedResult = mutableMapOf() - for ((_, unionSetCoefficients) in primitiveRegionsToUnionSetCoefficients) { - for ((unionSet, coefficient) in unionSetCoefficients) { - aggregatedResult[unionSet] = aggregatedResult.getOrDefault(unionSet, 0) + coefficient - - // Remove the entry if its coefficient is zero. - if (aggregatedResult[unionSet] == 0) { - aggregatedResult.remove(unionSet) - } - } - } - // Sort the aggregatedResult to make sure the result is consistent every time. - return aggregatedResult.toSortedMap().toMap() - } - - /** - * Converts a single primitive region to a map of union-set to its coefficients, where the - * cardinality of the input primitive region is equal to the linear combination of the - * cardinalities of the union-sets with the coefficients. - * - * The algorithm is based on the observation on the linear transformation matrix from primitive - * regions to union-sets. Ex: - * ``` - * b'01' b'10' b'11' - * A 0 -1 1 - * B -1 0 1 - * A U B 1 1 -1 - * ``` - */ - private fun convertSinglePrimitiveRegionToUnionSetCoefficientMap( - numReportingSets: Int, - region: PrimitiveRegion, - primitiveRegionsToUnionSetCoefficients: PrimitiveRegionToUnionSetCoefficientMap, - ) { - // For a given region, we first find which bit positions are set and which are not. - val setBitPositions = mutableListOf() - val unsetBitPositions = mutableListOf() - - for (bitPosition in 0 until numReportingSets) { - if (isBitSet(region, bitPosition)) { - setBitPositions.add(bitPosition) - } else { - unsetBitPositions.add(bitPosition) - } - } - val primitiveRegionWeight = setBitPositions.size - - // Always starts from -1 unless the primitive region = (2^numReportingSets - 1) = b'11...1'. - val baseSign = if (primitiveRegionWeight != numReportingSets) -1 else 1 - var count = 0 - - val unionSetCoefficients = mutableMapOf() - - for (size in 1..numReportingSets) { - // Skips it if the union-only set is too light - if (size + primitiveRegionWeight < numReportingSets) { - continue - } - - // Instead of flipping the sign at the end of the loop, this could avoid the race condition. - count++ - val sign = if (count % 2 == 1) baseSign else -baseSign - - val composingUnionSets = findComposingUnionSets(setBitPositions, unsetBitPositions, size) - unionSetCoefficients += composingUnionSets.associateWith { sign } - } - - if (unionSetCoefficients.isNotEmpty()) { - primitiveRegionsToUnionSetCoefficients[region] = unionSetCoefficients.toMap() - } - } - - /** Reuses previous result in the memory cache if there is any. */ - private fun reusePreviousComputation( - numReportingSets: Int, - region: PrimitiveRegion, - primitiveRegionsToUnionSetCoefficients: PrimitiveRegionToUnionSetCoefficientMap, - ): Boolean { - // If the compiler has already run the case where the number of reporting sets equal to - // `numReportingSets`. - primitiveRegionCache[numReportingSets]?.also { cachedPrimitiveRegionsToUnionSetCoefficients -> - // If the compiler has calculated this region before. - cachedPrimitiveRegionsToUnionSetCoefficients[region]?.also { cachedUnionSetCoefficientMap -> - primitiveRegionsToUnionSetCoefficients[region] = cachedUnionSetCoefficientMap - } - } - return primitiveRegionsToUnionSetCoefficients.containsKey(region) - } - - /** - * Finds the union-sets which will be part of the combination to form the target region. - * Essentially, given the size of a combination, we are finding all the combinations of the bit - * positions where at least unset bit positions are selected. - */ - private fun findComposingUnionSets( - setBitPositions: MutableList, - unsetBitPositions: MutableList, - size: Int, - ): MutableList { - val composingUnionSets = mutableListOf() - - // If the size is not large enough to at least contain all unset bit positions or the size is - // too large to fill, return empty result. - if (unsetBitPositions.size > size || setBitPositions.size + unsetBitPositions.size < size) { - return composingUnionSets - } - - findValidUnionSets(size, 0, setBitPositions, unsetBitPositions, composingUnionSets) - - return composingUnionSets - } - - /** Finds the valid combinations as [UnionSet]s using backtracking. */ - private fun findValidUnionSets( - size: Int, - start: Int, - choices: MutableList, - combination: MutableList, - result: MutableList, - ) { - if (combination.size == size) { - result.add(combination.sumOf { 1.toUnionSet() shl it }) - return - } - - for (i in start until choices.size) { - combination.add(choices[i]) - findValidUnionSets(size, i + 1, choices, combination, result) - combination.removeLast() - } - - return - } - - /** Gets the set of the primitive regions that form the set from the set operation expression. */ - private fun setOperationExpressionToPrimitiveRegions( - numReportingSets: Int, - setOperationExpression: SetOperationExpression, - ): Set { - val allPrimitiveRegionSetsList = buildAllPrimitiveRegions(numReportingSets) - return setOperationExpression.decompose(allPrimitiveRegionSetsList) - } -} - -/** - * Decomposes the set operation expression to a set of primitive regions by calculating the set - * operation between each two operands. - */ -private fun SetOperationExpression.decompose( - allPrimitiveRegionSetsList: List> -): Set { - val source = this - val lhsPrimitiveRegions = source.lhs.decompose(allPrimitiveRegionSetsList) - val rhsPrimitiveRegions = source.rhs.decompose(allPrimitiveRegionSetsList) - return calculateBinarySetOperation(lhsPrimitiveRegions, rhsPrimitiveRegions, source.setOperator) -} - -/** Decomposes the operand to a set of primitive regions. */ -private fun Operand?.decompose( - allPrimitiveRegionSetsList: List> -): Set { - return when (val operand = this) { - is SetOperationExpression -> { - operand.decompose(allPrimitiveRegionSetsList) - } - is ReportingSet -> { - allPrimitiveRegionSetsList[operand.id] - } - else -> setOf() - } -} - -/** Calculates the binary set operation. */ -private fun calculateBinarySetOperation( - lhs: Set, - rhs: Set, - operator: Operator, -): Set { - return when (operator) { - Operator.UNION -> lhs union rhs - Operator.INTERSECT -> lhs intersect rhs - Operator.DIFFERENCE -> lhs subtract rhs - } -} - -/** - * Builds a list of primitive regions where the index represents the reporting set ID and the - * element is the set of primitive regions which forms the corresponding reporting set. For example, - * if reportingSetId = 1, then allPrimitiveRegionSetsList\[reportingSetId\] = setOf(1(=b’001’), - * 3(=b’011’), 5(=b’101’), 7(=b’111’)). - */ -private fun buildAllPrimitiveRegions(numReportingSets: Int): List> { - val numPrimitiveRegions = 2.0.pow(numReportingSets).toPrimitiveRegion() - 1.toPrimitiveRegion() - val allPrimitiveRegionSetsList: List> = - List(numReportingSets) { mutableSetOf() } - - // A region is in the set of reportingSet when its bit at bit position == reportingSetId is set. - for (region in 1.toPrimitiveRegion()..numPrimitiveRegions) { - for (reportingSetId in 0 until numReportingSets) { - if (isBitSet(region, reportingSetId)) { - allPrimitiveRegionSetsList[reportingSetId].add(region) - } - } - } - - return allPrimitiveRegionSetsList.map(MutableSet::toSet) -} - -/** Converts a [Int] to a [PrimitiveRegion] */ -private fun Int.toPrimitiveRegion(): PrimitiveRegion { - return this.toULong() -} - -/** Converts a [Double] to a [PrimitiveRegion] */ -private fun Double.toPrimitiveRegion(): PrimitiveRegion { - return this.toULong() -} - -/** Converts a [Int] to a [UnionSet] */ -private fun Int.toUnionSet(): UnionSet { - return this.toULong() -} - -/** Checks if the bit at `bitPosition` of a number is set or not. */ -fun isBitSet(number: ULong, bitPosition: Int): Boolean { - return (number and (1UL shl bitPosition)) != 0UL -} - -/** Creates a map of resource names of reporting sets to [ReportingSet]s. */ -private fun createReportingSetsMap( - sortedReportingSetNames: List -): Map { - val reportingSetsMap: MutableMap = mutableMapOf() - for ((id, reportingSetName) in sortedReportingSetNames.withIndex()) { - reportingSetsMap[reportingSetName] = ReportingSet(id, reportingSetName) - } - return reportingSetsMap.toMap() -} - -/** Gets all resource names of the reporting sets used in this [SetOperation]. */ -private fun SetOperation.storeReportingSetNames(reportingSetNames: MutableSet) { - val root = this - if (!root.hasLhs()) { - throw IllegalArgumentException("lhs in SetOperation must be set.") - } - if (!root.lhs.hasReportingSet() && !root.lhs.hasOperation()) { - throw IllegalArgumentException("Operand type of lhs in SetOperation must be set.") - } - - root.lhs.storeReportingSetNames(reportingSetNames) - root.rhs.storeReportingSetNames(reportingSetNames) -} - -/** Gets all resource names of the reporting sets used in this [SetOperation.Operand]. */ -private fun SetOperation.Operand.storeReportingSetNames(reportingSetNames: MutableSet) { - val node = this - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - when (node.operandCase) { - // Leaf node - SetOperation.Operand.OperandCase.REPORTING_SET -> reportingSetNames.add(node.reportingSet) - SetOperation.Operand.OperandCase.OPERATION -> { - node.operation.storeReportingSetNames(reportingSetNames) - } - // Empty node. No further action. - SetOperation.Operand.OperandCase.OPERAND_NOT_SET -> return - } -} - -/** Converts a public [SetOperation] to a [SetOperationExpression]. */ -private fun SetOperation.toSetOperationExpression( - reportingSetsMap: Map -): SetOperationExpression { - val root = this - - if (!root.hasLhs()) { - throw IllegalArgumentException("lhs in SetOperation must be set.") - } - - val lhs = - root.lhs.toOperand(reportingSetsMap) - ?: throw IllegalArgumentException("Operand type of lhs in SetOperation must be set.") - - val rhs = root.rhs.toOperand(reportingSetsMap) - - return SetOperationExpression(root.type.toOperator(), lhs, rhs) -} - -/** Converts a public [SetOperation.Operand] to a nullable [Operand]. */ -private fun SetOperation.Operand.toOperand(reportingSetsMap: Map): Operand? { - val source = this - - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - return when (source.operandCase) { - SetOperation.Operand.OperandCase.REPORTING_SET -> { - return reportingSetsMap[source.reportingSet] - } - SetOperation.Operand.OperandCase.OPERATION -> - return source.operation.toSetOperationExpression(reportingSetsMap) - SetOperation.Operand.OperandCase.OPERAND_NOT_SET -> null - } -} - -/** Converts a public [SetOperation.Type] to a [Operator]. */ -private fun SetOperation.Type.toOperator(): Operator { - val source = this - - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - return when (source) { - SetOperation.Type.TYPE_UNSPECIFIED -> - throw IllegalArgumentException("Set operator type is not specified.") - SetOperation.Type.UNION -> Operator.UNION - SetOperation.Type.DIFFERENCE -> Operator.DIFFERENCE - SetOperation.Type.INTERSECTION -> Operator.INTERSECT - SetOperation.Type.UNRECOGNIZED -> - throw IllegalArgumentException("Unrecognized Set operator type.") - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/BUILD.bazel deleted file mode 100644 index eba73091447..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/BUILD.bazel +++ /dev/null @@ -1,32 +0,0 @@ -load("@rules_java//java:defs.bzl", "java_binary") -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") - -package( - default_visibility = ["//src/test/kotlin/org/wfanet/measurement/reporting:__subpackages__"], -) - -kt_jvm_library( - name = "reporting", - srcs = ["Reporting.kt"], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common:flags", - "//src/main/proto/wfa/measurement/reporting/v1alpha:event_group_kt_jvm_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:event_groups_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:report_kt_jvm_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:reporting_set_kt_jvm_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:reporting_sets_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:reports_service_kt_jvm_grpc_proto", - "@wfa_common_jvm//imports/java/picocli", - "@wfa_common_jvm//imports/kotlin/com/google/protobuf/kotlin", - "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", - ], -) - -java_binary( - name = "Reporting", - main_class = "org.wfanet.measurement.reporting.service.api.v1alpha.tools.ReportingKt", - tags = ["manual"], - runtime_deps = [":reporting"], -) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/README.md b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/README.md deleted file mode 100644 index 1126c7effd8..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/README.md +++ /dev/null @@ -1,138 +0,0 @@ -# Reporting CLI Tools - -Command-line tools for Reporting API. Use the `help` subcommand for help with -any of the subcommands. - -Note that instead of specifying arguments on the command-line, you can specify -an arguments file using `@` followed by the path. For example, - -```shell -Reporting @/home/foo/args.txt -``` - -## Certificate Host - -In the event that the host you specify to the `--reporting-server-api-target` -option doesn't match what's in the Subject Alternative Name (SAN) extension of -the server's certificate, you'll need to specify a host that does match using -the `--reporting-server-api-cert-host` option. - -## Examples - -### reporting-sets - -#### create - -```shell -Reporting \ - --tls-cert-file=src/main/k8s/testing/secretfiles/mc_tls.pem \ - --tls-key-file=src/main/k8s/testing/secretfiles/mc_tls.key \ - --cert-collection-file src/main/k8s/testing/secretfiles/reporting_root.pem \ - --reporting-server-api-target v1alpha.reporting.dev.halo-cmm.org:8443 \ - reporting-sets create --parent=measurementConsumers/VCTqwV_vFXw \ - --event-group=measurementConsumers/VCTqwV_vFXw/dataProviders/1/eventGroups/1 \ - --event-group=measurementConsumers/VCTqwV_vFXw/dataProviders/1/eventGroups/2 \ - --event-group=measurementConsumers/VCTqwV_vFXw/dataProviders/2/eventGroups/1 \ - --filter='video_ad.age == 1' --display-name='test-reporting-set' -``` - -#### list - -```shell -Reporting \ - --tls-cert-file=src/main/k8s/testing/secretfiles/mc_tls.pem \ - --tls-key-file=src/main/k8s/testing/secretfiles/mc_tls.key \ - --cert-collection-file src/main/k8s/testing/secretfiles/reporting_root.pem \ - --reporting-server-api-target v1alpha.reporting.dev.halo-cmm.org:8443 \ - reporting-sets list --parent=measurementConsumers/VCTqwV_vFXw -``` - -To retrieve the next page of reports, use the `--page-token` option to specify -the token returned from the previous response. - -### reports - -#### create - -```shell -Reporting \ - --tls-cert-file=secretfiles/mc_tls.pem \ - --tls-key-file=secretfiles/mc_tls.key \ - --cert-collection-file=secretfiles/reporting_root.pem \ - --reporting-server-api-target=v1alpha.reporting.dev.halo-cmm.org:8443 \ - reports create \ - --idempotency-key="report001" \ - --parent=measurementConsumers/777 \ - --event-group-key=$EVENT_GROUP_NAME_1 \ - --event-group-value="video_ad.age == 1" \ - --event-group-key=$EVENT_GROUP_NAME_2 \ - --event-group-value="video_ad.age == 12" \ - --interval-start-time=2017-01-15T01:30:15.01Z \ - --interval-end-time=2018-10-27T23:19:12.99Z \ - --interval-start-time=2019-01-19T09:48:35.57Z \ - --interval-end-time=2022-06-13T11:57:54.21Z \ - --metric=' - reach { } - set_operations { - unique_name: "operation1" - set_operation { - type: 1 - lhs { - reporting_set: "measurementConsumers/1/reportingSets/1" - } - rhs { - reporting_set: "measurementConsumers/1/reportingSets/2" - } - } - } - ' -``` - -User specifies the type of time args by using either repeated interval params( -`--interval-start-time`, `--interval-end-time`) or periodic time args( -`--periodic-interval-start-time`, `--periodic-interval-increment` and -`--periodic-interval-count`) - -The `--metric` option expects a -[`Metric`](../../../../../../../../../proto/wfa/measurement/reporting/v1alpha/metric.proto) -protobuf message in text format. See -[`metric1.textproto`](../../../../../../../../../../test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/metric1.textproto) -for a complicated example. You can use shell quoting for a multiline string, or -use command substitution to read the message from a file e.g. `--metric=$(cat -metric1.textproto)`. - -#### list - -```shell -Reporting \ - --tls-cert-file=secretfiles/mc_tls.pem \ - --tls-key-file=secretfiles/mc_tls.key \ - --cert-collection-file=secretfiles/reporting_root.pem \ - --reporting-server-api-target=v1alpha.reporting.dev.halo-cmm.org:8443 \ - reports list --parent=measurementConsumers/777 -``` - -#### get - -```shell -Reporting \ - --tls-cert-file=secretfiles/mc_tls.pem \ - --tls-key-file=secretfiles/mc_tls.key \ - --cert-collection-file=secretfiles/reporting_root.pem \ - --reporting-server-api-target=v1alpha.reporting.dev.halo-cmm.org:8443 \ - reports get measurementConsumers/777/reports/5 -``` - -### event-groups - -#### list - -```shell -Reporting \ - --tls-cert-file=secretfiles/mc_tls.pem \ - --tls-key-file=secretfiles/mc_tls.key \ - --cert-collection-file=secretfiles/reporting_root.pem \ - --reporting-server-api-target=v1alpha.reporting.dev.halo-cmm.org:8443 \ - event-groups list \ - --parent=measurementConsumers/777/dataProviders/1 -``` diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/Reporting.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/Reporting.kt deleted file mode 100644 index c2254b6eb01..00000000000 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/Reporting.kt +++ /dev/null @@ -1,497 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.service.api.v1alpha.tools - -import io.grpc.ManagedChannel -import java.time.Duration -import java.time.Instant -import kotlin.properties.Delegates -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.runBlocking -import org.wfanet.measurement.common.DurationFormat -import org.wfanet.measurement.common.commandLineMain -import org.wfanet.measurement.common.crypto.SigningCerts -import org.wfanet.measurement.common.grpc.TlsFlags -import org.wfanet.measurement.common.grpc.buildMutualTlsChannel -import org.wfanet.measurement.common.grpc.withShutdownTimeout -import org.wfanet.measurement.common.parseTextProto -import org.wfanet.measurement.common.toProtoDuration -import org.wfanet.measurement.common.toProtoTime -import org.wfanet.measurement.reporting.v1alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub -import org.wfanet.measurement.reporting.v1alpha.Metric -import org.wfanet.measurement.reporting.v1alpha.ReportKt.EventGroupUniverseKt.eventGroupEntry -import org.wfanet.measurement.reporting.v1alpha.ReportKt.eventGroupUniverse -import org.wfanet.measurement.reporting.v1alpha.ReportingSetsGrpcKt.ReportingSetsCoroutineStub -import org.wfanet.measurement.reporting.v1alpha.ReportsGrpcKt.ReportsCoroutineStub -import org.wfanet.measurement.reporting.v1alpha.createReportRequest -import org.wfanet.measurement.reporting.v1alpha.createReportingSetRequest -import org.wfanet.measurement.reporting.v1alpha.getReportRequest -import org.wfanet.measurement.reporting.v1alpha.listEventGroupsRequest -import org.wfanet.measurement.reporting.v1alpha.listReportingSetsRequest -import org.wfanet.measurement.reporting.v1alpha.listReportsRequest -import org.wfanet.measurement.reporting.v1alpha.periodicTimeInterval -import org.wfanet.measurement.reporting.v1alpha.report -import org.wfanet.measurement.reporting.v1alpha.reportingSet -import org.wfanet.measurement.reporting.v1alpha.timeInterval -import org.wfanet.measurement.reporting.v1alpha.timeIntervals -import picocli.CommandLine - -private class ReportingApiFlags { - @CommandLine.Option( - names = ["--reporting-server-api-target"], - description = ["gRPC target (authority) of the reporting server's public API"], - required = true, - ) - lateinit var apiTarget: String - private set - - @CommandLine.Option( - names = ["--reporting-server-api-cert-host"], - description = - [ - "Expected hostname (DNS-ID) in the reporting server's TLS certificate.", - "This overrides derivation of the TLS DNS-ID from --reporting-server-api-target.", - ], - required = false, - ) - var apiCertHost: String? = null - private set -} - -private class PageParams { - @CommandLine.Option( - names = ["--page-size"], - description = ["The maximum number of items to return. The maximum value is 1000"], - required = false, - ) - var pageSize: Int = 1000 - private set - - @CommandLine.Option( - names = ["--page-token"], - description = ["Page token from a previous list call to retrieve the next page"], - defaultValue = "", - required = false, - ) - lateinit var pageToken: String - private set -} - -@CommandLine.Command(name = "create", description = ["Creates a reporting set"]) -class CreateReportingSetCommand : Runnable { - @CommandLine.ParentCommand private lateinit var parent: ReportingSetsCommand - - @CommandLine.Option( - names = ["--parent"], - description = ["API resource name of the Measurement Consumer"], - required = true, - ) - private lateinit var measurementConsumerName: String - - @CommandLine.Option( - names = ["--event-group"], - description = ["List of EventGroup's API resource names"], - required = true, - ) - private lateinit var eventGroups: List - - @CommandLine.Option( - names = ["--filter"], - description = ["CEL filter predicate that applies to all `event_groups`"], - required = false, - defaultValue = "", - ) - private lateinit var filterExpression: String - - @CommandLine.Option( - names = ["--display-name"], - description = ["Human-readable name for display purposes"], - required = false, - defaultValue = "", - ) - private lateinit var displayNameInput: String - - override fun run() { - val request = createReportingSetRequest { - parent = measurementConsumerName - reportingSet = reportingSet { - eventGroups += this@CreateReportingSetCommand.eventGroups - filter = filterExpression - displayName = displayNameInput - } - } - val reportingSet = - runBlocking(Dispatchers.IO) { parent.reportingSetStub.createReportingSet(request) } - println(reportingSet) - } -} - -@CommandLine.Command(name = "list", description = ["List reporting sets"]) -class ListReportingSetsCommand : Runnable { - @CommandLine.ParentCommand private lateinit var parent: ReportingSetsCommand - - @CommandLine.Option( - names = ["--parent"], - description = ["API resource name of the Measurement Consumer"], - required = true, - ) - private lateinit var measurementConsumerName: String - - @CommandLine.Mixin private lateinit var pageParams: PageParams - - override fun run() { - val request = listReportingSetsRequest { - parent = measurementConsumerName - pageSize = pageParams.pageSize - pageToken = pageParams.pageToken - } - - val response = - runBlocking(Dispatchers.IO) { parent.reportingSetStub.listReportingSets(request) } - - println(response) - } -} - -@CommandLine.Command( - name = "reporting-sets", - sortOptions = false, - subcommands = - [ - CommandLine.HelpCommand::class, - CreateReportingSetCommand::class, - ListReportingSetsCommand::class, - ], -) -class ReportingSetsCommand : Runnable { - @CommandLine.ParentCommand lateinit var parent: Reporting - val reportingSetStub: ReportingSetsCoroutineStub by lazy { - ReportingSetsCoroutineStub(parent.channel) - } - - override fun run() {} -} - -@CommandLine.Command(name = "create", description = ["Create a set operation report"]) -class CreateReportCommand : Runnable { - @CommandLine.ParentCommand private lateinit var parent: ReportsCommand - - @CommandLine.Option( - names = ["--parent"], - description = ["API resource name of the Measurement Consumer"], - required = true, - ) - private lateinit var measurementConsumerName: String - - @CommandLine.Option( - names = ["--idempotency-key"], - description = ["Used as the prefix of the idempotency keys of measurements"], - required = true, - ) - private lateinit var idempotencyKey: String - - class EventGroupInput { - @CommandLine.Option( - names = ["--event-group-key"], - description = ["Event Group Entry's key"], - required = true, - ) - lateinit var key: String - private set - - @CommandLine.Option( - names = ["--event-group-value"], - description = ["Event Group Entry's value"], - defaultValue = "", - required = false, - ) - lateinit var value: String - private set - } - - @CommandLine.ArgGroup(exclusive = false, multiplicity = "1..*", heading = "Event Group Entries\n") - private lateinit var eventGroups: List - - class TimeInput { - class TimeIntervalInput { - @CommandLine.Option( - names = ["--interval-start-time"], - description = ["Start of time interval in ISO 8601 format of UTC"], - required = true, - ) - lateinit var intervalStartTime: Instant - private set - - @CommandLine.Option( - names = ["--interval-end-time"], - description = ["End of time interval in ISO 8601 format of UTC"], - required = true, - ) - lateinit var intervalEndTime: Instant - private set - } - - class PeriodicTimeIntervalInput { - @CommandLine.Option( - names = ["--periodic-interval-start-time"], - description = ["Start of the first time interval in ISO 8601 format of UTC"], - required = true, - ) - lateinit var periodicIntervalStartTime: Instant - private set - - @CommandLine.Option( - names = ["--periodic-interval-increment"], - description = ["Increment for each time interval in ISO-8601 format of PnDTnHnMn"], - required = true, - ) - lateinit var periodicIntervalIncrement: Duration - private set - - @set:CommandLine.Option( - names = ["--periodic-interval-count"], - description = ["Number of periodic intervals"], - required = true, - ) - var periodicIntervalCount by Delegates.notNull() - private set - } - - @CommandLine.ArgGroup(exclusive = false, multiplicity = "1..*", heading = "Time intervals\n") - var timeIntervals: List? = null - private set - - @CommandLine.ArgGroup( - exclusive = false, - multiplicity = "1", - heading = "Periodic time interval specification\n", - ) - var periodicTimeIntervalInput: PeriodicTimeIntervalInput? = null - private set - } - - @CommandLine.ArgGroup( - exclusive = true, - multiplicity = "1", - heading = "Time interval or periodic time interval\n", - ) - private lateinit var timeInput: TimeInput - - @CommandLine.Option( - names = ["--metric"], - description = ["Metric protobuf messages in text format"], - required = true, - ) - private lateinit var textFormatMetrics: List - - override fun run() { - val request = createReportRequest { - parent = measurementConsumerName - report = report { - reportIdempotencyKey = idempotencyKey - measurementConsumer = measurementConsumerName - eventGroupUniverse = eventGroupUniverse { - eventGroups.forEach { - eventGroupEntries += eventGroupEntry { - key = it.key - value = it.value - } - } - } - - // Either timeIntervals or periodicTimeIntervalInput are set. - if (timeInput.timeIntervals != null) { - val intervals = checkNotNull(timeInput.timeIntervals) - timeIntervals = timeIntervals { - intervals.forEach { - timeIntervals += timeInterval { - startTime = it.intervalStartTime.toProtoTime() - endTime = it.intervalEndTime.toProtoTime() - } - } - } - } else { - val periodicIntervals = checkNotNull(timeInput.periodicTimeIntervalInput) - periodicTimeInterval = periodicTimeInterval { - startTime = periodicIntervals.periodicIntervalStartTime.toProtoTime() - increment = periodicIntervals.periodicIntervalIncrement.toProtoDuration() - intervalCount = periodicIntervals.periodicIntervalCount - } - } - - for (textFormatMetric in textFormatMetrics) { - metrics += - textFormatMetric.reader().use { parseTextProto(it, Metric.getDefaultInstance()) } - } - } - } - val report = runBlocking(Dispatchers.IO) { parent.reportsStub.createReport(request) } - - println(report) - } -} - -@CommandLine.Command(name = "list", description = ["List set operation reports"]) -class ListReportsCommand : Runnable { - @CommandLine.ParentCommand private lateinit var parent: ReportsCommand - - @CommandLine.Option( - names = ["--parent"], - description = ["API resource name of the Measurement Consumer"], - required = true, - ) - private lateinit var measurementConsumerName: String - - @CommandLine.Mixin private lateinit var pageParams: PageParams - - override fun run() { - val request = listReportsRequest { - parent = measurementConsumerName - pageSize = pageParams.pageSize - pageToken = pageParams.pageToken - } - - val response = runBlocking(Dispatchers.IO) { parent.reportsStub.listReports(request) } - - response.reportsList.forEach { println(it.name + " " + it.state.toString()) } - if (response.nextPageToken.isNotEmpty()) { - println("nextPageToken: ${response.nextPageToken}") - } - } -} - -@CommandLine.Command(name = "get", description = ["Get a set operation report"]) -class GetReportCommand : Runnable { - @CommandLine.ParentCommand private lateinit var parent: ReportsCommand - - @CommandLine.Parameters(description = ["API resource name of the Report"]) - private lateinit var reportName: String - - override fun run() { - val request = getReportRequest { name = reportName } - - val report = runBlocking(Dispatchers.IO) { parent.reportsStub.getReport(request) } - println(report) - } -} - -@CommandLine.Command( - name = "reports", - sortOptions = false, - subcommands = - [ - CommandLine.HelpCommand::class, - CreateReportCommand::class, - ListReportsCommand::class, - GetReportCommand::class, - ], -) -class ReportsCommand : Runnable { - @CommandLine.ParentCommand lateinit var parent: Reporting - - val reportsStub: ReportsCoroutineStub by lazy { ReportsCoroutineStub(parent.channel) } - - override fun run() {} -} - -@CommandLine.Command(name = "list", description = ["List event groups"]) -class ListEventGroups : Runnable { - @CommandLine.ParentCommand private lateinit var parent: EventGroupsCommand - - @CommandLine.Option( - names = ["--parent"], - description = ["API resource name of the Data Provider"], - required = true, - ) - private lateinit var dataProviderName: String - - @CommandLine.Option( - names = ["--filter"], - description = ["Result filter in format of raw CEL expression"], - required = false, - defaultValue = "", - ) - private lateinit var celFilter: String - - @CommandLine.Mixin private lateinit var pageParams: PageParams - - override fun run() { - val request = listEventGroupsRequest { - parent = dataProviderName - pageSize = pageParams.pageSize - pageToken = pageParams.pageToken - filter = celFilter - } - - val response = runBlocking(Dispatchers.IO) { parent.eventGroupStub.listEventGroups(request) } - - println(response) - } -} - -@CommandLine.Command( - name = "event-groups", - sortOptions = false, - subcommands = [CommandLine.HelpCommand::class, ListEventGroups::class], -) -class EventGroupsCommand : Runnable { - @CommandLine.ParentCommand lateinit var parent: Reporting - - val eventGroupStub: EventGroupsCoroutineStub by lazy { EventGroupsCoroutineStub(parent.channel) } - - override fun run() {} -} - -@CommandLine.Command( - name = "reporting", - description = ["Reporting CLI tool"], - sortOptions = false, - subcommands = - [ - CommandLine.HelpCommand::class, - ReportingSetsCommand::class, - ReportsCommand::class, - EventGroupsCommand::class, - ], -) -class Reporting : Runnable { - @CommandLine.Mixin private lateinit var tlsFlags: TlsFlags - @CommandLine.Mixin private lateinit var apiFlags: ReportingApiFlags - - val channel: ManagedChannel by lazy { - val clientCerts = - SigningCerts.fromPemFiles( - certificateFile = tlsFlags.certFile, - privateKeyFile = tlsFlags.privateKeyFile, - trustedCertCollectionFile = tlsFlags.certCollectionFile, - ) - buildMutualTlsChannel(apiFlags.apiTarget, clientCerts, apiFlags.apiCertHost) - .withShutdownTimeout(Duration.ofSeconds(1)) - } - - override fun run() {} - - companion object { - @JvmStatic - fun main(args: Array) = commandLineMain(Reporting(), args, DurationFormat.ISO_8601) - } -} - -/** - * Create, List and Get reporting set or report. - * - * Use the `help` command to see usage details. - */ -fun main(args: Array) = commandLineMain(Reporting(), args) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/BUILD.bazel index 22d6343b537..fc750cdc976 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/BUILD.bazel @@ -52,7 +52,7 @@ kt_jvm_library( "context_keys", ":reporting_principal", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key", - "//src/main/kotlin/org/wfanet/measurement/common/api/grpc", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:akid_principal_server_interceptor", "//src/main/kotlin/org/wfanet/measurement/common/identity", "@wfa_common_jvm//imports/java/com/google/protobuf", "@wfa_common_jvm//imports/java/io/grpc:api", @@ -120,6 +120,7 @@ kt_jvm_library( "//imports/java/org/projectnessie/cel", "//src/main/kotlin/org/wfanet/measurement/api:api_key_constants", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:packed_messages", + "//src/main/kotlin/org/wfanet/measurement/common/api/grpc:list_resources", "//src/main/kotlin/org/wfanet/measurement/reporting/service/api:cel_env_provider", "//src/main/kotlin/org/wfanet/measurement/reporting/service/api:encryption_key_pair_store", "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha:principal_server_interceptor", diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsService.kt index 825e90c2bff..650a883fa68 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsService.kt @@ -20,10 +20,13 @@ import com.google.protobuf.DynamicMessage import com.google.protobuf.kotlin.unpack import io.grpc.Context import io.grpc.Deadline +import io.grpc.Deadline.Ticker import io.grpc.Status import io.grpc.StatusException import java.security.GeneralSecurityException import java.util.concurrent.TimeUnit +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.transformWhile import org.projectnessie.cel.common.types.Err import org.projectnessie.cel.common.types.ref.Val import org.wfanet.measurement.api.v2alpha.DataProviderKey @@ -31,9 +34,12 @@ import org.wfanet.measurement.api.v2alpha.EncryptionPublicKey import org.wfanet.measurement.api.v2alpha.EventGroup as CmmsEventGroup import org.wfanet.measurement.api.v2alpha.EventGroupKey as CmmsEventGroupKey import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub as CmmsEventGroupsCoroutineStub +import org.wfanet.measurement.api.v2alpha.ListEventGroupsResponse as CmmsListEventGroupsResponse import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey import org.wfanet.measurement.api.v2alpha.listEventGroupsRequest import org.wfanet.measurement.api.withAuthenticationKey +import org.wfanet.measurement.common.api.grpc.ResourceList +import org.wfanet.measurement.common.api.grpc.listResources import org.wfanet.measurement.common.crypto.PrivateKeyHandle import org.wfanet.measurement.common.grpc.grpcRequire import org.wfanet.measurement.common.grpc.grpcRequireNotNull @@ -52,6 +58,7 @@ class EventGroupsService( private val cmmsEventGroupsStub: CmmsEventGroupsCoroutineStub, private val encryptionKeyPairStore: EncryptionKeyPairStore, private val celEnvProvider: CelEnvProvider, + private val ticker: Ticker = Deadline.getSystemTicker(), ) : EventGroupsCoroutineImplBase() { override suspend fun listEventGroups(request: ListEventGroupsRequest): ListEventGroupsResponse { val parentKey = @@ -71,66 +78,79 @@ class EventGroupsService( } } + val deadline: Deadline = + Context.current().deadline + ?: Deadline.after(RPC_DEFAULT_DEADLINE_MILLIS, TimeUnit.MILLISECONDS, ticker) val apiAuthenticationKey: String = principal.config.apiKey grpcRequire(request.pageSize >= 0) { "page_size cannot be negative" } - val pageSize = - when { - request.pageSize < MIN_PAGE_SIZE -> DEFAULT_PAGE_SIZE - request.pageSize > MAX_PAGE_SIZE -> MAX_PAGE_SIZE - else -> request.pageSize - } + val limit = + if (request.pageSize > 0) request.pageSize.coerceAtMost(MAX_PAGE_SIZE) else DEFAULT_PAGE_SIZE + val parent = parentKey.toName() + val eventGroupLists: Flow> = + cmmsEventGroupsStub.withAuthenticationKey(apiAuthenticationKey).listResources( + limit, + request.pageToken, + ) { pageToken, remaining -> + val response: CmmsListEventGroupsResponse = + listEventGroups( + listEventGroupsRequest { + this.parent = parent + this.pageSize = remaining + this.pageToken = pageToken + } + ) - var nextPageToken = request.pageToken - val deadline = Context.current().deadline ?: Deadline.after(30, TimeUnit.SECONDS) - do { - val cmmsListEventGroupResponse = - try { - cmmsEventGroupsStub - .withAuthenticationKey(apiAuthenticationKey) - .listEventGroups( - listEventGroupsRequest { - parent = parentKey.toName() - this.pageSize = pageSize - pageToken = nextPageToken + val eventGroups: List = + response.eventGroupsList.map { + val cmmsMetadata: CmmsEventGroup.Metadata? = + if (it.hasEncryptedMetadata()) { + decryptMetadata(it, principal.resourceKey.toName()) + } else { + null } - ) - } catch (e: StatusException) { - throw when (e.status.code) { - Status.Code.DEADLINE_EXCEEDED -> Status.DEADLINE_EXCEEDED - Status.Code.CANCELLED -> Status.CANCELLED - else -> Status.UNKNOWN - } - .withCause(e) - .asRuntimeException() - } - val cmmsEventGroups = cmmsListEventGroupResponse.eventGroupsList - val eventGroups = - cmmsEventGroups.map { - val cmmsMetadata: CmmsEventGroup.Metadata? = - if (it.hasEncryptedMetadata()) { - decryptMetadata(it, principal.resourceKey.toName()) - } else { - null - } + it.toEventGroup(cmmsMetadata) + } - it.toEventGroup(cmmsMetadata) - } + ResourceList(filterEventGroups(eventGroups, request.filter), response.nextPageToken) + } - val filteredEventGroups = filterEventGroups(eventGroups, request.filter) - if (filteredEventGroups.size > 0) { - return listEventGroupsResponse { - this.eventGroups += filteredEventGroups - this.nextPageToken = cmmsListEventGroupResponse.nextPageToken + var hasResponse = false + return listEventGroupsResponse { + try { + eventGroupLists + .transformWhile { + emit(it) + deadline.timeRemaining(TimeUnit.MILLISECONDS) > RPC_DEADLINE_OVERHEAD_MILLIS + } + .collect { eventGroupList -> + this.eventGroups += eventGroupList + nextPageToken = eventGroupList.nextPageToken + hasResponse = true + } + } catch (e: StatusException) { + when (e.status.code) { + Status.Code.DEADLINE_EXCEEDED, + Status.Code.CANCELLED -> { + if (!hasResponse) { + // Only throw an error if we don't have any response yet. Otherwise, just return what + // we have so far. + throw Status.DEADLINE_EXCEEDED.withDescription( + "Timed out listing EventGroups from backend" + ) + .withCause(e) + .asRuntimeException() + } + } + else -> + throw Status.UNKNOWN.withDescription("Error listing EventGroups from backend") + .withCause(e) + .asRuntimeException() } - } else { - nextPageToken = cmmsListEventGroupResponse.nextPageToken } - } while (deadline.timeRemaining(TimeUnit.SECONDS) > 5) - - return listEventGroupsResponse { this.nextPageToken = nextPageToken } + } } private suspend fun filterEventGroups( @@ -259,8 +279,13 @@ class EventGroupsService( companion object { private const val METADATA_FIELD = "metadata.metadata" - private const val MIN_PAGE_SIZE = 1 private const val DEFAULT_PAGE_SIZE = 50 private const val MAX_PAGE_SIZE = 1000 + + /** Overhead to allow for RPC deadlines in milliseconds. */ + private const val RPC_DEADLINE_OVERHEAD_MILLIS = 100L + + /** Default RPC deadline in milliseconds. */ + private const val RPC_DEFAULT_DEADLINE_MILLIS = 30_000L } } diff --git a/src/main/proto/api-linter.yaml b/src/main/proto/api-linter.yaml index 3a90b3a78a6..65ad0664967 100644 --- a/src/main/proto/api-linter.yaml +++ b/src/main/proto/api-linter.yaml @@ -55,3 +55,17 @@ # gRPC-only. REST/HTTP not supported. - 'core::0127::http-annotation' - 'core::0133::http-uri-parent' + +- included_paths: + - 'wfa/measurement/access/**/*.proto' + disabled_rules: + # Java package uses full domain, while protobuf package does not. + - 'core::0191::java-package' + # List methods aren't useful for some resource types. + - 'core::0121::resource-must-support-list' + # gRPC-only for now. + - 'core::0127::http-annotation' + # Allow custom "Lookup" methods. + - 'core::0131::synonyms' + # Allow methods that operate on resources by lookup key. + - 'core::0136::response-message-name' diff --git a/src/main/proto/wfa/measurement/access/README.md b/src/main/proto/wfa/measurement/access/README.md new file mode 100644 index 00000000000..76cb22fc880 --- /dev/null +++ b/src/main/proto/wfa/measurement/access/README.md @@ -0,0 +1,15 @@ +# Access API Definition + +This API definition attempts to follow the +[API Improvement Proposals (AIPs)](https://google.aip.dev/) with the following +notable exceptions: + +* This is a gRPC-only API with no HTTP annotations. + +## Resource IDs + +The format of a resource ID is documented either on the `name` field of the +resource type or on the `{resource}_id` field of the corresponding Create +request. Resource IDs may never be UUIDs nor appear to look like UUIDs. + +Resource IDs are immutable; they cannot be changed after resource creation. diff --git a/src/main/proto/wfa/measurement/reporting/v1alpha/BUILD.bazel b/src/main/proto/wfa/measurement/access/v1alpha/BUILD.bazel similarity index 50% rename from src/main/proto/wfa/measurement/reporting/v1alpha/BUILD.bazel rename to src/main/proto/wfa/measurement/access/v1alpha/BUILD.bazel index 5df47be011f..43dc8c3abf3 100644 --- a/src/main/proto/wfa/measurement/reporting/v1alpha/BUILD.bazel +++ b/src/main/proto/wfa/measurement/access/v1alpha/BUILD.bazel @@ -9,27 +9,19 @@ package(default_visibility = ["//visibility:public"]) IMPORT_PREFIX = "/src/main/proto" -# Resources and shared message types. - proto_library( - name = "event_group_proto", - srcs = ["event_group.proto"], + name = "principal_proto", + srcs = ["principal.proto"], strip_import_prefix = IMPORT_PREFIX, deps = [ "@com_google_googleapis//google/api:field_behavior_proto", "@com_google_googleapis//google/api:resource_proto", - "@com_google_protobuf//:any_proto", ], ) -kt_jvm_proto_library( - name = "event_group_kt_jvm_proto", - deps = [":event_group_proto"], -) - proto_library( - name = "metric", - srcs = ["metric.proto"], + name = "permission_proto", + srcs = ["permission.proto"], strip_import_prefix = IMPORT_PREFIX, deps = [ "@com_google_googleapis//google/api:field_behavior_proto", @@ -38,25 +30,18 @@ proto_library( ) proto_library( - name = "report_proto", - srcs = ["report.proto"], + name = "role_proto", + srcs = ["role.proto"], strip_import_prefix = IMPORT_PREFIX, deps = [ - ":metric", - ":time_interval_proto", "@com_google_googleapis//google/api:field_behavior_proto", "@com_google_googleapis//google/api:resource_proto", ], ) -kt_jvm_proto_library( - name = "report_kt_jvm_proto", - deps = [":report_proto"], -) - proto_library( - name = "reporting_set_proto", - srcs = ["reporting_set.proto"], + name = "policy_proto", + srcs = ["policy.proto"], strip_import_prefix = IMPORT_PREFIX, deps = [ "@com_google_googleapis//google/api:field_behavior_proto", @@ -64,85 +49,81 @@ proto_library( ], ) -kt_jvm_proto_library( - name = "reporting_set_kt_jvm_proto", - deps = [":reporting_set_proto"], -) - proto_library( - name = "time_interval_proto", - srcs = ["time_interval.proto"], + name = "roles_service_proto", + srcs = ["roles_service.proto"], strip_import_prefix = IMPORT_PREFIX, deps = [ + ":role_proto", + "@com_google_googleapis//google/api:client_proto", "@com_google_googleapis//google/api:field_behavior_proto", - "@com_google_protobuf//:duration_proto", - "@com_google_protobuf//:timestamp_proto", + "@com_google_googleapis//google/api:resource_proto", + "@com_google_protobuf//:empty_proto", ], ) proto_library( - name = "page_token_proto", - srcs = ["page_token.proto"], - deps = [], -) - -kt_jvm_proto_library( - name = "page_token_kt_jvm_proto", - deps = [":page_token_proto"], -) - -# Services. - -proto_library( - name = "event_groups_service_proto", - srcs = ["event_groups_service.proto"], + name = "permissions_service_proto", + srcs = ["permissions_service.proto"], strip_import_prefix = IMPORT_PREFIX, deps = [ - ":event_group_proto", - "@com_google_googleapis//google/api:annotations_proto", + ":permission_proto", "@com_google_googleapis//google/api:client_proto", "@com_google_googleapis//google/api:field_behavior_proto", "@com_google_googleapis//google/api:resource_proto", ], ) -kt_jvm_grpc_proto_library( - name = "event_groups_service_kt_jvm_grpc_proto", - deps = [":event_groups_service_proto"], -) - proto_library( - name = "reports_service_proto", - srcs = ["reports_service.proto"], - strip_import_prefix = IMPORT_PREFIX, + name = "principals_service_proto", + srcs = ["principals_service.proto"], deps = [ - ":report_proto", - "@com_google_googleapis//google/api:annotations_proto", + ":principal_proto", "@com_google_googleapis//google/api:client_proto", "@com_google_googleapis//google/api:field_behavior_proto", "@com_google_googleapis//google/api:resource_proto", + "@com_google_protobuf//:empty_proto", ], ) -kt_jvm_grpc_proto_library( - name = "reports_service_kt_jvm_grpc_proto", - deps = [":reports_service_proto"], -) - proto_library( - name = "reporting_sets_service_proto", - srcs = ["reporting_sets_service.proto"], + name = "policies_service_proto", + srcs = ["policies_service.proto"], strip_import_prefix = IMPORT_PREFIX, deps = [ - ":reporting_set_proto", - "@com_google_googleapis//google/api:annotations_proto", + ":policy_proto", "@com_google_googleapis//google/api:client_proto", "@com_google_googleapis//google/api:field_behavior_proto", "@com_google_googleapis//google/api:resource_proto", ], ) -kt_jvm_grpc_proto_library( - name = "reporting_sets_service_kt_jvm_grpc_proto", - deps = [":reporting_sets_service_proto"], -) +MESSAGE_LIBS = [ + "permission", + "policy", + "principal", + "role", +] + +SERVICE_LIBS = [ + "permissions", + "policies", + "principals", + "roles", +] + +[ + kt_jvm_proto_library( + name = "{prefix}_kt_jvm_proto".format(prefix = prefix), + deps = [":{prefix}_proto".format(prefix = prefix)], + ) + for prefix in MESSAGE_LIBS +] + +[ + kt_jvm_grpc_proto_library( + name = "{prefix}_service_kt_jvm_grpc_proto".format(prefix = prefix), + deps = [":{prefix}_service_proto".format(prefix = prefix)], + ) + for prefix in SERVICE_LIBS +] diff --git a/src/main/proto/wfa/measurement/access/v1alpha/permission.proto b/src/main/proto/wfa/measurement/access/v1alpha/permission.proto new file mode 100644 index 00000000000..b92a6350d0b --- /dev/null +++ b/src/main/proto/wfa/measurement/access/v1alpha/permission.proto @@ -0,0 +1,47 @@ +// Copyright 2024 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package wfa.measurement.access.v1alpha; + +import "google/api/field_behavior.proto"; +import "google/api/resource.proto"; + +option java_package = "org.wfanet.measurement.access.v1alpha"; +option java_multiple_files = true; +option java_outer_classname = "PermissionProto"; + +// Resource representing a single permission. +message Permission { + option (google.api.resource) = { + type: "access.halo-cmm.org/Permission" + pattern: "permissions/{permission}" + singular: "permission" + plural: "permissions" + }; + + // Resource name. + // + // The resource ID will confirm to RFC-1034 with the following exceptions: + // * IDs are case-sensitive. + // * IDs may contain `.` as an interior character. + string name = 1 [(google.api.field_behavior) = IDENTIFIER]; + + // Set of resource types that this `Permission` can apply to. + repeated string resource_types = 2 [ + (google.api.field_behavior) = REQUIRED, + (google.api.field_behavior) = UNORDERED_LIST + ]; +} diff --git a/src/main/proto/wfa/measurement/access/v1alpha/permissions_service.proto b/src/main/proto/wfa/measurement/access/v1alpha/permissions_service.proto new file mode 100644 index 00000000000..c3f0c984c23 --- /dev/null +++ b/src/main/proto/wfa/measurement/access/v1alpha/permissions_service.proto @@ -0,0 +1,113 @@ +// Copyright 2024 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package wfa.measurement.access.v1alpha; + +import "google/api/client.proto"; +import "google/api/field_behavior.proto"; +import "google/api/resource.proto"; +import "wfa/measurement/access/v1alpha/permission.proto"; + +option java_package = "org.wfanet.measurement.access.v1alpha"; +option java_multiple_files = true; +option java_outer_classname = "PermissionsServiceProto"; + +// Service for interacting with `Permission` resources. +service Permissions { + // Gets a `Permission` by resource name. + rpc GetPermission(GetPermissionRequest) returns (Permission) { + option (google.api.method_signature) = "name"; + } + + // Lists `Permission` resources. + rpc ListPermissions(ListPermissionsRequest) returns (ListPermissionsResponse); + + // Checks what permissions a `Principal` has on a given resource. + // + // This can be used to check if the `Principal` has sufficient `Permission`s + // for a given operation on the protected resource. If the `Permission`s that + // the `Principal` has is a strict subset of the required `Permission`s, then + // the operation should not be permitted. + rpc CheckPermissions(CheckPermissionsRequest) + returns (CheckPermissionsResponse); +} + +// Request message for `GetPermission` method. +message GetPermissionRequest { + // Resource name. + string name = 1 [ + (google.api.resource_reference) = { + type: "access.halo-cmm.org/Permission" + }, + (google.api.field_behavior) = REQUIRED + ]; +} + +// Request message for `ListPermissions` method. +message ListPermissionsRequest { + // The maximum number of resources to return. + // + // The service may return fewer resources than this value. If unspecified, at + // most 50 resources will be returned. The maximum value is 100; values above + // this will be coerced to the maximum value. + int32 page_size = 1 [(google.api.field_behavior) = OPTIONAL]; + + // A page token from a previous call to the method, used to retrieve the + // subsequent page of results. + // + // If specified, all other request parameter values must match the call that + // provided the token. + string page_token = 2 [(google.api.field_behavior) = OPTIONAL]; +} + +// Response message for `ListPermissions` method. +message ListPermissionsResponse { + // Resources. + repeated Permission permissions = 1; + + // A token, which can be sent as `page_token` to retrieve the next page. + // If this field is omitted, there are no subsequent pages. + string next_page_token = 2; +} + +// Request message for `CheckPermissions` method. +message CheckPermissionsRequest { + // Name of the resource on which to check permissions. If not specified, this + // means the root of the protected API. + string protected_resource = 1 [ + (google.api.resource_reference) = { type: "*" }, + (google.api.field_behavior) = OPTIONAL + ]; + + // Resource name of the `Principal`. + string principal = 2 [ + (google.api.field_behavior) = REQUIRED, + (google.api.resource_reference).type = "access.halo-cmm.org/Principal" + ]; + + // Set of permissions to check. + repeated string permissions = 3 [ + (google.api.field_behavior) = REQUIRED, + (google.api.field_behavior) = UNORDERED_LIST + ]; +} + +// Response message for `CheckPermissions` method. +message CheckPermissionsResponse { + // Subset of request `permissions` that `principal` has on `resource`. + repeated string permissions = 1 + [(google.api.field_behavior) = UNORDERED_LIST]; +} diff --git a/src/main/proto/wfa/measurement/access/v1alpha/policies_service.proto b/src/main/proto/wfa/measurement/access/v1alpha/policies_service.proto new file mode 100644 index 00000000000..a6a332146f4 --- /dev/null +++ b/src/main/proto/wfa/measurement/access/v1alpha/policies_service.proto @@ -0,0 +1,138 @@ +// Copyright 2024 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package wfa.measurement.access.v1alpha; + +import "google/api/client.proto"; +import "google/api/field_behavior.proto"; +import "google/api/resource.proto"; +import "wfa/measurement/access/v1alpha/policy.proto"; + +option java_package = "org.wfanet.measurement.access.v1alpha"; +option java_multiple_files = true; +option java_outer_classname = "PoliciesServiceProto"; + +// Service for interacting with `Policy` resources. +service Policies { + // Retrieves a `Policy` by name. + rpc GetPolicy(GetPolicyRequest) returns (Policy) { + option (google.api.method_signature) = "name"; + } + + // Creates a `Policy`. + rpc CreatePolicy(CreatePolicyRequest) returns (Policy) { + option (google.api.method_signature) = "policy,policy_id"; + } + + // Looks up a `Policy` by lookup key. + rpc LookupPolicy(LookupPolicyRequest) returns (Policy); + + // Adds members to a `Policy.Binding`. + rpc AddPolicyBindingMembers(AddPolicyBindingMembersRequest) returns (Policy); + + // Removes members from a `Policy.Binding`. + rpc RemovePolicyBindingMembers(RemovePolicyBindingMembersRequest) + returns (Policy); +} + +// Request message for the `GetPolicy` method. +message GetPolicyRequest { + // Resource name. + string name = 1 [ + (google.api.field_behavior) = REQUIRED, + (google.api.resource_reference) = { type: "access.halo-cmm.org/Policy" } + ]; +} + +// Request message for `CreatePolicy` method. +message CreatePolicyRequest { + // Resource to create. + Policy policy = 1 [(google.api.field_behavior) = REQUIRED]; + + // Resource ID. + // + // This must confirm to RFC-1034 with the following exceptions: + // * IDs are case-sensitive. + string policy_id = 2 [(google.api.field_behavior) = REQUIRED]; +} + +// Request message for the `LookupPolicy` method. +message LookupPolicyRequest { + // Lookup key. Required. + oneof lookup_key { + // Name of the resource to which the policy applies. + string protected_resource = 1 [ + (google.api.resource_reference) = { type: "*" }, + (google.api.field_behavior) = OPTIONAL + ]; + } +} + +// Request message for the `AddPolicyBindingMembers` method. +message AddPolicyBindingMembersRequest { + // Resource name. + string name = 1 [ + (google.api.field_behavior) = REQUIRED, + (google.api.resource_reference) = { type: "access.halo-cmm.org/Policy" } + ]; + + // Resource name of the role. + string role = 2 [ + (google.api.field_behavior) = REQUIRED, + (google.api.resource_reference) = { type: "access.halo-cmm.org/Role" } + ]; + + // Resource names of the members to add. + repeated string members = 3 [ + (google.api.field_behavior) = OPTIONAL, + (google.api.field_behavior) = UNORDERED_LIST, + (google.api.resource_reference) = { type: "access.halo-cmm.org/Principal" } + ]; + + // Current etag of the resource. + // + // If it is specified and the value does not match the current etag, the + // operation will not occur and will result in an ABORTED status. + string etag = 4 [(google.api.field_behavior) = OPTIONAL]; +} + +// Request message for the `RemovePolicyBindingMembers` method. +message RemovePolicyBindingMembersRequest { + // Resource name. + string name = 1 [ + (google.api.field_behavior) = REQUIRED, + (google.api.resource_reference) = { type: "access.halo-cmm.org/Policy" } + ]; + + // Resource name of the role. + string role = 2 [ + (google.api.field_behavior) = REQUIRED, + (google.api.resource_reference) = { type: "access.halo-cmm.org/Role" } + ]; + + // Resource names of the members to remove. + repeated string members = 3 [ + (google.api.field_behavior) = OPTIONAL, + (google.api.field_behavior) = UNORDERED_LIST, + (google.api.resource_reference) = { type: "access.halo-cmm.org/Principal" } + ]; + + // Current etag of the resource. + // + // If it is specified and the value does not match the current etag, the + // operation will not occur and will result in an ABORTED status. + optional string etag = 4 [(google.api.field_behavior) = OPTIONAL]; +} diff --git a/src/main/proto/wfa/measurement/access/v1alpha/policy.proto b/src/main/proto/wfa/measurement/access/v1alpha/policy.proto new file mode 100644 index 00000000000..56c9d5a3b6a --- /dev/null +++ b/src/main/proto/wfa/measurement/access/v1alpha/policy.proto @@ -0,0 +1,75 @@ +// Copyright 2024 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package wfa.measurement.access.v1alpha; + +import "google/api/field_behavior.proto"; +import "google/api/resource.proto"; + +option java_package = "org.wfanet.measurement.access.v1alpha"; +option java_multiple_files = true; +option java_outer_classname = "PolicyProto"; + +// Resource representing the access policy for a single resource. +message Policy { + option (google.api.resource) = { + type: "access.halo-cmm.org/Policy" + pattern: "policies/{policy}" + singular: "policy" + plural: "policies" + }; + + // Resource name. + string name = 1 [(google.api.field_behavior) = IDENTIFIER]; + + // Name of the resource protected by this `Policy`. If not specified, this + // means the root of the protected API. + // + // Must be unique across all `Policy` resources. + string protected_resource = 2 [ + (google.api.resource_reference) = { type: "*" }, + (google.api.field_behavior) = OPTIONAL, + (google.api.field_behavior) = IMMUTABLE + ]; + + // A policy binding. + message Binding { + // Resource name of the role. + string role = 1 [ + (google.api.field_behavior) = REQUIRED, + (google.api.resource_reference) = { type: "access.halo-cmm.org/Role" } + ]; + // Resource names of the principals which are members of this role on + // `resource`. + repeated string members = 2 [ + (google.api.field_behavior) = REQUIRED, + (google.api.field_behavior) = UNORDERED_LIST, + (google.api.resource_reference) = { + type: "access.halo-cmm.org/Principal" + } + ]; + } + // Policy bindings. + // + // This effectively forms a map of role to members. + repeated Binding bindings = 3 [ + (google.api.field_behavior) = OPTIONAL, + (google.api.field_behavior) = UNORDERED_LIST + ]; + + // Entity tag. + string etag = 4; +} diff --git a/src/main/proto/wfa/measurement/access/v1alpha/principal.proto b/src/main/proto/wfa/measurement/access/v1alpha/principal.proto new file mode 100644 index 00000000000..46a61e1ae03 --- /dev/null +++ b/src/main/proto/wfa/measurement/access/v1alpha/principal.proto @@ -0,0 +1,72 @@ +// Copyright 2024 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package wfa.measurement.access.v1alpha; + +import "google/api/field_behavior.proto"; +import "google/api/resource.proto"; + +option java_package = "org.wfanet.measurement.access.v1alpha"; +option java_multiple_files = true; +option java_outer_classname = "PrincipalProto"; + +// Resource representing a principal (API caller). +message Principal { + option (google.api.resource) = { + type: "access.halo-cmm.org/Principal" + pattern: "principals/{principal}" + singular: "principal" + plural: "principals" + }; + + // Resource name. + string name = 1 [(google.api.field_behavior) = IDENTIFIER]; + + // An OAuth user identity. + message OAuthUser { + // OAuth issuer identifier. + string issuer = 1 [ + (google.api.field_behavior) = REQUIRED, + (google.api.field_behavior) = IMMUTABLE + ]; + + // OAuth subject identifier. + string subject = 2 [ + (google.api.field_behavior) = REQUIRED, + (google.api.field_behavior) = IMMUTABLE + ]; + } + // A TLS client. + message TlsClient { + // Authority key identifier (AKID). + // + // (-- api-linter: core::0140::abbreviations=disabled + // aip.dev/not-precedent: Refers to an external concept. --) + bytes authority_key_identifier = 1 [ + (google.api.field_behavior) = REQUIRED, + (google.api.field_behavior) = IMMUTABLE + ]; + } + // The identity associated with the principal. Required. + // + // Must be unique across all `Principal` resources. + oneof identity { + // Single user. + OAuthUser user = 2; + // TLS client. + TlsClient tls_client = 3; + } +} diff --git a/src/main/proto/wfa/measurement/access/v1alpha/principals_service.proto b/src/main/proto/wfa/measurement/access/v1alpha/principals_service.proto new file mode 100644 index 00000000000..0b88b46e3aa --- /dev/null +++ b/src/main/proto/wfa/measurement/access/v1alpha/principals_service.proto @@ -0,0 +1,91 @@ +// Copyright 2024 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package wfa.measurement.access.v1alpha; + +import "google/api/client.proto"; +import "google/api/field_behavior.proto"; +import "google/api/resource.proto"; +import "google/protobuf/empty.proto"; +import "wfa/measurement/access/v1alpha/principal.proto"; + +option java_package = "org.wfanet.measurement.access.v1alpha"; +option java_multiple_files = true; +option java_outer_classname = "PrincipalsServiceProto"; + +// Service for interacting with `Principal` resources. +service Principals { + // Retrieves a `Principal` by resource name. + rpc GetPrincipal(GetPrincipalRequest) returns (Principal) { + option (google.api.method_signature) = "name"; + } + + // Creates a `Principal`. + rpc CreatePrincipal(CreatePrincipalRequest) returns (Principal) { + option (google.api.method_signature) = "principal,principal_id"; + } + + // Deletes a `Principal`. + // + // This will also remove the `Principal` from all `Policy` bindings. + rpc DeletePrincipal(DeletePrincipalRequest) returns (google.protobuf.Empty) { + option (google.api.method_signature) = "name"; + } + + // Looks up a `Principal` by lookup key. + rpc LookupPrincipal(LookupPrincipalRequest) returns (Principal); +} + +// Request message for `GetPrincipal` method. +message GetPrincipalRequest { + // Resource name. + string name = 1 [ + (google.api.field_behavior) = REQUIRED, + (google.api.resource_reference).type = "access.halo-cmm.org/Principal" + ]; +} + +// Request message for `CreatePrincipal` method. +message CreatePrincipalRequest { + // Resource to create. + Principal principal = 1 [(google.api.field_behavior) = REQUIRED]; + + // Resource ID. + // + // This must confirm to RFC-1034 with the following exceptions: + // * IDs are case-sensitive. + string principal_id = 2 [(google.api.field_behavior) = REQUIRED]; +} + +// Request message for `DeletePrincipal` method. +message DeletePrincipalRequest { + // Resource name. + string name = 1 [ + (google.api.field_behavior) = REQUIRED, + (google.api.resource_reference).type = "access.halo-cmm.org/Principal" + ]; +} + +// Request message for the `LookupPrincipal` method. +message LookupPrincipalRequest { + // Lookup key. Required. + oneof lookup_key { + // User identity. + Principal.OAuthUser user = 1 [(google.api.field_behavior) = OPTIONAL]; + // TLS client identity. + Principal.TlsClient tls_client = 2 [(google.api.field_behavior) = OPTIONAL]; + } +} diff --git a/src/main/proto/wfa/measurement/access/v1alpha/role.proto b/src/main/proto/wfa/measurement/access/v1alpha/role.proto new file mode 100644 index 00000000000..ea6ef704ae1 --- /dev/null +++ b/src/main/proto/wfa/measurement/access/v1alpha/role.proto @@ -0,0 +1,56 @@ +// Copyright 2024 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package wfa.measurement.access.v1alpha; + +import "google/api/field_behavior.proto"; +import "google/api/resource.proto"; + +option java_package = "org.wfanet.measurement.access.v1alpha"; +option java_multiple_files = true; +option java_outer_classname = "RoleProto"; + +// Resource representing a role. +message Role { + option (google.api.resource) = { + type: "access.halo-cmm.org/Role" + pattern: "roles/{role}" + singular: "role" + plural: "roles" + }; + + // Resource name. + string name = 1 [(google.api.field_behavior) = IDENTIFIER]; + + // Set of resource types that this `Role` can be granted on. + repeated string resource_types = 2 [ + (google.api.field_behavior) = REQUIRED, + (google.api.field_behavior) = UNORDERED_LIST + ]; + + // Set of resource names of permissions granted by this role. + // + // Each of the permissions must have all of the resource types indicated in + // `resource_type`. + repeated string permissions = 3 [ + (google.api.field_behavior) = REQUIRED, + (google.api.field_behavior) = UNORDERED_LIST, + (google.api.resource_reference).type = "access.halo-cmm.org/Permission" + ]; + + // Entity tag. + string etag = 4; +} diff --git a/src/main/proto/wfa/measurement/access/v1alpha/roles_service.proto b/src/main/proto/wfa/measurement/access/v1alpha/roles_service.proto new file mode 100644 index 00000000000..3809d1d8864 --- /dev/null +++ b/src/main/proto/wfa/measurement/access/v1alpha/roles_service.proto @@ -0,0 +1,129 @@ +// Copyright 2024 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package wfa.measurement.access.v1alpha; + +import "google/api/client.proto"; +import "google/api/field_behavior.proto"; +import "google/api/resource.proto"; +import "google/protobuf/empty.proto"; +import "wfa/measurement/access/v1alpha/role.proto"; + +option java_package = "org.wfanet.measurement.access.v1alpha"; +option java_multiple_files = true; +option java_outer_classname = "RolesServiceProto"; + +// Service for interacting with `Role` resources. +service Roles { + // Retrieves a `Role` by resource name. + rpc GetRole(GetRoleRequest) returns (Role) { + option (google.api.method_signature) = "name"; + } + + // Lists `Role` resources. + rpc ListRoles(ListRolesRequest) returns (ListRolesResponse) { + option (google.api.method_signature) = "parent"; + } + + // Creates a `Role`. + rpc CreateRole(CreateRoleRequest) returns (Role) { + option (google.api.method_signature) = "role,role_id"; + } + + // Updates a `Role`. + // + // (-- api-linter: core::0134::method-signature=disabled + // aip.dev/not-precedent: Partial update not supported. --) + rpc UpdateRole(UpdateRoleRequest) returns (Role) { + option (google.api.method_signature) = "role"; + } + + // Deletes a `Role`. + // + // This will also remove all `Policy` bindings which reference the `Role`. + rpc DeleteRole(DeleteRoleRequest) returns (google.protobuf.Empty) { + option (google.api.method_signature) = "name"; + } +} + +// Request message for `GetRole` method. +message GetRoleRequest { + // Resource name. + string name = 1 [ + (google.api.field_behavior) = REQUIRED, + (google.api.resource_reference).type = "access.halo-cmm.org/Role" + ]; +} + +// Request message for `ListRoles` method. +message ListRolesRequest { + // The maximum number of resources to return. + // + // The service may return fewer resources than this value. If unspecified, at + // most 50 resources will be returned. The maximum value is 100; values above + // this will be coerced to the maximum value. + int32 page_size = 1 [(google.api.field_behavior) = OPTIONAL]; + + // A page token from a previous call to the method, used to retrieve the + // subsequent page of results. + // + // If specified, all other request parameter values must match the call that + // provided the token. + string page_token = 2 [(google.api.field_behavior) = OPTIONAL]; +} + +// Response message for `ListRoles` method. +message ListRolesResponse { + // Resources. + repeated Role roles = 1; + + // A token, which can be sent as `page_token` to retrieve the next page. + // If this field is omitted, there are no subsequent pages. + string next_page_token = 2; +} + +// Request message for `CreateRole` method. +message CreateRoleRequest { + // Resource to create. + Role role = 1 [(google.api.field_behavior) = REQUIRED]; + + // Resource ID. + // + // This must confirm to RFC-1034 with the following exceptions: + // * IDs are case-sensitive. + string role_id = 2 [(google.api.field_behavior) = REQUIRED]; +} + +// Request message for the `UpdateRole` method. +// +// (-- api-linter: core::0134::request-mask-required=disabled +// aip.dev/not-precedent: Partial update not supported. --) +message UpdateRoleRequest { + // Resource to update. + // + // The `name` field is used to identify the resource. The `etag` field is + // required. + Role role = 1 [(google.api.field_behavior) = REQUIRED]; +} + +// Request message for `DeleteRole` method. +message DeleteRoleRequest { + // Resource name. + string name = 1 [ + (google.api.field_behavior) = REQUIRED, + (google.api.resource_reference).type = "access.halo-cmm.org/Role" + ]; +} diff --git a/src/main/proto/wfa/measurement/config/access/BUILD.bazel b/src/main/proto/wfa/measurement/config/access/BUILD.bazel new file mode 100644 index 00000000000..67b69a58395 --- /dev/null +++ b/src/main/proto/wfa/measurement/config/access/BUILD.bazel @@ -0,0 +1,20 @@ +load("@rules_proto//proto:defs.bzl", "proto_library") +load( + "@wfa_rules_kotlin_jvm//kotlin:defs.bzl", + "kt_jvm_proto_library", +) + +package(default_visibility = ["//visibility:public"]) + +IMPORT_PREFIX = "/src/main/proto" + +proto_library( + name = "permissions_config_proto", + srcs = ["permissions_config.proto"], + strip_import_prefix = IMPORT_PREFIX, +) + +kt_jvm_proto_library( + name = "permissions_config_kt_jvm_proto", + deps = [":permissions_config_proto"], +) diff --git a/src/main/proto/wfa/measurement/config/access/permissions_config.proto b/src/main/proto/wfa/measurement/config/access/permissions_config.proto new file mode 100644 index 00000000000..6862b4d5dd0 --- /dev/null +++ b/src/main/proto/wfa/measurement/config/access/permissions_config.proto @@ -0,0 +1,42 @@ +// Copyright 2024 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package wfa.measurement.config.access; + +option java_package = "org.wfanet.measurement.config.access"; +option java_multiple_files = true; +option java_outer_classname = "PermissionsConfigProto"; + +message PermissionsConfig { + message Permission { + // Types of protected resources which this Permission is applicable to. + // + // The type comes from the `google.api.ResourceDescriptor` of the resource. + // Note that in the case that the permission can be granted on the parent + // type for all children, this may not be the type mentioned in the + // Permission resource ID. For example, permission `books.read` may apply to + // resource type `example.com/Shelf` to grant reading all books on the + // shelf. + repeated string protected_resource_types = 1; + } + + // Map of permission resource ID to `Permission`. + // + // The resource ID is typically in lower camel case, and often consists of the + // plural resource type followed by an operation. For example, + // `chapterHeadings.list`. + map permissions = 1; +} diff --git a/src/main/proto/wfa/measurement/config/reporting/BUILD.bazel b/src/main/proto/wfa/measurement/config/reporting/BUILD.bazel index ff5989eb7bd..ebcb30f5080 100644 --- a/src/main/proto/wfa/measurement/config/reporting/BUILD.bazel +++ b/src/main/proto/wfa/measurement/config/reporting/BUILD.bazel @@ -27,17 +27,6 @@ kt_jvm_proto_library( deps = [":measurement_consumer_config_proto"], ) -proto_library( - name = "measurement_spec_config_proto", - srcs = ["measurement_spec_config.proto"], - strip_import_prefix = IMPORT_PREFIX, -) - -kt_jvm_proto_library( - name = "measurement_spec_config_kt_jvm_proto", - deps = [":measurement_spec_config_proto"], -) - proto_library( name = "metric_spec_config_proto", srcs = ["metric_spec_config.proto"], diff --git a/src/main/proto/wfa/measurement/config/reporting/measurement_spec_config.proto b/src/main/proto/wfa/measurement/config/reporting/measurement_spec_config.proto deleted file mode 100644 index 15cfe90fd12..00000000000 --- a/src/main/proto/wfa/measurement/config/reporting/measurement_spec_config.proto +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Copyright 2023 The Cross-Media Measurement Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; - -package wfa.measurement.config; - -option java_package = "org.wfanet.measurement.config.reporting"; -option java_multiple_files = true; - -// The configuration for MeasurementSpecs used in Measurements created by the -// Reporting server. -message MeasurementSpecConfig { - // Parameters for differential privacy (DP). - // - // For detail, refer to "Dwork, C. and Roth, A., 2014. The algorithmic - // foundations of differential privacy. Foundations and Trends in Theoretical - // Computer Science, 9(3-4), pp.211-407." - message DifferentialPrivacyParams { - double epsilon = 1; - double delta = 2; - } - - // Specifies a range of VIDs to be sampled. - message VidSamplingInterval { - message FixedStart { - // The start of the sampling interval in [0, 1) - float start = 1; - // The width of the sampling interval. - float width = 2; - } - - message RandomStart { - int32 num_vid_buckets = 1; - // The width of the sampling interval in [1, `num_vid_buckets`]. For - // example, if `num_vid_buckets` is 300, then this width can be in the - // range [1, 300]. If 100 is chosen, then the width in - // `vid_sampling_interval` of `MeasurementSpec` is 100/300. - int32 width = 2; - } - - // Defaults to start of 0 width of 1 if not set. - oneof start { - FixedStart fixed_start = 1; - RandomStart random_start = 2; - } - } - - message ReachSingleDataProvider { - // Differential privacy parameters for reach. - DifferentialPrivacyParams privacy_params = 1; - VidSamplingInterval vid_sampling_interval = 2; - } - ReachSingleDataProvider reach_single_data_provider = 1; - - message Reach { - // Differential privacy parameters for reach. - DifferentialPrivacyParams privacy_params = 1; - VidSamplingInterval vid_sampling_interval = 2; - } - Reach reach = 2; - - message ReachAndFrequencySingleDataProvider { - // Differential privacy parameters for reach. - DifferentialPrivacyParams reach_privacy_params = 1; - // Differential privacy parameters for frequency. - DifferentialPrivacyParams frequency_privacy_params = 2; - VidSamplingInterval vid_sampling_interval = 3; - } - ReachAndFrequencySingleDataProvider reach_and_frequency_single_data_provider = - 3; - - message ReachAndFrequency { - // Differential privacy parameters for reach. - DifferentialPrivacyParams reach_privacy_params = 1; - // Differential privacy parameters for frequency. - DifferentialPrivacyParams frequency_privacy_params = 2; - VidSamplingInterval vid_sampling_interval = 3; - } - ReachAndFrequency reach_and_frequency = 4; - - message Impression { - // Differential privacy parameters. - DifferentialPrivacyParams privacy_params = 1; - VidSamplingInterval vid_sampling_interval = 2; - } - Impression impression = 5; - - message Duration { - // Differential privacy parameters. - DifferentialPrivacyParams privacy_params = 1; - VidSamplingInterval vid_sampling_interval = 2; - } - Duration duration = 6; -} diff --git a/src/main/proto/wfa/measurement/integration/k8s/testing/correctness_test_config.proto b/src/main/proto/wfa/measurement/integration/k8s/testing/correctness_test_config.proto index 4c894b4ca4f..25b14b56827 100644 --- a/src/main/proto/wfa/measurement/integration/k8s/testing/correctness_test_config.proto +++ b/src/main/proto/wfa/measurement/integration/k8s/testing/correctness_test_config.proto @@ -38,4 +38,13 @@ message CorrectnessTestConfig { // Authentication key for the CMMS public API. string api_authentication_key = 4; + + // gRPC target of Reporting public API server. + string reporting_public_api_target = 5; + + // Expected hostname (DNS-ID) in the reporting public API server's TLS + // certificate. + // + // If not specified, standard TLS DNS-ID derivation will be used. + string reporting_public_api_cert_host = 6; } diff --git a/src/main/proto/wfa/measurement/internal/access/BUILD.bazel b/src/main/proto/wfa/measurement/internal/access/BUILD.bazel new file mode 100644 index 00000000000..b185245ed07 --- /dev/null +++ b/src/main/proto/wfa/measurement/internal/access/BUILD.bazel @@ -0,0 +1,112 @@ +load("@rules_proto//proto:defs.bzl", "proto_library") +load( + "@wfa_rules_kotlin_jvm//kotlin:defs.bzl", + "kt_jvm_grpc_proto_library", + "kt_jvm_proto_library", +) + +package(default_visibility = [ + "//src/main/kotlin/org/wfanet/measurement/access:__subpackages__", + "//src/test/kotlin/org/wfanet/measurement/access:__subpackages__", +]) + +IMPORT_PREFIX = "/src/main/proto" + +proto_library( + name = "principal_proto", + srcs = ["principal.proto"], + strip_import_prefix = IMPORT_PREFIX, + deps = [ + "@com_google_protobuf//:timestamp_proto", + ], +) + +proto_library( + name = "principals_service_proto", + srcs = ["principals_service.proto"], + strip_import_prefix = IMPORT_PREFIX, + deps = [ + ":principal_proto", + "@com_google_protobuf//:empty_proto", + ], +) + +proto_library( + name = "permission_proto", + srcs = ["permission.proto"], + strip_import_prefix = IMPORT_PREFIX, +) + +proto_library( + name = "permissions_service_proto", + srcs = ["permissions_service.proto"], + strip_import_prefix = IMPORT_PREFIX, + deps = [ + ":permission_proto", + ], +) + +proto_library( + name = "role_proto", + srcs = ["role.proto"], + strip_import_prefix = IMPORT_PREFIX, + deps = [ + "@com_google_protobuf//:timestamp_proto", + ], +) + +proto_library( + name = "roles_service_proto", + srcs = ["roles_service.proto"], + deps = [ + ":role_proto", + "@com_google_protobuf//:empty_proto", + ], +) + +proto_library( + name = "policy_proto", + srcs = ["policy.proto"], + strip_import_prefix = IMPORT_PREFIX, + deps = [ + "@com_google_protobuf//:timestamp_proto", + ], +) + +proto_library( + name = "policies_service_proto", + srcs = ["policies_service.proto"], + deps = [ + ":policy_proto", + ], +) + +MESSAGE_LIBS = [ + "principal", + "permission", + "role", + "policy", +] + +SERVICE_LIBS = [ + "principals", + "permissions", + "roles", + "policies", +] + +[ + kt_jvm_proto_library( + name = "{prefix}_kt_jvm_proto".format(prefix = prefix), + deps = [":{prefix}_proto".format(prefix = prefix)], + ) + for prefix in MESSAGE_LIBS +] + +[ + kt_jvm_grpc_proto_library( + name = "{prefix}_service_kt_jvm_grpc_proto".format(prefix = prefix), + deps = [":{prefix}_service_proto".format(prefix = prefix)], + ) + for prefix in SERVICE_LIBS +] diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ContextKeys.kt b/src/main/proto/wfa/measurement/internal/access/permission.proto similarity index 51% rename from src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ContextKeys.kt rename to src/main/proto/wfa/measurement/internal/access/permission.proto index 23cb010cca8..ef774649372 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ContextKeys.kt +++ b/src/main/proto/wfa/measurement/internal/access/permission.proto @@ -1,10 +1,10 @@ -// Copyright 2022 The Cross-Media Measurement Authors +// Copyright 2024 The Cross-Media Measurement Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, @@ -12,11 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -package org.wfanet.measurement.reporting.service.api.v1alpha +syntax = "proto3"; -import io.grpc.Context +package wfa.measurement.internal.access; -object ContextKeys { - /** This is the context key for the authenticated [ReportingPrincipal]. */ - val PRINCIPAL_CONTEXT_KEY: Context.Key = Context.key("principal") +option java_package = "org.wfanet.measurement.internal.access"; +option java_multiple_files = true; +option java_outer_classname = "PermissionProto"; + +message Permission { + string permission_resource_id = 1; + + // Set of resource types that this `Permission` can apply to. + repeated string resource_types = 2; } diff --git a/src/main/proto/wfa/measurement/internal/access/permissions_service.proto b/src/main/proto/wfa/measurement/internal/access/permissions_service.proto new file mode 100644 index 00000000000..9b4190be330 --- /dev/null +++ b/src/main/proto/wfa/measurement/internal/access/permissions_service.proto @@ -0,0 +1,80 @@ +// Copyright 2024 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package wfa.measurement.internal.access; + +import "wfa/measurement/internal/access/permission.proto"; + +option java_package = "org.wfanet.measurement.internal.access"; +option java_multiple_files = true; +option java_outer_classname = "PermissionsServiceProto"; + +service Permissions { + // Gets a `Permission` by resource ID. + rpc GetPermission(GetPermissionRequest) returns (Permission); + + // Lists `Permission`s ordered by resource ID. + rpc ListPermissions(ListPermissionsRequest) returns (ListPermissionsResponse); + + // Checks what permissions a `Principal` has on a given resource. + rpc CheckPermissions(CheckPermissionsRequest) + returns (CheckPermissionsResponse); +} + +message GetPermissionRequest { + string permission_resource_id = 1; +} + +message ListPermissionsRequest { + // The maximum number of results to return. + // + // If unspecified, at most 50 results will be returned. The maximum value is + // 100; values above this will be coerced to the maximum value. + int32 page_size = 1; + ListPermissionsPageToken page_token = 2; +} + +message ListPermissionsResponse { + repeated Permission permissions = 1; + + // Token for requesting subsequent pages. + // + // If not specified, there are no more results. + ListPermissionsPageToken next_page_token = 2; +} + +message CheckPermissionsRequest { + // Name of the resource on which to check permissions. + string protected_resource_name = 1; + + // Resource name of the `Principal`. + string principal_resource_id = 2; + + // Set of permissions to check. + repeated string permission_resource_ids = 3; +} + +message CheckPermissionsResponse { + // Subset of request `permissions` that `principal` has on `resource`. + repeated string permission_resource_ids = 1; +} + +message ListPermissionsPageToken { + message After { + string permission_resource_id = 1; + } + After after = 1; +} diff --git a/src/main/proto/wfa/measurement/internal/access/policies_service.proto b/src/main/proto/wfa/measurement/internal/access/policies_service.proto new file mode 100644 index 00000000000..68b5ab0dbcd --- /dev/null +++ b/src/main/proto/wfa/measurement/internal/access/policies_service.proto @@ -0,0 +1,64 @@ +// Copyright 2024 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package wfa.measurement.internal.access; + +import "wfa/measurement/internal/access/policy.proto"; + +option java_package = "org.wfanet.measurement.internal.access"; +option java_multiple_files = true; +option java_outer_classname = "PoliciesServiceProto"; + +service Policies { + rpc GetPolicy(GetPolicyRequest) returns (Policy); + + rpc CreatePolicy(Policy) returns (Policy); + + rpc LookupPolicy(LookupPolicyRequest) returns (Policy); + + rpc AddPolicyBindingMembers(AddPolicyBindingMembersRequest) returns (Policy); + + rpc RemovePolicyBindingMembers(RemovePolicyBindingMembersRequest) + returns (Policy); +} + +message GetPolicyRequest { + string policy_resource_id = 1; +} + +message LookupPolicyRequest { + oneof lookup_key { + string protected_resource_name = 1; + } +} + +message AddPolicyBindingMembersRequest { + string policy_resource_id = 1; + string role_resource_id = 2; + repeated string member_principal_resource_ids = 3; + + // Current etag. Optional. + string etag = 4; +} + +message RemovePolicyBindingMembersRequest { + string policy_resource_id = 1; + string role_resource_id = 2; + repeated string member_principal_resource_ids = 3; + + // Current etag. Optional. + string etag = 4; +} diff --git a/src/main/proto/wfa/measurement/internal/access/policy.proto b/src/main/proto/wfa/measurement/internal/access/policy.proto new file mode 100644 index 00000000000..ea6344beead --- /dev/null +++ b/src/main/proto/wfa/measurement/internal/access/policy.proto @@ -0,0 +1,50 @@ +// Copyright 2024 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package wfa.measurement.internal.access; + +import "google/protobuf/timestamp.proto"; + +option java_package = "org.wfanet.measurement.internal.access"; +option java_multiple_files = true; +option java_outer_classname = "PolicyProto"; + +message Policy { + string policy_resource_id = 1; + + // Name of the resource protected by this `Policy`. If not specified, this + // means the root of the protected API. + // + // Must be unique across all `Policy` instances. + string protected_resource_name = 2; + + message Members { + // Resource IDs of the principals which are members of this role on the + // protected resource. + repeated string member_principal_resource_ids = 1; + } + // Policy bindings as a map of `Role` resource ID to `Members`. + map bindings = 3; + + // Entity tag. + string etag = 4; + + // Create time. Output-only. + google.protobuf.Timestamp create_time = 5; + + // Update time. Output-only. + google.protobuf.Timestamp update_time = 6; +} diff --git a/src/main/proto/wfa/measurement/internal/access/principal.proto b/src/main/proto/wfa/measurement/internal/access/principal.proto new file mode 100644 index 00000000000..a1667805a23 --- /dev/null +++ b/src/main/proto/wfa/measurement/internal/access/principal.proto @@ -0,0 +1,57 @@ +// Copyright 2024 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package wfa.measurement.internal.access; + +import "google/protobuf/timestamp.proto"; + +option java_package = "org.wfanet.measurement.internal.access"; +option java_multiple_files = true; +option java_outer_classname = "PrincipalProto"; + +message Principal { + string principal_resource_id = 1; + + // An OAuth user identity. + message OAuthUser { + // OAuth issuer identifier. Required. + string issuer = 1; + + // OAuth subject identifier. Required. + string subject = 2; + } + + message TlsClient { + // Authority key identifier (AKID). Required. + bytes authority_key_identifier = 1; + } + + // Identity. Required. + oneof identity { + OAuthUser user = 2; + TlsClient tls_client = 3; + } + + // Create time. Output-only. + // + // Not specified when `tls_client` is set. + google.protobuf.Timestamp create_time = 4; + + // Update time. Output-only. + // + // Not specified when `tls_client` is set. + google.protobuf.Timestamp update_time = 5; +} diff --git a/src/main/proto/wfa/measurement/internal/access/principals_service.proto b/src/main/proto/wfa/measurement/internal/access/principals_service.proto new file mode 100644 index 00000000000..255ad6ba7f6 --- /dev/null +++ b/src/main/proto/wfa/measurement/internal/access/principals_service.proto @@ -0,0 +1,61 @@ +// Copyright 2024 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package wfa.measurement.internal.access; + +import "google/protobuf/empty.proto"; +import "wfa/measurement/internal/access/principal.proto"; + +option java_package = "org.wfanet.measurement.internal.access"; +option java_multiple_files = true; +option java_outer_classname = "PrincipalsServiceProto"; + +service Principals { + rpc GetPrincipal(GetPrincipalRequest) returns (Principal); + + rpc CreateUserPrincipal(CreateUserPrincipalRequest) returns (Principal); + + // Deletes a `Principal`. + // + // TLS client principals may not be deleted by this method. + rpc DeletePrincipal(DeletePrincipalRequest) returns (google.protobuf.Empty); + + // Looks up a `Principal` by lookup key. + rpc LookupPrincipal(LookupPrincipalRequest) returns (Principal); +} + +message GetPrincipalRequest { + string principal_resource_id = 1; +} + +message CreateUserPrincipalRequest { + string principal_resource_id = 1; + Principal.OAuthUser user = 2; +} + +message DeletePrincipalRequest { + string principal_resource_id = 1; +} + +message LookupPrincipalRequest { + // Lookup key. Required. + oneof lookup_key { + // User identity. + Principal.OAuthUser user = 1; + // TLS client identity. + Principal.TlsClient tls_client = 2; + } +} diff --git a/src/main/proto/wfa/measurement/internal/access/role.proto b/src/main/proto/wfa/measurement/internal/access/role.proto new file mode 100644 index 00000000000..2a26e10583e --- /dev/null +++ b/src/main/proto/wfa/measurement/internal/access/role.proto @@ -0,0 +1,42 @@ +// Copyright 2024 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package wfa.measurement.internal.access; + +import "google/protobuf/timestamp.proto"; + +option java_package = "org.wfanet.measurement.internal.access"; +option java_multiple_files = true; +option java_outer_classname = "RoleProto"; + +message Role { + string role_resource_id = 1; + + // Set of resource types that this `Role` can be granted on. + repeated string resource_types = 2; + + // Set of resource IDs of permissions granted by this role. + repeated string permission_resource_ids = 3; + + // Entity tag. + string etag = 4; + + // Create time. Output-only. + google.protobuf.Timestamp create_time = 5; + + // Update time. Output-only. + google.protobuf.Timestamp update_time = 6; +} diff --git a/src/main/proto/wfa/measurement/internal/access/roles_service.proto b/src/main/proto/wfa/measurement/internal/access/roles_service.proto new file mode 100644 index 00000000000..0525fd0a928 --- /dev/null +++ b/src/main/proto/wfa/measurement/internal/access/roles_service.proto @@ -0,0 +1,70 @@ +// Copyright 2024 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package wfa.measurement.internal.access; + +import "google/protobuf/empty.proto"; +import "wfa/measurement/internal/access/role.proto"; + +option java_package = "org.wfanet.measurement.internal.access"; +option java_multiple_files = true; +option java_outer_classname = "RolesServiceProto"; + +service Roles { + rpc GetRole(GetRoleRequest) returns (Role); + + // Lists Roles ordered by `role_resource_id`. + rpc ListRoles(ListRolesRequest) returns (ListRolesResponse); + + rpc CreateRole(Role) returns (Role); + + rpc UpdateRole(Role) returns (Role); + + rpc DeleteRole(DeleteRoleRequest) returns (google.protobuf.Empty); +} + +message GetRoleRequest { + string role_resource_id = 1; +} + +message DeleteRoleRequest { + string role_resource_id = 2; +} + +message ListRolesRequest { + // The maximum number of results to return. + // + // If unspecified, at most 50 results will be returned. The maximum value is + // 100; values above this will be coerced to the maximum value. + int32 page_size = 1; + ListRolesPageToken page_token = 2; +} + +message ListRolesResponse { + repeated Role roles = 1; + + // Token for requesting subsequent pages. + // + // If not specified, there are no more results. + ListRolesPageToken next_page_token = 2; +} + +message ListRolesPageToken { + message After { + string role_resource_id = 1; + } + After after = 1; +} diff --git a/src/main/proto/wfa/measurement/reporting/v1alpha/event_group.proto b/src/main/proto/wfa/measurement/reporting/v1alpha/event_group.proto deleted file mode 100644 index a40be039697..00000000000 --- a/src/main/proto/wfa/measurement/reporting/v1alpha/event_group.proto +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package wfa.measurement.reporting.v1alpha; - -import "google/api/field_behavior.proto"; -import "google/api/resource.proto"; -import "google/protobuf/any.proto"; - -option java_package = "org.wfanet.measurement.reporting.v1alpha"; -option java_multiple_files = true; -option java_outer_classname = "EventGroupProto"; - -// A grouping of events defined by a `DataProvider`. For example, a single -// campaign or creative defined in a publisher's ad system. -message EventGroup { - option (google.api.resource) = { - type: "reporting.halo-cmm.org/EventGroup" - pattern: "measurementConsumers/{measurement_consumer}/dataProviders/{data_provider}/eventGroups/{event_group}" - }; - - // Resource name. - string name = 1; - - // Resource name of the `DataProvider` associated with this `EventGroup`. - string data_provider = 2 [ - (google.api.field_behavior) = IMMUTABLE, - (google.api.resource_reference).type = "halo.wfanet.org/DataProvider" - ]; - - // ID referencing the `EventGroup` in an external system, provided by the - // `DataProvider`. - // - // If set, this value must be unique among `EventGroup`s for the parent - // `DataProvider`. - string event_group_reference_id = 3; - - // The template that events associated with this `EventGroup` conform to. - message EventTemplate { - // The type of the Event Template. A fully-qualified protobuf message type. - string type = 1 [(google.api.field_behavior) = REQUIRED]; - } - // The `EventTemplate`s that events associated with this `EventGroup` conform - // to. - repeated EventTemplate event_templates = 4; - - // Wrapper for per-EDP Event Group metadata. - message Metadata { - // The resource name of the metadata descriptor. - string event_group_metadata_descriptor = 1 - [(google.api.resource_reference).type = - "halo.wfanet.org/EventGroupMetadataDescriptor"]; - - // The serialized value of the metadata message. - google.protobuf.Any metadata = 2; - } - - Metadata metadata = 5; -} diff --git a/src/main/proto/wfa/measurement/reporting/v1alpha/event_groups_service.proto b/src/main/proto/wfa/measurement/reporting/v1alpha/event_groups_service.proto deleted file mode 100644 index 79bc69d0a0c..00000000000 --- a/src/main/proto/wfa/measurement/reporting/v1alpha/event_groups_service.proto +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package wfa.measurement.reporting.v1alpha; - -import "google/api/client.proto"; -import "google/api/field_behavior.proto"; -import "google/api/annotations.proto"; -import "google/api/resource.proto"; -import "wfa/measurement/reporting/v1alpha/event_group.proto"; - -option java_package = "org.wfanet.measurement.reporting.v1alpha"; -option java_multiple_files = true; -option java_outer_classname = "EventGroupsServiceProto"; - -// Service for interacting with `EventGroup` resources. -service EventGroups { - // Lists `EventGroup`s. Results in a `PERMISSION_DENIED` error if attempting - // to list `EventGroup`s that the authenticated user does not have access to. - rpc ListEventGroups(ListEventGroupsRequest) - returns (ListEventGroupsResponse) { - option (google.api.http) = { - get: "/v1alpha/{parent=measurementConsumers/*/dataProviders/*}/eventGroups" - }; - option (google.api.method_signature) = "parent"; - } -} - -// Request message for `ListEventGroups` method. -message ListEventGroupsRequest { - // Resource name of the parent `DataProvider` under a given - // `MeasurementConsumer`, in the form - // `measurementConsumers/{measurement_consumer}/dataProviders/{data_provider}`. - // The wildcard ID (`-`) may be used in place of the `DataProvider` ID to list - // across `DataProvider`s, in which case a filter should be specified. - string parent = 1 [ - (google.api.field_behavior) = REQUIRED, - (google.api.resource_reference) = { - child_type: "reporting.halo-cmm.org/EventGroup" - } - ]; - - // The maximum number of `EventGroup`s to return. The service may return fewer - // than this value. If unspecified, at most 50 `EventGroup`s will be - // returned. The maximum value is 1000; values above 1000 will be coerced to - // 1000. - int32 page_size = 2; - - // A token from a previous call, specified to retrieve the next page. See - // https://aip.dev/158. - string page_token = 3; - - // Result filter. Raw CEL expression that is applied to a message which has a - // field for each event group template. - string filter = 4; -} - -// Response message for `ListEventGroups` method. -message ListEventGroupsResponse { - // The `EventGroup` resources. - repeated EventGroup event_groups = 1; - - // A token that can be specified in a subsequent call to retrieve the next - // page. See https://aip.dev/158. - string next_page_token = 2; -} diff --git a/src/main/proto/wfa/measurement/reporting/v1alpha/metric.proto b/src/main/proto/wfa/measurement/reporting/v1alpha/metric.proto deleted file mode 100644 index 2a18e2b4650..00000000000 --- a/src/main/proto/wfa/measurement/reporting/v1alpha/metric.proto +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package wfa.measurement.reporting.v1alpha; - -import "google/api/field_behavior.proto"; -import "google/api/resource.proto"; - -option java_package = "org.wfanet.measurement.reporting.v1alpha"; -option java_multiple_files = true; -option java_outer_classname = "MetricProto"; - -// Definition of a computed metric for a `Report` in terms of set operations. -message Metric { - // Parameters that are used to generate `Reach` metric. - message ReachParams {} - // Parameters that are used to generate `Frequency Histogram` metric. - message FrequencyHistogramParams { - // Maximum frequency to reveal in the histogram. - int32 maximum_frequency_per_user = 1; - } - // Parameters that are used to generate `Impression Count` metric. - message ImpressionCountParams { - // Setting the maximum frequency for each user is for noising the impression - // estimation with the noise proportional to maximum_frequency_per_user to - // guarantee epsilon-DP, i.e. the higher maximum_frequency_per_user, the - // larger the variance. On the other hand, if maximum_frequency_per_user is - // too small, there's truncation bias. Through optimization, the recommended - // value for maximum_frequency_per_user = 60 for the case with 1M audience - // size. - int32 maximum_frequency_per_user = 1; - } - // Parameters that are used to generate `Watch Duration` metric. - message WatchDurationParams { - // Maximum frequency per user that will be included in this measurement. - // - // Deprecated: Not supported by the CMMS. - int32 maximum_frequency_per_user = 1 [deprecated = true]; - // Maximum watch duration per user that will be included in this - // measurement. Recommended maximum_watch_duration_per_user = cap on the - // total watch duration of all the impressions of a user = 4000 sec for the - // case with 1M audience size. - int32 maximum_watch_duration_per_user = 2; - } - - // Types of metrics that can be selected to be in a `Report`. - // REQUIRED - oneof metric_type { - // The count of unique audiences reached given a set of event groups. - ReachParams reach = 1; - // The reach frequency histogram given a set of event groups. Currently, we - // only support union operations for frequency histograms. Any other - // operations on frequency histograms won't guarantee the result is a - // frequency histogram. - FrequencyHistogramParams frequency_histogram = 2; - // The impression count given a set of event groups. - ImpressionCountParams impression_count = 3; - // The watch duration given a set of event groups. - WatchDurationParams watch_duration = 4; - } - - // Whether the results for a given time interval is cumulative with those of - // previous time intervals. Only supported when using `PeriodicTimeInterval`. - bool cumulative = 6; - - // Represents a binary set operation. - message SetOperation { - // Types of set operators. - enum Type { - // Default value. This value is unused. - TYPE_UNSPECIFIED = 0; - // The set union operation. - UNION = 1; - // The set difference operation. - DIFFERENCE = 2; - // The set intersection operation. - INTERSECTION = 3; - } - // The type of set operator that will be applied on the operands. - Type type = 1 [(google.api.field_behavior) = REQUIRED]; - - // The object of a set operation. - message Operand { - oneof operand { - // Resource name of a `ReportingSet` describing a set operand. Note that - // the reporting set is constrained by the `EventGroupUniverse` defined - // in the `Report`. - string reporting_set = 1 [(google.api.resource_reference).type = - "reporting.halo-cmm.org/ReportingSet"]; - // Nested `SetOperation` to allow for expressions with more terms. - SetOperation operation = 2; - } - } - - // Left-hand side operand of the operation. - Operand lhs = 3 [(google.api.field_behavior) = REQUIRED]; - // Right-hand side operand of the operation. If not specified, implies the - // empty set. - Operand rhs = 4; - } - - // A `SetOperation` associated with a name. - message NamedSetOperation { - // Unique name of the set operation for display purposes and creation of - // measurement reference ID. The name should be unique for the SAME metric - // type among all metrics in a report. - string unique_name = 1 [(google.api.field_behavior) = REQUIRED]; - - // A set operation that specifies the set of event groups. - SetOperation set_operation = 2 [(google.api.field_behavior) = REQUIRED]; - } - - // A list of named `SetOperations` on which the same metric will be applied. - repeated NamedSetOperation set_operations = 7 - [(google.api.field_behavior) = REQUIRED]; -} diff --git a/src/main/proto/wfa/measurement/reporting/v1alpha/page_token.proto b/src/main/proto/wfa/measurement/reporting/v1alpha/page_token.proto deleted file mode 100644 index 159cee195ad..00000000000 --- a/src/main/proto/wfa/measurement/reporting/v1alpha/page_token.proto +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2023 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package wfa.measurement.reporting.v1alpha; - -option java_package = "org.wfanet.measurement.reporting.v1alpha"; -option java_multiple_files = true; - -message ListReportingSetsPageToken { - int32 page_size = 1; - string measurement_consumer_reference_id = 2; - message PreviousPageEnd { - string measurement_consumer_reference_id = 1; - fixed64 external_reporting_set_id = 2; - } - PreviousPageEnd last_reporting_set = 3; -} - -message ListReportsPageToken { - int32 page_size = 1; - string measurement_consumer_reference_id = 2; - message PreviousPageEnd { - string measurement_consumer_reference_id = 1; - fixed64 external_report_id = 2; - } - PreviousPageEnd last_report = 3; -} diff --git a/src/main/proto/wfa/measurement/reporting/v1alpha/report.proto b/src/main/proto/wfa/measurement/reporting/v1alpha/report.proto deleted file mode 100644 index d41acb92203..00000000000 --- a/src/main/proto/wfa/measurement/reporting/v1alpha/report.proto +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package wfa.measurement.reporting.v1alpha; - -import "google/api/field_behavior.proto"; -import "google/api/resource.proto"; -import "wfa/measurement/reporting/v1alpha/metric.proto"; -import "wfa/measurement/reporting/v1alpha/time_interval.proto"; - -option java_package = "org.wfanet.measurement.reporting.v1alpha"; -option java_multiple_files = true; -option java_outer_classname = "ReportProto"; - -// Resource representing a report. -message Report { - option (google.api.resource) = { - type: "reporting.halo-cmm.org/Report" - pattern: "measurementConsumers/{measurement_consumer}/reports/{report}" - }; - - // Resource name. - string name = 1; - - // Used as the prefix of the idempotency keys of internal measurements. This - // value must be unique among `Report`s for the parent `MeasurementConsumer`. - // TODO(@riemanli) Moved the idempotency key to request messages. - string report_idempotency_key = 2 [(google.api.field_behavior) = REQUIRED]; - - // Representation of a Measurement Consumer entity. - string measurement_consumer = 3 [ - (google.api.field_behavior) = REQUIRED, - (google.api.resource_reference) = { - type: "halo.wfanet.org/MeasurementConsumer" - } - ]; - - // Define the universe for all sets of event groups in `metrics`. - message EventGroupUniverse { - // Map entry of `EventGroup` to filter predicate. - message EventGroupEntry { - // Key of the map entry, which is an `EventGroup` resource name. - string key = 1 [ - (google.api.field_behavior) = REQUIRED, - (google.api.resource_reference) = { - type: "reporting.halo-cmm.org/EventGroup" - } - ]; - - // Filter predicate in CEL. If unspecified, evaluates to `true`. - string value = 2; - } - - // A Map with `EventGroup`s as the keys and filter predicates as the values. - repeated EventGroupEntry event_group_entries = 1 - [(google.api.field_behavior) = REQUIRED]; - } - - // The universe that defines all sets of event groups in `metrics`. Events - // from event groups that aren't listed here won't be included. - EventGroupUniverse event_group_universe = 4 - [(google.api.field_behavior) = REQUIRED]; - - // Types of time intervals for metric aggregation. - // REQUIRED - oneof time { - // A list of time intervals with different start times and end times. - TimeIntervals time_intervals = 5; - // A series of time intervals with the same length. - PeriodicTimeInterval periodic_time_interval = 6; - } - - // The metrics that are included in `Report`. - repeated Metric metrics = 7 [(google.api.field_behavior) = REQUIRED]; - - // Possible states of a `Report`. - enum State { - // Default value. This value is unused. - STATE_UNSPECIFIED = 0; - // Computation is running. - RUNNING = 1; - // Completed successfully. Terminal state. - SUCCEEDED = 2; - // Completed with failure. Terminal state. - FAILED = 3; - } - - // Report state. - State state = 8 [(google.api.field_behavior) = OUTPUT_ONLY]; - - // Set operation calculations are done for each set operation for a given - // `Metric`, which is column by column. - message Result { - // A column with a header and values. - message Column { - // The header of the column. - string column_header = 1; - // Must be in the same order as `row_headers`. - repeated double set_operations = 2; - } - // The measurement result of Reach, ImpressionCount, or WatchDuration - message ScalarTable { - // Must be in the same order as column's set_operations. - repeated string row_headers = 1; - // Must be in the same order as `row_headers`. - repeated Column columns = 2; - } - // For Reach, ImpressionCount, or WatchDuration - ScalarTable scalar_table = 1; - - // HistogramTable has an additional field frequency for each row, thus - // different from ScalarTable. - message HistogramTable { - // Combine header and frequency as Row. - message Row { - // Must be in the same order as column's set_operations. - string row_header = 1; - // frequency of the row - int32 frequency = 2; - } - // Must be in the same order as column's set_operations. - repeated Row rows = 1; - // Must be in the same order as row's `row_headers`. - repeated Column columns = 2; - } - // Each HistogramTable contains the result of one frequency histogram - // metric. - repeated HistogramTable histogram_tables = 2; - } - - // Set only when `state` is `SUCCEEDED`. - Result result = 9 [(google.api.field_behavior) = OUTPUT_ONLY]; -} diff --git a/src/main/proto/wfa/measurement/reporting/v1alpha/reporting_set.proto b/src/main/proto/wfa/measurement/reporting/v1alpha/reporting_set.proto deleted file mode 100644 index e1c5ce1ce45..00000000000 --- a/src/main/proto/wfa/measurement/reporting/v1alpha/reporting_set.proto +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package wfa.measurement.reporting.v1alpha; - -import "google/api/field_behavior.proto"; -import "google/api/resource.proto"; - -option java_package = "org.wfanet.measurement.reporting.v1alpha"; -option java_multiple_files = true; -option java_outer_classname = "ReportingSetProto"; - -// Resource that describes a set of events that can be used as an operand for a -// `SetOperation`. -message ReportingSet { - option (google.api.resource) = { - type: "reporting.halo-cmm.org/ReportingSet" - pattern: "measurementConsumers/{measurement_consumer}/reportingSets/{reporting_set}" - }; - - // Resource name. - string name = 1; - - // Set of EventGroup resource names. - repeated string event_groups = 2 [ - (google.api.resource_reference).type = "reporting.halo-cmm.org/EventGroup", - (google.api.field_behavior) = REQUIRED, - (google.api.field_behavior) = IMMUTABLE - ]; - - // CEL filter predicate that applies to all `event_groups`. - // - // This filter and the one in the `event_group_filters` form a conjunction - // that specifies which impressions are included in this set. If - // unspecified, evaluates to `true`. - string filter = 3 [(google.api.field_behavior) = IMMUTABLE]; - - // Human-readable name for display purposes. - string display_name = 4; -} diff --git a/src/main/proto/wfa/measurement/reporting/v1alpha/reporting_sets_service.proto b/src/main/proto/wfa/measurement/reporting/v1alpha/reporting_sets_service.proto deleted file mode 100644 index 721df2f02fc..00000000000 --- a/src/main/proto/wfa/measurement/reporting/v1alpha/reporting_sets_service.proto +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package wfa.measurement.reporting.v1alpha; - -import "google/api/client.proto"; -import "google/api/field_behavior.proto"; -import "google/api/annotations.proto"; -import "google/api/resource.proto"; -import "wfa/measurement/reporting/v1alpha/reporting_set.proto"; - -option java_package = "org.wfanet.measurement.reporting.v1alpha"; -option java_multiple_files = true; -option java_outer_classname = "ReportingSetsServiceProto"; - -// Service for interacting with `ReportingSet` resources. -service ReportingSets { - // Lists `ReportingSet`s. - rpc ListReportingSets(ListReportingSetsRequest) - returns (ListReportingSetsResponse) { - option (google.api.http) = { - get: "/v1alpha/{parent=measurementConsumers/*}/reportingSets" - }; - option (google.api.method_signature) = "parent"; - } - - // Creates a `ReportingSet`. - rpc CreateReportingSet(CreateReportingSetRequest) returns (ReportingSet) { - option (google.api.http) = { - post: "/v1alpha/{parent=measurementConsumers/*}/reportingSets" - body: "reporting_set" - }; - option (google.api.method_signature) = "parent,reporting_set"; - } -} - -// Request message for `ListReportingSet` method. -message ListReportingSetsRequest { - // Format: measurementConsumers/{measurement_consumer} - string parent = 1 [ - (google.api.field_behavior) = REQUIRED, - (google.api.resource_reference) = { - child_type: "reporting.halo-cmm.org/ReportingSet" - } - ]; - - // The maximum number of reportingSets to return. The service may return fewer - // than this value. - // If unspecified, at most 50 reportingSets will be returned. - // The maximum value is 1000; values above 1000 will be coerced to 1000. - int32 page_size = 2; - - // A page token, received from a previous `ListReportingSets` call. - // Provide this to retrieve the subsequent page. - // - // When paginating, all other parameters provided to `ListReportingSets` must - // match the call that provided the page token. - string page_token = 3; -} - -// Response message for `ListReportingSet` method. -message ListReportingSetsResponse { - // The reportingSets from the specified measurement consumer. - repeated ReportingSet reporting_sets = 1; - - // A token, which can be sent as `page_token` to retrieve the next page. - // If this field is omitted, there are no subsequent pages. - string next_page_token = 2; -} - -// Request message for `CreateReportingSet` method. -message CreateReportingSetRequest { - // The parent resource where this reportingSet will be created. - // Format: measurementConsumers/{measurement_consumer} - string parent = 1 [ - (google.api.field_behavior) = REQUIRED, - (google.api.resource_reference) = { - child_type: "reporting.halo-cmm.org/ReportingSet" - } - ]; - - // The ReportingSet to create. - ReportingSet reporting_set = 2 [(google.api.field_behavior) = REQUIRED]; -} diff --git a/src/main/proto/wfa/measurement/reporting/v1alpha/reports_service.proto b/src/main/proto/wfa/measurement/reporting/v1alpha/reports_service.proto deleted file mode 100644 index e36ce26c3ca..00000000000 --- a/src/main/proto/wfa/measurement/reporting/v1alpha/reports_service.proto +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package wfa.measurement.reporting.v1alpha; - -import "google/api/client.proto"; -import "google/api/field_behavior.proto"; -import "google/api/annotations.proto"; -import "google/api/resource.proto"; -import "wfa/measurement/reporting/v1alpha/report.proto"; - -option java_package = "org.wfanet.measurement.reporting.v1alpha"; -option java_multiple_files = true; -option java_outer_classname = "ReportsServiceProto"; - -// Service for interacting with `Report` resources. -service Reports { - // Returns the `Report` with the specified resource key. - rpc GetReport(GetReportRequest) returns (Report) { - option (google.api.http) = { - get: "/v1alpha/{name=measurementConsumers/*/reports/*}" - }; - option (google.api.method_signature) = "name"; - } - - // Lists `Report`s. - rpc ListReports(ListReportsRequest) returns (ListReportsResponse) { - option (google.api.http) = { - get: "/v1alpha/{parent=measurementConsumers/*}/reports" - }; - option (google.api.method_signature) = "parent"; - } - - // Creates a `Report`. - rpc CreateReport(CreateReportRequest) returns (Report) { - option (google.api.http) = { - post: "/v1alpha/{parent=measurementConsumers/*}/reports" - body: "report" - }; - option (google.api.method_signature) = "parent,report"; - } -} - -// Request message for `GetReport` method. -message GetReportRequest { - // The name of the report to retrieve. - // Format: measurementConsumers/{measurement_consumer}/reports/{report} - string name = 1 [ - (google.api.field_behavior) = REQUIRED, - (google.api.resource_reference) = { type: "reporting.halo-cmm.org/Report" } - ]; -} - -// Request message for `ListReports` method. -message ListReportsRequest { - // Format: measurementConsumers/{measurement_consumer} - string parent = 1 [ - (google.api.field_behavior) = REQUIRED, - (google.api.resource_reference) = { - child_type: "reporting.halo-cmm.org/Report" - } - ]; - - // The maximum number of reports to return. The service may return fewer than - // this value. - // If unspecified, at most 50 reports will be returned. - // The maximum value is 1000; values above 1000 will be coerced to 1000. - int32 page_size = 2; - - // A page token, received from a previous `ListReports` call. - // Provide this to retrieve the subsequent page. - // - // When paginating, all other parameters provided to `ListReports` must match - // the call that provided the page token. - string page_token = 3; -} - -// Response message for `ListReports` method. -message ListReportsResponse { - // The reports from the specified measurement consumer. - repeated Report reports = 1; - - // A token, which can be sent as `page_token` to retrieve the next page. - // If this field is omitted, there are no subsequent pages. - string next_page_token = 2; -} - -// Request message for `CreateReport` method. -message CreateReportRequest { - // The parent resource where this report will be created. - // Format: measurementConsumers/{measurement_consumer} - string parent = 1 [ - (google.api.field_behavior) = REQUIRED, - (google.api.resource_reference) = { - child_type: "reporting.halo-cmm.org/Report" - } - ]; - - // The report to create. - Report report = 2 [(google.api.field_behavior) = REQUIRED]; -} diff --git a/src/main/proto/wfa/measurement/reporting/v1alpha/time_interval.proto b/src/main/proto/wfa/measurement/reporting/v1alpha/time_interval.proto deleted file mode 100644 index ef7ea15f083..00000000000 --- a/src/main/proto/wfa/measurement/reporting/v1alpha/time_interval.proto +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package wfa.measurement.reporting.v1alpha; - -import "google/api/field_behavior.proto"; -import "google/protobuf/duration.proto"; -import "google/protobuf/timestamp.proto"; - -option java_package = "org.wfanet.measurement.reporting.v1alpha"; -option java_multiple_files = true; -option java_outer_classname = "TimeIntervalProto"; - -// A time interval. -message TimeInterval { - // Start of the time interval, inclusive. - google.protobuf.Timestamp start_time = 1 - [(google.api.field_behavior) = REQUIRED]; - // End of the time interval, exclusive. This must be later than the start - // time. - google.protobuf.Timestamp end_time = 2 - [(google.api.field_behavior) = REQUIRED]; -} - -// A list of time intervals with different start times and end times. -message TimeIntervals { - // A list of time intervals. - repeated TimeInterval time_intervals = 1 - [(google.api.field_behavior) = REQUIRED]; -} - -// A series of time intervals with the same length. -message PeriodicTimeInterval { - // Start of the first time interval, inclusive. - google.protobuf.Timestamp start_time = 1 - [(google.api.field_behavior) = REQUIRED]; - - // Increment for each time interval. The first interval will be [start_time, - // start_time + increment), the second [start_time + increment, start_time + - // increment * 2) and so forth. - // - // TODO(@SanjayVas): Consider whether we want to use civil time instead. - google.protobuf.Duration increment = 2 - [(google.api.field_behavior) = REQUIRED]; - - // Number of intervals. - int32 interval_count = 3 [(google.api.field_behavior) = REQUIRED]; -} diff --git a/src/main/terraform/gcloud/cmms/reporting.tf b/src/main/terraform/gcloud/cmms/reporting.tf deleted file mode 100644 index d0e416cba56..00000000000 --- a/src/main/terraform/gcloud/cmms/reporting.tf +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2023 The Cross-Media Measurement Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -module "reporting_cluster" { - source = "../modules/cluster" - - name = local.reporting_cluster_name - location = local.cluster_location - release_channel = var.cluster_release_channel - secret_key = module.common.cluster_secret_key -} - -module "reporting_default_node_pool" { - source = "../modules/node-pool" - - name = "default" - cluster = module.reporting_cluster.cluster - service_account = module.common.cluster_service_account - machine_type = "e2-small" - max_node_count = 8 -} - -module "reporting" { - source = "../modules/reporting" - - iam_service_account_name = "reporting-internal" - postgres_instance = google_sql_database_instance.postgres - postgres_database_name = "reporting" -} - -resource "google_compute_address" "reporting_v1alpha" { - name = "reporting-v1alpha" -} diff --git a/src/main/terraform/gcloud/modules/cluster/main.tf b/src/main/terraform/gcloud/modules/cluster/main.tf index 9faedd43a8e..5ec3ca36693 100644 --- a/src/main/terraform/gcloud/modules/cluster/main.tf +++ b/src/main/terraform/gcloud/modules/cluster/main.tf @@ -65,6 +65,8 @@ resource "google_container_cluster" "cluster" { } } + deletion_protection = var.deletion_protection + lifecycle { ignore_changes = [ # This only affects the default node pool, which is immediately deleted. diff --git a/src/main/terraform/gcloud/modules/cluster/variables.tf b/src/main/terraform/gcloud/modules/cluster/variables.tf index 6a5b0416010..fab0e2ba2a0 100644 --- a/src/main/terraform/gcloud/modules/cluster/variables.tf +++ b/src/main/terraform/gcloud/modules/cluster/variables.tf @@ -42,3 +42,9 @@ variable "autoscaling_profile" { type = string default = null } + +variable "deletion_protection" { + description = "Whether deletion protection is enabled" + type = bool + default = null +} diff --git a/src/test/kotlin/org/wfanet/measurement/integration/deploy/common/postgres/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/integration/deploy/common/postgres/BUILD.bazel deleted file mode 100644 index b0ef7121c51..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/integration/deploy/common/postgres/BUILD.bazel +++ /dev/null @@ -1,21 +0,0 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_test") - -kt_jvm_test( - name = "PostgresInProcessLifeOfAReportIntegrationTest", - srcs = ["PostgresInProcessLifeOfAReportIntegrationTest.kt"], - tags = [ - "cpu:2", - "no-remote-exec", - ], - test_class = "org.wfanet.measurement.integration.deploy.common.postgres.PostgresInProcessLifeOfAReportIntegrationTest", - runtime_deps = [ - "@wfa_common_jvm//imports/java/org/yaml:snakeyaml", - ], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/integration/common/reporting:in_process_life_of_a_report_integration_test", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/common/server/postgres:services", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/testing", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/testing:database_provider", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/identity", - ], -) diff --git a/src/test/kotlin/org/wfanet/measurement/integration/deploy/common/postgres/PostgresInProcessLifeOfAReportIntegrationTest.kt b/src/test/kotlin/org/wfanet/measurement/integration/deploy/common/postgres/PostgresInProcessLifeOfAReportIntegrationTest.kt deleted file mode 100644 index 35c3dfad5de..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/integration/deploy/common/postgres/PostgresInProcessLifeOfAReportIntegrationTest.kt +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.integration.deploy.common.postgres - -import java.time.Clock -import org.junit.ClassRule -import org.wfanet.measurement.common.db.r2dbc.postgres.testing.PostgresDatabaseProviderRule -import org.wfanet.measurement.common.identity.RandomIdGenerator -import org.wfanet.measurement.integration.common.reporting.InProcessLifeOfAReportIntegrationTest -import org.wfanet.measurement.reporting.deploy.common.server.postgres.PostgresServices -import org.wfanet.measurement.reporting.deploy.postgres.testing.Schemata - -/** Implementation of [InProcessLifeOfAReportIntegrationTest] for Postgres. */ -class PostgresInProcessLifeOfAReportIntegrationTest : InProcessLifeOfAReportIntegrationTest() { - override val reportingServerDataServices by lazy { - PostgresServices.create(RandomIdGenerator(Clock.systemUTC()), databaseProvider.createDatabase()) - } - - companion object { - @get:ClassRule - @JvmStatic - val databaseProvider = PostgresDatabaseProviderRule(Schemata.REPORTING_CHANGELOG_PATH) - } -} diff --git a/src/test/kotlin/org/wfanet/measurement/integration/k8s/AbstractCorrectnessTest.kt b/src/test/kotlin/org/wfanet/measurement/integration/k8s/AbstractCorrectnessTest.kt index f4a28f4d1cd..64ed27e3e9b 100644 --- a/src/test/kotlin/org/wfanet/measurement/integration/k8s/AbstractCorrectnessTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/integration/k8s/AbstractCorrectnessTest.kt @@ -28,6 +28,7 @@ import org.wfanet.measurement.common.crypto.SigningKeyHandle import org.wfanet.measurement.integration.common.loadEncryptionPrivateKey import org.wfanet.measurement.integration.common.loadSigningKey import org.wfanet.measurement.loadtest.measurementconsumer.MeasurementConsumerSimulator +import org.wfanet.measurement.loadtest.reporting.ReportingUserSimulator /** Test for correctness of the CMMS on Kubernetes. */ abstract class AbstractCorrectnessTest(private val measurementSystem: MeasurementSystem) { @@ -37,6 +38,9 @@ abstract class AbstractCorrectnessTest(private val measurementSystem: Measuremen private val testHarness: MeasurementConsumerSimulator get() = measurementSystem.testHarness + private val reportingTestHarness: ReportingUserSimulator + get() = measurementSystem.reportingTestHarness + @Test(timeout = 1 * 60 * 1000) fun `impression measurement completes with expected result`() = runBlocking { testHarness.testImpression("$runId-impression") @@ -63,9 +67,15 @@ abstract class AbstractCorrectnessTest(private val measurementSystem: Measuremen ) } + @Test(timeout = 1 * 60 * 1000) + fun `report can be created`() = runBlocking { + reportingTestHarness.testCreateReport("$runId-test-report") + } + interface MeasurementSystem { val runId: String val testHarness: MeasurementConsumerSimulator + val reportingTestHarness: ReportingUserSimulator } companion object { @@ -97,6 +107,14 @@ abstract class AbstractCorrectnessTest(private val measurementSystem: Measuremen SigningCerts.fromPemFiles(cert, key, trustedCerts) } + val REPORTING_SIGNING_CERTS: SigningCerts by lazy { + val secretFiles = getRuntimePath(SECRET_FILES_PATH) + val trustedCerts = secretFiles.resolve("reporting_root.pem").toFile() + val cert = secretFiles.resolve("mc_tls.pem").toFile() + val key = secretFiles.resolve("mc_tls.key").toFile() + SigningCerts.fromPemFiles(cert, key, trustedCerts) + } + val MC_ENCRYPTION_PRIVATE_KEY: PrivateKeyHandle by lazy { loadEncryptionPrivateKey(MC_ENCRYPTION_PRIVATE_KEY_NAME) } diff --git a/src/test/kotlin/org/wfanet/measurement/integration/k8s/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/integration/k8s/BUILD.bazel index 3883661dbdb..2f0097c1e6f 100644 --- a/src/test/kotlin/org/wfanet/measurement/integration/k8s/BUILD.bazel +++ b/src/test/kotlin/org/wfanet/measurement/integration/k8s/BUILD.bazel @@ -13,11 +13,13 @@ kt_jvm_library( srcs = ["AbstractCorrectnessTest.kt"], data = [ "//src/main/k8s/testing/secretfiles:mc_trusted_certs.pem", + "//src/main/k8s/testing/secretfiles:reporting_root.pem", "//src/main/k8s/testing/secretfiles:secret_files", ], deps = [ "//src/main/kotlin/org/wfanet/measurement/integration/common:configs", "//src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer:simulator", + "//src/main/kotlin/org/wfanet/measurement/loadtest/reporting:simulator", "@wfa_common_jvm//imports/java/com/google/common/truth", "@wfa_common_jvm//imports/java/org/junit", "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", @@ -58,6 +60,7 @@ kt_jvm_library( "//src/main/kotlin/org/wfanet/measurement/integration/common:synthetic_generation_specs", "//src/main/kotlin/org/wfanet/measurement/loadtest/config:vid_sampling", "//src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer:synthetic_generator_event_query", + "//src/main/kotlin/org/wfanet/measurement/loadtest/reporting:simulator", "//src/main/proto/wfa/measurement/integration/k8s/testing:correctness_test_config_kt_jvm_proto", "@wfa_common_jvm//imports/java/org/junit", "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", @@ -74,6 +77,8 @@ expand_template( "{kingdom_public_api_cert_host}": "localhost", "{mc_name}": TEST_K8S_SETTINGS.mc_name, "{mc_api_key}": TEST_K8S_SETTINGS.mc_api_key, + "{reporting_public_api_target}": "$(reporting_public_api_target)", + "{reporting_public_api_cert_host}": "localhost", }, tags = ["manual"], template = "correctness_test_config.tmpl.textproto", diff --git a/src/test/kotlin/org/wfanet/measurement/integration/k8s/EmptyClusterCorrectnessTest.kt b/src/test/kotlin/org/wfanet/measurement/integration/k8s/EmptyClusterCorrectnessTest.kt index 0ddd08ec41b..95f24dc4307 100644 --- a/src/test/kotlin/org/wfanet/measurement/integration/k8s/EmptyClusterCorrectnessTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/integration/k8s/EmptyClusterCorrectnessTest.kt @@ -69,10 +69,14 @@ import org.wfanet.measurement.internal.kingdom.AccountsGrpcKt import org.wfanet.measurement.loadtest.measurementconsumer.MeasurementConsumerData import org.wfanet.measurement.loadtest.measurementconsumer.MeasurementConsumerSimulator import org.wfanet.measurement.loadtest.measurementconsumer.MetadataSyntheticGeneratorEventQuery +import org.wfanet.measurement.loadtest.reporting.ReportingUserSimulator import org.wfanet.measurement.loadtest.resourcesetup.DuchyCert import org.wfanet.measurement.loadtest.resourcesetup.EntityContent import org.wfanet.measurement.loadtest.resourcesetup.ResourceSetup import org.wfanet.measurement.loadtest.resourcesetup.Resources +import org.wfanet.measurement.reporting.v2alpha.MetricCalculationSpecsGrpcKt +import org.wfanet.measurement.reporting.v2alpha.ReportingSetsGrpcKt +import org.wfanet.measurement.reporting.v2alpha.ReportsGrpcKt /** * Test for correctness of the CMMS on a single "empty" Kubernetes cluster using the `local` @@ -108,6 +112,7 @@ class EmptyClusterCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { val worker1Cert: String, val worker2Cert: String, val measurementConsumer: String, + val measurementConsumerCert: String, val apiKey: String, val dataProviders: Map, ) { @@ -117,6 +122,7 @@ class EmptyClusterCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { var worker1Cert: String? = null var worker2Cert: String? = null var measurementConsumer: String? = null + var measurementConsumerCert: String? = null var apiKey: String? = null val dataProviders = mutableMapOf() @@ -126,6 +132,7 @@ class EmptyClusterCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { Resources.Resource.ResourceCase.MEASUREMENT_CONSUMER -> { measurementConsumer = resource.name apiKey = resource.measurementConsumer.apiKey + measurementConsumerCert = resource.measurementConsumer.certificate } Resources.Resource.ResourceCase.DATA_PROVIDER -> { val displayName = resource.dataProvider.displayName @@ -146,12 +153,13 @@ class EmptyClusterCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { } return ResourceInfo( - requireNotNull(aggregatorCert), - requireNotNull(worker1Cert), - requireNotNull(worker2Cert), - requireNotNull(measurementConsumer), - requireNotNull(apiKey), - dataProviders, + aggregatorCert = requireNotNull(aggregatorCert), + worker1Cert = requireNotNull(worker1Cert), + worker2Cert = requireNotNull(worker2Cert), + measurementConsumer = requireNotNull(measurementConsumer), + measurementConsumerCert = requireNotNull(measurementConsumerCert), + apiKey = requireNotNull(apiKey), + dataProviders = dataProviders, ) } } @@ -174,6 +182,10 @@ class EmptyClusterCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { override val testHarness: MeasurementConsumerSimulator get() = _testHarness + private lateinit var _reportingTestHarness: ReportingUserSimulator + override val reportingTestHarness: ReportingUserSimulator + get() = _reportingTestHarness + override fun apply(base: Statement, description: Description): Statement { return object : Statement() { override fun evaluate() { @@ -182,6 +194,7 @@ class EmptyClusterCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { withTimeout(Duration.ofMinutes(5)) { val measurementConsumerData = populateCluster() _testHarness = createTestHarness(measurementConsumerData) + _reportingTestHarness = createReportingUserSimulator(measurementConsumerData) } } base.evaluate() @@ -212,7 +225,12 @@ class EmptyClusterCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { val resourceSetupOutput = runResourceSetup(duchyCerts, edpEntityContents, measurementConsumerContent) val resourceInfo = ResourceInfo.from(resourceSetupOutput.resources) - loadFullCmms(resourceInfo, resourceSetupOutput.akidPrincipalMap) + loadFullCmms( + resourceInfo, + resourceSetupOutput.akidPrincipalMap, + resourceSetupOutput.measurementConsumerConfig, + resourceSetupOutput.encryptionKeyPairConfig, + ) val encryptionPrivateKey: TinkPrivateKeyHandle = withContext(Dispatchers.IO) { @@ -259,6 +277,35 @@ class EmptyClusterCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { ) } + private suspend fun createReportingUserSimulator( + measurementConsumerData: MeasurementConsumerData + ): ReportingUserSimulator { + val reportingPublicPod: V1Pod = getPod(REPORTING_PUBLIC_DEPLOYMENT_NAME) + + val publicApiForwarder = PortForwarder(reportingPublicPod, SERVER_PORT) + portForwarders.add(publicApiForwarder) + + val publicApiAddress: InetSocketAddress = + withContext(Dispatchers.IO) { publicApiForwarder.start() } + val publicApiChannel: Channel = + buildMutualTlsChannel(publicApiAddress.toTarget(), REPORTING_SIGNING_CERTS) + .also { channels.add(it) } + .withDefaultDeadline(DEFAULT_RPC_DEADLINE) + + return ReportingUserSimulator( + measurementConsumerName = measurementConsumerData.name, + dataProvidersClient = DataProvidersGrpcKt.DataProvidersCoroutineStub(publicApiChannel), + eventGroupsClient = + org.wfanet.measurement.reporting.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub( + publicApiChannel + ), + reportingSetsClient = ReportingSetsGrpcKt.ReportingSetsCoroutineStub(publicApiChannel), + metricCalculationSpecsClient = + MetricCalculationSpecsGrpcKt.MetricCalculationSpecsCoroutineStub(publicApiChannel), + reportsClient = ReportsGrpcKt.ReportsCoroutineStub(publicApiChannel), + ) + } + fun stopPortForwarding() { for (channel in channels) { channel.shutdown() @@ -268,7 +315,12 @@ class EmptyClusterCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { } } - private suspend fun loadFullCmms(resourceInfo: ResourceInfo, akidPrincipalMap: File) { + private suspend fun loadFullCmms( + resourceInfo: ResourceInfo, + akidPrincipalMap: File, + measurementConsumerConfig: File, + encryptionKeyPairConfig: File, + ) { val appliedObjects: List = withContext(Dispatchers.IO) { val outputDir = tempDir.newFolder("cmms") @@ -278,6 +330,13 @@ class EmptyClusterCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { logger.info("Copying $akidPrincipalMap to $CONFIG_FILES_PATH") akidPrincipalMap.copyTo(configFilesDir.resolve(akidPrincipalMap.name)) + logger.info("Copying $encryptionKeyPairConfig to $CONFIG_FILES_PATH") + encryptionKeyPairConfig.copyTo(configFilesDir.resolve(encryptionKeyPairConfig.name)) + + val mcConfigDir = outputDir.toPath().resolve(MC_CONFIG_PATH).toFile() + logger.info("Copying $measurementConsumerConfig to $MC_CONFIG_PATH") + measurementConsumerConfig.copyTo(mcConfigDir.resolve(measurementConsumerConfig.name)) + val configTemplate: File = outputDir.resolve("config.yaml") kustomize( outputDir.toPath().resolve(LOCAL_K8S_TESTING_PATH).resolve("cmms").toFile(), @@ -291,6 +350,8 @@ class EmptyClusterCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { .replace("{worker1_cert_name}", resourceInfo.worker1Cert) .replace("{worker2_cert_name}", resourceInfo.worker2Cert) .replace("{mc_name}", resourceInfo.measurementConsumer) + .replace("{mc_api_key}", resourceInfo.apiKey) + .replace("{mc_cert_name}", resourceInfo.measurementConsumerCert) .let { var config = it for ((displayName, resource) in resourceInfo.dataProviders) { @@ -378,6 +439,8 @@ class EmptyClusterCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { return ResourceSetupOutput( resources, outputDir.resolve(ResourceSetup.AKID_PRINCIPAL_MAP_FILE), + outputDir.resolve(ResourceSetup.MEASUREMENT_CONSUMER_CONFIG_FILE), + outputDir.resolve(ResourceSetup.ENCRYPTION_KEY_PAIR_CONFIG_FILE), ) } @@ -439,6 +502,8 @@ class EmptyClusterCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { data class ResourceSetupOutput( val resources: List, val akidPrincipalMap: File, + val measurementConsumerConfig: File, + val encryptionKeyPairConfig: File, ) } @@ -456,6 +521,8 @@ class EmptyClusterCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { private val DEFAULT_RPC_DEADLINE = Duration.ofSeconds(30) private const val KINGDOM_INTERNAL_DEPLOYMENT_NAME = "gcp-kingdom-data-server-deployment" private const val KINGDOM_PUBLIC_DEPLOYMENT_NAME = "v2alpha-public-api-server-deployment" + private const val REPORTING_PUBLIC_DEPLOYMENT_NAME = + "reporting-v2alpha-public-api-server-deployment" private const val NUM_DATA_PROVIDERS = 6 private val EDP_DISPLAY_NAMES: List = (1..NUM_DATA_PROVIDERS).map { "edp$it" } private val READY_TIMEOUT = Duration.ofMinutes(2L) @@ -463,6 +530,7 @@ class EmptyClusterCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { private val LOCAL_K8S_PATH = Paths.get("src", "main", "k8s", "local") private val LOCAL_K8S_TESTING_PATH = LOCAL_K8S_PATH.resolve("testing") private val CONFIG_FILES_PATH = LOCAL_K8S_TESTING_PATH.resolve("config_files") + private val MC_CONFIG_PATH = LOCAL_K8S_TESTING_PATH.resolve("mc_config") private val IMAGE_PUSHER_PATH = Paths.get("src", "main", "docker", "push_all_local_images.bash") private val tempDir = TemporaryFolder() diff --git a/src/test/kotlin/org/wfanet/measurement/integration/k8s/SyntheticGeneratorCorrectnessTest.kt b/src/test/kotlin/org/wfanet/measurement/integration/k8s/SyntheticGeneratorCorrectnessTest.kt index 86608dd80fa..6108c68c62f 100644 --- a/src/test/kotlin/org/wfanet/measurement/integration/k8s/SyntheticGeneratorCorrectnessTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/integration/k8s/SyntheticGeneratorCorrectnessTest.kt @@ -41,6 +41,9 @@ import org.wfanet.measurement.loadtest.dataprovider.SyntheticGeneratorEventQuery import org.wfanet.measurement.loadtest.measurementconsumer.MeasurementConsumerData import org.wfanet.measurement.loadtest.measurementconsumer.MeasurementConsumerSimulator import org.wfanet.measurement.loadtest.measurementconsumer.MetadataSyntheticGeneratorEventQuery +import org.wfanet.measurement.loadtest.reporting.ReportingUserSimulator +import org.wfanet.measurement.reporting.v2alpha.MetricCalculationSpecsGrpcKt +import org.wfanet.measurement.reporting.v2alpha.ReportsGrpcKt /** * Test for correctness of an existing CMMS on Kubernetes where the EDP simulators use @@ -48,16 +51,21 @@ import org.wfanet.measurement.loadtest.measurementconsumer.MetadataSyntheticGene * The computation composition is using ACDP by assumption. * * This currently assumes that the CMMS instance is using the certificates and keys from this Bazel - * workspace. + * workspace. It also assumes that there is a Reporting system connected to the CMMS. */ class SyntheticGeneratorCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { private class RunningMeasurementSystem : MeasurementSystem, TestRule { override val runId: String by lazy { UUID.randomUUID().toString() } private lateinit var _testHarness: MeasurementConsumerSimulator + private lateinit var _reportingTestHarness: ReportingUserSimulator + override val testHarness: MeasurementConsumerSimulator get() = _testHarness + override val reportingTestHarness: ReportingUserSimulator + get() = _reportingTestHarness + private val channels = mutableListOf() override fun apply(base: Statement, description: Description): Statement { @@ -65,6 +73,7 @@ class SyntheticGeneratorCorrectnessTest : AbstractCorrectnessTest(measurementSys override fun evaluate() { try { _testHarness = createTestHarness() + _reportingTestHarness = createReportingTestHarness() base.evaluate() } finally { shutDownChannels() @@ -110,6 +119,33 @@ class SyntheticGeneratorCorrectnessTest : AbstractCorrectnessTest(measurementSys ) } + private fun createReportingTestHarness(): ReportingUserSimulator { + val publicApiChannel = + buildMutualTlsChannel( + TEST_CONFIG.reportingPublicApiTarget, + REPORTING_SIGNING_CERTS, + TEST_CONFIG.reportingPublicApiCertHost, + ) + .also { channels.add(it) } + .withDefaultDeadline(RPC_DEADLINE_DURATION) + + return ReportingUserSimulator( + measurementConsumerName = TEST_CONFIG.measurementConsumer, + dataProvidersClient = DataProvidersGrpcKt.DataProvidersCoroutineStub(publicApiChannel), + eventGroupsClient = + org.wfanet.measurement.reporting.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub( + publicApiChannel + ), + reportingSetsClient = + org.wfanet.measurement.reporting.v2alpha.ReportingSetsGrpcKt.ReportingSetsCoroutineStub( + publicApiChannel + ), + metricCalculationSpecsClient = + MetricCalculationSpecsGrpcKt.MetricCalculationSpecsCoroutineStub(publicApiChannel), + reportsClient = ReportsGrpcKt.ReportsCoroutineStub(publicApiChannel), + ) + } + private fun shutDownChannels() { for (channel in channels) { channel.shutdown() diff --git a/src/test/kotlin/org/wfanet/measurement/integration/k8s/correctness_test_config.tmpl.textproto b/src/test/kotlin/org/wfanet/measurement/integration/k8s/correctness_test_config.tmpl.textproto index a6b54b93d85..1242cd723ec 100644 --- a/src/test/kotlin/org/wfanet/measurement/integration/k8s/correctness_test_config.tmpl.textproto +++ b/src/test/kotlin/org/wfanet/measurement/integration/k8s/correctness_test_config.tmpl.textproto @@ -4,3 +4,5 @@ kingdom_public_api_target: "{kingdom_public_api_target}" kingdom_public_api_cert_host: "{kingdom_public_api_cert_host}" measurement_consumer: "{mc_name}" api_authentication_key: "{mc_api_key}" +reporting_public_api_target: "{reporting_public_api_target}" +reporting_public_api_cert_host: "{reporting_public_api_cert_host}" diff --git a/src/test/kotlin/org/wfanet/measurement/loadtest/dataprovider/PopulationRequisitionFulfillerTest.kt b/src/test/kotlin/org/wfanet/measurement/loadtest/dataprovider/PopulationRequisitionFulfillerTest.kt index d09f1e38ec1..056256f2ffb 100644 --- a/src/test/kotlin/org/wfanet/measurement/loadtest/dataprovider/PopulationRequisitionFulfillerTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/loadtest/dataprovider/PopulationRequisitionFulfillerTest.kt @@ -186,7 +186,6 @@ class PopulationRequisitionFulfillerTest { requisitionsStub, dummyThrottler, TRUSTED_CERTIFICATES, - MC_NAME, modelRolloutsStub, modelReleasesStub, POPULATION_INFO_MAP, @@ -236,7 +235,6 @@ class PopulationRequisitionFulfillerTest { requisitionsStub, dummyThrottler, TRUSTED_CERTIFICATES, - MC_NAME, modelRolloutsStub, modelReleasesStub, POPULATION_INFO_MAP, @@ -289,7 +287,6 @@ class PopulationRequisitionFulfillerTest { requisitionsStub, dummyThrottler, TRUSTED_CERTIFICATES, - MC_NAME, modelRolloutsStub, modelReleasesStub, POPULATION_INFO_MAP, @@ -343,7 +340,6 @@ class PopulationRequisitionFulfillerTest { requisitionsStub, dummyThrottler, TRUSTED_CERTIFICATES, - MC_NAME, modelRolloutsStub, modelReleasesStub, POPULATION_INFO_MAP, @@ -400,7 +396,6 @@ class PopulationRequisitionFulfillerTest { requisitionsStub, dummyThrottler, TRUSTED_CERTIFICATES, - MC_NAME, modelRolloutsStub, modelReleasesStub, POPULATION_INFO_MAP, @@ -455,7 +450,6 @@ class PopulationRequisitionFulfillerTest { requisitionsStub, dummyThrottler, TRUSTED_CERTIFICATES, - MC_NAME, modelRolloutsStub, modelReleasesStub, invalidPopulationInfoMap, @@ -505,7 +499,6 @@ class PopulationRequisitionFulfillerTest { requisitionsStub, dummyThrottler, TRUSTED_CERTIFICATES, - MC_NAME, modelRolloutsStub, modelReleasesStub, INVALID_POPULATION_INFO_MAP, @@ -556,7 +549,6 @@ class PopulationRequisitionFulfillerTest { requisitionsStub, dummyThrottler, TRUSTED_CERTIFICATES, - MC_NAME, modelRolloutsStub, modelReleasesStub, INVALID_POPULATION_INFO_MAP, @@ -605,7 +597,6 @@ class PopulationRequisitionFulfillerTest { requisitionsStub, dummyThrottler, TRUSTED_CERTIFICATES, - MC_NAME, modelRolloutsStub, modelReleasesStub, POPULATION_INFO_MAP, diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/common/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/reporting/deploy/common/BUILD.bazel deleted file mode 100644 index 95b258038e9..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/common/BUILD.bazel +++ /dev/null @@ -1,21 +0,0 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_test") - -filegroup( - name = "textproto_files", - srcs = glob(["*.textproto"]), -) - -kt_jvm_test( - name = "EncryptionKeyPairMapTest", - srcs = ["EncryptionKeyPairMapTest.kt"], - data = [ - "//src/main/k8s/testing/secretfiles:secret_files", - "//src/test/kotlin/org/wfanet/measurement/reporting/deploy/common:textproto_files", - ], - test_class = "org.wfanet.measurement.reporting.deploy.common.EncryptionKeyPairMapTest", - deps = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/common:encryption_key_pair_map", - "@wfa_common_jvm//imports/java/com/google/common/truth", - "@wfa_common_jvm//imports/kotlin/kotlin/test", - ], -) diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/common/EncryptionKeyPairMapTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/deploy/common/EncryptionKeyPairMapTest.kt deleted file mode 100644 index 3c3c4c0a771..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/common/EncryptionKeyPairMapTest.kt +++ /dev/null @@ -1,144 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.common - -import com.google.common.truth.Truth.assertThat -import com.google.protobuf.ByteString -import com.google.protobuf.kotlin.toByteStringUtf8 -import java.nio.file.Path -import java.nio.file.Paths -import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 -import org.wfanet.measurement.common.crypto.PrivateKeyHandle -import org.wfanet.measurement.common.crypto.tink.TinkPublicKeyHandle -import org.wfanet.measurement.common.getRuntimePath -import org.wfanet.measurement.common.readByteString -import picocli.CommandLine -import picocli.CommandLine.Mixin - -private val SECRETS_DIR: Path = - getRuntimePath( - Paths.get("wfa_measurement_system", "src", "main", "k8s", "testing", "secretfiles") - )!! - -private val ENCRYPTION_KEY_PAIR_MAP: Path = - getRuntimePath( - Paths.get( - "wfa_measurement_system", - "src", - "test", - "kotlin", - "org", - "wfanet", - "measurement", - "reporting", - "deploy", - "common", - "key_pair_map.textproto", - ) - )!! - -private val PUBLIC_KEY_FILE_1 = SECRETS_DIR.resolve("mc_enc_public.tink").toFile() -private val PUBLIC_KEY_1 = PUBLIC_KEY_FILE_1.readByteString() -private val PUBLIC_KEY_FILE_2 = SECRETS_DIR.resolve("edp1_enc_public.tink").toFile() -private val PUBLIC_KEY_2 = PUBLIC_KEY_FILE_2.readByteString() -private val PUBLIC_KEY_FILE_3 = SECRETS_DIR.resolve("edp2_enc_public.tink").toFile() -private val PUBLIC_KEY_3 = PUBLIC_KEY_FILE_3.readByteString() -private val NON_EXISTENT_PUBLIC_KEY = "non existent public key".toByteStringUtf8() -private val MEASUREMENT_CONSUMER1 = "measurement_consumer1" -private val MEASUREMENT_CONSUMER2 = "measurement_consumer2" -private val NON_EXISTENT_MEASUREMENT_CONSUMER = "non_existent_measurement_consumer" -private val PLAIN_TEXT = "This is plain text".toByteStringUtf8() - -@RunWith(JUnit4::class) -class EncryptionKeyPairMapTest { - - private fun findKeyPair( - principal: String, - publicKey: ByteString, - keyPairMap: Map>>, - ) = - keyPairMap[principal]?.find { (key, _): Pair -> key == publicKey } - - @Test - fun `keyPairMap returns corresponding private keys`() { - val args = - arrayOf("--key-pair-dir=$SECRETS_DIR", "--key-pair-config-file=$ENCRYPTION_KEY_PAIR_MAP") - - runTest(args) { keyPairMap -> - val findPrivateKey = { principal: String, publicKey: ByteString -> - findKeyPair(principal, publicKey, keyPairMap)?.second - } - verifyKeyPair( - PUBLIC_KEY_1, - requireNotNull(findPrivateKey(MEASUREMENT_CONSUMER1, PUBLIC_KEY_1)), - ) - verifyKeyPair( - PUBLIC_KEY_2, - requireNotNull(findPrivateKey(MEASUREMENT_CONSUMER1, PUBLIC_KEY_2)), - ) - verifyKeyPair( - PUBLIC_KEY_3, - requireNotNull(findPrivateKey(MEASUREMENT_CONSUMER2, PUBLIC_KEY_3)), - ) - } - } - - @Test - fun `keyPairMap returns null when private key is not found`() { - val args = - arrayOf("--key-pair-dir=$SECRETS_DIR", "--key-pair-config-file=$ENCRYPTION_KEY_PAIR_MAP") - - runTest(args) { keyPairMap -> - assertThat(findKeyPair(MEASUREMENT_CONSUMER1, NON_EXISTENT_PUBLIC_KEY, keyPairMap)).isNull() - } - } - - @Test - fun `keyPairMap returns null when principal is not found`() { - val args = - arrayOf("--key-pair-dir=$SECRETS_DIR", "--key-pair-config-file=$ENCRYPTION_KEY_PAIR_MAP") - - runTest(args) { keyPairMap -> - assertThat(findKeyPair(NON_EXISTENT_MEASUREMENT_CONSUMER, PUBLIC_KEY_1, keyPairMap)).isNull() - } - } - - private class KeyPairMapWrapper( - val verifyBlock: (keyPairs: Map>>) -> Unit - ) : Runnable { - @Mixin lateinit var encryptionKeyPairMap: EncryptionKeyPairMap - - override fun run() { - verifyBlock(encryptionKeyPairMap.keyPairs) - } - } - - private fun runTest( - args: Array, - verifyBlock: (Map>>) -> Unit, - ) { - val returnCode = CommandLine(KeyPairMapWrapper(verifyBlock)).execute(*args) - assertThat(returnCode).isEqualTo(0) - } - - private fun verifyKeyPair(publicKeyData: ByteString, privateKeyHandle: PrivateKeyHandle) { - val publicKeyHandle = TinkPublicKeyHandle(publicKeyData) - val encryptedText = publicKeyHandle.hybridEncrypt(PLAIN_TEXT) - val decryptedText = privateKeyHandle.hybridDecrypt(encryptedText) - assertThat(decryptedText).isEqualTo(PLAIN_TEXT) - } -} diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/common/key_pair_map.textproto b/src/test/kotlin/org/wfanet/measurement/reporting/deploy/common/key_pair_map.textproto deleted file mode 100644 index 3a8e52ec9ce..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/common/key_pair_map.textproto +++ /dev/null @@ -1,26 +0,0 @@ -# proto-file: -# src/main/proto/wfa/measurement/config/reporting/encryption_key_pair_config.proto -# proto-message: EncryptionKeyPairConfig - -principal_key_pairs { - principal: "measurement_consumer1" - key_pairs { - public_key_file: "mc_enc_public.tink" - private_key_file: "mc_enc_private.tink" - } - key_pairs { - public_key_file: "edp1_enc_public.tink" - private_key_file: "edp1_enc_private.tink" - } -} -principal_key_pairs { - principal: "measurement_consumer2" - key_pairs { - public_key_file: "edp2_enc_public.tink" - private_key_file: "edp2_enc_private.tink" - } - key_pairs { - public_key_file: "edp3_enc_public.tink" - private_key_file: "edp3_enc_private.tink" - } -} diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/config/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/reporting/deploy/config/BUILD.bazel deleted file mode 100644 index 6a543aadc48..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/config/BUILD.bazel +++ /dev/null @@ -1,12 +0,0 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_test") - -kt_jvm_test( - name = "MeasurementSpecConfigValidatorTest", - srcs = ["MeasurementSpecConfigValidatorTest.kt"], - test_class = "org.wfanet.measurement.reporting.deploy.config.MeasurementSpecConfigValidatorTest", - deps = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/config:measurement_spec_config_validator", - "@wfa_common_jvm//imports/java/com/google/common/truth", - "@wfa_common_jvm//imports/kotlin/kotlin/test", - ], -) diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/config/MeasurementSpecConfigValidatorTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/deploy/config/MeasurementSpecConfigValidatorTest.kt deleted file mode 100644 index a13494f6e3a..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/config/MeasurementSpecConfigValidatorTest.kt +++ /dev/null @@ -1,295 +0,0 @@ -/* - * Copyright 2023 The Cross-Media Measurement Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.wfanet.measurement.reporting.deploy.config - -import com.google.common.truth.Truth.assertThat -import kotlin.test.assertFailsWith -import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 -import org.wfanet.measurement.config.reporting.MeasurementSpecConfigKt -import org.wfanet.measurement.config.reporting.copy -import org.wfanet.measurement.config.reporting.measurementSpecConfig - -@RunWith(JUnit4::class) -class MeasurementSpecConfigValidatorTest { - @Test - fun `validate throws no exception when config is valid`() { - MEASUREMENT_SPEC_CONFIG.validate() - } - - @Test - fun `validate throws ILLEGAL_STATE_EXCEPTION when epsilon is 0`() { - val invalidMeasurementSpecConfig = - MEASUREMENT_SPEC_CONFIG.copy { - reachSingleDataProvider = - MEASUREMENT_SPEC_CONFIG.reachSingleDataProvider.copy { - privacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.0 - delta = 0.0 - } - } - } - val exception = - assertFailsWith { invalidMeasurementSpecConfig.validate() } - assertThat(exception.message).contains("privacy_params") - } - - @Test - fun `validate throws ILLEGAL_STATE_EXCEPTION when delta is negaitve`() { - val invalidMeasurementSpecConfig = - MEASUREMENT_SPEC_CONFIG.copy { - reachSingleDataProvider = - MEASUREMENT_SPEC_CONFIG.reachSingleDataProvider.copy { - privacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 1.0 - delta = -1.0 - } - } - } - val exception = - assertFailsWith { invalidMeasurementSpecConfig.validate() } - assertThat(exception.message).contains("privacy_params") - } - - @Test - fun `validate throws ILLEGAL_STATE_EXCEPTION when fixed_start width is 0`() { - val invalidMeasurementSpecConfig = - MEASUREMENT_SPEC_CONFIG.copy { - reachSingleDataProvider = - MEASUREMENT_SPEC_CONFIG.reachSingleDataProvider.copy { - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - fixedStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.fixedStart { width = 0.0f } - } - } - } - val exception = - assertFailsWith { invalidMeasurementSpecConfig.validate() } - assertThat(exception.message).contains("vid_sampling_interval") - } - - @Test - fun `validate throws ILLEGAL_STATE_EXCEPTION when fixed_start start is negative`() { - val invalidMeasurementSpecConfig = - MEASUREMENT_SPEC_CONFIG.copy { - reachSingleDataProvider = - MEASUREMENT_SPEC_CONFIG.reachSingleDataProvider.copy { - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - fixedStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.fixedStart { start = -1.0f } - } - } - } - val exception = - assertFailsWith { invalidMeasurementSpecConfig.validate() } - assertThat(exception.message).contains("vid_sampling_interval") - } - - @Test - fun `validate throws ILLEGAL_STATE_EXCEPTION when fixed_start interval is more than 1`() { - val invalidMeasurementSpecConfig = - MEASUREMENT_SPEC_CONFIG.copy { - reachSingleDataProvider = - MEASUREMENT_SPEC_CONFIG.reachSingleDataProvider.copy { - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - fixedStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.fixedStart { width = 1.1f } - } - } - } - val exception = - assertFailsWith { invalidMeasurementSpecConfig.validate() } - assertThat(exception.message).contains("vid_sampling_interval") - } - - @Test - fun `validate throws ILLEGAL_STATE_EXCEPTION when random_start width is negative`() { - val invalidMeasurementSpecConfig = - MEASUREMENT_SPEC_CONFIG.copy { - reachSingleDataProvider = - MEASUREMENT_SPEC_CONFIG.reachSingleDataProvider.copy { - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - randomStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.randomStart { - width = -5 - numVidBuckets = 5 - } - } - } - } - val exception = - assertFailsWith { invalidMeasurementSpecConfig.validate() } - assertThat(exception.message).contains("vid_sampling_interval") - } - - @Test - fun `validate throws ILLEGAL_STATE_EXCEPTION when random_start num_vid_buckets is 0`() { - val invalidMeasurementSpecConfig = - MEASUREMENT_SPEC_CONFIG.copy { - reachSingleDataProvider = - MEASUREMENT_SPEC_CONFIG.reachSingleDataProvider.copy { - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - randomStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.randomStart { numVidBuckets = 0 } - } - } - } - val exception = - assertFailsWith { invalidMeasurementSpecConfig.validate() } - assertThat(exception.message).contains("vid_sampling_interval") - } - - @Test - fun `validate throws ILLEGAL_STATE_EXCEPTION when random_start width exceeds vid buckets`() { - val invalidMeasurementSpecConfig = - MEASUREMENT_SPEC_CONFIG.copy { - reachSingleDataProvider = - MEASUREMENT_SPEC_CONFIG.reachSingleDataProvider.copy { - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - randomStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.randomStart { - width = 6 - numVidBuckets = 5 - } - } - } - } - val exception = - assertFailsWith { invalidMeasurementSpecConfig.validate() } - assertThat(exception.message).contains("vid_sampling_interval") - } - - companion object { - private val MEASUREMENT_SPEC_CONFIG = measurementSpecConfig { - reachSingleDataProvider = - MeasurementSpecConfigKt.reachSingleDataProvider { - privacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.000207 - delta = 1e-15 - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - fixedStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.fixedStart { - start = 0f - width = 1f - } - } - } - reach = - MeasurementSpecConfigKt.reach { - privacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.0007444 - delta = 1e-15 - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - randomStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.randomStart { - width = 256 - numVidBuckets = 300 - } - } - } - reachAndFrequencySingleDataProvider = - MeasurementSpecConfigKt.reachAndFrequencySingleDataProvider { - reachPrivacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.000207 - delta = 1e-15 - } - frequencyPrivacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.004728 - delta = 1e-15 - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - fixedStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.fixedStart { - start = 0f - width = 1f - } - } - } - reachAndFrequency = - MeasurementSpecConfigKt.reachAndFrequency { - reachPrivacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.0007444 - delta = 1e-15 - } - frequencyPrivacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.014638 - delta = 1e-15 - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - randomStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.randomStart { - width = 256 - numVidBuckets = 300 - } - } - } - impression = - MeasurementSpecConfigKt.impression { - privacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.003592 - delta = 1e-15 - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - fixedStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.fixedStart { - start = 0f - width = 1f - } - } - } - duration = - MeasurementSpecConfigKt.duration { - privacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.007418 - delta = 1e-15 - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - fixedStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.fixedStart { - start = 0f - width = 1f - } - } - } - } - } -} diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/postgres/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/reporting/deploy/postgres/BUILD.bazel deleted file mode 100644 index ef43a61efeb..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/postgres/BUILD.bazel +++ /dev/null @@ -1,58 +0,0 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_test") - -kt_jvm_test( - name = "PostgresMeasurementsServiceTest", - srcs = ["PostgresMeasurementsServiceTest.kt"], - tags = [ - "cpu:2", - "no-remote-exec", - ], - test_class = "org.wfanet.measurement.reporting.deploy.postgres.PostgresMeasurementsServiceTest", - runtime_deps = [ - "@wfa_common_jvm//imports/java/org/yaml:snakeyaml", - ], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres:services", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/testing", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/testing:database_provider", - ], -) - -kt_jvm_test( - name = "PostgresReportingSetsServiceTest", - srcs = ["PostgresReportingSetsServiceTest.kt"], - tags = [ - "cpu:2", - "no-remote-exec", - ], - test_class = "org.wfanet.measurement.reporting.deploy.postgres.PostgresReportingSetsServiceTest", - runtime_deps = [ - "@wfa_common_jvm//imports/java/org/yaml:snakeyaml", - ], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres:services", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/testing", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/testing:database_provider", - ], -) - -kt_jvm_test( - name = "PostgresReportsServiceTest", - srcs = ["PostgresReportsServiceTest.kt"], - tags = [ - "cpu:2", - "no-remote-exec", - ], - test_class = "org.wfanet.measurement.reporting.deploy.postgres.PostgresReportsServiceTest", - runtime_deps = [ - "@wfa_common_jvm//imports/java/org/yaml:snakeyaml", - ], - deps = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres:services", - "//src/main/kotlin/org/wfanet/measurement/reporting/deploy/postgres/testing", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/testing:database_provider", - ], -) diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresMeasurementsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresMeasurementsServiceTest.kt deleted file mode 100644 index 37ae12e69cb..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresMeasurementsServiceTest.kt +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres - -import org.junit.ClassRule -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 -import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresDatabaseClient -import org.wfanet.measurement.common.db.r2dbc.postgres.testing.PostgresDatabaseProviderRule -import org.wfanet.measurement.common.identity.IdGenerator -import org.wfanet.measurement.reporting.deploy.postgres.testing.Schemata -import org.wfanet.measurement.reporting.service.internal.testing.MeasurementsServiceTest - -@RunWith(JUnit4::class) -class PostgresMeasurementsServiceTest : MeasurementsServiceTest() { - - override fun newServices(idGenerator: IdGenerator): Services { - val client: PostgresDatabaseClient = databaseProvider.createDatabase() - return Services( - PostgresMeasurementsService(idGenerator, client), - PostgresReportsService(idGenerator, client), - ) - } - - companion object { - @get:ClassRule - @JvmStatic - val databaseProvider = PostgresDatabaseProviderRule(Schemata.REPORTING_CHANGELOG_PATH) - } -} diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresReportingSetsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresReportingSetsServiceTest.kt deleted file mode 100644 index 538b69a136d..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresReportingSetsServiceTest.kt +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres - -import org.junit.ClassRule -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 -import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresDatabaseClient -import org.wfanet.measurement.common.db.r2dbc.postgres.testing.PostgresDatabaseProviderRule -import org.wfanet.measurement.common.identity.IdGenerator -import org.wfanet.measurement.reporting.deploy.postgres.testing.Schemata -import org.wfanet.measurement.reporting.service.internal.testing.ReportingSetsServiceTest - -@RunWith(JUnit4::class) -class PostgresReportingSetsServiceTest : ReportingSetsServiceTest() { - override fun newService(idGenerator: IdGenerator): PostgresReportingSetsService { - val dbClient: PostgresDatabaseClient = databaseProvider.createDatabase() - return PostgresReportingSetsService(idGenerator, dbClient) - } - - companion object { - @get:ClassRule - @JvmStatic - val databaseProvider = PostgresDatabaseProviderRule(Schemata.REPORTING_CHANGELOG_PATH) - } -} diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresReportsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresReportsServiceTest.kt deleted file mode 100644 index d928a4cfcd9..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/postgres/PostgresReportsServiceTest.kt +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.deploy.postgres - -import org.junit.ClassRule -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 -import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresDatabaseClient -import org.wfanet.measurement.common.db.r2dbc.postgres.testing.PostgresDatabaseProviderRule -import org.wfanet.measurement.common.identity.IdGenerator -import org.wfanet.measurement.reporting.deploy.postgres.testing.Schemata -import org.wfanet.measurement.reporting.service.internal.testing.ReportsServiceTest - -@RunWith(JUnit4::class) -class PostgresReportsServiceTest : ReportsServiceTest() { - override fun newServices(idGenerator: IdGenerator): Services { - val client: PostgresDatabaseClient = databaseProvider.createDatabase() - return Services( - PostgresReportsService(idGenerator, client), - PostgresMeasurementsService(idGenerator, client), - PostgresReportingSetsService(idGenerator, client), - ) - } - - companion object { - @get:ClassRule - @JvmStatic - val databaseProvider = PostgresDatabaseProviderRule(Schemata.REPORTING_CHANGELOG_PATH) - } -} diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/BUILD.bazel index e601ec97363..bfad3ea9c46 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/BUILD.bazel +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/BUILD.bazel @@ -10,7 +10,7 @@ kt_jvm_test( "//src/main/proto/wfa/measurement/api/v2alpha:event_group_metadata_descriptors_service_kt_jvm_grpc_proto", "//src/main/proto/wfa/measurement/api/v2alpha/event_group_metadata/testing:test_metadata_message_kt_jvm_proto", "//src/main/proto/wfa/measurement/api/v2alpha/event_group_metadata/testing:test_parent_metadata_message_kt_jvm_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:event_group_kt_jvm_proto", + "//src/main/proto/wfa/measurement/reporting/v2alpha:event_group_kt_jvm_proto", "@wfa_common_jvm//imports/java/com/google/common/truth", "@wfa_common_jvm//imports/kotlin/kotlin/test", "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/CelEnvProviderTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/CelEnvProviderTest.kt index 3d76581af90..fa6a2397805 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/CelEnvProviderTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/CelEnvProviderTest.kt @@ -53,9 +53,9 @@ import org.wfanet.measurement.common.ProtoReflection import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule import org.wfanet.measurement.common.grpc.testing.mockService import org.wfanet.measurement.common.testing.verifyProtoArgument -import org.wfanet.measurement.reporting.v1alpha.EventGroup -import org.wfanet.measurement.reporting.v1alpha.EventGroupKt -import org.wfanet.measurement.reporting.v1alpha.eventGroup +import org.wfanet.measurement.reporting.v2alpha.EventGroup +import org.wfanet.measurement.reporting.v2alpha.EventGroupKt +import org.wfanet.measurement.reporting.v2alpha.eventGroup private const val METADATA_FIELD = "metadata.metadata" private const val MAX_PAGE_SIZE = 1000 diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/BUILD.bazel deleted file mode 100644 index b7e72ce8ea5..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/BUILD.bazel +++ /dev/null @@ -1,108 +0,0 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_test") - -kt_jvm_test( - name = "ReportingSetsServiceTest", - srcs = ["ReportingSetsServiceTest.kt"], - associates = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:reporting_sets_service", - ], - test_class = "org.wfanet.measurement.reporting.service.api.v1alpha.ReportingSetsServiceTest", - deps = [ - "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:principal_server_interceptor", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:principal_server_interceptor", - "//src/main/proto/wfa/measurement/config/reporting:measurement_consumer_config_kt_jvm_proto", - "@wfa_common_jvm//imports/java/com/google/common/truth", - "@wfa_common_jvm//imports/java/com/google/common/truth/extensions/proto", - "@wfa_common_jvm//imports/java/com/google/protobuf", - "@wfa_common_jvm//imports/kotlin/kotlin/test", - "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", - "@wfa_common_jvm//imports/kotlin/org/mockito/kotlin", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc/testing", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/testing", - ], -) - -kt_jvm_test( - name = "SetOperationCompilerTest", - srcs = ["SetOperationCompilerTest.kt"], - associates = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:set_operation_compiler", - ], - test_class = "org.wfanet.measurement.reporting.service.api.v1alpha.SetOperationCompilerTest", - deps = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:resource_key", - "@wfa_common_jvm//imports/java/com/google/common/truth", - "@wfa_common_jvm//imports/java/com/google/common/truth/extensions/proto", - "@wfa_common_jvm//imports/java/com/google/protobuf", - "@wfa_common_jvm//imports/kotlin/kotlin/test", - "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/identity", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/testing", - ], -) - -kt_jvm_test( - name = "EventGroupsServiceTest", - srcs = ["EventGroupsServiceTest.kt"], - data = [ - "//src/main/k8s/testing/secretfiles:all_der_files", - "//src/main/k8s/testing/secretfiles:all_tink_keysets", - ], - test_class = "org.wfanet.measurement.reporting.service.api.v1alpha.EventGroupsServiceTest", - deps = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api:cel_env_provider", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:event_groups_service", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:principal_server_interceptor", - "//src/main/proto/wfa/measurement/api/v2alpha:event_group_kt_jvm_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:event_group_metadata_descriptors_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:event_groups_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/api/v2alpha/event_group_metadata/testing:test_metadata_message_kt_jvm_proto", - "//src/main/proto/wfa/measurement/api/v2alpha/event_group_metadata/testing:test_parent_metadata_message_kt_jvm_proto", - "//src/main/proto/wfa/measurement/config/reporting:measurement_consumer_config_kt_jvm_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:event_groups_service_kt_jvm_grpc_proto", - "@wfa_common_jvm//imports/kotlin/kotlin/test", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto/testing", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto/tink", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc/testing", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/testing", - "@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/dataprovider", - ], -) - -kt_jvm_test( - name = "ReportsServiceTest", - srcs = ["ReportsServiceTest.kt"], - associates = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:reports_service", - ], - data = [ - "//src/main/k8s/testing/secretfiles:root_certs", - "//src/main/k8s/testing/secretfiles:secret_files", - ], - test_class = "org.wfanet.measurement.reporting.service.api.v1alpha.ReportsServiceTest", - deps = [ - "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:principal_server_interceptor", - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha:principal_server_interceptor", - "//src/main/proto/wfa/measurement/api/v2alpha:certificate_kt_jvm_proto", - "//src/main/proto/wfa/measurement/api/v2alpha:certificates_service_kt_jvm_grpc_proto", - "//src/main/proto/wfa/measurement/config/reporting:measurement_consumer_config_kt_jvm_proto", - "//src/main/proto/wfa/measurement/config/reporting:measurement_spec_config_kt_jvm_proto", - "//src/main/proto/wfa/measurement/reporting/v1alpha:reporting_set_kt_jvm_proto", - "@wfa_common_jvm//imports/java/com/google/common/truth", - "@wfa_common_jvm//imports/java/com/google/common/truth/extensions/proto", - "@wfa_common_jvm//imports/java/com/google/protobuf", - "@wfa_common_jvm//imports/kotlin/kotlin/test", - "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", - "@wfa_common_jvm//imports/kotlin/org/mockito/kotlin", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto/testing", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto/tink", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc/testing", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/testing", - "@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/common:key_handles", - "@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/dataprovider", - "@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/duchy", - ], -) diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupsServiceTest.kt deleted file mode 100644 index 17238a9d814..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupsServiceTest.kt +++ /dev/null @@ -1,489 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import com.google.common.truth.Truth.assertThat -import com.google.common.truth.extensions.proto.ProtoTruth.assertThat -import com.google.protobuf.ByteString -import io.grpc.Status -import io.grpc.StatusRuntimeException -import java.nio.file.Path -import java.nio.file.Paths -import java.time.Duration -import kotlin.test.assertFailsWith -import kotlinx.coroutines.runBlocking -import org.junit.Before -import org.junit.Rule -import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 -import org.mockito.kotlin.any -import org.mockito.kotlin.stub -import org.mockito.kotlin.whenever -import org.wfanet.measurement.api.v2alpha.EventGroupKt as CmmsEventGroup -import org.wfanet.measurement.api.v2alpha.EventGroupMetadataDescriptorsGrpcKt.EventGroupMetadataDescriptorsCoroutineImplBase -import org.wfanet.measurement.api.v2alpha.EventGroupMetadataDescriptorsGrpcKt.EventGroupMetadataDescriptorsCoroutineStub -import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineImplBase -import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub -import org.wfanet.measurement.api.v2alpha.ListEventGroupsRequest as CmmsListEventGroupsRequest -import org.wfanet.measurement.api.v2alpha.ListEventGroupsRequestKt -import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey -import org.wfanet.measurement.api.v2alpha.copy -import org.wfanet.measurement.api.v2alpha.encryptionPublicKey -import org.wfanet.measurement.api.v2alpha.eventGroup as cmmsEventGroup -import org.wfanet.measurement.api.v2alpha.eventGroupMetadataDescriptor -import org.wfanet.measurement.api.v2alpha.event_group_metadata.testing.testMetadataMessage -import org.wfanet.measurement.api.v2alpha.event_group_metadata.testing.testParentMetadataMessage -import org.wfanet.measurement.api.v2alpha.listEventGroupMetadataDescriptorsResponse -import org.wfanet.measurement.api.v2alpha.listEventGroupsRequest as cmmsListEventGroupsRequest -import org.wfanet.measurement.api.v2alpha.listEventGroupsResponse as cmmsListEventGroupsResponse -import org.wfanet.measurement.common.ProtoReflection -import org.wfanet.measurement.common.crypto.tink.TinkPrivateKeyHandle -import org.wfanet.measurement.common.crypto.tink.TinkPublicKeyHandle -import org.wfanet.measurement.common.crypto.tink.loadPrivateKey -import org.wfanet.measurement.common.crypto.tink.loadPublicKey -import org.wfanet.measurement.common.getRuntimePath -import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule -import org.wfanet.measurement.common.grpc.testing.mockService -import org.wfanet.measurement.common.pack -import org.wfanet.measurement.common.testing.verifyProtoArgument -import org.wfanet.measurement.config.reporting.measurementConsumerConfig -import org.wfanet.measurement.consent.client.common.toEncryptionPublicKey -import org.wfanet.measurement.consent.client.dataprovider.encryptMetadata -import org.wfanet.measurement.reporting.service.api.CelEnvCacheProvider -import org.wfanet.measurement.reporting.service.api.InMemoryEncryptionKeyPairStore -import org.wfanet.measurement.reporting.v1alpha.EventGroup -import org.wfanet.measurement.reporting.v1alpha.EventGroupKt.metadata -import org.wfanet.measurement.reporting.v1alpha.eventGroup -import org.wfanet.measurement.reporting.v1alpha.listEventGroupsRequest -import org.wfanet.measurement.reporting.v1alpha.listEventGroupsResponse - -private const val DEFAULT_PAGE_SIZE = 50 - -private const val API_AUTHENTICATION_KEY = "nR5QPN7ptx" -private val CONFIG = measurementConsumerConfig { apiKey = API_AUTHENTICATION_KEY } -private val SECRET_FILES_PATH: Path = - checkNotNull( - getRuntimePath( - Paths.get("wfa_measurement_system", "src", "main", "k8s", "testing", "secretfiles") - ) - ) -private val ENCRYPTION_PRIVATE_KEY_HANDLE = loadEncryptionPrivateKey("mc_enc_private.tink") -private val ENCRYPTION_PUBLIC_KEY = - loadEncryptionPublicKey("mc_enc_public.tink").toEncryptionPublicKey() -private const val MEASUREMENT_CONSUMER_REFERENCE_ID = "measurementConsumerRefId" -private val MEASUREMENT_CONSUMER_NAME = - MeasurementConsumerKey(MEASUREMENT_CONSUMER_REFERENCE_ID).toName() -private val ENCRYPTION_KEY_PAIR_STORE = - InMemoryEncryptionKeyPairStore( - mapOf( - MEASUREMENT_CONSUMER_NAME to - listOf(ENCRYPTION_PUBLIC_KEY.data to ENCRYPTION_PRIVATE_KEY_HANDLE) - ) - ) -private val TEST_MESSAGE = testMetadataMessage { publisherId = 15 } -private const val CMMS_EVENT_GROUP_ID = "AAAAAAAAAHs" -private val CMMS_EVENT_GROUP = cmmsEventGroup { - name = "$DATA_PROVIDER_NAME/eventGroups/$CMMS_EVENT_GROUP_ID" - measurementConsumer = MEASUREMENT_CONSUMER_NAME - eventGroupReferenceId = EVENT_GROUP_REFERENCE_ID - measurementConsumerPublicKey = ENCRYPTION_PUBLIC_KEY.pack() - encryptedMetadata = - encryptMetadata( - CmmsEventGroup.metadata { - eventGroupMetadataDescriptor = METADATA_NAME - metadata = TEST_MESSAGE.pack() - }, - ENCRYPTION_PUBLIC_KEY, - ) -} -private val TEST_MESSAGE_2 = testMetadataMessage { publisherId = 5 } -private const val CMMS_EVENT_GROUP_ID_2 = "AAAAAAAAAGs" -private val CMMS_EVENT_GROUP_2 = - CMMS_EVENT_GROUP.copy { - name = "$DATA_PROVIDER_NAME/eventGroups/$CMMS_EVENT_GROUP_ID_2" - eventGroupReferenceId = "id2" - encryptedMetadata = - encryptMetadata( - CmmsEventGroup.metadata { - eventGroupMetadataDescriptor = METADATA_NAME - metadata = TEST_MESSAGE_2.pack() - }, - ENCRYPTION_PUBLIC_KEY, - ) - } -private val EVENT_GROUP = eventGroup { - name = - EventGroupKey( - MEASUREMENT_CONSUMER_REFERENCE_ID, - DATA_PROVIDER_REFERENCE_ID, - CMMS_EVENT_GROUP_ID, - ) - .toName() - dataProvider = DATA_PROVIDER_NAME - eventGroupReferenceId = EVENT_GROUP_REFERENCE_ID - metadata = metadata { - eventGroupMetadataDescriptor = METADATA_NAME - metadata = TEST_MESSAGE.pack() - } -} -private const val PAGE_TOKEN = "base64encodedtoken" -private const val NEXT_PAGE_TOKEN = "base64encodedtoken2" -private const val DATA_PROVIDER_REFERENCE_ID = "123" -private const val DATA_PROVIDER_NAME = "dataProviders/$DATA_PROVIDER_REFERENCE_ID" -private const val EVENT_GROUP_REFERENCE_ID = "edpRefId1" -private const val EVENT_GROUP_PARENT = - "measurementConsumers/$MEASUREMENT_CONSUMER_REFERENCE_ID/dataProviders/$DATA_PROVIDER_REFERENCE_ID" -private const val METADATA_NAME = "$DATA_PROVIDER_NAME/eventGroupMetadataDescriptors/abc" -private val EVENT_GROUP_METADATA_DESCRIPTOR = eventGroupMetadataDescriptor { - name = METADATA_NAME - descriptorSet = ProtoReflection.buildFileDescriptorSet(TEST_MESSAGE.descriptorForType) -} - -@RunWith(JUnit4::class) -class EventGroupsServiceTest { - private val cmmsEventGroupsServiceMock: EventGroupsCoroutineImplBase = mockService { - onBlocking { listEventGroups(any()) } - .thenReturn( - cmmsListEventGroupsResponse { - eventGroups += listOf(CMMS_EVENT_GROUP, CMMS_EVENT_GROUP_2) - nextPageToken = NEXT_PAGE_TOKEN - } - ) - } - private val cmmsEventGroupMetadataDescriptorsServiceMock: - EventGroupMetadataDescriptorsCoroutineImplBase = - mockService { - onBlocking { listEventGroupMetadataDescriptors(any()) } - .thenReturn( - listEventGroupMetadataDescriptorsResponse { - eventGroupMetadataDescriptors += EVENT_GROUP_METADATA_DESCRIPTOR - } - ) - } - - @get:Rule - val grpcTestServerRule = GrpcTestServerRule { - addService(cmmsEventGroupsServiceMock) - addService(cmmsEventGroupMetadataDescriptorsServiceMock) - } - - private lateinit var service: EventGroupsService - - @Before - fun initService() { - val celEnvCacheProvider = - CelEnvCacheProvider( - EventGroupMetadataDescriptorsCoroutineStub(grpcTestServerRule.channel), - EventGroup.getDescriptor(), - Duration.ofSeconds(5), - emptyList(), - ) - - service = - EventGroupsService( - EventGroupsCoroutineStub(grpcTestServerRule.channel), - ENCRYPTION_KEY_PAIR_STORE, - celEnvCacheProvider, - ) - } - - @Test - fun `listEventGroups returns list with no filter`() { - cmmsEventGroupsServiceMock.stub { - onBlocking { listEventGroups(any()) } - .thenReturn( - cmmsListEventGroupsResponse { - eventGroups += - listOf( - CMMS_EVENT_GROUP, - // When there's no filter applied to metadata, it doesn't need to be set on all EGs. - CMMS_EVENT_GROUP_2.copy { clearEncryptedMetadata() }, - ) - nextPageToken = NEXT_PAGE_TOKEN - } - ) - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { - service.listEventGroups( - listEventGroupsRequest { - parent = EVENT_GROUP_PARENT - pageSize = 10 - pageToken = PAGE_TOKEN - } - ) - } - } - - assertThat(result) - .isEqualTo( - listEventGroupsResponse { - eventGroups += - listOf( - EVENT_GROUP, - eventGroup { - name = - EventGroupKey( - MeasurementConsumerKey.fromName(CMMS_EVENT_GROUP_2.measurementConsumer)!! - .measurementConsumerId, - DATA_PROVIDER_REFERENCE_ID, - CMMS_EVENT_GROUP_ID_2, - ) - .toName() - dataProvider = DATA_PROVIDER_NAME - eventGroupReferenceId = "id2" - }, - ) - nextPageToken = NEXT_PAGE_TOKEN - } - ) - - val expectedCmmsEventGroupsRequest = cmmsListEventGroupsRequest { - parent = MEASUREMENT_CONSUMER_NAME - pageSize = 10 - pageToken = PAGE_TOKEN - filter = ListEventGroupsRequestKt.filter { dataProviders += DATA_PROVIDER_NAME } - } - - verifyProtoArgument(cmmsEventGroupsServiceMock, EventGroupsCoroutineImplBase::listEventGroups) - .isEqualTo(expectedCmmsEventGroupsRequest) - } - - @Test - fun `listEventGroups returns list with filter`() { - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { - service.listEventGroups( - listEventGroupsRequest { - parent = EVENT_GROUP_PARENT - filter = "metadata.metadata.publisher_id > 10" - pageToken = PAGE_TOKEN - } - ) - } - } - - assertThat(result) - .isEqualTo( - listEventGroupsResponse { - eventGroups += EVENT_GROUP - nextPageToken = NEXT_PAGE_TOKEN - } - ) - - val expectedCmmsEventGroupsRequest = cmmsListEventGroupsRequest { - parent = MEASUREMENT_CONSUMER_NAME - pageSize = DEFAULT_PAGE_SIZE - pageToken = PAGE_TOKEN - filter = ListEventGroupsRequestKt.filter { dataProviders += DATA_PROVIDER_NAME } - } - - verifyProtoArgument(cmmsEventGroupsServiceMock, EventGroupsCoroutineImplBase::listEventGroups) - .isEqualTo(expectedCmmsEventGroupsRequest) - } - - @Test - fun `listEventGroups omits DataProvider filter in CMMS request when ID is wildcard`() { - val request = listEventGroupsRequest { - parent = "measurementConsumers/$MEASUREMENT_CONSUMER_REFERENCE_ID/dataProviders/-" - } - - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { service.listEventGroups(request) } - } - - verifyProtoArgument(cmmsEventGroupsServiceMock, EventGroupsCoroutineImplBase::listEventGroups) - .isEqualTo( - cmmsListEventGroupsRequest { - parent = MEASUREMENT_CONSUMER_NAME - pageSize = DEFAULT_PAGE_SIZE - filter = CmmsListEventGroupsRequest.Filter.getDefaultInstance() - } - ) - } - - @Test - fun `listEventGroups returns list with filter when event group with metadata and one without`() { - runBlocking { - whenever(cmmsEventGroupsServiceMock.listEventGroups(any())) - .thenReturn( - cmmsListEventGroupsResponse { - eventGroups += - listOf(CMMS_EVENT_GROUP, CMMS_EVENT_GROUP_2.copy { clearEncryptedMetadata() }) - } - ) - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { - service.listEventGroups( - listEventGroupsRequest { - parent = EVENT_GROUP_PARENT - filter = "metadata.metadata.publisher_id > 10" - pageToken = PAGE_TOKEN - } - ) - } - } - - assertThat(result).isEqualTo(listEventGroupsResponse { eventGroups += EVENT_GROUP }) - - val expectedCmmsEventGroupsRequest = cmmsListEventGroupsRequest { - parent = MEASUREMENT_CONSUMER_NAME - pageSize = DEFAULT_PAGE_SIZE - pageToken = PAGE_TOKEN - filter = ListEventGroupsRequestKt.filter { dataProviders += DATA_PROVIDER_NAME } - } - - verifyProtoArgument(cmmsEventGroupsServiceMock, EventGroupsCoroutineImplBase::listEventGroups) - .isEqualTo(expectedCmmsEventGroupsRequest) - } - - @Test - fun `listEventGroups throws FAILED_PRECONDITION if message descriptor not found`() { - val eventGroupInvalidMetadata = cmmsEventGroup { - name = "$DATA_PROVIDER_NAME/eventGroups/$CMMS_EVENT_GROUP_ID" - measurementConsumer = MEASUREMENT_CONSUMER_NAME - eventGroupReferenceId = "id1" - measurementConsumerPublicKey = ENCRYPTION_PUBLIC_KEY.pack() - encryptedMetadata = - encryptMetadata( - CmmsEventGroup.metadata { - eventGroupMetadataDescriptor = METADATA_NAME - metadata = testParentMetadataMessage { name = "name" }.pack() - }, - ENCRYPTION_PUBLIC_KEY, - ) - } - cmmsEventGroupsServiceMock.stub { - onBlocking { listEventGroups(any()) } - .thenReturn( - cmmsListEventGroupsResponse { - eventGroups += listOf(CMMS_EVENT_GROUP, CMMS_EVENT_GROUP_2, eventGroupInvalidMetadata) - } - ) - } - - val result = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { - service.listEventGroups( - listEventGroupsRequest { - parent = EVENT_GROUP_PARENT - filter = "metadata.metadata.publisher_id > 10" - } - ) - } - } - } - - assertThat(result.status.code).isEqualTo(Status.Code.FAILED_PRECONDITION) - } - - @Test - fun `listEventGroups throws FAILED_PRECONDITION if private key not found`() { - val eventGroupInvalidPublicKey = cmmsEventGroup { - name = "$DATA_PROVIDER_NAME/eventGroups/$CMMS_EVENT_GROUP_ID" - measurementConsumer = MEASUREMENT_CONSUMER_NAME - eventGroupReferenceId = "id1" - measurementConsumerPublicKey = - encryptionPublicKey { data = ByteString.copyFromUtf8("consumerkey") }.pack() - encryptedMetadata = - encryptMetadata( - CmmsEventGroup.metadata { - eventGroupMetadataDescriptor = METADATA_NAME - metadata = testParentMetadataMessage { name = "name" }.pack() - }, - ENCRYPTION_PUBLIC_KEY, - ) - } - cmmsEventGroupsServiceMock.stub { - onBlocking { listEventGroups(any()) } - .thenReturn( - cmmsListEventGroupsResponse { - eventGroups += listOf(CMMS_EVENT_GROUP, CMMS_EVENT_GROUP_2, eventGroupInvalidPublicKey) - } - ) - } - - val result = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { - service.listEventGroups( - listEventGroupsRequest { - parent = EVENT_GROUP_PARENT - filter = "metadata.metadata.publisher_id > 10" - } - ) - } - } - } - - assertThat(result.status.code).isEqualTo(Status.Code.FAILED_PRECONDITION) - } - - @Test - fun `listEventGroups throws INVALID_ARGUMENT if parent not specified`() { - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { - service.listEventGroups( - listEventGroupsRequest { - filter = "metadata.metadata.publisher_id > 10" - pageToken = PAGE_TOKEN - ENCRYPTION_KEY_PAIR_STORE - } - ) - } - } - } - - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - assertThat(exception).hasMessageThat().ignoringCase().contains("parent") - } - - @Test - fun `listEventGroups throws UNAUTHENTICATED if principal not found`() { - val result = - assertFailsWith { - runBlocking { - service.listEventGroups( - listEventGroupsRequest { - parent = EVENT_GROUP_PARENT - filter = "metadata.metadata.publisher_id > 10" - } - ) - } - } - - assertThat(result.status.code).isEqualTo(Status.Code.UNAUTHENTICATED) - } -} - -private fun loadEncryptionPrivateKey(fileName: String): TinkPrivateKeyHandle { - return loadPrivateKey(SECRET_FILES_PATH.resolve(fileName).toFile()) -} - -private fun loadEncryptionPublicKey(fileName: String): TinkPublicKeyHandle { - return loadPublicKey(SECRET_FILES_PATH.resolve(fileName).toFile()) -} diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportingSetsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportingSetsServiceTest.kt deleted file mode 100644 index 20b13d92928..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportingSetsServiceTest.kt +++ /dev/null @@ -1,750 +0,0 @@ -/* - * Copyright 2022 The Cross-Media Measurement Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import com.google.common.truth.Truth.assertThat -import com.google.common.truth.extensions.proto.ProtoTruth.assertThat -import io.grpc.Status -import io.grpc.StatusRuntimeException -import kotlin.test.assertFailsWith -import kotlinx.coroutines.flow.flowOf -import kotlinx.coroutines.runBlocking -import org.junit.Before -import org.junit.Rule -import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 -import org.mockito.kotlin.any -import org.wfanet.measurement.api.v2alpha.DataProviderKey -import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey -import org.wfanet.measurement.api.v2alpha.withDataProviderPrincipal -import org.wfanet.measurement.common.base64UrlEncode -import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule -import org.wfanet.measurement.common.grpc.testing.mockService -import org.wfanet.measurement.common.identity.externalIdToApiId -import org.wfanet.measurement.common.testing.verifyProtoArgument -import org.wfanet.measurement.config.reporting.measurementConsumerConfig -import org.wfanet.measurement.internal.reporting.ReportingSet as InternalReportingSet -import org.wfanet.measurement.internal.reporting.ReportingSetKt.eventGroupKey -import org.wfanet.measurement.internal.reporting.ReportingSetsGrpcKt.ReportingSetsCoroutineImplBase -import org.wfanet.measurement.internal.reporting.ReportingSetsGrpcKt.ReportingSetsCoroutineStub -import org.wfanet.measurement.internal.reporting.StreamReportingSetsRequestKt -import org.wfanet.measurement.internal.reporting.copy -import org.wfanet.measurement.internal.reporting.reportingSet as internalReportingSet -import org.wfanet.measurement.internal.reporting.streamReportingSetsRequest -import org.wfanet.measurement.reporting.v1alpha.ListReportingSetsPageTokenKt.previousPageEnd -import org.wfanet.measurement.reporting.v1alpha.ListReportingSetsRequest -import org.wfanet.measurement.reporting.v1alpha.ReportingSet -import org.wfanet.measurement.reporting.v1alpha.copy -import org.wfanet.measurement.reporting.v1alpha.createReportingSetRequest -import org.wfanet.measurement.reporting.v1alpha.listReportingSetsPageToken -import org.wfanet.measurement.reporting.v1alpha.listReportingSetsRequest -import org.wfanet.measurement.reporting.v1alpha.listReportingSetsResponse -import org.wfanet.measurement.reporting.v1alpha.reportingSet - -private const val DEFAULT_PAGE_SIZE = 50 -private const val MAX_PAGE_SIZE = 1000 -private const val PAGE_SIZE = 2 - -private const val API_AUTHENTICATION_KEY = "nR5QPN7ptx" -private val CONFIG = measurementConsumerConfig { apiKey = API_AUTHENTICATION_KEY } - -// Measurement consumer IDs and names -private const val MEASUREMENT_CONSUMER_EXTERNAL_ID = 111L -private const val MEASUREMENT_CONSUMER_EXTERNAL_ID_2 = 112L -private val MEASUREMENT_CONSUMER_REFERENCE_ID = externalIdToApiId(MEASUREMENT_CONSUMER_EXTERNAL_ID) -private val MEASUREMENT_CONSUMER_REFERENCE_ID_2 = - externalIdToApiId(MEASUREMENT_CONSUMER_EXTERNAL_ID_2) -private val MEASUREMENT_CONSUMER_NAME = - MeasurementConsumerKey(MEASUREMENT_CONSUMER_REFERENCE_ID).toName() -private val MEASUREMENT_CONSUMER_NAME_2 = - MeasurementConsumerKey(MEASUREMENT_CONSUMER_REFERENCE_ID_2).toName() - -// Data provider IDs and names -private const val DATA_PROVIDER_EXTERNAL_ID = 221L -private const val DATA_PROVIDER_EXTERNAL_ID_2 = 222L -private const val DATA_PROVIDER_EXTERNAL_ID_3 = 223L -private val DATA_PROVIDER_REFERENCE_ID = externalIdToApiId(DATA_PROVIDER_EXTERNAL_ID) -private val DATA_PROVIDER_REFERENCE_ID_2 = externalIdToApiId(DATA_PROVIDER_EXTERNAL_ID_2) -private val DATA_PROVIDER_REFERENCE_ID_3 = externalIdToApiId(DATA_PROVIDER_EXTERNAL_ID_3) - -private val DATA_PROVIDER_NAME = DataProviderKey(DATA_PROVIDER_REFERENCE_ID).toName() -private val DATA_PROVIDER_NAME_2 = DataProviderKey(DATA_PROVIDER_REFERENCE_ID_2).toName() -private val DATA_PROVIDER_NAME_3 = DataProviderKey(DATA_PROVIDER_REFERENCE_ID_3).toName() - -// Reporting set IDs and names -private val REPORTING_SET_EXTERNAL_ID = 331L -private val REPORTING_SET_EXTERNAL_ID_2 = 332L -private val REPORTING_SET_EXTERNAL_ID_3 = 333L - -private val REPORTING_SET_NAME = - ReportingSetKey(MEASUREMENT_CONSUMER_REFERENCE_ID, externalIdToApiId(REPORTING_SET_EXTERNAL_ID)) - .toName() -private val REPORTING_SET_NAME_2 = - ReportingSetKey(MEASUREMENT_CONSUMER_REFERENCE_ID, externalIdToApiId(REPORTING_SET_EXTERNAL_ID_2)) - .toName() -private val REPORTING_SET_NAME_3 = - ReportingSetKey(MEASUREMENT_CONSUMER_REFERENCE_ID, externalIdToApiId(REPORTING_SET_EXTERNAL_ID_3)) - .toName() - -// Event group IDs and names -private val EVENT_GROUP_EXTERNAL_ID = 441L -private val EVENT_GROUP_EXTERNAL_ID_2 = 442L -private val EVENT_GROUP_EXTERNAL_ID_3 = 443L -private val EVENT_GROUP_REFERENCE_ID = externalIdToApiId(EVENT_GROUP_EXTERNAL_ID) -private val EVENT_GROUP_REFERENCE_ID_2 = externalIdToApiId(EVENT_GROUP_EXTERNAL_ID_2) -private val EVENT_GROUP_REFERENCE_ID_3 = externalIdToApiId(EVENT_GROUP_EXTERNAL_ID_3) - -private val EVENT_GROUP_NAME = - EventGroupKey( - MEASUREMENT_CONSUMER_REFERENCE_ID, - DATA_PROVIDER_REFERENCE_ID, - EVENT_GROUP_REFERENCE_ID, - ) - .toName() -private val EVENT_GROUP_NAME_2 = - EventGroupKey( - MEASUREMENT_CONSUMER_REFERENCE_ID, - DATA_PROVIDER_REFERENCE_ID_2, - EVENT_GROUP_REFERENCE_ID_2, - ) - .toName() -private val EVENT_GROUP_NAME_3 = - EventGroupKey( - MEASUREMENT_CONSUMER_REFERENCE_ID, - DATA_PROVIDER_REFERENCE_ID_3, - EVENT_GROUP_REFERENCE_ID_3, - ) - .toName() -private val EVENT_GROUP_NAMES = listOf(EVENT_GROUP_NAME, EVENT_GROUP_NAME_2, EVENT_GROUP_NAME_3) - -// Event group keys -private val EVENT_GROUP_KEY = eventGroupKey { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - dataProviderReferenceId = DATA_PROVIDER_REFERENCE_ID - eventGroupReferenceId = EVENT_GROUP_REFERENCE_ID -} -private val EVENT_GROUP_KEY_2 = eventGroupKey { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - dataProviderReferenceId = DATA_PROVIDER_REFERENCE_ID_2 - eventGroupReferenceId = EVENT_GROUP_REFERENCE_ID_2 -} -private val EVENT_GROUP_KEY_3 = eventGroupKey { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - dataProviderReferenceId = DATA_PROVIDER_REFERENCE_ID_3 - eventGroupReferenceId = EVENT_GROUP_REFERENCE_ID_3 -} -private val EVENT_GROUP_KEYS = listOf(EVENT_GROUP_KEY, EVENT_GROUP_KEY_2, EVENT_GROUP_KEY_3) - -// Event filters -private const val FILTER = "AGE>20" - -// Reporting sets -private val DISPLAY_NAME = REPORTING_SET_NAME + FILTER -private val DISPLAY_NAME_2 = REPORTING_SET_NAME_2 + FILTER -private val DISPLAY_NAME_3 = REPORTING_SET_NAME_3 + FILTER - -private val REPORTING_SET: ReportingSet = reportingSet { - name = REPORTING_SET_NAME - eventGroups.addAll(EVENT_GROUP_NAMES) - filter = FILTER - displayName = DISPLAY_NAME -} - -// Internal reporting sets -private val INTERNAL_REPORTING_SET: InternalReportingSet = internalReportingSet { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - externalReportingSetId = REPORTING_SET_EXTERNAL_ID - eventGroupKeys.addAll(EVENT_GROUP_KEYS) - filter = FILTER - displayName = DISPLAY_NAME -} - -@RunWith(JUnit4::class) -class ReportingSetsServiceTest { - - private val internalReportingSetsMock: ReportingSetsCoroutineImplBase = - mockService() { - onBlocking { createReportingSet(any()) }.thenReturn(INTERNAL_REPORTING_SET) - onBlocking { streamReportingSets(any()) } - .thenReturn( - flowOf( - INTERNAL_REPORTING_SET, - INTERNAL_REPORTING_SET.copy { - externalReportingSetId = REPORTING_SET_EXTERNAL_ID_2 - displayName = DISPLAY_NAME_2 - }, - INTERNAL_REPORTING_SET.copy { - externalReportingSetId = REPORTING_SET_EXTERNAL_ID_3 - displayName = DISPLAY_NAME_3 - }, - ) - ) - } - - @get:Rule val grpcTestServerRule = GrpcTestServerRule { addService(internalReportingSetsMock) } - - private lateinit var service: ReportingSetsService - - @Before - fun initService() { - service = ReportingSetsService(ReportingSetsCoroutineStub(grpcTestServerRule.channel)) - } - - @Test - fun `createReportingSet returns reporting set`() { - val request = createReportingSetRequest { - parent = MEASUREMENT_CONSUMER_NAME - reportingSet = REPORTING_SET - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { service.createReportingSet(request) } - } - - val expected = REPORTING_SET - - verifyProtoArgument( - internalReportingSetsMock, - ReportingSetsCoroutineImplBase::createReportingSet, - ) - .isEqualTo(INTERNAL_REPORTING_SET.copy { clearExternalReportingSetId() }) - - assertThat(result).isEqualTo(expected) - } - - @Test - fun `createReportingSet throws UNAUTHENTICATED when no principal is found`() { - val request = createReportingSetRequest { - parent = MEASUREMENT_CONSUMER_NAME - reportingSet = REPORTING_SET - } - - val exception = - assertFailsWith { - runBlocking { service.createReportingSet(request) } - } - assertThat(exception.status.code).isEqualTo(Status.Code.UNAUTHENTICATED) - } - - @Test - fun `createReportingSet throws PERMISSION_DENIED when MC caller doesn't match`() { - val request = createReportingSetRequest { - parent = MEASUREMENT_CONSUMER_NAME - reportingSet = REPORTING_SET - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME_2, CONFIG) { - runBlocking { service.createReportingSet(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.PERMISSION_DENIED) - assertThat(exception.status.description) - .isEqualTo("Cannot create a ReportingSet for another MeasurementConsumer.") - } - - @Test - fun `createReportingSet throws UNAUTHENTICATED when caller is not MeasurementConsumer`() { - val request = createReportingSetRequest { - parent = MEASUREMENT_CONSUMER_NAME - reportingSet = REPORTING_SET - } - - val exception = - assertFailsWith { - withDataProviderPrincipal(DATA_PROVIDER_NAME) { - runBlocking { service.createReportingSet(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.UNAUTHENTICATED) - assertThat(exception.status.description).isEqualTo("No ReportingPrincipal found") - } - - @Test - fun `createReportingSet throws INVALID_ARGUMENT when parent is missing`() { - val request = createReportingSetRequest { reportingSet = REPORTING_SET } - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { service.createReportingSet(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `createReportingSet throws INVALID_ARGUMENT if ReportingSet is not specified`() { - val request = createReportingSetRequest { parent = MEASUREMENT_CONSUMER_NAME } - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { service.createReportingSet(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `createReportingSet throws INVALID_ARGUMENT when EventGroups in ReportingSet is empty`() { - val request = createReportingSetRequest { - parent = MEASUREMENT_CONSUMER_NAME - reportingSet = REPORTING_SET.copy { eventGroups.clear() } - } - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { service.createReportingSet(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - assertThat(exception.status.description) - .isEqualTo("EventGroups in ReportingSet cannot be empty.") - } - - @Test - fun `createReportingSet throws INVALID_ARGUMENT when there is any invalid EventGroup`() { - val invalidEventGroupName = "invalid" - val request = createReportingSetRequest { - parent = MEASUREMENT_CONSUMER_NAME - reportingSet = reportingSet { - name = REPORTING_SET_NAME - eventGroups.addAll(listOf(EVENT_GROUP_NAME, invalidEventGroupName)) - filter = FILTER - displayName = DISPLAY_NAME - } - } - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { service.createReportingSet(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - assertThat(exception.status.description) - .isEqualTo("EventGroup is either unspecified or invalid.") - } - - @Test - fun `listReportingSets returns without a next page token when there is no previous page token`() { - val request = listReportingSetsRequest { parent = MEASUREMENT_CONSUMER_NAME } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { service.listReportingSets(request) } - } - - val expected = listReportingSetsResponse { - reportingSets += REPORTING_SET - reportingSets += - REPORTING_SET.copy { - name = REPORTING_SET_NAME_2 - displayName = DISPLAY_NAME_2 - } - reportingSets += - REPORTING_SET.copy { - name = REPORTING_SET_NAME_3 - displayName = DISPLAY_NAME_3 - } - } - - verifyProtoArgument( - internalReportingSetsMock, - ReportingSetsCoroutineImplBase::streamReportingSets, - ) - .isEqualTo( - streamReportingSetsRequest { - limit = DEFAULT_PAGE_SIZE + 1 - filter = - StreamReportingSetsRequestKt.filter { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - } - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReportingSets returns with a next page token when there is no previous page token`() { - val request = listReportingSetsRequest { - parent = MEASUREMENT_CONSUMER_NAME - pageSize = PAGE_SIZE - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { service.listReportingSets(request) } - } - - val expected = listReportingSetsResponse { - reportingSets += REPORTING_SET - reportingSets += - REPORTING_SET.copy { - name = REPORTING_SET_NAME_2 - displayName = DISPLAY_NAME_2 - } - nextPageToken = - listReportingSetsPageToken { - pageSize = PAGE_SIZE - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - lastReportingSet = previousPageEnd { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - externalReportingSetId = REPORTING_SET_EXTERNAL_ID_2 - } - } - .toByteString() - .base64UrlEncode() - } - - verifyProtoArgument( - internalReportingSetsMock, - ReportingSetsCoroutineImplBase::streamReportingSets, - ) - .isEqualTo( - streamReportingSetsRequest { - limit = PAGE_SIZE + 1 - filter = - StreamReportingSetsRequestKt.filter { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - } - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReportingSets returns with a next page token when there is a previous page token`() { - val request = listReportingSetsRequest { - parent = MEASUREMENT_CONSUMER_NAME - pageSize = PAGE_SIZE - pageToken = - listReportingSetsPageToken { - pageSize = PAGE_SIZE - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - lastReportingSet = previousPageEnd { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - externalReportingSetId = REPORTING_SET_EXTERNAL_ID - } - } - .toByteString() - .base64UrlEncode() - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { service.listReportingSets(request) } - } - - val expected = listReportingSetsResponse { - reportingSets += REPORTING_SET - reportingSets += - REPORTING_SET.copy { - name = REPORTING_SET_NAME_2 - displayName = DISPLAY_NAME_2 - } - nextPageToken = - listReportingSetsPageToken { - pageSize = PAGE_SIZE - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - lastReportingSet = previousPageEnd { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - externalReportingSetId = REPORTING_SET_EXTERNAL_ID_2 - } - } - .toByteString() - .base64UrlEncode() - } - - verifyProtoArgument( - internalReportingSetsMock, - ReportingSetsCoroutineImplBase::streamReportingSets, - ) - .isEqualTo( - streamReportingSetsRequest { - limit = PAGE_SIZE + 1 - filter = - StreamReportingSetsRequestKt.filter { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - externalReportingSetIdAfter = REPORTING_SET_EXTERNAL_ID - } - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReportingSets with page size replaced with a valid value and no previous page token`() { - val invalidPageSize = MAX_PAGE_SIZE * 2 - val request = listReportingSetsRequest { - parent = MEASUREMENT_CONSUMER_NAME - pageSize = invalidPageSize - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { service.listReportingSets(request) } - } - - val expected = listReportingSetsResponse { - reportingSets += REPORTING_SET - reportingSets += - REPORTING_SET.copy { - name = REPORTING_SET_NAME_2 - displayName = DISPLAY_NAME_2 - } - reportingSets += - REPORTING_SET.copy { - name = REPORTING_SET_NAME_3 - displayName = DISPLAY_NAME_3 - } - } - - verifyProtoArgument( - internalReportingSetsMock, - ReportingSetsCoroutineImplBase::streamReportingSets, - ) - .isEqualTo( - streamReportingSetsRequest { - limit = MAX_PAGE_SIZE + 1 - filter = - StreamReportingSetsRequestKt.filter { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - } - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReportingSets with invalid page size replaced with the one in previous page token`() { - val invalidPageSize = MAX_PAGE_SIZE * 2 - val previousPageSize = PAGE_SIZE - val request = listReportingSetsRequest { - parent = MEASUREMENT_CONSUMER_NAME - pageSize = invalidPageSize - pageToken = - listReportingSetsPageToken { - pageSize = previousPageSize - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - lastReportingSet = previousPageEnd { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - externalReportingSetId = REPORTING_SET_EXTERNAL_ID - } - } - .toByteString() - .base64UrlEncode() - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { service.listReportingSets(request) } - } - - val expected = listReportingSetsResponse { - reportingSets += REPORTING_SET - reportingSets += - REPORTING_SET.copy { - name = REPORTING_SET_NAME_2 - displayName = DISPLAY_NAME_2 - } - nextPageToken = - listReportingSetsPageToken { - pageSize = previousPageSize - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - lastReportingSet = previousPageEnd { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - externalReportingSetId = REPORTING_SET_EXTERNAL_ID_2 - } - } - .toByteString() - .base64UrlEncode() - } - - verifyProtoArgument( - internalReportingSetsMock, - ReportingSetsCoroutineImplBase::streamReportingSets, - ) - .isEqualTo( - streamReportingSetsRequest { - limit = previousPageSize + 1 - filter = - StreamReportingSetsRequestKt.filter { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - externalReportingSetIdAfter = REPORTING_SET_EXTERNAL_ID - } - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReportingSets with page size replacing the one in previous page token`() { - val newPageSize = PAGE_SIZE - val previousPageSize = 1 - val request = listReportingSetsRequest { - parent = MEASUREMENT_CONSUMER_NAME - pageSize = newPageSize - pageToken = - listReportingSetsPageToken { - pageSize = previousPageSize - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - lastReportingSet = previousPageEnd { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - externalReportingSetId = REPORTING_SET_EXTERNAL_ID - } - } - .toByteString() - .base64UrlEncode() - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { service.listReportingSets(request) } - } - - val expected = listReportingSetsResponse { - reportingSets += REPORTING_SET - reportingSets += - REPORTING_SET.copy { - name = REPORTING_SET_NAME_2 - displayName = DISPLAY_NAME_2 - } - nextPageToken = - listReportingSetsPageToken { - pageSize = newPageSize - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - lastReportingSet = previousPageEnd { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - externalReportingSetId = REPORTING_SET_EXTERNAL_ID_2 - } - } - .toByteString() - .base64UrlEncode() - } - - verifyProtoArgument( - internalReportingSetsMock, - ReportingSetsCoroutineImplBase::streamReportingSets, - ) - .isEqualTo( - streamReportingSetsRequest { - limit = newPageSize + 1 - filter = - StreamReportingSetsRequestKt.filter { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID - externalReportingSetIdAfter = REPORTING_SET_EXTERNAL_ID - } - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReportingSets throws UNAUTHENTICATED when no principal is found`() { - val request = listReportingSetsRequest { parent = MEASUREMENT_CONSUMER_NAME } - val exception = - assertFailsWith { runBlocking { service.listReportingSets(request) } } - assertThat(exception.status.code).isEqualTo(Status.Code.UNAUTHENTICATED) - } - - @Test - fun `listReportingSets throws PERMISSION_DENIED when MeasurementConsumer caller doesn't match`() { - val request = listReportingSetsRequest { parent = MEASUREMENT_CONSUMER_NAME } - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME_2, CONFIG) { - runBlocking { service.listReportingSets(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.PERMISSION_DENIED) - assertThat(exception.status.description) - .isEqualTo("Cannot list ReportingSets belonging to other MeasurementConsumers.") - } - - @Test - fun `listReportingSets throws UNAUTHENTICATED when the caller is not MeasurementConsumer`() { - val request = listReportingSetsRequest { parent = MEASUREMENT_CONSUMER_NAME } - val exception = - assertFailsWith { - withDataProviderPrincipal(DATA_PROVIDER_NAME) { - runBlocking { service.listReportingSets(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.UNAUTHENTICATED) - assertThat(exception.status.description).isEqualTo("No ReportingPrincipal found") - } - - @Test - fun `listReportingSets throws INVALID_ARGUMENT when page size is less than 0`() { - val request = listReportingSetsRequest { - parent = MEASUREMENT_CONSUMER_NAME - pageSize = -1 - } - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { service.listReportingSets(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - assertThat(exception.status.description).isEqualTo("Page size cannot be less than 0") - } - - @Test - fun `listReportingSets throws INVALID_ARGUMENT when parent is unspecified`() { - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { service.listReportingSets(ListReportingSetsRequest.getDefaultInstance()) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `listReportingSets throws INVALID_ARGUMENT when mc id doesn't match one in page token`() { - val request = listReportingSetsRequest { - parent = MEASUREMENT_CONSUMER_NAME - pageToken = - listReportingSetsPageToken { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID_2 - lastReportingSet = previousPageEnd { - measurementConsumerReferenceId = MEASUREMENT_CONSUMER_REFERENCE_ID_2 - externalReportingSetId = REPORTING_SET_EXTERNAL_ID - } - } - .toByteString() - .base64UrlEncode() - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { - runBlocking { service.listReportingSets(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } -} diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportsServiceTest.kt deleted file mode 100644 index 3a31d8ad53a..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportsServiceTest.kt +++ /dev/null @@ -1,4656 +0,0 @@ -/* - * Copyright 2022 The Cross-Media Measurement Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import com.google.common.truth.Truth.assertThat -import com.google.common.truth.extensions.proto.ProtoTruth.assertThat -import com.google.protobuf.duration -import com.google.protobuf.kotlin.toByteString -import com.google.protobuf.kotlin.toByteStringUtf8 -import com.google.protobuf.timestamp -import com.google.protobuf.util.Durations -import com.google.protobuf.util.Timestamps -import com.google.type.interval -import io.grpc.Status -import io.grpc.StatusException -import io.grpc.StatusRuntimeException -import java.nio.file.Paths -import java.security.cert.X509Certificate -import java.time.Duration -import java.time.Instant -import kotlin.random.Random -import kotlin.test.assertFails -import kotlin.test.assertFailsWith -import kotlinx.coroutines.flow.flowOf -import kotlinx.coroutines.runBlocking -import org.junit.Before -import org.junit.Rule -import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 -import org.mockito.kotlin.KArgumentCaptor -import org.mockito.kotlin.any -import org.mockito.kotlin.argumentCaptor -import org.mockito.kotlin.doReturn -import org.mockito.kotlin.eq -import org.mockito.kotlin.mock -import org.mockito.kotlin.stub -import org.mockito.kotlin.times -import org.mockito.kotlin.verify -import org.mockito.kotlin.verifyBlocking -import org.mockito.kotlin.whenever -import org.wfanet.measurement.api.v2alpha.Certificate -import org.wfanet.measurement.api.v2alpha.CertificatesGrpcKt.CertificatesCoroutineImplBase -import org.wfanet.measurement.api.v2alpha.CertificatesGrpcKt.CertificatesCoroutineStub -import org.wfanet.measurement.api.v2alpha.CreateMeasurementRequest -import org.wfanet.measurement.api.v2alpha.DataProviderCertificateKey -import org.wfanet.measurement.api.v2alpha.DataProviderKey -import org.wfanet.measurement.api.v2alpha.DataProvidersGrpcKt.DataProvidersCoroutineImplBase -import org.wfanet.measurement.api.v2alpha.DataProvidersGrpcKt.DataProvidersCoroutineStub -import org.wfanet.measurement.api.v2alpha.EncryptionPublicKey -import org.wfanet.measurement.api.v2alpha.EventGroupKey as CmmsEventGroupKey -import org.wfanet.measurement.api.v2alpha.GetDataProviderRequest -import org.wfanet.measurement.api.v2alpha.Measurement -import org.wfanet.measurement.api.v2alpha.Measurement.DataProviderEntry.Value.ENCRYPTED_REQUISITION_SPEC_FIELD_NUMBER -import org.wfanet.measurement.api.v2alpha.MeasurementConsumer -import org.wfanet.measurement.api.v2alpha.MeasurementConsumerCertificateKey -import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey -import org.wfanet.measurement.api.v2alpha.MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineImplBase -import org.wfanet.measurement.api.v2alpha.MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineStub -import org.wfanet.measurement.api.v2alpha.MeasurementKey -import org.wfanet.measurement.api.v2alpha.MeasurementKt -import org.wfanet.measurement.api.v2alpha.MeasurementKt.DataProviderEntryKt -import org.wfanet.measurement.api.v2alpha.MeasurementKt.dataProviderEntry -import org.wfanet.measurement.api.v2alpha.MeasurementKt.failure -import org.wfanet.measurement.api.v2alpha.MeasurementKt.resultOutput -import org.wfanet.measurement.api.v2alpha.MeasurementSpec -import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt -import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.vidSamplingInterval -import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt.MeasurementsCoroutineImplBase -import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt.MeasurementsCoroutineStub -import org.wfanet.measurement.api.v2alpha.RequisitionSpec -import org.wfanet.measurement.api.v2alpha.RequisitionSpecKt -import org.wfanet.measurement.api.v2alpha.RequisitionSpecKt.EventGroupEntryKt -import org.wfanet.measurement.api.v2alpha.RequisitionSpecKt.eventFilter -import org.wfanet.measurement.api.v2alpha.RequisitionSpecKt.eventGroupEntry -import org.wfanet.measurement.api.v2alpha.certificate -import org.wfanet.measurement.api.v2alpha.copy -import org.wfanet.measurement.api.v2alpha.createMeasurementRequest -import org.wfanet.measurement.api.v2alpha.dataProvider -import org.wfanet.measurement.api.v2alpha.differentialPrivacyParams -import org.wfanet.measurement.api.v2alpha.encryptionPublicKey -import org.wfanet.measurement.api.v2alpha.getCertificateRequest -import org.wfanet.measurement.api.v2alpha.getDataProviderRequest -import org.wfanet.measurement.api.v2alpha.getMeasurementConsumerRequest -import org.wfanet.measurement.api.v2alpha.getMeasurementRequest -import org.wfanet.measurement.api.v2alpha.measurement -import org.wfanet.measurement.api.v2alpha.measurementConsumer -import org.wfanet.measurement.api.v2alpha.measurementSpec -import org.wfanet.measurement.api.v2alpha.requisitionSpec -import org.wfanet.measurement.api.v2alpha.unpack -import org.wfanet.measurement.api.v2alpha.withDataProviderPrincipal -import org.wfanet.measurement.common.base64UrlEncode -import org.wfanet.measurement.common.crypto.Hashing -import org.wfanet.measurement.common.crypto.PrivateKeyHandle -import org.wfanet.measurement.common.crypto.SigningKeyHandle -import org.wfanet.measurement.common.crypto.readCertificate -import org.wfanet.measurement.common.crypto.subjectKeyIdentifier -import org.wfanet.measurement.common.crypto.testing.loadSigningKey -import org.wfanet.measurement.common.crypto.tink.loadPrivateKey -import org.wfanet.measurement.common.getRuntimePath -import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule -import org.wfanet.measurement.common.grpc.testing.mockService -import org.wfanet.measurement.common.identity.ExternalId -import org.wfanet.measurement.common.identity.externalIdToApiId -import org.wfanet.measurement.common.pack -import org.wfanet.measurement.common.readByteString -import org.wfanet.measurement.common.testing.captureFirst -import org.wfanet.measurement.common.testing.verifyProtoArgument -import org.wfanet.measurement.common.toProtoDuration -import org.wfanet.measurement.common.toProtoTime -import org.wfanet.measurement.config.reporting.MeasurementSpecConfigKt -import org.wfanet.measurement.config.reporting.measurementConsumerConfig -import org.wfanet.measurement.config.reporting.measurementSpecConfig -import org.wfanet.measurement.consent.client.dataprovider.decryptRequisitionSpec -import org.wfanet.measurement.consent.client.dataprovider.verifyMeasurementSpec -import org.wfanet.measurement.consent.client.dataprovider.verifyRequisitionSpec -import org.wfanet.measurement.consent.client.duchy.encryptResult -import org.wfanet.measurement.consent.client.duchy.signResult -import org.wfanet.measurement.consent.client.measurementconsumer.encryptRequisitionSpec -import org.wfanet.measurement.consent.client.measurementconsumer.signEncryptionPublicKey -import org.wfanet.measurement.consent.client.measurementconsumer.signMeasurementSpec -import org.wfanet.measurement.consent.client.measurementconsumer.signRequisitionSpec -import org.wfanet.measurement.internal.reporting.CreateReportRequestKt as InternalCreateReportRequestKt -import org.wfanet.measurement.internal.reporting.GetReportRequest as GetInternalReportRequest -import org.wfanet.measurement.internal.reporting.Measurement as InternalMeasurement -import org.wfanet.measurement.internal.reporting.MeasurementKt as InternalMeasurementKt -import org.wfanet.measurement.internal.reporting.MeasurementKt.ResultKt as InternalMeasurementResultKt -import org.wfanet.measurement.internal.reporting.MeasurementsGrpcKt.MeasurementsCoroutineImplBase as InternalMeasurementsCoroutineImplBase -import org.wfanet.measurement.internal.reporting.MeasurementsGrpcKt.MeasurementsCoroutineStub as InternalMeasurementsCoroutineStub -import org.wfanet.measurement.internal.reporting.Metric as InternalMetric -import org.wfanet.measurement.internal.reporting.MetricKt as InternalMetricKt -import org.wfanet.measurement.internal.reporting.MetricKt.MeasurementCalculationKt.weightedMeasurement -import org.wfanet.measurement.internal.reporting.MetricKt.SetOperationKt.reportingSetKey -import org.wfanet.measurement.internal.reporting.MetricKt.measurementCalculation -import org.wfanet.measurement.internal.reporting.Report as InternalReport -import org.wfanet.measurement.internal.reporting.ReportKt as InternalReportKt -import org.wfanet.measurement.internal.reporting.ReportKt.DetailsKt as InternalReportDetailsKt -import org.wfanet.measurement.internal.reporting.ReportKt.DetailsKt.ResultKt as InternalReportResultKt -import org.wfanet.measurement.internal.reporting.ReportingSet as InternalReportingSet -import org.wfanet.measurement.internal.reporting.ReportingSetKt as InternalReportingSetKt -import org.wfanet.measurement.internal.reporting.ReportingSetsGrpcKt.ReportingSetsCoroutineImplBase as InternalReportingSetsCoroutineImplBase -import org.wfanet.measurement.internal.reporting.ReportingSetsGrpcKt.ReportingSetsCoroutineStub as InternalReportingSetsCoroutineStub -import org.wfanet.measurement.internal.reporting.ReportsGrpcKt.ReportsCoroutineImplBase -import org.wfanet.measurement.internal.reporting.ReportsGrpcKt.ReportsCoroutineStub as InternalReportsCoroutineStub -import org.wfanet.measurement.internal.reporting.StreamReportsRequestKt.filter -import org.wfanet.measurement.internal.reporting.batchCreateMeasurementsRequest -import org.wfanet.measurement.internal.reporting.batchGetReportingSetRequest -import org.wfanet.measurement.internal.reporting.copy -import org.wfanet.measurement.internal.reporting.createReportRequest as internalCreateReportRequest -import org.wfanet.measurement.internal.reporting.getReportByIdempotencyKeyRequest -import org.wfanet.measurement.internal.reporting.getReportRequest as getInternalReportRequest -import org.wfanet.measurement.internal.reporting.measurement as internalMeasurement -import org.wfanet.measurement.internal.reporting.metric as internalMetric -import org.wfanet.measurement.internal.reporting.periodicTimeInterval as internalPeriodicTimeInterval -import org.wfanet.measurement.internal.reporting.report as internalReport -import org.wfanet.measurement.internal.reporting.reportingSet as internalReportingSet -import org.wfanet.measurement.internal.reporting.setMeasurementFailureRequest -import org.wfanet.measurement.internal.reporting.setMeasurementResultRequest -import org.wfanet.measurement.internal.reporting.streamReportsRequest -import org.wfanet.measurement.internal.reporting.timeInterval as internalTimeInterval -import org.wfanet.measurement.internal.reporting.timeIntervals as internalTimeIntervals -import org.wfanet.measurement.reporting.service.api.InMemoryEncryptionKeyPairStore -import org.wfanet.measurement.reporting.v1alpha.ListReportsPageTokenKt.previousPageEnd -import org.wfanet.measurement.reporting.v1alpha.ListReportsRequest -import org.wfanet.measurement.reporting.v1alpha.Metric.SetOperation -import org.wfanet.measurement.reporting.v1alpha.MetricKt.SetOperationKt -import org.wfanet.measurement.reporting.v1alpha.MetricKt.frequencyHistogramParams -import org.wfanet.measurement.reporting.v1alpha.MetricKt.impressionCountParams -import org.wfanet.measurement.reporting.v1alpha.MetricKt.namedSetOperation -import org.wfanet.measurement.reporting.v1alpha.MetricKt.reachParams -import org.wfanet.measurement.reporting.v1alpha.MetricKt.setOperation -import org.wfanet.measurement.reporting.v1alpha.MetricKt.watchDurationParams -import org.wfanet.measurement.reporting.v1alpha.Report -import org.wfanet.measurement.reporting.v1alpha.ReportKt -import org.wfanet.measurement.reporting.v1alpha.ReportKt.EventGroupUniverseKt -import org.wfanet.measurement.reporting.v1alpha.ReportKt.ResultKt.HistogramTableKt.row -import org.wfanet.measurement.reporting.v1alpha.ReportKt.ResultKt.column -import org.wfanet.measurement.reporting.v1alpha.ReportKt.ResultKt.histogramTable -import org.wfanet.measurement.reporting.v1alpha.ReportKt.ResultKt.scalarTable -import org.wfanet.measurement.reporting.v1alpha.ReportKt.eventGroupUniverse -import org.wfanet.measurement.reporting.v1alpha.copy -import org.wfanet.measurement.reporting.v1alpha.createReportRequest -import org.wfanet.measurement.reporting.v1alpha.getReportRequest -import org.wfanet.measurement.reporting.v1alpha.listReportsPageToken -import org.wfanet.measurement.reporting.v1alpha.listReportsRequest -import org.wfanet.measurement.reporting.v1alpha.listReportsResponse -import org.wfanet.measurement.reporting.v1alpha.metric -import org.wfanet.measurement.reporting.v1alpha.periodicTimeInterval -import org.wfanet.measurement.reporting.v1alpha.report -import org.wfanet.measurement.reporting.v1alpha.timeInterval -import org.wfanet.measurement.reporting.v1alpha.timeIntervals - -private const val DEFAULT_PAGE_SIZE = 50 -private const val MAX_PAGE_SIZE = 1000 -private const val PAGE_SIZE = 3 - -private const val NUMBER_VID_BUCKETS = 300 -private const val WIDTH = 256 -private const val DELTA = 1e-15 - -private const val MAXIMUM_FREQUENCY = 10 - -private val MEASUREMENT_SPEC_CONFIG = measurementSpecConfig { - reachSingleDataProvider = - MeasurementSpecConfigKt.reachSingleDataProvider { - privacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.000207 - delta = DELTA - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - fixedStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.fixedStart { - start = 0f - width = 1f - } - } - } - reach = - MeasurementSpecConfigKt.reach { - privacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.0007444 - delta = DELTA - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - randomStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.randomStart { - width = WIDTH - numVidBuckets = NUMBER_VID_BUCKETS - } - } - } - reachAndFrequencySingleDataProvider = - MeasurementSpecConfigKt.reachAndFrequencySingleDataProvider { - reachPrivacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.004728 - delta = DELTA - } - frequencyPrivacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.004728 - delta = DELTA - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - fixedStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.fixedStart { - start = 0f - width = 1f - } - } - } - reachAndFrequency = - MeasurementSpecConfigKt.reachAndFrequency { - reachPrivacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.014638 - delta = DELTA - } - frequencyPrivacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.014638 - delta = DELTA - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - randomStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.randomStart { - width = WIDTH - numVidBuckets = NUMBER_VID_BUCKETS - } - } - } - impression = - MeasurementSpecConfigKt.impression { - privacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.003592 - delta = DELTA - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - fixedStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.fixedStart { - start = 0f - width = 1f - } - } - } - duration = - MeasurementSpecConfigKt.duration { - privacyParams = - MeasurementSpecConfigKt.differentialPrivacyParams { - epsilon = 0.007418 - delta = DELTA - } - vidSamplingInterval = - MeasurementSpecConfigKt.vidSamplingInterval { - fixedStart = - MeasurementSpecConfigKt.VidSamplingIntervalKt.fixedStart { - start = 0f - width = 1f - } - } - } -} - -private const val SECURE_RANDOM_OUTPUT_INT = 0 -private const val SECURE_RANDOM_OUTPUT_LONG = 0L - -private val SECRETS_DIR = - getRuntimePath( - Paths.get("wfa_measurement_system", "src", "main", "k8s", "testing", "secretfiles") - )!! - .toFile() - -// Authentication key -private const val API_AUTHENTICATION_KEY = "nR5QPN7ptx" - -// Aggregator certificate - -private val AGGREGATOR_SIGNING_KEY: SigningKeyHandle by lazy { - loadSigningKey( - SECRETS_DIR.resolve("aggregator_cs_cert.der"), - SECRETS_DIR.resolve("aggregator_cs_private.der"), - ) -} -private val AGGREGATOR_CERTIFICATE = certificate { - name = "duchies/aggregator/certificates/abc123" - x509Der = AGGREGATOR_SIGNING_KEY.certificate.encoded.toByteString() -} -private val AGGREGATOR_ROOT_CERTIFICATE: X509Certificate = - readCertificate(SECRETS_DIR.resolve("aggregator_root.pem")) - -private val INVALID_MEASUREMENT_PUBLIC_KEY_DATA = "Invalid public key".toByteStringUtf8() - -// Measurement consumer crypto - -private val TRUSTED_MEASUREMENT_CONSUMER_ISSUER: X509Certificate = - readCertificate(SECRETS_DIR.resolve("mc_root.pem")) -private val MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE = - loadSigningKey(SECRETS_DIR.resolve("mc_cs_cert.der"), SECRETS_DIR.resolve("mc_cs_private.der")) -private val MEASUREMENT_CONSUMER_CERTIFICATE = MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE.certificate -private val MEASUREMENT_CONSUMER_PRIVATE_KEY_HANDLE: PrivateKeyHandle = - loadPrivateKey(SECRETS_DIR.resolve("mc_enc_private.tink")) -private val MEASUREMENT_CONSUMER_PUBLIC_KEY = encryptionPublicKey { - format = EncryptionPublicKey.Format.TINK_KEYSET - data = SECRETS_DIR.resolve("mc_enc_public.tink").readByteString() -} - -private val MEASUREMENT_CONSUMERS: Map = - (1L..2L).associate { - val measurementConsumerKey = MeasurementConsumerKey(ExternalId(it + 110L).apiId.value) - val certificateKey = - MeasurementConsumerCertificateKey( - measurementConsumerKey.measurementConsumerId, - ExternalId(it + 120L).apiId.value, - ) - measurementConsumerKey to - measurementConsumer { - name = measurementConsumerKey.toName() - certificate = certificateKey.toName() - certificateDer = MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE.certificate.encoded.toByteString() - publicKey = - signEncryptionPublicKey( - MEASUREMENT_CONSUMER_PUBLIC_KEY, - MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE, - ) - } - } - -private val CONFIG = measurementConsumerConfig { - apiKey = API_AUTHENTICATION_KEY - signingCertificateName = MEASUREMENT_CONSUMERS.values.first().certificate - signingPrivateKeyPath = "mc_cs_private.der" -} - -// InMemoryEncryptionKeyPairStore -private val ENCRYPTION_KEY_PAIR_STORE = - InMemoryEncryptionKeyPairStore( - MEASUREMENT_CONSUMERS.values.associateBy( - { it.name }, - { - listOf( - it.publicKey.unpack().data to MEASUREMENT_CONSUMER_PRIVATE_KEY_HANDLE - ) - }, - ) - ) - -// Report IDs and names -private val REPORT_EXTERNAL_IDS = listOf(331L, 332L, 333L, 334L) -private val REPORT_NAMES = - REPORT_EXTERNAL_IDS.map { - ReportKey(MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId, externalIdToApiId(it)) - .toName() - } - -// Typo causes invalid name -private const val INVALID_REPORT_NAME = "measurementConsumer/AAAAAAAAAG8/report/AAAAAAAAAU0" - -private val DATA_PROVIDER_PUBLIC_KEY = encryptionPublicKey { - format = EncryptionPublicKey.Format.TINK_KEYSET - data = SECRETS_DIR.resolve("edp1_enc_public.tink").readByteString() -} -private val DATA_PROVIDER_PRIVATE_KEY_HANDLE = - loadPrivateKey(SECRETS_DIR.resolve("edp1_enc_private.tink")) -private val DATA_PROVIDER_SIGNING_KEY = - loadSigningKey( - SECRETS_DIR.resolve("edp1_cs_cert.der"), - SECRETS_DIR.resolve("edp1_cs_private.der"), - ) -private val DATA_PROVIDER_ROOT_CERTIFICATE = readCertificate(SECRETS_DIR.resolve("edp1_root.pem")) - -// Data providers - -private val DATA_PROVIDERS = - (1L..3L).associate { - val dataProviderKey = DataProviderKey(ExternalId(it + 550L).apiId.value) - val certificateKey = - DataProviderCertificateKey(dataProviderKey.dataProviderId, ExternalId(it + 560L).apiId.value) - dataProviderKey to - dataProvider { - name = dataProviderKey.toName() - certificate = certificateKey.toName() - publicKey = signEncryptionPublicKey(DATA_PROVIDER_PUBLIC_KEY, DATA_PROVIDER_SIGNING_KEY) - } - } -private val DATA_PROVIDERS_LIST = DATA_PROVIDERS.values.toList() - -// Event group keys - -private val COVERED_EVENT_GROUP_KEYS = - DATA_PROVIDERS.keys.mapIndexed { index, dataProviderKey -> - val measurementConsumerKey = MEASUREMENT_CONSUMERS.keys.first() - EventGroupKey( - measurementConsumerKey.measurementConsumerId, - dataProviderKey.dataProviderId, - ExternalId(index + 660L).apiId.value, - ) - } -private val UNCOVERED_EVENT_GROUP_KEY = - EventGroupKey( - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId, - DATA_PROVIDERS.keys.last().dataProviderId, - ExternalId(664L).apiId.value, - ) -private val UNCOVERED_EVENT_GROUP_NAME = UNCOVERED_EVENT_GROUP_KEY.toName() -private val UNCOVERED_INTERNAL_EVENT_GROUP_KEY = UNCOVERED_EVENT_GROUP_KEY.toInternal() - -// Reporting sets -private const val REPORTING_SET_FILTER = "AGE>18" - -private val INTERNAL_REPORTING_SETS = - COVERED_EVENT_GROUP_KEYS.mapIndexed { index, eventGroupKey -> - internalReportingSet { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetId = index + 220L - eventGroupKeys += eventGroupKey.toInternal() - filter = REPORTING_SET_FILTER - displayName = "$measurementConsumerReferenceId-$externalReportingSetId-$filter" - } - } -private val UNCOVERED_INTERNAL_REPORTING_SET = internalReportingSet { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetId = INTERNAL_REPORTING_SETS.last().externalReportingSetId + 1 - eventGroupKeys += UNCOVERED_INTERNAL_EVENT_GROUP_KEY - filter = REPORTING_SET_FILTER - displayName = "$measurementConsumerReferenceId-$externalReportingSetId-$filter" -} - -// Reporting set IDs and names - -private const val REPORTING_SET_EXTERNAL_ID_FOR_MC_2 = 241L - -private const val INVALID_REPORTING_SET_NAME = "INVALID_REPORTING_SET_NAME" -private val REPORTING_SET_NAME_FOR_MC_2 = - ReportingSetKey( - MEASUREMENT_CONSUMERS.keys.last().measurementConsumerId, - externalIdToApiId(REPORTING_SET_EXTERNAL_ID_FOR_MC_2), - ) - .toName() - -// Time intervals -private val START_INSTANT = Instant.now() -private val END_INSTANT = START_INSTANT.plus(Duration.ofDays(1)) - -private val START_TIME = START_INSTANT.toProtoTime() -private val TIME_INTERVAL_INCREMENT = Duration.ofDays(1).toProtoDuration() -private const val INTERVAL_COUNT = 1 -private val END_TIME = END_INSTANT.toProtoTime() -private val MEASUREMENT_TIME_INTERVAL = interval { - startTime = START_TIME - endTime = END_TIME -} -private val INTERNAL_TIME_INTERVAL = internalTimeInterval { - startTime = START_TIME - endTime = END_TIME -} -private val INTERNAL_PERIODIC_TIME_INTERVAL = internalPeriodicTimeInterval { - startTime = START_TIME - increment = TIME_INTERVAL_INCREMENT - intervalCount = INTERVAL_COUNT -} -private val PERIODIC_TIME_INTERVAL = periodicTimeInterval { - startTime = START_TIME - increment = TIME_INTERVAL_INCREMENT - intervalCount = INTERVAL_COUNT -} - -// Report idempotency keys -private const val REACH_REPORT_IDEMPOTENCY_KEY = "TEST_REACH_REPORT" -private const val IMPRESSION_REPORT_IDEMPOTENCY_KEY = "TEST_IMPRESSION_REPORT" -private const val WATCH_DURATION_REPORT_IDEMPOTENCY_KEY = "TEST_WATCH_DURATION_REPORT" -private const val FREQUENCY_HISTOGRAM_REPORT_IDEMPOTENCY_KEY = "TEST_FREQUENCY_HISTOGRAM_REPORT" - -// Set operation unique names -private const val REACH_SET_OPERATION_UNIQUE_NAME = "Reach Set Operation" -private const val FREQUENCY_HISTOGRAM_SET_OPERATION_UNIQUE_NAME = - "Frequency Histogram Set Operation" -private const val IMPRESSION_SET_OPERATION_UNIQUE_NAME = "Impression Set Operation" -private const val WATCH_DURATION_SET_OPERATION_UNIQUE_NAME = "Watch Duration Set Operation" - -// Measurement IDs and names -private val REACH_MEASUREMENT_CREATE_REQUEST_ID = - "$REACH_REPORT_IDEMPOTENCY_KEY-Reach-$REACH_SET_OPERATION_UNIQUE_NAME-$START_INSTANT-" + - "$END_INSTANT-measurement-0" - -private val FREQUENCY_HISTOGRAM_MEASUREMENT_CREATE_REQUEST_ID = - "$FREQUENCY_HISTOGRAM_REPORT_IDEMPOTENCY_KEY-FrequencyHistogram-$FREQUENCY_HISTOGRAM_SET_OPERATION_UNIQUE_NAME-$START_INSTANT-" + - "$END_INSTANT-measurement-0" - -private val IMPRESSION_MEASUREMENT_CREATE_REQUEST_ID = - "$IMPRESSION_REPORT_IDEMPOTENCY_KEY-ImpressionCount-$IMPRESSION_SET_OPERATION_UNIQUE_NAME-$START_INSTANT-" + - "$END_INSTANT-measurement-0" - -private val WATCH_DURATION_MEASUREMENT_CREATE_REQUEST_ID = - "$WATCH_DURATION_REPORT_IDEMPOTENCY_KEY-WatchDuration-$WATCH_DURATION_SET_OPERATION_UNIQUE_NAME-$START_INSTANT-" + - "$END_INSTANT-measurement-0" - -private val REACH_MEASUREMENT_KEY = - MeasurementKey( - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId, - ExternalId(111).apiId.value, - ) -private val REACH_MEASUREMENT_KEY_2 = - MeasurementKey(REACH_MEASUREMENT_KEY.measurementConsumerId, ExternalId(222).apiId.value) -private val FREQUENCY_HISTOGRAM_MEASUREMENT_KEY = - MeasurementKey(REACH_MEASUREMENT_KEY.measurementConsumerId, ExternalId(333).apiId.value) -private val IMPRESSION_MEASUREMENT_KEY = - MeasurementKey(REACH_MEASUREMENT_KEY.measurementConsumerId, ExternalId(444).apiId.value) -private val WATCH_DURATION_MEASUREMENT_KEY = - MeasurementKey(REACH_MEASUREMENT_KEY.measurementConsumerId, ExternalId(555).apiId.value) - -// Set operations -private val INTERNAL_SET_OPERATION = - InternalMetricKt.setOperation { - type = InternalMetric.SetOperation.Type.UNION - lhs = - InternalMetricKt.SetOperationKt.operand { - reportingSetId = reportingSetKey { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetId = INTERNAL_REPORTING_SETS[0].externalReportingSetId - } - } - rhs = - InternalMetricKt.SetOperationKt.operand { - reportingSetId = reportingSetKey { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetId = INTERNAL_REPORTING_SETS[1].externalReportingSetId - } - } - } - -private val SET_OPERATION = setOperation { - type = SetOperation.Type.UNION - lhs = SetOperationKt.operand { reportingSet = INTERNAL_REPORTING_SETS[0].resourceName } - rhs = SetOperationKt.operand { reportingSet = INTERNAL_REPORTING_SETS[1].resourceName } -} -private val DATA_PROVIDER_KEYS_IN_SET_OPERATION = DATA_PROVIDERS.keys.take(2) - -private val SET_OPERATION_WITH_INVALID_REPORTING_SET = setOperation { - type = SetOperation.Type.UNION - lhs = SetOperationKt.operand { reportingSet = INVALID_REPORTING_SET_NAME } - rhs = SetOperationKt.operand { reportingSet = INTERNAL_REPORTING_SETS[1].resourceName } -} - -private val SET_OPERATION_WITH_INACCESSIBLE_REPORTING_SET = setOperation { - type = SetOperation.Type.UNION - lhs = SetOperationKt.operand { reportingSet = REPORTING_SET_NAME_FOR_MC_2 } - rhs = SetOperationKt.operand { reportingSet = INTERNAL_REPORTING_SETS[1].resourceName } -} - -// Event group filters -private const val EVENT_GROUP_FILTER = "AGE>20" -private val EVENT_GROUP_FILTERS_MAP = - COVERED_EVENT_GROUP_KEYS.associateBy(EventGroupKey::toName) { EVENT_GROUP_FILTER } - -// Event group entries -private val EVENT_GROUP_ENTRIES = - COVERED_EVENT_GROUP_KEYS.groupBy( - { DataProviderKey(it.dataProviderReferenceId) }, - { - eventGroupEntry { - key = CmmsEventGroupKey(it.dataProviderReferenceId, it.eventGroupReferenceId).toName() - value = - EventGroupEntryKt.value { - collectionInterval = MEASUREMENT_TIME_INTERVAL - filter = eventFilter { - expression = "($REPORTING_SET_FILTER) AND ($EVENT_GROUP_FILTER)" - } - } - } - }, - ) - -// Requisition specs -private val REQUISITION_SPECS: Map = - EVENT_GROUP_ENTRIES.mapValues { - requisitionSpec { - events = RequisitionSpecKt.events { eventGroups += it.value } - measurementPublicKey = MEASUREMENT_CONSUMERS.values.first().publicKey.message - nonce = SECURE_RANDOM_OUTPUT_LONG - } - } - -// Data provider entries -private val DATA_PROVIDER_ENTRIES = - REQUISITION_SPECS.mapValues { (dataProviderKey, requisitionSpec) -> - val dataProvider = DATA_PROVIDERS.getValue(dataProviderKey) - dataProviderEntry { - key = dataProvider.name - value = - DataProviderEntryKt.value { - dataProviderCertificate = dataProvider.certificate - dataProviderPublicKey = dataProvider.publicKey.message - encryptedRequisitionSpec = - encryptRequisitionSpec( - signRequisitionSpec(requisitionSpec, MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE), - dataProvider.publicKey.unpack(), - ) - nonceHash = Hashing.hashSha256(requisitionSpec.nonce) - } - } - } - -// Measurements -private val BASE_MEASUREMENT = measurement { - measurementConsumerCertificate = MEASUREMENT_CONSUMERS.values.first().certificate -} -private val BASE_MEASUREMENT_SPEC = measurementSpec { - measurementPublicKey = MEASUREMENT_CONSUMER_PUBLIC_KEY.pack() - // TODO(world-federation-of-advertisers/cross-media-measurement#1301): Stop setting this field. - serializedMeasurementPublicKey = measurementPublicKey.value -} - -// Measurement values -private const val REACH_VALUE = 100_000L -private val FREQUENCY_DISTRIBUTION = mapOf(1L to 1.0 / 6, 2L to 2.0 / 6, 3L to 3.0 / 6) -private val IMPRESSION_VALUES = listOf(100L, 150L) -private val TOTAL_IMPRESSION_VALUE = IMPRESSION_VALUES.sum() -private val WATCH_DURATION_SECOND_LIST = listOf(100L, 200L) -private val WATCH_DURATION_LIST = WATCH_DURATION_SECOND_LIST.map { duration { seconds = it } } -private val TOTAL_WATCH_DURATION = duration { seconds = WATCH_DURATION_SECOND_LIST.sum() } - -// Reach measurement -private val BASE_REACH_MEASUREMENT = BASE_MEASUREMENT.copy { name = REACH_MEASUREMENT_KEY.toName() } -private val BASE_REACH_MEASUREMENT_2 = - BASE_MEASUREMENT.copy { name = REACH_MEASUREMENT_KEY_2.toName() } - -private val PENDING_REACH_MEASUREMENT = - BASE_REACH_MEASUREMENT.copy { state = Measurement.State.COMPUTING } - -private val REACH_SINGLE_DATA_PROVIDER_MEASUREMENT_SPEC = - BASE_MEASUREMENT_SPEC.copy { - nonceHashes.addAll(listOf(Hashing.hashSha256(SECURE_RANDOM_OUTPUT_LONG))) - - reach = - MeasurementSpecKt.reach { - privacyParams = differentialPrivacyParams { - epsilon = MEASUREMENT_SPEC_CONFIG.reachSingleDataProvider.privacyParams.epsilon - delta = MEASUREMENT_SPEC_CONFIG.reachSingleDataProvider.privacyParams.delta - } - } - vidSamplingInterval = vidSamplingInterval { - start = 0f - width = 1f - } - } - -private val REACH_SINGLE_DATA_PROVIDER_MEASUREMENT_REQUEST = createMeasurementRequest { - parent = MeasurementConsumerKey(REACH_MEASUREMENT_KEY.measurementConsumerId).toName() - measurement = - BASE_MEASUREMENT.copy { - dataProviders += DATA_PROVIDER_ENTRIES.getValue(DataProviderKey(ExternalId(551L).apiId.value)) - measurementSpec = - signMeasurementSpec( - REACH_SINGLE_DATA_PROVIDER_MEASUREMENT_SPEC, - MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE, - ) - } - requestId = REACH_MEASUREMENT_CREATE_REQUEST_ID -} - -private val REACH_MEASUREMENT_SPEC = - BASE_MEASUREMENT_SPEC.copy { - nonceHashes.addAll( - listOf( - Hashing.hashSha256(SECURE_RANDOM_OUTPUT_LONG), - Hashing.hashSha256(SECURE_RANDOM_OUTPUT_LONG), - ) - ) - - reach = - MeasurementSpecKt.reach { - privacyParams = differentialPrivacyParams { - epsilon = MEASUREMENT_SPEC_CONFIG.reach.privacyParams.epsilon - delta = MEASUREMENT_SPEC_CONFIG.reach.privacyParams.delta - } - } - vidSamplingInterval = vidSamplingInterval { - start = SECURE_RANDOM_OUTPUT_INT.toFloat() / NUMBER_VID_BUCKETS - width = WIDTH.toFloat() / NUMBER_VID_BUCKETS - } - } - -private val REACH_MEASUREMENT_REQUEST = createMeasurementRequest { - parent = MeasurementConsumerKey(REACH_MEASUREMENT_KEY.measurementConsumerId).toName() - measurement = - BASE_MEASUREMENT.copy { - dataProviders += - DATA_PROVIDER_KEYS_IN_SET_OPERATION.map { DATA_PROVIDER_ENTRIES.getValue(it) } - measurementSpec = - signMeasurementSpec(REACH_MEASUREMENT_SPEC, MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE) - } - requestId = REACH_MEASUREMENT_CREATE_REQUEST_ID -} - -private val SUCCEEDED_REACH_MEASUREMENT = - BASE_REACH_MEASUREMENT.copy { - dataProviders += DATA_PROVIDER_KEYS_IN_SET_OPERATION.map { DATA_PROVIDER_ENTRIES.getValue(it) } - - measurementSpec = - signMeasurementSpec(REACH_MEASUREMENT_SPEC, MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE) - - state = Measurement.State.SUCCEEDED - - results += resultOutput { - val result = - MeasurementKt.result { - reach = MeasurementKt.ResultKt.reach { value = REACH_VALUE } - frequency = - MeasurementKt.ResultKt.frequency { - relativeFrequencyDistribution.putAll(FREQUENCY_DISTRIBUTION) - } - } - encryptedResult = - encryptResult(signResult(result, AGGREGATOR_SIGNING_KEY), MEASUREMENT_CONSUMER_PUBLIC_KEY) - certificate = AGGREGATOR_CERTIFICATE.name - } - } - -private val INTERNAL_PENDING_REACH_MEASUREMENT = internalMeasurement { - measurementConsumerReferenceId = REACH_MEASUREMENT_KEY.measurementConsumerId - measurementReferenceId = REACH_MEASUREMENT_KEY.measurementId - state = InternalMeasurement.State.PENDING -} -private val INTERNAL_SUCCEEDED_REACH_MEASUREMENT = - INTERNAL_PENDING_REACH_MEASUREMENT.copy { - state = InternalMeasurement.State.SUCCEEDED - result = - InternalMeasurementKt.result { - reach = InternalMeasurementResultKt.reach { value = REACH_VALUE } - frequency = - InternalMeasurementResultKt.frequency { - relativeFrequencyDistribution.putAll(FREQUENCY_DISTRIBUTION) - } - } - } - -// Frequency histogram measurement -private val BASE_REACH_FREQUENCY_HISTOGRAM_MEASUREMENT = - BASE_MEASUREMENT.copy { name = FREQUENCY_HISTOGRAM_MEASUREMENT_KEY.toName() } - -private val REACH_FREQUENCY_SINGLE_DATA_PROVIDER_MEASUREMENT_SPEC = - BASE_MEASUREMENT_SPEC.copy { - nonceHashes.addAll(listOf(Hashing.hashSha256(SECURE_RANDOM_OUTPUT_LONG))) - - reachAndFrequency = - MeasurementSpecKt.reachAndFrequency { - reachPrivacyParams = differentialPrivacyParams { - epsilon = - MEASUREMENT_SPEC_CONFIG.reachAndFrequencySingleDataProvider.reachPrivacyParams.epsilon - delta = - MEASUREMENT_SPEC_CONFIG.reachAndFrequencySingleDataProvider.reachPrivacyParams.delta - } - frequencyPrivacyParams = differentialPrivacyParams { - epsilon = - MEASUREMENT_SPEC_CONFIG.reachAndFrequencySingleDataProvider.frequencyPrivacyParams - .epsilon - delta = - MEASUREMENT_SPEC_CONFIG.reachAndFrequencySingleDataProvider.frequencyPrivacyParams.delta - } - maximumFrequency = MAXIMUM_FREQUENCY - } - vidSamplingInterval = vidSamplingInterval { - start = 0f - width = 1f - } - } - -private val REACH_FREQUENCY_SINGLE_DATA_PROVIDER_MEASUREMENT_REQUEST = createMeasurementRequest { - parent = MeasurementConsumerKey(REACH_MEASUREMENT_KEY.measurementConsumerId).toName() - measurement = - BASE_MEASUREMENT.copy { - dataProviders += DATA_PROVIDER_ENTRIES.getValue(DataProviderKey(ExternalId(551L).apiId.value)) - measurementSpec = - signMeasurementSpec( - REACH_FREQUENCY_SINGLE_DATA_PROVIDER_MEASUREMENT_SPEC, - MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE, - ) - } - requestId = FREQUENCY_HISTOGRAM_MEASUREMENT_CREATE_REQUEST_ID -} - -private val REACH_FREQUENCY_MEASUREMENT_SPEC = - BASE_MEASUREMENT_SPEC.copy { - nonceHashes.addAll( - listOf( - Hashing.hashSha256(SECURE_RANDOM_OUTPUT_LONG), - Hashing.hashSha256(SECURE_RANDOM_OUTPUT_LONG), - ) - ) - - reachAndFrequency = - MeasurementSpecKt.reachAndFrequency { - reachPrivacyParams = differentialPrivacyParams { - epsilon = MEASUREMENT_SPEC_CONFIG.reachAndFrequency.reachPrivacyParams.epsilon - delta = MEASUREMENT_SPEC_CONFIG.reachAndFrequency.reachPrivacyParams.delta - } - frequencyPrivacyParams = differentialPrivacyParams { - epsilon = MEASUREMENT_SPEC_CONFIG.reachAndFrequency.frequencyPrivacyParams.epsilon - delta = MEASUREMENT_SPEC_CONFIG.reachAndFrequency.frequencyPrivacyParams.delta - } - maximumFrequency = MAXIMUM_FREQUENCY - } - vidSamplingInterval = vidSamplingInterval { - start = SECURE_RANDOM_OUTPUT_INT.toFloat() / NUMBER_VID_BUCKETS - width = WIDTH.toFloat() / NUMBER_VID_BUCKETS - } - } - -private val REACH_FREQUENCY_MEASUREMENT_REQUEST = createMeasurementRequest { - parent = MeasurementConsumerKey(REACH_MEASUREMENT_KEY.measurementConsumerId).toName() - measurement = - BASE_MEASUREMENT.copy { - dataProviders += - DATA_PROVIDER_KEYS_IN_SET_OPERATION.map { DATA_PROVIDER_ENTRIES.getValue(it) } - measurementSpec = - signMeasurementSpec( - REACH_FREQUENCY_MEASUREMENT_SPEC, - MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE, - ) - } - requestId = FREQUENCY_HISTOGRAM_MEASUREMENT_CREATE_REQUEST_ID -} - -private val SUCCEEDED_FREQUENCY_HISTOGRAM_MEASUREMENT = - BASE_REACH_FREQUENCY_HISTOGRAM_MEASUREMENT.copy { - dataProviders += DATA_PROVIDER_KEYS_IN_SET_OPERATION.map { DATA_PROVIDER_ENTRIES.getValue(it) } - - measurementSpec = - signMeasurementSpec(REACH_FREQUENCY_MEASUREMENT_SPEC, MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE) - - state = Measurement.State.SUCCEEDED - results += resultOutput { - val result = - MeasurementKt.result { - reach = MeasurementKt.ResultKt.reach { value = REACH_VALUE } - frequency = - MeasurementKt.ResultKt.frequency { - relativeFrequencyDistribution.putAll(FREQUENCY_DISTRIBUTION) - } - } - encryptedResult = - encryptResult(signResult(result, AGGREGATOR_SIGNING_KEY), MEASUREMENT_CONSUMER_PUBLIC_KEY) - certificate = AGGREGATOR_CERTIFICATE.name - } - } - -private val INTERNAL_PENDING_FREQUENCY_HISTOGRAM_MEASUREMENT = internalMeasurement { - measurementConsumerReferenceId = FREQUENCY_HISTOGRAM_MEASUREMENT_KEY.measurementConsumerId - measurementReferenceId = FREQUENCY_HISTOGRAM_MEASUREMENT_KEY.measurementId - state = InternalMeasurement.State.PENDING -} - -private val INTERNAL_SUCCEEDED_FREQUENCY_HISTOGRAM_MEASUREMENT = - INTERNAL_PENDING_FREQUENCY_HISTOGRAM_MEASUREMENT.copy { - state = InternalMeasurement.State.SUCCEEDED - result = - InternalMeasurementKt.result { - reach = InternalMeasurementResultKt.reach { value = REACH_VALUE } - frequency = - InternalMeasurementResultKt.frequency { - relativeFrequencyDistribution.putAll(FREQUENCY_DISTRIBUTION) - } - } - } - -// Impression measurement -private val BASE_IMPRESSION_MEASUREMENT = - BASE_MEASUREMENT.copy { name = IMPRESSION_MEASUREMENT_KEY.toName() } - -private val IMPRESSION_MEASUREMENT_SPEC = - BASE_MEASUREMENT_SPEC.copy { - nonceHashes.addAll( - listOf( - Hashing.hashSha256(SECURE_RANDOM_OUTPUT_LONG), - Hashing.hashSha256(SECURE_RANDOM_OUTPUT_LONG), - ) - ) - - impression = - MeasurementSpecKt.impression { - privacyParams = differentialPrivacyParams { - epsilon = MEASUREMENT_SPEC_CONFIG.impression.privacyParams.epsilon - delta = MEASUREMENT_SPEC_CONFIG.impression.privacyParams.delta - } - maximumFrequencyPerUser = MAXIMUM_FREQUENCY - } - vidSamplingInterval = vidSamplingInterval { - start = MEASUREMENT_SPEC_CONFIG.impression.vidSamplingInterval.fixedStart.start - width = MEASUREMENT_SPEC_CONFIG.impression.vidSamplingInterval.fixedStart.width - } - } - -private val IMPRESSION_MEASUREMENT_REQUEST = createMeasurementRequest { - parent = MeasurementConsumerKey(REACH_MEASUREMENT_KEY.measurementConsumerId).toName() - measurement = - BASE_MEASUREMENT.copy { - dataProviders += - DATA_PROVIDER_KEYS_IN_SET_OPERATION.map { DATA_PROVIDER_ENTRIES.getValue(it) } - measurementSpec = - signMeasurementSpec(IMPRESSION_MEASUREMENT_SPEC, MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE) - } - requestId = IMPRESSION_MEASUREMENT_CREATE_REQUEST_ID -} - -private val SUCCEEDED_IMPRESSION_MEASUREMENT = - BASE_IMPRESSION_MEASUREMENT.copy { - dataProviders += DATA_PROVIDER_KEYS_IN_SET_OPERATION.map { DATA_PROVIDER_ENTRIES.getValue(it) } - - measurementSpec = - signMeasurementSpec(IMPRESSION_MEASUREMENT_SPEC, MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE) - - state = Measurement.State.SUCCEEDED - - results += - DATA_PROVIDER_KEYS_IN_SET_OPERATION.zip(IMPRESSION_VALUES).map { - (dataProviderKey, numImpressions) -> - val dataProvider = DATA_PROVIDERS.getValue(dataProviderKey) - resultOutput { - val result = - MeasurementKt.result { - impression = MeasurementKt.ResultKt.impression { value = numImpressions } - } - encryptedResult = - encryptResult( - signResult(result, DATA_PROVIDER_SIGNING_KEY), - MEASUREMENT_CONSUMER_PUBLIC_KEY, - ) - certificate = dataProvider.certificate - } - } - } - -private val INTERNAL_PENDING_IMPRESSION_MEASUREMENT = internalMeasurement { - measurementConsumerReferenceId = IMPRESSION_MEASUREMENT_KEY.measurementConsumerId - measurementReferenceId = IMPRESSION_MEASUREMENT_KEY.measurementId - state = InternalMeasurement.State.PENDING -} - -private val INTERNAL_SUCCEEDED_IMPRESSION_MEASUREMENT = - INTERNAL_PENDING_IMPRESSION_MEASUREMENT.copy { - state = InternalMeasurement.State.SUCCEEDED - result = - InternalMeasurementKt.result { - impression = InternalMeasurementResultKt.impression { value = TOTAL_IMPRESSION_VALUE } - } - } - -// Watch Duration measurement -private val BASE_WATCH_DURATION_MEASUREMENT = - BASE_MEASUREMENT.copy { name = WATCH_DURATION_MEASUREMENT_KEY.toName() } - -private val PENDING_WATCH_DURATION_MEASUREMENT = - BASE_WATCH_DURATION_MEASUREMENT.copy { state = Measurement.State.COMPUTING } - -private val WATCH_DURATION_MEASUREMENT_SPEC = - BASE_MEASUREMENT_SPEC.copy { - nonceHashes.addAll( - listOf( - Hashing.hashSha256(SECURE_RANDOM_OUTPUT_LONG), - Hashing.hashSha256(SECURE_RANDOM_OUTPUT_LONG), - ) - ) - - duration = - MeasurementSpecKt.duration { - privacyParams = differentialPrivacyParams { - epsilon = MEASUREMENT_SPEC_CONFIG.duration.privacyParams.epsilon - delta = MEASUREMENT_SPEC_CONFIG.duration.privacyParams.delta - } - maximumWatchDurationPerUser = - Durations.fromSeconds(MAXIMUM_WATCH_DURATION_PER_USER.toLong()) - } - vidSamplingInterval = vidSamplingInterval { - start = MEASUREMENT_SPEC_CONFIG.duration.vidSamplingInterval.fixedStart.start - width = MEASUREMENT_SPEC_CONFIG.duration.vidSamplingInterval.fixedStart.width - } - } - -private val WATCH_DURATION_MEASUREMENT_REQUEST = createMeasurementRequest { - parent = MeasurementConsumerKey(REACH_MEASUREMENT_KEY.measurementConsumerId).toName() - measurement = - BASE_MEASUREMENT.copy { - dataProviders += - DATA_PROVIDER_KEYS_IN_SET_OPERATION.map { DATA_PROVIDER_ENTRIES.getValue(it) } - measurementSpec = - signMeasurementSpec( - WATCH_DURATION_MEASUREMENT_SPEC, - MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE, - ) - } - requestId = WATCH_DURATION_MEASUREMENT_CREATE_REQUEST_ID -} - -private val SUCCEEDED_WATCH_DURATION_MEASUREMENT = - BASE_WATCH_DURATION_MEASUREMENT.copy { - dataProviders += DATA_PROVIDER_KEYS_IN_SET_OPERATION.map { DATA_PROVIDER_ENTRIES.getValue(it) } - - measurementSpec = - signMeasurementSpec(WATCH_DURATION_MEASUREMENT_SPEC, MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE) - - state = Measurement.State.SUCCEEDED - - results += - DATA_PROVIDER_KEYS_IN_SET_OPERATION.zip(WATCH_DURATION_LIST).map { - (dataProviderKey, watchDuration) -> - val dataProvider = DATA_PROVIDERS.getValue(dataProviderKey) - resultOutput { - val result = - MeasurementKt.result { - this.watchDuration = MeasurementKt.ResultKt.watchDuration { value = watchDuration } - } - encryptedResult = - encryptResult( - signResult(result, DATA_PROVIDER_SIGNING_KEY), - MEASUREMENT_CONSUMER_PUBLIC_KEY, - ) - certificate = dataProvider.certificate - } - } - } - -private val INTERNAL_PENDING_WATCH_DURATION_MEASUREMENT = internalMeasurement { - measurementConsumerReferenceId = WATCH_DURATION_MEASUREMENT_KEY.measurementConsumerId - measurementReferenceId = WATCH_DURATION_MEASUREMENT_KEY.measurementId - state = InternalMeasurement.State.PENDING -} -private val INTERNAL_SUCCEEDED_WATCH_DURATION_MEASUREMENT = - INTERNAL_PENDING_WATCH_DURATION_MEASUREMENT.copy { - state = InternalMeasurement.State.SUCCEEDED - result = - InternalMeasurementKt.result { - watchDuration = InternalMeasurementResultKt.watchDuration { value = TOTAL_WATCH_DURATION } - } - } - -// Weighted measurements -private val WEIGHTED_REACH_MEASUREMENT = weightedMeasurement { - measurementReferenceId = REACH_MEASUREMENT_KEY.measurementId - coefficient = 1 -} - -private val WEIGHTED_FREQUENCY_HISTOGRAM_MEASUREMENT = weightedMeasurement { - measurementReferenceId = FREQUENCY_HISTOGRAM_MEASUREMENT_KEY.measurementId - coefficient = 1 -} - -private val WEIGHTED_IMPRESSION_MEASUREMENT = weightedMeasurement { - measurementReferenceId = IMPRESSION_MEASUREMENT_KEY.measurementId - coefficient = 1 -} - -private val WEIGHTED_WATCH_DURATION_MEASUREMENT = weightedMeasurement { - measurementReferenceId = WATCH_DURATION_MEASUREMENT_KEY.measurementId - coefficient = 1 -} - -// Measurement Calculations -private val REACH_MEASUREMENT_CALCULATION = measurementCalculation { - timeInterval = INTERNAL_TIME_INTERVAL - weightedMeasurements.add(WEIGHTED_REACH_MEASUREMENT) -} - -private val FREQUENCY_HISTOGRAM_MEASUREMENT_CALCULATION = measurementCalculation { - timeInterval = INTERNAL_TIME_INTERVAL - weightedMeasurements.add(WEIGHTED_FREQUENCY_HISTOGRAM_MEASUREMENT) -} - -private val IMPRESSION_MEASUREMENT_CALCULATION = measurementCalculation { - timeInterval = INTERNAL_TIME_INTERVAL - weightedMeasurements.add(WEIGHTED_IMPRESSION_MEASUREMENT) -} - -private val WATCH_DURATION_MEASUREMENT_CALCULATION = measurementCalculation { - timeInterval = INTERNAL_TIME_INTERVAL - weightedMeasurements.add(WEIGHTED_WATCH_DURATION_MEASUREMENT) -} - -// Named set operations -// Reach set operation -private val INTERNAL_NAMED_REACH_SET_OPERATION = - InternalMetricKt.namedSetOperation { - displayName = REACH_SET_OPERATION_UNIQUE_NAME - setOperation = INTERNAL_SET_OPERATION - measurementCalculations += REACH_MEASUREMENT_CALCULATION - } -private val NAMED_REACH_SET_OPERATION = namedSetOperation { - uniqueName = REACH_SET_OPERATION_UNIQUE_NAME - setOperation = SET_OPERATION -} - -// Frequency histogram set operation -private val INTERNAL_NAMED_FREQUENCY_HISTOGRAM_SET_OPERATION = - InternalMetricKt.namedSetOperation { - displayName = FREQUENCY_HISTOGRAM_SET_OPERATION_UNIQUE_NAME - setOperation = INTERNAL_SET_OPERATION - measurementCalculations += FREQUENCY_HISTOGRAM_MEASUREMENT_CALCULATION - } -private val NAMED_FREQUENCY_HISTOGRAM_SET_OPERATION = namedSetOperation { - uniqueName = FREQUENCY_HISTOGRAM_SET_OPERATION_UNIQUE_NAME - setOperation = SET_OPERATION -} - -// Impression set operation -private val INTERNAL_NAMED_IMPRESSION_SET_OPERATION = - InternalMetricKt.namedSetOperation { - displayName = IMPRESSION_SET_OPERATION_UNIQUE_NAME - setOperation = INTERNAL_SET_OPERATION - measurementCalculations += IMPRESSION_MEASUREMENT_CALCULATION - } -private val NAMED_IMPRESSION_SET_OPERATION = namedSetOperation { - uniqueName = IMPRESSION_SET_OPERATION_UNIQUE_NAME - setOperation = SET_OPERATION -} - -// Watch duration set operation -private val INTERNAL_NAMED_WATCH_DURATION_SET_OPERATION = - InternalMetricKt.namedSetOperation { - displayName = WATCH_DURATION_SET_OPERATION_UNIQUE_NAME - setOperation = INTERNAL_SET_OPERATION - measurementCalculations += WATCH_DURATION_MEASUREMENT_CALCULATION - } -private val NAMED_WATCH_DURATION_SET_OPERATION = namedSetOperation { - uniqueName = WATCH_DURATION_SET_OPERATION_UNIQUE_NAME - setOperation = SET_OPERATION -} - -// Internal metrics -private const val MAXIMUM_FREQUENCY_PER_USER = 10 -private const val MAXIMUM_WATCH_DURATION_PER_USER = 300 - -// Reach metric -private val REACH_METRIC = metric { - reach = reachParams {} - cumulative = false - setOperations.add(NAMED_REACH_SET_OPERATION) -} -private val INTERNAL_REACH_METRIC = internalMetric { - details = - InternalMetricKt.details { - reach = InternalMetricKt.reachParams {} - cumulative = false - } - namedSetOperations.add(INTERNAL_NAMED_REACH_SET_OPERATION) -} - -// Frequency histogram metric -private val FREQUENCY_HISTOGRAM_METRIC = metric { - frequencyHistogram = frequencyHistogramParams { - maximumFrequencyPerUser = MAXIMUM_FREQUENCY_PER_USER - } - cumulative = false - setOperations.add(NAMED_FREQUENCY_HISTOGRAM_SET_OPERATION) -} -private val INTERNAL_FREQUENCY_HISTOGRAM_METRIC = internalMetric { - details = - InternalMetricKt.details { - frequencyHistogram = - InternalMetricKt.frequencyHistogramParams { maximumFrequency = MAXIMUM_FREQUENCY_PER_USER } - cumulative = false - } - namedSetOperations.add(INTERNAL_NAMED_FREQUENCY_HISTOGRAM_SET_OPERATION) -} - -// Impression metric -private val IMPRESSION_METRIC = metric { - impressionCount = impressionCountParams { maximumFrequencyPerUser = MAXIMUM_FREQUENCY_PER_USER } - cumulative = false - setOperations.add(NAMED_IMPRESSION_SET_OPERATION) -} -private val INTERNAL_IMPRESSION_METRIC = internalMetric { - details = - InternalMetricKt.details { - impressionCount = - InternalMetricKt.impressionCountParams { - maximumFrequencyPerUser = MAXIMUM_FREQUENCY_PER_USER - } - cumulative = false - } - namedSetOperations.add(INTERNAL_NAMED_IMPRESSION_SET_OPERATION) -} - -// Watch duration metric -private val WATCH_DURATION_METRIC = metric { - watchDuration = watchDurationParams { - maximumWatchDurationPerUser = MAXIMUM_WATCH_DURATION_PER_USER - } - cumulative = false - setOperations.add(NAMED_WATCH_DURATION_SET_OPERATION) -} -private val INTERNAL_WATCH_DURATION_METRIC = internalMetric { - details = - InternalMetricKt.details { - watchDuration = - InternalMetricKt.watchDurationParams { - maximumWatchDurationPerUserSeconds = MAXIMUM_WATCH_DURATION_PER_USER - } - cumulative = false - } - namedSetOperations.add(INTERNAL_NAMED_WATCH_DURATION_SET_OPERATION) -} - -// Internal reports with running states -// Internal reports of reach -private val INTERNAL_PENDING_REACH_REPORT = internalReport { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[0] - periodicTimeInterval = INTERNAL_PERIODIC_TIME_INTERVAL - metrics.add(INTERNAL_REACH_METRIC) - state = InternalReport.State.RUNNING - measurements.put(REACH_MEASUREMENT_KEY.measurementId, INTERNAL_PENDING_REACH_MEASUREMENT) - details = InternalReportKt.details { eventGroupFilters.putAll(EVENT_GROUP_FILTERS_MAP) } - createTime = timestamp { seconds = 1000 } - reportIdempotencyKey = REACH_REPORT_IDEMPOTENCY_KEY -} -private val INTERNAL_SUCCEEDED_REACH_REPORT = - INTERNAL_PENDING_REACH_REPORT.copy { - state = InternalReport.State.SUCCEEDED - measurements.put(REACH_MEASUREMENT_KEY.measurementId, INTERNAL_SUCCEEDED_REACH_MEASUREMENT) - } - -// Internal reports of impression -private val INTERNAL_PENDING_IMPRESSION_REPORT = internalReport { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[1] - periodicTimeInterval = INTERNAL_PERIODIC_TIME_INTERVAL - metrics.add(INTERNAL_IMPRESSION_METRIC) - state = InternalReport.State.RUNNING - measurements.put( - IMPRESSION_MEASUREMENT_KEY.measurementId, - INTERNAL_PENDING_IMPRESSION_MEASUREMENT, - ) - details = InternalReportKt.details { eventGroupFilters.putAll(EVENT_GROUP_FILTERS_MAP) } - createTime = timestamp { seconds = 2000 } - reportIdempotencyKey = IMPRESSION_REPORT_IDEMPOTENCY_KEY -} -private val INTERNAL_SUCCEEDED_IMPRESSION_REPORT = - INTERNAL_PENDING_IMPRESSION_REPORT.copy { - state = InternalReport.State.SUCCEEDED - measurements.put( - IMPRESSION_MEASUREMENT_KEY.measurementId, - INTERNAL_SUCCEEDED_IMPRESSION_MEASUREMENT, - ) - } - -// Internal reports of watch duration -private val INTERNAL_PENDING_WATCH_DURATION_REPORT = internalReport { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[2] - periodicTimeInterval = INTERNAL_PERIODIC_TIME_INTERVAL - metrics.add(INTERNAL_WATCH_DURATION_METRIC) - state = InternalReport.State.RUNNING - measurements.put( - WATCH_DURATION_MEASUREMENT_KEY.measurementId, - INTERNAL_PENDING_WATCH_DURATION_MEASUREMENT, - ) - details = InternalReportKt.details { eventGroupFilters.putAll(EVENT_GROUP_FILTERS_MAP) } - createTime = timestamp { seconds = 3000 } - reportIdempotencyKey = WATCH_DURATION_REPORT_IDEMPOTENCY_KEY -} -private val INTERNAL_SUCCEEDED_WATCH_DURATION_REPORT = - INTERNAL_PENDING_WATCH_DURATION_REPORT.copy { - state = InternalReport.State.SUCCEEDED - measurements.put( - WATCH_DURATION_MEASUREMENT_KEY.measurementId, - INTERNAL_SUCCEEDED_WATCH_DURATION_MEASUREMENT, - ) - } - -// Internal reports of frequency histogram -private val INTERNAL_PENDING_FREQUENCY_HISTOGRAM_REPORT = internalReport { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[3] - periodicTimeInterval = INTERNAL_PERIODIC_TIME_INTERVAL - metrics.add(INTERNAL_FREQUENCY_HISTOGRAM_METRIC) - state = InternalReport.State.RUNNING - measurements.put( - FREQUENCY_HISTOGRAM_MEASUREMENT_KEY.measurementId, - INTERNAL_PENDING_FREQUENCY_HISTOGRAM_MEASUREMENT, - ) - details = InternalReportKt.details { eventGroupFilters.putAll(EVENT_GROUP_FILTERS_MAP) } - createTime = timestamp { seconds = 4000 } - reportIdempotencyKey = FREQUENCY_HISTOGRAM_REPORT_IDEMPOTENCY_KEY -} -private val INTERNAL_SUCCEEDED_FREQUENCY_HISTOGRAM_REPORT = - INTERNAL_PENDING_FREQUENCY_HISTOGRAM_REPORT.copy { - state = InternalReport.State.SUCCEEDED - measurements.put( - FREQUENCY_HISTOGRAM_MEASUREMENT_KEY.measurementId, - INTERNAL_SUCCEEDED_FREQUENCY_HISTOGRAM_MEASUREMENT, - ) - } - -private val EVENT_GROUP_UNIVERSE = eventGroupUniverse { - eventGroupEntries += - COVERED_EVENT_GROUP_KEYS.map { - EventGroupUniverseKt.eventGroupEntry { - key = it.toName() - value = EVENT_GROUP_FILTER - } - } -} - -private val EVENT_GROUP_UNIVERSE_WITH_ONE_DATA_PROVIDER = eventGroupUniverse { - eventGroupEntries += - EventGroupUniverseKt.eventGroupEntry { - key = COVERED_EVENT_GROUP_KEYS[0].toName() - value = EVENT_GROUP_FILTER - } -} - -// Public reports with running states -// Reports of reach -private val PENDING_REACH_REPORT = report { - name = REPORT_NAMES[0] - reportIdempotencyKey = REACH_REPORT_IDEMPOTENCY_KEY - measurementConsumer = MEASUREMENT_CONSUMERS.values.first().name - eventGroupUniverse = EVENT_GROUP_UNIVERSE - periodicTimeInterval = PERIODIC_TIME_INTERVAL - metrics.add(REACH_METRIC) - state = Report.State.RUNNING -} -private val SUCCEEDED_REACH_REPORT = PENDING_REACH_REPORT.copy { state = Report.State.SUCCEEDED } - -// Reports of impression -private val PENDING_IMPRESSION_REPORT = report { - name = REPORT_NAMES[1] - reportIdempotencyKey = IMPRESSION_REPORT_IDEMPOTENCY_KEY - measurementConsumer = MEASUREMENT_CONSUMERS.values.first().name - eventGroupUniverse = EVENT_GROUP_UNIVERSE - periodicTimeInterval = PERIODIC_TIME_INTERVAL - metrics.add(IMPRESSION_METRIC) - state = Report.State.RUNNING -} -private val SUCCEEDED_IMPRESSION_REPORT = - PENDING_IMPRESSION_REPORT.copy { state = Report.State.SUCCEEDED } - -// Reports of watch duration -private val PENDING_WATCH_DURATION_REPORT = report { - name = REPORT_NAMES[2] - reportIdempotencyKey = WATCH_DURATION_REPORT_IDEMPOTENCY_KEY - measurementConsumer = MEASUREMENT_CONSUMERS.values.first().name - eventGroupUniverse = EVENT_GROUP_UNIVERSE - periodicTimeInterval = PERIODIC_TIME_INTERVAL - metrics.add(WATCH_DURATION_METRIC) - state = Report.State.RUNNING -} -private val SUCCEEDED_WATCH_DURATION_REPORT = - PENDING_WATCH_DURATION_REPORT.copy { state = Report.State.SUCCEEDED } - -// Reports of frequency histogram -private val PENDING_FREQUENCY_HISTOGRAM_REPORT = report { - name = REPORT_NAMES[3] - reportIdempotencyKey = FREQUENCY_HISTOGRAM_REPORT_IDEMPOTENCY_KEY - measurementConsumer = MEASUREMENT_CONSUMERS.values.first().name - eventGroupUniverse = EVENT_GROUP_UNIVERSE - periodicTimeInterval = PERIODIC_TIME_INTERVAL - metrics.add(FREQUENCY_HISTOGRAM_METRIC) - state = Report.State.RUNNING -} -private val SUCCEEDED_FREQUENCY_HISTOGRAM_REPORT = - PENDING_FREQUENCY_HISTOGRAM_REPORT.copy { state = Report.State.SUCCEEDED } - -@RunWith(JUnit4::class) -class ReportsServiceTest { - - private val internalReportsMock: ReportsCoroutineImplBase = mockService { - onBlocking { createReport(any()) } - .thenReturn( - INTERNAL_PENDING_REACH_REPORT, - INTERNAL_PENDING_IMPRESSION_REPORT, - INTERNAL_PENDING_WATCH_DURATION_REPORT, - INTERNAL_PENDING_FREQUENCY_HISTOGRAM_REPORT, - ) - onBlocking { streamReports(any()) } - .thenReturn( - flowOf( - INTERNAL_PENDING_REACH_REPORT, - INTERNAL_PENDING_IMPRESSION_REPORT, - INTERNAL_PENDING_WATCH_DURATION_REPORT, - INTERNAL_PENDING_FREQUENCY_HISTOGRAM_REPORT, - ) - ) - onBlocking { getReport(any()) } - .thenReturn( - INTERNAL_SUCCEEDED_REACH_REPORT, - INTERNAL_SUCCEEDED_IMPRESSION_REPORT, - INTERNAL_SUCCEEDED_WATCH_DURATION_REPORT, - INTERNAL_SUCCEEDED_FREQUENCY_HISTOGRAM_REPORT, - ) - onBlocking { getReportByIdempotencyKey(any()) } - .thenThrow(StatusRuntimeException(Status.NOT_FOUND)) - } - - private val internalReportingSetsMock: InternalReportingSetsCoroutineImplBase = mockService { - onBlocking { batchGetReportingSet(any()) } - .thenReturn( - flowOf( - INTERNAL_REPORTING_SETS[0], - INTERNAL_REPORTING_SETS[1], - INTERNAL_REPORTING_SETS[0], - INTERNAL_REPORTING_SETS[1], - ) - ) - } - - private val internalMeasurementsMock: InternalMeasurementsCoroutineImplBase = mockService { - onBlocking { getMeasurement(any()) }.thenThrow(StatusRuntimeException(Status.NOT_FOUND)) - } - - private val measurementsMock: MeasurementsCoroutineImplBase = mockService { - onBlocking { getMeasurement(any()) } - .thenReturn( - SUCCEEDED_REACH_MEASUREMENT, - SUCCEEDED_IMPRESSION_MEASUREMENT, - SUCCEEDED_WATCH_DURATION_MEASUREMENT, - SUCCEEDED_FREQUENCY_HISTOGRAM_MEASUREMENT, - ) - - onBlocking { createMeasurement(any()) }.thenReturn(BASE_REACH_MEASUREMENT) - } - - private val measurementConsumersMock: MeasurementConsumersCoroutineImplBase = mockService { - onBlocking { getMeasurementConsumer(any()) }.thenReturn(MEASUREMENT_CONSUMERS.values.first()) - } - - private val dataProvidersMock: DataProvidersCoroutineImplBase = mockService { - var stubbing = onBlocking { getDataProvider(any()) } - for (dataProvider in DATA_PROVIDERS.values) { - stubbing = stubbing.thenReturn(dataProvider) - } - } - - private val certificateMock: CertificatesCoroutineImplBase = mockService { - onBlocking { getCertificate(eq(getCertificateRequest { name = AGGREGATOR_CERTIFICATE.name })) } - .thenReturn(AGGREGATOR_CERTIFICATE) - for (dataProvider in DATA_PROVIDERS.values) { - onBlocking { getCertificate(eq(getCertificateRequest { name = dataProvider.certificate })) } - .thenReturn( - certificate { - name = dataProvider.certificate - x509Der = DATA_PROVIDER_SIGNING_KEY.certificate.encoded.toByteString() - } - ) - } - for (measurementConsumer in MEASUREMENT_CONSUMERS.values) { - onBlocking { - getCertificate(eq(getCertificateRequest { name = measurementConsumer.certificate })) - } - .thenReturn( - certificate { - name = measurementConsumer.certificate - x509Der = measurementConsumer.certificateDer - } - ) - } - } - - private val randomMock: Random = mock() - - @get:Rule - val grpcTestServerRule = GrpcTestServerRule { - addService(internalReportsMock) - addService(internalReportingSetsMock) - addService(internalMeasurementsMock) - addService(measurementsMock) - addService(measurementConsumersMock) - addService(dataProvidersMock) - addService(certificateMock) - } - - private lateinit var service: ReportsService - - @Before - fun initService() { - randomMock.stub { - on { nextInt(any()) } doReturn SECURE_RANDOM_OUTPUT_INT - on { nextLong() } doReturn SECURE_RANDOM_OUTPUT_LONG - } - - service = - ReportsService( - InternalReportsCoroutineStub(grpcTestServerRule.channel), - InternalReportingSetsCoroutineStub(grpcTestServerRule.channel), - InternalMeasurementsCoroutineStub(grpcTestServerRule.channel), - DataProvidersCoroutineStub(grpcTestServerRule.channel), - MeasurementConsumersCoroutineStub(grpcTestServerRule.channel), - MeasurementsCoroutineStub(grpcTestServerRule.channel), - CertificatesCoroutineStub(grpcTestServerRule.channel), - ENCRYPTION_KEY_PAIR_STORE, - randomMock, - SECRETS_DIR, - listOf(AGGREGATOR_ROOT_CERTIFICATE, DATA_PROVIDER_ROOT_CERTIFICATE).associateBy { - it.subjectKeyIdentifier!! - }, - MEASUREMENT_SPEC_CONFIG, - ) - } - - @Test - fun `createReport returns a report of reach with RUNNING state`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - - val expected = PENDING_REACH_REPORT - - // Verify proto argument of ReportsCoroutineImplBase::getReportByIdempotencyKey - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::getReportByIdempotencyKey) - .isEqualTo( - getReportByIdempotencyKeyRequest { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - reportIdempotencyKey = REACH_REPORT_IDEMPOTENCY_KEY - } - ) - - // Verify proto argument of InternalReportingSetsCoroutineImplBase::batchGetReportingSet - verifyProtoArgument( - internalReportingSetsMock, - InternalReportingSetsCoroutineImplBase::batchGetReportingSet, - ) - .ignoringRepeatedFieldOrder() - .isEqualTo( - batchGetReportingSetRequest { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetIds += INTERNAL_REPORTING_SETS[0].externalReportingSetId - externalReportingSetIds += INTERNAL_REPORTING_SETS[1].externalReportingSetId - } - ) - - // Verify proto argument of MeasurementConsumersCoroutineImplBase::getMeasurementConsumer - verifyProtoArgument( - measurementConsumersMock, - MeasurementConsumersCoroutineImplBase::getMeasurementConsumer, - ) - .isEqualTo(getMeasurementConsumerRequest { name = MEASUREMENT_CONSUMERS.values.first().name }) - - // Verify proto argument of DataProvidersCoroutineImplBase::getDataProvider - val dataProvidersCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(dataProvidersMock, times(2)) { getDataProvider(dataProvidersCaptor.capture()) } - val capturedDataProviderRequests = dataProvidersCaptor.allValues - assertThat(capturedDataProviderRequests) - .containsExactly( - getDataProviderRequest { name = DATA_PROVIDERS_LIST[0].name }, - getDataProviderRequest { name = DATA_PROVIDERS_LIST[1].name }, - ) - - // Verify proto argument of MeasurementsCoroutineImplBase::createMeasurement - val capturedMeasurementRequest = - captureFirst { - runBlocking { verify(measurementsMock).createMeasurement(capture()) } - } - assertThat(capturedMeasurementRequest) - .ignoringRepeatedFieldOrder() - .ignoringFieldDescriptors( - MEASUREMENT_SPEC_FIELD_DESCRIPTOR, - ENCRYPTED_REQUISITION_SPEC_FIELD_DESCRIPTOR, - ) - .isEqualTo(REACH_MEASUREMENT_REQUEST) - - verifyMeasurementSpec( - capturedMeasurementRequest.measurement.measurementSpec, - MEASUREMENT_CONSUMER_CERTIFICATE, - TRUSTED_MEASUREMENT_CONSUMER_ISSUER, - ) - val measurementSpec: MeasurementSpec = - capturedMeasurementRequest.measurement.measurementSpec.unpack() - assertThat(measurementSpec).isEqualTo(REACH_MEASUREMENT_SPEC) - - val dataProvidersList = - capturedMeasurementRequest.measurement.dataProvidersList.sortedBy { it.key } - - dataProvidersList.map { dataProviderEntry -> - val signedRequisitionSpec = - decryptRequisitionSpec( - dataProviderEntry.value.encryptedRequisitionSpec, - DATA_PROVIDER_PRIVATE_KEY_HANDLE, - ) - val requisitionSpec: RequisitionSpec = signedRequisitionSpec.unpack() - verifyRequisitionSpec( - signedRequisitionSpec, - requisitionSpec, - measurementSpec, - MEASUREMENT_CONSUMER_CERTIFICATE, - TRUSTED_MEASUREMENT_CONSUMER_ISSUER, - ) - } - - // Verify proto argument of InternalMeasurementsCoroutineImplBase::batchCreateMeasurements - verifyProtoArgument( - internalMeasurementsMock, - InternalMeasurementsCoroutineImplBase::batchCreateMeasurements, - ) - .isEqualTo( - batchCreateMeasurementsRequest { measurements += INTERNAL_PENDING_REACH_MEASUREMENT } - ) - - // Verify proto argument of InternalReportsCoroutineImplBase::createReport - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::createReport) - .ignoringRepeatedFieldOrder() - .isEqualTo( - internalCreateReportRequest { - report = - INTERNAL_PENDING_REACH_REPORT.copy { - clearState() - clearExternalReportId() - measurements.clear() - clearCreateTime() - } - measurements += - InternalCreateReportRequestKt.measurementKey { - measurementConsumerReferenceId = REACH_MEASUREMENT_KEY.measurementConsumerId - measurementReferenceId = REACH_MEASUREMENT_KEY.measurementId - } - } - ) - - assertThat(result).isEqualTo(expected) - } - - @Test - fun `createReport creates reach single data provider measurement when report needs reach`() { - runBlocking { - whenever(internalReportingSetsMock.batchGetReportingSet(any())) - .thenReturn(flowOf(INTERNAL_REPORTING_SETS[0])) - } - - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - eventGroupUniverse = EVENT_GROUP_UNIVERSE_WITH_ONE_DATA_PROVIDER - metrics.clear() - metrics += metric { - reach = reachParams {} - setOperations += namedSetOperation { - uniqueName = REACH_SET_OPERATION_UNIQUE_NAME - setOperation = setOperation { - type = SetOperation.Type.UNION - lhs = - SetOperationKt.operand { reportingSet = INTERNAL_REPORTING_SETS[0].resourceName } - } - } - } - clearState() - } - } - - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - - // Verify proto argument of MeasurementsCoroutineImplBase::createMeasurement - val capturedMeasurementRequest = - captureFirst { - runBlocking { verify(measurementsMock).createMeasurement(capture()) } - } - assertThat(capturedMeasurementRequest) - .ignoringRepeatedFieldOrder() - .ignoringFieldDescriptors( - MEASUREMENT_SPEC_FIELD_DESCRIPTOR, - ENCRYPTED_REQUISITION_SPEC_FIELD_DESCRIPTOR, - ) - .isEqualTo(REACH_SINGLE_DATA_PROVIDER_MEASUREMENT_REQUEST) - - verifyMeasurementSpec( - capturedMeasurementRequest.measurement.measurementSpec, - MEASUREMENT_CONSUMER_CERTIFICATE, - TRUSTED_MEASUREMENT_CONSUMER_ISSUER, - ) - val measurementSpec: MeasurementSpec = - capturedMeasurementRequest.measurement.measurementSpec.unpack() - assertThat(measurementSpec).isEqualTo(REACH_SINGLE_DATA_PROVIDER_MEASUREMENT_SPEC) - } - - @Test - fun `createReport creates rf single data provider measurement when report needs rf`() { - runBlocking { - whenever(internalReportingSetsMock.batchGetReportingSet(any())) - .thenReturn(flowOf(INTERNAL_REPORTING_SETS[0])) - } - - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_FREQUENCY_HISTOGRAM_REPORT.copy { - eventGroupUniverse = EVENT_GROUP_UNIVERSE_WITH_ONE_DATA_PROVIDER - metrics.clear() - metrics += metric { - frequencyHistogram = frequencyHistogramParams { - maximumFrequencyPerUser = MAXIMUM_FREQUENCY - } - setOperations += namedSetOperation { - uniqueName = FREQUENCY_HISTOGRAM_SET_OPERATION_UNIQUE_NAME - setOperation = setOperation { - type = SetOperation.Type.UNION - lhs = - SetOperationKt.operand { reportingSet = INTERNAL_REPORTING_SETS[0].resourceName } - } - } - } - clearState() - } - } - - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - - // Verify proto argument of MeasurementsCoroutineImplBase::createMeasurement - val capturedMeasurementRequest = - captureFirst { - runBlocking { verify(measurementsMock).createMeasurement(capture()) } - } - assertThat(capturedMeasurementRequest) - .ignoringRepeatedFieldOrder() - .ignoringFieldDescriptors( - MEASUREMENT_SPEC_FIELD_DESCRIPTOR, - ENCRYPTED_REQUISITION_SPEC_FIELD_DESCRIPTOR, - ) - .isEqualTo(REACH_FREQUENCY_SINGLE_DATA_PROVIDER_MEASUREMENT_REQUEST) - - verifyMeasurementSpec( - capturedMeasurementRequest.measurement.measurementSpec, - MEASUREMENT_CONSUMER_CERTIFICATE, - TRUSTED_MEASUREMENT_CONSUMER_ISSUER, - ) - val measurementSpec: MeasurementSpec = - capturedMeasurementRequest.measurement.measurementSpec.unpack() - assertThat(measurementSpec).isEqualTo(REACH_FREQUENCY_SINGLE_DATA_PROVIDER_MEASUREMENT_SPEC) - } - - @Test - fun `createReport creates reach and requency measurement when report needs frequency`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_FREQUENCY_HISTOGRAM_REPORT.copy { clearState() } - } - - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - - // Verify proto argument of MeasurementsCoroutineImplBase::createMeasurement - val capturedMeasurementRequest = - captureFirst { - runBlocking { verify(measurementsMock).createMeasurement(capture()) } - } - assertThat(capturedMeasurementRequest) - .ignoringRepeatedFieldOrder() - .ignoringFieldDescriptors( - MEASUREMENT_SPEC_FIELD_DESCRIPTOR, - ENCRYPTED_REQUISITION_SPEC_FIELD_DESCRIPTOR, - ) - .isEqualTo(REACH_FREQUENCY_MEASUREMENT_REQUEST) - - verifyMeasurementSpec( - capturedMeasurementRequest.measurement.measurementSpec, - MEASUREMENT_CONSUMER_CERTIFICATE, - TRUSTED_MEASUREMENT_CONSUMER_ISSUER, - ) - val measurementSpec: MeasurementSpec = - capturedMeasurementRequest.measurement.measurementSpec.unpack() - assertThat(measurementSpec).isEqualTo(REACH_FREQUENCY_MEASUREMENT_SPEC) - } - - @Test - fun `createReport creates impression measurement when report needs impression`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_IMPRESSION_REPORT.copy { clearState() } - } - - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - - // Verify proto argument of MeasurementsCoroutineImplBase::createMeasurement - val capturedMeasurementRequest = - captureFirst { - runBlocking { verify(measurementsMock).createMeasurement(capture()) } - } - assertThat(capturedMeasurementRequest) - .ignoringRepeatedFieldOrder() - .ignoringFieldDescriptors( - MEASUREMENT_SPEC_FIELD_DESCRIPTOR, - ENCRYPTED_REQUISITION_SPEC_FIELD_DESCRIPTOR, - ) - .isEqualTo(IMPRESSION_MEASUREMENT_REQUEST) - - verifyMeasurementSpec( - capturedMeasurementRequest.measurement.measurementSpec, - MEASUREMENT_CONSUMER_CERTIFICATE, - TRUSTED_MEASUREMENT_CONSUMER_ISSUER, - ) - val measurementSpec: MeasurementSpec = - capturedMeasurementRequest.measurement.measurementSpec.unpack() - assertThat(measurementSpec).isEqualTo(IMPRESSION_MEASUREMENT_SPEC) - } - - @Test - fun `createReport creates duration measurement when report needs duration`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_WATCH_DURATION_REPORT.copy { clearState() } - } - - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - - // Verify proto argument of MeasurementsCoroutineImplBase::createMeasurement - val capturedMeasurementRequest = - captureFirst { - runBlocking { verify(measurementsMock).createMeasurement(capture()) } - } - assertThat(capturedMeasurementRequest) - .ignoringRepeatedFieldOrder() - .ignoringFieldDescriptors( - MEASUREMENT_SPEC_FIELD_DESCRIPTOR, - ENCRYPTED_REQUISITION_SPEC_FIELD_DESCRIPTOR, - ) - .isEqualTo(WATCH_DURATION_MEASUREMENT_REQUEST) - - verifyMeasurementSpec( - capturedMeasurementRequest.measurement.measurementSpec, - MEASUREMENT_CONSUMER_CERTIFICATE, - TRUSTED_MEASUREMENT_CONSUMER_ISSUER, - ) - val measurementSpec: MeasurementSpec = - capturedMeasurementRequest.measurement.measurementSpec.unpack() - assertThat(measurementSpec).isEqualTo(WATCH_DURATION_MEASUREMENT_SPEC) - } - - @Test - fun `createReport returns a report of reach when no event filter at all`(): Unit = runBlocking { - val internalReportingSets: List = - INTERNAL_REPORTING_SETS.map { internalReportingSet -> - internalReportingSet.copy { - clearFilter() - displayName = "$measurementConsumerReferenceId-$externalReportingSetId-$filter" - } - } - - whenever( - internalReportingSetsMock.batchGetReportingSet( - eq( - batchGetReportingSetRequest { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetIds += internalReportingSets[0].externalReportingSetId - externalReportingSetIds += internalReportingSets[1].externalReportingSetId - } - ) - ) - ) - .thenReturn( - flowOf( - internalReportingSets[0], - internalReportingSets[1], - internalReportingSets[0], - internalReportingSets[1], - ) - ) - - val requestingReport = - PENDING_REACH_REPORT.copy { - clearState() - eventGroupUniverse = - EVENT_GROUP_UNIVERSE.copy { - eventGroupEntries.clear() - eventGroupEntries += - COVERED_EVENT_GROUP_KEYS.map { - EventGroupUniverseKt.eventGroupEntry { key = it.toName() } - } - } - } - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = requestingReport - } - - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - - // Verify proto argument of MeasurementsCoroutineImplBase::createMeasurement - val dataProviderEntries = - REQUISITION_SPECS.mapValues { (dataProviderKey, requisitionSpec) -> - val dataProvider = DATA_PROVIDERS.getValue(dataProviderKey) - dataProviderEntry { - key = dataProvider.name - - val requisitionSpecWithNoFilter = - requisitionSpec.copy { - events = - RequisitionSpecKt.events { - val eventGroupsWithNoFilter = - eventGroups.map { eventGroup -> - eventGroup.copy { - value = - EventGroupEntryKt.value { - collectionInterval = MEASUREMENT_TIME_INTERVAL - filter = eventFilter { expression = "" } - } - } - } - eventGroups.clear() - eventGroups += eventGroupsWithNoFilter - } - } - value = - DataProviderEntryKt.value { - dataProviderCertificate = dataProvider.certificate - dataProviderPublicKey = dataProvider.publicKey.message - encryptedRequisitionSpec = - encryptRequisitionSpec( - signRequisitionSpec( - requisitionSpecWithNoFilter, - MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE, - ), - dataProvider.publicKey.unpack(), - ) - nonceHash = Hashing.hashSha256(requisitionSpecWithNoFilter.nonce) - } - } - } - - val reachMeasurementRequest = - REACH_MEASUREMENT_REQUEST.copy { - measurement = - measurement.copy { - dataProviders.clear() - dataProviders += - DATA_PROVIDER_KEYS_IN_SET_OPERATION.map { dataProviderEntries.getValue(it) } - } - } - - val capturedMeasurementRequest = - captureFirst { - runBlocking { verify(measurementsMock).createMeasurement(capture()) } - } - assertThat(capturedMeasurementRequest) - .ignoringRepeatedFieldOrder() - .ignoringFieldDescriptors( - MEASUREMENT_SPEC_FIELD_DESCRIPTOR, - ENCRYPTED_REQUISITION_SPEC_FIELD_DESCRIPTOR, - ) - .isEqualTo(reachMeasurementRequest) - - verifyMeasurementSpec( - capturedMeasurementRequest.measurement.measurementSpec, - MEASUREMENT_CONSUMER_CERTIFICATE, - TRUSTED_MEASUREMENT_CONSUMER_ISSUER, - ) - val measurementSpec: MeasurementSpec = - capturedMeasurementRequest.measurement.measurementSpec.unpack() - assertThat(measurementSpec).isEqualTo(REACH_MEASUREMENT_SPEC) - - val dataProvidersList = - capturedMeasurementRequest.measurement.dataProvidersList.sortedBy { it.key } - - val filters = - dataProvidersList.flatMap { dataProviderEntry -> - val signedRequisitionSpec = - decryptRequisitionSpec( - dataProviderEntry.value.encryptedRequisitionSpec, - DATA_PROVIDER_PRIVATE_KEY_HANDLE, - ) - val requisitionSpec: RequisitionSpec = signedRequisitionSpec.unpack() - - verifyRequisitionSpec( - signedRequisitionSpec, - requisitionSpec, - measurementSpec, - MEASUREMENT_CONSUMER_CERTIFICATE, - TRUSTED_MEASUREMENT_CONSUMER_ISSUER, - ) - - requisitionSpec.events.eventGroupsList.map { eventGroupEntry -> - eventGroupEntry.value.filter.expression - } - } - - for (filter in filters) { - assertThat(filter).isEqualTo("") - } - } - - @Test - fun `createReport returns a report of reach with RUNNING state when timeIntervals set`() { - val internalReport = - INTERNAL_PENDING_REACH_REPORT.copy { - clearTime() - timeIntervals = internalTimeIntervals { - timeIntervals += internalTimeInterval { - startTime = START_TIME - endTime = Timestamps.add(START_TIME, TIME_INTERVAL_INCREMENT) - } - } - } - runBlocking { whenever(internalReportsMock.createReport(any())).thenReturn(internalReport) } - - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - clearTime() - timeIntervals = timeIntervals { - timeIntervals += timeInterval { - startTime = START_TIME - endTime = Timestamps.add(START_TIME, TIME_INTERVAL_INCREMENT) - } - } - } - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - - val expected = - PENDING_REACH_REPORT.copy { - clearTime() - timeIntervals = timeIntervals { - timeIntervals += timeInterval { - startTime = START_TIME - endTime = Timestamps.add(START_TIME, TIME_INTERVAL_INCREMENT) - } - } - } - - // Verify proto argument of ReportsCoroutineImplBase::getReportByIdempotencyKey - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::getReportByIdempotencyKey) - .isEqualTo( - getReportByIdempotencyKeyRequest { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - reportIdempotencyKey = REACH_REPORT_IDEMPOTENCY_KEY - } - ) - - // Verify proto argument of InternalReportingSetsCoroutineImplBase::batchGetReportingSet - verifyProtoArgument( - internalReportingSetsMock, - InternalReportingSetsCoroutineImplBase::batchGetReportingSet, - ) - .ignoringRepeatedFieldOrder() - .isEqualTo( - batchGetReportingSetRequest { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetIds += INTERNAL_REPORTING_SETS[0].externalReportingSetId - externalReportingSetIds += INTERNAL_REPORTING_SETS[1].externalReportingSetId - } - ) - - // Verify proto argument of MeasurementConsumersCoroutineImplBase::getMeasurementConsumer - verifyProtoArgument( - measurementConsumersMock, - MeasurementConsumersCoroutineImplBase::getMeasurementConsumer, - ) - .isEqualTo(getMeasurementConsumerRequest { name = MEASUREMENT_CONSUMERS.values.first().name }) - - // Verify proto argument of DataProvidersCoroutineImplBase::getDataProvider - val dataProvidersCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(dataProvidersMock, times(2)) { getDataProvider(dataProvidersCaptor.capture()) } - val capturedDataProviderRequests = dataProvidersCaptor.allValues - assertThat(capturedDataProviderRequests) - .containsExactly( - getDataProviderRequest { name = DATA_PROVIDERS_LIST[0].name }, - getDataProviderRequest { name = DATA_PROVIDERS_LIST[1].name }, - ) - - // Verify proto argument of MeasurementsCoroutineImplBase::createMeasurement - val capturedMeasurementRequest = - captureFirst { - runBlocking { verify(measurementsMock).createMeasurement(capture()) } - } - - assertThat(capturedMeasurementRequest) - .ignoringRepeatedFieldOrder() - .ignoringFieldDescriptors( - MEASUREMENT_SPEC_FIELD_DESCRIPTOR, - ENCRYPTED_REQUISITION_SPEC_FIELD_DESCRIPTOR, - ) - .isEqualTo(REACH_MEASUREMENT_REQUEST) - - verifyMeasurementSpec( - capturedMeasurementRequest.measurement.measurementSpec, - MEASUREMENT_CONSUMER_CERTIFICATE, - TRUSTED_MEASUREMENT_CONSUMER_ISSUER, - ) - val measurementSpec: MeasurementSpec = - capturedMeasurementRequest.measurement.measurementSpec.unpack() - assertThat(measurementSpec).isEqualTo(REACH_MEASUREMENT_SPEC) - - val dataProvidersList = - capturedMeasurementRequest.measurement.dataProvidersList.sortedBy { it.key } - - dataProvidersList.map { dataProviderEntry -> - val signedRequisitionSpec = - decryptRequisitionSpec( - dataProviderEntry.value.encryptedRequisitionSpec, - DATA_PROVIDER_PRIVATE_KEY_HANDLE, - ) - val requisitionSpec: RequisitionSpec = signedRequisitionSpec.unpack() - verifyRequisitionSpec( - signedRequisitionSpec, - requisitionSpec, - measurementSpec, - MEASUREMENT_CONSUMER_CERTIFICATE, - TRUSTED_MEASUREMENT_CONSUMER_ISSUER, - ) - } - - // Verify proto argument of InternalMeasurementsCoroutineImplBase::batchCreateMeasurements - verifyProtoArgument( - internalMeasurementsMock, - InternalMeasurementsCoroutineImplBase::batchCreateMeasurements, - ) - .isEqualTo( - batchCreateMeasurementsRequest { measurements += INTERNAL_PENDING_REACH_MEASUREMENT } - ) - - // Verify proto argument of InternalReportsCoroutineImplBase::createReport - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::createReport) - .ignoringRepeatedFieldOrder() - .isEqualTo( - internalCreateReportRequest { - report = - internalReport.copy { - clearState() - clearExternalReportId() - measurements.clear() - clearCreateTime() - } - measurements += - InternalCreateReportRequestKt.measurementKey { - measurementConsumerReferenceId = REACH_MEASUREMENT_KEY.measurementConsumerId - measurementReferenceId = REACH_MEASUREMENT_KEY.measurementId - } - } - ) - - assertThat(result).isEqualTo(expected) - } - - @Test - fun `createReport returns a report with a cumulative metric`() { - val internalCumulativeReport = - INTERNAL_PENDING_REACH_REPORT.copy { - metrics.clear() - metrics += - INTERNAL_PENDING_REACH_REPORT.metricsList[0].copy { - details = - INTERNAL_PENDING_REACH_REPORT.metricsList[0].details.copy { cumulative = true } - } - } - runBlocking { - whenever(internalReportsMock.createReport(any())).thenReturn(internalCumulativeReport) - } - - val cumulativeReport = - PENDING_REACH_REPORT.copy { - metrics.clear() - metrics += PENDING_REACH_REPORT.metricsList[0].copy { cumulative = true } - } - - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = cumulativeReport.copy { clearState() } - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - - // Verify proto argument of ReportsCoroutineImplBase::getReportByIdempotencyKey - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::getReportByIdempotencyKey) - .isEqualTo( - getReportByIdempotencyKeyRequest { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - reportIdempotencyKey = REACH_REPORT_IDEMPOTENCY_KEY - } - ) - - // Verify proto argument of InternalReportingSetsCoroutineImplBase::batchGetReportingSet - verifyProtoArgument( - internalReportingSetsMock, - InternalReportingSetsCoroutineImplBase::batchGetReportingSet, - ) - .ignoringRepeatedFieldOrder() - .isEqualTo( - batchGetReportingSetRequest { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetIds += INTERNAL_REPORTING_SETS[0].externalReportingSetId - externalReportingSetIds += INTERNAL_REPORTING_SETS[1].externalReportingSetId - } - ) - - // Verify proto argument of MeasurementConsumersCoroutineImplBase::getMeasurementConsumer - verifyProtoArgument( - measurementConsumersMock, - MeasurementConsumersCoroutineImplBase::getMeasurementConsumer, - ) - .isEqualTo(getMeasurementConsumerRequest { name = MEASUREMENT_CONSUMERS.values.first().name }) - - // Verify proto argument of DataProvidersCoroutineImplBase::getDataProvider - val dataProvidersCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(dataProvidersMock, times(2)) { getDataProvider(dataProvidersCaptor.capture()) } - val capturedDataProviderRequests = dataProvidersCaptor.allValues - assertThat(capturedDataProviderRequests) - .containsExactly( - getDataProviderRequest { name = DATA_PROVIDERS_LIST[0].name }, - getDataProviderRequest { name = DATA_PROVIDERS_LIST[1].name }, - ) - - // Verify proto argument of MeasurementsCoroutineImplBase::createMeasurement - val capturedMeasurementRequest = - captureFirst { - runBlocking { verify(measurementsMock).createMeasurement(capture()) } - } - assertThat(capturedMeasurementRequest) - .ignoringRepeatedFieldOrder() - .ignoringFieldDescriptors( - MEASUREMENT_SPEC_FIELD_DESCRIPTOR, - ENCRYPTED_REQUISITION_SPEC_FIELD_DESCRIPTOR, - ) - .isEqualTo(REACH_MEASUREMENT_REQUEST) - - verifyMeasurementSpec( - capturedMeasurementRequest.measurement.measurementSpec, - MEASUREMENT_CONSUMER_CERTIFICATE, - TRUSTED_MEASUREMENT_CONSUMER_ISSUER, - ) - val measurementSpec: MeasurementSpec = - capturedMeasurementRequest.measurement.measurementSpec.unpack() - assertThat(measurementSpec).isEqualTo(REACH_MEASUREMENT_SPEC) - - val dataProvidersList = - capturedMeasurementRequest.measurement.dataProvidersList.sortedBy { it.key } - - dataProvidersList.map { dataProviderEntry -> - val signedRequisitionSpec = - decryptRequisitionSpec( - dataProviderEntry.value.encryptedRequisitionSpec, - DATA_PROVIDER_PRIVATE_KEY_HANDLE, - ) - val requisitionSpec: RequisitionSpec = signedRequisitionSpec.unpack() - verifyRequisitionSpec( - signedRequisitionSpec, - requisitionSpec, - measurementSpec, - MEASUREMENT_CONSUMER_CERTIFICATE, - TRUSTED_MEASUREMENT_CONSUMER_ISSUER, - ) - } - - // Verify proto argument of InternalMeasurementsCoroutineImplBase::batchCreateMeasurements - verifyProtoArgument( - internalMeasurementsMock, - InternalMeasurementsCoroutineImplBase::batchCreateMeasurements, - ) - .isEqualTo( - batchCreateMeasurementsRequest { measurements += INTERNAL_PENDING_REACH_MEASUREMENT } - ) - - // Verify proto argument of InternalReportsCoroutineImplBase::createReport - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::createReport) - .ignoringRepeatedFieldOrder() - .isEqualTo( - internalCreateReportRequest { - report = - internalCumulativeReport.copy { - clearState() - clearExternalReportId() - measurements.clear() - clearCreateTime() - } - measurements += - InternalCreateReportRequestKt.measurementKey { - measurementConsumerReferenceId = REACH_MEASUREMENT_KEY.measurementConsumerId - measurementReferenceId = REACH_MEASUREMENT_KEY.measurementId - } - } - ) - - assertThat(result).isEqualTo(cumulativeReport) - } - - @Test - fun `createReport returns a report with set operation type DIFFERENCE`() { - val internalPendingReachReportWithSetDifference = - INTERNAL_PENDING_REACH_REPORT.copy { - val source = this - measurements.clear() - clearCreateTime() - val metric = internalMetric { - details = InternalMetricKt.details { reach = InternalMetricKt.reachParams {} } - namedSetOperations += - source.metrics[0].namedSetOperationsList[0].copy { - setOperation = - setOperation.copy { type = InternalMetric.SetOperation.Type.DIFFERENCE } - measurementCalculations.clear() - measurementCalculations += - source.metrics[0].namedSetOperationsList[0].measurementCalculationsList[0].copy { - weightedMeasurements.clear() - weightedMeasurements += weightedMeasurement { - measurementReferenceId = REACH_MEASUREMENT_KEY.measurementId - coefficient = -1 - } - weightedMeasurements += weightedMeasurement { - measurementReferenceId = REACH_MEASUREMENT_KEY_2.measurementId - coefficient = 1 - } - } - } - } - metrics.clear() - metrics += metric - } - - runBlocking { - whenever(internalReportsMock.createReport(any())) - .thenReturn(internalPendingReachReportWithSetDifference) - whenever(measurementsMock.createMeasurement(any())) - .thenReturn(BASE_REACH_MEASUREMENT, BASE_REACH_MEASUREMENT_2) - } - - val pendingReachReportWithSetDifference = - PENDING_REACH_REPORT.copy { - metrics.clear() - metrics += metric { - reach = reachParams {} - cumulative = false - setOperations += namedSetOperation { - uniqueName = REACH_SET_OPERATION_UNIQUE_NAME - setOperation = setOperation { - type = SetOperation.Type.DIFFERENCE - lhs = - SetOperationKt.operand { reportingSet = INTERNAL_REPORTING_SETS[0].resourceName } - rhs = - SetOperationKt.operand { reportingSet = INTERNAL_REPORTING_SETS[1].resourceName } - } - } - } - } - - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = pendingReachReportWithSetDifference - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - - // Verify proto argument of ReportsCoroutineImplBase::getReportByIdempotencyKey - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::getReportByIdempotencyKey) - .isEqualTo( - getReportByIdempotencyKeyRequest { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - reportIdempotencyKey = REACH_REPORT_IDEMPOTENCY_KEY - } - ) - - // Verify proto argument of InternalReportingSetsCoroutineImplBase::batchGetReportingSet - verifyProtoArgument( - internalReportingSetsMock, - InternalReportingSetsCoroutineImplBase::batchGetReportingSet, - ) - .ignoringRepeatedFieldOrder() - .isEqualTo( - batchGetReportingSetRequest { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetIds += INTERNAL_REPORTING_SETS[0].externalReportingSetId - externalReportingSetIds += INTERNAL_REPORTING_SETS[1].externalReportingSetId - } - ) - - // Verify proto argument of MeasurementConsumersCoroutineImplBase::getMeasurementConsumer - verifyProtoArgument( - measurementConsumersMock, - MeasurementConsumersCoroutineImplBase::getMeasurementConsumer, - ) - .isEqualTo(getMeasurementConsumerRequest { name = MEASUREMENT_CONSUMERS.values.first().name }) - - // Verify proto argument of DataProvidersCoroutineImplBase::getDataProvider - val dataProvidersCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(dataProvidersMock, times(2)) { getDataProvider(dataProvidersCaptor.capture()) } - val capturedDataProviderRequests = dataProvidersCaptor.allValues - assertThat(capturedDataProviderRequests) - .containsExactly( - getDataProviderRequest { name = DATA_PROVIDERS_LIST[0].name }, - getDataProviderRequest { name = DATA_PROVIDERS_LIST[1].name }, - ) - - // Verify proto argument of MeasurementsCoroutineImplBase::createMeasurement - val measurementCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(measurementsMock, times(2)) { createMeasurement(measurementCaptor.capture()) } - assertThat(measurementCaptor.allValues.map { it.measurement }).containsNoDuplicates() - - // Verify proto argument of InternalMeasurementsCoroutineImplBase::batchCreateMeasurements - verifyProtoArgument( - internalMeasurementsMock, - InternalMeasurementsCoroutineImplBase::batchCreateMeasurements, - ) - .ignoringRepeatedFieldOrder() - .isEqualTo( - batchCreateMeasurementsRequest { - measurements += INTERNAL_PENDING_REACH_MEASUREMENT - measurements += - INTERNAL_PENDING_REACH_MEASUREMENT.copy { - measurementReferenceId = REACH_MEASUREMENT_KEY_2.measurementId - } - } - ) - - // Verify proto argument of InternalReportsCoroutineImplBase::createReport - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::createReport) - .ignoringRepeatedFieldOrder() - .isEqualTo( - internalCreateReportRequest { - report = - internalPendingReachReportWithSetDifference.copy { - clearState() - clearExternalReportId() - } - measurements += - InternalCreateReportRequestKt.measurementKey { - measurementConsumerReferenceId = REACH_MEASUREMENT_KEY.measurementConsumerId - measurementReferenceId = REACH_MEASUREMENT_KEY.measurementId - } - measurements += - InternalCreateReportRequestKt.measurementKey { - measurementConsumerReferenceId = REACH_MEASUREMENT_KEY_2.measurementConsumerId - measurementReferenceId = REACH_MEASUREMENT_KEY_2.measurementId - } - } - ) - - assertThat(result).isEqualTo(pendingReachReportWithSetDifference) - } - - @Test - fun `createReport succeeds when the internal createMeasurement throws ALREADY_EXISTS`() = - runBlocking { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - - val report = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - assertThat(report.state).isEqualTo(Report.State.RUNNING) - } - - @Test - fun `createReport throws UNAUTHENTICATED when no principal is found`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - val exception = - assertFailsWith { runBlocking { service.createReport(request) } } - assertThat(exception.status.code).isEqualTo(Status.Code.UNAUTHENTICATED) - } - - @Test - fun `createReport throws PERMISSION_DENIED when MeasurementConsumer caller doesn't match`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.last().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.PERMISSION_DENIED) - assertThat(exception.status.description) - .isEqualTo("Cannot create a Report for another MeasurementConsumer.") - } - - @Test - fun `createReport throws PERMISSION_DENIED when report doesn't belong to caller`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.last().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.PERMISSION_DENIED) - assertThat(exception.status.description) - .isEqualTo("Cannot create a Report for another MeasurementConsumer.") - } - - @Test - fun `createReport throws UNAUTHENTICATED when the caller is not MeasurementConsumer`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - val exception = - assertFailsWith { - withDataProviderPrincipal(DATA_PROVIDERS_LIST[0].name) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.UNAUTHENTICATED) - assertThat(exception.status.description).isEqualTo("No ReportingPrincipal found") - } - - @Test - fun `createReport throws INVALID_ARGUMENT when parent is unspecified`() { - val request = createReportRequest { report = PENDING_REACH_REPORT.copy { clearState() } } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - assertThat(exception.status.description).isEqualTo("Parent is either unspecified or invalid.") - } - - @Test - fun `createReport throws INVALID_ARGUMENT when report is unspecified`() { - val request = createReportRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - assertThat(exception.status.description).isEqualTo("Report is not specified.") - } - - @Test - fun `createReport throws INVALID_ARGUMENT when reportIdempotencyKey is unspecified`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - clearReportIdempotencyKey() - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - assertThat(exception.status.description).isEqualTo("ReportIdempotencyKey is not specified.") - } - - @Test - fun `createReport throws INVALID_ARGUMENT when eventGroupUniverse in Report is unspecified`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - clearEventGroupUniverse() - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - assertThat(exception.status.description).isEqualTo("EventGroupUniverse is not specified.") - } - - @Test - fun `createReport throws INVALID_ARGUMENT when eventGroupUniverse entries list is empty`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - eventGroupUniverse = eventGroupUniverse { eventGroupEntries.clear() } - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `createReport throws INVALID_ARGUMENT when eventGroupUniverse entry is missing key`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - eventGroupUniverse = eventGroupUniverse { - eventGroupEntries += EventGroupUniverseKt.eventGroupEntry {} - } - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `createReport throws INVALID_ARGUMENT when setOperationName duplicate for same metricType`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - metrics.add(REACH_METRIC) - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - assertThat(exception.status.description) - .isEqualTo("The names of the set operations within the same metric type should be unique.") - } - - @Test - fun `createReport throws INVALID_ARGUMENT when time in Report is unspecified`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - clearTime() - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - assertThat(exception.status.description).isEqualTo("The time in Report is not specified.") - } - - @Test - fun `createReport throws INVALID_ARGUMENT when TimeIntervals is set and cumulative is true`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - clearTime() - timeIntervals = timeIntervals { - timeIntervals += timeInterval { - startTime = timestamp { seconds = 1 } - endTime = timestamp { seconds = 5 } - } - } - metrics.clear() - metrics += PENDING_REACH_REPORT.metricsList[0].copy { cumulative = true } - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `createReport throws INVALID_ARGUMENT when TimeIntervals timeIntervalsList is empty`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - clearTime() - timeIntervals = timeIntervals {} - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `createReport throws INVALID_ARGUMENT when TimeInterval startTime is unspecified`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - clearTime() - timeIntervals = timeIntervals { - timeIntervals += timeInterval { endTime = timestamp { seconds = 5 } } - } - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `createReport throws INVALID_ARGUMENT when TimeInterval endTime is unspecified`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - clearTime() - timeIntervals = timeIntervals { - timeIntervals += timeInterval { startTime = timestamp { seconds = 5 } } - } - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `createReport throws INVALID_ARGUMENT when TimeInterval endTime is before startTime`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - clearTime() - timeIntervals = timeIntervals { - timeIntervals += timeInterval { - startTime = timestamp { - seconds = 5 - nanos = 5 - } - endTime = timestamp { - seconds = 5 - nanos = 1 - } - } - } - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `createReport throws INVALID_ARGUMENT when PeriodicTimeInterval startTime is unspecified`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - clearTime() - periodicTimeInterval = periodicTimeInterval { - increment = duration { seconds = 5 } - intervalCount = 3 - } - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `createReport throws INVALID_ARGUMENT when PeriodicTimeInterval increment is unspecified`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - clearTime() - periodicTimeInterval = periodicTimeInterval { - startTime = timestamp { - seconds = 5 - nanos = 5 - } - intervalCount = 3 - } - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `createReport throws INVALID_ARGUMENT when PeriodicTimeInterval intervalCount is 0`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - clearTime() - periodicTimeInterval = periodicTimeInterval { - startTime = timestamp { - seconds = 5 - nanos = 5 - } - increment = duration { seconds = 5 } - } - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `createReport throws INVALID_ARGUMENT when any metric type in Report is unspecified`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - metrics.clear() - metrics.add(REACH_METRIC.copy { clearReach() }) - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - assertThat(exception.status.description) - .isEqualTo("The metric type in Report is not specified.") - } - - @Test - fun `createReport throws INVALID_ARGUMENT when metrics list is empty`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - metrics.clear() - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `createReport throws INVALID_ARGUMENT when namedSetOperation uniqueName is unspecified`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - metrics.clear() - metrics += metric { - reach = reachParams {} - setOperations += namedSetOperation { - setOperation = setOperation { - type = SetOperation.Type.UNION - lhs = - SetOperationKt.operand { reportingSet = INTERNAL_REPORTING_SETS[0].resourceName } - } - } - } - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `createReport throws INVALID_ARGUMENT when setOperation type is unspecified`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - metrics.clear() - metrics += metric { - reach = reachParams {} - setOperations += namedSetOperation { - uniqueName = "name" - setOperation = setOperation { - lhs = - SetOperationKt.operand { reportingSet = INTERNAL_REPORTING_SETS[0].resourceName } - } - } - } - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `createReport throws INVALID_ARGUMENT when setOperation lhs is unspecified`() { - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - metrics.clear() - metrics += metric { - reach = reachParams {} - setOperations += namedSetOperation { - uniqueName = "name" - setOperation = setOperation { type = SetOperation.Type.UNION } - } - } - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `createReport throws INVALID_ARGUMENT when provided reporting set name is invalid`() { - val invalidMetric = metric { - reach = reachParams {} - cumulative = false - setOperations.add( - NAMED_REACH_SET_OPERATION.copy { setOperation = SET_OPERATION_WITH_INVALID_REPORTING_SET } - ) - } - - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - metrics.clear() - metrics.add(invalidMetric) - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - assertThat(exception.status.description) - .isEqualTo("Invalid reporting set name $INVALID_REPORTING_SET_NAME.") - } - - @Test - fun `createReport throws INVALID_ARGUMENT when any reporting set is not accessible to caller`() { - val invalidMetric = metric { - reach = reachParams {} - cumulative = false - setOperations.add( - NAMED_REACH_SET_OPERATION.copy { - setOperation = SET_OPERATION_WITH_INACCESSIBLE_REPORTING_SET - } - ) - } - - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = - PENDING_REACH_REPORT.copy { - clearState() - metrics.clear() - metrics.add(invalidMetric) - } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - assertThat(exception.status.description) - .isEqualTo("No access to the reporting set [$REPORTING_SET_NAME_FOR_MC_2].") - } - - @Test - fun `createReport throws INVALID_ARGUMENT when eventGroup isn't covered by eventGroupUniverse`() = - runBlocking { - whenever(internalReportingSetsMock.batchGetReportingSet(any())) - .thenReturn(flowOf(INTERNAL_REPORTING_SETS[0], UNCOVERED_INTERNAL_REPORTING_SET)) - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - val expectedExceptionDescription = - "The event group [$UNCOVERED_EVENT_GROUP_NAME] in the reporting set" + - " [${UNCOVERED_INTERNAL_REPORTING_SET.displayName}] is not included in the event group " + - "universe." - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - assertThat(exception.status.description).isEqualTo(expectedExceptionDescription) - } - - @Test - fun `createReport throws NOT_FOUND when reporting set is not found`() = runBlocking { - whenever(internalReportingSetsMock.batchGetReportingSet(any())) - .thenReturn(flowOf(INTERNAL_REPORTING_SETS[0])) - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.NOT_FOUND) - } - - @Test - fun `createReport throws FAILED_PRECONDITION when EDP cert is revoked`() = runBlocking { - val dataProvider = DATA_PROVIDERS.values.first() - whenever( - certificateMock.getCertificate( - eq(getCertificateRequest { name = dataProvider.certificate }) - ) - ) - .thenReturn( - certificate { - name = dataProvider.certificate - x509Der = DATA_PROVIDER_SIGNING_KEY.certificate.encoded.toByteString() - revocationState = Certificate.RevocationState.REVOKED - } - ) - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - - assertThat(exception).hasMessageThat().ignoringCase().contains("revoked") - } - - @Test - fun `createReport throws FAILED_PRECONDITION when EDP public key signature is invalid`() = - runBlocking { - val dataProvider = DATA_PROVIDERS.values.first() - whenever( - dataProvidersMock.getDataProvider(eq(getDataProviderRequest { name = dataProvider.name })) - ) - .thenReturn( - dataProvider.copy { - publicKey = publicKey.copy { signature = "invalid sig".toByteStringUtf8() } - } - ) - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - - assertThat(exception).hasMessageThat().ignoringCase().contains("signature") - } - - @Test - fun `createReport throws exception from getReportByIdempotencyKey when status isn't NOT_FOUND`() = - runBlocking { - whenever(internalReportsMock.getReportByIdempotencyKey(any())) - .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) - - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - - val exception = - assertFailsWith(Exception::class) { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - val expectedExceptionDescription = - "Unable to retrieve a report from the reporting database using the provided " + - "reportIdempotencyKey [${PENDING_REACH_REPORT.reportIdempotencyKey}]." - assertThat(exception.message).isEqualTo(expectedExceptionDescription) - } - - @Test - fun `createReport throws exception when internal createReport throws exception`() = runBlocking { - val status = Status.INVALID_ARGUMENT.withDescription("Bad CreateReport request") - whenever(internalReportsMock.createReport(any())).thenThrow(StatusRuntimeException(status)) - - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.cause).isInstanceOf(StatusException::class.java) - val actualStatus = (exception.cause as StatusException).status - assertThat(actualStatus.code).isEqualTo(status.code) - assertThat(actualStatus.description).isEqualTo(status.description) - } - - @Test - fun `createReport throws exception when the CMM createMeasurement throws exception`() = - runBlocking { - whenever(measurementsMock.createMeasurement(any())) - .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) - - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - - val exception = - assertFailsWith(Exception::class) { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.message).contains(REACH_MEASUREMENT_CREATE_REQUEST_ID) - } - - @Test - fun `createReport throws exception when the internal createMeasurement throws exception`() = - runBlocking { - whenever(internalMeasurementsMock.batchCreateMeasurements(any())) - .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) - - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.UNKNOWN) - assertThat(exception.message).contains("Unable to create measurement") - } - - @Test - fun `createReport throws exception when getMeasurementConsumer throws exception`() = runBlocking { - whenever(measurementConsumersMock.getMeasurementConsumer(any())) - .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) - - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - - val exception = - assertFailsWith(Exception::class) { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - val expectedExceptionDescription = - "Unable to retrieve the measurement consumer [${MEASUREMENT_CONSUMERS.values.first().name}]." - assertThat(exception.message).isEqualTo(expectedExceptionDescription) - } - - @Test - fun `createReport throws exception when the internal batchGetReportingSet throws exception`(): - Unit = runBlocking { - whenever(internalReportingSetsMock.batchGetReportingSet(any())) - .thenThrow(StatusRuntimeException(Status.UNKNOWN)) - - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - - assertFails { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - } - - @Test - fun `createReport throws exception when getDataProvider throws exception`() = runBlocking { - whenever(dataProvidersMock.getDataProvider(any())) - .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) - - val request = createReportRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - report = PENDING_REACH_REPORT.copy { clearState() } - } - - val exception = - assertFailsWith(Exception::class) { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createReport(request) } - } - } - assertThat(exception).hasMessageThat().contains("dataProviders/") - } - - @Test - fun `listReports returns without a next page token when there is no previous page token`() { - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - - val expected = listReportsResponse { - reports.add(SUCCEEDED_REACH_REPORT) - reports.add(SUCCEEDED_IMPRESSION_REPORT) - reports.add(SUCCEEDED_WATCH_DURATION_REPORT) - reports.add(SUCCEEDED_FREQUENCY_HISTOGRAM_REPORT) - } - - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::streamReports) - .isEqualTo( - streamReportsRequest { - limit = DEFAULT_PAGE_SIZE + 1 - this.filter = filter { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - } - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReports returns with a next page token when there is no previous page token`() { - val request = listReportsRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - pageSize = PAGE_SIZE - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - - val expected = listReportsResponse { - reports.add(SUCCEEDED_REACH_REPORT) - reports.add(SUCCEEDED_IMPRESSION_REPORT) - reports.add(SUCCEEDED_WATCH_DURATION_REPORT) - - nextPageToken = - listReportsPageToken { - pageSize = PAGE_SIZE - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - lastReport = previousPageEnd { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[2] - } - } - .toByteString() - .base64UrlEncode() - } - - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::streamReports) - .isEqualTo( - streamReportsRequest { - limit = PAGE_SIZE + 1 - this.filter = filter { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - } - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReports returns with a next page token when there is a previous page token`() { - val request = listReportsRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - pageSize = PAGE_SIZE - pageToken = - listReportsPageToken { - pageSize = PAGE_SIZE - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - lastReport = previousPageEnd { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[0] - } - } - .toByteString() - .base64UrlEncode() - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - - val expected = listReportsResponse { - reports.add(SUCCEEDED_REACH_REPORT) - reports.add(SUCCEEDED_IMPRESSION_REPORT) - reports.add(SUCCEEDED_WATCH_DURATION_REPORT) - - nextPageToken = - listReportsPageToken { - pageSize = PAGE_SIZE - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - lastReport = previousPageEnd { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[2] - } - } - .toByteString() - .base64UrlEncode() - } - - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::streamReports) - .isEqualTo( - streamReportsRequest { - limit = PAGE_SIZE + 1 - this.filter = filter { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportIdAfter = REPORT_EXTERNAL_IDS[0] - } - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReports with page size replaced with a valid value and no previous page token`() { - val invalidPageSize = MAX_PAGE_SIZE * 2 - val request = listReportsRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - pageSize = invalidPageSize - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - - val expected = listReportsResponse { - reports.add(SUCCEEDED_REACH_REPORT) - reports.add(SUCCEEDED_IMPRESSION_REPORT) - reports.add(SUCCEEDED_WATCH_DURATION_REPORT) - reports.add(SUCCEEDED_FREQUENCY_HISTOGRAM_REPORT) - } - - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::streamReports) - .isEqualTo( - streamReportsRequest { - limit = MAX_PAGE_SIZE + 1 - this.filter = filter { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - } - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReports with invalid page size replaced with the one in previous page token`() { - val invalidPageSize = MAX_PAGE_SIZE * 2 - val previousPageSize = PAGE_SIZE - val request = listReportsRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - pageSize = invalidPageSize - pageToken = - listReportsPageToken { - pageSize = previousPageSize - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - lastReport = previousPageEnd { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[0] - } - } - .toByteString() - .base64UrlEncode() - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - - val expected = listReportsResponse { - reports.add(SUCCEEDED_REACH_REPORT) - reports.add(SUCCEEDED_IMPRESSION_REPORT) - reports.add(SUCCEEDED_WATCH_DURATION_REPORT) - - nextPageToken = - listReportsPageToken { - pageSize = previousPageSize - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - lastReport = previousPageEnd { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[2] - } - } - .toByteString() - .base64UrlEncode() - } - - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::streamReports) - .isEqualTo( - streamReportsRequest { - limit = previousPageSize + 1 - this.filter = filter { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportIdAfter = REPORT_EXTERNAL_IDS[0] - } - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReports with page size replacing the one in previous page token`() { - val newPageSize = PAGE_SIZE - val previousPageSize = 1 - val request = listReportsRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - pageSize = newPageSize - pageToken = - listReportsPageToken { - pageSize = previousPageSize - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - lastReport = previousPageEnd { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[0] - } - } - .toByteString() - .base64UrlEncode() - } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - - val expected = listReportsResponse { - reports.add(SUCCEEDED_REACH_REPORT) - reports.add(SUCCEEDED_IMPRESSION_REPORT) - reports.add(SUCCEEDED_WATCH_DURATION_REPORT) - - nextPageToken = - listReportsPageToken { - pageSize = newPageSize - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - lastReport = previousPageEnd { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[2] - } - } - .toByteString() - .base64UrlEncode() - } - - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::streamReports) - .isEqualTo( - streamReportsRequest { - limit = newPageSize + 1 - this.filter = filter { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportIdAfter = REPORT_EXTERNAL_IDS[0] - } - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReports throws UNAUTHENTICATED when no principal is found`() { - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - val exception = - assertFailsWith { runBlocking { service.listReports(request) } } - assertThat(exception.status.code).isEqualTo(Status.Code.UNAUTHENTICATED) - } - - @Test - fun `listReports throws PERMISSION_DENIED when MeasurementConsumer caller doesn't match`() { - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.last().name, CONFIG) { - runBlocking { service.listReports(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.PERMISSION_DENIED) - assertThat(exception.status.description) - .isEqualTo("Cannot list Reports belonging to other MeasurementConsumers.") - } - - @Test - fun `listReports throws UNAUTHENTICATED when the caller is not MeasurementConsumer`() { - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - val exception = - assertFailsWith { - withDataProviderPrincipal(DATA_PROVIDERS.values.first().name) { - runBlocking { service.listReports(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.UNAUTHENTICATED) - assertThat(exception.status.description).isEqualTo("No ReportingPrincipal found") - } - - @Test - fun `listReports throws INVALID_ARGUMENT when page size is less than 0`() { - val request = listReportsRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - pageSize = -1 - } - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - assertThat(exception.status.description).isEqualTo("Page size cannot be less than 0") - } - - @Test - fun `listReports throws INVALID_ARGUMENT when parent is unspecified`() { - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(ListReportsRequest.getDefaultInstance()) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `listReports throws INVALID_ARGUMENT when mc id doesn't match one in page token`() { - val measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.last().measurementConsumerId - val request = listReportsRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - pageToken = - listReportsPageToken { - this.measurementConsumerReferenceId = measurementConsumerReferenceId - lastReport = previousPageEnd { - this.measurementConsumerReferenceId = measurementConsumerReferenceId - externalReportId = REPORT_EXTERNAL_IDS[0] - } - } - .toByteString() - .base64UrlEncode() - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `listReports throws Exception when the internal streamReports throws Exception`() = - runBlocking { - whenever(internalReportsMock.streamReports(any())) - .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) - - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - - val exception = - assertFailsWith(Exception::class) { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - } - val expectedExceptionDescription = "Unable to list reports from the reporting database." - assertThat(exception.message).isEqualTo(expectedExceptionDescription) - } - - @Test - fun `listReports throws Exception when the internal getReport throws Exception`() = runBlocking { - whenever(internalReportsMock.getReport(any())) - .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) - - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - - val exception = - assertFailsWith(Exception::class) { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - } - val expectedExceptionDescription = - "Unable to get the report [${REPORT_NAMES[0]}] from the reporting database." - assertThat(exception.message).isEqualTo(expectedExceptionDescription) - } - - @Test - fun `listReports throws Exception when the CMM getMeasurement throws Exception`() = runBlocking { - whenever(measurementsMock.getMeasurement(any())) - .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) - - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - - val exception = - assertFailsWith(Exception::class) { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - } - val expectedExceptionDescription = - "Unable to retrieve the measurement [${REACH_MEASUREMENT_KEY.toName()}]." - assertThat(exception.message).isEqualTo(expectedExceptionDescription) - } - - @Test - fun `listReports throws Exception when the internal setMeasurementResult throws Exception`() = - runBlocking { - whenever(internalMeasurementsMock.setMeasurementResult(any())) - .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) - - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - - val exception = - assertFailsWith(Exception::class) { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - } - val expectedExceptionDescription = - "Unable to update the measurement [${REACH_MEASUREMENT_KEY.toName()}] in the reporting database." - assertThat(exception.message).isEqualTo(expectedExceptionDescription) - } - - @Test - fun `listReports throws Exception when the internal setMeasurementFailure throws Exception`() = - runBlocking { - whenever(internalMeasurementsMock.setMeasurementFailure(any())) - .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) - - whenever(internalReportsMock.streamReports(any())) - .thenReturn(flowOf(INTERNAL_PENDING_REACH_REPORT)) - whenever(measurementsMock.getMeasurement(any())) - .thenReturn( - PENDING_REACH_MEASUREMENT.copy { - state = Measurement.State.FAILED - failure = failure { - reason = Measurement.Failure.Reason.REQUISITION_REFUSED - message = "Privacy budget exceeded." - } - } - ) - whenever(internalReportsMock.getReport(any())) - .thenReturn(INTERNAL_PENDING_REACH_REPORT.copy { state = InternalReport.State.FAILED }) - - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - - val exception = - assertFailsWith(Exception::class) { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - } - val expectedExceptionDescription = - "Unable to update the measurement [${REACH_MEASUREMENT_KEY.toName()}] in the reporting database." - assertThat(exception.message).isEqualTo(expectedExceptionDescription) - } - - @Test - fun `listReports throws Exception when the getCertificate throws Exception`() = runBlocking { - whenever(certificateMock.getCertificate(any())) - .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) - - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - - val exception = - assertFailsWith(Exception::class) { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - } - - assertThat(exception).hasMessageThat().contains(AGGREGATOR_CERTIFICATE.name) - } - - @Test - fun `listReports returns reports with SUCCEEDED states when reports are already succeeded`() { - whenever(internalReportsMock.streamReports(any())) - .thenReturn( - flowOf( - INTERNAL_SUCCEEDED_REACH_REPORT, - INTERNAL_SUCCEEDED_IMPRESSION_REPORT, - INTERNAL_SUCCEEDED_WATCH_DURATION_REPORT, - INTERNAL_SUCCEEDED_FREQUENCY_HISTOGRAM_REPORT, - ) - ) - - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - - val expected = listReportsResponse { - reports.add(SUCCEEDED_REACH_REPORT) - reports.add(SUCCEEDED_IMPRESSION_REPORT) - reports.add(SUCCEEDED_WATCH_DURATION_REPORT) - reports.add(SUCCEEDED_FREQUENCY_HISTOGRAM_REPORT) - } - - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::streamReports) - .isEqualTo( - streamReportsRequest { - limit = DEFAULT_PAGE_SIZE + 1 - this.filter = filter { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - } - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReports returns reports with FAILED states when reports are already failed`() { - whenever(internalReportsMock.streamReports(any())) - .thenReturn( - flowOf( - INTERNAL_PENDING_REACH_REPORT.copy { state = InternalReport.State.FAILED }, - INTERNAL_PENDING_IMPRESSION_REPORT.copy { state = InternalReport.State.FAILED }, - INTERNAL_PENDING_WATCH_DURATION_REPORT.copy { state = InternalReport.State.FAILED }, - INTERNAL_PENDING_FREQUENCY_HISTOGRAM_REPORT.copy { state = InternalReport.State.FAILED }, - ) - ) - - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - - val expected = listReportsResponse { - reports.add(PENDING_REACH_REPORT.copy { state = Report.State.FAILED }) - reports.add(PENDING_IMPRESSION_REPORT.copy { state = Report.State.FAILED }) - reports.add(PENDING_WATCH_DURATION_REPORT.copy { state = Report.State.FAILED }) - reports.add(PENDING_FREQUENCY_HISTOGRAM_REPORT.copy { state = Report.State.FAILED }) - } - - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::streamReports) - .isEqualTo( - streamReportsRequest { - limit = DEFAULT_PAGE_SIZE + 1 - this.filter = filter { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - } - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReports returns reports with RUNNING states when measurements are PENDING`() = - runBlocking { - whenever(internalReportsMock.streamReports(any())) - .thenReturn(flowOf(INTERNAL_PENDING_REACH_REPORT)) - whenever(measurementsMock.getMeasurement(any())) - .thenReturn( - PENDING_REACH_MEASUREMENT.copy { - state = Measurement.State.COMPUTING - results.clear() - } - ) - whenever(internalReportsMock.getReport(any())).thenReturn(INTERNAL_PENDING_REACH_REPORT) - - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - - val expected = listReportsResponse { reports.add(PENDING_REACH_REPORT) } - - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::streamReports) - .isEqualTo( - streamReportsRequest { - limit = DEFAULT_PAGE_SIZE + 1 - this.filter = filter { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - } - } - ) - verifyProtoArgument(measurementsMock, MeasurementsCoroutineImplBase::getMeasurement) - .isEqualTo(getMeasurementRequest { name = REACH_MEASUREMENT_KEY.toName() }) - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::getReport) - .isEqualTo( - getInternalReportRequest { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[0] - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReports returns reports with FAILED states when measurements are FAILED`() = - runBlocking { - whenever(internalReportsMock.streamReports(any())) - .thenReturn(flowOf(INTERNAL_PENDING_REACH_REPORT)) - whenever(measurementsMock.getMeasurement(any())) - .thenReturn( - PENDING_REACH_MEASUREMENT.copy { - state = Measurement.State.FAILED - failure = failure { - reason = Measurement.Failure.Reason.REQUISITION_REFUSED - message = "Privacy budget exceeded." - } - } - ) - whenever(internalReportsMock.getReport(any())) - .thenReturn(INTERNAL_PENDING_REACH_REPORT.copy { state = InternalReport.State.FAILED }) - - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - - val expected = listReportsResponse { - reports.add(PENDING_REACH_REPORT.copy { state = Report.State.FAILED }) - } - - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::streamReports) - .isEqualTo( - streamReportsRequest { - limit = DEFAULT_PAGE_SIZE + 1 - this.filter = filter { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - } - } - ) - verifyProtoArgument(measurementsMock, MeasurementsCoroutineImplBase::getMeasurement) - .isEqualTo(getMeasurementRequest { name = REACH_MEASUREMENT_KEY.toName() }) - verifyProtoArgument( - internalMeasurementsMock, - InternalMeasurementsCoroutineImplBase::setMeasurementFailure, - ) - .isEqualTo( - setMeasurementFailureRequest { - measurementConsumerReferenceId = REACH_MEASUREMENT_KEY.measurementConsumerId - measurementReferenceId = REACH_MEASUREMENT_KEY.measurementId - failure = - InternalMeasurementKt.failure { - reason = InternalMeasurement.Failure.Reason.REQUISITION_REFUSED - message = "Privacy budget exceeded." - } - } - ) - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::getReport) - .isEqualTo( - getInternalReportRequest { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[0] - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReports returns reports with SUCCEEDED states when measurements are SUCCEEDED`() = - runBlocking { - whenever(internalReportsMock.streamReports(any())) - .thenReturn(flowOf(INTERNAL_PENDING_REACH_REPORT)) - whenever(measurementsMock.getMeasurement(any())).thenReturn(SUCCEEDED_REACH_MEASUREMENT) - whenever(internalReportsMock.getReport(any())).thenReturn(INTERNAL_SUCCEEDED_REACH_REPORT) - - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - - val expected = listReportsResponse { reports.add(SUCCEEDED_REACH_REPORT) } - - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::streamReports) - .isEqualTo( - streamReportsRequest { - limit = DEFAULT_PAGE_SIZE + 1 - this.filter = filter { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - } - } - ) - verifyProtoArgument(measurementsMock, MeasurementsCoroutineImplBase::getMeasurement) - .isEqualTo(getMeasurementRequest { name = REACH_MEASUREMENT_KEY.toName() }) - verifyProtoArgument( - internalMeasurementsMock, - InternalMeasurementsCoroutineImplBase::setMeasurementResult, - ) - .usingDoubleTolerance(1e-12) - .isEqualTo( - setMeasurementResultRequest { - measurementConsumerReferenceId = REACH_MEASUREMENT_KEY.measurementConsumerId - measurementReferenceId = REACH_MEASUREMENT_KEY.measurementId - this.result = - InternalMeasurementKt.result { - reach = InternalMeasurementResultKt.reach { value = REACH_VALUE } - frequency = - InternalMeasurementResultKt.frequency { - relativeFrequencyDistribution.putAll(FREQUENCY_DISTRIBUTION) - } - } - } - ) - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::getReport) - .isEqualTo( - getInternalReportRequest { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[0] - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReports returns an impression report with aggregated results`() = runBlocking { - whenever(internalReportsMock.streamReports(any())) - .thenReturn(flowOf(INTERNAL_PENDING_IMPRESSION_REPORT)) - whenever(measurementsMock.getMeasurement(any())).thenReturn(SUCCEEDED_IMPRESSION_MEASUREMENT) - whenever(internalReportsMock.getReport(any())).thenReturn(INTERNAL_SUCCEEDED_IMPRESSION_REPORT) - - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - - val expected = listReportsResponse { reports.add(SUCCEEDED_IMPRESSION_REPORT) } - - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::streamReports) - .isEqualTo( - streamReportsRequest { - limit = DEFAULT_PAGE_SIZE + 1 - this.filter = filter { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - } - } - ) - verifyProtoArgument(measurementsMock, MeasurementsCoroutineImplBase::getMeasurement) - .isEqualTo(getMeasurementRequest { name = IMPRESSION_MEASUREMENT_KEY.toName() }) - verifyProtoArgument( - internalMeasurementsMock, - InternalMeasurementsCoroutineImplBase::setMeasurementResult, - ) - .isEqualTo( - setMeasurementResultRequest { - measurementConsumerReferenceId = IMPRESSION_MEASUREMENT_KEY.measurementConsumerId - measurementReferenceId = IMPRESSION_MEASUREMENT_KEY.measurementId - this.result = - InternalMeasurementKt.result { - impression = InternalMeasurementResultKt.impression { value = TOTAL_IMPRESSION_VALUE } - } - } - ) - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::getReport) - .isEqualTo( - getInternalReportRequest { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[1] - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `listReports returns a watch duration report with aggregated results`() = runBlocking { - whenever(internalReportsMock.streamReports(any())) - .thenReturn(flowOf(INTERNAL_PENDING_WATCH_DURATION_REPORT)) - whenever(measurementsMock.getMeasurement(any())) - .thenReturn(SUCCEEDED_WATCH_DURATION_MEASUREMENT) - whenever(internalReportsMock.getReport(any())) - .thenReturn(INTERNAL_SUCCEEDED_WATCH_DURATION_REPORT) - - val request = listReportsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listReports(request) } - } - - val expected = listReportsResponse { reports.add(SUCCEEDED_WATCH_DURATION_REPORT) } - - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::streamReports) - .isEqualTo( - streamReportsRequest { - limit = DEFAULT_PAGE_SIZE + 1 - this.filter = filter { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - } - } - ) - verifyProtoArgument(measurementsMock, MeasurementsCoroutineImplBase::getMeasurement) - .isEqualTo(getMeasurementRequest { name = WATCH_DURATION_MEASUREMENT_KEY.toName() }) - verifyProtoArgument( - internalMeasurementsMock, - InternalMeasurementsCoroutineImplBase::setMeasurementResult, - ) - .isEqualTo( - setMeasurementResultRequest { - measurementConsumerReferenceId = WATCH_DURATION_MEASUREMENT_KEY.measurementConsumerId - measurementReferenceId = WATCH_DURATION_MEASUREMENT_KEY.measurementId - this.result = - InternalMeasurementKt.result { - watchDuration = - InternalMeasurementResultKt.watchDuration { value = TOTAL_WATCH_DURATION } - } - } - ) - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::getReport) - .isEqualTo( - getInternalReportRequest { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[2] - } - ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } - - @Test - fun `getReport returns the report with SUCCEEDED when the report is already succeeded`() = - runBlocking { - whenever(internalReportsMock.getReport(any())) - .thenReturn(INTERNAL_SUCCEEDED_WATCH_DURATION_REPORT) - - val request = getReportRequest { name = REPORT_NAMES[2] } - - val report = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.getReport(request) } - } - - assertThat(report).isEqualTo(SUCCEEDED_WATCH_DURATION_REPORT) - - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::getReport) - .isEqualTo( - getInternalReportRequest { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[2] - } - ) - } - - @Test - fun `getReport returns the report with FAILED when the report is already failed`() = runBlocking { - whenever(internalReportsMock.getReport(any())) - .thenReturn( - INTERNAL_PENDING_WATCH_DURATION_REPORT.copy { state = InternalReport.State.FAILED } - ) - - val request = getReportRequest { name = REPORT_NAMES[2] } - - val report = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.getReport(request) } - } - - assertThat(report).isEqualTo(PENDING_WATCH_DURATION_REPORT.copy { state = Report.State.FAILED }) - - verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::getReport) - .isEqualTo( - getInternalReportRequest { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[2] - } - ) - } - - @Test - fun `getReport returns the report with RUNNING when measurements are pending`(): Unit = - runBlocking { - whenever(internalReportsMock.getReport(any())) - .thenReturn(INTERNAL_PENDING_WATCH_DURATION_REPORT) - whenever(measurementsMock.getMeasurement(any())) - .thenReturn(PENDING_WATCH_DURATION_MEASUREMENT) - - val request = getReportRequest { name = REPORT_NAMES[2] } - - val report = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.getReport(request) } - } - - assertThat(report).isEqualTo(PENDING_WATCH_DURATION_REPORT) - - verifyProtoArgument(measurementsMock, MeasurementsCoroutineImplBase::getMeasurement) - .comparingExpectedFieldsOnly() - .isEqualTo(getMeasurementRequest { name = WATCH_DURATION_MEASUREMENT_KEY.toName() }) - - val internalReportCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(internalReportsMock, times(2)) { getReport(internalReportCaptor.capture()) } - assertThat(internalReportCaptor.allValues) - .containsExactly( - getInternalReportRequest { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[2] - }, - getInternalReportRequest { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[2] - }, - ) - } - - @Test - fun `getReport syncs and returns an SUCCEEDED report with aggregated results`(): Unit = - runBlocking { - whenever(measurementsMock.getMeasurement(any())).thenReturn(SUCCEEDED_IMPRESSION_MEASUREMENT) - whenever(internalReportsMock.getReport(any())) - .thenReturn(INTERNAL_PENDING_IMPRESSION_REPORT, INTERNAL_SUCCEEDED_IMPRESSION_REPORT) - - val request = getReportRequest { name = REPORT_NAMES[1] } - - val report = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.getReport(request) } - } - - assertThat(report).isEqualTo(SUCCEEDED_IMPRESSION_REPORT) - - verifyProtoArgument(measurementsMock, MeasurementsCoroutineImplBase::getMeasurement) - .isEqualTo(getMeasurementRequest { name = IMPRESSION_MEASUREMENT_KEY.toName() }) - verifyProtoArgument( - internalMeasurementsMock, - InternalMeasurementsCoroutineImplBase::setMeasurementResult, - ) - .isEqualTo( - setMeasurementResultRequest { - measurementConsumerReferenceId = IMPRESSION_MEASUREMENT_KEY.measurementConsumerId - measurementReferenceId = IMPRESSION_MEASUREMENT_KEY.measurementId - this.result = - InternalMeasurementKt.result { - impression = - InternalMeasurementResultKt.impression { value = TOTAL_IMPRESSION_VALUE } - } - } - ) - - val internalReportCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(internalReportsMock, times(2)) { getReport(internalReportCaptor.capture()) } - assertThat(internalReportCaptor.allValues) - .containsExactly( - getInternalReportRequest { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[1] - }, - getInternalReportRequest { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[1] - }, - ) - } - - @Test - fun `getReport syncs and returns an FAILED report when measurements failed`(): Unit = - runBlocking { - whenever(measurementsMock.getMeasurement(any())) - .thenReturn( - BASE_IMPRESSION_MEASUREMENT.copy { - state = Measurement.State.FAILED - failure = failure { - reason = Measurement.Failure.Reason.REQUISITION_REFUSED - message = "Privacy budget exceeded." - } - } - ) - whenever(internalReportsMock.getReport(any())) - .thenReturn( - INTERNAL_PENDING_IMPRESSION_REPORT, - INTERNAL_PENDING_IMPRESSION_REPORT.copy { state = InternalReport.State.FAILED }, - ) - - val request = getReportRequest { name = REPORT_NAMES[1] } - - val report = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.getReport(request) } - } - - assertThat(report).isEqualTo(PENDING_IMPRESSION_REPORT.copy { state = Report.State.FAILED }) - - verifyProtoArgument(measurementsMock, MeasurementsCoroutineImplBase::getMeasurement) - .isEqualTo(getMeasurementRequest { name = IMPRESSION_MEASUREMENT_KEY.toName() }) - verifyProtoArgument( - internalMeasurementsMock, - InternalMeasurementsCoroutineImplBase::setMeasurementFailure, - ) - .isEqualTo( - setMeasurementFailureRequest { - measurementConsumerReferenceId = IMPRESSION_MEASUREMENT_KEY.measurementConsumerId - measurementReferenceId = IMPRESSION_MEASUREMENT_KEY.measurementId - failure = - InternalMeasurementKt.failure { - reason = InternalMeasurement.Failure.Reason.REQUISITION_REFUSED - message = "Privacy budget exceeded." - } - } - ) - - val internalReportCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(internalReportsMock, times(2)) { getReport(internalReportCaptor.capture()) } - assertThat(internalReportCaptor.allValues) - .containsExactly( - getInternalReportRequest { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[1] - }, - getInternalReportRequest { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportId = REPORT_EXTERNAL_IDS[1] - }, - ) - } - - @Test - fun `getReport throws INVALID_ARGUMENT when Report name is invalid`() { - val request = getReportRequest { name = INVALID_REPORT_NAME } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.getReport(request) } - } - } - - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `getReport throws PERMISSION_DENIED when MeasurementConsumer's identity does not match`() { - val request = getReportRequest { name = REPORT_NAMES[0] } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.last().name, CONFIG) { - runBlocking { service.getReport(request) } - } - } - - assertThat(exception.status.code).isEqualTo(Status.Code.PERMISSION_DENIED) - } - - @Test - fun `getReport throws UNAUTHENTICATED when the caller is not a MeasurementConsumer`() { - val request = getReportRequest { name = REPORT_NAMES[0] } - - val exception = - assertFailsWith { - withDataProviderPrincipal(DATA_PROVIDERS.values.first().name) { - runBlocking { service.getReport(request) } - } - } - - assertThat(exception.status.code).isEqualTo(Status.Code.UNAUTHENTICATED) - } - - @Test - fun `getReport throws PERMISSION_DENIED when encryption private key not found`() = runBlocking { - whenever(internalReportsMock.getReport(any())) - .thenReturn(INTERNAL_PENDING_WATCH_DURATION_REPORT) - - whenever(measurementsMock.getMeasurement(any())) - .thenReturn( - SUCCEEDED_WATCH_DURATION_MEASUREMENT.copy { - val measurementSpec = measurementSpec { - measurementPublicKey = - MEASUREMENT_CONSUMER_PUBLIC_KEY.copy { data = INVALID_MEASUREMENT_PUBLIC_KEY_DATA } - .pack() - } - this.measurementSpec = - signMeasurementSpec(measurementSpec, MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE) - } - ) - - val request = getReportRequest { name = REPORT_NAMES[2] } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.getReport(request) } - } - } - - assertThat(exception.status.code).isEqualTo(Status.Code.PERMISSION_DENIED) - assertThat(exception.status.description).contains("private key") - } - - @Test - fun `getReport throws Exception when the internal GetReport throws Exception`() = runBlocking { - whenever(internalReportsMock.getReport(any())) - .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) - - val request = getReportRequest { name = REPORT_NAMES[2] } - - val exception = - assertFailsWith(Exception::class) { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.getReport(request) } - } - } - val expectedExceptionDescription = "Unable to get the report from the reporting database." - assertThat(exception.message).isEqualTo(expectedExceptionDescription) - } - - @Test - fun `toResult converts internal result to external result with the same content`() = runBlocking { - val internalResult = - InternalReportDetailsKt.result { - scalarTable = - InternalReportResultKt.scalarTable { - rowHeaders += listOf("row1", "row2", "row3") - columns += - InternalReportResultKt.column { - columnHeader = "column1" - setOperations += listOf(1.0, 2.0, 3.0) - } - } - histogramTables += - InternalReportResultKt.histogramTable { - rows += - InternalReportResultKt.HistogramTableKt.row { - rowHeader = "row4" - frequency = 100 - } - rows += - InternalReportResultKt.HistogramTableKt.row { - rowHeader = "row5" - frequency = 101 - } - columns += - InternalReportResultKt.column { - columnHeader = "column1" - setOperations += listOf(10.0, 11.0, 12.0) - } - columns += - InternalReportResultKt.column { - columnHeader = "column2" - setOperations += listOf(20.0, 21.0, 22.0) - } - } - } - - whenever(internalReportsMock.getReport(any())) - .thenReturn( - INTERNAL_SUCCEEDED_WATCH_DURATION_REPORT.copy { - details = InternalReportKt.details { result = internalResult } - } - ) - - val request = getReportRequest { name = REPORT_NAMES[2] } - - val report = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.getReport(request) } - } - - assertThat(report.result) - .isEqualTo( - ReportKt.result { - scalarTable = scalarTable { - rowHeaders += listOf("row1", "row2", "row3") - columns += column { - columnHeader = "column1" - setOperations += listOf(1.0, 2.0, 3.0) - } - } - histogramTables += histogramTable { - rows += row { - rowHeader = "row4" - frequency = 100 - } - rows += row { - rowHeader = "row5" - frequency = 101 - } - columns += column { - columnHeader = "column1" - setOperations += listOf(10.0, 11.0, 12.0) - } - columns += column { - columnHeader = "column2" - setOperations += listOf(20.0, 21.0, 22.0) - } - } - } - ) - } - - companion object { - private val MEASUREMENT_SPEC_FIELD_DESCRIPTOR = - Measurement.getDescriptor().findFieldByNumber(Measurement.MEASUREMENT_SPEC_FIELD_NUMBER) - private val ENCRYPTED_REQUISITION_SPEC_FIELD_DESCRIPTOR = - Measurement.DataProviderEntry.Value.getDescriptor() - .findFieldByNumber(ENCRYPTED_REQUISITION_SPEC_FIELD_NUMBER) - } -} - -private fun EventGroupKey.toInternal(): InternalReportingSet.EventGroupKey { - val source = this - return InternalReportingSetKt.eventGroupKey { - measurementConsumerReferenceId = source.measurementConsumerReferenceId - dataProviderReferenceId = source.dataProviderReferenceId - eventGroupReferenceId = source.eventGroupReferenceId - } -} - -private val InternalReportingSet.resourceKey: ReportingSetKey - get() = - ReportingSetKey(measurementConsumerReferenceId, ExternalId(externalReportingSetId).apiId.value) -private val InternalReportingSet.resourceName: String - get() = resourceKey.toName() diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/SetOperationCompilerTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/SetOperationCompilerTest.kt deleted file mode 100644 index 0fe757be12d..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/SetOperationCompilerTest.kt +++ /dev/null @@ -1,197 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.service.api.v1alpha - -import com.google.common.truth.Truth.assertThat -import kotlinx.coroutines.runBlocking -import org.junit.Assert.assertThrows -import org.junit.Before -import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 -import org.wfanet.measurement.common.identity.externalIdToApiId -import org.wfanet.measurement.reporting.v1alpha.Metric -import org.wfanet.measurement.reporting.v1alpha.MetricKt.SetOperationKt.operand -import org.wfanet.measurement.reporting.v1alpha.MetricKt.namedSetOperation -import org.wfanet.measurement.reporting.v1alpha.MetricKt.setOperation -import org.wfanet.measurement.reporting.v1alpha.copy - -// Measurement consumer IDs and names -private const val MEASUREMENT_CONSUMER_EXTERNAL_ID = 111L -private val MEASUREMENT_CONSUMER_REFERENCE_ID = externalIdToApiId(MEASUREMENT_CONSUMER_EXTERNAL_ID) - -// Reporting set IDs and names -private const val REPORTING_SET_EXTERNAL_ID = 331L -private const val REPORTING_SET_EXTERNAL_ID_2 = 332L -private const val REPORTING_SET_EXTERNAL_ID_3 = 333L - -private val REPORTING_SET_NAME = - ReportingSetKey(MEASUREMENT_CONSUMER_REFERENCE_ID, externalIdToApiId(REPORTING_SET_EXTERNAL_ID)) - .toName() -private val REPORTING_SET_NAME_2 = - ReportingSetKey(MEASUREMENT_CONSUMER_REFERENCE_ID, externalIdToApiId(REPORTING_SET_EXTERNAL_ID_2)) - .toName() -private val REPORTING_SET_NAME_3 = - ReportingSetKey(MEASUREMENT_CONSUMER_REFERENCE_ID, externalIdToApiId(REPORTING_SET_EXTERNAL_ID_3)) - .toName() - -private val EXPECTED_REPORTING_SET_NAMES_LIST_ALL_UNION = - listOf(REPORTING_SET_NAME, REPORTING_SET_NAME_2, REPORTING_SET_NAME_3).sorted() - -private val SET_OPERATION = setOperation { - type = Metric.SetOperation.Type.UNION - lhs = operand { reportingSet = REPORTING_SET_NAME } - rhs = operand { reportingSet = REPORTING_SET_NAME_2 } -} - -private val SET_OPERATION_ALL_UNION = setOperation { - type = Metric.SetOperation.Type.UNION - lhs = operand { operation = SET_OPERATION } - rhs = operand { reportingSet = REPORTING_SET_NAME_3 } -} - -// SetOperation = A + B + C - B -private val SET_OPERATION_ALL_UNION_BUT_ONE = setOperation { - type = Metric.SetOperation.Type.DIFFERENCE - lhs = operand { operation = SET_OPERATION_ALL_UNION } - rhs = operand { reportingSet = EXPECTED_REPORTING_SET_NAMES_LIST_ALL_UNION[1] } -} - -private const val SET_OPERATION_ALL_UNION_DISPLAY_NAME = "SET_OPERATION_ALL_UNION" -private const val SET_OPERATION_ALL_UNION_BUT_ONE_DISPLAY_NAME = "SET_OPERATION_ALL_UNION_BUT_ONE" - -private val NAMED_SET_OPERATION_ALL_UNION = namedSetOperation { - uniqueName = SET_OPERATION_ALL_UNION_DISPLAY_NAME - setOperation = SET_OPERATION_ALL_UNION -} - -private val NAMED_SET_OPERATION_ALL_UNION_BUT_ONE = namedSetOperation { - uniqueName = SET_OPERATION_ALL_UNION_BUT_ONE_DISPLAY_NAME - setOperation = SET_OPERATION_ALL_UNION_BUT_ONE -} - -private val EXPECTED_RESULT_FOR_ALL_UNION_SET_OPERATION = - listOf(WeightedMeasurement(EXPECTED_REPORTING_SET_NAMES_LIST_ALL_UNION, coefficient = 1)) - -private val EXPECTED_RESULT_FOR_ALL_UNION_BUT_ONE_SET_OPERATION = - listOf( - WeightedMeasurement(EXPECTED_REPORTING_SET_NAMES_LIST_ALL_UNION, coefficient = 1), - WeightedMeasurement(listOf(EXPECTED_REPORTING_SET_NAMES_LIST_ALL_UNION[1]), coefficient = -1), - ) - -private val EXPECTED_CACHE_FOR_ALL_UNION_SET_OPERATION = - mapOf( - EXPECTED_REPORTING_SET_NAMES_LIST_ALL_UNION.size to - mapOf( - 1UL to mapOf(6UL to -1, 7UL to 1), - 2UL to mapOf(5UL to -1, 7UL to 1), - 3UL to mapOf(4UL to -1, 5UL to 1, 6UL to 1, 7UL to -1), - 4UL to mapOf(3UL to -1, 7UL to 1), - 5UL to mapOf(2UL to -1, 3UL to 1, 6UL to 1, 7UL to -1), - 6UL to mapOf(1UL to -1, 3UL to 1, 5UL to 1, 7UL to -1), - 7UL to mapOf(1UL to 1, 2UL to 1, 3UL to -1, 4UL to 1, 5UL to -1, 6UL to -1, 7UL to 1), - ) - ) - -// {4: {3: -1, 7: 1}, 1: {6: -1, 7: 1}, 5: {2: -1, 3: 1, 6: 1, 7: -1}} -private val EXPECTED_CACHE_FOR_ALL_UNION_BUT_ONE_SET_OPERATION = - mapOf( - EXPECTED_REPORTING_SET_NAMES_LIST_ALL_UNION.size to - mapOf( - 1UL to mapOf(6UL to -1, 7UL to 1), - 4UL to mapOf(3UL to -1, 7UL to 1), - 5UL to mapOf(2UL to -1, 3UL to 1, 6UL to 1, 7UL to -1), - ) - ) - -@RunWith(JUnit4::class) -class SetOperationCompilerTest { - private lateinit var reportResultCompiler: SetOperationCompiler - - @Before - fun initService() { - reportResultCompiler = SetOperationCompiler() - } - - @Test - fun `compileSetOperation returns a list of weightedMeasurements and store it in the cache`() { - val resultAllUnionButOne = runBlocking { - reportResultCompiler.compileSetOperation(SET_OPERATION_ALL_UNION_BUT_ONE) - } - val primitiveRegionCacheAllUnionButOne = reportResultCompiler.getPrimitiveRegionCache() - - assertThat(resultAllUnionButOne) - .containsExactlyElementsIn(EXPECTED_RESULT_FOR_ALL_UNION_BUT_ONE_SET_OPERATION) - assertThat(primitiveRegionCacheAllUnionButOne) - .isEqualTo(EXPECTED_CACHE_FOR_ALL_UNION_BUT_ONE_SET_OPERATION) - - val resultAllUnion = runBlocking { - reportResultCompiler.compileSetOperation(SET_OPERATION_ALL_UNION) - } - val primitiveRegionCacheAllUnion = reportResultCompiler.getPrimitiveRegionCache() - - assertThat(resultAllUnion) - .containsExactlyElementsIn(EXPECTED_RESULT_FOR_ALL_UNION_SET_OPERATION) - assertThat(primitiveRegionCacheAllUnion).isEqualTo(EXPECTED_CACHE_FOR_ALL_UNION_SET_OPERATION) - } - - @Test - fun `compileSetOperation reuses the computation in the cache when there exists one`() { - runBlocking { reportResultCompiler.compileSetOperation(SET_OPERATION_ALL_UNION) } - val firstRoundPrimitiveRegionCache = reportResultCompiler.getPrimitiveRegionCache() - - runBlocking { reportResultCompiler.compileSetOperation(SET_OPERATION_ALL_UNION) } - val secondRoundPrimitiveRegionCache = reportResultCompiler.getPrimitiveRegionCache() - - assertThat(firstRoundPrimitiveRegionCache).isEqualTo(secondRoundPrimitiveRegionCache) - } - - @Test - fun `compileSetOperation throws IllegalArgumentException when lhs in SetOperation is not set`() { - val setOperationWithLhsNotSet = SET_OPERATION_ALL_UNION.copy { clearLhs() } - - val exception = - assertThrows(IllegalArgumentException::class.java) { - runBlocking { reportResultCompiler.compileSetOperation(setOperationWithLhsNotSet) } - } - assertThat(exception.message).isEqualTo("lhs in SetOperation must be set.") - } - - @Test - fun `compileSetOperation throws IllegalArgumentException when lhs operand type is not set`() { - val setOperationWithLhsOperandTypeNotSet = SET_OPERATION_ALL_UNION.copy { lhs = operand {} } - - val exception = - assertThrows(IllegalArgumentException::class.java) { - runBlocking { - reportResultCompiler.compileSetOperation(setOperationWithLhsOperandTypeNotSet) - } - } - assertThat(exception.message).isEqualTo("Operand type of lhs in SetOperation must be set.") - } - - @Test - fun `compileSetOperation throws IllegalArgumentException when a set operator type is not set`() { - val setOperationWithSetOperatorTypeNotSet = SET_OPERATION_ALL_UNION.copy { clearType() } - - val exception = - assertThrows(IllegalArgumentException::class.java) { - runBlocking { - reportResultCompiler.compileSetOperation(setOperationWithSetOperatorTypeNotSet) - } - } - assertThat(exception.message).isEqualTo("Set operator type is not specified.") - } -} diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/BUILD.bazel deleted file mode 100644 index 28ecd80eb1d..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/BUILD.bazel +++ /dev/null @@ -1,29 +0,0 @@ -load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_test") - -filegroup( - name = "textproto_files", - srcs = glob(["*.textproto"]), -) - -kt_jvm_test( - name = "ReportingTest", - srcs = ["ReportingTest.kt"], - data = [ - "textproto_files", - "//src/main/k8s/testing/secretfiles:root_certs", - "//src/main/k8s/testing/secretfiles:secret_files", - ], - jvm_flags = ["-Dcom.google.testing.junit.runner.shouldInstallTestSecurityManager=false"], - test_class = "org.wfanet.measurement.reporting.service.api.v1alpha.tools.ReportingTest", - deps = [ - "//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools:reporting", - "@wfa_common_jvm//imports/java/com/google/common/truth", - "@wfa_common_jvm//imports/java/com/google/common/truth/extensions/proto", - "@wfa_common_jvm//imports/java/io/grpc/netty", - "@wfa_common_jvm//imports/java/org/junit", - "@wfa_common_jvm//imports/kotlin/kotlin/test", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc/testing", - "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/testing", - ], -) diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/ReportingTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/ReportingTest.kt deleted file mode 100644 index 90060571776..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/ReportingTest.kt +++ /dev/null @@ -1,561 +0,0 @@ -// Copyright 2022 The Cross-Media Measurement Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.wfanet.measurement.reporting.service.api.v1alpha.tools - -import com.google.common.truth.Truth.assertThat -import io.grpc.Server -import io.grpc.ServerServiceDefinition -import io.grpc.netty.NettyServerBuilder -import java.nio.file.Path -import java.nio.file.Paths -import java.time.Duration -import java.time.Instant -import java.time.LocalDate -import java.time.ZoneOffset -import java.util.concurrent.TimeUnit.SECONDS -import org.junit.After -import org.junit.Before -import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 -import org.mockito.kotlin.any -import org.wfanet.measurement.common.crypto.SigningCerts -import org.wfanet.measurement.common.getRuntimePath -import org.wfanet.measurement.common.grpc.testing.mockService -import org.wfanet.measurement.common.grpc.toServerTlsContext -import org.wfanet.measurement.common.parseTextProto -import org.wfanet.measurement.common.testing.CommandLineTesting -import org.wfanet.measurement.common.testing.CommandLineTesting.assertThat -import org.wfanet.measurement.common.testing.ExitInterceptingSecurityManager -import org.wfanet.measurement.common.testing.verifyProtoArgument -import org.wfanet.measurement.common.toProtoDuration -import org.wfanet.measurement.common.toProtoTime -import org.wfanet.measurement.reporting.v1alpha.EventGroupsGrpcKt.EventGroupsCoroutineImplBase -import org.wfanet.measurement.reporting.v1alpha.ListEventGroupsResponse -import org.wfanet.measurement.reporting.v1alpha.ListReportingSetsResponse -import org.wfanet.measurement.reporting.v1alpha.Metric -import org.wfanet.measurement.reporting.v1alpha.MetricKt.SetOperationKt.operand -import org.wfanet.measurement.reporting.v1alpha.MetricKt.namedSetOperation -import org.wfanet.measurement.reporting.v1alpha.MetricKt.reachParams -import org.wfanet.measurement.reporting.v1alpha.MetricKt.setOperation -import org.wfanet.measurement.reporting.v1alpha.Report -import org.wfanet.measurement.reporting.v1alpha.ReportKt.EventGroupUniverseKt.eventGroupEntry -import org.wfanet.measurement.reporting.v1alpha.ReportKt.eventGroupUniverse -import org.wfanet.measurement.reporting.v1alpha.ReportingSet -import org.wfanet.measurement.reporting.v1alpha.ReportingSetsGrpcKt.ReportingSetsCoroutineImplBase -import org.wfanet.measurement.reporting.v1alpha.ReportsGrpcKt.ReportsCoroutineImplBase -import org.wfanet.measurement.reporting.v1alpha.createReportRequest -import org.wfanet.measurement.reporting.v1alpha.createReportingSetRequest -import org.wfanet.measurement.reporting.v1alpha.eventGroup -import org.wfanet.measurement.reporting.v1alpha.getReportRequest -import org.wfanet.measurement.reporting.v1alpha.listEventGroupsRequest -import org.wfanet.measurement.reporting.v1alpha.listEventGroupsResponse -import org.wfanet.measurement.reporting.v1alpha.listReportingSetsRequest -import org.wfanet.measurement.reporting.v1alpha.listReportingSetsResponse -import org.wfanet.measurement.reporting.v1alpha.listReportsRequest -import org.wfanet.measurement.reporting.v1alpha.listReportsResponse -import org.wfanet.measurement.reporting.v1alpha.metric -import org.wfanet.measurement.reporting.v1alpha.periodicTimeInterval -import org.wfanet.measurement.reporting.v1alpha.report -import org.wfanet.measurement.reporting.v1alpha.reportingSet -import org.wfanet.measurement.reporting.v1alpha.timeInterval -import org.wfanet.measurement.reporting.v1alpha.timeIntervals - -private const val HOST = "localhost" -private val SECRETS_DIR: Path = - getRuntimePath( - Paths.get("wfa_measurement_system", "src", "main", "k8s", "testing", "secretfiles") - )!! - -private val TEXTPROTO_DIR: Path = - getRuntimePath( - Paths.get( - "wfa_measurement_system", - "src", - "test", - "kotlin", - "org", - "wfanet", - "measurement", - "reporting", - "service", - "api", - "v1alpha", - "tools", - ) - )!! - -private const val REPORT_IDEMPOTENCY_KEY = "report001" -private const val MEASUREMENT_CONSUMER_NAME = "measurementConsumers/1" -private const val EVENT_GROUP_NAME_1 = "$MEASUREMENT_CONSUMER_NAME/dataProviders/1/eventGroups/1" -private const val EVENT_GROUP_NAME_2 = "$MEASUREMENT_CONSUMER_NAME/dataProviders/1/eventGroups/2" -private const val EVENT_GROUP_NAME_3 = "$MEASUREMENT_CONSUMER_NAME/dataProviders/2/eventGroups/1" - -private val REPORTING_SET = reportingSet { name = "$MEASUREMENT_CONSUMER_NAME/reportingSets/1" } -private val LIST_REPORTING_SETS_RESPONSE = listReportingSetsResponse { - reportingSets += reportingSet { - name = "$MEASUREMENT_CONSUMER_NAME/reportingSets/1" - eventGroups += listOf(EVENT_GROUP_NAME_1, EVENT_GROUP_NAME_2, EVENT_GROUP_NAME_3) - filter = "some.filter1" - displayName = "test-reporting-set1" - } - reportingSets += reportingSet { - name = "$MEASUREMENT_CONSUMER_NAME/reportingSets/2" - eventGroups += listOf(EVENT_GROUP_NAME_1) - filter = "some.filter2" - displayName = "test-reporting-set2" - } - nextPageToken = "TokenToGetTheNextPage" -} - -private const val REPORT_NAME = "$MEASUREMENT_CONSUMER_NAME/reports/1" -private val REPORT = report { - name = REPORT_NAME - measurementConsumer = MEASUREMENT_CONSUMER_NAME - eventGroupUniverse = eventGroupUniverse { - eventGroupEntries += eventGroupEntry { - key = "measurementConsumers/1/dataProviders/1/eventGroups/1" - value = "" - } - eventGroupEntries += eventGroupEntry { - key = "measurementConsumers/1/dataProviders/2/eventGroups/3" - value = "partner=abc" - } - } - timeIntervals = timeIntervals { - timeIntervals += timeInterval { - startTime = - LocalDate.now().minusDays(1).atStartOfDay().toInstant(ZoneOffset.UTC).toProtoTime() - endTime = LocalDate.now().atStartOfDay().toInstant(ZoneOffset.UTC).toProtoTime() - } - } -} - -private val LIST_REPORTS_RESPONSE = listReportsResponse { - reports += - listOf( - report { - name = "$MEASUREMENT_CONSUMER_NAME/reports/1" - measurementConsumer = MEASUREMENT_CONSUMER_NAME - state = Report.State.RUNNING - }, - report { - name = "$MEASUREMENT_CONSUMER_NAME/reports/2" - measurementConsumer = MEASUREMENT_CONSUMER_NAME - state = Report.State.SUCCEEDED - }, - report { - name = "$MEASUREMENT_CONSUMER_NAME/reports/3" - measurementConsumer = MEASUREMENT_CONSUMER_NAME - state = Report.State.FAILED - }, - ) -} - -private const val DATA_PROVIDER_NAME = "dataProviders/1" - -private val LIST_EVENT_GROUPS_RESPONSE = listEventGroupsResponse { eventGroups += eventGroup {} } - -@RunWith(JUnit4::class) -class ReportingTest { - private val reportingSetsServiceMock: ReportingSetsCoroutineImplBase = - mockService() { - onBlocking { createReportingSet(any()) }.thenReturn(REPORTING_SET) - onBlocking { listReportingSets(any()) }.thenReturn(LIST_REPORTING_SETS_RESPONSE) - } - private val reportsServiceMock: ReportsCoroutineImplBase = - mockService() { - onBlocking { createReport(any()) }.thenReturn(REPORT) - onBlocking { listReports(any()) }.thenReturn(LIST_REPORTS_RESPONSE) - onBlocking { getReport(any()) }.thenReturn(REPORT) - } - private val eventGroupsServiceMock: EventGroupsCoroutineImplBase = - mockService() { onBlocking { listEventGroups(any()) }.thenReturn(LIST_EVENT_GROUPS_RESPONSE) } - - private val serverCerts = - SigningCerts.fromPemFiles( - certificateFile = SECRETS_DIR.resolve("reporting_tls.pem").toFile(), - privateKeyFile = SECRETS_DIR.resolve("reporting_tls.key").toFile(), - trustedCertCollectionFile = SECRETS_DIR.resolve("reporting_root.pem").toFile(), - ) - - private val services: List = - listOf( - reportingSetsServiceMock.bindService(), - reportsServiceMock.bindService(), - eventGroupsServiceMock.bindService(), - ) - - private val server: Server = - NettyServerBuilder.forPort(0) - .sslContext(serverCerts.toServerTlsContext()) - .addServices(services) - .build() - - @Before - fun initServer() { - server.start() - } - - @After - fun shutdownServer() { - server.shutdown() - server.awaitTermination(1, SECONDS) - } - - private fun callCli(args: Array): String { - val capturedOutput = CommandLineTesting.capturingOutput(args, Reporting::main) - assertThat(capturedOutput).status().isEqualTo(0) - return capturedOutput.out - } - - @Test - fun `reporting_sets create calls api with valid request`() { - val args = - arrayOf( - "--tls-cert-file=$SECRETS_DIR/mc_tls.pem", - "--tls-key-file=$SECRETS_DIR/mc_tls.key", - "--cert-collection-file=$SECRETS_DIR/reporting_root.pem", - "--reporting-server-api-target=$HOST:${server.port}", - "reporting-sets", - "create", - "--parent=$MEASUREMENT_CONSUMER_NAME", - "--event-group=$EVENT_GROUP_NAME_1", - "--event-group=$EVENT_GROUP_NAME_2", - "--event-group=$EVENT_GROUP_NAME_3", - "--filter=some.filter", - "--display-name=test-reporting-set", - ) - - val output = callCli(args) - - verifyProtoArgument( - reportingSetsServiceMock, - ReportingSetsCoroutineImplBase::createReportingSet, - ) - .isEqualTo( - createReportingSetRequest { - parent = MEASUREMENT_CONSUMER_NAME - reportingSet = reportingSet { - eventGroups += listOf(EVENT_GROUP_NAME_1, EVENT_GROUP_NAME_2, EVENT_GROUP_NAME_3) - filter = "some.filter" - displayName = "test-reporting-set" - } - } - ) - assertThat(parseTextProto(output.reader(), ReportingSet.getDefaultInstance())) - .isEqualTo(REPORTING_SET) - } - - @Test - fun `reporting_sets list calls api with valid request`() { - val args = - arrayOf( - "--tls-cert-file=$SECRETS_DIR/mc_tls.pem", - "--tls-key-file=$SECRETS_DIR/mc_tls.key", - "--cert-collection-file=$SECRETS_DIR/reporting_root.pem", - "--reporting-server-api-target=$HOST:${server.port}", - "reporting-sets", - "list", - "--parent=$MEASUREMENT_CONSUMER_NAME", - "--page-size=50", - ) - - val output = callCli(args) - - verifyProtoArgument(reportingSetsServiceMock, ReportingSetsCoroutineImplBase::listReportingSets) - .isEqualTo( - listReportingSetsRequest { - parent = MEASUREMENT_CONSUMER_NAME - pageSize = 50 - } - ) - assertThat(parseTextProto(output.reader(), ListReportingSetsResponse.getDefaultInstance())) - .isEqualTo(LIST_REPORTING_SETS_RESPONSE) - } - - @Test - fun `Reports create calls api with valid request`() { - val metric = - """ - reach { } - set_operations { - unique_name: "operation1" - set_operation { - type: 1 - lhs { - reporting_set: "measurementConsumers/1/reportingSets/1" - } - rhs { - reporting_set: "measurementConsumers/1/reportingSets/2" - } - } - } - """ - .trimIndent() - val args = - arrayOf( - "--tls-cert-file=$SECRETS_DIR/mc_tls.pem", - "--tls-key-file=$SECRETS_DIR/mc_tls.key", - "--cert-collection-file=$SECRETS_DIR/reporting_root.pem", - "--reporting-server-api-target=$HOST:${server.port}", - "reports", - "create", - "--idempotency-key=$REPORT_IDEMPOTENCY_KEY", - "--parent=$MEASUREMENT_CONSUMER_NAME", - "--event-group-key=$EVENT_GROUP_NAME_1", - "--event-group-value=", - "--event-group-key=$EVENT_GROUP_NAME_2", - "--event-group-value=partner=abc", - "--periodic-interval-start-time=2017-01-15T01:30:15.01Z", - "--periodic-interval-increment=P1DT3H5M12.99S", - "--periodic-interval-count=3", - "--metric=$metric", - ) - - val output = callCli(args) - - verifyProtoArgument(reportsServiceMock, ReportsCoroutineImplBase::createReport) - .isEqualTo( - createReportRequest { - parent = MEASUREMENT_CONSUMER_NAME - report = report { - reportIdempotencyKey = REPORT_IDEMPOTENCY_KEY - measurementConsumer = MEASUREMENT_CONSUMER_NAME - eventGroupUniverse = eventGroupUniverse { - eventGroupEntries += eventGroupEntry { key = EVENT_GROUP_NAME_1 } - eventGroupEntries += eventGroupEntry { - key = EVENT_GROUP_NAME_2 - value = "partner=abc" - } - } - periodicTimeInterval = periodicTimeInterval { - startTime = Instant.parse("2017-01-15T01:30:15.01Z").toProtoTime() - increment = Duration.parse("P1DT3H5M12.99S").toProtoDuration() - intervalCount = 3 - } - metrics += metric { - reach = reachParams {} - setOperations += namedSetOperation { - uniqueName = "operation1" - setOperation = setOperation { - type = Metric.SetOperation.Type.UNION - lhs = operand { reportingSet = "measurementConsumers/1/reportingSets/1" } - rhs = operand { reportingSet = "measurementConsumers/1/reportingSets/2" } - } - } - } - } - } - ) - assertThat(parseTextProto(output.reader(), Report.getDefaultInstance())).isEqualTo(REPORT) - } - - @Test - fun `Reports create calls api with correct time intervals params`() { - val textFormatMetric = - """ - reach { } - set_operations { - unique_name: "operation1" - set_operation { - type: 1 - lhs { - reporting_set: "measurementConsumers/1/reportingSets/1" - } - rhs { - reporting_set: "measurementConsumers/1/reportingSets/2" - } - } - } - """ - .trimIndent() - - val args = - arrayOf( - "--tls-cert-file=$SECRETS_DIR/mc_tls.pem", - "--tls-key-file=$SECRETS_DIR/mc_tls.key", - "--cert-collection-file=$SECRETS_DIR/reporting_root.pem", - "--reporting-server-api-target=$HOST:${server.port}", - "reports", - "create", - "--idempotency-key=$REPORT_IDEMPOTENCY_KEY", - "--parent=$MEASUREMENT_CONSUMER_NAME", - "--event-group-key=$EVENT_GROUP_NAME_1", - "--event-group-value=", - "--interval-start-time=2017-01-15T01:30:15.01Z", - "--interval-end-time=2018-10-27T23:19:12.99Z", - "--interval-start-time=2019-01-19T09:48:35.57Z", - "--interval-end-time=2022-06-13T11:57:54.21Z", - "--metric=$textFormatMetric", - ) - val output = callCli(args) - - verifyProtoArgument(reportsServiceMock, ReportsCoroutineImplBase::createReport) - .comparingExpectedFieldsOnly() - .isEqualTo( - createReportRequest { - report = report { - timeIntervals = timeIntervals { - timeIntervals += timeInterval { - startTime = Instant.parse("2017-01-15T01:30:15.01Z").toProtoTime() - endTime = Instant.parse("2018-10-27T23:19:12.99Z").toProtoTime() - } - timeIntervals += timeInterval { - startTime = Instant.parse("2019-01-19T09:48:35.57Z").toProtoTime() - endTime = Instant.parse("2022-06-13T11:57:54.21Z").toProtoTime() - } - } - } - } - ) - assertThat(parseTextProto(output.reader(), Report.getDefaultInstance())).isEqualTo(REPORT) - } - - @Test - fun `Reports create calls api with complex metric`() { - val textFormatMetric = TEXTPROTO_DIR.resolve("metric2.textproto").toFile().readText() - - val args = - arrayOf( - "--tls-cert-file=$SECRETS_DIR/mc_tls.pem", - "--tls-key-file=$SECRETS_DIR/mc_tls.key", - "--cert-collection-file=$SECRETS_DIR/reporting_root.pem", - "--reporting-server-api-target=$HOST:${server.port}", - "reports", - "create", - "--idempotency-key=$REPORT_IDEMPOTENCY_KEY", - "--parent=$MEASUREMENT_CONSUMER_NAME", - "--event-group-key=$EVENT_GROUP_NAME_1", - "--event-group-value=", - "--periodic-interval-start-time=2017-01-15T01:30:15.01Z", - "--periodic-interval-increment=P1DT3H5M12.99S", - "--periodic-interval-count=3", - "--metric=$textFormatMetric", - ) - val output = callCli(args) - - verifyProtoArgument(reportsServiceMock, ReportsCoroutineImplBase::createReport) - .comparingExpectedFieldsOnly() - .isEqualTo( - createReportRequest { - report = report { - metrics += metric { - reach = reachParams {} - cumulative = true - setOperations += namedSetOperation { - uniqueName = "operation1" - setOperation = setOperation { - type = Metric.SetOperation.Type.UNION - lhs = operand { reportingSet = "measurementConsumers/1/reportingSets/1" } - rhs = operand { reportingSet = "measurementConsumers/1/reportingSets/2" } - } - } - setOperations += namedSetOperation { - uniqueName = "operation2" - setOperation = setOperation { - type = Metric.SetOperation.Type.DIFFERENCE - lhs = operand { reportingSet = "measurementConsumers/1/reportingSets/3" } - rhs = operand { - operation = setOperation { - type = Metric.SetOperation.Type.INTERSECTION - lhs = operand { reportingSet = "measurementConsumers/1/reportingSets/4" } - rhs = operand { reportingSet = "measurementConsumers/1/reportingSets/5" } - } - } - } - } - } - } - } - ) - assertThat(parseTextProto(output.reader(), Report.getDefaultInstance())).isEqualTo(REPORT) - } - - @Test - fun `Reports list calls api with valid request`() { - val args = - arrayOf( - "--tls-cert-file=$SECRETS_DIR/mc_tls.pem", - "--tls-key-file=$SECRETS_DIR/mc_tls.key", - "--cert-collection-file=$SECRETS_DIR/reporting_root.pem", - "--reporting-server-api-target=$HOST:${server.port}", - "reports", - "list", - "--parent=$MEASUREMENT_CONSUMER_NAME", - ) - callCli(args) - - verifyProtoArgument(reportsServiceMock, ReportsCoroutineImplBase::listReports) - .isEqualTo( - listReportsRequest { - parent = MEASUREMENT_CONSUMER_NAME - pageSize = 1000 - } - ) - } - - @Test - fun `Reports get calls api with valid request`() { - val args = - arrayOf( - "--tls-cert-file=$SECRETS_DIR/mc_tls.pem", - "--tls-key-file=$SECRETS_DIR/mc_tls.key", - "--cert-collection-file=$SECRETS_DIR/reporting_root.pem", - "--reporting-server-api-target=$HOST:${server.port}", - "reports", - "get", - REPORT_NAME, - ) - val output = callCli(args) - - verifyProtoArgument(reportsServiceMock, ReportsCoroutineImplBase::getReport) - .isEqualTo(getReportRequest { name = REPORT_NAME }) - assertThat(parseTextProto(output.reader(), Report.getDefaultInstance())).isEqualTo(REPORT) - } - - @Test - fun `EventGroups list callls api with valid request`() { - val args = - arrayOf( - "--tls-cert-file=$SECRETS_DIR/mc_tls.pem", - "--tls-key-file=$SECRETS_DIR/mc_tls.key", - "--cert-collection-file=$SECRETS_DIR/reporting_root.pem", - "--reporting-server-api-target=$HOST:${server.port}", - "event-groups", - "list", - "--parent=$DATA_PROVIDER_NAME", - "--filter=abcd", - ) - val output = callCli(args) - - verifyProtoArgument(eventGroupsServiceMock, EventGroupsCoroutineImplBase::listEventGroups) - .isEqualTo( - listEventGroupsRequest { - parent = DATA_PROVIDER_NAME - filter = "abcd" - pageSize = 1000 - } - ) - assertThat(parseTextProto(output.reader(), ListEventGroupsResponse.getDefaultInstance())) - .isEqualTo(LIST_EVENT_GROUPS_RESPONSE) - } - - companion object { - init { - System.setSecurityManager(ExitInterceptingSecurityManager) - } - } -} diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/metric1.textproto b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/metric1.textproto deleted file mode 100644 index c5bf5885454..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/metric1.textproto +++ /dev/null @@ -1,15 +0,0 @@ -# proto-file: wfa/measurement/reporting/v1alpha/metric.proto -# proto-message: Metric -reach { } -set_operations { - unique_name: "operation1" - set_operation { - type: 1 - lhs { - reporting_set: "measurementConsumers/1/reportingSets/1" - } - rhs { - reporting_set: "measurementConsumers/1/reportingSets/2" - } - } -} diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/metric2.textproto b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/metric2.textproto deleted file mode 100644 index 48a44bd3e55..00000000000 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/tools/metric2.textproto +++ /dev/null @@ -1,36 +0,0 @@ -# proto-file: wfa/measurement/reporting/v1alpha/metric.proto -# proto-message: Metric -reach { } -cumulative: true -set_operations { - unique_name: "operation1" - set_operation { - type: 1 - lhs { - reporting_set: "measurementConsumers/1/reportingSets/1" - } - rhs { - reporting_set: "measurementConsumers/1/reportingSets/2" - } - } -} -set_operations { - unique_name: "operation2" - set_operation { - type: 2 - lhs { - reporting_set: "measurementConsumers/1/reportingSets/3" - } - rhs { - operation { - type: 3 - lhs { - reporting_set: "measurementConsumers/1/reportingSets/4" - } - rhs { - reporting_set: "measurementConsumers/1/reportingSets/5" - } - } - } - } -} \ No newline at end of file diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsServiceTest.kt index befbdcedd0b..42df07ae256 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/EventGroupsServiceTest.kt @@ -19,11 +19,13 @@ package org.wfanet.measurement.reporting.service.api.v2alpha import com.google.common.truth.Truth.assertThat import com.google.common.truth.extensions.proto.ProtoTruth.assertThat import com.google.protobuf.Any +import io.grpc.Deadline import io.grpc.Status import io.grpc.StatusRuntimeException import java.nio.file.Path import java.nio.file.Paths import java.time.Duration +import java.util.concurrent.TimeUnit import kotlin.test.assertFailsWith import kotlinx.coroutines.runBlocking import org.junit.After @@ -110,6 +112,7 @@ class EventGroupsServiceTest { private lateinit var celEnvCacheProvider: CelEnvCacheProvider private lateinit var service: EventGroupsService + private val fakeTicker = SettableSystemTicker() @Before fun initService() { @@ -126,6 +129,7 @@ class EventGroupsServiceTest { EventGroupsCoroutineStub(grpcTestServerRule.channel), ENCRYPTION_KEY_PAIR_STORE, celEnvCacheProvider, + fakeTicker, ) } @@ -152,13 +156,13 @@ class EventGroupsServiceTest { whenever(publicKingdomEventGroupsMock.listEventGroups(any())) .thenReturn( listEventGroupsResponse { - nextPageToken = "1" eventGroups += cmmsEventGroup2 + nextPageToken = "1" } ) .thenReturn( listEventGroupsResponse { - eventGroups += listOf(CMMS_EVENT_GROUP, cmmsEventGroup2) + eventGroups += CMMS_EVENT_GROUP nextPageToken = "2" } ) @@ -170,6 +174,7 @@ class EventGroupsServiceTest { listEventGroupsRequest { parent = MEASUREMENT_CONSUMER_NAME filter = "metadata.metadata.publisher_id > 5" + pageSize = 1 } ) } @@ -200,7 +205,10 @@ class EventGroupsServiceTest { eventGroups += CMMS_EVENT_GROUP } ) - + .then { + // Advance time. + fakeTicker.setNanoTime(fakeTicker.nanoTime() + TimeUnit.SECONDS.toNanos(30)) + } val response = withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { runBlocking { @@ -443,7 +451,12 @@ class EventGroupsServiceTest { val response = withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { runBlocking { - service.listEventGroups(listEventGroupsRequest { parent = MEASUREMENT_CONSUMER_NAME }) + service.listEventGroups( + listEventGroupsRequest { + parent = MEASUREMENT_CONSUMER_NAME + pageSize = 2 + } + ) } } @@ -455,7 +468,7 @@ class EventGroupsServiceTest { .isEqualTo( cmmsListEventGroupsRequest { parent = MEASUREMENT_CONSUMER_NAME - pageSize = DEFAULT_PAGE_SIZE + pageSize = 2 } ) } @@ -813,4 +826,20 @@ class EventGroupsServiceTest { } } } + + /** + * Fake [Deadline.Ticker] implementation that allows time to be specified to override delegation + * to the system ticker. + */ + private class SettableSystemTicker : Deadline.Ticker() { + private var nanoTime: Long? = null + + fun setNanoTime(value: Long) { + nanoTime = value + } + + override fun nanoTime(): Long { + return this.nanoTime ?: Deadline.getSystemTicker().nanoTime() + } + } } diff --git a/src/test/kotlin/org/wfanet/panelmatch/integration/k8s/EmptyClusterPanelMatchCorrectnessTest.kt b/src/test/kotlin/org/wfanet/panelmatch/integration/k8s/EmptyClusterPanelMatchCorrectnessTest.kt index 8c126916597..07de2db818b 100644 --- a/src/test/kotlin/org/wfanet/panelmatch/integration/k8s/EmptyClusterPanelMatchCorrectnessTest.kt +++ b/src/test/kotlin/org/wfanet/panelmatch/integration/k8s/EmptyClusterPanelMatchCorrectnessTest.kt @@ -494,7 +494,7 @@ class EmptyClusterPanelMatchCorrectnessTest : AbstractPanelMatchCorrectnessTest( private val LOCAL_K8S_PATH = Paths.get("src", "main", "k8s", "local") private val LOCAL_K8S_TESTING_PATH = LOCAL_K8S_PATH.resolve("testing") - private val CONFIG_FILES_PATH = LOCAL_K8S_TESTING_PATH.resolve("config_files") + private val CONFIG_FILES_PATH = LOCAL_K8S_TESTING_PATH.resolve("config_files_for_panel_match") private val LOCAL_K8S_PANELMATCH_PATH = Paths.get("src", "main", "k8s", "panelmatch", "local") private val PANELMATCH_CONFIG_FILES_PATH = LOCAL_K8S_PANELMATCH_PATH.resolve("config_files")