diff --git a/transmogrifier/cli.py b/transmogrifier/cli.py index 288f230..ec168f9 100644 --- a/transmogrifier/cli.py +++ b/transmogrifier/cli.py @@ -5,10 +5,7 @@ import click from transmogrifier.config import SOURCES, configure_logger, configure_sentry -from transmogrifier.helpers import ( - write_deleted_records_to_file, - write_timdex_records_to_json, -) +from transmogrifier.helpers import write_deleted_records_to_file from transmogrifier.sources.transformer import Transformer logger = logging.getLogger(__name__) @@ -44,10 +41,8 @@ def main(source, input_file, output_file, verbose): logger.info(configure_sentry()) logger.info("Running transform for source %s", source) - transformer_class = Transformer.get_transformer(source) - source_records = transformer_class.parse_source_file(input_file) - transformer = transformer_class(source, source_records) - write_timdex_records_to_json(transformer, output_file) + transformer = Transformer.load(source, input_file) + transformer.write_timdex_records_to_json(output_file) if transformer.processed_record_count == 0: raise ValueError("No records processed from input file, needs investigation") if deleted_records := transformer.deleted_records: diff --git a/transmogrifier/helpers.py b/transmogrifier/helpers.py index f3f28dd..c749730 100644 --- a/transmogrifier/helpers.py +++ b/transmogrifier/helpers.py @@ -1,18 +1,10 @@ -import json import logging -import os from datetime import datetime -from typing import TYPE_CHECKING, Optional +from typing import Optional -from attrs import asdict from smart_open import open from transmogrifier.config import DATE_FORMATS -from transmogrifier.models import TimdexRecord - -# import XmlTransformer only when type checking to avoid circular dependency -if TYPE_CHECKING: # pragma: no cover - from transmogrifier.sources.transformer import Transformer logger = logging.getLogger(__name__) @@ -154,37 +146,6 @@ def write_deleted_records_to_file(deleted_records: list[str], output_file_path: file.write(f"{record_id}\n") -def write_timdex_records_to_json( - transformer_instance: "Transformer", output_file_path: str -) -> int: - count = 0 - try: - record: TimdexRecord = next(transformer_instance) - except StopIteration: - return count - with open(output_file_path, "w") as file: - file.write("[\n") - while record: - file.write( - json.dumps( - asdict(record, filter=lambda attr, value: value is not None), - indent=2, - ) - ) - count += 1 - if count % int(os.getenv("STATUS_UPDATE_INTERVAL", 1000)) == 0: - logger.info( - "Status update: %s records written to output file so far!", count - ) - try: - record: TimdexRecord = next(transformer_instance) # type: ignore[no-redef] # noqa: E501 - except StopIteration: - break - file.write(",\n") - file.write("\n]") - return count - - class DeletedRecord(Exception): """Exception raised for records with a deleted status. diff --git a/transmogrifier/sources/transformer.py b/transmogrifier/sources/transformer.py index 17f756c..1fc0c5d 100644 --- a/transmogrifier/sources/transformer.py +++ b/transmogrifier/sources/transformer.py @@ -1,11 +1,14 @@ """Transformer module.""" from __future__ import annotations +import json import logging +import os from abc import ABCMeta, abstractmethod from importlib import import_module from typing import Iterator, Optional, TypeAlias, final +from attrs import asdict from bs4 import BeautifulSoup, Tag # Note: the lxml module in defusedxml is deprecated, so we have to use the @@ -69,12 +72,68 @@ def __next__(self) -> TimdexRecord: continue @final - @staticmethod - def get_transformer(source: str) -> type[Transformer]: + def write_timdex_records_to_json(self, output_file: str) -> int: + """ + Write TIMDEX records to JSON file. + + Args: + output_file: The JSON file used for writing TIMDEX records. + + """ + count = 0 + try: + record: TimdexRecord = next(self) + except StopIteration: + return count + with open(output_file, "w") as file: + file.write("[\n") + while record: + file.write( + json.dumps( + asdict(record, filter=lambda attr, value: value is not None), + indent=2, + ) + ) + count += 1 + if count % int(os.getenv("STATUS_UPDATE_INTERVAL", 1000)) == 0: + logger.info( + "Status update: %s records written to output file so far!", + count, + ) + try: + record: TimdexRecord = next(self) # type: ignore[no-redef] # noqa: E501 + except StopIteration: + break + file.write(",\n") + file.write("\n]") + return count + + @final + @classmethod + def load(cls, source: str, source_file: str) -> Transformer: + """ + Instantiate specified transformer class and populate with source records. + + Args: + source: Source repository label. Must match a source key from config.SOURCES. + source_file: A file containing source records to be transformed. + """ + transformer_class = cls.get_transformer(source) + source_records = transformer_class.parse_source_file(source_file) + transformer = transformer_class(source, source_records) + return transformer + + @final + @classmethod + def get_transformer(cls, source: str) -> type[Transformer]: """ Return configured transformer class for a source. Source must be configured with a valid transform class path. + + Args: + source: Source repository label. Must match a source key from config.SOURCES. + """ module_name, class_name = SOURCES[source]["transform-class"].rsplit(".", 1) source_module = import_module(module_name) @@ -82,11 +141,14 @@ def get_transformer(source: str) -> type[Transformer]: @classmethod @abstractmethod - def parse_source_file(cls, input_file: str) -> Iterator[JSON | Tag]: + def parse_source_file(cls, source_file: str) -> Iterator[JSON | Tag]: """ Parse source file and return source records via an iterator. Must be overridden by format subclasses. + + Args: + source_file: A file containing source records to be transformed. """ pass @@ -219,13 +281,16 @@ class XmlTransformer(Transformer): @final @classmethod - def parse_source_file(cls, input_file: str) -> Iterator[Tag]: + def parse_source_file(cls, source_file: str) -> Iterator[Tag]: """ Parse XML file and return source records as bs4 Tags via an iterator. May not be overridden. + + Args: + source_file: A file containing source records to be transformed. """ - with open(input_file, "rb") as file: + with open(source_file, "rb") as file: for _, element in etree.iterparse( file, tag="{*}record",