diff --git a/examples/maps/custom-entrez-2-string.py b/examples/maps/custom-entrez-2-string.py
index de2c07a..c799c58 100644
--- a/examples/maps/custom-entrez-2-string.py
+++ b/examples/maps/custom-entrez-2-string.py
@@ -1,16 +1,7 @@
-from koza.cli_utils import get_koza_app
+from koza.runner import KozaTransform
-source_name = 'custom-map-protein-links-detailed'
-map_name = 'custom-entrez-2-string'
-
-koza_app = get_koza_app(source_name)
-
-row = koza_app.get_row(map_name)
-
-map = koza_app.get_map(map_name)
-
-entry = dict()
-
-entry["entrez"] = row["entrez"]
-
-map[row["STRING"]] = entry
+def transform_record(koza: KozaTransform, record: dict):
+ koza.write({
+ "STRING": record['STRING'],
+ "entrez": record["entrez"],
+ })
diff --git a/examples/maps/custom-entrez-2-string.yaml b/examples/maps/custom-entrez-2-string.yaml
index eeb9839..a159320 100644
--- a/examples/maps/custom-entrez-2-string.yaml
+++ b/examples/maps/custom-entrez-2-string.yaml
@@ -3,23 +3,25 @@ name: 'custom-entrez-2-string'
metadata:
description: 'Mapping file provided by StringDB that contains entrez to protein ID mappings'
-delimiter: '\t'
-header_delimiter: '/'
+reader:
+ delimiter: '\t'
+ header_prefix: '#'
+ header_delimiter: '/'
-# Assumes that no identifiers are overlapping
-# otherwise these should go into separate configs
-files:
- - './examples/data/entrez-2-string.tsv'
- - './examples/data/additional-entrez-2-string.tsv'
+ # Assumes that no identifiers are overlapping
+ # otherwise these should go into separate configs
+ files:
+ - './examples/data/entrez-2-string.tsv'
+ - './examples/data/additional-entrez-2-string.tsv'
-header: 0
+ header_mode: 0
-columns:
+ columns:
- 'NCBI taxid'
- 'entrez'
- 'STRING'
-key: 'STRING'
-
-values:
- - 'entrez'
+transform:
+ key: 'STRING'
+ values:
+ - 'entrez'
diff --git a/examples/maps/entrez-2-string.yaml b/examples/maps/entrez-2-string.yaml
index b5be71d..6c0b586 100644
--- a/examples/maps/entrez-2-string.yaml
+++ b/examples/maps/entrez-2-string.yaml
@@ -3,23 +3,25 @@ name: 'entrez-2-string'
metadata:
description: 'Mapping file provided by StringDB that contains entrez to protein ID mappings'
-delimiter: '\t'
-header_delimiter: '/'
-header: 0
-comment_char: '#'
+reader:
+ delimiter: '\t'
+ header_delimiter: '/'
+ header_mode: 0
+ header_prefix: '#'
+ comment_char: '#'
-# Assumes that no identifiers are overlapping
-# otherwise these should go into separate configs
-files:
- - './examples/data/entrez-2-string.tsv'
- - './examples/data/additional-entrez-2-string.tsv'
+ # Assumes that no identifiers are overlapping
+ # otherwise these should go into separate configs
+ files:
+ - './examples/data/entrez-2-string.tsv'
+ - './examples/data/additional-entrez-2-string.tsv'
-columns:
- - 'NCBI taxid'
- - 'entrez'
- - 'STRING'
+ columns:
+ - 'NCBI taxid'
+ - 'entrez'
+ - 'STRING'
-key: 'STRING'
-
-values:
- - 'entrez'
+transform:
+ key: 'STRING'
+ values:
+ - 'entrez'
diff --git a/examples/minimal.py b/examples/minimal.py
new file mode 100644
index 0000000..7259161
--- /dev/null
+++ b/examples/minimal.py
@@ -0,0 +1,4 @@
+from koza.runner import KozaTransform
+
+def transform(koza: KozaTransform):
+ pass
diff --git a/examples/string-declarative/declarative-protein-links-detailed.py b/examples/string-declarative/declarative-protein-links-detailed.py
index e06967d..d3b0205 100644
--- a/examples/string-declarative/declarative-protein-links-detailed.py
+++ b/examples/string-declarative/declarative-protein-links-detailed.py
@@ -1,24 +1,22 @@
import re
+from typing import Any
import uuid
from biolink_model.datamodel.pydanticmodel_v2 import PairwiseGeneToGeneInteraction, Protein
-from koza.cli_utils import get_koza_app
+from koza.runner import KozaTransform
-koza_app = get_koza_app("declarative-protein-links-detailed")
+def transform_record(koza: KozaTransform, record: dict[str, Any]):
+ protein_a = Protein(id="ENSEMBL:" + re.sub(r"\d+\.", "", record["protein1"]))
+ protein_b = Protein(id="ENSEMBL:" + re.sub(r"\d+\.", "", record["protein2"]))
-row = koza_app.get_row()
+ pairwise_gene_to_gene_interaction = PairwiseGeneToGeneInteraction(
+ id="uuid:" + str(uuid.uuid1()),
+ subject=protein_a.id,
+ object=protein_b.id,
+ predicate="biolink:interacts_with",
+ knowledge_level="not_provided",
+ agent_type="not_provided",
+ )
-protein_a = Protein(id="ENSEMBL:" + re.sub(r"\d+\.", "", row["protein1"]))
-protein_b = Protein(id="ENSEMBL:" + re.sub(r"\d+\.", "", row["protein2"]))
-
-pairwise_gene_to_gene_interaction = PairwiseGeneToGeneInteraction(
- id="uuid:" + str(uuid.uuid1()),
- subject=protein_a.id,
- object=protein_b.id,
- predicate="biolink:interacts_with",
- knowledge_level="not_provided",
- agent_type="not_provided",
-)
-
-koza_app.write(protein_a, protein_b, pairwise_gene_to_gene_interaction)
+ koza.write(protein_a, protein_b, pairwise_gene_to_gene_interaction)
diff --git a/examples/string-declarative/declarative-protein-links-detailed.yaml b/examples/string-declarative/declarative-protein-links-detailed.yaml
index aa0345c..a1962f6 100644
--- a/examples/string-declarative/declarative-protein-links-detailed.yaml
+++ b/examples/string-declarative/declarative-protein-links-detailed.yaml
@@ -1,49 +1,51 @@
name: 'declarative-protein-links-detailed'
-delimiter: ' '
-
-files:
- - './examples/data/string.tsv'
- - './examples/data/string2.tsv'
-
metadata:
ingest_title: 'String DB'
ingest_url: 'https://string-db.org'
description: 'STRING: functional protein association networks'
rights: 'https://string-db.org/cgi/access.pl?footer_active_subpage=licensing'
-global_table: './examples/translation_table.yaml'
-
-columns:
- - 'protein1'
- - 'protein2'
- - 'neighborhood'
- - 'fusion'
- - 'cooccurence'
- - 'coexpression'
- - 'experimental'
- - 'database'
- - 'textmining'
- - 'combined_score' : 'int'
-
-filters:
- - inclusion: 'include'
- column: 'combined_score'
- filter_code: 'lt'
- value: 700
-
-transform_mode: 'flat'
-
-node_properties:
- - 'id'
- - 'category'
- - 'provided_by'
-
-edge_properties:
- - 'id'
- - 'subject'
- - 'predicate'
- - 'object'
- - 'category'
- - 'relation'
- - 'provided_by'
\ No newline at end of file
+reader:
+ format: csv
+
+ delimiter: ' '
+
+ files:
+ - './examples/data/string.tsv'
+ - './examples/data/string2.tsv'
+
+ columns:
+ - 'protein1'
+ - 'protein2'
+ - 'neighborhood'
+ - 'fusion'
+ - 'cooccurence'
+ - 'coexpression'
+ - 'experimental'
+ - 'database'
+ - 'textmining'
+ - 'combined_score' : 'int'
+
+
+transform:
+ filters:
+ - inclusion: 'include'
+ column: 'combined_score'
+ filter_code: 'lt'
+ value: 700
+
+writer:
+ node_properties:
+ - 'id'
+ - 'category'
+ - 'provided_by'
+
+ edge_properties:
+ - 'id'
+ - 'subject'
+ - 'predicate'
+ - 'object'
+ - 'category'
+ - 'relation'
+ - 'provided_by'
diff --git a/examples/string-w-custom-map/custom-map-protein-links-detailed.py b/examples/string-w-custom-map/custom-map-protein-links-detailed.py
index 1a3ef4f..b239411 100644
--- a/examples/string-w-custom-map/custom-map-protein-links-detailed.py
+++ b/examples/string-w-custom-map/custom-map-protein-links-detailed.py
@@ -2,23 +2,23 @@
from biolink_model.datamodel.pydanticmodel_v2 import Gene, PairwiseGeneToGeneInteraction
-from koza.cli_utils import get_koza_app
+from koza.runner import KozaTransform
-source_name = "custom-map-protein-links-detailed"
-koza_app = get_koza_app(source_name)
-row = koza_app.get_row()
-entrez_2_string = koza_app.get_map("custom-entrez-2-string")
+def transform_record(koza: KozaTransform, record: dict):
+ a = record["protein1"]
+ b = record["protein2"]
+ mapped_a = koza.lookup(a, "entrez")
+ mapped_b = koza.lookup(b, "entrez")
+ gene_a = Gene(id="NCBIGene:" + mapped_a)
+ gene_b = Gene(id="NCBIGene:" + mapped_b)
-gene_a = Gene(id="NCBIGene:" + entrez_2_string[row["protein1"]]["entrez"])
-gene_b = Gene(id="NCBIGene:" + entrez_2_string[row["protein2"]]["entrez"])
+ pairwise_gene_to_gene_interaction = PairwiseGeneToGeneInteraction(
+ id="uuid:" + str(uuid.uuid1()),
+ subject=gene_a.id,
+ object=gene_b.id,
+ predicate="biolink:interacts_with",
+ knowledge_level="not_provided",
+ agent_type="not_provided",
+ )
-pairwise_gene_to_gene_interaction = PairwiseGeneToGeneInteraction(
- id="uuid:" + str(uuid.uuid1()),
- subject=gene_a.id,
- object=gene_b.id,
- predicate="biolink:interacts_with",
- knowledge_level="not_provided",
- agent_type="not_provided",
-)
-
-koza_app.write(gene_a, gene_b, pairwise_gene_to_gene_interaction)
+ koza.write(gene_a, gene_b, pairwise_gene_to_gene_interaction)
diff --git a/examples/string-w-custom-map/custom-map-protein-links-detailed.yaml b/examples/string-w-custom-map/custom-map-protein-links-detailed.yaml
index 6bf01cb..34863df 100644
--- a/examples/string-w-custom-map/custom-map-protein-links-detailed.yaml
+++ b/examples/string-w-custom-map/custom-map-protein-links-detailed.yaml
@@ -1,46 +1,47 @@
name: 'custom-map-protein-links-detailed'
-delimiter: ' '
-
-files:
- - './examples/data/string.tsv'
- - './examples/data/string2.tsv'
-
metadata: !include './examples/string-w-custom-map/metadata.yaml'
-columns:
- - 'protein1'
- - 'protein2'
- - 'neighborhood'
- - 'fusion'
- - 'cooccurence'
- - 'coexpression'
- - 'experimental'
- - 'database'
- - 'textmining'
- - 'combined_score' : 'int'
-
-filters:
- - inclusion: 'include'
- column: 'combined_score'
- filter_code: 'lt'
- value: 700
-
-depends_on:
- - 'examples/maps/custom-entrez-2-string.yaml'
-
-transform_mode: 'flat'
-
-node_properties:
- - 'id'
- - 'category'
- - 'provided_by'
-
-edge_properties:
- - 'id'
- - 'subject'
- - 'predicate'
- - 'object'
- - 'category'
- - 'relation'
- - 'provided_by'
\ No newline at end of file
+reader:
+ delimiter: ' '
+
+ files:
+ - './examples/data/string.tsv'
+ - './examples/data/string2.tsv'
+
+ columns:
+ - 'protein1'
+ - 'protein2'
+ - 'neighborhood'
+ - 'fusion'
+ - 'cooccurence'
+ - 'coexpression'
+ - 'experimental'
+ - 'database'
+ - 'textmining'
+ - 'combined_score' : 'int'
+
+transform:
+ filters:
+ - inclusion: 'include'
+ column: 'combined_score'
+ filter_code: 'lt'
+ value: 700
+
+ mappings:
+ - 'examples/maps/custom-entrez-2-string.yaml'
+
+writer:
+ node_properties:
+ - 'id'
+ - 'category'
+ - 'provided_by'
+
+ edge_properties:
+ - 'id'
+ - 'subject'
+ - 'predicate'
+ - 'object'
+ - 'category'
+ - 'relation'
+ - 'provided_by'
diff --git a/examples/string-w-map/map-protein-links-detailed.py b/examples/string-w-map/map-protein-links-detailed.py
index 95b87b8..b239411 100644
--- a/examples/string-w-map/map-protein-links-detailed.py
+++ b/examples/string-w-map/map-protein-links-detailed.py
@@ -2,29 +2,23 @@
from biolink_model.datamodel.pydanticmodel_v2 import Gene, PairwiseGeneToGeneInteraction
-from koza.cli_utils import get_koza_app
-
-source_name = "map-protein-links-detailed"
-map_name = "entrez-2-string"
-
-koza_app = get_koza_app(source_name)
-row = koza_app.get_row()
-koza_map = koza_app.get_map(map_name)
-
-from loguru import logger
-
-logger.info(koza_map)
-
-gene_a = Gene(id="NCBIGene:" + koza_map[row["protein1"]]["entrez"])
-gene_b = Gene(id="NCBIGene:" + koza_map[row["protein2"]]["entrez"])
-
-pairwise_gene_to_gene_interaction = PairwiseGeneToGeneInteraction(
- id="uuid:" + str(uuid.uuid1()),
- subject=gene_a.id,
- object=gene_b.id,
- predicate="biolink:interacts_with",
- knowledge_level="not_provided",
- agent_type="not_provided",
-)
-
-koza_app.write(gene_a, gene_b, pairwise_gene_to_gene_interaction)
+from koza.runner import KozaTransform
+
+def transform_record(koza: KozaTransform, record: dict):
+ a = record["protein1"]
+ b = record["protein2"]
+ mapped_a = koza.lookup(a, "entrez")
+ mapped_b = koza.lookup(b, "entrez")
+ gene_a = Gene(id="NCBIGene:" + mapped_a)
+ gene_b = Gene(id="NCBIGene:" + mapped_b)
+
+ pairwise_gene_to_gene_interaction = PairwiseGeneToGeneInteraction(
+ id="uuid:" + str(uuid.uuid1()),
+ subject=gene_a.id,
+ object=gene_b.id,
+ predicate="biolink:interacts_with",
+ knowledge_level="not_provided",
+ agent_type="not_provided",
+ )
+
+ koza.write(gene_a, gene_b, pairwise_gene_to_gene_interaction)
diff --git a/examples/string-w-map/map-protein-links-detailed.yaml b/examples/string-w-map/map-protein-links-detailed.yaml
index 53dab1c..4d84d14 100644
--- a/examples/string-w-map/map-protein-links-detailed.yaml
+++ b/examples/string-w-map/map-protein-links-detailed.yaml
@@ -1,46 +1,46 @@
name: 'map-protein-links-detailed'
-delimiter: ' '
-
-files:
- - './examples/data/string.tsv'
- - './examples/data/string2.tsv'
-
metadata: !include './examples/string-w-map/metadata.yaml'
-columns:
- - 'protein1'
- - 'protein2'
- - 'neighborhood'
- - 'fusion'
- - 'cooccurence'
- - 'coexpression'
- - 'experimental'
- - 'database'
- - 'textmining'
- - 'combined_score' : 'int'
-
-filters:
- - inclusion: 'include'
- column: 'combined_score'
- filter_code: 'lt'
- value: 700
-
-depends_on:
- - './examples/maps/entrez-2-string.yaml'
-
-transform_mode: 'flat'
-
-node_properties:
- - 'id'
- - 'category'
- - 'provided_by'
-
-edge_properties:
- - 'id'
- - 'subject'
- - 'predicate'
- - 'object'
- - 'category'
- - 'relation'
- - 'provided_by'
\ No newline at end of file
+reader:
+ format: csv
+ delimiter: ' '
+ files:
+ - './examples/data/string.tsv'
+ - './examples/data/string2.tsv'
+
+ columns:
+ - 'protein1'
+ - 'protein2'
+ - 'neighborhood'
+ - 'fusion'
+ - 'cooccurence'
+ - 'coexpression'
+ - 'experimental'
+ - 'database'
+ - 'textmining'
+ - 'combined_score' : 'int'
+
+transform:
+ filters:
+ - inclusion: 'include'
+ column: 'combined_score'
+ filter_code: 'lt'
+ value: 700
+ mappings:
+ - './examples/maps/entrez-2-string.yaml'
+
+writer:
+ node_properties:
+ - 'id'
+ - 'category'
+ - 'provided_by'
+
+ edge_properties:
+ - 'id'
+ - 'subject'
+ - 'predicate'
+ - 'object'
+ - 'category'
+ - 'relation'
+ - 'provided_by'
diff --git a/examples/string/protein-links-detailed.py b/examples/string/protein-links-detailed.py
index 9ce03ae..3539f33 100644
--- a/examples/string/protein-links-detailed.py
+++ b/examples/string/protein-links-detailed.py
@@ -1,23 +1,23 @@
import re
import uuid
-from biolink_model.datamodel.pydanticmodel_v2 import PairwiseGeneToGeneInteraction, Protein
+from biolink_model.datamodel.pydanticmodel_v2 import (
+ PairwiseGeneToGeneInteraction, Protein)
+from koza.runner import KozaTransform
-from koza.cli_utils import get_koza_app
-koza_app = get_koza_app('protein-links-detailed')
+def transform(koza: KozaTransform):
+ for row in koza.data:
+ protein_a = Protein(id='ENSEMBL:' + re.sub(r'\d+\.', '', row['protein1']))
+ protein_b = Protein(id='ENSEMBL:' + re.sub(r'\d+\.', '', row['protein2']))
-for row in koza_app.source:
- protein_a = Protein(id='ENSEMBL:' + re.sub(r'\d+\.', '', row['protein1']))
- protein_b = Protein(id='ENSEMBL:' + re.sub(r'\d+\.', '', row['protein2']))
+ pairwise_gene_to_gene_interaction = PairwiseGeneToGeneInteraction(
+ id="uuid:" + str(uuid.uuid1()),
+ subject=protein_a.id,
+ object=protein_b.id,
+ predicate="biolink:interacts_with",
+ knowledge_level="not_provided",
+ agent_type="not_provided",
+ )
- pairwise_gene_to_gene_interaction = PairwiseGeneToGeneInteraction(
- id="uuid:" + str(uuid.uuid1()),
- subject=protein_a.id,
- object=protein_b.id,
- predicate="biolink:interacts_with",
- knowledge_level="not_provided",
- agent_type="not_provided",
- )
-
- koza_app.write(protein_a, protein_b, pairwise_gene_to_gene_interaction)
+ koza.write(protein_a, protein_b, pairwise_gene_to_gene_interaction)
diff --git a/examples/string/protein-links-detailed.yaml b/examples/string/protein-links-detailed.yaml
index d41cb4a..ddbb1fc 100644
--- a/examples/string/protein-links-detailed.yaml
+++ b/examples/string/protein-links-detailed.yaml
@@ -1,35 +1,33 @@
name: 'protein-links-detailed'
-
-delimiter: ' '
-
-files:
- - './examples/data/string.tsv'
- - './examples/data/string2.tsv'
-
metadata: !include './examples/string/metadata.yaml'
-columns: !include './examples/standards/string.yaml'
-
-filters:
- - inclusion: 'include'
- column: 'combined_score'
- filter_code: 'lt'
- value: 700
-
-transform_code: './examples/string/protein-links-detailed.py'
-
-transform_mode: 'loop'
-
-node_properties:
- - 'id'
- - 'category'
- - 'provided_by'
-
-edge_properties:
- - 'id'
- - 'subject'
- - 'predicate'
- - 'object'
- - 'category'
- - 'relation'
- - 'provided_by'
+reader:
+ format: csv
+ files:
+ - './examples/data/string.tsv'
+ - './examples/data/string2.tsv'
+ delimiter: ' '
+ columns: !include './examples/standards/string.yaml'
+
+transform:
+ code: './examples/string/protein-links-detailed.py'
+ filters:
+ - inclusion: 'include'
+ column: 'combined_score'
+ filter_code: 'lt'
+ value: 700
+
+writer:
+ node_properties:
+ - 'id'
+ - 'category'
+ - 'provided_by'
+
+ edge_properties:
+ - 'id'
+ - 'subject'
+ - 'predicate'
+ - 'object'
+ - 'category'
+ - 'relation'
+ - 'provided_by'
diff --git a/pyproject.toml b/pyproject.toml
index 82843c9..ec9d895 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -25,13 +25,13 @@ pyyaml = ">=5.0.0"
requests = "^2.24.0"
sssom = ">=0.4"
typer = ">=0.12.3"
+mergedeep = "1.3.4"
[tool.poetry.dev-dependencies]
black = "^24.4"
ruff = "*"
pytest = ">=6.0.0"
biolink-model = ">=4.2"
-dask = ">=2022.5.2"
mkdocs = ">=1.4"
mkdocs-material = ">=9.5"
mkdocstrings = {extras = ["python"], version = ">=0.22.0"}
diff --git a/src/koza/__init__.py b/src/koza/__init__.py
index bf271f0..71322ad 100644
--- a/src/koza/__init__.py
+++ b/src/koza/__init__.py
@@ -1,3 +1,12 @@
from importlib import metadata
+from koza.model.koza import KozaConfig
+from koza.runner import KozaRunner, KozaTransform
+
__version__ = metadata.version("koza")
+
+__all__ = (
+ 'KozaConfig',
+ 'KozaRunner',
+ 'KozaTransform',
+)
diff --git a/src/koza/io/reader/csv_reader.py b/src/koza/io/reader/csv_reader.py
index 18bbd81..facb200 100644
--- a/src/koza/io/reader/csv_reader.py
+++ b/src/koza/io/reader/csv_reader.py
@@ -1,7 +1,7 @@
from csv import reader
-from typing import IO, Any, Dict, Iterator, List, Union
+from typing import IO, Any, Callable, Dict, List
-from koza.model.config.source_config import FieldType, HeaderMode
+from koza.model.reader import FieldType, CSVReaderConfig, HeaderMode
# from koza.utils.log_utils import get_logger
# logger = get_logger(__name__)
@@ -9,7 +9,7 @@
# logger = logging.getLogger(__name__)
from loguru import logger
-FIELDTYPE_CLASS = {
+FIELDTYPE_CLASS: Dict[FieldType, Callable[[str], Any]] = {
FieldType.str: str,
FieldType.int: int,
FieldType.float: float,
@@ -43,50 +43,23 @@ class CSVReader:
def __init__(
self,
io_str: IO[str],
- field_type_map: Dict[str, FieldType] = None,
- delimiter: str = ",",
- header: Union[int, HeaderMode] = HeaderMode.infer,
- header_delimiter: str = None,
- header_prefix: str = None,
- dialect: str = "excel",
- skip_blank_lines: bool = True,
- name: str = "csv file",
- comment_char: str = "#",
- row_limit: int = None,
+ config: CSVReaderConfig,
+ row_limit: int = 0,
*args,
**kwargs,
):
"""
:param io_str: Any IO stream that yields a string
See https://docs.python.org/3/library/io.html#io.IOBase
- :param field_type_map: A dictionary of field names and their type (using the FieldType enum)
- :param delimiter: Field delimiter (eg. '\t' ',' ' ')
- :param header: 0 based index of the file that contains the header,
- or header mode 'infer'|'none' ( default='infer' )
- if 'infer' will use the first non-empty and uncommented line
- if 'none' will use the user supplied columns in field_type_map keys,
- if field_type_map is None this will raise a ValueError
-
- :param header_delimiter: delimiter for the header row, default = self.delimiter
- :param header_prefix: prefix for the header row, default = None
- :param dialect: csv dialect, default=excel
- :param skip_blank_lines: true to skip blank lines, false to insert NaN for blank lines,
- :param name: filename or alias
- :param comment_char: string representing a commented line, eg # or !!
+ :param config: A configuration for the CSV reader. See model/config/source_config.py
:param row_limit: int number of lines to process
:param args: additional args to pass to csv.reader
:param kwargs: additional kwargs to pass to csv.reader
"""
self.io_str = io_str
- self.field_type_map = field_type_map
- self.dialect = dialect
- self.header = header
- self.header_delimiter = header_delimiter if header_delimiter else delimiter
- self.header_prefix = header_prefix
- self.skip_blank_lines = skip_blank_lines
- self.name = name
- self.comment_char = comment_char
+ self.config = config
self.row_limit = row_limit
+ self.field_type_map = config.field_type_map
# used by _set_header
self.line_num = 0
@@ -96,148 +69,189 @@ def __init__(
self._header = None
- if delimiter == '\\s':
+ delimiter = config.delimiter
+
+ if config.delimiter == '\\s':
delimiter = ' '
- kwargs['dialect'] = dialect
- kwargs['delimiter'] = delimiter
- self.reader = reader(io_str, *args, **kwargs)
+ self.csv_args = args
+ self.csv_kwargs = kwargs
+
+ self.csv_kwargs['dialect'] = config.dialect
+ self.csv_kwargs['delimiter'] = delimiter
+ self.csv_reader = reader(io_str, *self.csv_args, **self.csv_kwargs)
- def __iter__(self) -> Iterator:
- return self
+ @property
+ def header(self):
+ if self._header is None:
+ header = self._consume_header()
+ self._ensure_field_type_map(header)
+ self._compare_headers_to_supplied_columns(header)
+ self._header = header
+ return self._header
- def __next__(self) -> Dict[str, Any]:
- if not self._header:
- self._set_header()
+ def __iter__(self):
+ header = self.header
+ item_ct = 0
+ comment_char = self.config.comment_char
- try:
- if self.line_count == self.row_limit:
+ if self.field_type_map is None:
+ raise ValueError("Field type map not set on CSV source")
+
+ for row in self.csv_reader:
+ if self.row_limit and item_ct >= self.row_limit:
logger.debug("Row limit reached")
- raise StopIteration
- else:
- row = next(self.reader)
- self.line_count += 1
- except StopIteration:
- logger.info(f"Finished processing {self.line_num} rows for {self.name} from file {self.io_str.name}")
- raise StopIteration
- self.line_num = self.reader.line_num
-
- # skip blank lines
- if self.skip_blank_lines:
- while not row:
- row = next(self.reader)
- else:
- row = ['NaN' for _ in range(len(self._header))]
-
- # skip commented lines (this is for footers)
- if self.comment_char is not None:
- while row[0].startswith(self.comment_char):
- row = next(self.reader)
-
- # Check row length discrepancies for each row
- fields_len = len(self._header)
- row_len = len(row)
- stripped_row = [val.strip() for val in row]
-
- # if we've made it here we can convert a row to a dict
- field_map = dict(zip(self._header, stripped_row))
-
- if fields_len > row_len:
- raise ValueError(f"CSV file {self.name} has {fields_len - row_len} fewer columns at {self.reader.line_num}")
-
- elif fields_len < row_len:
- logger.warning(f"CSV file {self.name} has {row_len - fields_len} extra columns at {self.reader.line_num}")
- # # Not sure if this would serve a purpose:
- # if 'extra_cols' not in self.field_type_map:
- # # Create a type map for extra columns
- # self.field_type_map['extra_cols'] = FieldType.str
- # field_map['extra_cols'] = row[fields_len:]
-
- typed_field_map = {}
-
- for field, field_value in field_map.items():
- # Take the value and coerce it using self.field_type_map (field: FieldType)
- # FIELD_TYPE is map of the field_type enum to the python
- # to built-in type or custom extras defined in the source config
- try:
- typed_field_map[field] = FIELDTYPE_CLASS[self.field_type_map[field]](field_value)
- except KeyError as key_error:
- logger.warning(f"Field {field} not found in field_type_map ({key_error})")
-
- return typed_field_map
-
- def _set_header(self):
- if isinstance(self.header, int):
- while self.line_num < self.header:
- next(self.reader)
- self.line_num = self.reader.line_num
- self._header = self._parse_header_line()
-
- if self.field_type_map:
- self._compare_headers_to_supplied_columns()
- else:
- self.field_type_map = {field: FieldType.str for field in self._header}
-
- elif self.header == 'infer':
- self._header = self._parse_header_line(skip_blank_or_commented_lines=True)
- logger.debug(f"headers for {self.name} parsed as {self._header}")
- if self.field_type_map:
- self._compare_headers_to_supplied_columns()
- else:
- self.field_type_map = {field: FieldType.str for field in self._header}
-
- elif self.header == 'none':
- if self.field_type_map:
- self._header = list(self.field_type_map.keys())
- else:
+ return
+
+ if not row:
+ if self.config.skip_blank_lines:
+ continue
+ else:
+ row = ['NaN' for _ in range(len(header))]
+
+ elif comment_char and row[0].startswith(comment_char):
+ continue
+
+ row = [val.strip() for val in row]
+ item = dict(zip(header, row))
+
+ if len(item) > len(header):
+ num_extra_fields = len(item) - len(header)
+ logger.warning(
+ f"CSV file {self.io_str.name} has {num_extra_fields} extra columns at {self.csv_reader.line_num}"
+ )
+
+ if len(header) > len(item):
+ num_missing_columns = len(header) - len(item)
+ raise ValueError(
+ f"CSV file {self.io_str.name} is missing {num_missing_columns} "
+ f"column(s) at {self.csv_reader.line_num}"
+ )
+
+ typed_item: dict[str, Any] = {}
+
+ for k, v in item.items():
+ field_type = self.field_type_map.get(k, None)
+ if field_type is None:
+ # FIXME: is this the right behavior? Or should we raise an error?
+ # raise ValueError(f"No field type found for field {k}")
+ field_type = FieldType.str
+
+ # By default, use `str` as a converter (essentially a noop)
+ converter = FIELDTYPE_CLASS.get(field_type, str)
+
+ typed_item[k] = converter(v)
+
+ item_ct += 1
+ yield typed_item
+
+ logger.info(f"Finished processing {item_ct} rows for from file {self.io_str.name}")
+
+ def _consume_header(self):
+ if self.csv_reader.line_num > 0:
+ raise RuntimeError("Can only set header at beginning of file.")
+
+ if self.config.header_mode == HeaderMode.none:
+ if self.config.field_type_map is None:
raise ValueError(
- "there is no header and columns have not been supplied\n"
- "configure the 'columns' property in the source yaml"
+ "Header mode was set to 'none', but no columns were supplied.\n"
+ "Configure the 'columns' property in the transform yaml."
)
+ return list(self.config.field_type_map.keys())
+
+ if self.config.header_mode == HeaderMode.infer:
+ # logger.debug(f"headers for {self.name} parsed as {self._header}")
+ return self._parse_header_line(skip_blank_or_commented_lines=True)
+ elif isinstance(self.config.header_mode, int):
+ while self.csv_reader.line_num < self.config.header_mode:
+ next(self.csv_reader)
+ return self._parse_header_line()
+ else:
+ raise ValueError(f"Invalid header mode given: {self.config.header_mode}.")
def _parse_header_line(self, skip_blank_or_commented_lines: bool = False) -> List[str]:
"""
Parse the header line and return a list of headers
"""
- fieldnames = next(reader(self.io_str, **{'delimiter': self.header_delimiter, 'dialect': self.dialect}))
- if self.header_prefix and fieldnames[0].startswith(self.header_prefix):
- fieldnames[0] = fieldnames[0].lstrip(self.header_prefix)
+ header_prefix = self.config.header_prefix
+ comment_char = self.config.comment_char
+
+ csv_reader = self.csv_reader
+
+ # If the header delimiter is explicitly set create a new CSVReader using that one.
+ if self.config.header_delimiter is not None:
+ kwargs = self.csv_kwargs | { "delimiter": self.config.header_delimiter }
+ csv_reader = reader(self.io_str, *self.csv_args, **kwargs)
+
+ headers = next(csv_reader)
+
+ # If a header_prefix was defined, remove that string from the first record in the first row.
+ # For example, given the header_prefix of "#" and an initial CSV row of:
+ #
+ # #ID,LABEL,DESCRIPTION
+ #
+ # The headers would be ["ID", "LABEL", "DESCRIPTION"].
+ #
+ # This is run before skipping commented lines since a header prefix may be "#", which is the default comment
+ # character.
+ if headers and header_prefix:
+ headers[0] = headers[0].lstrip(header_prefix)
+
if skip_blank_or_commented_lines:
- # there has to be a cleaner way to do this
- while not fieldnames or (self.comment_char is not None and fieldnames[0].startswith(self.comment_char)):
- fieldnames = next(reader(self.io_str, **{'delimiter': self.header_delimiter, 'dialect': self.dialect}))
- fieldnames[0] = fieldnames[0].lstrip(self.comment_char)
- return [f.strip() for f in fieldnames]
+ while True:
+ # Continue if the line is empty
+ if not headers:
+ headers = next(csv_reader)
+ continue
+
+ # Continue if the line starts with a comment character
+ if comment_char and headers[0].startswith(comment_char):
+ headers = next(csv_reader)
+ continue
+
+ break
+
+ return [field.strip() for field in headers]
- def _compare_headers_to_supplied_columns(self):
+ def _ensure_field_type_map(self, header: list[str]):
+ # The field type map is either set explicitly, or derived based on config.columns. If
+ # neither of those are set, then set the field type map based on the parsed headers.
+ if self.field_type_map is None:
+ self.field_type_map = {
+ key: FieldType.str
+ for key in header
+ }
+
+
+ def _compare_headers_to_supplied_columns(self, header: list[str]):
"""
Compares headers to supplied columns
:return:
"""
+ if self.field_type_map is None:
+ raise ValueError("No field type map set for CSV reader")
+
configured_fields = list(self.field_type_map.keys())
- if set(configured_fields) > set(self._header):
+ if set(configured_fields) > set(header):
raise ValueError(
- f"Configured columns missing in source file {self.name}\n"
- f"\t{set(configured_fields) - set(self._header)}"
+ f"Configured columns missing in source file {self.io_str.name}\n"
+ f"\t{set(configured_fields) - set(header)}"
)
- if set(self._header) > set(configured_fields):
+ if set(header) > set(configured_fields):
logger.warning(
- f"Additional column(s) in source file {self.name}\n"
- f"\t{set(self._header) - set(configured_fields)}\n"
+ f"Additional column(s) in source file {self.io_str.name}\n"
+ f"\t{set(header) - set(configured_fields)}\n"
f"\tChecking if new column(s) inserted at end of the row"
)
- # add to type map
- for new_fields in set(self._header) - set(configured_fields):
- self.field_type_map[new_fields] = FieldType.str
# Check if the additional columns are appended
# not sure if this would useful or just noise
- if self._header[: len(configured_fields)] != configured_fields:
+ if header[: len(configured_fields)] != configured_fields:
logger.warning(
f"Additional columns located within configured fields\n"
f"\tgiven: {configured_fields}\n"
- f"\tfound: {self._header}\n"
- f"\tadditional columns: {set(self._header) - set(configured_fields)}"
+ f"\tfound: {header}\n"
+ f"\tadditional columns: {set(header) - set(configured_fields)}"
)
diff --git a/src/koza/io/reader/json_reader.py b/src/koza/io/reader/json_reader.py
index 4b6e8da..5839c4d 100644
--- a/src/koza/io/reader/json_reader.py
+++ b/src/koza/io/reader/json_reader.py
@@ -1,16 +1,11 @@
import json
-import yaml
-from typing import IO, Any, Dict, Iterator, List, Union
-
-# from xmlrpc.client import Boolean
+from typing import IO, Any, Dict, Generator, List, Union
+import yaml
from koza.io.utils import check_data
+from koza.model.reader import JSONReaderConfig, YAMLReaderConfig
-# from koza.utils.log_utils import get_logger
-# logger = get_logger(__name__)
-# import logging
-# logger = logging.getLogger(__name__)
-from loguru import logger
+# FIXME: Add back logging as part of progress
class JSONReader:
@@ -21,75 +16,49 @@ class JSONReader:
def __init__(
self,
io_str: IO[str],
- required_properties: List[str] = None,
- json_path: List[Union[str, int]] = None,
- name: str = 'json file',
- is_yaml: bool = False,
- row_limit: int = None,
+ config: Union[JSONReaderConfig, YAMLReaderConfig],
+ row_limit: int = 0,
):
"""
:param io_str: Any IO stream that yields a string
See https://docs.python.org/3/library/io.html#io.IOBase
- :param required_properties: required top level properties
- :param row_limit: integer number of non-header rows to process
- :param iterate_over: todo
- :param name: todo
+ :param config: The JSON or YAML reader configuration
+ :param row_limit: The number of lines to be read. No limit if 0.
"""
self.io_str = io_str
- self.required_properties = required_properties
- self.json_path = json_path
- self.name = name
+ self.config = config
+ self.row_limit = row_limit
- if self.json_path:
- if is_yaml:
- self.json_obj = yaml.safe_load(self.io_str)
- else:
- self.json_obj = json.load(self.io_str)
- for path in self.json_path:
- self.json_obj = self.json_obj[path]
+ if isinstance(config, YAMLReaderConfig):
+ json_obj = yaml.safe_load(self.io_str)
else:
- if is_yaml:
- self.json_obj = yaml.safe_load(self.io_str)
- else:
- self.json_obj = json.load(self.io_str)
+ json_obj = json.load(self.io_str)
- if isinstance(self.json_obj, list):
- self._len = len(self.json_obj)
- self._line_num = 0
- else:
- self.json_obj = [self.json_obj]
- self._len = 0
- self._line_num = 0
+ if config.json_path:
+ for path in config.json_path:
+ json_obj = json_obj[path]
- if row_limit:
- self._line_limit = row_limit
+ if isinstance(json_obj, list):
+ self.json_obj: List[Any] = json_obj
else:
- self._line_limit = self._len
-
- def __iter__(self) -> Iterator:
- return self
-
- def __next__(self) -> Dict[str, Any]:
- if self._line_num == self._line_limit:
- logger.info(f"Finished processing {self._line_num} rows for {self.name} from file {self.io_str.name}")
- raise StopIteration
+ self.json_obj = [json_obj]
- next_obj = self.json_obj[self._line_num]
+ def __iter__(self) -> Generator[Dict[str, Any], None, None]:
+ for i, item in enumerate(self.json_obj):
+ if self.row_limit and i >= self.row_limit:
+ return
- self._line_num += 1
+ if not isinstance(item, dict):
+ raise ValueError()
- # Check that required properties exist in row
- if self.required_properties:
- properties = []
- for prop in self.required_properties:
- new_prop = check_data(next_obj, prop)
- properties.append(new_prop)
+ if self.config.required_properties:
+ missing_properties = [prop for prop in self.config.required_properties if not check_data(item, prop)]
- if False in properties:
- raise ValueError(
- f"Required properties defined for {self.name} are missing from {self.io_str.name}\n"
- f"Missing properties: {set(self.required_properties) - set(next_obj.keys())}\n"
- f"Row: {next_obj}"
- )
+ if missing_properties:
+ raise ValueError(
+ f"Required properties are missing from {self.io_str.name}\n"
+ f"Missing properties: {missing_properties}\n"
+ f"Row: {item}"
+ )
- return next_obj
+ yield item
diff --git a/src/koza/io/reader/jsonl_reader.py b/src/koza/io/reader/jsonl_reader.py
index 7899e5b..9fad406 100644
--- a/src/koza/io/reader/jsonl_reader.py
+++ b/src/koza/io/reader/jsonl_reader.py
@@ -1,13 +1,10 @@
import json
-from typing import IO, Any, Dict, Iterator, List
+from typing import IO
from koza.io.utils import check_data
+from koza.model.reader import JSONLReaderConfig
-# from koza.utils.log_utils import get_logger
-# logger = get_logger(__name__)
-# import logging
-# logger = logging.getLogger(__name__)
-from loguru import logger
+# FIXME: Add back logging as part of progress
class JSONLReader:
@@ -19,49 +16,33 @@ class JSONLReader:
def __init__(
self,
io_str: IO[str],
- required_properties: List[str] = None,
- name: str = 'jsonl file',
- row_limit: int = None,
+ config: JSONLReaderConfig,
+ row_limit: int = 0,
):
"""
:param io_str: Any IO stream that yields a string
See https://docs.python.org/3/library/io.html#io.IOBase
- :param required_properties: List of required top level properties
- :param name: todo
+ :param config: The JSONL reader configuration
+ :param row_limit: The number of lines to be read. No limit if 0.
"""
self.io_str = io_str
- self.required_properties = required_properties
- self.line_num = 0
- self.name = name
- self.line_limit = row_limit
-
- def __iter__(self) -> Iterator:
- return self
-
- def __next__(self) -> Dict[str, Any]:
- next_line = self.io_str.readline()
- if not next_line:
- logger.info(f"Finished processing {self.line_num} lines for {self.name} from {self.io_str.name}")
- raise StopIteration
- self.line_num += 1
- if self.line_limit:
- if self.line_num == self.line_limit:
- raise StopIteration
-
- json_obj = json.loads(next_line)
-
- # Check that required properties exist in row
- if self.required_properties:
- properties = []
- for prop in self.required_properties:
- new_prop = check_data(json_obj, prop)
- properties.append(new_prop)
-
- if False in properties:
- raise ValueError(
- f"Required properties defined for {self.name} are missing from {self.io_str.name}\n"
- f"Missing properties: {set(self.required_properties) - set(json_obj.keys())}\n"
- f"Row: {json_obj}"
- )
-
- return json_obj
+ self.config = config
+ self.row_limit = row_limit
+
+ def __iter__(self):
+ for i, line in enumerate(self.io_str):
+ if self.row_limit and self.row_limit >= i:
+ return
+
+ item = json.loads(line)
+ if self.config.required_properties:
+ missing_properties = [prop for prop in self.config.required_properties if not check_data(item, prop)]
+
+ if missing_properties:
+ raise ValueError(
+ f"Required properties are missing from {self.io_str.name}\n"
+ f"Missing properties: {missing_properties}\n"
+ f"Row: {item}"
+ )
+
+ yield item
diff --git a/src/koza/io/utils.py b/src/koza/io/utils.py
index c1b4455..2d78f57 100644
--- a/src/koza/io/utils.py
+++ b/src/koza/io/utils.py
@@ -2,23 +2,48 @@
"""
Set of functions to manage input and output
"""
+import dataclasses
import gzip
+import tarfile
import tempfile
from io import TextIOWrapper
from os import PathLike
from pathlib import Path
-from typing import IO, Any, Dict, Union
+from tarfile import TarFile, is_tarfile
+from typing import Any, Callable, Dict, Generator, Optional, TextIO, Union
from zipfile import ZipFile, is_zipfile
import requests
-
######################
### Reader Helpers ###
######################
-def open_resource(resource: Union[str, PathLike]) -> IO[str]:
+@dataclasses.dataclass
+class SizedResource:
+ name: str
+ size: int
+ reader: TextIO
+ tell: Callable[[], int]
+
+
+def is_gzipped(filename: str):
+ with gzip.open(filename, "r") as fh:
+ try:
+ fh.read(1)
+ return True
+ except gzip.BadGzipFile:
+ return False
+ finally:
+ fh.close()
+
+
+def open_resource(resource: Union[str, PathLike]) -> Union[
+ SizedResource,
+ tuple[ZipFile, Generator[SizedResource, None, None]],
+ tuple[TarFile, Generator[SizedResource, None, None]],
+]:
"""
A generic function for opening a local or remote file
@@ -35,43 +60,92 @@ def open_resource(resource: Union[str, PathLike]) -> IO[str]:
:return: str, next line in resource
"""
+
# Check if resource is a remote file
+ resource_name: Optional[Union[str, PathLike]] = None
+
if isinstance(resource, str) and resource.startswith('http'):
- tmp_file = tempfile.TemporaryFile('w+b')
- request = requests.get(resource)
+ tmp_file = tempfile.NamedTemporaryFile('w+b')
+ request = requests.get(resource, timeout=10)
if request.status_code != 200:
raise ValueError(f"Remote file returned {request.status_code}: {request.text}")
tmp_file.write(request.content)
- # request.close() # not sure this is needed
+ request.close()
tmp_file.seek(0)
- if resource.endswith('gz'):
- # This should be more robust, either check headers
- # or use https://github.com/ahupp/python-magic
- remote_file = gzip.open(tmp_file, 'rt')
- return remote_file
- else:
- return TextIOWrapper(tmp_file)
+ resource_name = resource
+ resource = tmp_file.name
+ else:
+ resource_name = resource
# If resource is not remote or local, raise error
- elif not Path(resource).exists():
+ if not Path(resource).exists():
raise ValueError(
- f"Cannot open local or remote file: {resource}. Check the URL/path, and that the file exists, and try again."
+ f"Cannot open local or remote file: {resource}. Check the URL/path, and that the file exists, "
+ "and try again."
)
# If resource is local, check for compression
if is_zipfile(resource):
- with ZipFile(resource, 'r') as zip_file:
- file = TextIOWrapper(zip_file.open(zip_file.namelist()[0], 'r')) # , encoding='utf-8')
- # file = zip_file.read(zip_file.namelist()[0], 'r').decode('utf-8')
- elif str(resource).endswith('gz'):
- file = gzip.open(resource, 'rt')
- file.read(1)
- file.seek(0)
+ zip_fh = ZipFile(resource, 'r')
+
+ def generator():
+ for zip_info in zip_fh.infolist():
+ extracted = zip_fh.open(zip_info, 'r')
+ yield SizedResource(
+ zip_info.filename,
+ zip_info.file_size,
+ TextIOWrapper(extracted),
+ extracted.tell,
+ )
+
+ return zip_fh, generator()
+
+ elif is_tarfile(resource):
+ tar_fh = tarfile.open(resource, mode='r|*')
+
+ def generator():
+ for tarinfo in tar_fh:
+ extracted = tar_fh.extractfile(tarinfo)
+ if extracted:
+ extracted.seekable = lambda: True
+ reader = TextIOWrapper(extracted)
+ yield SizedResource(
+ tarinfo.name,
+ tarinfo.size,
+ reader,
+ reader.tell,
+ )
+
+ return tar_fh, generator()
+
+ elif is_gzipped(str(resource)):
+ path = Path(resource)
+ fh = path.open("rb")
+ gzip_fh = gzip.open(fh, 'rt')
+ assert isinstance(gzip_fh, TextIOWrapper)
+ gzip_fh.read(1)
+ gzip_fh.seek(0)
+ stat = path.stat()
+
+ return SizedResource(
+ str(resource_name),
+ stat.st_size,
+ gzip_fh,
+ lambda: fh.tell(),
+ )
# If resource is local and not compressed, open as text
else:
- file = open(resource, 'r')
- return file
+ path = Path(resource)
+ stat = path.stat()
+ fh = path.open("r")
+
+ return SizedResource(
+ str(resource_name),
+ stat.st_size,
+ fh,
+ fh.tell,
+ )
def check_data(entry, path) -> bool:
@@ -126,7 +200,7 @@ def check_data(entry, path) -> bool:
column_types.update(provenance_slot_types)
-def build_export_row(data: Dict, list_delimiter: str = None) -> Dict:
+def build_export_row(data: Dict, list_delimiter: Optional[str] = None) -> Dict:
"""
Sanitize key-value pairs in dictionary.
This should be used to ensure proper syntax and types for node and edge data as it is exported.
@@ -149,7 +223,7 @@ def build_export_row(data: Dict, list_delimiter: str = None) -> Dict:
return tidy_data
-def _sanitize_export_property(key: str, value: Any, list_delimiter: str = None) -> Any:
+def _sanitize_export_property(key: str, value: Any, list_delimiter: Optional[str] = None) -> Any:
"""
Sanitize value for a key for the purpose of export.
Casts all values to primitive types like str or bool according to the
@@ -181,22 +255,22 @@ def _sanitize_export_property(key: str, value: Any, list_delimiter: str = None)
elif column_types[key] == bool:
try:
new_value = bool(value)
- except:
+ except Exception:
new_value = False
else:
new_value = str(value).replace("\n", " ").replace('\\"', "").replace("\t", " ")
else:
- if type(value) == list:
+ if isinstance(value, list):
value = [
v.replace("\n", " ").replace('\\"', "").replace("\t", " ") if isinstance(v, str) else v for v in value
]
new_value = list_delimiter.join([str(x) for x in value]) if list_delimiter else value
column_types[key] = list
- elif type(value) == bool:
+ elif isinstance(value, bool):
try:
new_value = bool(value)
column_types[key] = bool # this doesn't seem right, shouldn't column_types come from the biolink model?
- except:
+ except Exception:
new_value = False
else:
new_value = str(value).replace("\n", " ").replace('\\"', "").replace("\t", " ")
diff --git a/src/koza/io/writer/jsonl_writer.py b/src/koza/io/writer/jsonl_writer.py
index 2987879..438447e 100644
--- a/src/koza/io/writer/jsonl_writer.py
+++ b/src/koza/io/writer/jsonl_writer.py
@@ -1,10 +1,10 @@
import json
import os
-from typing import Iterable, List, Optional
+from typing import Iterable
from koza.converter.kgx_converter import KGXConverter
from koza.io.writer.writer import KozaWriter
-from koza.model.config.sssom_config import SSSOMConfig
+from koza.model.writer import WriterConfig
class JSONLWriter(KozaWriter):
@@ -12,20 +12,18 @@ def __init__(
self,
output_dir: str,
source_name: str,
- node_properties: List[str],
- edge_properties: Optional[List[str]] = [],
- sssom_config: SSSOMConfig = None,
+ config: WriterConfig,
):
self.output_dir = output_dir
self.source_name = source_name
- self.sssom_config = sssom_config
+ self.sssom_config = config.sssom_config
self.converter = KGXConverter()
os.makedirs(output_dir, exist_ok=True)
- if node_properties:
+ if config.node_properties:
self.nodeFH = open(f"{output_dir}/{source_name}_nodes.jsonl", "w")
- if edge_properties:
+ if config.edge_properties:
self.edgeFH = open(f"{output_dir}/{source_name}_edges.jsonl", "w")
def write(self, entities: Iterable):
diff --git a/src/koza/io/writer/passthrough_writer.py b/src/koza/io/writer/passthrough_writer.py
new file mode 100644
index 0000000..56adf6a
--- /dev/null
+++ b/src/koza/io/writer/passthrough_writer.py
@@ -0,0 +1,17 @@
+from typing import Iterable
+from koza.io.writer.writer import KozaWriter
+
+
+class PassthroughWriter(KozaWriter):
+ def __init__(self):
+ self.data = []
+
+ def write(self, entities: Iterable):
+ for item in entities:
+ self.data.append(item)
+
+ def finalize(self):
+ pass
+
+ def result(self):
+ return self.data
diff --git a/src/koza/io/writer/tsv_writer.py b/src/koza/io/writer/tsv_writer.py
index 5c586bd..217c419 100644
--- a/src/koza/io/writer/tsv_writer.py
+++ b/src/koza/io/writer/tsv_writer.py
@@ -2,14 +2,14 @@
# NOTE - May want to rename to KGXWriter at some point, if we develop writers for other models non biolink/kgx specific
from pathlib import Path
-from typing import Dict, Iterable, List, Literal, Set, Union
+from typing import Dict, Iterable, List, Literal, Union
from ordered_set import OrderedSet
from koza.converter.kgx_converter import KGXConverter
from koza.io.utils import build_export_row
from koza.io.writer.writer import KozaWriter
-from koza.model.config.sssom_config import SSSOMConfig
+from koza.model.writer import WriterConfig
class TSVWriter(KozaWriter):
@@ -17,19 +17,20 @@ def __init__(
self,
output_dir: Union[str, Path],
source_name: str,
- node_properties: List[str] = None,
- edge_properties: List[str] = None,
- sssom_config: SSSOMConfig = None,
+ config: WriterConfig,
):
self.basename = source_name
self.dirname = output_dir
self.delimiter = "\t"
self.list_delimiter = "|"
self.converter = KGXConverter()
- self.sssom_config = sssom_config
+ self.sssom_config = config.sssom_config
Path(self.dirname).mkdir(parents=True, exist_ok=True)
+ node_properties = config.node_properties
+ edge_properties = config.edge_properties
+
if node_properties: # Make node file
self.node_columns = TSVWriter._order_columns(node_properties, "node")
self.nodes_file_name = Path(self.dirname if self.dirname else "", f"{self.basename}_nodes.tsv")
@@ -37,7 +38,7 @@ def __init__(
self.nodeFH.write(self.delimiter.join(self.node_columns) + "\n")
if edge_properties: # Make edge file
- if sssom_config:
+ if config.sssom_config:
edge_properties = self.add_sssom_columns(edge_properties)
self.edge_columns = TSVWriter._order_columns(edge_properties, "edge")
self.edges_file_name = Path(self.dirname if self.dirname else "", f"{self.basename}_edges.tsv")
@@ -88,7 +89,7 @@ def finalize(self):
self.edgeFH.close()
@staticmethod
- def _order_columns(cols: Set, record_type: Literal["node", "edge"]) -> OrderedSet:
+ def _order_columns(cols: List[str], record_type: Literal["node", "edge"]) -> OrderedSet[str]:
"""Arrange node or edge columns in a defined order.
Args:
@@ -101,7 +102,7 @@ def _order_columns(cols: Set, record_type: Literal["node", "edge"]) -> OrderedSe
core_columns = OrderedSet(["id", "category", "name", "description", "xref", "provided_by", "synonym"])
elif record_type == "edge":
core_columns = OrderedSet(["id", "subject", "predicate", "object", "category", "provided_by"])
- ordered_columns = OrderedSet()
+ ordered_columns: OrderedSet[str] = OrderedSet([])
for c in core_columns:
if c in cols:
ordered_columns.add(c)
diff --git a/src/koza/io/writer/writer.py b/src/koza/io/writer/writer.py
index 881a5ea..78484ee 100644
--- a/src/koza/io/writer/writer.py
+++ b/src/koza/io/writer/writer.py
@@ -18,3 +18,6 @@ def write(self, entities: Iterable):
@abstractmethod
def finalize(self):
pass
+
+ def result(self):
+ raise NotImplementedError()
diff --git a/src/koza/io/yaml_loader.py b/src/koza/io/yaml_loader.py
index a6b0f76..9842445 100644
--- a/src/koza/io/yaml_loader.py
+++ b/src/koza/io/yaml_loader.py
@@ -42,7 +42,10 @@ def include_constructor(self, node: yaml.Node) -> Union[str, IO[str]]:
"""
Opens some resource (local or remote file) that appears after an !include tag
"""
- return yaml.load(open_resource(self.construct_scalar(node)), Loader=UniqueIncludeLoader)
+ resource = open_resource(self.construct_scalar(node))
+ if isinstance(resource, tuple):
+ raise ValueError("Cannot load yaml from archive files")
+ return yaml.load(resource.reader, Loader=UniqueIncludeLoader)
yaml.add_constructor('!include', UniqueIncludeLoader.include_constructor, UniqueIncludeLoader)
diff --git a/src/koza/main.py b/src/koza/main.py
index 2dee4b9..74e357b 100755
--- a/src/koza/main.py
+++ b/src/koza/main.py
@@ -1,15 +1,19 @@
#!/usr/bin/env python3
"""CLI for Koza - wraps the koza library to provide a command line interface"""
+import sys
from pathlib import Path
-from typing import Optional
-
-from koza.cli_utils import transform_source, validate_file, split_file
-from koza.model.config.source_config import FormatType, OutputFormat
+from typing import Annotated, Optional
import typer
+from loguru import logger
+
+from koza.model.formats import OutputFormat
+from koza.runner import KozaRunner
-typer_app = typer.Typer()
+typer_app = typer.Typer(
+ no_args_is_help=True,
+)
@typer_app.callback(invoke_without_command=True)
@@ -23,61 +27,80 @@ def callback(version: Optional[bool] = typer.Option(None, "--version", is_eager=
@typer_app.command()
def transform(
- source: str = typer.Option(..., help="Source metadata file"),
- output_dir: str = typer.Option("./output", help="Path to output directory"),
- output_format: OutputFormat = typer.Option("tsv", help="Output format"),
- global_table: str = typer.Option(None, help="Path to global translation table"),
- local_table: str = typer.Option(None, help="Path to local translation table"),
- schema: str = typer.Option(None, help="Path to schema YAML for validation in writer"),
- row_limit: int = typer.Option(None, help="Number of rows to process (if skipped, processes entire source file)"),
- verbose: Optional[bool] = typer.Option(None, "--debug/--quiet"),
- log: bool = typer.Option(False, help="Optional log mode - set true to save output to ./logs"),
+ configuration_yaml: Annotated[
+ str,
+ typer.Argument(help="Configuration YAML file"),
+ ],
+ output_dir: Annotated[
+ str,
+ typer.Option("--output-dir", "-o", help="Path to output directory"),
+ ] = "./output",
+ output_format: Annotated[
+ OutputFormat,
+ typer.Option("--output-format", "-f", help="Output format"),
+ ] = OutputFormat.tsv,
+ row_limit: Annotated[
+ int,
+ typer.Option("--limit", "-n", help="Number of rows to process (if skipped, processes entire source file)"),
+ ] = 0,
+ quiet: Annotated[
+ bool,
+ typer.Option("--quiet", "-q", help="Disable log output"),
+ ] = False,
) -> None:
"""Transform a source file"""
+ logger.remove()
+
output_path = Path(output_dir)
+
if output_path.exists() and not output_path.is_dir():
raise NotADirectoryError(f"{output_dir} is not a directory")
elif not output_path.exists():
output_path.mkdir(parents=True)
- transform_source(
- source,
- output_dir,
- output_format,
- global_table,
- local_table,
- schema,
- node_type=None,
- edge_type=None,
+
+ if not quiet:
+ prompt = "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level} | {message}"
+ logger.add(sys.stderr, format=prompt, colorize=True)
+
+ # FIXME: Verbosity, logging
+ config, runner = KozaRunner.from_config_file(
+ configuration_yaml,
+ output_dir=output_dir,
+ output_format=output_format,
row_limit=row_limit,
- verbose=verbose,
- log=log,
)
+ logger.info(f"Running transform for {config.name} with output to `{output_dir}`")
-@typer_app.command()
-def validate(
- file: str = typer.Option(..., help="Path or url to the source file"),
- format: FormatType = FormatType.csv,
- delimiter: str = ",",
- header_delimiter: str = None,
- skip_blank_lines: bool = True,
-) -> None:
- """Validate a source file"""
- validate_file(file, format, delimiter, header_delimiter, skip_blank_lines)
+ runner.run()
+ logger.info(f"Finished transform for {config.name}")
-@typer_app.command()
-def split(
- file: str = typer.Argument(..., help="Path to the source kgx file to be split"),
- fields: str = typer.Argument(..., help="Comma separated list of fields to split on"),
- remove_prefixes: bool = typer.Option(
- False,
- help="Remove prefixes from the file names for values from the specified fields. (e.g, NCBIGene:9606 becomes 9606",
- ),
- output_dir: str = typer.Option(default="output", help="Path to output directory"),
-):
- """Split a file by fields"""
- split_file(file, fields, remove_prefixes=remove_prefixes, output_dir=output_dir)
+
+# @typer_app.command()
+# def validate(
+# file: str = typer.Option(..., help="Path or url to the source file"),
+# format: InputFormat = InputFormat.csv,
+# delimiter: str = ",",
+# header_delimiter: str = None,
+# skip_blank_lines: bool = True,
+# ) -> None:
+# """Validate a source file"""
+# validate_file(file, format, delimiter, header_delimiter, skip_blank_lines)
+#
+#
+# @typer_app.command()
+# def split(
+# file: str = typer.Argument(..., help="Path to the source kgx file to be split"),
+# fields: str = typer.Argument(..., help="Comma separated list of fields to split on"),
+# remove_prefixes: bool = typer.Option(
+# False,
+# help="Remove prefixes from the file names for values from the specified fields. (e.g, NCBIGene:9606 becomes 9606",
+# ),
+# output_dir: str = typer.Option(default="output", help="Path to output directory"),
+# ):
+# """Split a file by fields"""
+# split_file(file, fields, remove_prefixes=remove_prefixes, output_dir=output_dir)
if __name__ == "__main__":
diff --git a/src/koza/model/__init__.py b/src/koza/model/__init__.py
index 652e193..e69de29 100644
--- a/src/koza/model/__init__.py
+++ b/src/koza/model/__init__.py
@@ -1,3 +0,0 @@
-"""
-Biolink Dataclasses
-"""
diff --git a/src/koza/model/config/source_config.py b/src/koza/model/config/source_config.py
index 28a304c..77f59f9 100644
--- a/src/koza/model/config/source_config.py
+++ b/src/koza/model/config/source_config.py
@@ -3,85 +3,15 @@
map config data class
"""
-import os
-import tarfile
-import zipfile
from dataclasses import field
from enum import Enum
from pathlib import Path
-from typing import Dict, List, Union, Optional
-import yaml
-
-from pydantic import StrictFloat, StrictInt, StrictStr
-from pydantic.dataclasses import dataclass
+from typing import Dict, List, Optional, Union
from koza.model.config.pydantic_config import PYDANTIC_CONFIG
from koza.model.config.sssom_config import SSSOMConfig
-
-
-class FilterCode(str, Enum):
- """Enum for filter codes (ex. gt = greater than)
-
- This should be aligned with https://docs.python.org/3/library/operator.html
- """
-
- gt = "gt"
- ge = "ge"
- lt = "lt"
- lte = "le"
- eq = "eq"
- ne = "ne"
- inlist = "in"
- inlist_exact = "in_exact"
-
-
-class FilterInclusion(str, Enum):
- """Enum for filter inclusion/exclusion"""
-
- include = "include"
- exclude = "exclude"
-
-
-class FieldType(str, Enum):
- """Enum for field types"""
-
- str = "str"
- int = "int"
- float = "float"
-
-
-class FormatType(str, Enum):
- """Enum for supported file types"""
-
- csv = "csv"
- jsonl = "jsonl"
- json = "json"
- yaml = "yaml"
- # xml = "xml" # Not yet supported
-
-
-class HeaderMode(str, Enum):
- """Enum for supported header modes in addition to an index based lookup"""
-
- infer = "infer"
- none = "none"
-
-
-class MapErrorEnum(str, Enum):
- """Enum for how to handle key errors in map files"""
-
- warning = "warning"
- error = "error"
-
-
-class OutputFormat(str, Enum):
- """
- Output formats
- """
-
- tsv = "tsv"
- jsonl = "jsonl"
- kgx = "kgx"
+from pydantic import StrictInt, StrictStr, TypeAdapter
+from pydantic.dataclasses import dataclass
class StandardFormat(str, Enum):
@@ -102,36 +32,28 @@ class TransformMode(str, Enum):
loop = "loop"
-@dataclass(frozen=True)
-class ColumnFilter:
- column: str
- inclusion: FilterInclusion
- filter_code: FilterCode
- value: Union[StrictInt, StrictFloat, StrictStr, List[Union[StrictInt, StrictFloat, StrictStr]]]
+# Reader configuration
+# ---
-@dataclass(frozen=True)
-class DatasetDescription:
- """
- These options should be treated as being in alpha, as we need
- to align with various efforts (hcls, translator infores)
+# Transform configuration
+# ---
- These currently do not serve a purpose in koza other than documentation
- """
- # id: Optional[str] = None # Can uncomment when we have a standard
- name: Optional[str] = None # If empty use source name
- ingest_title: Optional[str] = None # Title of source of data, map to biolink name
- ingest_url: Optional[str] = None # URL to source of data, maps to biolink iri
- description: Optional[str] = None # Description of the data/ingest
- # source: Optional[str] = None # Possibly replaced with provided_by
- provided_by: Optional[str] = None # _, ex. hpoa_gene_to_disease
- # license: Optional[str] = None # Possibly redundant, same as rights
- rights: Optional[str] = None # License information for the data source
+# Writer configuration
+# ---
+
+
+# Main Koza configuration
+# ---
+
+
+def SourceConfig(**kwargs):
+ return DEPRECATEDSourceConfig(**kwargs).to_new_transform()
@dataclass(config=PYDANTIC_CONFIG)
-class SourceConfig:
+class DEPRECATEDSourceConfig:
"""
Source config data class
@@ -182,128 +104,7 @@ class SourceConfig:
global_table: Optional[Union[str, Dict]] = None
local_table: Optional[Union[str, Dict]] = None
- def extract_archive(self):
- archive_path = Path(self.file_archive).parent # .absolute()
- if self.file_archive.endswith(".tar.gz") or self.file_archive.endswith(".tar"):
- with tarfile.open(self.file_archive) as archive:
- archive.extractall(archive_path)
- elif self.file_archive.endswith(".zip"):
- with zipfile.ZipFile(self.file_archive, "r") as archive:
- archive.extractall(archive_path)
- else:
- raise ValueError("Error extracting archive. Supported archive types: .tar.gz, .zip")
- if self.files:
- files = [os.path.join(archive_path, file) for file in self.files]
- else:
- files = [os.path.join(archive_path, file) for file in os.listdir(archive_path)]
- return files
-
- def __post_init__(self):
- # Get files as paths, or extract them from an archive
- if self.file_archive:
- files = self.extract_archive()
- else:
- files = self.files
-
- files_as_paths: List[Path] = []
- for file in files:
- if isinstance(file, str):
- files_as_paths.append(Path(file))
- else:
- files_as_paths.append(file)
- object.__setattr__(self, "files", files_as_paths)
-
- # If metadata looks like a file path attempt to load it from the yaml
- if self.metadata and isinstance(self.metadata, str):
- try:
- with open(self.metadata, "r") as meta:
- object.__setattr__(self, "metadata", DatasetDescription(**yaml.safe_load(meta)))
- except Exception as e:
- raise ValueError(f"Unable to load metadata from {self.metadata}: {e}")
-
- # Format tab as delimiter
- if self.delimiter in ["tab", "\\t"]:
- object.__setattr__(self, "delimiter", "\t")
-
- # Filter columns
- filtered_columns = [column_filter.column for column_filter in self.filters]
-
- all_columns = []
- if self.columns:
- all_columns = [next(iter(column)) if isinstance(column, Dict) else column for column in self.columns]
-
- if self.header == HeaderMode.none and not self.columns:
- raise ValueError(
- "there is no header and columns have not been supplied\n"
- "configure the 'columns' field or set header to the 0-based"
- "index in which it appears in the file, or set this value to"
- "'infer'"
- )
-
- for column in filtered_columns:
- if column not in all_columns:
- raise (ValueError(f"Filter column {column} not in column list"))
-
- for column_filter in self.filters:
- if column_filter.filter_code in ["lt", "gt", "lte", "gte"]:
- if not isinstance(column_filter.value, (int, float)):
- raise ValueError(f"Filter value must be int or float for operator {column_filter.filter_code}")
- elif column_filter.filter_code == "eq":
- if not isinstance(column_filter.value, (str, int, float)):
- raise ValueError(
- f"Filter value must be string, int or float for operator {column_filter.filter_code}"
- )
- elif column_filter.filter_code == "in":
- if not isinstance(column_filter.value, List):
- raise ValueError(f"Filter value must be List for operator {column_filter.filter_code}")
-
- # Check for conflicting configurations
- if self.format == FormatType.csv and self.required_properties:
- raise ValueError(
- "CSV specified but required properties have been configured\n"
- "Either set format to jsonl or change properties to columns in the config"
- )
- if self.columns and self.format != FormatType.csv:
- raise ValueError(
- "Columns have been configured but format is not csv\n"
- "Either set format to csv or change columns to properties in the config"
- )
- if self.json_path and self.format != FormatType.json:
- raise ValueError(
- "iterate_over has been configured but format is not json\n"
- "Either set format to json or remove iterate_over in the configuration"
- )
-
- # Create a field_type_map if columns are supplied
- if self.columns:
- field_type_map = {}
- for field in self.columns:
- if isinstance(field, str):
- field_type_map[field] = FieldType.str
- else:
- if len(field) != 1:
- raise ValueError("Field type map contains more than one key")
- for key, val in field.items():
- field_type_map[key] = val
- object.__setattr__(self, "field_type_map", field_type_map)
-
-
-@dataclass(config=PYDANTIC_CONFIG)
-class PrimaryFileConfig(SourceConfig):
- """
- Primary configuration for transforming a source file
-
- Parameters
- ----------
- node_properties: List[str] (optional) - list of node properties/columns to include
- edge_properties: List[str] (optional) - list of edge properties/columns to include
- min_node_count: int (optional) - minimum number of nodes required in output
- min_edge_count: int (optional) - minimum number of edges required in output
- node_report_columns: List[str] (optional) - list of node properties to include in the report
- edge_report_columns: List[str] (optional) - list of edge properties to include in the report
- depends_on: List[str] (optional) - Optional lookup dictionary for basic mapping
- on_map_failure: MapErrorEnum (optional) - How to handle key errors in map files
- """
+ metadata: Optional[Union[DatasetDescription, str]] = None
node_properties: Optional[List[str]] = None
edge_properties: Optional[List[str]] = None
@@ -314,11 +115,43 @@ class PrimaryFileConfig(SourceConfig):
depends_on: List[str] = field(default_factory=list)
on_map_failure: MapErrorEnum = MapErrorEnum.warning
-
-@dataclass(config=PYDANTIC_CONFIG)
-class MapFileConfig(SourceConfig):
- key: Optional[str] = None
- values: Optional[List[str]] = None
- curie_prefix: Optional[str] = None
- add_curie_prefix_to_columns: Optional[List[str]] = None
- depends_on: Optional[List[str]] = None
+ def to_new_transform(self):
+ files = self.files or []
+ if self.file_archive:
+ files.append(self.file_archive)
+
+ config_obj = {
+ "name": self.name,
+ "metadata": self.metadata,
+ "reader": {
+ "format": self.format,
+ "files": files,
+ "columns": self.columns,
+ "field_type_map": self.field_type_map,
+ "required_properties": self.required_properties,
+ "delimiter": self.delimiter,
+ "header_mode": self.header, # Renamed to header_mode
+ "header_delimiter": self.header_delimiter,
+ "header_prefix": self.header_prefix,
+ "comment_char": self.comment_char,
+ "skip_blank_lines": self.skip_blank_lines,
+ "json_path": self.json_path,
+ },
+ "transform": {
+ "code": self.transform_code,
+ "filters": self.filters,
+ "mapping": self.depends_on,
+ "global_table": self.global_table,
+ "local_table": self.local_table,
+ },
+ "writer": {
+ "format": self.format,
+ "sssom_config": self.sssom_config,
+ "node_properties": self.node_properties,
+ "edge_properties": self.edge_properties,
+ "min_node_count": self.min_node_count,
+ "min_edge_count": self.min_edge_count,
+ },
+ }
+
+ return TypeAdapter(KozaConfig).validate_python(config_obj)
diff --git a/src/koza/model/filters.py b/src/koza/model/filters.py
new file mode 100644
index 0000000..06cf880
--- /dev/null
+++ b/src/koza/model/filters.py
@@ -0,0 +1,67 @@
+from enum import Enum
+from typing import Annotated, List, Literal, Union
+
+from koza.model.config.pydantic_config import PYDANTIC_CONFIG
+from pydantic import Field, StrictFloat, StrictInt, StrictStr
+from pydantic.dataclasses import dataclass
+
+
+__all__ = ('ColumnFilter',)
+
+class FilterCode(str, Enum):
+ """Enum for filter codes (ex. gt = greater than)
+
+ This should be aligned with https://docs.python.org/3/library/operator.html
+ """
+
+ gt = "gt"
+ ge = "ge"
+ lt = "lt"
+ lte = "le"
+ eq = "eq"
+ ne = "ne"
+ inlist = "in"
+ inlist_exact = "in_exact"
+
+
+class FilterInclusion(str, Enum):
+ """Enum for filter inclusion/exclusion"""
+
+ include = "include"
+ exclude = "exclude"
+
+
+@dataclass(config=PYDANTIC_CONFIG, frozen=True)
+class BaseColumnFilter:
+ column: str
+ inclusion: FilterInclusion
+
+
+@dataclass(config=PYDANTIC_CONFIG, frozen=True)
+class ComparisonFilter(BaseColumnFilter):
+ filter_code: Literal[FilterCode.lt, FilterCode.gt, FilterCode.lte, FilterCode.ge]
+ value: Union[StrictInt, StrictFloat]
+
+
+@dataclass(config=PYDANTIC_CONFIG, frozen=True)
+class EqualsFilter(BaseColumnFilter):
+ filter_code: Literal[FilterCode.eq]
+ value: Union[StrictInt, StrictFloat, StrictStr]
+
+
+@dataclass(config=PYDANTIC_CONFIG, frozen=True)
+class NotEqualsFilter(BaseColumnFilter):
+ filter_code: Literal[FilterCode.ne]
+ value: Union[StrictInt, StrictFloat, StrictStr]
+
+
+@dataclass(config=PYDANTIC_CONFIG, frozen=True)
+class InListFilter(BaseColumnFilter):
+ filter_code: Literal[FilterCode.inlist, FilterCode.inlist_exact]
+ value: List[Union[StrictInt, StrictFloat, StrictStr]]
+
+
+ColumnFilter = Annotated[
+ Union[ComparisonFilter, EqualsFilter, NotEqualsFilter, InListFilter],
+ Field(..., discriminator="filter_code"),
+]
diff --git a/src/koza/model/formats.py b/src/koza/model/formats.py
new file mode 100644
index 0000000..523d5e3
--- /dev/null
+++ b/src/koza/model/formats.py
@@ -0,0 +1,24 @@
+from enum import Enum
+
+__all__ = ('InputFormat', 'OutputFormat')
+
+
+class InputFormat(str, Enum):
+ """Enum for supported file types"""
+
+ csv = "csv"
+ jsonl = "jsonl"
+ json = "json"
+ yaml = "yaml"
+ # xml = "xml" # Not yet supported
+
+
+class OutputFormat(str, Enum):
+ """
+ Output formats
+ """
+
+ tsv = "tsv"
+ jsonl = "jsonl"
+ kgx = "kgx"
+ passthrough = "passthrough"
diff --git a/src/koza/model/koza.py b/src/koza/model/koza.py
new file mode 100644
index 0000000..ea049e8
--- /dev/null
+++ b/src/koza/model/koza.py
@@ -0,0 +1,65 @@
+from dataclasses import field
+from typing import Optional, Union
+
+import yaml
+from ordered_set import OrderedSet
+from pydantic.dataclasses import dataclass
+
+from koza.model.config.pydantic_config import PYDANTIC_CONFIG
+from koza.model.formats import InputFormat
+from koza.model.reader import CSVReaderConfig, ReaderConfig
+from koza.model.transform import TransformConfig
+from koza.model.writer import WriterConfig
+
+__all__ = ('DatasetDescription', 'KozaConfig')
+
+
+@dataclass(frozen=True)
+class DatasetDescription:
+ """
+ These options should be treated as being in alpha, as we need
+ to align with various efforts (hcls, translator infores)
+
+ These currently do not serve a purpose in koza other than documentation
+ """
+
+ # id: Optional[str] = None # Can uncomment when we have a standard
+ name: Optional[str] = None # If empty use source name
+ ingest_title: Optional[str] = None # Title of source of data, map to biolink name
+ ingest_url: Optional[str] = None # URL to source of data, maps to biolink iri
+ description: Optional[str] = None # Description of the data/ingest
+ # source: Optional[str] = None # Possibly replaced with provided_by
+ provided_by: Optional[str] = None # _, ex. hpoa_gene_to_disease
+ # license: Optional[str] = None # Possibly redundant, same as rights
+ rights: Optional[str] = None # License information for the data source
+
+
+@dataclass(config=PYDANTIC_CONFIG, frozen=True)
+class KozaConfig:
+ name: str
+ reader: ReaderConfig = field(default_factory=CSVReaderConfig)
+ transform: TransformConfig = field(default_factory=TransformConfig)
+ writer: WriterConfig = field(default_factory=WriterConfig)
+ metadata: Optional[Union[DatasetDescription, str]] = None
+
+ def __post_init__(self):
+ # If metadata looks like a file path attempt to load it from the yaml
+ if self.metadata and isinstance(self.metadata, str):
+ try:
+ with open(self.metadata, "r") as meta:
+ object.__setattr__(self, "metadata", DatasetDescription(**yaml.safe_load(meta)))
+ except Exception as e:
+ raise ValueError(f"Unable to load metadata from {self.metadata}: {e}") from e
+
+ if self.reader.format == InputFormat.csv and self.reader.columns is not None:
+ filtered_columns = OrderedSet([column_filter.column for column_filter in self.transform.filters])
+ all_columns = OrderedSet(
+ [column if isinstance(column, str) else list(column.keys())[0] for column in self.reader.columns]
+ )
+ extra_filtered_columns = filtered_columns - all_columns
+ if extra_filtered_columns:
+ quote = "'"
+ raise ValueError(
+ "One or more filter columns not present in designated CSV columns:"
+ f" {', '.join([f'{quote}{c}{quote}' for c in extra_filtered_columns])}"
+ )
diff --git a/src/koza/model/map_dict.py b/src/koza/model/map_dict.py
deleted file mode 100644
index 2aa51e3..0000000
--- a/src/koza/model/map_dict.py
+++ /dev/null
@@ -1,14 +0,0 @@
-from koza.utils.exceptions import MapItemException
-
-
-class MapDict(dict):
- """
- A custom dictionary that raises a special KeyError exception
- MapItemException
- """
-
- def __getitem__(self, key):
- try:
- return super().__getitem__(key)
- except KeyError as key_error:
- raise MapItemException(*key_error.args)
diff --git a/src/koza/model/reader.py b/src/koza/model/reader.py
new file mode 100644
index 0000000..abe5cb7
--- /dev/null
+++ b/src/koza/model/reader.py
@@ -0,0 +1,114 @@
+"""
+source config data class
+map config data class
+"""
+
+from dataclasses import field
+from enum import Enum
+from typing import Annotated, Any, Dict, List, Literal, Optional, Union
+
+from koza.model.config.pydantic_config import PYDANTIC_CONFIG
+from koza.model.formats import InputFormat
+from pydantic import Discriminator, StrictInt, StrictStr, Tag
+from pydantic.dataclasses import dataclass
+
+
+__all__ = ('ReaderConfig',)
+
+class FieldType(str, Enum):
+ """Enum for field types"""
+
+ str = "str"
+ int = "int"
+ float = "float"
+
+
+class HeaderMode(str, Enum):
+ """Enum for supported header modes in addition to an index based lookup"""
+
+ infer = "infer"
+ none = "none"
+
+
+
+@dataclass(config=PYDANTIC_CONFIG, frozen=True)
+class BaseReaderConfig:
+ files: List[str] = field(default_factory=list)
+
+
+@dataclass(config=PYDANTIC_CONFIG, frozen=True)
+class CSVReaderConfig(BaseReaderConfig):
+ format: Literal[InputFormat.csv] = InputFormat.csv
+ columns: Optional[List[Union[str, Dict[str, FieldType]]]] = None
+ field_type_map: Optional[dict[str, FieldType]] = None
+ delimiter: str = "\t"
+ header_delimiter: Optional[str] = None
+ dialect: str = "excel"
+ header_mode: Union[int, HeaderMode] = HeaderMode.infer
+ header_delimiter: Optional[str] = None
+ header_prefix: Optional[str] = None
+ skip_blank_lines: bool = True
+ comment_char: str = "#"
+
+ def __post_init__(self):
+ # Format tab as delimiter
+ if self.delimiter in ["tab", "\\t"]:
+ object.__setattr__(self, "delimiter", "\t")
+
+ # Create a field_type_map if columns are supplied
+ if self.columns:
+ field_type_map = {}
+ for field in self.columns:
+ if isinstance(field, str):
+ field_type_map[field] = FieldType.str
+ else:
+ if len(field) != 1:
+ raise ValueError("Field type map contains more than one key")
+ for key, val in field.items():
+ field_type_map[key] = val
+ object.__setattr__(self, "field_type_map", field_type_map)
+
+ if self.header_mode == HeaderMode.none and not self.columns:
+ raise ValueError(
+ "there is no header and columns have not been supplied\n"
+ "configure the 'columns' field or set header to the 0-based"
+ "index in which it appears in the file, or set this value to"
+ "'infer'"
+ )
+
+
+@dataclass(config=PYDANTIC_CONFIG, frozen=True)
+class JSONLReaderConfig(BaseReaderConfig):
+ format: Literal[InputFormat.jsonl] = InputFormat.jsonl
+ required_properties: Optional[List[str]] = None
+
+
+@dataclass(config=PYDANTIC_CONFIG, frozen=True)
+class JSONReaderConfig(BaseReaderConfig):
+ format: Literal[InputFormat.json] = InputFormat.json
+ required_properties: Optional[List[str]] = None
+ json_path: Optional[List[Union[StrictStr, StrictInt]]] = None
+
+
+@dataclass(config=PYDANTIC_CONFIG, frozen=True)
+class YAMLReaderConfig(BaseReaderConfig):
+ format: Literal[InputFormat.yaml] = InputFormat.yaml
+ required_properties: Optional[List[str]] = None
+ json_path: Optional[List[Union[StrictStr, StrictInt]]] = None
+
+
+def get_reader_discriminator(model: Any):
+ if isinstance(model, dict):
+ return model.get("format", InputFormat.csv)
+ return getattr(model, "format", InputFormat.csv)
+
+
+ReaderConfig = Annotated[
+ (
+ Annotated[CSVReaderConfig, Tag(InputFormat.csv)]
+ | Annotated[JSONLReaderConfig, Tag(InputFormat.jsonl)]
+ | Annotated[JSONReaderConfig, Tag(InputFormat.json)]
+ | Annotated[YAMLReaderConfig, Tag(InputFormat.yaml)]
+ ),
+ Discriminator(get_reader_discriminator),
+]
diff --git a/src/koza/model/source.py b/src/koza/model/source.py
index 7d44e32..7087cc1 100644
--- a/src/koza/model/source.py
+++ b/src/koza/model/source.py
@@ -1,11 +1,14 @@
-from typing import Any, Dict, Iterator, List, Optional, Union
+from tarfile import TarFile
+from typing import Any, Dict, Iterable, List, Optional, TextIO, Union
+from zipfile import ZipFile
from koza.io.reader.csv_reader import CSVReader
from koza.io.reader.json_reader import JSONReader
from koza.io.reader.jsonl_reader import JSONLReader
from koza.io.utils import open_resource
+from koza.model.formats import InputFormat
+from koza.model.koza import KozaConfig
from koza.utils.row_filter import RowFilter
-from koza.model.config.source_config import MapFileConfig, PrimaryFileConfig # , SourceConfig
# from koza.io.yaml_loader import UniqueIncludeLoader
# import yaml
@@ -23,82 +26,60 @@ class Source:
reader: An iterator that takes in an IO[str] and yields a dictionary
"""
- def __init__(self, config: Union[PrimaryFileConfig, MapFileConfig], row_limit: Optional[int] = None):
- self.config = config
+ def __init__(self, config: KozaConfig, row_limit: int = 0):
+ reader_config = config.reader
+
self.row_limit = row_limit
- self._filter = RowFilter(config.filters)
+ self._filter = RowFilter(config.transform.filters)
self._reader = None
- self._readers: List = []
- self.last_row: Optional[Dict] = None
+ self._readers: List[Iterable[Dict[str, Any]]] = []
+ self.last_row: Optional[Dict[str, Any]] = None
+ self._opened: list[Union[ZipFile, TarFile, TextIO]] = []
+
+ for file in reader_config.files:
+ opened_resource = open_resource(file)
+ if isinstance(opened_resource, tuple):
+ archive, resources = opened_resource
+ self._opened.append(archive)
+ else:
+ resources = [opened_resource]
- for file in config.files:
- resource_io = open_resource(file)
- if self.config.format == "csv":
- self._readers.append(
- CSVReader(
- resource_io,
- name=config.name,
- field_type_map=config.field_type_map,
- delimiter=config.delimiter,
- header=config.header,
- header_delimiter=config.header_delimiter,
- header_prefix=config.header_prefix,
- comment_char=self.config.comment_char,
- row_limit=self.row_limit,
+ for resource in resources:
+ self._opened.append(resource.reader)
+ if reader_config.format == InputFormat.csv:
+ self._readers.append(
+ CSVReader(
+ resource.reader,
+ config=reader_config,
+ row_limit=self.row_limit,
+ )
)
- )
- elif self.config.format == "jsonl":
- self._readers.append(
- JSONLReader(
- resource_io,
- name=config.name,
- required_properties=config.required_properties,
- row_limit=self.row_limit,
+ elif reader_config.format == InputFormat.jsonl:
+ self._readers.append(
+ JSONLReader(
+ resource.reader,
+ config=reader_config,
+ row_limit=self.row_limit,
+ )
)
- )
- elif self.config.format == "json" or self.config.format == "yaml":
- self._readers.append(
- JSONReader(
- resource_io,
- name=config.name,
- json_path=config.json_path,
- required_properties=config.required_properties,
- is_yaml=(self.config.format == "yaml"),
- row_limit=self.row_limit,
+ elif reader_config.format == InputFormat.json or reader_config.format == InputFormat.yaml:
+ self._readers.append(
+ JSONReader(
+ resource.reader,
+ config=reader_config,
+ row_limit=self.row_limit,
+ )
)
- )
- else:
- raise ValueError(f"File type {format} not supported")
-
- def __iter__(self) -> Iterator:
- return self
+ else:
+ raise ValueError(f"File type {reader_config.format} not supported")
- def __next__(self) -> Dict[str, Any]:
- if self._reader is None:
- self._reader = self._readers.pop()
- try:
- row = self._get_row()
- except StopIteration as si:
- if len(self._readers) == 0:
- raise si
- else:
- self._reader = self._readers.pop()
- row = self._get_row()
- return row
+ def __iter__(self):
+ for reader in self._readers:
+ for item in reader:
+ if self._filter and not self._filter.include_row(item):
+ continue
+ self.last_row = item
+ yield item
- def _get_row(self):
- # If we built a filter for this source, run extra code to validate each row for inclusion in the final output.
- if self._filter:
- row = next(self._reader)
- reject_current_row = not self._filter.include_row(row)
- # If the filter says we shouldn't include the current row; we filter it out and move onto the next row.
- # We'll only break out of the following loop if "reject_current_row" is false (i.e. include_row is True/we
- # have a valid row to return) or we hit a StopIteration exception from self._reader.
- while reject_current_row:
- row = next(self._reader)
- reject_current_row = not self._filter.include_row(row)
- else:
- row = next(self._reader)
- # Retain the most recent row so that it can be logged alongside validation errors
- self.last_row = row
- return row
+ for fh in self._opened:
+ fh.close()
diff --git a/src/koza/model/transform.py b/src/koza/model/transform.py
new file mode 100644
index 0000000..151d4db
--- /dev/null
+++ b/src/koza/model/transform.py
@@ -0,0 +1,62 @@
+from dataclasses import field, fields
+from enum import Enum
+from typing import Any, Dict, List, Optional, Union
+
+from koza.model.config.pydantic_config import PYDANTIC_CONFIG
+from koza.model.filters import ColumnFilter
+from pydantic import model_validator
+from pydantic.dataclasses import dataclass
+from pydantic_core import ArgsKwargs
+
+
+class MapErrorEnum(str, Enum):
+ """Enum for how to handle key errors in map files"""
+
+ warning = "warning"
+ error = "error"
+
+
+@dataclass(config=PYDANTIC_CONFIG, frozen=True, kw_only=True)
+class TransformConfig:
+ """
+ Source config data class
+
+ Parameters
+ ----------
+ name: name of the source
+ code: path to a python file to transform the data
+ mode: how to process the transform file
+ global_table: path to a global table file
+ local_table: path to a local table file
+ """
+
+ code: Optional[str] = None
+ module: Optional[str] = None
+ filters: List[ColumnFilter] = field(default_factory=list)
+ global_table: Optional[Union[str, Dict]] = None
+ local_table: Optional[Union[str, Dict]] = None
+ mappings: List[str] = field(default_factory=list)
+ on_map_failure: MapErrorEnum = MapErrorEnum.warning
+ extra_fields: Dict[str, Any] = field(default_factory=dict)
+
+ @model_validator(mode="before")
+ @classmethod
+ def extract_extra_fields(cls, values: dict | ArgsKwargs) -> Dict[str, Any]:
+ """Take any additional kwargs and put them in the `extra_fields` attribute."""
+ if isinstance(values, dict):
+ kwargs = values.copy()
+ elif isinstance(values, ArgsKwargs) and values.kwargs is not None:
+ kwargs = values.kwargs.copy()
+ else:
+ kwargs = {}
+
+ configured_field_names = {f.name for f in fields(cls) if f.name != "extra_fields"}
+ extra_fields: dict[str, Any] = kwargs.pop("extra_fields", {})
+
+ for field_name in list(kwargs.keys()):
+ if field_name in configured_field_names:
+ continue
+ extra_fields[field_name] = kwargs.pop(field_name)
+ kwargs["extra_fields"] = extra_fields
+
+ return kwargs
diff --git a/src/koza/model/translation_table.py b/src/koza/model/translation_table.py
deleted file mode 100644
index 2af93ca..0000000
--- a/src/koza/model/translation_table.py
+++ /dev/null
@@ -1,92 +0,0 @@
-from dataclasses import dataclass
-from typing import Dict, Optional
-
-# from koza.utils.log_utils import get_logger
-# logger = get_logger(__name__)
-# import logging
-# logger = logging.getLogger(__name__)
-from loguru import logger
-
-
-def is_dictionary_bimap(dictionary: Dict[str, str]) -> bool:
- """
- Test if a dictionary is a bimap
- :param dictionary:
- :return: boolean
- """
- is_bimap = True
- all_values = set()
- failed_list = []
-
- for val in dictionary.values():
- if val not in all_values:
- all_values.add(val)
- else:
- is_bimap = False
- failed_list.append(val)
-
- if not is_bimap:
- logger.warning(f"Duplicate values in yaml: {failed_list}")
-
- return is_bimap
-
-
-@dataclass(frozen=True)
-class TranslationTable:
- """
- Translation table
- """
-
- global_table: Dict[str, str]
- local_table: Dict[str, str] # maybe bidict
-
- def __post_init__(self):
- if not is_dictionary_bimap(self.global_table):
- raise ValueError("Global table is not a bimap")
-
- def resolve_term(self, word: str, mandatory: Optional[bool] = True, default: Optional[str] = None):
- """
- Resolve a term from a source to its preferred curie
-
- given a term in some source
- return global[ (local[term] | term) ] || local[term] || (term | default)
-
- if finding a mapping is not mandatory
- returns x | default on fall through
-
- This may be generalized further from any mapping
- to a global mapping only; if need be.
-
- :param word: the string to find as a key in translation tables
- :param mandatory: boolean to cause failure when no key exists
- :param default: string to return if nothing is found (& not manandatory)
- :return
- value from global translation table,
- or value from local translation table,
- or the query key if finding a value is not mandatory (in this order)
- """
-
- if word is None:
- raise ValueError("word is required")
-
- # we may not agree with a remote sources use of a global term we have
- # this provides opportunity for us to override
- if word in self.local_table:
- label = self.local_table[word]
- if label in self.global_table:
- term_id = self.global_table[label]
- else:
- logger.info("Translated to '%s' but no global term_id for: '%s'", label, word) #
- term_id = label
- elif word in self.global_table:
- term_id = self.global_table[word]
- else:
- if mandatory:
- raise KeyError("Mapping required for: ", word)
- logger.warning("We have no translation for: '%s'", word)
-
- if default is not None:
- term_id = default
- else:
- term_id = word
- return term_id
diff --git a/src/koza/model/writer.py b/src/koza/model/writer.py
new file mode 100644
index 0000000..264da4d
--- /dev/null
+++ b/src/koza/model/writer.py
@@ -0,0 +1,16 @@
+from typing import List, Optional
+
+from koza.model.config.pydantic_config import PYDANTIC_CONFIG
+from koza.model.config.sssom_config import SSSOMConfig
+from koza.model.formats import OutputFormat
+from pydantic.dataclasses import dataclass
+
+
+@dataclass(config=PYDANTIC_CONFIG, frozen=True)
+class WriterConfig:
+ format: OutputFormat = OutputFormat.tsv
+ sssom_config: Optional[SSSOMConfig] = None
+ node_properties: Optional[List[str]] = None
+ edge_properties: Optional[List[str]] = None
+ min_node_count: Optional[int] = None
+ min_edge_count: Optional[int] = None
diff --git a/src/koza/runner.py b/src/koza/runner.py
new file mode 100644
index 0000000..14b6dd6
--- /dev/null
+++ b/src/koza/runner.py
@@ -0,0 +1,360 @@
+import importlib
+import sys
+from abc import ABC, abstractmethod
+from dataclasses import asdict, dataclass
+from pathlib import Path
+from types import ModuleType
+from typing import Any, Callable, Dict, Iterator, Optional
+from typing_extensions import assert_never
+
+import yaml
+from loguru import logger
+from mergedeep import merge
+
+from koza.io.writer.jsonl_writer import JSONLWriter
+from koza.io.writer.passthrough_writer import PassthroughWriter
+from koza.io.writer.tsv_writer import TSVWriter
+from koza.io.writer.writer import KozaWriter
+from koza.io.yaml_loader import UniqueIncludeLoader
+from koza.model.koza import KozaConfig
+from koza.model.transform import MapErrorEnum
+from koza.model.formats import OutputFormat
+from koza.model.source import Source
+from koza.utils.exceptions import MapItemException, NoTransformException
+
+Record = Dict[str, Any]
+Mappings = dict[str, dict[str, dict[str, str]]]
+
+
+def is_function(obj: object, attr: str):
+ return hasattr(obj, attr) and callable(getattr(obj, attr))
+
+
+@dataclass(kw_only=True)
+class KozaTransform(ABC):
+ extra_fields: Dict[str, Any]
+ writer: KozaWriter
+ mappings: Mappings
+ on_map_failure: MapErrorEnum = MapErrorEnum.warning
+
+ @property
+ @abstractmethod
+ def data(self) -> Iterator[Record]: ...
+
+ def write(self, *records: Record, writer: Optional[str] = None) -> None:
+ """Write a series of records to a writer.
+
+ The writer argument specifies the specific writer to write to (named
+ writers not yet implemented)
+ """
+ self.writer.write(records)
+
+ def lookup(self, name: str, map_column: str, map_name: Optional[str] = None) -> str:
+ """Look up a term in the configured mappings.
+
+ In the one argument form:
+
+ koza.lookup("name")
+
+ It will look for the first match for "name" in the configured mappings.
+ The first mapping will have precendence over any proceeding ones.
+
+ If a map name is provided, only that named mapping will be used:
+
+ koza.lookup("name", map_name="mapping_a")
+
+ """
+ try:
+ if map_name:
+ mapping = self.mappings.get(map_name, None)
+ if mapping is None:
+ raise MapItemException(f"Map {map_name} does not exist")
+
+ values = mapping.get(name, None)
+ if values is None:
+ raise MapItemException(f"No record for {name} in map {map_name}")
+
+ mapped_value = values.get(map_column, None)
+ if mapped_value is None:
+ raise MapItemException(f"No record for {name} in column {map_column} in {map_name}")
+
+ return mapped_value
+ else:
+ for mapping in self.mappings.values():
+ values = mapping.get(name, None)
+ if values is None:
+ raise MapItemException(f"No record for {name} in map {map_name}")
+
+ mapped_value = values.get(map_column, None)
+ if mapped_value is None:
+ raise MapItemException(f"No record for {name} in column {map_column} in {map_name}")
+
+ return mapped_value
+ else:
+ raise MapItemException(f"No record found in any mapping for {name} in column {map_column}")
+ except MapItemException as e:
+ match self.on_map_failure:
+ case MapErrorEnum.error:
+ raise e
+ case MapErrorEnum.warning:
+ return name
+ case _:
+ assert_never(self.on_map_failure)
+
+ def log(self, msg: str, level: str = "info") -> None:
+ """Log a message."""
+ logger.log(level, msg)
+
+ @property
+ @abstractmethod
+ def current_reader(self) -> str:
+ """Returns the reader for the last row read.
+
+ Useful for getting the filename of the file that a row was read from:
+
+ for row in koza.iter_rows():
+ filename = koza.current_reader.filename
+ """
+ ...
+
+
+@dataclass(kw_only=True)
+class SingleTransform(KozaTransform):
+ _data: Iterator[Record]
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def current_reader(self):
+ raise NotImplementedError()
+
+
+@dataclass(kw_only=True)
+class SerialTransform(KozaTransform):
+ @property
+ def data(self):
+ raise NotImplementedError()
+
+ @property
+ def current_reader(self):
+ raise NotImplementedError()
+
+
+class KozaRunner:
+ def __init__(
+ self,
+ data: Iterator[Record],
+ writer: KozaWriter,
+ mapping_filenames: Optional[list[str]] = None,
+ extra_transform_fields: Optional[dict[str, Any]] = None,
+ transform_record: Optional[Callable[[KozaTransform, Record], None]] = None,
+ transform: Optional[Callable[[KozaTransform], None]] = None,
+ ):
+ self.data = data
+ self.writer = writer
+ self.mapping_filenames = mapping_filenames or []
+ self.transform_record = transform_record
+ self.transform = transform
+ self.extra_transform_fields = extra_transform_fields or {}
+
+ def run_single(self):
+ fn = self.transform
+
+ if fn is None:
+ raise NoTransformException("Can only be run when `transform` is defined")
+
+ mappings = self.load_mappings()
+
+ logger.info("Running single transform")
+ transform = SingleTransform(
+ _data=self.data,
+ mappings=mappings,
+ writer=self.writer,
+ extra_fields=self.extra_transform_fields,
+ )
+ fn(transform)
+
+ def run_serial(self):
+ fn = self.transform_record
+
+ if fn is None:
+ raise NoTransformException("Can only be run when `transform_record` is defined")
+
+ mappings = self.load_mappings()
+
+ logger.info("Running serial transform")
+ transform = SerialTransform(
+ mappings=mappings,
+ writer=self.writer,
+ extra_fields=self.extra_transform_fields,
+ )
+ for item in self.data:
+ fn(transform, item)
+
+ def run(self):
+ if callable(self.transform) and callable(self.transform_record):
+ raise ValueError("Can only define one of `transform` or `transform_record`")
+ elif callable(self.transform):
+ self.run_single()
+ elif callable(self.transform_record):
+ self.run_serial()
+ else:
+ raise NoTransformException("Must define one of `transform` or `transform_record`")
+
+ self.writer.finalize()
+
+ def load_mappings(self):
+ mappings: Mappings = {}
+
+ if self.mapping_filenames:
+ logger.info("Loading mappings")
+
+ for mapping_config_filename in self.mapping_filenames:
+ # Check if a transform has been defined for the mapping
+ config, map_runner = KozaRunner.from_config_file(
+ mapping_config_filename,
+ output_format=OutputFormat.passthrough,
+ )
+ try:
+ map_runner.run()
+ data = map_runner.writer.result()
+ assert isinstance(data, list)
+ except NoTransformException:
+ data = map_runner.data
+
+ mapping_entry: dict[str, dict[str, str]] = {}
+ key_column: Optional[str] = map_runner.extra_transform_fields.get("key", None)
+ value_columns: Optional[list[str]] = map_runner.extra_transform_fields.get("values", None)
+
+ if key_column is None:
+ raise ValueError(f"Must define transform mapping key column in configuration for {config.name}")
+
+ if not isinstance(value_columns, list):
+ raise ValueError(
+ "Must define a list of transform mapping value columns in configuration for {config.name}"
+ )
+
+ for row in data:
+ item_key = row[key_column]
+
+ mapping_entry[str(item_key)] = {
+ key: value
+ for key, value in row.items()
+ if key in value_columns
+ #
+ }
+
+ mappings[config.name] = mapping_entry
+
+ if self.mapping_filenames:
+ logger.info("Completed loading mappings")
+
+ return mappings
+
+ @classmethod
+ def from_config(
+ cls,
+ config: KozaConfig,
+ output_dir: str = "",
+ row_limit: int = 0,
+ ):
+ module_name: Optional[str] = None
+ transform_module: Optional[ModuleType] = None
+
+ if config.transform.code:
+ transform_code_path = Path(config.transform.code)
+ parent_path = transform_code_path.absolute().parent
+ module_name = transform_code_path.stem
+ logger.debug(f"Adding `{parent_path}` to system path to load transform module")
+ sys.path.append(str(parent_path))
+ # FIXME: Remove this from sys.path
+ elif config.transform.module:
+ module_name = config.transform.module
+
+ if module_name:
+ logger.debug(f"Loading module `{module_name}`")
+ transform_module = importlib.import_module(module_name)
+
+ transform = getattr(transform_module, "transform", None)
+ if transform:
+ logger.debug(f"Found transform function `{module_name}.transform`")
+ transform_record = getattr(transform_module, "transform_record", None)
+ if transform_record:
+ logger.debug(f"Found transform function `{module_name}.transform_record`")
+ source = Source(config, row_limit)
+
+ writer: Optional[KozaWriter] = None
+
+ if config.writer.format == OutputFormat.tsv:
+ writer = TSVWriter(output_dir=output_dir, source_name=config.name, config=config.writer)
+ elif config.writer.format == OutputFormat.jsonl:
+ writer = JSONLWriter(output_dir=output_dir, source_name=config.name, config=config.writer)
+ elif config.writer.format == OutputFormat.passthrough:
+ writer = PassthroughWriter()
+
+ if writer is None:
+ raise ValueError("No writer defined")
+
+ return cls(
+ data=iter(source),
+ writer=writer,
+ mapping_filenames=config.transform.mappings,
+ extra_transform_fields=config.transform.extra_fields,
+ transform=transform,
+ transform_record=transform_record,
+ )
+
+ @classmethod
+ def from_config_file(
+ cls,
+ config_filename: str,
+ output_dir: str = "",
+ output_format: Optional[OutputFormat] = None,
+ row_limit: int = 0,
+ overrides: Optional[dict] = None,
+ ):
+ transform_code_path: Optional[Path] = None
+ config_path = Path(config_filename)
+
+ logger.info(f"Loading configuration from `{config_filename}`")
+
+ with config_path.open("r") as fh:
+ config_dict = yaml.load(fh, Loader=UniqueIncludeLoader) # noqa: S506
+ config = KozaConfig(**config_dict)
+
+ if not config.transform.code and not config.transform.module:
+
+ # If config file is named:
+ # /path/to/transform_name.yaml
+ # then look for a transform at
+ # /path/to/transform_name.py
+ mirrored_path = config_path.parent / f"{config_path.stem}.py"
+
+ # Otherwise, look for a file named transform.py in the same directory
+ transform_literal_path = config_path.parent / "transform.py"
+
+ if mirrored_path.exists():
+ transform_code_path = mirrored_path
+ elif transform_literal_path.exists():
+ transform_code_path = transform_literal_path
+
+ if transform_code_path:
+ logger.debug(f"Using transform code from `{mirrored_path}`")
+
+ # Override any necessary fields
+ config_dict = asdict(config)
+ _overrides = {}
+ if output_format is not None:
+ _overrides["writer"] = {
+ "format": output_format,
+ }
+ if transform_code_path is not None:
+ _overrides["transform"] = {
+ "code": str(transform_code_path),
+ }
+ config_dict = merge(config_dict, _overrides, overrides or {})
+ config = KozaConfig(**config_dict)
+
+ return config, cls.from_config(config, output_dir=output_dir, row_limit=row_limit)
diff --git a/src/koza/utils/exceptions.py b/src/koza/utils/exceptions.py
index 76d6de0..c5af3f6 100644
--- a/src/koza/utils/exceptions.py
+++ b/src/koza/utils/exceptions.py
@@ -12,3 +12,6 @@ class MapItemException(KeyError):
Special case of KeyError for source maps based on configuration,
a source may opt to warn or exit with an error
"""
+
+class NoTransformException(ValueError):
+ """Exception raised when a transform was not passed to KozaRunner"""
diff --git a/src/koza/utils/row_filter.py b/src/koza/utils/row_filter.py
index daa0633..823aaf0 100644
--- a/src/koza/utils/row_filter.py
+++ b/src/koza/utils/row_filter.py
@@ -1,7 +1,7 @@
from operator import eq, ge, gt, le, lt, ne
from typing import List
-from koza.model.config.source_config import ColumnFilter, FilterInclusion
+from koza.model.filters import ColumnFilter, FilterInclusion
class RowFilter:
diff --git a/src/koza/utils/testing_utils.py b/src/koza/utils/testing_utils.py
deleted file mode 100644
index 80af189..0000000
--- a/src/koza/utils/testing_utils.py
+++ /dev/null
@@ -1,83 +0,0 @@
-import types
-from typing import Union, List, Dict, Iterable
-
-import pytest
-from loguru import logger
-
-from koza.app import KozaApp
-from koza.cli_utils import get_koza_app, get_translation_table, _set_koza_app
-from koza.model.config.source_config import PrimaryFileConfig
-from koza.model.source import Source
-
-
-def test_koza(koza: KozaApp):
- """Manually sets KozaApp for testing"""
- global koza_app
- koza_app = koza
-
-
-@pytest.fixture(scope="package")
-def mock_koza():
- """Mock KozaApp for testing"""
-
- def _mock_write(self, *entities):
- if hasattr(self, "_entities"):
- self._entities.extend(list(entities))
- else:
- self._entities = list(entities)
-
- def _make_mock_koza_app(
- name: str,
- data: Iterable,
- transform_code: str,
- map_cache=None,
- filters=None,
- global_table=None,
- local_table=None,
- ):
- mock_source_file_config = PrimaryFileConfig(
- name=name,
- files=[],
- transform_code=transform_code,
- )
- mock_source_file = Source(mock_source_file_config)
- mock_source_file._reader = data
-
- _set_koza_app(
- source=mock_source_file,
- translation_table=get_translation_table(global_table, local_table, logger),
- logger=logger,
- )
- koza = get_koza_app(name)
-
- # TODO filter mocks
- koza._map_cache = map_cache
- koza.write = types.MethodType(_mock_write, koza)
-
- return koza
-
- def _transform(
- name: str,
- data: Union[Dict, List[Dict]],
- transform_code: str,
- map_cache=None,
- filters=None,
- global_table=None,
- local_table=None,
- ):
- koza_app = _make_mock_koza_app(
- name,
- iter(data) if isinstance(data, list) else iter([data]),
- transform_code,
- map_cache=map_cache,
- filters=filters,
- global_table=global_table,
- local_table=local_table,
- )
- test_koza(koza_app)
- koza_app.process_sources()
- if not hasattr(koza_app, "_entities"):
- koza_app._entities = []
- return koza_app._entities
-
- return _transform
diff --git a/tests/conftest.py b/tests/conftest.py
deleted file mode 100644
index 7e5fdb0..0000000
--- a/tests/conftest.py
+++ /dev/null
@@ -1,70 +0,0 @@
-import types
-from typing import Iterable
-
-import pytest
-from loguru import logger
-
-from koza.app import KozaApp
-from koza.utils.testing_utils import test_koza
-from koza.model.config.source_config import PrimaryFileConfig
-from koza.model.source import Source
-
-
-@pytest.fixture
-def caplog(caplog):
- handler_id = logger.add(caplog.handler, format="{message}")
- yield caplog
- logger.remove(handler_id)
-
-
-@pytest.fixture(scope="package")
-def mock_koza():
- # This should be extracted out but for quick prototyping
- def _mock_write(self, *entities):
- self._entities = list(entities)
-
- def _make_mock_koza_app(
- name: str,
- data: Iterable,
- transform_code: str,
- map_cache=None,
- filters=None,
- translation_table=None,
- ):
- mock_source_file_config = PrimaryFileConfig(
- name=name,
- files=[],
- transform_code=transform_code,
- )
- mock_source_file = Source(mock_source_file_config)
- mock_source_file._reader = data
-
- koza = KozaApp(mock_source_file)
- # TODO filter mocks
- koza.translation_table = translation_table
- koza._map_cache = map_cache
- koza.write = types.MethodType(_mock_write, koza)
-
- return koza
-
- def _transform(
- name: str,
- data: Iterable,
- transform_code: str,
- map_cache=None,
- filters=None,
- translation_table=None,
- ):
- koza_app = _make_mock_koza_app(
- name,
- data,
- transform_code,
- map_cache=map_cache,
- filters=filters,
- translation_table=translation_table,
- )
- test_koza(koza_app)
- koza_app.process_sources()
- return koza_app._entities
-
- return _transform
diff --git a/tests/integration/test_archives.py b/tests/integration/test_archives.py
deleted file mode 100644
index 0c4dc27..0000000
--- a/tests/integration/test_archives.py
+++ /dev/null
@@ -1,27 +0,0 @@
-import os
-import yaml
-from pathlib import Path
-
-from koza.io.yaml_loader import UniqueIncludeLoader
-from koza.model.config.source_config import PrimaryFileConfig
-
-
-def test_archive_targz():
- source = Path('tests/resources/string.yaml')
- unzipped_data = Path('tests/resources/source-files/string.tsv.gz')
-
- # Delete unzipped archive if it exists
- if os.path.exists(unzipped_data.absolute()):
- os.remove(unzipped_data.absolute())
-
- # Create a SourceConfig object with test config
- with open(source.absolute(), 'r') as src:
- source_config = PrimaryFileConfig(**yaml.load(src, Loader=UniqueIncludeLoader))
-
- # This method only happens after validation - force it now
- source_config.__post_init__()
-
- assert os.path.exists(unzipped_data)
-
-
-# test_archive_targz()
diff --git a/tests/integration/test_examples.py b/tests/integration/test_examples.py
index cc92ad1..7f177ea 100644
--- a/tests/integration/test_examples.py
+++ b/tests/integration/test_examples.py
@@ -6,8 +6,8 @@
import pytest
-from koza.cli_utils import transform_source
-from koza.model.config.source_config import OutputFormat
+from koza.runner import KozaRunner
+from koza.model.formats import OutputFormat
@pytest.mark.parametrize(
@@ -24,17 +24,19 @@
],
)
def test_examples(source_name, ingest, output_format):
- source_config = f"examples/{source_name}/{ingest}.yaml"
+ config_filename = f"examples/{source_name}/{ingest}.yaml"
output_suffix = str(output_format).split('.')[1]
output_dir = "./output/tests/string-test-examples"
output_files = [f"{output_dir}/{ingest}_nodes.{output_suffix}", f"{output_dir}/{ingest}_edges.{output_suffix}"]
+ for file in output_files:
+ Path(file).unlink(missing_ok=True)
- transform_source(source_config, output_dir, output_format, "examples/translation_table.yaml", None)
+ config, runner = KozaRunner.from_config_file(config_filename, output_dir, output_format)
+ runner.run()
for file in output_files:
assert Path(file).exists()
- # assert Path(file).stat().st_size > 0 # Removed this line because now node files are not
# TODO: at some point, these assertions could get more rigorous, but knowing if we have errors/exceptions is a start
diff --git a/tests/integration/test_parallel.py b/tests/integration/test_parallel.py
deleted file mode 100644
index ea06584..0000000
--- a/tests/integration/test_parallel.py
+++ /dev/null
@@ -1,42 +0,0 @@
-"""
-Test parallel transforms
-"""
-
-# import pytest
-import dask
-
-from koza.cli_utils import transform_source
-from koza.model.config.source_config import OutputFormat
-
-
-def transform(source_file):
- transform_source(
- source=source_file,
- output_dir="output/tests/string-test-parallel",
- output_format=OutputFormat.tsv,
- local_table=None,
- global_table='examples/translation_table.yaml',
- row_limit=10,
- )
- return source_file
-
-
-@dask.delayed
-def transform_string():
- return transform("examples/string/protein-links-detailed.yaml")
-
-
-@dask.delayed
-def transform_string_string_declarative():
- return transform("examples/string-declarative/declarative-protein-links-detailed.yaml")
-
-
-a = transform_string()
-b = transform_string_string_declarative()
-
-
-def test_parallel_transforms():
- results = [a, b]
-
- result = dask.delayed(print)(results)
- result.compute()
diff --git a/tests/integration/test_row_limit.py b/tests/integration/test_row_limit.py
index d6f6e10..2a8d6a0 100644
--- a/tests/integration/test_row_limit.py
+++ b/tests/integration/test_row_limit.py
@@ -3,15 +3,14 @@
Assert correct number of rows has been processed
"""
-# TODO: Parameterize row_limit, and test reading from JSON and JSONL
-# TODO: Address filter in examples/string-declarative/protein-links-detailed.yaml
-
+from pathlib import Path
import pytest
+from koza.model.formats import OutputFormat
+from koza.runner import KozaRunner
-from koza.cli_utils import transform_source
-from koza.model.config.source_config import OutputFormat
-
+# TODO: Parameterize row_limit, and test reading from JSON and JSONL
+# TODO: Address filter in examples/string-declarative/protein-links-detailed.yaml
@pytest.mark.parametrize(
"source_name, ingest, output_format, row_limit, header_len, expected_node_len, expected_edge_len",
@@ -28,28 +27,25 @@
],
)
def test_examples(source_name, ingest, output_format, row_limit, header_len, expected_node_len, expected_edge_len):
- source_config = f"examples/{source_name}/{ingest}.yaml"
+ config_filename = f"examples/{source_name}/{ingest}.yaml"
output_suffix = str(output_format).split('.')[1]
- output_dir = "./output/tests/string-test-row-limit"
+ output_dir = "./output/tests/string-test-examples"
- transform_source(
- source=source_config,
- output_dir=output_dir,
- output_format=output_format,
- global_table="examples/translation_table.yaml",
- row_limit=row_limit,
- )
+ output_files = [f"{output_dir}/{ingest}_nodes.{output_suffix}", f"{output_dir}/{ingest}_edges.{output_suffix}"]
+
+ for file in output_files:
+ Path(file).unlink(missing_ok=True)
+
+ config, runner = KozaRunner.from_config_file(config_filename, output_dir, output_format, row_limit)
+ runner.run()
# hacky check that correct number of rows was processed
# node_file = f"{output_dir}/string/{ingest}-row-limit_nodes{output_suffix}"
# edge_file = f"{output_dir}/string/{ingest}-row-limit_edges{output_suffix}"
- output_files = [f"{output_dir}/{ingest}_nodes.{output_suffix}", f"{output_dir}/{ingest}_edges.{output_suffix}"]
-
- number_of_lines = [sum(1 for line in open(output_files[0])), sum(1 for line in open(output_files[1]))]
-
- assert number_of_lines == [expected_node_len, expected_edge_len]
+ with open(output_files[0], "r") as fp:
+ assert expected_node_len == len([line for line in fp])
- # assert node_lines == expected_node_len
- # assert edge_lines == expected_edge_len
+ with open(output_files[1], "r") as fp:
+ assert expected_edge_len == len([line for line in fp])
diff --git a/tests/integration/test_validator.py b/tests/integration/test_validator.py
index 21eb8a5..3ee0055 100644
--- a/tests/integration/test_validator.py
+++ b/tests/integration/test_validator.py
@@ -12,7 +12,7 @@
import pytest
from koza.cli_utils import transform_source
-from koza.model.config.source_config import OutputFormat
+from koza.model.formats import OutputFormat
# pytest.skip("LinkML issue with `category` slot has `designates_type: true`", allow_module_level=True)
diff --git a/tests/resources/source-files/string-split.tar.gz b/tests/resources/source-files/string-split.tar.gz
new file mode 100644
index 0000000..098b82f
Binary files /dev/null and b/tests/resources/source-files/string-split.tar.gz differ
diff --git a/tests/resources/source-files/string-split.zip b/tests/resources/source-files/string-split.zip
new file mode 100644
index 0000000..fc90383
Binary files /dev/null and b/tests/resources/source-files/string-split.zip differ
diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py
index 04aeb67..273f71d 100644
--- a/tests/unit/test_config.py
+++ b/tests/unit/test_config.py
@@ -3,25 +3,10 @@
"""
-from pathlib import Path
-
import pytest
-import yaml
-
-from koza.model.config.source_config import PrimaryFileConfig
-
-base_config = Path(__file__).parent / 'resources' / 'primary-source.yaml'
-
-
-def test_source_primary_config():
- with open(base_config, 'r') as config:
- PrimaryFileConfig(**yaml.safe_load(config))
-
-
-def test_inline_local_table():
- with open(base_config, 'r') as config:
- config = PrimaryFileConfig(**yaml.safe_load(config))
- assert config.local_table["is just a little to the left of"] == "RO:123"
+from koza.model.koza import KozaConfig
+from koza.model.transform import TransformConfig
+from pydantic import TypeAdapter, ValidationError
@pytest.mark.parametrize(
@@ -36,8 +21,7 @@ def test_inline_local_table():
('exclude', 'combined_score', 'in', 0.7),
('include', 'combined_score', 'eq', ['goat', 'sheep']),
('include', 'combined_score', 'lt', ['goat', 'sheep']),
- ('include', 'combined_score', 'gte', ['goat', 'sheep']),
- ('exclude', 'is_ungulate', 'eq', 'T'),
+ ('include', 'combined_score', 'ge', ['goat', 'sheep']),
]
),
)
@@ -47,23 +31,84 @@ def test_wrong_filter_type_raises_exception(inclusion, column, filter_code, valu
value error when handed an incompatible type,
eg a string when using the lt operator
"""
- with open(base_config, 'r') as config:
- source_config = yaml.safe_load(config)
- del source_config['filters']
+ config = {
+ "filters": [
+ {
+ 'column': column,
+ 'inclusion': inclusion,
+ 'filter_code': filter_code,
+ 'value': value,
+ }
+ ],
+ }
+ with pytest.raises(ValidationError) as e:
+ TypeAdapter(TransformConfig).validate_python(config)
- source_config['filters'] = [
- {'column': column, 'inclusion': inclusion, 'filter_code': filter_code, 'value': value}
- ]
- with pytest.raises(ValueError):
- PrimaryFileConfig(**source_config)
+ for error in e.value.errors():
+ assert error["msg"].startswith("Input should be a")
-@pytest.mark.parametrize("inclusion, code", [('include', 'lgt'), ('exclude', 'ngte')])
-def test_wrong_filter_code_raises_exception(inclusion, code):
- with open(base_config, 'r') as config:
- source_config = yaml.safe_load(config)
- source_config['filters'] = [
- {'column': 'combined_score', 'inclusion': inclusion, 'filter_code': code, 'value': 70}
+@pytest.mark.parametrize(
+ "inclusion, code",
+ (
+ [
+ ('include', 'lgt'),
+ ('exclude', 'ngte'),
]
- with pytest.raises(ValueError):
- PrimaryFileConfig(**source_config)
+ ),
+)
+def test_wrong_filter_code_raises_exception(inclusion, code):
+ config = {
+ "filters": [
+ {
+ "inclusion": inclusion,
+ "filter_code": code,
+ }
+ ],
+ }
+ with pytest.raises(ValidationError) as e:
+ TypeAdapter(TransformConfig).validate_python(config)
+
+ assert e.value.error_count() == 1
+ assert e.value.errors()[0]["msg"].startswith(
+ f"Input tag '{code}' found using 'filter_code' does not match any of the expected tags:"
+ )
+
+
+def test_filter_on_nonexistent_column():
+ config = {
+ "name": "test_config",
+ "reader": {
+ "columns": ["a", "b", "c"],
+ },
+ "transform": {
+ "filters": [
+ {
+ "column": "a",
+ "inclusion": "include",
+ "filter_code": "gt",
+ "value": 0,
+ },
+ {
+ "column": "d",
+ "inclusion": "include",
+ "filter_code": "gt",
+ "value": 0,
+ },
+ {
+ "column": "e",
+ "inclusion": "include",
+ "filter_code": "gt",
+ "value": 0,
+ },
+ ],
+ },
+ }
+
+ with pytest.raises(ValidationError) as e:
+ TypeAdapter(KozaConfig).validate_python(config)
+
+ assert e.value.error_count() == 1
+ assert e.value.errors()[0]["msg"].startswith(
+ f"Value error, One or more filter columns not present in designated CSV columns: 'd', 'e'"
+ )
diff --git a/tests/unit/test_csvreader.py b/tests/unit/test_csvreader.py
index f994f20..baddcd2 100644
--- a/tests/unit/test_csvreader.py
+++ b/tests/unit/test_csvreader.py
@@ -1,9 +1,10 @@
+from io import StringIO
from pathlib import Path
import pytest
-
from koza.io.reader.csv_reader import CSVReader
-from koza.model.config.source_config import FieldType
+from koza.model.formats import InputFormat
+from koza.model.reader import CSVReaderConfig, FieldType
test_file = Path(__file__).parent.parent / 'resources' / 'source-files' / 'string.tsv'
tsv_with_footer = Path(__file__).parent.parent / 'resources' / 'source-files' / 'tsv-with-footer.tsv'
@@ -25,7 +26,12 @@
def test_no_exceptions_in_normal_case():
with open(test_file, 'r') as string_file:
- reader = CSVReader(string_file, field_type_map, delimiter=' ')
+ config = CSVReaderConfig(
+ format=InputFormat.csv,
+ field_type_map=field_type_map,
+ delimiter=' ',
+ )
+ reader = CSVReader(string_file, config)
# TODO actually test something
for _ in reader:
pass
@@ -33,8 +39,13 @@ def test_no_exceptions_in_normal_case():
def test_type_conversion():
with open(test_file, 'r') as string_file:
- reader = CSVReader(string_file, field_type_map, delimiter=' ')
- row = next(reader)
+ config = CSVReaderConfig(
+ format=InputFormat.csv,
+ field_type_map=field_type_map,
+ delimiter=' ',
+ )
+ reader = CSVReader(string_file, config)
+ row = next(iter(reader))
assert isinstance(row['protein1'], str)
assert isinstance(row['textmining'], float)
assert isinstance(row['combined_score'], int)
@@ -42,11 +53,15 @@ def test_type_conversion():
def test_field_doesnt_exist_in_file_raises_exception():
with open(test_file, 'r') as string_file:
- field_map = field_type_map.copy()
- field_map['some_field_that_doesnt_exist'] = FieldType.str
- reader = CSVReader(string_file, field_map, delimiter=' ')
+ invalid_field_type_map = field_type_map.copy()
+ invalid_field_type_map['some_field_that_doesnt_exist'] = FieldType.str
+ config = CSVReaderConfig(
+ field_type_map=invalid_field_type_map,
+ delimiter=' ',
+ )
+ reader = CSVReader(string_file, config)
with pytest.raises(ValueError):
- next(reader)
+ next(iter(reader))
def test_field_in_file_but_not_in_config_logs_warning(caplog):
@@ -55,34 +70,140 @@ def test_field_in_file_but_not_in_config_logs_warning(caplog):
:return:
"""
with open(test_file, 'r') as string_file:
- field_map = field_type_map.copy()
- del field_map['combined_score']
- reader = CSVReader(string_file, field_map, delimiter=' ')
- next(reader)
- assert caplog.records[1].levelname == 'WARNING'
- assert caplog.records[1].msg.startswith('Additional column(s) in source file')
+ missing_field_field_type_map = field_type_map.copy()
+ del missing_field_field_type_map['combined_score']
+ config = CSVReaderConfig(
+ field_type_map=missing_field_field_type_map,
+ delimiter=' ',
+ )
+ reader = CSVReader(string_file, config)
+ next(iter(reader))
+ assert caplog.records[0].levelname == 'WARNING'
+ assert caplog.records[0].msg.startswith('Additional column(s) in source file')
def test_middle_field_in_file_but_not_in_config_logs_warning(caplog):
with open(test_file, 'r') as string_file:
- field_map = field_type_map.copy()
- del field_map['cooccurence']
- reader = CSVReader(string_file, field_map, delimiter=' ')
- next(reader)
- assert caplog.records[1].levelname == 'WARNING'
+ missing_field_field_type_map = field_type_map.copy()
+ del missing_field_field_type_map['cooccurence']
+ config = CSVReaderConfig(
+ field_type_map=missing_field_field_type_map,
+ delimiter=' ',
+ )
+ reader = CSVReader(string_file, config)
+ next(iter(reader))
+ assert caplog.records[0].levelname == 'WARNING'
# assert caplog.records[1].msg.startswith('Additional columns located within configured fields')
- assert caplog.records[1].msg.startswith('Additional column(s) in source file')
+ assert caplog.records[0].msg.startswith('Additional column(s) in source file')
def test_no_field_map(caplog):
with open(test_file, 'r') as string_file:
- reader = CSVReader(string_file, delimiter=' ')
- next(reader)
+ config = CSVReaderConfig(
+ delimiter=' ',
+ )
+ reader = CSVReader(string_file, config)
+ assert reader.field_type_map is None
+ header = reader.header
+ assert len(header) == 10
+ assert reader.field_type_map is not None
+ assert header == list(reader.field_type_map.keys())
def test_no_exceptions_with_footer():
with open(tsv_with_footer, 'r') as footer_file:
- reader = CSVReader(footer_file, field_type_map, delimiter=' ', comment_char='!!')
+ config = CSVReaderConfig(
+ format=InputFormat.csv,
+ field_type_map=field_type_map,
+ delimiter=' ',
+ comment_char='!!',
+ )
+ reader = CSVReader(footer_file, config)
# TODO actually test something
for _ in reader:
pass
+
+
+def test_header_delimiter():
+ test_buffer = StringIO('a/b/c\n1,2,3\n4,5,6')
+ test_buffer.name = 'teststring'
+ config = CSVReaderConfig(
+ delimiter=',',
+ header_delimiter='/',
+ )
+ reader = CSVReader(test_buffer, config)
+ assert reader.header == ["a", "b", "c"]
+ assert [row for row in reader] == [
+ {
+ "a": "1",
+ "b": "2",
+ "c": "3",
+ },
+ {
+ "a": "4",
+ "b": "5",
+ "c": "6",
+ },
+ ]
+
+
+def test_header_prefix():
+ test_buffer = StringIO("# a|b|c")
+ test_buffer.name = 'teststring'
+ config = CSVReaderConfig(
+ header_delimiter='|',
+ header_prefix="# ",
+ )
+ reader = CSVReader(test_buffer, config)
+ assert reader.header == ["a", "b", "c"]
+
+
+def test_header_skip_lines():
+ test_buffer = StringIO("skipped line 1\nskipped line 2\na,b,c")
+ test_buffer.name = 'teststring'
+ config = CSVReaderConfig(
+ header_mode=2,
+ delimiter=',',
+ )
+ reader = CSVReader(test_buffer, config)
+ assert reader.header == ["a", "b", "c"]
+
+
+def test_default_config():
+ test_buffer = StringIO("a\tb\tc\n1\t2\t3\n4\t5\t6")
+ test_buffer.name = 'teststring'
+ config = CSVReaderConfig()
+ reader = CSVReader(test_buffer, config)
+ assert [row for row in reader] == [
+ {
+ "a": "1",
+ "b": "2",
+ "c": "3",
+ },
+ {
+ "a": "4",
+ "b": "5",
+ "c": "6",
+ },
+ ]
+
+
+def test_header_with_leading_comments():
+ test_buffer = StringIO("# comment 1\n#comment 2\na,b,c")
+ test_buffer.name = 'teststring'
+ config = CSVReaderConfig(
+ comment_char='#',
+ delimiter=',',
+ )
+ reader = CSVReader(test_buffer, config)
+ assert reader.header == ["a", "b", "c"]
+
+
+def test_header_with_blank_lines():
+ test_buffer = StringIO("\n\n\n\na,b,c")
+ test_buffer.name = 'teststring'
+ config = CSVReaderConfig(
+ delimiter=',',
+ )
+ reader = CSVReader(test_buffer, config)
+ assert reader.header == ["a", "b", "c"]
diff --git a/tests/unit/test_custom_dict.py b/tests/unit/test_custom_dict.py
deleted file mode 100644
index d815c1f..0000000
--- a/tests/unit/test_custom_dict.py
+++ /dev/null
@@ -1,19 +0,0 @@
-"""
-Testing custom dictionary
-"""
-
-import pytest
-
-from koza.utils.exceptions import MapItemException
-from koza.model.map_dict import MapDict
-
-
-def test_custom_dict_exception():
- map_dict = MapDict(foo='bar')
- with pytest.raises(MapItemException):
- map_dict['bad_key']
-
-
-def test_custom_dict_get_item():
- map_dict = MapDict(foo='bar')
- assert map_dict['foo'] == 'bar'
diff --git a/tests/unit/test_filter.py b/tests/unit/test_filter.py
index 0eb446e..faa510e 100644
--- a/tests/unit/test_filter.py
+++ b/tests/unit/test_filter.py
@@ -3,11 +3,17 @@
"""
+import pydantic
import pytest
-from koza.model.config.source_config import ColumnFilter, FilterCode, FilterInclusion
+from koza.model.filters import ColumnFilter, FilterCode, FilterInclusion
from koza.utils.row_filter import RowFilter
+class Filter(pydantic.BaseModel):
+ filter: ColumnFilter
+
+def get_filter(**kwargs):
+ return Filter.model_validate({ "filter": kwargs }).filter
@pytest.mark.parametrize(
"column, inclusion, code, value, result",
@@ -39,7 +45,8 @@
)
def test_filter(column, inclusion, code, value, result):
row = {'a': 0.3, 'b': 10, 'c': 'llama'}
- column_filter = ColumnFilter(
+
+ column_filter = get_filter(
column=column,
inclusion=FilterInclusion(inclusion),
filter_code=FilterCode(code),
@@ -56,13 +63,13 @@ def test_filter(column, inclusion, code, value, result):
[
(
[
- ColumnFilter(
+ get_filter(
column='a',
inclusion=FilterInclusion('include'),
filter_code=FilterCode('lt'),
value=0.4,
),
- ColumnFilter(
+ get_filter(
column='a',
inclusion=FilterInclusion('include'),
filter_code=FilterCode('gt'),
@@ -73,13 +80,13 @@ def test_filter(column, inclusion, code, value, result):
),
(
[
- ColumnFilter(
+ get_filter(
column='a',
inclusion=FilterInclusion('include'),
filter_code=FilterCode('lt'),
value=0.4,
),
- ColumnFilter(
+ get_filter(
column='a',
inclusion=FilterInclusion('exclude'),
filter_code=FilterCode('gt'),
@@ -90,13 +97,13 @@ def test_filter(column, inclusion, code, value, result):
),
(
[
- ColumnFilter(
+ get_filter(
column='a',
inclusion=FilterInclusion('include'),
filter_code=FilterCode('in'),
value=[0.2, 0.3, 0.4],
),
- ColumnFilter(
+ get_filter(
column='b',
inclusion=FilterInclusion('exclude'),
filter_code=FilterCode('lt'),
@@ -107,13 +114,13 @@ def test_filter(column, inclusion, code, value, result):
),
(
[
- ColumnFilter(
+ get_filter(
column='a',
inclusion=FilterInclusion('include'),
filter_code=FilterCode('in'),
value=[0.2, 0.3, 0.4],
),
- ColumnFilter(
+ get_filter(
column='b',
inclusion=FilterInclusion('exclude'),
filter_code=FilterCode('gt'),
diff --git a/tests/unit/test_io_utils.py b/tests/unit/test_io_utils.py
index fe63168..913f018 100644
--- a/tests/unit/test_io_utils.py
+++ b/tests/unit/test_io_utils.py
@@ -9,17 +9,81 @@
https://github.com/monarch-initiative/dipper/blob/682560f/tests/test_udp.py#L85
"""
-import pytest
+from pathlib import Path
+from tarfile import TarFile
+from zipfile import ZipFile
-from koza.io.utils import *
+import pytest
+from koza.io import utils as io_utils
from koza.io.utils import _sanitize_export_property
def test_404():
resource = "http://httpstat.us/404"
with pytest.raises(ValueError):
- with open_resource(resource) as _:
- pass
+ io_utils.open_resource(resource)
+
+
+def test_http():
+ resource = "https://github.com/monarch-initiative/koza/blob/8a3bab998958ecbd406c6a150cbd5c009f3f2510/tests/resources/source-files/string.tsv?raw=true"
+ resource = io_utils.open_resource(resource)
+ assert not isinstance(resource, tuple)
+
+
+def check_resource_completion(resource: io_utils.SizedResource):
+ assert resource.reader.tell() == 0
+ contents = [line for line in resource.reader]
+ assert resource.tell() == resource.size
+ return contents
+
+
+def test_open_zipfile():
+ resource = io_utils.open_resource("tests/resources/source-files/string-split.zip")
+ assert isinstance(resource, tuple)
+ zip_fh, resources = resource
+ assert isinstance(zip_fh, ZipFile)
+ assert zip_fh.filename == "tests/resources/source-files/string-split.zip"
+
+ resource_1 = next(resources)
+ assert resource_1.name == "string-a.tsv"
+ contents = check_resource_completion(resource_1)
+ assert len(contents) == 9
+
+ resource_2 = next(resources)
+ assert resource_2.name == "string-b.tsv"
+ contents = check_resource_completion(resource_2)
+ assert len(contents) == 11
+
+ zip_fh.close()
+
+
+def test_open_tarfile():
+ resource = io_utils.open_resource("tests/resources/source-files/string-split.tar.gz")
+ assert isinstance(resource, tuple)
+ tar_fh, resources = resource
+ assert isinstance(tar_fh, TarFile)
+ assert tar_fh.name == str(Path("tests/resources/source-files/string-split.tar.gz").absolute())
+
+ resource_1 = next(resources)
+ assert resource_1.name == "string-a.tsv"
+ contents = check_resource_completion(resource_1)
+ assert len(contents) == 9
+
+ resource_2 = next(resources)
+ assert resource_2.name == "string-b.tsv"
+ contents = check_resource_completion(resource_2)
+ assert len(contents) == 11
+
+ tar_fh.close()
+
+
+def test_open_gzip():
+ resource = io_utils.open_resource("tests/resources/source-files/ZFIN_PHENOTYPE_0.jsonl.gz")
+ assert not isinstance(resource, tuple)
+ contents = check_resource_completion(resource)
+ assert len(contents) == 10
+
+ resource.reader.close()
@pytest.mark.parametrize(
@@ -57,7 +121,7 @@ def test_build_export_row(query):
"""
Test build_export_row method.
"""
- d = build_export_row(query[0], list_delimiter="|")
+ d = io_utils.build_export_row(query[0], list_delimiter="|")
for k, v in query[1].items():
assert k in d
assert d[k] == v
@@ -117,7 +181,7 @@ def test_sanitize_export_property(query):
],
)
def test_remove_null(input, expected):
- assert remove_null(input) == expected
+ assert io_utils.remove_null(input) == expected
@pytest.mark.parametrize(
@@ -133,4 +197,4 @@ def test_remove_null(input, expected):
],
)
def test_is_null(input, expected):
- assert is_null(input) == expected
+ assert io_utils.is_null(input) == expected
diff --git a/tests/unit/test_jsonlreader.py b/tests/unit/test_jsonlreader.py
index 8e83d9a..bff142f 100644
--- a/tests/unit/test_jsonlreader.py
+++ b/tests/unit/test_jsonlreader.py
@@ -4,22 +4,33 @@
import pytest
from koza.io.reader.jsonl_reader import JSONLReader
+from koza.model.formats import InputFormat
+from koza.model.reader import JSONLReaderConfig
test_zfin_data = Path(__file__).parents[1] / 'resources' / 'source-files' / 'ZFIN_PHENOTYPE_0.jsonl.gz'
def test_normal_case():
+ config = JSONLReaderConfig(
+ format=InputFormat.jsonl,
+ files=[],
+ )
with gzip.open(test_zfin_data, 'rt') as zfin:
- jsonl_reader = JSONLReader(zfin)
- row = next(jsonl_reader)
+ jsonl_reader = JSONLReader(zfin, config)
+ row = next(iter(jsonl_reader))
assert len(row) == 6
assert row['objectId'] == 'ZFIN:ZDB-GENE-011026-1'
def test_required_property():
+ config = JSONLReaderConfig(
+ format=InputFormat.jsonl,
+ required_properties=["objectId", "evidence.publicationId"],
+ files=[],
+ )
with gzip.open(test_zfin_data, 'rt') as zfin:
- jsonl_reader = JSONLReader(zfin, required_properties=['objectId', 'evidence.publicationId'])
+ jsonl_reader = JSONLReader(zfin, config)
for row in jsonl_reader:
# assert len(row) == 1 # removed subsetter
print(row)
@@ -28,7 +39,12 @@ def test_required_property():
def test_missing_req_property_raises_exception():
+ config = JSONLReaderConfig(
+ format=InputFormat.jsonl,
+ required_properties=["objectId", "foobar"],
+ files=[],
+ )
with gzip.open(test_zfin_data, 'rt') as zfin:
- jsonl_reader = JSONLReader(zfin, ['objectId', 'foobar'])
+ jsonl_reader = JSONLReader(zfin, config)
with pytest.raises(ValueError):
- next(jsonl_reader)
+ next(iter(jsonl_reader))
diff --git a/tests/unit/test_jsonreader.py b/tests/unit/test_jsonreader.py
index f57fd27..a53e76c 100644
--- a/tests/unit/test_jsonreader.py
+++ b/tests/unit/test_jsonreader.py
@@ -2,35 +2,53 @@
from pathlib import Path
import pytest
-
from koza.io.reader.json_reader import JSONReader
+from koza.model.formats import InputFormat
+from koza.model.reader import JSONReaderConfig
-test_zfin_data = Path(__file__).parents[1] / 'resources' / 'source-files' / 'test_BGI_ZFIN.json.gz'
+test_zfin_data = Path(__file__).parents[1] / "resources" / "source-files" / "test_BGI_ZFIN.json.gz"
json_path = [
- 'data',
+ "data",
0,
]
def test_normal_case():
- with gzip.open(test_zfin_data, 'rt') as zfin:
- json_reader = JSONReader(zfin, json_path=json_path)
- row = next(json_reader)
- assert row['symbol'] == 'gdnfa'
+ config = JSONReaderConfig(
+ format=InputFormat.json,
+ json_path=json_path,
+ files=[],
+ )
+ with gzip.open(test_zfin_data, "rt") as zfin:
+ json_reader = JSONReader(zfin, config)
+ row = next(iter(json_reader))
+ assert row["symbol"] == "gdnfa"
def test_required_properties():
- with gzip.open(test_zfin_data, 'rt') as zfin:
- json_reader = JSONReader(zfin, ['name', 'basicGeneticEntity.primaryId'], json_path=json_path)
+ config = JSONReaderConfig(
+ format=InputFormat.json,
+ json_path=json_path,
+ required_properties=["name", "basicGeneticEntity.primaryId"],
+ files=[],
+ )
+ with gzip.open(test_zfin_data, "rt") as zfin:
+ json_reader = JSONReader(zfin, config)
for row in json_reader:
print(row)
- assert row['name']
- assert row['basicGeneticEntity']['primaryId']
+ assert row["name"]
+ assert row["basicGeneticEntity"]["primaryId"]
def test_missing_req_property_raises_exception():
- with gzip.open(test_zfin_data, 'rt') as zfin:
- json_reader = JSONReader(zfin, ['fake_prop'], json_path=json_path)
+ config = JSONReaderConfig(
+ format=InputFormat.json,
+ json_path=json_path,
+ required_properties=["fake_prop"],
+ files=[],
+ )
+ with gzip.open(test_zfin_data, "rt") as zfin:
+ json_reader = JSONReader(zfin, config)
with pytest.raises(ValueError):
- next(json_reader)
+ next(iter(json_reader))
diff --git a/tests/unit/test_jsonreader_row_limit.py b/tests/unit/test_jsonreader_row_limit.py
index 4645c37..e49e32c 100644
--- a/tests/unit/test_jsonreader_row_limit.py
+++ b/tests/unit/test_jsonreader_row_limit.py
@@ -4,6 +4,8 @@
import pytest
from koza.io.reader.json_reader import JSONReader
+from koza.model.formats import InputFormat
+from koza.model.reader import JSONReaderConfig
test_ddpheno = Path(__file__).parents[1] / 'resources' / 'source-files' / 'ddpheno.json.gz'
@@ -11,17 +13,28 @@
def test_normal_case():
+ config = JSONReaderConfig(
+ format=InputFormat.json,
+ json_path=json_path,
+ files=[],
+ )
with gzip.open(test_ddpheno, 'rt') as ddpheno:
- json_reader = JSONReader(ddpheno, json_path=json_path, row_limit=3)
- row = next(json_reader)
+ json_reader = JSONReader(ddpheno, config=config, row_limit=3)
+ row = next(iter(json_reader))
assert row['id'] == 'http://purl.obolibrary.org/obo/DDPHENO_0001198'
def test_required_properties():
+ config = JSONReaderConfig(
+ format=InputFormat.json,
+ required_properties=["id"],
+ json_path=json_path,
+ files=[],
+ )
with gzip.open(test_ddpheno, 'rt') as ddpheno:
row_limit = 3
row_count = 0
- json_reader = JSONReader(ddpheno, ['id'], json_path=json_path, row_limit=row_limit)
+ json_reader = JSONReader(ddpheno, config=config, row_limit=row_limit)
for row in json_reader:
row_count += 1
assert 'id' in row
@@ -29,7 +42,13 @@ def test_required_properties():
def test_missing_req_property_raises_exception():
+ config = JSONReaderConfig(
+ format=InputFormat.json,
+ required_properties=["fake_prop"],
+ json_path=json_path,
+ files=[],
+ )
with gzip.open(test_ddpheno, 'rt') as ddpheno:
- json_reader = JSONReader(ddpheno, ['fake_prop'], json_path=json_path, row_limit=3)
+ json_reader = JSONReader(ddpheno, config, row_limit=3)
with pytest.raises(ValueError):
- next(json_reader)
+ next(iter(json_reader))
diff --git a/tests/unit/test_multifile.py b/tests/unit/test_multifile.py
index aa0099e..1de5b5d 100644
--- a/tests/unit/test_multifile.py
+++ b/tests/unit/test_multifile.py
@@ -1,22 +1,14 @@
-import yaml
-from pathlib import Path
-
from koza.model.source import Source
-from koza.model.config.source_config import PrimaryFileConfig
-from koza.io.yaml_loader import UniqueIncludeLoader
-
+from koza.runner import KozaRunner
-def test_multiple_files():
- source_file = Path(__file__).parent.parent / 'resources' / 'multifile.yaml'
- row_limit = None
- with open(source_file, 'r') as source_fh:
- source_config = PrimaryFileConfig(**yaml.load(source_fh, Loader=UniqueIncludeLoader))
- if not source_config.name:
- source_config.name = Path(source_file).stem
+def test_source_with_multiple_files():
+ source_file = f"examples/string/protein-links-detailed.yaml"
+ config, runner = KozaRunner.from_config_file(source_file)
- source = Source(source_config, row_limit)
+ assert len(config.reader.files) == 2
- row_count = sum(1 for row in source)
+ source = Source(config)
+ row_count = len(list(source))
assert row_count == 15
diff --git a/tests/unit/test_runner.py b/tests/unit/test_runner.py
new file mode 100644
index 0000000..52ab495
--- /dev/null
+++ b/tests/unit/test_runner.py
@@ -0,0 +1,92 @@
+from typing import Any, Dict
+
+import pytest
+from koza.io.writer.writer import KozaWriter
+from koza.model.koza import KozaConfig
+from koza.runner import KozaRunner, KozaTransform
+from pydantic import TypeAdapter
+
+from koza.utils.exceptions import NoTransformException
+
+
+class MockWriter(KozaWriter):
+ def __init__(self):
+ self.items = []
+
+ def write(self, entities):
+ self.items += entities
+
+ def finalize(self):
+ pass
+
+
+def test_run_single():
+ data = iter([{"a": 1, "b": 2}])
+ writer = MockWriter()
+
+ def transform(koza: KozaTransform):
+ for record in koza.data:
+ koza.write(record)
+
+ runner = KozaRunner(data=data, writer=writer, transform=transform)
+ runner.run()
+
+ assert writer.items == [{"a": 1, "b": 2}]
+
+
+def test_run_serial():
+ data = iter([{"a": 1, "b": 2}])
+ writer = MockWriter()
+
+ def transform_record(koza: KozaTransform, record: Dict[str, Any]):
+ koza.write(record)
+
+ runner = KozaRunner(data=data, writer=writer, transform_record=transform_record)
+ runner.run()
+
+ assert writer.items == [{"a": 1, "b": 2}]
+
+
+def test_fn_required():
+ data = iter([])
+ writer = MockWriter()
+
+ with pytest.raises(NoTransformException):
+ runner = KozaRunner(data=data, writer=writer)
+ runner.run()
+
+
+def test_exactly_one_fn_required():
+ data = iter([])
+ writer = MockWriter()
+
+ def transform(koza: KozaTransform):
+ for record in koza.data:
+ koza.write(record)
+
+ def transform_record(koza: KozaTransform, record: Dict[str, Any]):
+ koza.write(record)
+
+ with pytest.raises(ValueError):
+ runner = KozaRunner(data=data, writer=writer, transform=transform, transform_record=transform_record)
+ runner.run()
+
+
+def test_load_config():
+ config = TypeAdapter(KozaConfig).validate_python({
+ "name": "my-transform",
+ "reader": {
+ "format": "csv",
+ "files": [],
+ },
+ "transform": {
+ "code": "examples/minimal.py"
+ },
+ "writer": {
+ },
+ })
+
+ runner = KozaRunner.from_config(config)
+ assert callable(runner.transform)
+ assert runner.transform_record is None
+ assert callable(runner.run)
diff --git a/tests/unit/test_tsvwriter_node_and_edge.py b/tests/unit/test_tsvwriter_node_and_edge.py
index 9456653..fe81025 100644
--- a/tests/unit/test_tsvwriter_node_and_edge.py
+++ b/tests/unit/test_tsvwriter_node_and_edge.py
@@ -1,62 +1,75 @@
-import os
-
-from biolink_model.datamodel.pydanticmodel_v2 import Disease, Gene, GeneToDiseaseAssociation
-
-from koza.io.writer.tsv_writer import TSVWriter
-
-
-def test_tsv_writer():
- """
- Writes a test tsv file
- """
- g = Gene(id="HGNC:11603", in_taxon=["NCBITaxon:9606"], symbol="TBX4")
- d = Disease(id="MONDO:0005002", name="chronic obstructive pulmonary disease")
- a = GeneToDiseaseAssociation(
- id="uuid:5b06e86f-d768-4cd9-ac27-abe31e95ab1e",
- subject=g.id,
- object=d.id,
- predicate="biolink:contributes_to",
- knowledge_level="not_provided",
- agent_type="not_provided",
- has_count=0,
- has_total=20,
- )
- ent = [g, d, a]
-
- node_properties = ["id", "category", "symbol", "in_taxon", "provided_by", "source"]
- edge_properties = [
- "id",
- "subject",
- "predicate",
- "object",
- "category" "qualifiers",
- "has_count",
- "has_total",
- "publications",
- "provided_by",
- ]
-
- outdir = "output/tests"
- outfile = "tsvwriter-node-and-edge"
-
- t = TSVWriter(outdir, outfile, node_properties, edge_properties)
- t.write(ent)
- t.finalize()
-
- assert os.path.exists("{}/{}_nodes.tsv".format(outdir, outfile)) and os.path.exists(
- "{}/{}_edges.tsv".format(outdir, outfile)
- )
-
- # read the node and edges tsv files and confirm the expected values
- with open("{}/{}_nodes.tsv".format(outdir, outfile), "r") as f:
- lines = f.readlines()
- assert lines[1] == "HGNC:11603\tbiolink:Gene\t\tNCBITaxon:9606\t\tTBX4\n"
- assert len(lines) == 3
-
- with open("{}/{}_edges.tsv".format(outdir, outfile), "r") as f:
- lines = f.readlines()
- assert (
- lines[1].strip()
- == "uuid:5b06e86f-d768-4cd9-ac27-abe31e95ab1e\tHGNC:11603\tbiolink:contributes_to\tMONDO:0005002\t\t\t0\t20"
- )
- assert len(lines) == 2
+from pathlib import Path
+
+from biolink_model.datamodel.pydanticmodel_v2 import (Disease, Gene,
+ GeneToDiseaseAssociation)
+from koza.io.writer.tsv_writer import TSVWriter
+from koza.model.writer import WriterConfig
+
+
+def test_tsv_writer():
+ """
+ Writes a test tsv file
+ """
+ node_properties = [
+ "id",
+ "category",
+ "symbol",
+ "in_taxon",
+ "provided_by",
+ "source",
+ ]
+
+ edge_properties = [
+ "id",
+ "subject",
+ "predicate",
+ "object",
+ "category" "qualifiers",
+ "has_count",
+ "has_total",
+ "publications",
+ "provided_by",
+ ]
+
+ config = WriterConfig(node_properties=node_properties, edge_properties=edge_properties)
+
+ gene = Gene(id="HGNC:11603", in_taxon=["NCBITaxon:9606"], symbol="TBX4")
+ disease = Disease(id="MONDO:0005002", name="chronic obstructive pulmonary disease")
+ association = GeneToDiseaseAssociation(
+ id="uuid:5b06e86f-d768-4cd9-ac27-abe31e95ab1e",
+ subject=gene.id,
+ object=disease.id,
+ predicate="biolink:contributes_to",
+ knowledge_level="not_provided",
+ agent_type="not_provided",
+ has_count=0,
+ has_total=20,
+ )
+ entities = [gene, disease, association]
+
+ outdir = "output/tests"
+ source_name = "tsvwriter-node-and-edge"
+
+ t = TSVWriter(outdir, source_name, config=config)
+ t.write(entities)
+ t.finalize()
+
+ nodes_path = Path(f"{outdir}/{source_name}_nodes.tsv")
+ edges_path = Path(f"{outdir}/{source_name}_edges.tsv")
+
+ assert nodes_path.exists()
+ assert edges_path.exists()
+
+ # read the node and edges tsv files and confirm the expected values
+ with nodes_path.open("r") as fh:
+ lines = fh.readlines()
+ assert lines[1] == "HGNC:11603\tbiolink:Gene\t\tNCBITaxon:9606\t\tTBX4\n"
+ assert len(lines) == 3
+
+ with edges_path.open("r") as fh:
+ lines = fh.readlines()
+ assert (
+ lines[1].strip()
+ == "uuid:5b06e86f-d768-4cd9-ac27-abe31e95ab1e\tHGNC:11603\tbiolink:contributes_to\tMONDO:0005002\t\t\t0\t20"
+ )
+ assert len(lines) == 2
diff --git a/tests/unit/test_tsvwriter_node_only.py b/tests/unit/test_tsvwriter_node_only.py
index 0fa8eb8..43fdbf2 100644
--- a/tests/unit/test_tsvwriter_node_only.py
+++ b/tests/unit/test_tsvwriter_node_only.py
@@ -1,26 +1,30 @@
-import os
-
-from biolink_model.datamodel.pydanticmodel_v2 import Disease, Gene
-
-from koza.io.writer.tsv_writer import TSVWriter
-
-
-def test_tsv_writer():
- """
- Writes a test tsv file
- """
- g = Gene(id="HGNC:11603", name="TBX4")
- d = Disease(id="MONDO:0005002", name="chronic obstructive pulmonary disease")
-
- ent = [g, d]
-
- node_properties = ['id', 'category', 'symbol', 'in_taxon', 'provided_by', 'source']
-
- outdir = "output/tests"
- outfile = "tsvwriter-node-only"
-
- t = TSVWriter(outdir, outfile, node_properties)
- t.write(ent)
- t.finalize()
-
- assert os.path.exists("{}/{}_nodes.tsv".format(outdir, outfile))
+import os
+
+from biolink_model.datamodel.pydanticmodel_v2 import Disease, Gene
+
+from koza.io.writer.tsv_writer import TSVWriter
+from koza.model.writer import WriterConfig
+
+
+def test_tsv_writer():
+ """
+ Writes a test tsv file
+ """
+ gene = Gene(id="HGNC:11603", name="TBX4")
+ disease = Disease(id="MONDO:0005002", name="chronic obstructive pulmonary disease")
+
+ entities = [gene, disease]
+
+ node_properties = ['id', 'category', 'symbol', 'in_taxon', 'provided_by', 'source']
+
+ config = WriterConfig(node_properties=node_properties)
+
+ outdir = "output/tests"
+ source_name = "tsvwriter-node-only"
+
+ t = TSVWriter(outdir, source_name, config=config)
+ t.write(entities)
+ t.finalize()
+
+ assert os.path.exists(f"{outdir}/{source_name}_nodes.tsv")
+ assert not os.path.exists(f"{outdir}/{source_name}_edges.tsv")