Skip to content

Commit

Permalink
set up scheduler to create tasks for idle workers
Browse files Browse the repository at this point in the history
  • Loading branch information
elfkuzco committed Jun 20, 2024
1 parent 460ac2d commit c80736d
Show file tree
Hide file tree
Showing 14 changed files with 352 additions and 13 deletions.
5 changes: 3 additions & 2 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies = [
"pycountry==24.6.1",
"cryptography==42.0.8",
"PyJWT==2.8.0",
"paramiko==3.4.0",
]
license = {text = "GPL-3.0-or-later"}
classifiers = [
Expand All @@ -37,6 +38,7 @@ Homepage = "https://github.com/kiwix/mirrors-qa"

[project.scripts]
update-mirrors = "mirrors_qa_backend.entrypoint:main"
mirrors-qa-scheduler = "mirrors_qa_backend.scheduler:main"

[project.optional-dependencies]
scripts = [
Expand All @@ -53,7 +55,6 @@ test = [
"pytest==8.0.0",
"coverage==7.4.1",
"Faker==25.8.0",
"paramiko==3.4.0",
"httpx==0.27.0",
]
dev = [
Expand Down Expand Up @@ -215,7 +216,7 @@ testpaths = ["tests"]
pythonpath = [".", "src"]
addopts = "--strict-markers"
markers = [
"num_tests: number of tests to create in the database (default: 10)",
"num_tests(num=10, *, status=..., country=...): create num tests in the database using status and/or country. Random data is chosen for country or status if either is not set",
]

[tool.coverage.paths]
Expand Down
35 changes: 33 additions & 2 deletions backend/src/mirrors_qa_backend/cryptography.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import datetime

import jwt
import paramiko
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
from cryptography.hazmat.primitives.asymmetric import padding, rsa
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey

from mirrors_qa_backend.exceptions import PEMPublicKeyLoadError
from mirrors_qa_backend.settings import Settings
Expand Down Expand Up @@ -44,6 +45,36 @@ def sign_message(private_key: RSAPrivateKey, message: bytes) -> bytes:
)


def generate_private_key(key_size: int = 2048) -> RSAPrivateKey:
return rsa.generate_private_key(public_exponent=65537, key_size=key_size)


def serialize_private_key(private_key: RSAPrivateKey) -> bytes:
return private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)


def generate_public_key(private_key: RSAPrivateKey) -> RSAPublicKey:
return private_key.public_key()


def serialize_public_key(public_key: RSAPublicKey) -> bytes:
return public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)


def get_public_key_fingerprint(public_key: RSAPublicKey) -> str:
"""Compute the SHA256 fingerprint of the public key"""
return paramiko.RSAKey(
key=public_key
).fingerprint # pyright: ignore[reportUnknownMemberType, UnknownVariableType]


def generate_access_token(worker_id: str) -> str:
issue_time = datetime.datetime.now(datetime.UTC)
expire_time = issue_time + datetime.timedelta(hours=Settings.TOKEN_EXPIRY)
Expand Down
12 changes: 12 additions & 0 deletions backend/src/mirrors_qa_backend/db/country.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from sqlalchemy import select
from sqlalchemy.orm import Session as OrmSession

from mirrors_qa_backend.db import models


def get_countries_by_name(session: OrmSession, *countries: str) -> list[models.Country]:
return list(
session.scalars(
select(models.Country).where(models.Country.name.in_(countries))
).all()
)
4 changes: 4 additions & 0 deletions backend/src/mirrors_qa_backend/db/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ def __init__(self, message: str, *args: object) -> None:

class EmptyMirrorsError(Exception):
"""An empty list was used to update the mirrors in the database."""


class DuplicatePrimaryKeyError(Exception):
"""A database record with the same primary key exists."""
57 changes: 56 additions & 1 deletion backend/src/mirrors_qa_backend/db/tests.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# ruff: noqa: DTZ005, DTZ001
import datetime
from dataclasses import dataclass
from ipaddress import IPv4Address
from uuid import UUID

from sqlalchemy import UnaryExpression, asc, desc, func, select
from sqlalchemy import UnaryExpression, asc, desc, func, select, update
from sqlalchemy.orm import Session as OrmSession

