diff --git a/main.py b/main.py index 37306f3..3d83986 100644 --- a/main.py +++ b/main.py @@ -1,42 +1,9 @@ from ollama import chat from ollama import ChatResponse -import yfinance as yf import pandas as pd import streamlit as st - - -def get_stock_prices(symbol:str, duration:str) -> pd.DataFrame: - """Get stock prices from yahoo finance api - Args: - symbol (str): stock symbol - duration (str): time duration - Returns: - pd.DataFrame: stock prices - """ - suffix='.NS' - durations = {'1 year': '1y', '1month': '1mo', '1 month': '1mo', '1M': '1mo', '1year': '1y'} - if duration not in durations: - durations[duration] = duration - try: - stock = yf.Ticker(symbol+suffix) - data = stock.history(period=durations[duration]) - return data - except Exception as e: - return None - - -def current_pe_ratio(symbol:str,) -> int: - """Get current pe ratio for a stock - Args: - symbol (str): stock symbol - - Returns: - int: current pe ratio - """ - suffix = '.NS' - stock = yf.Ticker(symbol+suffix) - return stock.info.get('trailingPE', None) - +from tools import (get_address1, get_beta, get_marketcap, get_current_price, get_stock_prices, current_pe_ratio, get_52_week_high, + get_52_week_low,get_current_ratio, get_debt_to_equity, get_free_cash_flow, get_eps, get_price_to_book) st.title('AI Stock Data Retriever') @@ -52,12 +19,25 @@ def current_pe_ratio(symbol:str,) -> int: available_functions = { 'get_stock_prices': get_stock_prices, 'current_pe_ratio': current_pe_ratio, + 'get_current_price': get_current_price, + 'get_marketcap': get_marketcap, + 'get_beta': get_beta, + 'get_address': get_address1, + 'get_52_week_high':get_52_week_high, + 'get_52_week_low':get_52_week_low, + 'get_current_ratio':get_current_ratio, + 'get_debt_to_equity':get_debt_to_equity, + 'get_free_cash_flow':get_free_cash_flow, + 'get_eps':get_eps, + 'get_price_to_book':get_price_to_book + } response: ChatResponse = chat( 'llama3.2:1b', messages=messages, - tools=[get_stock_prices, current_pe_ratio], + tools=[get_stock_prices, current_pe_ratio, get_current_price, get_marketcap, get_beta, get_address1, get_52_week_high, get_52_week_low, + get_current_ratio, get_debt_to_equity, get_free_cash_flow, get_eps, get_price_to_book] ) if response.message.tool_calls: @@ -73,18 +53,12 @@ def current_pe_ratio(symbol:str,) -> int: else: st.write(output) print('Function output:', output) - else: - print('Function', tool.function.name, 'not found') - - # Only needed to chat with the model using the tool call results - if response.message.tool_calls: - # Add the function response to messages for the model to use - messages.append(response.message) - messages.append({'role': 'tool', 'content': str(output), 'name': tool.function.name}) + messages.append(response.message) + messages.append({'role': 'tool', 'function output:': str(output), 'name': tool.function.name}) - # Get final response from model with function outputs - final_response = chat('llama3.2:1b', messages=messages) - print('Final response:', final_response.message.content) - - else: - print('No tool calls returned from model') \ No newline at end of file + # Get final response from model with function outputs + final_response = chat('llama3.2:1b', messages=messages) + st.write(final_response.message.content) + print('Final response:', final_response.message.content) + else: + print('Function', tool.function.name, 'not found') \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index bcc3d79..e365fc6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -751,13 +751,13 @@ files = [ [[package]] name = "narwhals" -version = "1.21.1" +version = "1.22.0" description = "Extremely lightweight compatibility layer between dataframe libraries" optional = false python-versions = ">=3.8" files = [ - {file = "narwhals-1.21.1-py3-none-any.whl", hash = "sha256:f5f2cd33a6fa636de74067f4050d6dd9a9343b39a5a911dc97810d55d8f24cdd"}, - {file = "narwhals-1.21.1.tar.gz", hash = "sha256:44082c6273fd0125a2bde5baae6ddb7465e185c24fd6e1c5e71cab1d746c89cc"}, + {file = "narwhals-1.22.0-py3-none-any.whl", hash = "sha256:5c931bf8696b6dec276f590f1bc5043080606b16ce86d85c9b550312c981970f"}, + {file = "narwhals-1.22.0.tar.gz", hash = "sha256:8e257c5af70a82382796706f39d681290f2c482812474524087f36cb69f9d2f1"}, ] [package.extras] @@ -841,13 +841,13 @@ files = [ [[package]] name = "ollama" -version = "0.4.5" +version = "0.4.6" description = "The official Python client for Ollama." optional = false python-versions = "<4.0,>=3.8" files = [ - {file = "ollama-0.4.5-py3-none-any.whl", hash = "sha256:74936de89a41c87c9745f09f2e1db964b4783002188ac21241bfab747f46d925"}, - {file = "ollama-0.4.5.tar.gz", hash = "sha256:e7fb71a99147046d028ab8b75e51e09437099aea6f8f9a0d91a71f787e97439e"}, + {file = "ollama-0.4.6-py3-none-any.whl", hash = "sha256:cbb4ebe009e10dd12bdd82508ab415fd131945e185753d728a7747c9ebe762e9"}, + {file = "ollama-0.4.6.tar.gz", hash = "sha256:b00717651c829f96094ed4231b9f0d87e33cc92dc235aca50aeb5a2a4e6e95b7"}, ] [package.dependencies] @@ -1739,4 +1739,4 @@ repair = ["scipy (>=1.6.3)"] [metadata] lock-version = "2.0" python-versions = "^3.12" -content-hash = "0ce9fc6b9b5bca227d612f1769f827e0cc818eed18a05c3d67078ed16d569a7f" +content-hash = "9a95e62af78e60afd127e9644d5d29fcdeaeb92402236b2d5b127709a4bbcb1a" diff --git a/pyproject.toml b/pyproject.toml index cbd300b..b45177c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [tool.poetry] -name = "2025" +name = "agent-from-scratch" version = "0.1.0" description = "" authors = ["shubhammandowara "] @@ -7,7 +7,7 @@ readme = "README.md" [tool.poetry.dependencies] python = "^3.12" -ollama = "^0.4.5" +ollama = "^0.4.6" streamlit = "^1.41.1" yfinance = "^0.2.51" pandas = "^2.2.3" diff --git a/tools.py b/tools.py new file mode 100644 index 0000000..2f5693a --- /dev/null +++ b/tools.py @@ -0,0 +1,149 @@ +from typing import Any +import yfinance as yf +import pandas as pd +def get_parameter(symbol, parameter: str) -> Any: + """Retrieve a value for a specific parameter from the data dictionary. + + Args: + symbol (str): stock symbol + parameter (str): The key representing the parameter. + + Returns: + Any: The value of the parameter, or a message if not found. + """ + suffix = '.NS' + stock = yf.Ticker(symbol+suffix) + return stock.info.get(parameter, None) + +def get_address1(symbol:str) -> str: + """Retrieve address1. + Args: + symbol (str): Stock symbol + Return: + str: address + """ + return get_parameter(symbol, 'address1') + + +def get_current_price(symbol:str) -> float: + """Retrieve the current stock price. + Args: + symbol (str): Stock symbol + Return: + float: price + """ + return get_parameter(symbol, 'currentPrice') + + +def get_marketcap(symbol:str) -> float: + """Retrieve the marketcap + Args: + symbol (str): Stock symbol + Return: + float: price + """ + return get_parameter(symbol, 'marketCap') + +def get_beta(symbol:str) -> float: + """Retrieve the beta + Args: + symbol (str): Stock symbol + Return: + float: price + """ + return get_parameter(symbol, 'beta') + +def get_stock_prices(symbol:str, duration:str) -> pd.DataFrame: + """Get stock prices from yahoo finance api + Args: + symbol (str): stock symbol + duration (str): time duration + Returns: + pd.DataFrame: stock prices + """ + suffix='.NS' + durations = {'1 year': '1y', '1month': '1mo', '1 month': '1mo', '1M': '1mo', '1year': '1y'} + if duration not in durations: + durations[duration] = duration + try: + stock = yf.Ticker(symbol+suffix) + data = stock.history(period=durations[duration]) + return data + except Exception as e: + return None + + +def current_pe_ratio(symbol:str) -> int: + """Get current pe ratio for a stock + Args: + symbol (str): stock symbol + + Returns: + int: current pe ratio + """ + return get_parameter(symbol, 'trailingPE') + +import yfinance as yf + +def get_52_week_low(symbol: str) -> float: + """Retrieve the 52 week low price + Args: + symbol (str): Stock symbol + Return: + float: 52week low price + """ + return get_parameter(symbol, 'fiftyTwoWeekLow') + +def get_52_week_high(symbol: str): + """Retrieve the debt-to-equity + Args: + symbol (str): Stock symbol + Return: + float: 52 week high prices + """ + return get_parameter(symbol, 'fiftyTwoWeekHigh') + +def get_price_to_book(symbol: str): + """Retrieve the price to book value + Args: + symbol (str): Stock symbol + Return: + float: price to book value + """ + return get_parameter(symbol, 'priceToBook') + +def get_eps(symbol: str) -> float: + """Retrieve the get EPS + Args: + symbol (str): Stock symbol + Return: + float: EPS + """ + return get_parameter(symbol, 'trailingEps') + +def get_current_ratio(symbol: str)-> float: + """Retrieve the current ratio + Args: + symbol (str): Stock symbol + Return: + float: current ratio + """ + return get_parameter(symbol, 'currentRatio') + +def get_debt_to_equity(symbol: str)-> float: + """Retrieve the debt-to-equity + Args: + symbol (str): Stock symbol + Return: + float: debt to equity + """ + return get_parameter(symbol, 'debtToEquity') + +def get_free_cash_flow(symbol: str): + """Retrieve the free cash flow + Args: + symbol (str): Stock symbol + Return: + float: cashflow + """ + return get_parameter(symbol, 'freeCashflow')