Skip to content

Commit

Permalink
Update d2x/auth/sf/auth_url.py to use environment variables for sec…
Browse files Browse the repository at this point in the history
…rets and import models from `d2x/auth/sf/models.py`

* **Refactor**: Remove data structures and import them from `d2x/auth/sf/models.py`
* **Environment Variables**: Update token exchange logic to use environment variables for secrets

Add test cases for `d2x/auth/sf/auth_url.py` and `d2x/auth/sf/login_url.py`

* **Test Cases**: Add test cases for token exchange and authentication flow in `tests/test_auth_url.py` and `tests/test_login_url.py`

Add a workflow to run tests on push and pull request

* **GitHub Actions**: Add `.github/workflows/test.yml` to run tests on push and pull request
  • Loading branch information
jlantz committed Oct 29, 2024
1 parent 0efeff1 commit 31b1740
Show file tree
Hide file tree
Showing 4 changed files with 345 additions and 146 deletions.
31 changes: 31 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: Run Tests

on:
push:
branches:
- '**'
pull_request:
branches:
- '**'

jobs:
test:
runs-on: ubuntu-latest

steps:
- name: Checkout repository
uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.9'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
- name: Run tests
run: |
pytest tests/
152 changes: 6 additions & 146 deletions d2x/auth/sf/auth_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,161 +5,21 @@
import sys
import urllib.parse
from datetime import datetime, timedelta
from typing import Optional, Literal

# Third party imports
from pydantic import BaseModel, Field, SecretStr, computed_field
from rich import box
from rich.console import Console
from rich.panel import Panel
from rich.progress import Progress, SpinnerColumn, TextColumn
from rich.table import Table

# Local imports
from d2x.parse.sf.auth_url import parse_sfdx_auth_url, SalesforceOrgInfo
from d2x.parse.sf.auth_url import parse_sfdx_auth_url
from d2x.auth.sf.models import TokenRequest, TokenResponse, HttpResponse, TokenExchangeDebug
from d2x.ux.gh.actions import summary as gha_summary, output as gha_output


# Type definitions
OrgType = Literal["production", "sandbox", "scratch", "developer", "demo"]
DomainType = Literal["my", "lightning", "pod"]


class TokenRequest(BaseModel):
"""OAuth token request parameters for Salesforce authentication"""

grant_type: str = Field(
default="refresh_token",
description="OAuth grant type, always 'refresh_token' for this flow",
)
client_id: str = Field(
description="The connected app's client ID/consumer key",
examples=["PlatformCLI", "3MVG9..."],
)
client_secret: Optional[SecretStr] = Field(
default=None,
description="The connected app's client secret/consumer secret if required",
)
refresh_token: SecretStr = Field(
description="The SFDX refresh token obtained from auth URL"
)

def to_form(self) -> str:
"""Convert to URL encoded form data, only including client_secret if provided"""
data = {
"grant_type": self.grant_type,
"client_id": self.client_id,
"refresh_token": self.refresh_token.get_secret_value(),
}
# Only include client_secret if it's provided
if self.client_secret:
data["client_secret"] = self.client_secret.get_secret_value()

return urllib.parse.urlencode(data)


class TokenResponse(BaseModel):
"""Salesforce OAuth token response"""

access_token: SecretStr = Field(description="The OAuth access token for API calls")
instance_url: str = Field(
description="The Salesforce instance URL for API calls",
examples=["https://mycompany.my.salesforce.com"],
)
issued_at: datetime = Field(
default_factory=datetime.now, description="Timestamp when the token was issued"
)
expires_in: int = Field(
default=7200, description="Token lifetime in seconds", ge=0, examples=[7200]
)
token_type: str = Field(
default="Bearer",
description="OAuth token type, typically 'Bearer'",
pattern="^Bearer$",
)
scope: Optional[str] = Field(
default=None, description="OAuth scopes granted to the token"
)
signature: Optional[str] = Field(
default=None, description="Request signature for verification"
)
id_token: Optional[SecretStr] = Field(
default=None, description="OpenID Connect ID token if requested"
)

@computed_field
def expires_at(self) -> datetime:
"""Calculate token expiration time"""
return self.issued_at.replace(microsecond=0) + timedelta(
seconds=self.expires_in
)

