Skip to content

Commit

Permalink
feat: [genai-modules][models] Add HttpOptions to all method configs f…
Browse files Browse the repository at this point in the history
…or models.

PiperOrigin-RevId: 705295653
  • Loading branch information
google-genai-bot authored and copybara-github committed Jan 13, 2025
1 parent 0e4b0e5 commit 88a7790
Show file tree
Hide file tree
Showing 11 changed files with 693 additions and 568 deletions.
198 changes: 74 additions & 124 deletions google/genai/batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,6 @@ def _CreateBatchJobConfig_to_mldev(
parent_object: dict = None,
) -> dict:
to_object = {}
if getv(from_object, ['http_options']) is not None:
setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))

if getv(from_object, ['display_name']) is not None:
setv(parent_object, ['displayName'], getv(from_object, ['display_name']))
Expand All @@ -135,8 +133,6 @@ def _CreateBatchJobConfig_to_vertex(
parent_object: dict = None,
) -> dict:
to_object = {}
if getv(from_object, ['http_options']) is not None:
setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))

if getv(from_object, ['display_name']) is not None:
setv(parent_object, ['displayName'], getv(from_object, ['display_name']))
Expand Down Expand Up @@ -215,30 +211,6 @@ def _CreateBatchJobParameters_to_vertex(
return to_object


def _GetBatchJobConfig_to_mldev(
api_client: ApiClient,
from_object: Union[dict, object],
parent_object: dict = None,
) -> dict:
to_object = {}
if getv(from_object, ['http_options']) is not None:
setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))

return to_object


def _GetBatchJobConfig_to_vertex(
api_client: ApiClient,
from_object: Union[dict, object],
parent_object: dict = None,
) -> dict:
to_object = {}
if getv(from_object, ['http_options']) is not None:
setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))

return to_object


def _GetBatchJobParameters_to_mldev(
api_client: ApiClient,
from_object: Union[dict, object],
Expand All @@ -248,15 +220,6 @@ def _GetBatchJobParameters_to_mldev(
if getv(from_object, ['name']) is not None:
raise ValueError('name parameter is not supported in Google AI.')

if getv(from_object, ['config']) is not None:
setv(
to_object,
['config'],
_GetBatchJobConfig_to_mldev(
api_client, getv(from_object, ['config']), to_object
),
)

return to_object


Expand All @@ -273,39 +236,6 @@ def _GetBatchJobParameters_to_vertex(
t.t_batch_job_name(api_client, getv(from_object, ['name'])),
)

if getv(from_object, ['config']) is not None:
setv(
to_object,
['config'],
_GetBatchJobConfig_to_vertex(
api_client, getv(from_object, ['config']), to_object
),
)

return to_object


def _CancelBatchJobConfig_to_mldev(
api_client: ApiClient,
from_object: Union[dict, object],
parent_object: dict = None,
) -> dict:
to_object = {}
if getv(from_object, ['http_options']) is not None:
setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))

return to_object


def _CancelBatchJobConfig_to_vertex(
api_client: ApiClient,
from_object: Union[dict, object],
parent_object: dict = None,
) -> dict:
to_object = {}
if getv(from_object, ['http_options']) is not None:
setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))

return to_object


Expand All @@ -318,15 +248,6 @@ def _CancelBatchJobParameters_to_mldev(
if getv(from_object, ['name']) is not None:
raise ValueError('name parameter is not supported in Google AI.')

if getv(from_object, ['config']) is not None:
setv(
to_object,
['config'],
_CancelBatchJobConfig_to_mldev(
api_client, getv(from_object, ['config']), to_object
),
)

return to_object


Expand All @@ -343,15 +264,6 @@ def _CancelBatchJobParameters_to_vertex(
t.t_batch_job_name(api_client, getv(from_object, ['name'])),
)

if getv(from_object, ['config']) is not None:
setv(
to_object,
['config'],
_CancelBatchJobConfig_to_vertex(
api_client, getv(from_object, ['config']), to_object
),
)

return to_object


Expand All @@ -361,8 +273,6 @@ def _ListBatchJobConfig_to_mldev(
parent_object: dict = None,
) -> dict:
to_object = {}
if getv(from_object, ['http_options']) is not None:
setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))

