Skip to content

Commit

Permalink
feat: Update v1beta1 sdk for a few new protos
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 700552835
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Nov 27, 2024
1 parent 4288fec commit 7127d97
Show file tree
Hide file tree
Showing 16 changed files with 755 additions and 493 deletions.
2 changes: 2 additions & 0 deletions google/cloud/aiplatform/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@
tensorboard_service_client_v1.TensorboardServiceClient,
vizier_service_client_v1.VizierServiceClient,
vertex_rag_service_client_v1.VertexRagServiceClient,
vertex_rag_data_service_async_client_v1.VertexRagDataServiceAsyncClient,
vertex_rag_data_service_client_v1.VertexRagDataServiceClient,
)


Expand Down
76 changes: 74 additions & 2 deletions tests/unit/vertex_rag/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,18 @@
from google import auth
from google.api_core import operation as ga_operation
from google.auth import credentials as auth_credentials
from vertexai.preview import rag
from google.cloud.aiplatform_v1beta1 import (
from vertexai import rag
from vertexai.preview import rag as rag_preview
from google.cloud.aiplatform_v1 import (
DeleteRagCorpusRequest,
VertexRagDataServiceAsyncClient,
VertexRagDataServiceClient,
)
from google.cloud.aiplatform_v1beta1 import (
DeleteRagCorpusRequest as DeleteRagCorpusRequestPreview,
VertexRagDataServiceAsyncClient as VertexRagDataServiceAsyncClientPreview,
VertexRagDataServiceClient as VertexRagDataServiceClientPreview,
)
import test_rag_constants_preview
import mock
import pytest
Expand Down Expand Up @@ -75,6 +81,30 @@ def rag_data_client_mock():
yield rag_data_client_mock


@pytest.fixture
def rag_data_client_preview_mock():
with mock.patch.object(
rag_preview.utils._gapic_utils, "create_rag_data_service_client"
) as rag_data_client_mock:
api_client_mock = mock.Mock(spec=VertexRagDataServiceClientPreview)

# get_rag_corpus
api_client_mock.get_rag_corpus.return_value = (
test_rag_constants_preview.TEST_GAPIC_RAG_CORPUS
)
# delete_rag_corpus
delete_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation)
delete_rag_corpus_lro_mock.result.return_value = DeleteRagCorpusRequestPreview()
api_client_mock.delete_rag_corpus.return_value = delete_rag_corpus_lro_mock
# get_rag_file
api_client_mock.get_rag_file.return_value = (
test_rag_constants_preview.TEST_GAPIC_RAG_FILE
)

rag_data_client_mock.return_value = api_client_mock
yield rag_data_client_mock