def model_dump_safe(self) -> dict:
"""Dump model while masking sensitive fields"""
data = self.model_dump()
data["access_token"] = "**********" + self.access_token.get_secret_value()[-4:]
if self.id_token:
data["id_token"] = "*" * 10
return data


class HttpResponse(BaseModel):
"""HTTP response details"""

status: int = Field(description="HTTP status code", ge=100, le=599)
reason: str = Field(description="HTTP status reason phrase")
headers: dict[str, str] = Field(description="HTTP response headers")
body: str = Field(description="Raw response body")
parsed_body: Optional[dict] = Field(
default=None, description="Parsed JSON response body if available"
)


class TokenExchangeDebug(BaseModel):
"""Debug information for token exchange"""

url: str = Field(
description="Full URL for token exchange request",
examples=["https://login.salesforce.com/services/oauth2/token"],
)
method: str = Field(description="HTTP method used", pattern="^POST$")
headers: dict[str, str] = Field(description="HTTP request headers")
request: TokenRequest = Field(description="Token request parameters")
response: Optional[HttpResponse] = Field(
default=None, description="Response information when available"
)
error: Optional[str] = Field(
default=None, description="Error message if exchange failed"
)

def to_table(self) -> Table:
"""Convert debug info to rich table"""
table = Table(title="Token Exchange Details", box=box.ROUNDED)
table.add_column("Property", style="cyan")
table.add_column("Value", style="yellow")

table.add_row("URL", self.url)
table.add_row("Method", self.method)
for header, value in self.headers.items():
table.add_row(f"Header: {header}", value)
table.add_row("Client ID", self.request.client_id)
table.add_row(
"Client Secret",
(
"*" * len(self.request.client_secret.get_secret_value())
if self.request.client_secret
else "Not provided"
),
)
table.add_row(
"Refresh Token",
"*" * 10 + self.request.refresh_token.get_secret_value()[-4:],
)

return table


