Skip to content

Commit

Permalink
Add GenAI Strategy save method
Browse files Browse the repository at this point in the history
  • Loading branch information
om-khade-algobulls committed Aug 19, 2023
1 parent 4fbff0d commit c892e3a
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 12 deletions.
44 changes: 35 additions & 9 deletions pyalgotrading/algobulls/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import requests

from .exceptions import AlgoBullsAPIBaseException, AlgoBullsAPIUnauthorizedErrorException, AlgoBullsAPIInsufficientBalanceErrorException, AlgoBullsAPIResourceNotFoundErrorException, AlgoBullsAPIBadRequestException, \
AlgoBullsAPIInternalServerErrorException, AlgoBullsAPIForbiddenErrorException, AlgoBullsAPIGatewayTimeoutErrorException
AlgoBullsAPIInternalServerErrorException, AlgoBullsAPIForbiddenErrorException, AlgoBullsAPIGatewayTimeoutErrorException, AlgoBullsAPITooManyRequestsException
from ..constants import TradingType, TradingReportType
from ..utils.func import get_raw_response

Expand All @@ -34,7 +34,6 @@ def __init__(self, connection):
self.__key_papertrading = {} # strategy-cstc_id mapping
self.__key_realtrading = {} # strategy-cstc_id mapping
self.pattern = re.compile(r'(?<!^)(?=[A-Z])')
self.genai_api_key = None
self.genai_session_id = None
self.genai_sessions_map = None