if getv(from_object, ['page_size']) is not None:
setv(
Expand All @@ -388,8 +298,6 @@ def _ListBatchJobConfig_to_vertex(
parent_object: dict = None,
) -> dict:
to_object = {}
if getv(from_object, ['http_options']) is not None:
setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))

if getv(from_object, ['page_size']) is not None:
setv(
Expand Down Expand Up @@ -727,9 +635,12 @@ def _create(
query_params = request_dict.get('_query')
if query_params:
path = f'{path}?{urlencode(query_params)}'
# TODO: remove the hack that pops config.
config = request_dict.pop('config', None)
http_options = config.pop('httpOptions', None) if config else None
http_options = (
parameter_model.config.http_options
if (hasattr(parameter_model, 'config') and parameter_model.config)
else None
)
request_dict.pop('config', None)
request_dict = _common.convert_to_dict(request_dict)
request_dict = _common.apply_base64_encoding(request_dict)

Expand Down Expand Up @@ -783,9 +694,12 @@ def get(
query_params = request_dict.get('_query')
if query_params:
path = f'{path}?{urlencode(query_params)}'
# TODO: remove the hack that pops config.
config = request_dict.pop('config', None)
http_options = config.pop('httpOptions', None) if config else None
http_options = (
parameter_model.config.http_options
if (hasattr(parameter_model, 'config') and parameter_model.config)
else None
)
request_dict.pop('config', None)
request_dict = _common.convert_to_dict(request_dict)
request_dict = _common.apply_base64_encoding(request_dict)

Expand Down Expand Up @@ -826,9 +740,12 @@ def cancel(
query_params = request_dict.get('_query')
if query_params:
path = f'{path}?{urlencode(query_params)}'
# TODO: remove the hack that pops config.
config = request_dict.pop('config', None)
http_options = config.pop('httpOptions', None) if config else None
http_options = (
parameter_model.config.http_options
if (hasattr(parameter_model, 'config') and parameter_model.config)
else None
)
request_dict.pop('config', None)
request_dict = _common.convert_to_dict(request_dict)
request_dict = _common.apply_base64_encoding(request_dict)

Expand All @@ -854,9 +771,12 @@ def _list(
query_params = request_dict.get('_query')
if query_params:
path = f'{path}?{urlencode(query_params)}'
# TODO: remove the hack that pops config.
config = request_dict.pop('config', None)
http_options = config.pop('httpOptions', None) if config else None
http_options = (
parameter_model.config.http_options
if (hasattr(parameter_model, 'config') and parameter_model.config)
else None
)
request_dict.pop('config', None)
request_dict = _common.convert_to_dict(request_dict)
request_dict = _common.apply_base64_encoding(request_dict)

Expand All @@ -879,7 +799,12 @@ def _list(
self._api_client._verify_response(return_value)
return return_value

def delete(self, *, name: str) -> types.DeleteResourceJob:
def delete(
self,
*,
name: str,
config: Optional[types.DeleteBatchJobConfigOrDict] = None,
) -> types.DeleteResourceJob:
"""Deletes a batch job.
Args:
Expand All @@ -899,6 +824,7 @@ def delete(self, *, name: str) -> types.DeleteResourceJob:

parameter_model = types._DeleteBatchJobParameters(
name=name,
config=config,
)

if not self._api_client.vertexai:
Expand All @@ -912,9 +838,12 @@ def delete(self, *, name: str) -> types.DeleteResourceJob:
query_params = request_dict.get('_query')
if query_params:
path = f'{path}?{urlencode(query_params)}'
# TODO: remove the hack that pops config.
config = request_dict.pop('config', None)
http_options = config.pop('httpOptions', None) if config else None
http_options = (
parameter_model.config.http_options
if (hasattr(parameter_model, 'config') and parameter_model.config)
else None
)
request_dict.pop('config', None)
request_dict = _common.convert_to_dict(request_dict)
request_dict = _common.apply_base64_encoding(request_dict)

Expand Down Expand Up @@ -1023,9 +952,12 @@ async def _create(
query_params = request_dict.get('_query')
if query_params:
path = f'{path}?{urlencode(query_params)}'
# TODO: remove the hack that pops config.
config = request_dict.pop('config', None)
http_options = config.pop('httpOptions', None) if config else None
http_options = (
parameter_model.config.http_options
if (hasattr(parameter_model, 'config') and parameter_model.config)
else None
)
request_dict.pop('config', None)
request_dict = _common.convert_to_dict(request_dict)
request_dict = _common.apply_base64_encoding(request_dict)

Expand Down Expand Up @@ -1079,9 +1011,12 @@ async def get(
query_params = request_dict.get('_query')
if query_params:
path = f'{path}?{urlencode(query_params)}'
# TODO: remove the hack that pops config.
config = request_dict.pop('config', None)
http_options = config.pop('httpOptions', None) if config else None
http_options = (
parameter_model.config.http_options
if (hasattr(parameter_model, 'config') and parameter_model.config)
else None
)
request_dict.pop('config', None)
request_dict = _common.convert_to_dict(request_dict)
request_dict = _common.apply_base64_encoding(request_dict)

Expand Down Expand Up @@ -1122,9 +1057,12 @@ async def cancel(
query_params = request_dict.get('_query')
if query_params:
path = f'{path}?{urlencode(query_params)}'
# TODO: remove the hack that pops config.
config = request_dict.pop('config', None)
http_options = config.pop('httpOptions', None) if config else None
http_options = (
parameter_model.config.http_options
if (hasattr(parameter_model, 'config') and parameter_model.config)
else None
)
request_dict.pop('config', None)
request_dict = _common.convert_to_dict(request_dict)
request_dict = _common.apply_base64_encoding(request_dict)

Expand All @@ -1150,9 +1088,12 @@ async def _list(
query_params = request_dict.get('_query')
if query_params:
path = f'{path}?{urlencode(query_params)}'
# TODO: remove the hack that pops config.
config = request_dict.pop('config', None)
http_options = config.pop('httpOptions', None) if config else None
http_options = (
parameter_model.config.http_options
if (hasattr(parameter_model, 'config') and parameter_model.config)
else None
)
request_dict.pop('config', None)
request_dict = _common.convert_to_dict(request_dict)
request_dict = _common.apply_base64_encoding(request_dict)

Expand All @@ -1175,7 +1116,12 @@ async def _list(
self._api_client._verify_response(return_value)
return return_value

async def delete(self, *, name: str) -> types.DeleteResourceJob:
async def delete(
self,
*,
name: str,
config: Optional[types.DeleteBatchJobConfigOrDict] = None,
) -> types.DeleteResourceJob:
"""Deletes a batch job.
Args:
Expand All @@ -1195,6 +1141,7 @@ async def delete(self, *, name: str) -> types.DeleteResourceJob:

parameter_model = types._DeleteBatchJobParameters(
name=name,
config=config,
)

if not self._api_client.vertexai:
Expand All @@ -1208,9 +1155,12 @@ async def delete(self, *, name: str) -> types.DeleteResourceJob:
query_params = request_dict.get('_query')
if query_params:
path = f'{path}?{urlencode(query_params)}'
# TODO: remove the hack that pops config.
config = request_dict.pop('config', None)
http_options = config.pop('httpOptions', None) if config else None
http_options = (
parameter_model.config.http_options
if (hasattr(parameter_model, 'config') and parameter_model.config)
else None
)
request_dict.pop('config', None)
request_dict = _common.convert_to_dict(request_dict)
request_dict = _common.apply_base64_encoding(request_dict)

Expand Down
Loading

0 comments on commit 88a7790

Please sign in to comment.