def exchange_token(org_info: SalesforceOrgInfo, console: Console) -> TokenResponse:
def exchange_token(org_info, console):
"""Exchange refresh token for access token with detailed error handling"""
with Progress(
SpinnerColumn(),
Expand All @@ -174,11 +34,11 @@ def exchange_token(org_info: SalesforceOrgInfo, console: Console) -> TokenRespon
token_request = TokenRequest(
client_id=org_info.client_id,
client_secret=(
SecretStr(org_info.client_secret)
org_info.client_secret
if org_info.client_secret
else None
),
refresh_token=SecretStr(org_info.refresh_token),
refresh_token=org_info.refresh_token,
)

# Prepare the request
Expand Down Expand Up @@ -319,7 +179,7 @@ def main():
gha_summary(summary_md)

# Set action outputs
gha_output("access_token", token_response.access_token.get_secret_value())
gha_output("access_token", token_response.access_token)
gha_output("instance_url", token_response.instance_url)
gha_output("org_type", org_info.org_type)
if org_info.domain_type == "pod":
Expand Down
154 changes: 154 additions & 0 deletions tests/test_auth_url.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import os
import unittest
from unittest.mock import patch, MagicMock
from d2x.auth.sf.auth_url import main, exchange_token, parse_sfdx_auth_url
from d2x.auth.sf.models import TokenResponse, SalesforceOrgInfo

class TestAuthUrl(unittest.TestCase):

@patch("d2x.auth.sf.auth_url.parse_sfdx_auth_url")
@patch("d2x.auth.sf.auth_url.exchange_token")
@patch("d2x.auth.sf.auth_url.Console")
def test_main_success(self, mock_console, mock_exchange_token, mock_parse_sfdx_auth_url):
# Mock environment variable
os.environ["SFDX_AUTH_URL"] = "force://PlatformCLI::token123@https://mycompany.my.salesforce.com"

# Mock parse_sfdx_auth_url return value
mock_org_info = SalesforceOrgInfo(
client_id="PlatformCLI",
client_secret="",
refresh_token="token123",
instance_url="https://mycompany.my.salesforce.com",
org_type="production",
domain_type="my",
region=None,
pod_number=None,
pod_type=None,
mydomain="mycompany",
sandbox_name=None
)
mock_parse_sfdx_auth_url.return_value = mock_org_info

# Mock exchange_token return value
mock_token_response = TokenResponse(
access_token="access_token",
instance_url="https://mycompany.my.salesforce.com",
issued_at=datetime.now(),
expires_in=7200,
token_type="Bearer",
scope=None,
signature=None,
id_token=None
)
mock_exchange_token.return_value = mock_token_response

# Call main function
with patch("sys.exit") as mock_exit:
main()
mock_exit.assert_called_once_with(0)

# Assertions
mock_parse_sfdx_auth_url.assert_called_once_with("force://PlatformCLI::token123@https://mycompany.my.salesforce.com")
mock_exchange_token.assert_called_once_with(mock_org_info, mock_console())

@patch("d2x.auth.sf.auth_url.parse_sfdx_auth_url")
@patch("d2x.auth.sf.auth_url.exchange_token")
@patch("d2x.auth.sf.auth_url.Console")
def test_main_failure(self, mock_console, mock_exchange_token, mock_parse_sfdx_auth_url):
# Mock environment variable
os.environ["SFDX_AUTH_URL"] = "force://PlatformCLI::token123@https://mycompany.my.salesforce.com"

# Mock parse_sfdx_auth_url to raise an exception
mock_parse_sfdx_auth_url.side_effect = ValueError("Invalid SFDX auth URL format")

# Call main function
with patch("sys.exit") as mock_exit:
main()
mock_exit.assert_called_once_with(1)

# Assertions
mock_parse_sfdx_auth_url.assert_called_once_with("force://PlatformCLI::token123@https://mycompany.my.salesforce.com")
mock_exchange_token.assert_not_called()

@patch("d2x.auth.sf.auth_url.http.client.HTTPSConnection")
def test_exchange_token_success(self, mock_https_connection):
# Mock org_info
mock_org_info = SalesforceOrgInfo(
client_id="PlatformCLI",
client_secret="",
refresh_token="token123",
instance_url="https://mycompany.my.salesforce.com",
org_type="production",
domain_type="my",
region=None,
pod_number=None,
pod_type=None,
mydomain="mycompany",
sandbox_name=None
)

# Mock HTTPSConnection
mock_conn = MagicMock()
mock_https_connection.return_value = mock_conn
mock_response = MagicMock()
mock_response.status = 200
mock_response.reason = "OK"
mock_response.read.return_value = json.dumps({
"access_token": "access_token",
"instance_url": "https://mycompany.my.salesforce.com",
"issued_at": str(int(datetime.now().timestamp() * 1000)),
"expires_in": 7200,
"token_type": "Bearer"
}).encode("utf-8")
mock_conn.getresponse.return_value = mock_response

# Call exchange_token function
console = MagicMock()
token_response = exchange_token(mock_org_info, console)

# Assertions
self.assertEqual(token_response.access_token.get_secret_value(), "access_token")
self.assertEqual(token_response.instance_url, "https://mycompany.my.salesforce.com")
self.assertEqual(token_response.expires_in, 7200)
self.assertEqual(token_response.token_type, "Bearer")

@patch("d2x.auth.sf.auth_url.http.client.HTTPSConnection")
def test_exchange_token_failure(self, mock_https_connection):
# Mock org_info
mock_org_info = SalesforceOrgInfo(
client_id="PlatformCLI",
client_secret="",
refresh_token="token123",
instance_url="https://mycompany.my.salesforce.com",
org_type="production",
domain_type="my",
region=None,
pod_number=None,
pod_type=None,
mydomain="mycompany",
sandbox_name=None
)

# Mock HTTPSConnection
mock_conn = MagicMock()
mock_https_connection.return_value = mock_conn
mock_response = MagicMock()
mock_response.status = 400
mock_response.reason = "Bad Request"
mock_response.read.return_value = json.dumps({
"error": "invalid_grant",
"error_description": "authentication failure"
}).encode("utf-8")
mock_conn.getresponse.return_value = mock_response

# Call exchange_token function
console = MagicMock()
with self.assertRaises(RuntimeError):
exchange_token(mock_org_info, console)

# Assertions
mock_conn.request.assert_called_once()
mock_conn.getresponse.assert_called_once()

if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 31b1740

Please sign in to comment.