Skip to content

Commit

Permalink
Improve run times when using key pair auth by caching the private key (
Browse files Browse the repository at this point in the history
…#1110)

* cache _get_private_key
* update caching functions and add unit tests
* add an integration test for key pair auth method
* generalize oauth test suite to auth_tests
* move generic private key methods into their own module
  • Loading branch information
mikealfare authored Jul 11, 2024
1 parent b2ce704 commit 1193a79
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 33 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20240710-172345.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Improve run times when using key pair auth by caching the private key
time: 2024-07-10T17:23:45.046905-04:00
custom:
Author: mikealfare aranke
Issue: "1082"
57 changes: 57 additions & 0 deletions dbt/adapters/snowflake/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import base64
import sys
from typing import Optional

if sys.version_info < (3, 9):
from functools import lru_cache

cache = lru_cache(maxsize=None)
else:
from functools import cache

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey


@cache
def private_key_from_string(
private_key_string: str, passphrase: Optional[str] = None
) -> RSAPrivateKey:

if passphrase:
encoded_passphrase = passphrase.encode()
else:
encoded_passphrase = None

if private_key_string.startswith("-"):
return serialization.load_pem_private_key(
data=bytes(private_key_string, "utf-8"),
password=encoded_passphrase,
backend=default_backend(),
)
return serialization.load_der_private_key(
data=base64.b64decode(private_key_string),
password=encoded_passphrase,
backend=default_backend(),
)


@cache
def private_key_from_file(
private_key_path: str, passphrase: Optional[str] = None
) -> RSAPrivateKey:

if passphrase:
encoded_passphrase = passphrase.encode()
else:
encoded_passphrase = None

with open(private_key_path, "rb") as file:
private_key_bytes = file.read()

return serialization.load_pem_private_key(
data=private_key_bytes,
password=encoded_passphrase,
backend=default_backend(),
)
58 changes: 25 additions & 33 deletions dbt/adapters/snowflake/connections.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import base64
import datetime
import os
import sys

if sys.version_info < (3, 9):
from functools import lru_cache

cache = lru_cache(maxsize=None)
else:
from functools import cache

import pytz
import re
Expand All @@ -11,8 +19,8 @@

from typing import Optional, Tuple, Union, Any, List, Iterable, TYPE_CHECKING

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
import requests
import snowflake.connector
import snowflake.connector.constants
Expand Down Expand Up @@ -44,6 +52,8 @@
from dbt.adapters.events.types import AdapterEventWarning, AdapterEventError
from dbt_common.ui import line_wrap_message, warning_tag

from dbt.adapters.snowflake.auth import private_key_from_file, private_key_from_string

if TYPE_CHECKING:
import agate

Expand All @@ -63,6 +73,15 @@
}


@cache
def snowflake_private_key(private_key: RSAPrivateKey) -> bytes:
return private_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)


@dataclass
class SnowflakeAdapterResponse(AdapterResponse):
query_id: str = ""
Expand Down Expand Up @@ -273,44 +292,17 @@ def _get_access_token(self) -> str:
)
return result_json["access_token"]

def _get_private_key(self):
def _get_private_key(self) -> Optional[bytes]:
"""Get Snowflake private key by path, from a Base64 encoded DER bytestring or None."""
if self.private_key and self.private_key_path:
raise DbtConfigError("Cannot specify both `private_key` and `private_key_path`")

if self.private_key_passphrase:
encoded_passphrase = self.private_key_passphrase.encode()
else:
encoded_passphrase = None

if self.private_key:
if self.private_key.startswith("-"):
p_key = serialization.load_pem_private_key(
data=bytes(self.private_key, "utf-8"),
password=encoded_passphrase,
backend=default_backend(),
)

else:
p_key = serialization.load_der_private_key(
data=base64.b64decode(self.private_key),
password=encoded_passphrase,
backend=default_backend(),
)

elif self.private_key:
private_key = private_key_from_string(self.private_key, self.private_key_passphrase)
elif self.private_key_path:
with open(self.private_key_path, "rb") as key:
p_key = serialization.load_pem_private_key(
key.read(), password=encoded_passphrase, backend=default_backend()
)
private_key = private_key_from_file(self.private_key_path, self.private_key_passphrase)
else:
return None

return p_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
return snowflake_private_key(private_key)


class SnowflakeConnectionManager(SQLConnectionManager):
Expand Down
File renamed without changes.
26 changes: 26 additions & 0 deletions tests/functional/auth_tests/test_key_pair.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os

from dbt.tests.util import run_dbt
import pytest


class TestKeyPairAuth:
@pytest.fixture(scope="class", autouse=True)
def dbt_profile_target(self):
return {
"type": "snowflake",
"threads": 4,
"account": os.getenv("SNOWFLAKE_TEST_ACCOUNT"),
"user": os.getenv("SNOWFLAKE_TEST_USER"),
"private_key": os.getenv("SNOWFLAKE_TEST_PRIVATE_KEY"),
"private_key_passphrase": os.getenv("SNOWFLAKE_TEST_PRIVATE_KEY_PASSPHRASE"),
"database": os.getenv("SNOWFLAKE_TEST_DATABASE"),
"warehouse": os.getenv("SNOWFLAKE_TEST_WAREHOUSE"),
}

@pytest.fixture(scope="class")
def models(self):
return {"my_model.sql": "select 1 as id"}

def test_connection(self, project):
run_dbt()
File renamed without changes.
61 changes: 61 additions & 0 deletions tests/unit/test_private_keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
import tempfile
from typing import Generator

from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
import pytest

from dbt.adapters.snowflake.auth import private_key_from_file, private_key_from_string


PASSPHRASE = "password1234"


def serialize(private_key: rsa.RSAPrivateKey) -> bytes:
return private_key.private_bytes(
serialization.Encoding.DER,
serialization.PrivateFormat.PKCS8,
serialization.NoEncryption(),
)


@pytest.fixture(scope="session")
def private_key() -> rsa.RSAPrivateKey:
return rsa.generate_private_key(public_exponent=65537, key_size=2048)


@pytest.fixture(scope="session")
def private_key_string(private_key) -> str:
private_key_bytes = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.BestAvailableEncryption(PASSPHRASE.encode()),
)
return private_key_bytes.decode("utf-8")


@pytest.fixture(scope="session")
def private_key_file(private_key) -> Generator[str, None, None]:
private_key_bytes = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.BestAvailableEncryption(PASSPHRASE.encode()),
)
file = tempfile.NamedTemporaryFile()
file.write(private_key_bytes)
file.seek(0)
yield file.name
file.close()


def test_private_key_from_string_pem(private_key_string, private_key):
assert isinstance(private_key_string, str)
calculated_private_key = private_key_from_string(private_key_string, PASSPHRASE)
assert serialize(calculated_private_key) == serialize(private_key)


def test_private_key_from_file(private_key_file, private_key):
assert os.path.exists(private_key_file)
calculated_private_key = private_key_from_file(private_key_file, PASSPHRASE)
assert serialize(calculated_private_key) == serialize(private_key)

0 comments on commit 1193a79

Please sign in to comment.