Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add follow_redirect, timeout, useragent #33

Merged
merged 5 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 74 additions & 20 deletions lnurl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,19 @@

from .exceptions import InvalidLnurl, InvalidUrl, LnurlResponseException
from .helpers import lnurlauth_signature, url_encode
from .models import LnurlAuthResponse, LnurlPayResponse, LnurlResponse, LnurlResponseModel, LnurlWithdrawResponse
from .models import (
LnurlAuthResponse,
LnurlPayActionResponse,
LnurlPayResponse,
LnurlResponse,
LnurlResponseModel,
LnurlWithdrawResponse,
)
from .types import ClearnetUrl, DebugUrl, LnAddress, Lnurl, OnionUrl

USER_AGENT = "lnbits/lnurl"
TIMEOUT = 5


def decode(bech32_lnurl: str) -> Union[OnionUrl, ClearnetUrl, DebugUrl]:
try:
Expand All @@ -25,13 +35,20 @@ def encode(url: str) -> Lnurl:
raise InvalidUrl


async def get(url: str, *, response_class: Optional[Any] = None) -> LnurlResponseModel:
async with httpx.AsyncClient() as client:
async def get(
url: str,
*,
response_class: Optional[Any] = None,
user_agent: Optional[str] = None,
timeout: Optional[int] = None,
) -> LnurlResponseModel:
headers = {"User-Agent": user_agent or USER_AGENT}
async with httpx.AsyncClient(headers=headers, follow_redirects=True) as client:
try:
res = await client.get(url)
res = await client.get(url, timeout=timeout or TIMEOUT)
res.raise_for_status()
except Exception as e:
raise LnurlResponseException(str(e))
except Exception as exc:
raise LnurlResponseException(str(exc)) from exc

if response_class:
assert issubclass(response_class, LnurlResponseModel), "Use a valid `LnurlResponseModel` subclass."
Expand All @@ -43,73 +60,108 @@ async def get(url: str, *, response_class: Optional[Any] = None) -> LnurlRespons
async def handle(
bech32_lnurl: str,
response_class: Optional[LnurlResponseModel] = None,
user_agent: Optional[str] = None,
timeout: Optional[int] = None,
) -> LnurlResponseModel:
try:
if "@" in bech32_lnurl:
lnaddress = LnAddress(bech32_lnurl)
return await get(lnaddress.url, response_class=response_class)
return await get(lnaddress.url, response_class=response_class, user_agent=user_agent, timeout=timeout)
lnurl = Lnurl(bech32_lnurl)
except (ValidationError, ValueError):
raise InvalidLnurl

if lnurl.is_login:
return LnurlAuthResponse(callback=lnurl.url, k1=lnurl.url.query_params["k1"])

return await get(lnurl.url, response_class=response_class)
return await get(lnurl.url, response_class=response_class, user_agent=user_agent, timeout=timeout)


async def execute(bech32_or_address: str, value: str) -> LnurlResponseModel:
async def execute(
bech32_or_address: str,
value: str,
user_agent: Optional[str] = None,
timeout: Optional[int] = None,
) -> LnurlResponseModel:
try:
res = handle(bech32_or_address)
res = await handle(bech32_or_address, user_agent=user_agent, timeout=timeout)
except Exception as exc:
raise LnurlResponseException(str(exc))

if isinstance(res, LnurlPayResponse) and res.tag == "payRequest":
return await execute_pay_request(res, value)
return await execute_pay_request(res, value, user_agent=user_agent, timeout=timeout)
elif isinstance(res, LnurlAuthResponse) and res.tag == "login":
return await execute_login(res, value)
return await execute_login(res, value, user_agent=user_agent, timeout=timeout)
elif isinstance(res, LnurlWithdrawResponse) and res.tag == "withdrawRequest":
return await execute_withdraw(res, value)
return await execute_withdraw(res, value, user_agent=user_agent, timeout=timeout)

raise LnurlResponseException(f"{res.tag} not implemented") # type: ignore


async def execute_pay_request(res: LnurlPayResponse, msat: str) -> LnurlResponseModel:
async def execute_pay_request(
res: LnurlPayResponse,
msat: str,
user_agent: Optional[str] = None,
timeout: Optional[int] = None,
) -> LnurlResponseModel:
if not res.min_sendable <= MilliSatoshi(msat) <= res.max_sendable:
raise LnurlResponseException(f"Amount {msat} not in range {res.min_sendable} - {res.max_sendable}")

try:
async with httpx.AsyncClient() as client:
headers = {"User-Agent": user_agent or USER_AGENT}
async with httpx.AsyncClient(headers=headers, follow_redirects=True) as client:
res2 = await client.get(
url=res.callback,
params={
"amount": msat,
},
timeout=timeout or TIMEOUT,
)
res2.raise_for_status()
return LnurlResponse.from_dict(res2.json())
pay_res = LnurlResponse.from_dict(res2.json())
assert isinstance(pay_res, LnurlPayActionResponse), "Invalid response in execute_pay_request."
invoice = bolt11_decode(pay_res.pr)
if invoice.amount_msat != int(msat):
raise LnurlResponseException(
f"{res.callback.host} returned an invalid invoice."
f"Excepted `{msat}` msat, got `{invoice.amount_msat}`."
)
return pay_res
except Exception as exc:
raise LnurlResponseException(str(exc))


