-
Notifications
You must be signed in to change notification settings - Fork 181
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Improve run times when using key pair auth by caching the private key (…
…#1110) * cache _get_private_key * update caching functions and add unit tests * add an integration test for key pair auth method * generalize oauth test suite to auth_tests * move generic private key methods into their own module
- Loading branch information
1 parent
b2ce704
commit 1193a79
Showing
7 changed files
with
175 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
kind: Features | ||
body: Improve run times when using key pair auth by caching the private key | ||
time: 2024-07-10T17:23:45.046905-04:00 | ||
custom: | ||
Author: mikealfare aranke | ||
Issue: "1082" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import base64 | ||
import sys | ||
from typing import Optional | ||
|
||
if sys.version_info < (3, 9): | ||
from functools import lru_cache | ||
|
||
cache = lru_cache(maxsize=None) | ||
else: | ||
from functools import cache | ||
|
||
from cryptography.hazmat.backends import default_backend | ||
from cryptography.hazmat.primitives import serialization | ||
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey | ||
|
||
|
||
@cache | ||
def private_key_from_string( | ||
private_key_string: str, passphrase: Optional[str] = None | ||
) -> RSAPrivateKey: | ||
|
||
if passphrase: | ||
encoded_passphrase = passphrase.encode() | ||
else: | ||
encoded_passphrase = None | ||
|
||
if private_key_string.startswith("-"): | ||
return serialization.load_pem_private_key( | ||
data=bytes(private_key_string, "utf-8"), | ||
password=encoded_passphrase, | ||
backend=default_backend(), | ||
) | ||
return serialization.load_der_private_key( | ||
data=base64.b64decode(private_key_string), | ||
password=encoded_passphrase, | ||
backend=default_backend(), | ||
) | ||
|
||
|
||
@cache | ||
def private_key_from_file( | ||
private_key_path: str, passphrase: Optional[str] = None | ||
) -> RSAPrivateKey: | ||
|
||
if passphrase: | ||
encoded_passphrase = passphrase.encode() | ||
else: | ||
encoded_passphrase = None | ||
|
||
with open(private_key_path, "rb") as file: | ||
private_key_bytes = file.read() | ||
|
||
return serialization.load_pem_private_key( | ||
data=private_key_bytes, | ||
password=encoded_passphrase, | ||
backend=default_backend(), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import os | ||
|
||
from dbt.tests.util import run_dbt | ||
import pytest | ||
|
||
|
||
class TestKeyPairAuth: | ||
@pytest.fixture(scope="class", autouse=True) | ||
def dbt_profile_target(self): | ||
return { | ||
"type": "snowflake", | ||
"threads": 4, | ||
"account": os.getenv("SNOWFLAKE_TEST_ACCOUNT"), | ||
"user": os.getenv("SNOWFLAKE_TEST_USER"), | ||
"private_key": os.getenv("SNOWFLAKE_TEST_PRIVATE_KEY"), | ||
"private_key_passphrase": os.getenv("SNOWFLAKE_TEST_PRIVATE_KEY_PASSPHRASE"), | ||
"database": os.getenv("SNOWFLAKE_TEST_DATABASE"), | ||
"warehouse": os.getenv("SNOWFLAKE_TEST_WAREHOUSE"), | ||
} | ||
|
||
@pytest.fixture(scope="class") | ||
def models(self): | ||
return {"my_model.sql": "select 1 as id"} | ||
|
||
def test_connection(self, project): | ||
run_dbt() |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import os | ||
import tempfile | ||
from typing import Generator | ||
|
||
from cryptography.hazmat.primitives import serialization | ||
from cryptography.hazmat.primitives.asymmetric import rsa | ||
import pytest | ||
|
||
from dbt.adapters.snowflake.auth import private_key_from_file, private_key_from_string | ||
|
||
|
||
PASSPHRASE = "password1234" | ||
|
||
|
||
def serialize(private_key: rsa.RSAPrivateKey) -> bytes: | ||
return private_key.private_bytes( | ||
serialization.Encoding.DER, | ||
serialization.PrivateFormat.PKCS8, | ||
serialization.NoEncryption(), | ||
) | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def private_key() -> rsa.RSAPrivateKey: | ||
return rsa.generate_private_key(public_exponent=65537, key_size=2048) | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def private_key_string(private_key) -> str: | ||
private_key_bytes = private_key.private_bytes( | ||
encoding=serialization.Encoding.PEM, | ||
format=serialization.PrivateFormat.PKCS8, | ||
encryption_algorithm=serialization.BestAvailableEncryption(PASSPHRASE.encode()), | ||
) | ||
return private_key_bytes.decode("utf-8") | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def private_key_file(private_key) -> Generator[str, None, None]: | ||
private_key_bytes = private_key.private_bytes( | ||
encoding=serialization.Encoding.PEM, | ||
format=serialization.PrivateFormat.PKCS8, | ||
encryption_algorithm=serialization.BestAvailableEncryption(PASSPHRASE.encode()), | ||
) | ||
file = tempfile.NamedTemporaryFile() | ||
file.write(private_key_bytes) | ||
file.seek(0) | ||
yield file.name | ||
file.close() | ||
|
||
|
||
def test_private_key_from_string_pem(private_key_string, private_key): | ||
assert isinstance(private_key_string, str) | ||
calculated_private_key = private_key_from_string(private_key_string, PASSPHRASE) | ||
assert serialize(calculated_private_key) == serialize(private_key) | ||
|
||
|
||
def test_private_key_from_file(private_key_file, private_key): | ||
assert os.path.exists(private_key_file) | ||
calculated_private_key = private_key_from_file(private_key_file, PASSPHRASE) | ||
assert serialize(calculated_private_key) == serialize(private_key) |