diff --git a/bionty/base/entities/_gene.py b/bionty/base/entities/_gene.py index c388792..54ee6b6 100644 --- a/bionty/base/entities/_gene.py +++ b/bionty/base/entities/_gene.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Iterable from typing import TYPE_CHECKING, Literal, NamedTuple import pandas as pd @@ -96,34 +97,23 @@ def map_legacy_ids(self, values: Iterable) -> MappingResult: class EnsemblGene: - def __init__( - self, - organism: str, - version: str, - taxa: Literal[ - "vertebrates", "bacteria", "fungi", "metazoa", "plants", "all" - ] = "vertebrates", - ) -> None: + def __init__(self, organism: str, version: str) -> None: """Ensembl Gene mysql. Args: - organism: Name of the organism - version: Name of the ensembl DB version, e.g. "release-110" - taxa: The taxa of the organism to fetch genes for. + organism: a bionty.Organism object + version: name of the ensembl DB version, e.g. "release-110" """ self._import() import mysql.connector as sql from sqlalchemy import create_engine self._organism = ( - Organism(version=version, taxa=taxa).lookup().dict().get(organism) # type:ignore + Organism(version=version).lookup().dict().get(organism) # type:ignore + ) + self._url = ( + f"mysql+mysqldb://anonymous:@ensembldb.ensembl.org/{self._organism.core_db}" ) - # vertebrates and plants use different ports - if taxa == "plants": - port = 4157 - else: - port = 3306 - self._url = f"mysql+mysqldb://anonymous:@ensembldb.ensembl.org:{port}/{self._organism.core_db}" self._engine = create_engine(url=self._url) def _import(self): @@ -242,10 +232,8 @@ def add_external_db_column(df: pd.DataFrame, ext_db: str, df_col: str): df_res = df_res[~df_res["ensembl_gene_id"].isna()] # if stable_id is not ensembl_gene_id, keep a stable_id column - if not all(df_res["ensembl_gene_id"].str.startswith("ENS")): - logger.warning( - "ensembl_gene_id column not all ENS-prefixed, writing to stable_id column." - ) + if not any(df_res["ensembl_gene_id"].str.startswith("ENS")): + logger.warning("no ensembl_gene_id found, writing to table_id column.") df_res.insert(0, "stable_id", df_res.pop("ensembl_gene_id")) df_res = df_res.sort_values("stable_id").reset_index(drop=True) else: diff --git a/bionty/base/entities/_organism.py b/bionty/base/entities/_organism.py index c577361..6327364 100644 --- a/bionty/base/entities/_organism.py +++ b/bionty/base/entities/_organism.py @@ -22,7 +22,9 @@ class Organism(PublicOntology): def __init__( self, - taxa: Literal["vertebrates", "bacteria", "fungi", "metazoa", "plants", "all"] + organism: Literal[ + "vertebrates", "bacteria", "fungi", "metazoa", "plants", "all" + ] | None = None, source: Literal["ensembl", "ncbitaxon"] | None = None, version: Literal[ @@ -37,11 +39,7 @@ def __init__( | None = None, **kwargs, ): - # To support the organism kwarg being passed in getattr access in other parts of the code - # https://github.com/laminlabs/bionty/issues/163 - if kwargs.get("organism") is not None: - taxa = kwargs.pop("organism") - super().__init__(organism=taxa, source=source, version=version, **kwargs) + super().__init__(organism=organism, source=source, version=version, **kwargs) def _load_df(self) -> pd.DataFrame: if self.source == "ensembl":