-
-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Upload task worker results to Backend API #22
Changes from all commits
8532b30
6047a93
d57ad1c
245aed8
aaa2a75
9580ffe
4a9084d
ede4da7
80b0b97
3055d20
af271c0
28a40a0
81c60e9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -170,3 +170,4 @@ dev/data/** | |
!dev/data/README.md | ||
!dev/.env | ||
id_rsa | ||
*.json |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,42 +1,65 @@ | ||
import sys | ||
|
||
import pycountry | ||
from cryptography.hazmat.primitives import serialization | ||
|
||
from mirrors_qa_backend import logger | ||
from mirrors_qa_backend.db import Session | ||
from mirrors_qa_backend.db.country import update_countries as update_db_countries | ||
from mirrors_qa_backend.db.worker import create_worker as create_db_worker | ||
from mirrors_qa_backend.db.worker import update_worker as update_db_worker | ||
|
||
|
||
def get_country_mapping(country_codes: list[str]) -> dict[str, str]: | ||
"""Fetch the country names from the country codes. | ||
|
||
def create_worker(worker_id: str, private_key_data: bytes, country_codes: list[str]): | ||
Maps the country code to the country name. | ||
""" | ||
country_mapping: dict[str, str] = {} | ||
# Ensure all the countries are valid country codes | ||
for country_code in country_codes: | ||
if len(country_code) != 2: # noqa: PLR2004 | ||
logger.info(f"Country code '{country_code}' must be two characters long") | ||
sys.exit(1) | ||
|
||
if not pycountry.countries.get(alpha_2=country_code): | ||
logger.info(f"'{country_code}' is not valid country code") | ||
sys.exit(1) | ||
|
||
try: | ||
private_key = serialization.load_pem_private_key( | ||
private_key_data, password=None | ||
) # pyright: ignore[reportReturnType] | ||
except Exception as exc: | ||
logger.info(f"Unable to load private key: {exc}") | ||
sys.exit(1) | ||
|
||
try: | ||
with Session.begin() as session: | ||
create_db_worker( | ||
session, | ||
worker_id, | ||
country_codes, | ||
private_key, # pyright: ignore [reportGeneralTypeIssues, reportArgumentType] | ||
raise ValueError( | ||
f"Country code '{country_code}' must be two characters long" | ||
) | ||
except Exception as exc: | ||
logger.info(f"error while creating worker: {exc}") | ||
sys.exit(1) | ||
|
||
if country := pycountry.countries.get(alpha_2=country_code): | ||
country_mapping[country_code] = country.name | ||
else: | ||
raise ValueError(f"'{country_code}' is not valid country code") | ||
return country_mapping | ||
|
||
|
||
def create_worker( | ||
worker_id: str, private_key_data: bytes, initial_country_codes: list[str] | ||
): | ||
"""Create a worker in the DB. | ||
|
||
Assigns the countries for a worker to run tests from. | ||
""" | ||
country_mapping = get_country_mapping(initial_country_codes) | ||
private_key = serialization.load_pem_private_key( | ||
private_key_data, password=None | ||
) # pyright: ignore[reportReturnType] | ||
|
||
with Session.begin() as session: | ||
# Update the database with the countries in case those countries don't | ||
# exist yet. | ||
update_db_countries(session, country_mapping) | ||
create_db_worker( | ||
session, | ||
worker_id, | ||
initial_country_codes, | ||
private_key, # pyright: ignore [reportGeneralTypeIssues, reportArgumentType] | ||
) | ||
|
||
logger.info(f"Created worker {worker_id} successfully") | ||
|
||
|
||
def update_worker(worker_id: str, country_codes: list[str]): | ||
"""Update worker's data. | ||
|
||
Updates the ountries for a worker to run tests from. | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo You see you have a duplicated block to assign countries to a worker. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay. Makes sense. |
||
country_mapping = get_country_mapping(country_codes) | ||
with Session.begin() as session: | ||
update_db_countries(session, country_mapping) | ||
update_db_worker(session, worker_id, country_codes) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,50 @@ | ||
from sqlalchemy import select | ||
from sqlalchemy.dialects.postgresql import insert | ||
from sqlalchemy.orm import Session as OrmSession | ||
|
||
from mirrors_qa_backend.db.exceptions import RecordDoesNotExistError | ||
from mirrors_qa_backend.db.models import Country | ||
|
||
|
||
def get_countries(session: OrmSession, *country_codes: str) -> list[Country]: | ||
def get_countries(session: OrmSession, country_codes: list[str]) -> list[Country]: | ||
"""Get countries with the provided country codes. | ||
|
||
Gets all available countries if no country codes are provided. | ||
""" | ||
return list( | ||
session.scalars(select(Country).where(Country.code.in_(country_codes))).all() | ||
session.scalars( | ||
select(Country).where( | ||
(Country.code.in_(country_codes)) | (country_codes == []) | ||
) | ||
).all() | ||
) | ||
|
||
|
||
def get_country_or_none(session: OrmSession, country_code: str) -> Country | None: | ||
return session.scalars( | ||
select(Country).where(Country.code == country_code) | ||
).one_or_none() | ||
|
||
|
||
def get_country(session: OrmSession, country_code: str) -> Country: | ||
if country := get_country_or_none(session, country_code): | ||
return country | ||
raise RecordDoesNotExistError(f"Country with code {country_code} does not exist.") | ||
|
||
|
||
def create_country( | ||
session: OrmSession, *, country_code: str, country_name: str | ||
) -> Country: | ||
"""Creates a new country in the database.""" | ||
session.execute( | ||
insert(Country) | ||
.values(code=country_code, name=country_name) | ||
.on_conflict_do_nothing(index_elements=["code"]) | ||
) | ||
return get_country(session, country_code) | ||
|
||
|
||
def update_countries(session: OrmSession, country_mapping: dict[str, str]) -> None: | ||
"""Updates the list of countries in the database.""" | ||
for country_code, country_name in country_mapping.items(): | ||
create_country(session, country_code=country_code, country_name=country_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not clear enough. Maybe renaming to
initial_country_codes
would help