Skip to content

Commit

Permalink
update model from PaLM API to Gemini API
Browse files Browse the repository at this point in the history
  • Loading branch information
davidtamaki committed Jun 26, 2024
1 parent 23a1bde commit 7a4432b
Show file tree
Hide file tree
Showing 6 changed files with 623 additions and 518 deletions.
1 change: 1 addition & 0 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ sendgrid = "*"
ratelimit = "*"
backoff = "*"
pandas = "*"
numpy = "==1.26.4"

[dev-packages]

Expand Down
929 changes: 542 additions & 387 deletions Pipfile.lock

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ The two variables you must to modify are:
1. Deploy 3 cloud functions for action hub listing, action form, and action execute (this may take a few minutes):

```
gcloud functions deploy vertex-ai-list --entry-point action_list --env-vars-file .env.yaml --trigger-http --runtime=python311 --allow-unauthenticated --timeout=540s --region=${REGION} --project=${PROJECT} --service-account ${SERVICE_ACCOUNT_EMAIL} --set-secrets 'LOOKER_AUTH_TOKEN=LOOKER_AUTH_TOKEN:latest'
gcloud functions deploy vertex-ai-list --entry-point action_list --env-vars-file .env.yaml --trigger-http --runtime=python311 --allow-unauthenticated --no-gen2 --memory=1024MB --timeout=540s --region=${REGION} --project=${PROJECT} --service-account ${SERVICE_ACCOUNT_EMAIL} --set-secrets 'LOOKER_AUTH_TOKEN=LOOKER_AUTH_TOKEN:latest'
gcloud functions deploy vertex-ai-form --entry-point action_form --env-vars-file .env.yaml --trigger-http --runtime=python311 --allow-unauthenticated --timeout=540s --region=${REGION} --project=${PROJECT} --service-account ${SERVICE_ACCOUNT_EMAIL} --set-secrets 'LOOKER_AUTH_TOKEN=LOOKER_AUTH_TOKEN:latest'
gcloud functions deploy vertex-ai-form --entry-point action_form --env-vars-file .env.yaml --trigger-http --runtime=python311 --allow-unauthenticated --no-gen2 --memory=1024MB --timeout=540s --region=${REGION} --project=${PROJECT} --service-account ${SERVICE_ACCOUNT_EMAIL} --set-secrets 'LOOKER_AUTH_TOKEN=LOOKER_AUTH_TOKEN:latest'
gcloud functions deploy vertex-ai-execute --entry-point action_execute --env-vars-file .env.yaml --trigger-http --runtime=python311 --allow-unauthenticated --timeout=540s --region=${REGION} --project=${PROJECT} --service-account ${SERVICE_ACCOUNT_EMAIL} --set-secrets 'LOOKER_AUTH_TOKEN=LOOKER_AUTH_TOKEN:latest,SENDGRID_API_KEY=SENDGRID_API_KEY:latest' --memory=1024MB
gcloud functions deploy vertex-ai-execute --entry-point action_execute --env-vars-file .env.yaml --trigger-http --runtime=python311 --allow-unauthenticated --no-gen2 --memory=8192MB --timeout=540s --region=${REGION} --project=${PROJECT} --service-account ${SERVICE_ACCOUNT_EMAIL} --set-secrets 'LOOKER_AUTH_TOKEN=LOOKER_AUTH_TOKEN:latest,SENDGRID_API_KEY=SENDGRID_API_KEY:latest'
```

1. Copy the Action Hub URL (`action_list` endpoint) and the `LOOKER_AUTH_TOKEN` to input into Looker:
Expand Down
54 changes: 13 additions & 41 deletions palm_api.py → gemini_api.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,11 @@
import backoff
import ratelimit
from google.api_core import exceptions
import os
import vertexai
from vertexai.preview.language_models import TextGenerationModel, CodeGenerationModel

MODEL_TYPES = {
'text-bison': {
'name': 'text-bison',
'version': 'text-bison@001',
'label': 'Text Bison',
'max_output_tokens': 1024,
'model': TextGenerationModel.from_pretrained
},
'code-bison': {
'name': 'code-bison',
'version': 'code-bison@001',
'label': 'Code Bison',
'max_output_tokens': 2048,
'model': CodeGenerationModel.from_pretrained
}
}
DEFAULT_MODEL_TYPE = MODEL_TYPES['text-bison']['name']
from vertexai.generative_models import GenerationConfig, GenerativeModel
from google.api_core import exceptions

MODEL_VARIANT = 'gemini-1.5-flash'

# https://cloud.google.com/vertex-ai/docs/quotas#request_quotas
CALL_LIMIT = 50 # Number of calls to allow within a period
Expand Down Expand Up @@ -78,35 +62,24 @@ def backoff_hdlr(details):
@ratelimit.limits( # Limit the number of calls to the model per minute
calls=CALL_LIMIT, period=ONE_MINUTE
)
def model_prediction(model: TextGenerationModel | CodeGenerationModel,
model_type: str,
def model_prediction(model: GenerativeModel,
content: str,
temperature: float,
max_output_tokens: int,
top_k: int,
top_p: float,
):
"""Predict using a Large Language Model."""
if model_type == DEFAULT_MODEL_TYPE:
response = model.predict(
content,
temperature=temperature,
max_output_tokens=max_output_tokens,
top_k=top_k,
top_p=top_p)
else:
response = model.predict(
content,
temperature=temperature,
max_output_tokens=max_output_tokens)
print('Response from {} model: {}'.format(model_type, response))
config = GenerationConfig(max_output_tokens=max_output_tokens,
temperature=temperature, top_p=top_p, top_k=top_k)
response = model.generate_content(content, generation_config=config)
print('Response from model: {}'.format(response))
return response


def model_with_limit_and_backoff(all_data: dict,
question: str,
row_chunks: int,
model_type: str,
temperature: float,
max_output_tokens: int,
top_k: int,
Expand All @@ -115,7 +88,7 @@ def model_with_limit_and_backoff(all_data: dict,
"""Split data into chunks to call model predict function and applies rate limiting."""
vertexai.init(project=os.environ.get('PROJECT'),
location=os.environ.get('REGION'))
model = MODEL_TYPES[model_type]['model'](MODEL_TYPES[model_type]['version'])
model = GenerativeModel(MODEL_VARIANT)
initial_summary = []
list_size = len(all_data)

Expand All @@ -125,14 +98,13 @@ def model_with_limit_and_backoff(all_data: dict,
print('Processing rows {} to {}.'.format(i, i+row_chunks))
content = initial_prompt_template.format(question=question, data=chunk)
summary = model_prediction(
model, model_type, content, temperature, max_output_tokens, top_k, top_p).text
model, content, temperature, max_output_tokens, top_k, top_p).text
initial_summary.append(summary) # append summary to list of summaries

return initial_summary


def reduce(initial_summary: any,
model_type: str,
temperature: float,
max_output_tokens: int,
top_k: int,
Expand All @@ -142,11 +114,11 @@ def reduce(initial_summary: any,

vertexai.init(project=os.environ.get('PROJECT'),
location=os.environ.get('REGION'))
model = MODEL_TYPES[model_type]['model'](MODEL_TYPES[model_type]['version'])
model = GenerativeModel(MODEL_VARIANT)
content = final_prompt_template.format(text=initial_summary)

# Generate a summary using the model and the prompt
summary = model_prediction(
model, model_type, content, temperature, max_output_tokens, top_k, top_p).text
model, content, temperature, max_output_tokens, top_k, top_p).text

return summary
62 changes: 17 additions & 45 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,24 @@
from sendgrid.helpers.mail import Mail
from icon import icon_data_uri
from utils import authenticate, handle_error, list_to_html, safe_cast, sanitize_and_load_json_str
from palm_api import model_with_limit_and_backoff, reduce, MODEL_TYPES, DEFAULT_MODEL_TYPE
from gemini_api import model_with_limit_and_backoff, reduce


BASE_DOMAIN = 'https://{}-{}.cloudfunctions.net/{}-'.format(os.environ.get(
'REGION'), os.environ.get('PROJECT'), os.environ.get('ACTION_NAME'))

OUTPUT_TOKEN_LIMIT = 8192

# https://github.com/looker-open-source/actions/blob/master/docs/action_api.md#actions-list-endpoint


def action_list(request):
"""Return action hub list endpoint data for action"""
auth = authenticate(request)
if auth.status_code != 200:
return auth

response = {
'label': 'Looker Vertex AI [DEV]',
'label': 'Looker Vertex AI',
'integrations': [{
'name': os.environ.get('ACTION_NAME'),
'label': os.environ.get('ACTION_LABEL'),
Expand Down Expand Up @@ -67,10 +69,6 @@ def action_form(request):
if 'default_params' in form_params:
default_params = form_params['default_params']

default_model_type = ''
if 'model_type' in form_params:
default_model_type = form_params['model_type']

# step 1 - select a prompt
response = [{
'name': 'question',
Expand Down Expand Up @@ -102,24 +100,8 @@ def action_form(request):
'interactive': True # dynamic field for model specific options
}]

# step 2 - optional - choose model type
if ('default_params' in form_params and
form_params['default_params'] == 'no'):
response.extend([{
'name': 'model_type',
'label': 'Model Type',
'type': 'select',
'default': default_model_type,
'options': [{'name': MODEL_TYPES['text-bison']['name'], 'label': MODEL_TYPES['text-bison']['label']},
{'name': MODEL_TYPES['code-bison']['name'], 'label': MODEL_TYPES['code-bison']['label']}],
'interactive': True
}
])

# step 3a - optional - customize model params used by both models
if ('default_params' in form_params and
form_params['default_params'] == 'no' and
'model_type' in form_params):
# step 2 - optional - customize model params used by both models
if ('default_params' in form_params and form_params['default_params'] == 'no'):
response.extend([{
'name': 'temperature',
'label': 'Temperature',
Expand All @@ -130,18 +112,11 @@ def action_form(request):
{
'name': 'max_output_tokens',
'label': 'Max Output Tokens',
'description': 'Maximum number of tokens that can be generated in the response (Acceptable values = 1–1024 for Text Bison. Acceptable values = 1-2048 for Code Bison.)',
'description': 'Maximum number of tokens that can be generated in the response (Acceptable values = 1 - {})'.format(OUTPUT_TOKEN_LIMIT),
'type': 'text',
'default': '1024',
}
])

# step 3b - optional - customize model params used by text-bison
if ('default_params' in form_params and
form_params['default_params'] == 'no' and
'model_type' in form_params and
form_params['model_type'] == DEFAULT_MODEL_TYPE):
response.extend([{
'default': str(OUTPUT_TOKEN_LIMIT),
},
{
'name': 'top_k',
'label': 'Top-k',
'description': 'Top-k changes how the model selects tokens for output. Specify a lower value for less random responses and a higher value for more random responses. (Acceptable values = 1-40)',
Expand Down Expand Up @@ -176,13 +151,10 @@ def action_execute(request):
print(action_params)
print(form_params)

maximum_max_output_tokens = MODEL_TYPES[form_params['model_type']]['max_output_tokens'] if 'model_type' in form_params else 1024

model_type = DEFAULT_MODEL_TYPE if 'model_type' not in form_params else MODEL_TYPES[form_params['model_type']]['name']
temperature = 0.2 if 'temperature' not in form_params else safe_cast(
form_params['temperature'], float, 0.0, 1.0, 0.2)
max_output_tokens = 1024 if 'max_output_tokens' not in form_params else safe_cast(
form_params['max_output_tokens'], int, 1, maximum_max_output_tokens, 1024)
max_output_tokens = OUTPUT_TOKEN_LIMIT if 'max_output_tokens' not in form_params else safe_cast(
form_params['max_output_tokens'], int, 1, OUTPUT_TOKEN_LIMIT, OUTPUT_TOKEN_LIMIT)
top_k = 40 if 'top_k' not in form_params else safe_cast(
form_params['top_k'], int, 1, 40, 40)
top_p = 0.8 if 'top_p' not in form_params else safe_cast(
Expand All @@ -191,15 +163,15 @@ def action_execute(request):
# placeholder for model error email response
body = 'There was a problem running the model. Please try again with less data. '
summary = ''
row_chunks = 50 # mumber of rows to summarize together
row_chunks = 200 # mumber of rows to summarize together
try:
all_data = sanitize_and_load_json_str(
attachment['data'])
if form_params['row_or_all'] == 'row':
row_chunks = 1 # run function on each row individually

summary = model_with_limit_and_backoff(
all_data, question, row_chunks, model_type, temperature, max_output_tokens, top_k, top_p)
all_data, question, row_chunks, temperature, max_output_tokens, top_k, top_p)

# if row, zip prompt_result with all_data and send html table
if form_params['row_or_all'] == 'row':
Expand All @@ -214,7 +186,7 @@ def action_execute(request):
summary[0].replace('\n', '<br>'))
else:
reduced_summary = reduce(
'\n'.join(summary), model_type, temperature, max_output_tokens, top_k, top_p)
'\n'.join(summary), temperature, max_output_tokens, top_k, top_p)
body = 'Final Prompt Result:<br><strong>{}</strong><br><br>'.format(
reduced_summary.replace('\n', '<br>'))
body += '<br><br><strong>Batch Prompt Result:</strong><br>'
Expand All @@ -224,7 +196,7 @@ def action_execute(request):
body += list_to_html(all_data)

except Exception as e:
body += 'PaLM API Error: ' + e.message
body += 'Gemini API Error: ' + e.message
print(body)

if body == '':
Expand Down
89 changes: 47 additions & 42 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,45 +1,50 @@
-i https://pypi.org/simple
annotated-types==0.7.0
backoff==2.2.1
blinker==1.6.2 ; python_version >= '3.7'
cachetools==5.3.1 ; python_version >= '3.7'
certifi==2023.5.7 ; python_version >= '3.6'
charset-normalizer==3.2.0 ; python_full_version >= '3.7.0'
click==8.1.4 ; python_version >= '3.7'
flask==2.3.2
google-api-core[grpc]==2.11.1 ; python_version >= '3.7'
google-auth==2.22.0 ; python_version >= '3.6'
google-cloud-aiplatform==1.28.0
google-cloud-bigquery==3.11.3 ; python_version >= '3.7'
google-cloud-core==2.3.3 ; python_version >= '3.7'
google-cloud-resource-manager==1.10.2 ; python_version >= '3.7'
google-cloud-storage==2.10.0 ; python_version >= '3.7'
google-crc32c==1.5.0 ; python_version >= '3.7'
google-resumable-media==2.5.0 ; python_version >= '3.7'
googleapis-common-protos==1.59.1 ; python_version >= '3.7'
grpc-google-iam-v1==0.12.6 ; python_version >= '3.7'
grpcio==1.56.0
grpcio-status==1.56.0
idna==3.4 ; python_version >= '3.5'
itsdangerous==2.1.2 ; python_version >= '3.7'
jinja2==3.1.2 ; python_version >= '3.7'
markupsafe==2.1.3 ; python_version >= '3.7'
numpy==1.25.1 ; python_version >= '3.10'
packaging==23.1 ; python_version >= '3.7'
pandas==2.0.3
proto-plus==1.22.3 ; python_version >= '3.6'
protobuf==4.23.4 ; python_version >= '3.7'
pyasn1==0.5.0 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5'
pyasn1-modules==0.3.0 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5'
python-dateutil==2.8.2 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
python-http-client==3.3.7 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
pytz==2023.3
blinker==1.8.2
cachetools==5.3.3
certifi==2024.6.2
charset-normalizer==3.3.2
click==8.1.7
docstring_parser==0.16
functions-framework==3.* # set this manually to avoid incompatible dependencies error
Flask==3.0.3
google-api-core==2.19.1
google-auth==2.30.0
google-cloud-aiplatform==1.56.0
google-cloud-bigquery==3.25.0
google-cloud-core==2.4.1
google-cloud-resource-manager==1.12.3
google-cloud-storage==2.17.0
google-crc32c==1.5.0
google-resumable-media==2.7.1
googleapis-common-protos==1.63.2
grpc-google-iam-v1==0.13.1
grpcio==1.64.1
grpcio-status==1.62.2
idna==3.7
itsdangerous==2.2.0
Jinja2==3.1.4
MarkupSafe==2.1.5
numpy==1.26.4
packaging==24.1
pandas==2.2.2
proto-plus==1.24.0
protobuf==4.25.3
pyasn1==0.6.0
pyasn1_modules==0.4.0
pydantic==2.7.4
pydantic_core==2.18.4
python-dateutil==2.9.0.post0
python-http-client==3.3.7
pytz==2024.1
ratelimit==2.2.1
requests==2.31.0 ; python_version >= '3.7'
rsa==4.9 ; python_version >= '3.6' and python_version < '4'
sendgrid==6.10.0
shapely==1.8.5.post1 ; python_version >= '3.6'
six==1.16.0 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
requests==2.32.3
rsa==4.9
sendgrid==6.11.0
shapely==2.0.4
six==1.16.0
starkbank-ecdsa==2.2.0
tzdata==2023.3 ; python_version >= '2'
urllib3==1.26.16 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5'
werkzeug==2.3.6 ; python_version >= '3.8'
typing_extensions==4.12.2
tzdata==2024.1
urllib3==2.2.2
Werkzeug==3.0.3

0 comments on commit 7a4432b

Please sign in to comment.