from mirrors_qa_backend.db import models
Expand Down Expand Up @@ -144,5 +145,59 @@ def create_or_update_test(
test.started_on = started_on if started_on else test.started_on

session.add(test)
session.flush()

return test


def create_test(
session: OrmSession,
*,
worker_id: str | None = None,
status: StatusEnum = StatusEnum.PENDING,
error: str | None = None,
ip_address: IPv4Address | None = None,
asn: str | None = None,
country: str | None = None,
location: str | None = None,
latency: int | None = None,
download_size: int | None = None,
duration: int | None = None,
speed: float | None = None,
started_on: datetime.datetime | None = None,
) -> models.Test:
return create_or_update_test(
session,
test_id=None,
worker_id=worker_id,
status=status,
error=error,
ip_address=ip_address,
asn=asn,
country=country,
location=location,
latency=latency,
download_size=download_size,
duration=duration,
speed=speed,
started_on=started_on,
)


def expire_tests(
session: OrmSession, interval: datetime.timedelta
) -> list[models.Test]:
"""Change the status of PENDING tests created before the interval to MISSED"""
end = datetime.datetime.now() - interval
begin = datetime.datetime(1970, 1, 1)
return list(
session.scalars(
update(models.Test)
.where(
models.Test.requested_on.between(begin, end),
models.Test.status == StatusEnum.PENDING,
)
.values(status=StatusEnum.MISSED)
.returning(models.Test)
).all()
)
85 changes: 84 additions & 1 deletion backend/src/mirrors_qa_backend/db/worker.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,93 @@
# ruff: noqa: DTZ005, DTZ001
import datetime
from pathlib import Path

from sqlalchemy import select
from sqlalchemy.orm import Session as OrmSession

from mirrors_qa_backend.db import models
from mirrors_qa_backend import cryptography
from mirrors_qa_backend.db import country, models
from mirrors_qa_backend.db.exceptions import DuplicatePrimaryKeyError


def get_worker(session: OrmSession, worker_id: str) -> models.Worker | None:
return session.scalars(
select(models.Worker).where(models.Worker.id == worker_id)
).one_or_none()


def create_worker(
session: OrmSession,
worker_id: str,
countries: list[str],
private_key_filename: str | Path | None = None,
) -> models.Worker:
"""Creates a worker and writes private key contents to private_key_filename.
If no private_key_filename is provided, defaults to {worker_id}.pem.
"""
if get_worker(session, worker_id) is not None:
raise DuplicatePrimaryKeyError(
f"A worker with id {worker_id!r} already exists."
)

if private_key_filename is None:
private_key_filename = f"{worker_id}.pem"

private_key = cryptography.generate_private_key()
public_key = cryptography.generate_public_key(private_key)
public_key_pkcs8 = cryptography.serialize_public_key(public_key).decode(
encoding="ascii"
)
with open(private_key_filename, "wb") as fp:
fp.write(cryptography.serialize_private_key(private_key))

worker = models.Worker(
id=worker_id,
pubkey_pkcs8=public_key_pkcs8,
pubkey_fingerprint=cryptography.get_public_key_fingerprint(public_key),
)
session.add(worker)

for db_country in country.get_countries_by_name(session, *countries):
db_country.worker_id = worker_id
session.add(db_country)

return worker


def get_workers_last_seen_in_range(
session: OrmSession, begin: datetime.datetime, end: datetime.datetime
) -> list[models.Worker]:
"""Get workers whose last_seen_on falls between begin and end dates"""
return list(
session.scalars(
select(models.Worker).where(
models.Worker.last_seen_on.between(begin, end),
)
).all()
)


def get_idle_workers(
session: OrmSession, interval: datetime.timedelta
) -> list[models.Worker]:
end = datetime.datetime.now() - interval
begin = datetime.datetime(1970, 1, 1)
return get_workers_last_seen_in_range(session, begin, end)


def get_active_workers(
session: OrmSession, interval: datetime.timedelta
) -> list[models.Worker]:
end = datetime.datetime.now()
begin = end - interval
return get_workers_last_seen_in_range(session, begin, end)


def update_worker_last_seen(
session: OrmSession, worker: models.Worker
) -> models.Worker:
worker.last_seen_on = datetime.datetime.now()
session.add(worker)
return worker
4 changes: 3 additions & 1 deletion backend/src/mirrors_qa_backend/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def is_country_row(tag: Tag) -> bool:
resp = requests.get(Settings.MIRRORS_URL, timeout=Settings.REQUESTS_TIMEOUT)
resp.raise_for_status()
except requests.RequestException as exc:
raise MirrorsRequestError from exc
raise MirrorsRequestError(
"network error while fetching mirrors from url"
) from exc

soup = BeautifulSoup(resp.text, features="html.parser")
body = soup.find("tbody")
Expand Down
8 changes: 4 additions & 4 deletions backend/src/mirrors_qa_backend/routes/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from fastapi import status as status_codes

from mirrors_qa_backend import schemas, serializer
from mirrors_qa_backend.db import tests
from mirrors_qa_backend.db import tests, worker
from mirrors_qa_backend.enums import SortDirectionEnum, StatusEnum, TestSortColumnEnum
from mirrors_qa_backend.routes.dependencies import (
CurrentWorker,
Expand Down Expand Up @@ -78,7 +78,7 @@ def get_test(test: RetrievedTest) -> schemas.Test:
)
def update_test(
session: DbSession,
worker: CurrentWorker,
current_worker: CurrentWorker,
test: RetrievedTest,
update: schemas.UpdateTestModel,
) -> schemas.Test:
Expand All @@ -87,7 +87,7 @@ def update_test(
updated_test = tests.create_or_update_test(
session,
test_id=test.id,
worker_id=worker.id,
worker_id=current_worker.id,
status=body.status,
error=body.error,
ip_address=body.ip_address,
Expand All @@ -99,5 +99,5 @@ def update_test(
duration=body.duration,
speed=body.speed,
)

worker.update_worker_last_seen(session, current_worker)
return serializer.serialize_test(updated_test)
75 changes: 75 additions & 0 deletions backend/src/mirrors_qa_backend/scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import datetime
import time

from mirrors_qa_backend import logger
from mirrors_qa_backend.db import Session, tests, worker
from mirrors_qa_backend.enums import StatusEnum
from mirrors_qa_backend.settings import Settings


def main():
while True:
with Session.begin() as session:
# expire tesst whose results have not been reported
expired_tests = tests.expire_tests(
session,
interval=datetime.timedelta(hours=Settings.EXPIRE_TEST_INTERVAL),
)
for expired_test in expired_tests:
logger.info(
f"Expired test {expired_test.id}, "
f"country: {expired_test.country}, "
f"worker: {expired_test.worker_id}"
)

idle_workers = worker.get_idle_workers(
session,
interval=datetime.timedelta(hours=Settings.IDLE_WORKER_INTERVAL),
)
if not idle_workers:
logger.info("No idle workers found.")

# Create tests for the countries the worker is responsible for..
for idle_worker in idle_workers:
if not idle_worker.countries:
logger.info(
f"No countries registered for idle worker {idle_worker.id}"
)
continue
for country in idle_worker.countries:
# While we have expired "unreported" tests, it is possible that
# a test for a country might still be PENDING as the interval
# for expiration and that of the scheduler might overlap.
# In such scenarios, we skip creating a test for that country.
pending_tests = tests.list_tests(
session,
worker_id=idle_worker.id,
statuses=[StatusEnum.PENDING],
country=country.name,
)

if pending_tests.nb_tests:
logger.info(
"Skipping creation of new test entries for "
f"{idle_worker.id} as {pending_tests.nb_tests} "
"tests are still pending."
)
continue

new_test = tests.create_test(
session=session,
worker_id=idle_worker.id,
country=country.name,
status=StatusEnum.PENDING,
)
logger.info(
f"Created new test {new_test.id} for worker "
f"{idle_worker.id} in country {country.name}"
)

sleep_interval = datetime.timedelta(
hours=Settings.SCHEDULER_SLEEP_INTERVAL
).total_seconds()

logger.info(f"Sleeping for {sleep_interval} seconds.")
time.sleep(sleep_interval)
Loading

0 comments on commit c80736d

Please sign in to comment.