Expand Down Expand Up @@ -98,6 +97,9 @@ def _send_request(self, method: str = 'get', endpoint: str = '', base_url: str =
elif r.status_code == 404:
r.raw.decode_content = True
raise AlgoBullsAPIResourceNotFoundErrorException(method=method, url=url, response=get_raw_response(r), status_code=404)
elif r.status_code == 429:
r.raw.decode_content = True
raise AlgoBullsAPITooManyRequestsException(method=method, url=url, response=get_raw_response(r), status_code=429)
elif r.status_code == 500:
r.raw.decode_content = True
raise AlgoBullsAPIInternalServerErrorException(method=method, url=url, response=get_raw_response(r), status_code=500)
Expand Down Expand Up @@ -481,6 +483,17 @@ def get_reports(self, strategy_code: str, trading_type: TradingType, report_type

return response

def set_genai_api_key(self, genai_api_key):
endpoint = 'v1/build/python/genai/key'
json_data = {"openaiApiKey": genai_api_key}
response = self._send_request(method='post', endpoint=endpoint, json_data=json_data)
return response

def get_genai_api_key_status(self):
endpoint = f'v1/build/python/genai/key'
response = self._send_request(endpoint=endpoint)
return response

def get_genai_response(self, user_prompt: str, chat_gpt_model: str = ''):
"""
Fetch GenAI response.
Expand All @@ -495,10 +508,17 @@ def get_genai_response(self, user_prompt: str, chat_gpt_model: str = ''):
`GET` v1/build/python/genai Get GenAI response
"""
endpoint = 'v1/build/python/genai'
params = {"userPrompt": user_prompt, 'sessionId': self.genai_session_id, 'openaiApiKey': self.genai_api_key, 'chatGPTModel': chat_gpt_model}
response = self._send_request(endpoint=endpoint, params=params)
if self.genai_session_id is None and 'session_id' in response:
self.genai_session_id = response['session_id']
params = {"userPrompt": user_prompt, 'sessionId': self.genai_session_id, 'chatGPTModel': chat_gpt_model}

try:
response = self._send_request(endpoint=endpoint, params=params)
if self.genai_session_id is None and 'session_id' in response:
self.genai_session_id = response['session_id']
except (AlgoBullsAPIResourceNotFoundErrorException, AlgoBullsAPIForbiddenErrorException, AlgoBullsAPIBadRequestException, AlgoBullsAPITooManyRequestsException) as ex:
print('\nFail.')
print(f'{ex.get_error_type()}: {ex.response}')
response = None

return response

def handle_genai_response_timeout(self):
Expand All @@ -514,9 +534,15 @@ def handle_genai_response_timeout(self):
"""
endpoint = 'v1/build/python/genai/response'
params = {'session_id': self.genai_session_id}
response = self._send_request(endpoint=endpoint, params=params)
if self.genai_session_id is None and 'session_id' in response:
self.genai_session_id = response['session_id']

try:
response = self._send_request(endpoint=endpoint, params=params)
if self.genai_session_id is None and 'session_id' in response:
self.genai_session_id = response['session_id']
except AlgoBullsAPIResourceNotFoundErrorException as ex:
print('\nFail.')
print(f'{ex.get_error_type()}: {ex.response}')
response = None

return response

Expand Down
39 changes: 36 additions & 3 deletions pyalgotrading/algobulls/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def set_generative_ai_keys(self, genai_api_key):
genai_api_key: GenAI API key
"""
assert isinstance(genai_api_key, str), f'Argument "api_key" should be a string'
self.api.genai_api_key = genai_api_key
self.api.set_genai_api_key(genai_api_key)

def get_genai_response_pooling(self, no_of_tries, user_prompt=None, chat_gpt_model=None):
if no_of_tries < GENAI_RESPONSE_POOLING_LIMIT:
Expand Down Expand Up @@ -150,8 +150,10 @@ def display_session_chat_history(self, session_id):
print(f"No available chat history for session id: {session_id}")

def start_chat(self, start_fresh=None, session_id=None, chat_gpt_model=None):
assert self.api.genai_api_key, f"Please set your GenAI key using set_generative_ai_keys()"
# This will set the session_id
response = self.api.get_genai_api_key_status()
assert response['key_available'], f"Please set your GenAI key using set_generative_ai_keys()"

# This will reset the session_id
if start_fresh:
# reset session
self.api.genai_session_id = None
Expand All @@ -173,9 +175,40 @@ def start_chat(self, start_fresh=None, session_id=None, chat_gpt_model=None):
print("Session End")
return

print("Please wait your request is being precessed.")
response = self.get_genai_response_pooling(1, user_prompt, chat_gpt_model)
if not response:
break

self.recent_genai_response = response['message']
print(f"GenAI: {response['message']}", end="\n\n")

def save_last_generated_strategy(self, strategy_code=None, strategy_name=None):
if self.recent_genai_response or strategy_code:
strategy_name = strategy_name or f'GenAI Strategy-{time.time():.0f}'
strategy_details = strategy_code or self.recent_genai_response

pattern = r"```python\n(.*?)\n```\n"
code_matches = re.findall(pattern, strategy_details, re.DOTALL)

if not code_matches:
print(strategy_details)
print("Do you want to save the following strategy? (Yes/No)")

while True:
user_response = input().lower()
if user_response == 'yes':
break
elif user_response == 'no':
return
else:
strategy_details = code_matches[0]

response = self.api.create_strategy(strategy_name=strategy_name, strategy_details=strategy_details, abc_version='3.3.3')
return response
else:
print("Please generate a GenAI strategy")

def create_strategy(self, strategy, overwrite=False, strategy_code=None, abc_version=None):
"""
Method to upload new strategy.
Expand Down
9 changes: 9 additions & 0 deletions pyalgotrading/algobulls/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ def get_error_type(self):
return 'Resource Not Found'


class AlgoBullsAPITooManyRequestsException(AlgoBullsAPIBaseException):
"""
Exception class for HTTP status code of 429 (Too Many Requests)
"""

def get_error_type(self):
return 'Too Many Requests'


class AlgoBullsAPIInternalServerErrorException(AlgoBullsAPIBaseException):
"""
Exception class for HTTP status code of 500 (Internal Server Error)
Expand Down

0 comments on commit c892e3a

Please sign in to comment.