From 3dbbb76c62708d7ccf1594c4b3eb522117056459 Mon Sep 17 00:00:00 2001 From: Koncopd Date: Thu, 12 Oct 2023 12:53:18 +0200 Subject: [PATCH] fix scpoli load_query_data --- scarches/models/scpoli/scpoli_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scarches/models/scpoli/scpoli_model.py b/scarches/models/scpoli/scpoli_model.py index 04135750..f49b45fb 100644 --- a/scarches/models/scpoli/scpoli_model.py +++ b/scarches/models/scpoli/scpoli_model.py @@ -740,7 +740,7 @@ def load_query_data( attr_dict, model_state_dict, var_names = cls._load_params(reference_model) adata = _validate_var_names(adata, var_names) else: - attr_dict = reference_model._get_public_attributes() + attr_dict = deepcopy(reference_model._get_public_attributes()) model_state_dict = reference_model.model.state_dict() init_params = cls._get_init_params_from_dict(attr_dict) @@ -767,7 +767,7 @@ def load_query_data( if len(condition_keys) > 1: adata.obs['conditions_combined'] = adata.obs[condition_keys].apply(lambda x: '_'.join(x), axis=1) else: - adata.obs['conditions_combined'] = adata.obs[condition_keys].apply(lambda x: '_'.join(x)) + adata.obs['conditions_combined'] = adata.obs[condition_keys] new_conditions_combined = adata.obs['conditions_combined'].unique().tolist() for item in new_conditions_combined: if item not in conditions_combined: