diff --git a/lnurl/core.py b/lnurl/core.py index be06cb4..f6a4f66 100644 --- a/lnurl/core.py +++ b/lnurl/core.py @@ -7,9 +7,19 @@ from .exceptions import InvalidLnurl, InvalidUrl, LnurlResponseException from .helpers import lnurlauth_signature, url_encode -from .models import LnurlAuthResponse, LnurlPayResponse, LnurlResponse, LnurlResponseModel, LnurlWithdrawResponse +from .models import ( + LnurlAuthResponse, + LnurlPayActionResponse, + LnurlPayResponse, + LnurlResponse, + LnurlResponseModel, + LnurlWithdrawResponse, +) from .types import ClearnetUrl, DebugUrl, LnAddress, Lnurl, OnionUrl +USER_AGENT = "lnbits/lnurl" +TIMEOUT = 5 + def decode(bech32_lnurl: str) -> Union[OnionUrl, ClearnetUrl, DebugUrl]: try: @@ -25,13 +35,20 @@ def encode(url: str) -> Lnurl: raise InvalidUrl -async def get(url: str, *, response_class: Optional[Any] = None) -> LnurlResponseModel: - async with httpx.AsyncClient() as client: +async def get( + url: str, + *, + response_class: Optional[Any] = None, + user_agent: Optional[str] = None, + timeout: Optional[int] = None, +) -> LnurlResponseModel: + headers = {"User-Agent": user_agent or USER_AGENT} + async with httpx.AsyncClient(headers=headers, follow_redirects=True) as client: try: - res = await client.get(url) + res = await client.get(url, timeout=timeout or TIMEOUT) res.raise_for_status() - except Exception as e: - raise LnurlResponseException(str(e)) + except Exception as exc: + raise LnurlResponseException(str(exc)) from exc if response_class: assert issubclass(response_class, LnurlResponseModel), "Use a valid `LnurlResponseModel` subclass." @@ -43,11 +60,13 @@ async def get(url: str, *, response_class: Optional[Any] = None) -> LnurlRespons async def handle( bech32_lnurl: str, response_class: Optional[LnurlResponseModel] = None, + user_agent: Optional[str] = None, + timeout: Optional[int] = None, ) -> LnurlResponseModel: try: if "@" in bech32_lnurl: lnaddress = LnAddress(bech32_lnurl) - return await get(lnaddress.url, response_class=response_class) + return await get(lnaddress.url, response_class=response_class, user_agent=user_agent, timeout=timeout) lnurl = Lnurl(bech32_lnurl) except (ValidationError, ValueError): raise InvalidLnurl @@ -55,53 +74,81 @@ async def handle( if lnurl.is_login: return LnurlAuthResponse(callback=lnurl.url, k1=lnurl.url.query_params["k1"]) - return await get(lnurl.url, response_class=response_class) + return await get(lnurl.url, response_class=response_class, user_agent=user_agent, timeout=timeout) -async def execute(bech32_or_address: str, value: str) -> LnurlResponseModel: +async def execute( + bech32_or_address: str, + value: str, + user_agent: Optional[str] = None, + timeout: Optional[int] = None, +) -> LnurlResponseModel: try: - res = handle(bech32_or_address) + res = await handle(bech32_or_address, user_agent=user_agent, timeout=timeout) except Exception as exc: raise LnurlResponseException(str(exc)) if isinstance(res, LnurlPayResponse) and res.tag == "payRequest": - return await execute_pay_request(res, value) + return await execute_pay_request(res, value, user_agent=user_agent, timeout=timeout) elif isinstance(res, LnurlAuthResponse) and res.tag == "login": - return await execute_login(res, value) + return await execute_login(res, value, user_agent=user_agent, timeout=timeout) elif isinstance(res, LnurlWithdrawResponse) and res.tag == "withdrawRequest": - return await execute_withdraw(res, value) + return await execute_withdraw(res, value, user_agent=user_agent, timeout=timeout) raise LnurlResponseException(f"{res.tag} not implemented") # type: ignore -async def execute_pay_request(res: LnurlPayResponse, msat: str) -> LnurlResponseModel: +async def execute_pay_request( + res: LnurlPayResponse, + msat: str, + user_agent: Optional[str] = None, + timeout: Optional[int] = None, +) -> LnurlResponseModel: if not res.min_sendable <= MilliSatoshi(msat) <= res.max_sendable: raise LnurlResponseException(f"Amount {msat} not in range {res.min_sendable} - {res.max_sendable}") + try: - async with httpx.AsyncClient() as client: + headers = {"User-Agent": user_agent or USER_AGENT} + async with httpx.AsyncClient(headers=headers, follow_redirects=True) as client: res2 = await client.get( url=res.callback, params={ "amount": msat, }, + timeout=timeout or TIMEOUT, ) res2.raise_for_status() - return LnurlResponse.from_dict(res2.json()) + pay_res = LnurlResponse.from_dict(res2.json()) + assert isinstance(pay_res, LnurlPayActionResponse), "Invalid response in execute_pay_request." + invoice = bolt11_decode(pay_res.pr) + if invoice.amount_msat != int(msat): + raise LnurlResponseException( + f"{res.callback.host} returned an invalid invoice." + f"Excepted `{msat}` msat, got `{invoice.amount_msat}`." + ) + return pay_res except Exception as exc: raise LnurlResponseException(str(exc)) -async def execute_login(res: LnurlAuthResponse, secret: str) -> LnurlResponseModel: +async def execute_login( + res: LnurlAuthResponse, + secret: str, + user_agent: Optional[str] = None, + timeout: Optional[int] = None, +) -> LnurlResponseModel: try: assert res.callback.host, "LNURLauth host does not exist" key, sig = lnurlauth_signature(res.callback.host, secret, res.k1) - async with httpx.AsyncClient() as client: + headers = {"User-Agent": user_agent or USER_AGENT} + async with httpx.AsyncClient(headers=headers, follow_redirects=True) as client: res2 = await client.get( url=res.callback, params={ "key": key, "sig": sig, }, + timeout=timeout or TIMEOUT, ) res2.raise_for_status() return LnurlResponse.from_dict(res2.json()) @@ -109,7 +156,12 @@ async def execute_login(res: LnurlAuthResponse, secret: str) -> LnurlResponseMod raise LnurlResponseException(str(e)) -async def execute_withdraw(res: LnurlWithdrawResponse, pr: str) -> LnurlResponseModel: +async def execute_withdraw( + res: LnurlWithdrawResponse, + pr: str, + user_agent: Optional[str] = None, + timeout: Optional[int] = None, +) -> LnurlResponseModel: try: invoice = bolt11_decode(pr) except Bolt11Exception as exc: @@ -119,13 +171,15 @@ async def execute_withdraw(res: LnurlWithdrawResponse, pr: str) -> LnurlResponse if not res.min_withdrawable <= MilliSatoshi(amount) <= res.max_withdrawable: raise LnurlResponseException(f"Amount {amount} not in range {res.min_withdrawable} - {res.max_withdrawable}") try: - async with httpx.AsyncClient() as client: + headers = {"User-Agent": user_agent or USER_AGENT} + async with httpx.AsyncClient(headers=headers, follow_redirects=True) as client: res2 = await client.get( url=res.callback, params={ "k1": res.k1, "pr": pr, }, + timeout=timeout or TIMEOUT, ) res2.raise_for_status() return LnurlResponse.from_dict(res2.json()) diff --git a/lnurl/models.py b/lnurl/models.py index 5dbea0b..9902ce9 100644 --- a/lnurl/models.py +++ b/lnurl/models.py @@ -138,6 +138,7 @@ class LnurlPayActionResponse(LnurlResponseModel): pr: LightningInvoice success_action: Optional[Union[MessageAction, UrlAction, AesAction]] = Field(None, alias="successAction") routes: List[List[LnurlPayRouteHop]] = [] + verify: Optional[str] = None class LnurlWithdrawResponse(LnurlResponseModel): diff --git a/tests/test_core.py b/tests/test_core.py index d3b6688..8867c4b 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -72,8 +72,9 @@ def test_encode_nohttps(self, url): class TestHandle: - """Responses from the LNURL: https://legend.lnbits.com/""" + """Responses from the LNURL: https://demo.lnbits.com/""" + @pytest.mark.xfail(reason="legend.lnbits.com is down") @pytest.mark.asyncio @pytest.mark.parametrize( "bech32", @@ -86,7 +87,7 @@ async def test_handle_withdraw(self, bech32): res = await handle(bech32) assert isinstance(res, LnurlWithdrawResponse) assert res.tag == "withdrawRequest" - assert res.callback.host == "legend.lnbits.com" + assert res.callback.host == "demo.lnbits.com" assert res.default_description == "sample withdraw" assert res.max_withdrawable >= res.min_withdrawable @@ -104,8 +105,9 @@ async def test_get_requests_error(self, url): class TestPayFlow: - """Full LNURL-pay flow interacting with https://legend.lnbits.com/""" + """Full LNURL-pay flow interacting with https://demo.lnbits.com/""" + @pytest.mark.xfail(reason="legend.lnbits.com is down") @pytest.mark.asyncio @pytest.mark.parametrize( "bech32, amount", @@ -114,14 +116,14 @@ class TestPayFlow: "LNURL1DP68GURN8GHJ7MR9VAJKUEPWD3HXY6T5WVHXXMMD9AKXUATJD3CZ7JN9F4EHQJQC25ZZY", "1000", ), - ("donate@legend.lnbits.com", "100000"), + ("donate@demo.lnbits.com", "100000"), ], ) async def test_pay_flow(self, bech32: str, amount: str): res = await handle(bech32) assert isinstance(res, LnurlPayResponse) assert res.tag == "payRequest" - assert res.callback.host == "legend.lnbits.com" + assert res.callback.host == "demo.lnbits.com" assert len(res.metadata.list()) >= 1 assert res.metadata.text != ""