Skip to content

Commit

Permalink
Refactor SubwayFeed API key handling and update related tests (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
nolanbconaway authored Nov 10, 2024
1 parent 032996e commit f32135d
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 93 deletions.
15 changes: 1 addition & 14 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,16 @@ Or if you'd like to live dangerously:
pip install git+https://github.com/nolanbconaway/underground.git#egg=underground
```

To request data from the MTA, you'll also need a free API key.
[Register here](https://api.mta.info/).

## Python API

Once you have your API key, use the Python API like:
Use the Python API like:

``` python
import os

from underground import metadata, SubwayFeed

API_KEY = os.getenv('MTA_API_KEY')
ROUTE = 'Q'
feed = SubwayFeed.get(ROUTE, api_key=API_KEY)

# request will read from $MTA_API_KEY if a key is not provided
feed = SubwayFeed.get(ROUTE)

# under the hood, the Q route is mapped to a URL. This call is equivalent:
Expand Down Expand Up @@ -93,9 +86,6 @@ Usage: underground feed [OPTIONS] ROUTE_OR_URL
underground feed $URL --json > feed_nrqw.json
Options:
--api-key TEXT MTA API key. Will be read from $MTA_API_KEY if not
provided.
--json Option to output the feed data as JSON. Otherwise
output will be bytes.
Expand All @@ -120,8 +110,6 @@ Options:
unix timestamp.
-r, --retries INTEGER Retry attempts in case of API connection failure.
Default 100.
--api-key TEXT MTA API key. Will be read from $MTA_API_KEY if not
provided.
-t, --timezone TEXT Output timezone. Ignored if --epoch. Default to NYC
time.
-s, --stalled-timeout INTEGER Number of seconds between the last movement
Expand All @@ -135,7 +123,6 @@ Options:
Stops are printed to stdout in the format `stop_id t1 t2 ... tn` .

``` sh
$ export MTA_API_KEY='...'
$ underground stops Q | tail -2
Q05S 19:01 19:09 19:16 19:25 19:34 19:44 19:51 19:58
Q04S 19:03 19:11 19:18 19:27 19:36 19:46 19:53 20:00
Expand Down
9 changes: 1 addition & 8 deletions src/underground/cli/feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,6 @@

@click.command()
@click.argument("route_or_url")
@click.option(
"--api-key",
"api_key",
default=None,
help="MTA API key. Will be read from $MTA_API_KEY if not provided.",
)
@click.option(
"--json",
"output_json",
Expand All @@ -29,7 +23,7 @@
type=int,
help="Retry attempts in case of API connection failure. Default 100.",
)
def main(route_or_url, api_key, output_json, retries):
def main(route_or_url, output_json, retries):
"""Request an MTA feed via a route or URL.
ROUTE_OR_URL may be either a feed URL or a route (which will be used to look up
Expand All @@ -53,7 +47,6 @@ def main(route_or_url, api_key, output_json, retries):
data = feed.request_robust(
route_or_url=route_or_url,
retries=retries,
api_key=api_key,
return_dict=output_json,
)

Expand Down
10 changes: 2 additions & 8 deletions src/underground/cli/stops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,6 @@ def datetime_to_epoch(dttm: datetime.datetime) -> int:
type=int,
help="Retry attempts in case of API connection failure. Default 100.",
)
@click.option(
"--api-key",
"api_key",
default=None,
help="MTA API key. Will be read from $MTA_API_KEY if not provided.",
)
@click.option(
"-t",
"--timezone",
Expand All @@ -54,10 +48,10 @@ def datetime_to_epoch(dttm: datetime.datetime) -> int:
" update before considering a train stalled. Default is 90 as recommended"
" by the MTA. Numbers less than 1 disable this check.",
)
def main(route, fmt, retries, api_key, timezone, stalled_timeout):
def main(route, fmt, retries, timezone, stalled_timeout):
"""Print out train departure times for all stops on a subway line."""
stops = (
SubwayFeed.get(api_key=api_key, route_or_url=route, retries=retries)
SubwayFeed.get(route_or_url=route, retries=retries)
.extract_stop_dict(timezone=timezone, stalled_timeout=stalled_timeout)
.get(route, dict())
)
Expand Down
30 changes: 6 additions & 24 deletions src/underground/feed.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Interact with the MTA GTFS api."""

import os
import time
import typing

Expand All @@ -22,7 +21,7 @@ def load_protobuf(protobuf_bytes: bytes) -> dict:
Parameters
----------
protobuf_bytes : bytes
Protobuuf data, as returned from the raw request.
Protobuf data, as returned from the raw request.
Returns
-------
Expand All @@ -38,7 +37,7 @@ def load_protobuf(protobuf_bytes: bytes) -> dict:
return feed_dict


def request(route_or_url: str, api_key: typing.Optional[str] = None) -> bytes:
def request(route_or_url: str) -> bytes:
"""Send a HTTP GET request to the MTA for realtime feed data.
Occassionally a feed is requested as the MTA is writing updated data to the file,
Expand All @@ -49,9 +48,6 @@ def request(route_or_url: str, api_key: typing.Optional[str] = None) -> bytes:
----------
route_or_url : str
Route ID or feed url (per ``https://api.mta.info/#/subwayRealTimeFeeds``).
api_key : str
MTA API key. If not provided, it will be read from the $MTA_API_KEY env
variable.
Returns
-------
Expand All @@ -62,26 +58,15 @@ def request(route_or_url: str, api_key: typing.Optional[str] = None) -> bytes:
# check feed
url = metadata.resolve_url(route_or_url)

# get the API key.
api_key = api_key or os.getenv("MTA_API_KEY", None)
if api_key is None:
raise ValueError(
"No API key. pass to the called function "
"or set the $MTA_API_KEY environment variable."
)

# make the request
res = requests.get(url, headers={"x-api-key": api_key})
res = requests.get(url)
res.raise_for_status()

return res.content


def request_robust(
route_or_url: str,
retries: int = 100,
api_key: typing.Optional[str] = None,
return_dict: bool = False,
route_or_url: str, retries: int = 100, return_dict: bool = False
) -> typing.Union[dict, bytes]:
"""Request feed data with validations and retries.
Expand All @@ -97,9 +82,6 @@ def request_robust(
retries : int
Number of retry attempts, with 1 second timeout between attempts.
Set to -1 for unlimited. Default 100.
api_key : str
MTA API key. If not provided, it will be read from the $MTA_API_KEY env
variable.
return_dict : bool
Option to return the process data as a dict rather than as raw protobuf data.
This is equivalent to running ``load_protobuf(request_robust(...))``.
Expand All @@ -112,7 +94,7 @@ def request_robust(
"""
# get protobuf bytes
protobuf_data = request(route_or_url=route_or_url, api_key=api_key)
protobuf_data = request(route_or_url=route_or_url)
for attempt in range(retries + 1):
try:
feed_dict = load_protobuf(protobuf_data)
Expand All @@ -125,6 +107,6 @@ def request_robust(

# wait 1 second and then make new protobuf data
time.sleep(1) # be cool to the MTA
protobuf_data = request(route_or_url=route_or_url, api_key=api_key)
protobuf_data = request(route_or_url=route_or_url)

return feed_dict if return_dict else protobuf_data
18 changes: 4 additions & 14 deletions src/underground/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,8 @@ class SubwayFeed(pydantic.BaseModel):
header: FeedHeader
entity: list[Entity]

@staticmethod
def get(
route_or_url: str, retries: int = 100, api_key: typing.Optional[str] = None
) -> "SubwayFeed":
@classmethod
def get(cls, route_or_url: str, retries: int = 100) -> "SubwayFeed":
"""Request feed data from the MTA.
Parameters
Expand All @@ -193,23 +191,15 @@ def get(
retries : int
Number of retry attempts, with 1 second timeout between attempts.
Set to -1 for unlimited. Default 100.
api_key : str
MTA API key. If not provided, it will be read from the $MTA_API_KEY env
variable.
Returns
-------
SubwayFeed
An instance of the SubwayFeed class with the requested data.
"""
return SubwayFeed(
**feed.request_robust(
route_or_url=route_or_url,
retries=retries,
api_key=api_key,
return_dict=True,
)
return cls(
**feed.request_robust(route_or_url=route_or_url, retries=retries, return_dict=True)
)

def extract_stop_dict(
Expand Down
19 changes: 7 additions & 12 deletions test/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test the CLI."""

import io
import json
import os
Expand Down Expand Up @@ -52,9 +53,7 @@ def test_stops_epoch(monkeypatch):
},
],
}
monkeypatch.setattr(
"underground.SubwayFeed.get", lambda *x, **y: SubwayFeed(**sample_data)
)
monkeypatch.setattr("underground.SubwayFeed.get", lambda *x, **y: SubwayFeed(**sample_data))
runner = CliRunner()
result = runner.invoke(stops_cli.main, ["1", "-f", "epoch"])
assert result.exit_code == 0
Expand Down Expand Up @@ -87,9 +86,7 @@ def test_stops_format(monkeypatch):
],
}

monkeypatch.setattr(
"underground.SubwayFeed.get", lambda *x, **y: SubwayFeed(**sample_data)
)
monkeypatch.setattr("underground.SubwayFeed.get", lambda *x, **y: SubwayFeed(**sample_data))
runner = CliRunner()

# year
Expand Down Expand Up @@ -130,9 +127,7 @@ def test_stops_timezone(monkeypatch):
],
}

monkeypatch.setattr(
"underground.SubwayFeed.get", lambda *x, **y: SubwayFeed(**sample_data)
)
monkeypatch.setattr("underground.SubwayFeed.get", lambda *x, **y: SubwayFeed(**sample_data))
runner = CliRunner()

# in hong kong time
Expand All @@ -155,7 +150,7 @@ def test_feed_bytes(requests_mock, filename):
requests_mock.get(requests_mock_any, content=file.read())

runner = CliRunner()
result = runner.invoke(feed_cli.main, ["1", "--api-key", "FAKE"])
result = runner.invoke(feed_cli.main, ["1"])
assert result.exit_code == 0
assert "entity" in load_protobuf(result.stdout_bytes)

Expand All @@ -167,7 +162,7 @@ def test_feed_json(requests_mock, filename):
requests_mock.get(requests_mock_any, content=file.read())

runner = CliRunner()
result = runner.invoke(feed_cli.main, ["1", "--json", "--api-key", "FAKE"])
result = runner.invoke(feed_cli.main, ["1", "--json"])
assert result.exit_code == 0
assert "entity" in json.loads(result.output)

Expand Down Expand Up @@ -200,7 +195,7 @@ def test_stopstxt_json(monkeypatch, args):
lambda: content,
)
runner = CliRunner()
result = runner.invoke(findstops_cli.main, args + ["--json"])
result = runner.invoke(findstops_cli.main, [*args, "--json"])
assert result.exit_code == 0

for stop in json.loads(result.output):
Expand Down
17 changes: 5 additions & 12 deletions test/test_feed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test the feed submodule."""

import os
import time

Expand Down Expand Up @@ -37,7 +38,7 @@ def mock_load_protobuf(*a):

time_1 = time.time()
with pytest.raises(feed.EmptyFeedError):
feed.request_robust("1", retries=retries, api_key="FAKE")
feed.request_robust("1", retries=retries)
elapsed = time.time() - time_1

assert elapsed >= retries
Expand All @@ -62,21 +63,13 @@ def test_request_invalid_feed():
feed.request("NOT REAL")


def test_request_no_api_key(monkeypatch):
"""Test that request raises value error when no api key is available."""
monkeypatch.delenv("MTA_API_KEY", raising=False)

with pytest.raises(ValueError):
feed.request(next(iter(metadata.VALID_FEED_URLS)))


@pytest.mark.parametrize("ret_code", [200, 500])
def test_request_raise_status(requests_mock, ret_code):
"""Test the request raise status conditional."""
feed_url = next(iter(metadata.VALID_FEED_URLS))
requests_mock.get(requests_mock_any, content="".encode(), status_code=ret_code)
requests_mock.get(requests_mock_any, content=b"", status_code=ret_code)
if ret_code != 200:
with pytest.raises(requests.HTTPError):
feed.request(feed_url, api_key="FAKE")
feed.request(feed_url)
else:
feed.request(feed_url, api_key="FAKE")
feed.request(feed_url)
2 changes: 1 addition & 1 deletion test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_get(requests_mock, filename):
return_value = file.read()

requests_mock.get(requests_mock_any, content=return_value)
feed = SubwayFeed.get("1", api_key="FAKE") ## valid route but not used at all
feed = SubwayFeed.get("1") ## valid route but not used at all

assert isinstance(feed, SubwayFeed)
assert isinstance(feed.extract_stop_dict(), dict)
Expand Down

0 comments on commit f32135d

Please sign in to comment.