From f5fbe7dd40947bd3e9b6c85bf0187a62ef12200f Mon Sep 17 00:00:00 2001 From: tfukaza Date: Mon, 27 May 2024 18:55:26 -0700 Subject: [PATCH 1/2] Refactor yf ut --- CONTRIBUTING.md | 1 + harvest/broker/_base.py | 4 +- harvest/broker/yahoo.py | 73 ++++++-- harvest/util/date.py | 30 ++-- setup.cfg | 2 +- tests/livetest/_test_util.py | 23 +++ tests/livetest/test_api_yahoo.py | 129 -------------- tests/livetest/test_broker_base.py | 258 +++++++++++++++++++++++++++ tests/livetest/test_broker_common.py | 146 +++++++-------- tests/livetest/test_broker_yahoo.py | 216 ++++++++++++++++++++++ 10 files changed, 653 insertions(+), 229 deletions(-) create mode 100644 tests/livetest/_test_util.py delete mode 100644 tests/livetest/test_api_yahoo.py create mode 100644 tests/livetest/test_broker_base.py create mode 100644 tests/livetest/test_broker_yahoo.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e760d4a..855cdad 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -24,6 +24,7 @@ pip install . ### Unit Testing After any modifications to the code, conduct unit tests by running: ```bash +pip install . --upgrade --no-deps --force-reinstall python -m unittest discover -s tests/unittest ``` from the project's root directory. This will run the tests defined in the `tests` directory. diff --git a/harvest/broker/_base.py b/harvest/broker/_base.py index 0c5cfdf..b3890a9 100644 --- a/harvest/broker/_base.py +++ b/harvest/broker/_base.py @@ -241,7 +241,7 @@ def step(self) -> None: This method is called at the interval specified by the user. It should create a dictionary where each key is the symbol for an asset, and the value is the corresponding data in the following pandas dataframe format: - Symbol + [TICKER] open high low close volume timestamp --- --- --- --- --- --- @@ -308,7 +308,7 @@ def fetch_price_history( :param interval: The interval of requested historical data. :param start: The starting date of the period, inclusive. :param end: The ending date of the period, inclusive. - :returns: A pandas dataframe, same format as main() + :returns: A pandas dataframe, same format as self.step() """ raise NotImplementedError( f"{type(self).__name__} class does not support the method {inspect.currentframe().f_code.co_name}." diff --git a/harvest/broker/yahoo.py b/harvest/broker/yahoo.py index 89e2cca..c4cd677 100644 --- a/harvest/broker/yahoo.py +++ b/harvest/broker/yahoo.py @@ -11,7 +11,7 @@ from harvest.broker._base import Broker from harvest.definitions import Account, Stats from harvest.enum import Interval -from harvest.util.date import convert_input_to_datetime, date_to_str, utc_current_time, utc_epoch_zero +from harvest.util.date import convert_input_to_datetime, date_to_str, str_to_datetime, utc_current_time, utc_epoch_zero from harvest.util.helper import ( check_interval, debugger, @@ -172,24 +172,41 @@ def fetch_price_history( crypto = True df = yf.download(symbol, period=period, interval=get_fmt, prepost=True, progress=False) + debugger.debug(df) if crypto: symbol = "@" + symbol[:-4] df = self._format_df(df, symbol) - df = df.loc[start:end] debugger.debug(f"From yfinance got: {df}") + debugger.debug(f"Filtering from {start} to {end}") + df = df.loc[start:end] + return df @Broker._exception_handler def fetch_chain_info(self, symbol: str) -> Dict[str, Any]: + """ + Return the list of option expirations dates available for the given symbol. + YFinance returns option chain data as tuple of expiration dates, formatted as "YYYY-MM-DD". + YFinance gets data from NASDAQ, NYSE, and NYSE America, sp option expiration dates + use the Eastern Time Zone. (TODO: Check if this is correct) + """ option_list = self.watch_ticker[symbol].options return { "id": "n/a", - "exp_dates": [convert_input_to_datetime(s, no_tz=True) for s in option_list], + "exp_dates": [str_to_datetime(date) for date in option_list], "multiplier": 100, } @Broker._exception_handler def fetch_chain_data(self, symbol: str, date: dt.datetime) -> pd.DataFrame: + """ + Return the option chain list for a given symbol and expiration date. + + YFinance returns option chain data in the Options class. + This class has two attributes: calls and puts, each which are DataFrames in the following format: + contractSymbol lastTradeDate strike lastPrice bid ask change percentChange volume openInterest impliedVolatility inTheMoney contractSize currency + 0 MSFT240614P00220000 2024-05-13 13:55:14+00:00 220.0 0.02 ... ... ... ... ... ... 0.937501 False REGULAR USD + """ if bool(self.option_cache) and symbol in self.option_cache and date in self.option_cache[symbol]: return self.option_cache[symbol][date] @@ -197,14 +214,11 @@ def fetch_chain_data(self, symbol: str, date: dt.datetime) -> pd.DataFrame: chain = self.watch_ticker[symbol].option_chain(date_to_str(date)) - print(f"From yfinance got: {chain}") - puts = chain.puts puts["type"] = "put" calls = chain.calls calls["type"] = "call" - df = df.append(puts) - df = df.append(calls) + df = pd.concat([df, puts, calls]) df = df.rename(columns={"contractSymbol": "occ_symbol"}) df["exp_date"] = df.apply(lambda x: self.occ_to_data(x["occ_symbol"])[1], axis=1) @@ -219,13 +233,14 @@ def fetch_chain_data(self, symbol: str, date: dt.datetime) -> pd.DataFrame: @Broker._exception_handler def fetch_option_market_data(self, occ_symbol: str) -> Dict[str, Any]: + """ + Return the market data for a given option symbol. + """ occ_symbol = occ_symbol.replace(" ", "") symbol, date, typ, _ = self.occ_to_data(occ_symbol) chain = self.watch_ticker[symbol].option_chain(date_to_str(date)) chain = chain.calls if typ == "call" else chain.puts df = chain[chain["contractSymbol"] == occ_symbol] - - debugger.debug(df) return { "price": float(df["lastPrice"].iloc[0]), "ask": float(df["ask"].iloc[0]), @@ -234,9 +249,17 @@ def fetch_option_market_data(self, occ_symbol: str) -> Dict[str, Any]: @Broker._exception_handler def fetch_market_hours(self, date: datetime.date) -> Dict[str, Any]: - # yfinance does not support getting market hours, - # so use the free Tradier API instead. - # See documentation.tradier.com/brokerage-api/markets/get-clock + """ + Get the market hours for the given date. + yfinance does not support getting market hours, so use the free Tradier API instead. + See documentation.tradier.com/brokerage-api/markets/get-clock + + This API cannot be used to check market hours on a specific date, only the current day. + """ + + if date.date() != utc_current_time().date(): + raise ValueError("Cannot check market hours for a specific date") + response = requests.get( "https://api.tradier.com/v1/markets/clock", params={"delayed": "false"}, @@ -289,10 +312,25 @@ def fetch_market_hours(self, date: datetime.date) -> Dict[str, Any]: # ------------- Helper methods ------------- # def _format_df(self, df: pd.DataFrame, symbol: str) -> pd.DataFrame: - df = df.copy() + """ + Format the DataFrame returned by yfinance to the format expected by the BrokerHub. + + If the Dataframe contains 1 ticker, Yfinance returns with the following columns: + Open High Low Close Adj Close Volume + Index is a pandas datetime index + + If the Dataframe contains multiple tickers, Yfinance returns the following multi-index columns: + + Price: Open High Low Close Adj Close Volume + Ticker: TICK1 TICK2 TICK1 TICK2 TICK1 TICK2 TICK1 TICK2 TICK1 TICK2 TICK1 TICK2 + Index is a pandas datetime index + + """ + # df = df.copy() df.reset_index(inplace=True) ts_name = df.columns[0] df["timestamp"] = df[ts_name] + print(df) df = df.set_index(["timestamp"]) d = df.index[0] if d.tzinfo is None or d.tzinfo.utcoffset(d) is None: @@ -300,6 +338,9 @@ def _format_df(self, df: pd.DataFrame, symbol: str) -> pd.DataFrame: else: df = df.tz_convert(tz="UTC") df = df.drop([ts_name], axis=1) + # Drop adjusted close column + df = df.drop(["Adj Close"], axis=1) + print(df) df = df.rename( columns={ "Open": "open", @@ -313,4 +354,8 @@ def _format_df(self, df: pd.DataFrame, symbol: str) -> pd.DataFrame: df.columns = pd.MultiIndex.from_product([[symbol], df.columns]) - return df.dropna() + print(df) + + df.dropna(inplace=True) + + return df diff --git a/harvest/util/date.py b/harvest/util/date.py index e9c4181..4092309 100644 --- a/harvest/util/date.py +++ b/harvest/util/date.py @@ -22,6 +22,13 @@ def utc_epoch_zero() -> dt.datetime: return dt.datetime(1970, 1, 1, tzinfo=tz.utc) +def get_local_timezone() -> ZoneInfo: + """ + Returns a datetime timezone instance for the user's current timezone using their system time. + """ + return dt.datetime.now(None).astimezone().tzinfo + + def date_to_str(day: dt.date) -> str: """ Returns a string representation of the date in the format YYYY-MM-DD @@ -37,21 +44,22 @@ def str_to_date(day: str) -> dt.date: return dt.datetime.strptime(day, "%Y-%m-%d") -def str_to_datetime(date: str) -> dt.datetime: +def str_to_datetime(date: str, timezone: ZoneInfo = None) -> dt.datetime: """ Returns a datetime object from a string. + If timezone is not specified, the timezone is assumed to be UTC-0. :date: A string in the format YYYY-MM-DD hh:mm """ if len(date) <= 10: - return dt.datetime.strptime(date, "%Y-%m-%d") - return dt.datetime.strptime(date, "%Y-%m-%d %H:%M") + ret = dt.datetime.strptime(date, "%Y-%m-%d") + else: + ret = dt.datetime.strptime(date, "%Y-%m-%d %H:%M") + if timezone is None: + timezone = tz.utc -def get_local_timezone() -> ZoneInfo: - """ - Returns a datetime timezone instance for the user's current timezone using their system time. - """ - return dt.datetime.now(tz.utc).astimezone().tzinfo + ret = ret.replace(tzinfo=timezone) + return ret def convert_input_to_datetime(datetime: Union[str, dt.datetime], timezone: ZoneInfo = None, no_tz=False) -> dt.datetime: @@ -62,6 +70,10 @@ def convert_input_to_datetime(datetime: Union[str, dt.datetime], timezone: ZoneI covert it to UTC. If timezone is None then the system's local timezone is used. """ + + if timezone is None: + timezone = get_local_timezone() + if datetime is None: return None elif isinstance(datetime, str): @@ -75,8 +87,6 @@ def convert_input_to_datetime(datetime: Union[str, dt.datetime], timezone: ZoneI return datetime if not has_timezone(datetime): - if timezone is None: - timezone = get_local_timezone() datetime = datetime.replace(tzinfo=timezone) datetime = datetime.astimezone(tz.utc) diff --git a/setup.cfg b/setup.cfg index d5b132c..5ecd2ba 100644 --- a/setup.cfg +++ b/setup.cfg @@ -22,7 +22,7 @@ install_requires = tqdm == 4.66.4 tzlocal == 5.2 tzdata == 2024.1 - yfinance == 0.2.38 + yfinance >= 0.2.38 SQLAlchemy == 2.0.0 flask-login == 0.6.3 flask-cors == 3.0.10 diff --git a/tests/livetest/_test_util.py b/tests/livetest/_test_util.py new file mode 100644 index 0000000..76f2767 --- /dev/null +++ b/tests/livetest/_test_util.py @@ -0,0 +1,23 @@ +import datetime as dt + +""" +For testing purposes, assume: +- Current day is September 15th, 2008 +- Current time is 10:00 AM +- Current timezone is US/Eastern +""" + + +def mock_get_local_timezone(): + """ + Return the US/Eastern timezone + """ + return dt.timezone(dt.timedelta(hours=-4)) + + +def mock_utc_current_time(): + """ + Return the current time in UTC timezone + """ + d = dt.datetime(2008, 9, 15, 10, 0, 0, tzinfo=dt.timezone(dt.timedelta(hours=-4))) + return d.astimezone(dt.timezone.utc) diff --git a/tests/livetest/test_api_yahoo.py b/tests/livetest/test_api_yahoo.py deleted file mode 100644 index c725b68..0000000 --- a/tests/livetest/test_api_yahoo.py +++ /dev/null @@ -1,129 +0,0 @@ -import datetime as dt -import time -import unittest - -from harvest.broker.yahoo import YahooBroker -from harvest.definitions import Account, Stats -from harvest.enum import Interval -from harvest.util.helper import debugger, utc_current_time - -debugger.setLevel("DEBUG") - - -class TestYahooStreamer(unittest.TestCase): - def test_current_time(self): - broker = YahooBroker() - - threshold = dt.timedelta(seconds=5) - current_time = broker.get_current_time() - self.assertTrue(utc_current_time() - current_time < threshold) - - time.sleep(60) - - def test_fetch_stock_price(self): - broker = YahooBroker() - - # Use datetime with no timezone for start and end - end = dt.datetime.now() - dt.timedelta(days=7) - start = end - dt.timedelta(days=7) - results = broker.fetch_price_history("AAPL", Interval.MIN_1, start, end) - self.assertTrue(results.shape[0] > 0) - self.assertTrue(results.shape[1] == 5) - - # Use datetime with timezone for start and end - start = start.astimezone(dt.timezone(dt.timedelta(hours=2))) - end = end.astimezone(dt.timezone(dt.timedelta(hours=2))) - results = broker.fetch_price_history("AAPL", Interval.MIN_1, start, end) - self.assertTrue(results.shape[0] > 0) - self.assertTrue(results.shape[1] == 5) - - # Use ISO 8601 string for start and end - start = "2022-01-21T09:00-05:00" - end = "2022-01-25T17:00-05:00" - results = broker.fetch_price_history("AAPL", Interval.MIN_1, start, end) - self.assertTrue(results.shape[0] > 0) - self.assertTrue(results.shape[1] == 5) - - time.sleep(60) - - def test_fetch_prices(self): - yh = YahooBroker() - df = yh.fetch_price_history("SPY", Interval.HR_1) - df = df["SPY"] - self.assertEqual(list(df.columns.values), ["open", "high", "low", "close", "volume"]) - - def test_setup(self): - yh = YahooBroker() - interval = { - "SPY": {"interval": Interval.MIN_15, "aggregations": []}, - "AAPL": {"interval": Interval.MIN_1, "aggregations": []}, - } - stats = Stats(watchlist_cfg=interval) - yh.setup(stats, Account()) - - self.assertEqual(yh.poll_interval, Interval.MIN_1) - self.assertListEqual(list(yh.stats.watchlist_cfg.keys()), ["SPY", "AAPL"]) - - def test_main(self): - interval = { - "SPY": {"interval": Interval.MIN_1, "aggregations": []}, - "AAPL": {"interval": Interval.MIN_1, "aggregations": []}, - "@BTC": {"interval": Interval.MIN_1, "aggregations": []}, - } - - def test_main(df): - self.assertEqual(len(df), 3) - self.assertEqual(df["SPY"].columns[0][0], "SPY") - self.assertEqual(df["AAPL"].columns[0][0], "AAPL") - self.assertEqual(df["@BTC"].columns[0][0], "@BTC") - - yh = YahooBroker() - stats = Stats(watchlist_cfg=interval) - yh.setup(stats, Account(), test_main) - - yh.step() - - def test_main_single(self): - interval = {"SPY": {"interval": Interval.MIN_1, "aggregations": []}} - - def test_main(df): - self.assertEqual(len(df), 1) - self.assertEqual(df["SPY"].columns[0][0], "SPY") - - yh = YahooBroker() - stats = Stats(watchlist_cfg=interval) - yh.setup(stats, Account(), test_main) - - yh.step() - - def test_chain_info(self): - yh = YahooBroker() - - interval = {"SPY": {"interval": Interval.MIN_1, "aggregations": []}} - stats = Stats(watchlist_cfg=interval) - yh.setup(stats, Account()) - - info = yh.fetch_chain_info("SPY") - - self.assertGreater(len(info["exp_dates"]), 0) - - def test_chain_data(self): - yh = YahooBroker() - - interval = {"SPY": {"interval": Interval.MIN_1, "aggregations": []}} - stats = Stats(watchlist_cfg=interval) - yh.setup(stats, Account()) - - dates = yh.fetch_chain_info("SPY")["exp_dates"] - data = yh.fetch_chain_data("SPY", dates[0]) - self.assertGreater(len(data), 0) - self.assertListEqual(list(data.columns), ["exp_date", "strike", "type"]) - - sym = data.index[0] - _ = yh.fetch_option_market_data(sym) - - self.assertTrue(True) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/livetest/test_broker_base.py b/tests/livetest/test_broker_base.py new file mode 100644 index 0000000..5009934 --- /dev/null +++ b/tests/livetest/test_broker_base.py @@ -0,0 +1,258 @@ +import datetime as dt +import unittest +from functools import wraps +from unittest import mock + +import pandas as pd +from _test_util import mock_get_local_timezone, mock_utc_current_time + +from harvest.definitions import Account, Stats +from harvest.enum import Interval +from harvest.util.helper import debugger + +debugger.setLevel("DEBUG") + +""" +For testing purposes, assume: +- Current day is September 15th, 2008 +- Current time is 10:00 AM +- Current timezone is US/Eastern +""" + + +def repeat_test(broker_list): + def decorator_test(func): + @wraps(func) + def wrapper_repeat(*args): + self = args[0] + for broker in broker_list: + print(f"Testing {broker}") + func(self, broker) + + return wrapper_repeat + + return decorator_test + + +class TestBroker(object): + """ + Base class for testing Broker implementations. + Each brokers should inherit from this class and implement the necessary + setup and teardown procedures specific to the broker, and call the code + in this class to test the common functionalities. + + """ + + def _define_patch(self, path, side_effect): + patcher = mock.patch(path) + func = patcher.start() + func.side_effect = side_effect + self.addCleanup(patcher.stop) + + def setUp(self): + self._define_patch("harvest.util.date.get_local_timezone", mock_get_local_timezone) + self._define_patch("harvest.util.date.utc_current_time", mock_utc_current_time) + + def test_fetch_stock_price(self, broker): + """ + Test fetching stock price history + The returned DataFrame should be in the format: + [Ticker] + open high low close volume + timestamp + + Where timestamp is a pandas datetime object in UTC timezone, + and open, high, low, close, and volume are float values. + """ + + broker = broker() + + end = mock_utc_current_time() + start = end - dt.timedelta(days=1) + results = broker.fetch_price_history("AAPL", Interval.MIN_1, start, end) + # Check that the returned DataFrame is not empty + self.assertGreaterEqual(len(results), 1) + self.assertTrue(results.shape[1] == 5) + # Check that the returned DataFrame has the correct columns + self.assertListEqual( + list(results.columns), + [("AAPL", "open"), ("AAPL", "high"), ("AAPL", "low"), ("AAPL", "close"), ("AAPL", "volume")], + ) + # Check that the returned DataFrame has the correct index + self.assertTrue(results.index[0] >= start) + self.assertTrue(results.index[-1] <= end) + # Check that the returned DataFrame has the correct data types + self.assertEqual(results.dtypes["AAPL", "open"], float) + self.assertEqual(results.dtypes["AAPL", "high"], float) + self.assertEqual(results.dtypes["AAPL", "low"], float) + self.assertEqual(results.dtypes["AAPL", "close"], float) + self.assertEqual(results.dtypes["AAPL", "volume"], float) + + # Check that the returned DataFrame has the correct index type + self.assertEqual(type(results.index[0]), pd.Timestamp) + self.assertEqual(results.index.tzinfo, dt.timezone.utc) + + def test_fetch_stock_price_timezone(self, broker): + """ + Test that the price history returned + correctly adjusts the input to utc timezone. + """ + broker = broker() + + # Create an end date in ETC timezone + end = dt.datetime(2008, 9, 15, 10, 0, 0, tzinfo=dt.timezone(dt.timedelta(hours=-4))) + start = end - dt.timedelta(days=1) + results = broker.fetch_price_history("AAPL", Interval.MIN_1, start, end) + + # The last timestamp in the returned DataFrame should be 4 hours ahead of the end date, + # since UTC-0 is 4 hours ahead of UTC-4 + self.assertEqual(results.index[-1], end.astimezone(dt.timezone.utc)) + + def test_fetch_stock_price_str_input(self, broker): + """ + Test fetching stock price history using Yahoo Broker + with string input for start and end dates. + As with datetime objects, time is converted from local timezone to UTC. + """ + broker = broker() + # Use ISO 8601 string for start and end + start = "2008-09-15T09:00" + end = "2008-09-15T10:00" + results = broker.fetch_price_history("AAPL", Interval.MIN_1, start, end) + self.assertEqual(type(results.index[0]), pd.Timestamp) + self.assertEqual(results.index.tzinfo, dt.timezone.utc) + self.assertEqual(results.index[-1], dt.datetime(2008, 9, 15, 14, 0, 0, tzinfo=dt.timezone.utc)) + + def test_setup(self, broker): + """ + Test that the broker is correctly set up with the stats and account objects. + """ + broker = broker() + interval = { + "SPY": {"interval": Interval.MIN_15, "aggregations": []}, + "AAPL": {"interval": Interval.MIN_1, "aggregations": []}, + } + stats = Stats(watchlist_cfg=interval) + broker.setup(stats, Account()) + + self.assertEqual(broker.poll_interval, Interval.MIN_1) + self.assertListEqual(list(broker.stats.watchlist_cfg.keys()), ["SPY", "AAPL"]) + + def test_main(self, broker): + """ + Test that the main function is called with the correct security data. + """ + interval = { + "SPY": {"interval": Interval.MIN_1, "aggregations": []}, + "AAPL": {"interval": Interval.MIN_1, "aggregations": []}, + "@BTC": {"interval": Interval.MIN_1, "aggregations": []}, + } + + def test_main(df): + self.assertEqual(len(df), 3) + self.assertEqual(df["SPY"].columns[0][0], "SPY") + self.assertEqual(df["AAPL"].columns[0][0], "AAPL") + self.assertEqual(df["@BTC"].columns[0][0], "@BTC") + + broker = broker() + + stats = Stats(watchlist_cfg=interval) + broker.setup(stats, Account(), test_main) + + # Call the main function + broker.step() + + def test_chain_info(self, broker): + broker = broker() + + interval = {"SPY": {"interval": Interval.MIN_1, "aggregations": []}} + stats = Stats(watchlist_cfg=interval) + broker.setup(stats, Account()) + + info = broker.fetch_chain_info("SPY") + + self.assertGreater(len(info["exp_dates"]), 0) + + def test_chain_data(self, broker): + broker = broker() + + interval = {"SPY": {"interval": Interval.MIN_1, "aggregations": []}} + stats = Stats(watchlist_cfg=interval) + broker.setup(stats, Account()) + + dates = broker.fetch_chain_info("SPY") + print(dates) + dates = dates["exp_dates"] + chain = broker.fetch_chain_data("SPY", dates[0]) + self.assertGreater(len(chain), 0) + self.assertListEqual(list(chain.columns), ["exp_date", "strike", "type"]) + + print(chain) + + # TODO test getting market data + + # def test_buy_option(self, api): + # api = api(secret_path) + # interval = { + # "TWTR": {"interval": Interval.MIN_5, "aggregations": []}, + # } + # stats = Stats(watchlist_cfg=interval) + # api.setup(stats, Account()) + + # # Get a list of all options + # dates = api.fetch_chain_info("TWTR")["exp_dates"] + # data = api.fetch_chain_data("TWTR", dates[1]) + # option = data.iloc[0] + + # exp_date = option["exp_date"] + # strike = option["strike"] + + # ret = api.order_option_limit("buy", "TWTR", 1, 0.01, "call", exp_date, strike) + + # time.sleep(5) + + # api.cancel_option_order(ret["order_id"]) + + # self.assertTrue(True) + + # def test_buy_stock(self, api): + # """ + # Test that it can buy stocks + # """ + # api = api(secret_path) + # interval = { + # "TWTR": {"interval": Interval.MIN_5, "aggregations": []}, + # } + # stats = Stats(watchlist_cfg=interval) + # api.setup(stats, Account()) + + # # Limit order TWTR stock at an extremely low limit price + # # to ensure the order is not actually filled. + # ret = api.order_stock_limit("buy", "TWTR", 1, 10.0) + + # time.sleep(5) + + # api.cancel_stock_order(ret["order_id"]) + + # def test_buy_crypto(self, api): + # """ + # Test that it can buy crypto + # """ + # api = api(secret_path) + # interval = { + # "@DOGE": {"interval": Interval.MIN_5, "aggregations": []}, + # } + # stats = Stats(watchlist_cfg=interval) + # api.setup(stats, Account()) + + # # Limit order DOGE at an extremely low limit price + # # to ensure the order is not actually filled. + # ret = api.order_crypto_limit("buy", "@DOGE", 1, 0.10) + + # time.sleep(5) + + # api.cancel_crypto_order(ret["order_id"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/livetest/test_broker_common.py b/tests/livetest/test_broker_common.py index 1e057df..ff3cd18 100644 --- a/tests/livetest/test_broker_common.py +++ b/tests/livetest/test_broker_common.py @@ -1,98 +1,98 @@ -import os -import time -import unittest +# import os +# import time +# import unittest -from harvest.broker.robinhood import RobinhoodBroker -from harvest.definitions import Account, Interval, Stats -from harvest.util.helper import debugger +# from harvest.broker.robinhood import RobinhoodBroker +# from harvest.definitions import Account, Interval, Stats +# from harvest.util.helper import debugger -secret_path = os.environ["SECRET_PATH"] -debugger.setLevel("DEBUG") +# secret_path = os.environ["SECRET_PATH"] +# debugger.setLevel("DEBUG") -import functools +# import functools -# A decorator to repeat the same test for all the brokers -def decorator_repeat_test(broker_list): - def decorator_test(func): - @functools.wraps(func) - def wrapper_repeat(*args): - self = args[0] - for api in broker_list: - print(f"Testing {api}") - func(self, api) +# # A decorator to repeat the same test for all the brokers +# def decorator_repeat_test(broker_list): +# def decorator_test(func): +# @functools.wraps(func) +# def wrapper_repeat(*args): +# self = args[0] +# for api in broker_list: +# print(f"Testing {api}") +# func(self, api) - return wrapper_repeat +# return wrapper_repeat - return decorator_test +# return decorator_test -class TestLiveBroker(unittest.TestCase): - @decorator_repeat_test([RobinhoodBroker]) - def test_buy_option(self, api): - api = api(secret_path) - interval = { - "TWTR": {"interval": Interval.MIN_5, "aggregations": []}, - } - stats = Stats(watchlist_cfg=interval) - api.setup(stats, Account()) +# class TestLiveBroker(unittest.TestCase): +# @decorator_repeat_test([RobinhoodBroker]) +# def test_buy_option(self, api): +# api = api(secret_path) +# interval = { +# "TWTR": {"interval": Interval.MIN_5, "aggregations": []}, +# } +# stats = Stats(watchlist_cfg=interval) +# api.setup(stats, Account()) - # Get a list of all options - dates = api.fetch_chain_info("TWTR")["exp_dates"] - data = api.fetch_chain_data("TWTR", dates[1]) - option = data.iloc[0] +# # Get a list of all options +# dates = api.fetch_chain_info("TWTR")["exp_dates"] +# data = api.fetch_chain_data("TWTR", dates[1]) +# option = data.iloc[0] - exp_date = option["exp_date"] - strike = option["strike"] +# exp_date = option["exp_date"] +# strike = option["strike"] - ret = api.order_option_limit("buy", "TWTR", 1, 0.01, "call", exp_date, strike) +# ret = api.order_option_limit("buy", "TWTR", 1, 0.01, "call", exp_date, strike) - time.sleep(5) +# time.sleep(5) - api.cancel_option_order(ret["order_id"]) +# api.cancel_option_order(ret["order_id"]) - self.assertTrue(True) +# self.assertTrue(True) - @decorator_repeat_test([RobinhoodBroker]) - def test_buy_stock(self, api): - """ - Test that it can buy stocks - """ - api = api(secret_path) - interval = { - "TWTR": {"interval": Interval.MIN_5, "aggregations": []}, - } - stats = Stats(watchlist_cfg=interval) - api.setup(stats, Account()) +# @decorator_repeat_test([RobinhoodBroker]) +# def test_buy_stock(self, api): +# """ +# Test that it can buy stocks +# """ +# api = api(secret_path) +# interval = { +# "TWTR": {"interval": Interval.MIN_5, "aggregations": []}, +# } +# stats = Stats(watchlist_cfg=interval) +# api.setup(stats, Account()) - # Limit order TWTR stock at an extremely low limit price - # to ensure the order is not actually filled. - ret = api.order_stock_limit("buy", "TWTR", 1, 10.0) +# # Limit order TWTR stock at an extremely low limit price +# # to ensure the order is not actually filled. +# ret = api.order_stock_limit("buy", "TWTR", 1, 10.0) - time.sleep(5) +# time.sleep(5) - api.cancel_stock_order(ret["order_id"]) +# api.cancel_stock_order(ret["order_id"]) - @decorator_repeat_test([RobinhoodBroker]) - def test_buy_crypto(self, api): - """ - Test that it can buy crypto - """ - api = api(secret_path) - interval = { - "@DOGE": {"interval": Interval.MIN_5, "aggregations": []}, - } - stats = Stats(watchlist_cfg=interval) - api.setup(stats, Account()) +# @decorator_repeat_test([RobinhoodBroker]) +# def test_buy_crypto(self, api): +# """ +# Test that it can buy crypto +# """ +# api = api(secret_path) +# interval = { +# "@DOGE": {"interval": Interval.MIN_5, "aggregations": []}, +# } +# stats = Stats(watchlist_cfg=interval) +# api.setup(stats, Account()) - # Limit order DOGE at an extremely low limit price - # to ensure the order is not actually filled. - ret = api.order_crypto_limit("buy", "@DOGE", 1, 0.10) +# # Limit order DOGE at an extremely low limit price +# # to ensure the order is not actually filled. +# ret = api.order_crypto_limit("buy", "@DOGE", 1, 0.10) - time.sleep(5) +# time.sleep(5) - api.cancel_crypto_order(ret["order_id"]) +# api.cancel_crypto_order(ret["order_id"]) -if __name__ == "__main__": - unittest.main() +# if __name__ == "__main__": +# unittest.main() diff --git a/tests/livetest/test_broker_yahoo.py b/tests/livetest/test_broker_yahoo.py new file mode 100644 index 0000000..2bf01ef --- /dev/null +++ b/tests/livetest/test_broker_yahoo.py @@ -0,0 +1,216 @@ +import datetime as dt +import unittest +from unittest.mock import MagicMock, patch + +import pandas as pd +from _test_util import mock_utc_current_time + +from harvest.broker.yahoo import YahooBroker +from harvest.util.helper import data_to_occ +from tests.livetest.test_broker_base import TestBroker + + +def _mock_yf_download(symbol, period, interval, **_): + delta = None + if period == "1d": + delta = dt.timedelta(days=1) + elif period == "5d": + delta = dt.timedelta(days=5) + elif period == "1mo": + delta = dt.timedelta(days=30) + elif period == "3mo": + delta = dt.timedelta(days=90) + elif period == "1y": + delta = dt.timedelta(days=365) + elif period == "5y": + delta = dt.timedelta(days=365 * 5) + elif period == "max": + delta = dt.timedelta(days=365 * 10) + + end = mock_utc_current_time() + start = end - delta + + if interval == "1m": + freq = "min" + delta = (end - start).total_seconds() / 60 + elif interval == "5m": + freq = "5min" + delta = (end - start).total_seconds() / 60 / 5 + elif interval == "15m": + freq = "15min" + delta = (end - start).total_seconds() / 60 / 15 + elif interval == "30m": + freq = "30min" + delta = (end - start).total_seconds() / 60 / 30 + elif interval == "1h": + freq = "H" + delta = (end - start).total_seconds() / 60 / 60 + elif interval == "1d": + freq = "D" + delta = (end - start).days + else: + raise ValueError(f"Invalid interval: {interval}") + + delta = int(delta) + 1 + + data_range = pd.date_range(start, periods=delta, freq=freq) + + symbols = symbol.split(" ") + + if len(symbols) > 1: + df_columns = { + "Price": ["Open"] * len(symbols) + + ["High"] * len(symbols) + + ["Low"] * len(symbols) + + ["Close"] * len(symbols) + + ["Adj Close"] * len(symbols) + + ["Volume"] * len(symbols), + "Ticker": symbols * 6, + # "Date": [0] * len(symbols) * 6, + } + + dummy_df = pd.DataFrame(df_columns) + dummy_df.set_index(["Price", "Ticker"], inplace=True) + dummy_df = dummy_df.T + dummy_df["Date"] = data_range + dummy_df.set_index("Date", inplace=True) + dummy_df.index = data_range + + # TODO: Populate each column with dummy data + return dummy_df + + else: + dummy_df = pd.DataFrame( + { + "Open": [1.0] * delta, + "High": [2.0] * delta, + "Low": [0.5] * delta, + "Close": [1.5] * delta, + "Adj Close": [1.5] * delta, + "Volume": [1000] * delta, + } + ) + dummy_df.index = data_range + return dummy_df + + +def mock_yf_options(): + return ("2008-09-15", "2008-10-15") + + +def mock_yf_option_chain(date): + """ + contractSymbol lastTradeDate strike lastPrice bid ask change percentChange volume openInterest impliedVolatility inTheMoney contractSize currency + 0 SPY240614P00220000 2024-05-13 13:55:14+00:00 220.0 0.02 ... ... ... ... ... ... 0.937501 False REGULAR USD + """ + date = pd.to_datetime(date) + dummy_df = pd.DataFrame( + { + "contractSymbol": [data_to_occ("SPY", date, "P", 220.0)], + "lastTradeDate": [date], + "strike": [220.0], + "lastPrice": [0.02], + "bid": [0.01], + "ask": [0.03], + "change": [0.01], + "percentChange": [0.5], + "volume": [100], + "openInterest": [1000], + "impliedVolatility": [0.937501], + "inTheMoney": [False], + "contractSize": ["REGULAR"], + "currency": ["USD"], + } + ) + + # Repeat the same row 10 times + dummy_df = pd.concat([dummy_df] * 10, ignore_index=True) + + return dummy_df + + +class TestYahooBroker(TestBroker, unittest.TestCase): + def setUp(self): + super().setUp() + self._define_patch("yfinance.download", _mock_yf_download) + + def test_fetch_stock_price(self): + """ + Test fetching stock price history + The returned DataFrame should be in the format: + [Ticker] + open high low close volume + timestamp + + Where timestamp is a pandas datetime object in UTC timezone, + and open, high, low, close, and volume are float values. + """ + + super().test_fetch_stock_price(YahooBroker) + + def test_fetch_stock_price_timezone(self): + """ + Test that the price history returned + correctly adjusts the input to utc timezone. + """ + + super().test_fetch_stock_price_timezone(YahooBroker) + + def test_fetch_stock_price_str_input(self): + """ + Test fetching stock price history using Yahoo Broker + with string input for start and end dates. + As with datetime objects, time is converted from local timezone to UTC. + """ + + super().test_fetch_stock_price_str_input(YahooBroker) + + def test_setup(self): + """ + Test that the broker is correctly set up with the stats and account objects. + """ + + super().test_setup(YahooBroker) + + def test_main(self): + """ + Test that the main function is called with the correct security data. + """ + + super().test_main(YahooBroker) + + @patch("yfinance.Ticker") + def test_chain_info(self, mock_ticker): + """ + Test that the broker can fetch option chain information. + """ + instance = mock_ticker.return_value + instance.options = mock_yf_options() + + super().test_chain_info(YahooBroker) + + @patch("yfinance.Ticker") + def test_chain_data(self, mock_ticker): + """ + Test that the broker can fetch option chain data. + """ + instance = mock_ticker.return_value + instance.options = mock_yf_options() + + """ + Mock the "Option" class that yFiance returns when fetching option chain data. + """ + + def return_option_chain(date): + mock_option_class = MagicMock() + mock_option_class.calls = mock_yf_option_chain(date) + mock_option_class.puts = mock_yf_option_chain(date) + return mock_option_class + + instance.option_chain = return_option_chain + + super().test_chain_data(YahooBroker) + + +if __name__ == "__main__": + unittest.main() From 878a11aad8fb8c29e723953acf0493a3da8aac63 Mon Sep 17 00:00:00 2001 From: tfukaza Date: Sat, 1 Jun 2024 18:01:03 -0700 Subject: [PATCH 2/2] Update --- tests/livetest/_test_util.py | 23 ------------------- tests/unittest/_util.py | 23 +++++++++++++++++++ .../test_broker_base.py | 2 +- .../test_broker_yahoo.py | 4 ++-- 4 files changed, 26 insertions(+), 26 deletions(-) delete mode 100644 tests/livetest/_test_util.py rename tests/{livetest => unittest}/test_broker_base.py (99%) rename tests/{livetest => unittest}/test_broker_yahoo.py (98%) diff --git a/tests/livetest/_test_util.py b/tests/livetest/_test_util.py deleted file mode 100644 index 76f2767..0000000 --- a/tests/livetest/_test_util.py +++ /dev/null @@ -1,23 +0,0 @@ -import datetime as dt - -""" -For testing purposes, assume: -- Current day is September 15th, 2008 -- Current time is 10:00 AM -- Current timezone is US/Eastern -""" - - -def mock_get_local_timezone(): - """ - Return the US/Eastern timezone - """ - return dt.timezone(dt.timedelta(hours=-4)) - - -def mock_utc_current_time(): - """ - Return the current time in UTC timezone - """ - d = dt.datetime(2008, 9, 15, 10, 0, 0, tzinfo=dt.timezone(dt.timedelta(hours=-4))) - return d.astimezone(dt.timezone.utc) diff --git a/tests/unittest/_util.py b/tests/unittest/_util.py index a520637..1dd387a 100644 --- a/tests/unittest/_util.py +++ b/tests/unittest/_util.py @@ -1,9 +1,32 @@ +import datetime as dt import functools import os from harvest.algo import BaseAlgo from harvest.trader.trader import BrokerHub +""" +For testing purposes, assume: +- Current day is September 15th, 2008 +- Current time is 10:00 AM +- Current timezone is US/Eastern +""" + + +def mock_get_local_timezone(): + """ + Return the US/Eastern timezone + """ + return dt.timezone(dt.timedelta(hours=-4)) + + +def mock_utc_current_time(): + """ + Return the current time in UTC timezone + """ + d = dt.datetime(2008, 9, 15, 10, 0, 0, tzinfo=dt.timezone(dt.timedelta(hours=-4))) + return d.astimezone(dt.timezone.utc) + def create_trader_and_api( streamer=None, diff --git a/tests/livetest/test_broker_base.py b/tests/unittest/test_broker_base.py similarity index 99% rename from tests/livetest/test_broker_base.py rename to tests/unittest/test_broker_base.py index 5009934..af64d3c 100644 --- a/tests/livetest/test_broker_base.py +++ b/tests/unittest/test_broker_base.py @@ -4,7 +4,7 @@ from unittest import mock import pandas as pd -from _test_util import mock_get_local_timezone, mock_utc_current_time +from _util import mock_get_local_timezone, mock_utc_current_time from harvest.definitions import Account, Stats from harvest.enum import Interval diff --git a/tests/livetest/test_broker_yahoo.py b/tests/unittest/test_broker_yahoo.py similarity index 98% rename from tests/livetest/test_broker_yahoo.py rename to tests/unittest/test_broker_yahoo.py index 2bf01ef..edf0ce4 100644 --- a/tests/livetest/test_broker_yahoo.py +++ b/tests/unittest/test_broker_yahoo.py @@ -3,11 +3,11 @@ from unittest.mock import MagicMock, patch import pandas as pd -from _test_util import mock_utc_current_time +from _util import mock_utc_current_time from harvest.broker.yahoo import YahooBroker from harvest.util.helper import data_to_occ -from tests.livetest.test_broker_base import TestBroker +from tests.unittest.test_broker_base import TestBroker def _mock_yf_download(symbol, period, interval, **_):