async def execute_login(res: LnurlAuthResponse, secret: str) -> LnurlResponseModel:
async def execute_login(
res: LnurlAuthResponse,
secret: str,
user_agent: Optional[str] = None,
timeout: Optional[int] = None,
) -> LnurlResponseModel:
try:
assert res.callback.host, "LNURLauth host does not exist"
key, sig = lnurlauth_signature(res.callback.host, secret, res.k1)
async with httpx.AsyncClient() as client:
headers = {"User-Agent": user_agent or USER_AGENT}
async with httpx.AsyncClient(headers=headers, follow_redirects=True) as client:
res2 = await client.get(
url=res.callback,
params={
"key": key,
"sig": sig,
},
timeout=timeout or TIMEOUT,
)
res2.raise_for_status()
return LnurlResponse.from_dict(res2.json())
except Exception as e:
raise LnurlResponseException(str(e))


async def execute_withdraw(res: LnurlWithdrawResponse, pr: str) -> LnurlResponseModel:
async def execute_withdraw(
res: LnurlWithdrawResponse,
pr: str,
user_agent: Optional[str] = None,
timeout: Optional[int] = None,
) -> LnurlResponseModel:
try:
invoice = bolt11_decode(pr)
except Bolt11Exception as exc:
Expand All @@ -119,13 +171,15 @@ async def execute_withdraw(res: LnurlWithdrawResponse, pr: str) -> LnurlResponse
if not res.min_withdrawable <= MilliSatoshi(amount) <= res.max_withdrawable:
raise LnurlResponseException(f"Amount {amount} not in range {res.min_withdrawable} - {res.max_withdrawable}")
try:
async with httpx.AsyncClient() as client:
headers = {"User-Agent": user_agent or USER_AGENT}
async with httpx.AsyncClient(headers=headers, follow_redirects=True) as client:
res2 = await client.get(
url=res.callback,
params={
"k1": res.k1,
"pr": pr,
},
timeout=timeout or TIMEOUT,
)
res2.raise_for_status()
return LnurlResponse.from_dict(res2.json())
Expand Down
1 change: 1 addition & 0 deletions lnurl/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ class LnurlPayActionResponse(LnurlResponseModel):
pr: LightningInvoice
success_action: Optional[Union[MessageAction, UrlAction, AesAction]] = Field(None, alias="successAction")
routes: List[List[LnurlPayRouteHop]] = []
verify: Optional[str] = None


class LnurlWithdrawResponse(LnurlResponseModel):
Expand Down
12 changes: 7 additions & 5 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,9 @@ def test_encode_nohttps(self, url):


class TestHandle:
"""Responses from the LNURL: https://legend.lnbits.com/"""
"""Responses from the LNURL: https://demo.lnbits.com/"""

@pytest.mark.xfail(reason="legend.lnbits.com is down")
@pytest.mark.asyncio
@pytest.mark.parametrize(
"bech32",
Expand All @@ -86,7 +87,7 @@ async def test_handle_withdraw(self, bech32):
res = await handle(bech32)
assert isinstance(res, LnurlWithdrawResponse)
assert res.tag == "withdrawRequest"
assert res.callback.host == "legend.lnbits.com"
assert res.callback.host == "demo.lnbits.com"
assert res.default_description == "sample withdraw"
assert res.max_withdrawable >= res.min_withdrawable

Expand All @@ -104,8 +105,9 @@ async def test_get_requests_error(self, url):


class TestPayFlow:
"""Full LNURL-pay flow interacting with https://legend.lnbits.com/"""
"""Full LNURL-pay flow interacting with https://demo.lnbits.com/"""

@pytest.mark.xfail(reason="legend.lnbits.com is down")
@pytest.mark.asyncio
@pytest.mark.parametrize(
"bech32, amount",
Expand All @@ -114,14 +116,14 @@ class TestPayFlow:
"LNURL1DP68GURN8GHJ7MR9VAJKUEPWD3HXY6T5WVHXXMMD9AKXUATJD3CZ7JN9F4EHQJQC25ZZY",
"1000",
),
("donate@legend.lnbits.com", "100000"),
("donate@demo.lnbits.com", "100000"),
],
)
async def test_pay_flow(self, bech32: str, amount: str):
res = await handle(bech32)
assert isinstance(res, LnurlPayResponse)
assert res.tag == "payRequest"
assert res.callback.host == "legend.lnbits.com"
assert res.callback.host == "demo.lnbits.com"
assert len(res.metadata.list()) >= 1
assert res.metadata.text != ""

Expand Down
Loading