@pytest.fixture
def rag_data_client_mock_exception():
with mock.patch.object(
Expand Down Expand Up @@ -105,6 +135,36 @@ def rag_data_client_mock_exception():
yield rag_data_client_mock_exception


@pytest.fixture
def rag_data_client_preview_mock_exception():
with mock.patch.object(
rag_preview.utils._gapic_utils, "create_rag_data_service_client"
) as rag_data_client_mock_exception:
api_client_mock = mock.Mock(spec=VertexRagDataServiceClientPreview)
# create_rag_corpus
api_client_mock.create_rag_corpus.side_effect = Exception
# update_rag_corpus
api_client_mock.update_rag_corpus.side_effect = Exception
# get_rag_corpus
api_client_mock.get_rag_corpus.side_effect = Exception
# list_rag_corpora
api_client_mock.list_rag_corpora.side_effect = Exception
# delete_rag_corpus
api_client_mock.delete_rag_corpus.side_effect = Exception
# upload_rag_file
api_client_mock.upload_rag_file.side_effect = Exception
# import_rag_files
api_client_mock.import_rag_files.side_effect = Exception
# get_rag_file
api_client_mock.get_rag_file.side_effect = Exception
# list_rag_files
api_client_mock.list_rag_files.side_effect = Exception
# delete_rag_file
api_client_mock.delete_rag_file.side_effect = Exception
rag_data_client_mock_exception.return_value = api_client_mock
yield rag_data_client_mock_exception


@pytest.fixture
def rag_data_async_client_mock_exception():
with mock.patch.object(
Expand All @@ -115,3 +175,15 @@ def rag_data_async_client_mock_exception():
api_client_mock.import_rag_files.side_effect = Exception
rag_data_client_mock_exception.return_value = api_client_mock
yield rag_data_async_client_mock_exception


@pytest.fixture
def rag_data_async_client_preview_mock_exception():
with mock.patch.object(
rag_preview.utils._gapic_utils, "create_rag_data_service_async_client"
) as rag_data_async_client_mock_exception:
api_client_mock = mock.Mock(spec=VertexRagDataServiceAsyncClientPreview)
# import_rag_files
api_client_mock.import_rag_files.side_effect = Exception
rag_data_client_mock_exception.return_value = api_client_mock
yield rag_data_async_client_mock_exception
112 changes: 23 additions & 89 deletions tests/unit/vertex_rag/test_rag_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@
VertexFeatureStore,
)

from google.cloud.aiplatform_v1beta1 import (
from google.cloud.aiplatform_v1 import (
GoogleDriveSource,
RagFileChunkingConfig,
RagFileParsingConfig,
RagFileTransformationConfig,
ImportRagFilesConfig,
ImportRagFilesRequest,
ImportRagFilesResponse,
Expand All @@ -55,7 +55,7 @@
RagContexts,
RetrieveContextsResponse,
)
from google.cloud.aiplatform_v1beta1.types import api_auth
from google.cloud.aiplatform_v1.types import api_auth
from google.protobuf import timestamp_pb2


Expand Down Expand Up @@ -99,61 +99,6 @@
display_name=TEST_CORPUS_DISPLAY_NAME,
description=TEST_CORPUS_DISCRIPTION,
)
TEST_GAPIC_RAG_CORPUS.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = (
"projects/{}/locations/{}/publishers/google/models/textembedding-gecko".format(
TEST_PROJECT, TEST_REGION
)
)
TEST_GAPIC_RAG_CORPUS_WEAVIATE = GapicRagCorpus(
name=TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=TEST_CORPUS_DISPLAY_NAME,
description=TEST_CORPUS_DISCRIPTION,
rag_vector_db_config=RagVectorDbConfig(
weaviate=RagVectorDbConfig.Weaviate(
http_endpoint=TEST_WEAVIATE_HTTP_ENDPOINT,
collection_name=TEST_WEAVIATE_COLLECTION_NAME,
),
api_auth=api_auth.ApiAuth(
api_key_config=api_auth.ApiAuth.ApiKeyConfig(
api_key_secret_version=TEST_WEAVIATE_API_KEY_SECRET_VERSION
),
),
),
)
TEST_GAPIC_RAG_CORPUS_VERTEX_FEATURE_STORE = GapicRagCorpus(
name=TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=TEST_CORPUS_DISPLAY_NAME,
description=TEST_CORPUS_DISCRIPTION,
rag_vector_db_config=RagVectorDbConfig(
vertex_feature_store=RagVectorDbConfig.VertexFeatureStore(
feature_view_resource_name=TEST_VERTEX_FEATURE_STORE_RESOURCE_NAME
),
),
)
TEST_GAPIC_RAG_CORPUS_VERTEX_VECTOR_SEARCH = GapicRagCorpus(
name=TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=TEST_CORPUS_DISPLAY_NAME,
description=TEST_CORPUS_DISCRIPTION,
rag_vector_db_config=RagVectorDbConfig(
vertex_vector_search=RagVectorDbConfig.VertexVectorSearch(
index_endpoint=TEST_VERTEX_VECTOR_SEARCH_INDEX_ENDPOINT,
index=TEST_VERTEX_VECTOR_SEARCH_INDEX,
),
),
)
TEST_GAPIC_RAG_CORPUS_PINECONE = GapicRagCorpus(
name=TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=TEST_CORPUS_DISPLAY_NAME,
description=TEST_CORPUS_DISCRIPTION,
rag_vector_db_config=RagVectorDbConfig(
pinecone=RagVectorDbConfig.Pinecone(index_name=TEST_PINECONE_INDEX_NAME),
api_auth=api_auth.ApiAuth(
api_key_config=api_auth.ApiAuth.ApiKeyConfig(
api_key_secret_version=TEST_PINECONE_API_KEY_SECRET_VERSION
),
),
),
)
TEST_EMBEDDING_MODEL_CONFIG = EmbeddingModelConfig(
publisher_model="publishers/google/models/textembedding-gecko",
)
Expand Down Expand Up @@ -198,7 +143,7 @@
TEST_FILE_DISPLAY_NAME = "my-file.txt"
TEST_FILE_DESCRIPTION = "my file."
TEST_HEADERS = {"X-Goog-Upload-Protocol": "multipart"}
TEST_UPLOAD_REQUEST_URI = "https://{}/upload/v1beta1/projects/{}/locations/{}/ragCorpora/{}/ragFiles:upload".format(
TEST_UPLOAD_REQUEST_URI = "https://{}/upload/v1/projects/{}/locations/{}/ragCorpora/{}/ragFiles:upload".format(
TEST_API_ENDPOINT, TEST_PROJECT_NUMBER, TEST_REGION, TEST_RAG_CORPUS_ID
)
TEST_RAG_FILE_ID = "generate-456"
Expand All @@ -215,10 +160,19 @@
TEST_RAG_FILE_JSON_ERROR = {"error": {"code": 13}}
TEST_CHUNK_SIZE = 512
TEST_CHUNK_OVERLAP = 100
TEST_RAG_FILE_TRANSFORMATION_CONFIG = RagFileTransformationConfig(
rag_file_chunking_config=RagFileChunkingConfig(
fixed_length_chunking=RagFileChunkingConfig.FixedLengthChunking(
chunk_size=TEST_CHUNK_SIZE,
chunk_overlap=TEST_CHUNK_OVERLAP,
),
),
)
# GCS
TEST_IMPORT_FILES_CONFIG_GCS = ImportRagFilesConfig()
TEST_IMPORT_FILES_CONFIG_GCS = ImportRagFilesConfig(
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
)
TEST_IMPORT_FILES_CONFIG_GCS.gcs_source.uris = [TEST_GCS_PATH]
TEST_IMPORT_FILES_CONFIG_GCS.rag_file_parsing_config.use_advanced_pdf_parsing = False
TEST_IMPORT_REQUEST_GCS = ImportRagFilesRequest(
parent=TEST_RAG_CORPUS_RESOURCE_NAME,
import_rag_files_config=TEST_IMPORT_FILES_CONFIG_GCS,
Expand All @@ -231,26 +185,22 @@
TEST_DRIVE_FOLDER_2 = (
f"https://drive.google.com/drive/folders/{TEST_DRIVE_FOLDER_ID}?resourcekey=0-eiOT3"
)
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER = ImportRagFilesConfig()
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER = ImportRagFilesConfig(
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
)
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER.google_drive_source.resource_ids = [
GoogleDriveSource.ResourceId(
resource_id=TEST_DRIVE_FOLDER_ID,
resource_type=GoogleDriveSource.ResourceId.ResourceType.RESOURCE_TYPE_FOLDER,
)
]
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER.rag_file_parsing_config.use_advanced_pdf_parsing = (
False
)
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING = ImportRagFilesConfig()
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING.google_drive_source.resource_ids = [
GoogleDriveSource.ResourceId(
resource_id=TEST_DRIVE_FOLDER_ID,
resource_type=GoogleDriveSource.ResourceId.ResourceType.RESOURCE_TYPE_FOLDER,
)
]
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING.rag_file_parsing_config.use_advanced_pdf_parsing = (
True
)
TEST_IMPORT_REQUEST_DRIVE_FOLDER = ImportRagFilesRequest(
parent=TEST_RAG_CORPUS_RESOURCE_NAME,
import_rag_files_config=TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER,
Expand All @@ -263,11 +213,7 @@
TEST_DRIVE_FILE_ID = "456"
TEST_DRIVE_FILE = f"https://drive.google.com/file/d/{TEST_DRIVE_FILE_ID}"
TEST_IMPORT_FILES_CONFIG_DRIVE_FILE = ImportRagFilesConfig(
rag_file_chunking_config=RagFileChunkingConfig(
chunk_size=TEST_CHUNK_SIZE,
chunk_overlap=TEST_CHUNK_OVERLAP,
),
rag_file_parsing_config=RagFileParsingConfig(use_advanced_pdf_parsing=False),
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
)
TEST_IMPORT_FILES_CONFIG_DRIVE_FILE.max_embedding_requests_per_min = 800

Expand Down Expand Up @@ -322,10 +268,7 @@
],
)
TEST_IMPORT_FILES_CONFIG_SLACK_SOURCE = ImportRagFilesConfig(
rag_file_chunking_config=RagFileChunkingConfig(
chunk_size=TEST_CHUNK_SIZE,
chunk_overlap=TEST_CHUNK_OVERLAP,
)
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
)
TEST_IMPORT_FILES_CONFIG_SLACK_SOURCE.slack_source.channels = [
GapicSlackSource.SlackChannels(
Expand Down Expand Up @@ -377,10 +320,7 @@
],
)
TEST_IMPORT_FILES_CONFIG_JIRA_SOURCE = ImportRagFilesConfig(
rag_file_chunking_config=RagFileChunkingConfig(
chunk_size=TEST_CHUNK_SIZE,
chunk_overlap=TEST_CHUNK_OVERLAP,
)
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
)
TEST_IMPORT_FILES_CONFIG_JIRA_SOURCE.jira_source.jira_queries = [
GapicJiraSource.JiraQueries(
Expand Down Expand Up @@ -412,10 +352,7 @@
],
)
TEST_IMPORT_FILES_CONFIG_SHARE_POINT_SOURCE = ImportRagFilesConfig(
rag_file_chunking_config=RagFileChunkingConfig(
chunk_size=TEST_CHUNK_SIZE,
chunk_overlap=TEST_CHUNK_OVERLAP,
),
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
share_point_sources=GapicSharePointSources(
share_point_sources=[
GapicSharePointSources.SharePointSource(
Expand Down Expand Up @@ -490,10 +427,7 @@
)

TEST_IMPORT_FILES_CONFIG_SHARE_POINT_SOURCE_NO_FOLDERS = ImportRagFilesConfig(
rag_file_chunking_config=RagFileChunkingConfig(
chunk_size=TEST_CHUNK_SIZE,
chunk_overlap=TEST_CHUNK_OVERLAP,
),
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
share_point_sources=GapicSharePointSources(
share_point_sources=[
GapicSharePointSources.SharePointSource(
Expand Down
Loading

0 comments on commit 7127d97

Please sign in to comment.