Skip to content

Commit

Permalink
Merge pull request #298 from tfukaza/ut-yf
Browse files Browse the repository at this point in the history
Refactor yf ut
  • Loading branch information
tfukaza authored Jun 2, 2024
2 parents 2f4e20d + 878a11a commit 1fd0a39
Show file tree
Hide file tree
Showing 10 changed files with 653 additions and 229 deletions.
1 change: 1 addition & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pip install .
### Unit Testing
After any modifications to the code, conduct unit tests by running:
```bash
pip install . --upgrade --no-deps --force-reinstall
python -m unittest discover -s tests/unittest
```
from the project's root directory. This will run the tests defined in the `tests` directory.
Expand Down
4 changes: 2 additions & 2 deletions harvest/broker/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def step(self) -> None:
This method is called at the interval specified by the user.
It should create a dictionary where each key is the symbol for an asset,
and the value is the corresponding data in the following pandas dataframe format:
Symbol
[TICKER]
open high low close volume
timestamp
--- --- --- --- --- ---
Expand Down Expand Up @@ -308,7 +308,7 @@ def fetch_price_history(
:param interval: The interval of requested historical data.
:param start: The starting date of the period, inclusive.
:param end: The ending date of the period, inclusive.
:returns: A pandas dataframe, same format as main()
:returns: A pandas dataframe, same format as self.step()
"""
raise NotImplementedError(
f"{type(self).__name__} class does not support the method {inspect.currentframe().f_code.co_name}."
Expand Down
73 changes: 59 additions & 14 deletions harvest/broker/yahoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from harvest.broker._base import Broker
from harvest.definitions import Account, Stats
from harvest.enum import Interval
from harvest.util.date import convert_input_to_datetime, date_to_str, utc_current_time, utc_epoch_zero
from harvest.util.date import convert_input_to_datetime, date_to_str, str_to_datetime, utc_current_time, utc_epoch_zero
from harvest.util.helper import (
check_interval,
debugger,
Expand Down Expand Up @@ -172,39 +172,53 @@ def fetch_price_history(
crypto = True

df = yf.download(symbol, period=period, interval=get_fmt, prepost=True, progress=False)
debugger.debug(df)
if crypto:
symbol = "@" + symbol[:-4]
df = self._format_df(df, symbol)
df = df.loc[start:end]
debugger.debug(f"From yfinance got: {df}")
debugger.debug(f"Filtering from {start} to {end}")
df = df.loc[start:end]

return df

@Broker._exception_handler
def fetch_chain_info(self, symbol: str) -> Dict[str, Any]:
"""
Return the list of option expirations dates available for the given symbol.
YFinance returns option chain data as tuple of expiration dates, formatted as "YYYY-MM-DD".
YFinance gets data from NASDAQ, NYSE, and NYSE America, sp option expiration dates
use the Eastern Time Zone. (TODO: Check if this is correct)
"""
option_list = self.watch_ticker[symbol].options
return {
"id": "n/a",
"exp_dates": [convert_input_to_datetime(s, no_tz=True) for s in option_list],
"exp_dates": [str_to_datetime(date) for date in option_list],
"multiplier": 100,
}

@Broker._exception_handler
def fetch_chain_data(self, symbol: str, date: dt.datetime) -> pd.DataFrame:
"""
Return the option chain list for a given symbol and expiration date.
YFinance returns option chain data in the Options class.
This class has two attributes: calls and puts, each which are DataFrames in the following format:
contractSymbol lastTradeDate strike lastPrice bid ask change percentChange volume openInterest impliedVolatility inTheMoney contractSize currency
0 MSFT240614P00220000 2024-05-13 13:55:14+00:00 220.0 0.02 ... ... ... ... ... ... 0.937501 False REGULAR USD
"""
if bool(self.option_cache) and symbol in self.option_cache and date in self.option_cache[symbol]:
return self.option_cache[symbol][date]

df = pd.DataFrame(columns=["contractSymbol", "exp_date", "strike", "type"])

chain = self.watch_ticker[symbol].option_chain(date_to_str(date))

print(f"From yfinance got: {chain}")

puts = chain.puts
puts["type"] = "put"
calls = chain.calls
calls["type"] = "call"
df = df.append(puts)
df = df.append(calls)
df = pd.concat([df, puts, calls])

df = df.rename(columns={"contractSymbol": "occ_symbol"})
df["exp_date"] = df.apply(lambda x: self.occ_to_data(x["occ_symbol"])[1], axis=1)
Expand All @@ -219,13 +233,14 @@ def fetch_chain_data(self, symbol: str, date: dt.datetime) -> pd.DataFrame:

@Broker._exception_handler
def fetch_option_market_data(self, occ_symbol: str) -> Dict[str, Any]:
"""
Return the market data for a given option symbol.
"""
occ_symbol = occ_symbol.replace(" ", "")
symbol, date, typ, _ = self.occ_to_data(occ_symbol)
chain = self.watch_ticker[symbol].option_chain(date_to_str(date))
chain = chain.calls if typ == "call" else chain.puts
df = chain[chain["contractSymbol"] == occ_symbol]

debugger.debug(df)
return {
"price": float(df["lastPrice"].iloc[0]),
"ask": float(df["ask"].iloc[0]),
Expand All @@ -234,9 +249,17 @@ def fetch_option_market_data(self, occ_symbol: str) -> Dict[str, Any]:

@Broker._exception_handler
def fetch_market_hours(self, date: datetime.date) -> Dict[str, Any]:
# yfinance does not support getting market hours,
# so use the free Tradier API instead.
# See documentation.tradier.com/brokerage-api/markets/get-clock
"""
Get the market hours for the given date.
yfinance does not support getting market hours, so use the free Tradier API instead.
See documentation.tradier.com/brokerage-api/markets/get-clock
This API cannot be used to check market hours on a specific date, only the current day.
"""

if date.date() != utc_current_time().date():
raise ValueError("Cannot check market hours for a specific date")

response = requests.get(
"https://api.tradier.com/v1/markets/clock",
params={"delayed": "false"},
Expand Down Expand Up @@ -289,17 +312,35 @@ def fetch_market_hours(self, date: datetime.date) -> Dict[str, Any]:
# ------------- Helper methods ------------- #

def _format_df(self, df: pd.DataFrame, symbol: str) -> pd.DataFrame:
df = df.copy()
"""
Format the DataFrame returned by yfinance to the format expected by the BrokerHub.
If the Dataframe contains 1 ticker, Yfinance returns with the following columns:
Open High Low Close Adj Close Volume
Index is a pandas datetime index
If the Dataframe contains multiple tickers, Yfinance returns the following multi-index columns:
Price: Open High Low Close Adj Close Volume
Ticker: TICK1 TICK2 TICK1 TICK2 TICK1 TICK2 TICK1 TICK2 TICK1 TICK2 TICK1 TICK2
Index is a pandas datetime index
"""
# df = df.copy()
df.reset_index(inplace=True)
ts_name = df.columns[0]
df["timestamp"] = df[ts_name]
print(df)
df = df.set_index(["timestamp"])
d = df.index[0]
if d.tzinfo is None or d.tzinfo.utcoffset(d) is None:
df = df.tz_localize("UTC")
else:
df = df.tz_convert(tz="UTC")
df = df.drop([ts_name], axis=1)
# Drop adjusted close column
df = df.drop(["Adj Close"], axis=1)
print(df)
df = df.rename(
columns={
"Open": "open",
Expand All @@ -313,4 +354,8 @@ def _format_df(self, df: pd.DataFrame, symbol: str) -> pd.DataFrame:

df.columns = pd.MultiIndex.from_product([[symbol], df.columns])

return df.dropna()
print(df)

df.dropna(inplace=True)

return df
30 changes: 20 additions & 10 deletions harvest/util/date.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ def utc_epoch_zero() -> dt.datetime:
return dt.datetime(1970, 1, 1, tzinfo=tz.utc)


def get_local_timezone() -> ZoneInfo:
"""
Returns a datetime timezone instance for the user's current timezone using their system time.
"""
return dt.datetime.now(None).astimezone().tzinfo


def date_to_str(day: dt.date) -> str:
"""
Returns a string representation of the date in the format YYYY-MM-DD
Expand All @@ -37,21 +44,22 @@ def str_to_date(day: str) -> dt.date:
return dt.datetime.strptime(day, "%Y-%m-%d")


def str_to_datetime(date: str) -> dt.datetime:
def str_to_datetime(date: str, timezone: ZoneInfo = None) -> dt.datetime:
"""
Returns a datetime object from a string.
If timezone is not specified, the timezone is assumed to be UTC-0.
:date: A string in the format YYYY-MM-DD hh:mm
"""
if len(date) <= 10:
return dt.datetime.strptime(date, "%Y-%m-%d")
return dt.datetime.strptime(date, "%Y-%m-%d %H:%M")
ret = dt.datetime.strptime(date, "%Y-%m-%d")
else:
ret = dt.datetime.strptime(date, "%Y-%m-%d %H:%M")

if timezone is None:
timezone = tz.utc

def get_local_timezone() -> ZoneInfo:
"""
Returns a datetime timezone instance for the user's current timezone using their system time.
"""
return dt.datetime.now(tz.utc).astimezone().tzinfo
ret = ret.replace(tzinfo=timezone)
return ret


def convert_input_to_datetime(datetime: Union[str, dt.datetime], timezone: ZoneInfo = None, no_tz=False) -> dt.datetime:
Expand All @@ -62,6 +70,10 @@ def convert_input_to_datetime(datetime: Union[str, dt.datetime], timezone: ZoneI
covert it to UTC. If timezone is None then the system's local
timezone is used.
"""

if timezone is None:
timezone = get_local_timezone()

if datetime is None:
return None
elif isinstance(datetime, str):
Expand All @@ -75,8 +87,6 @@ def convert_input_to_datetime(datetime: Union[str, dt.datetime], timezone: ZoneI
return datetime

if not has_timezone(datetime):
if timezone is None:
timezone = get_local_timezone()
datetime = datetime.replace(tzinfo=timezone)

datetime = datetime.astimezone(tz.utc)
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ install_requires =
tqdm == 4.66.4
tzlocal == 5.2
tzdata == 2024.1
yfinance == 0.2.38
yfinance >= 0.2.38
SQLAlchemy == 2.0.0
flask-login == 0.6.3
flask-cors == 3.0.10
Expand Down
129 changes: 0 additions & 129 deletions tests/livetest/test_api_yahoo.py

This file was deleted.

Loading

0 comments on commit 1fd0a39

Please sign in to comment.