diff --git a/tableone/tableone.py b/tableone/tableone.py index 948d49a..2a1a062 100644 --- a/tableone/tableone.py +++ b/tableone/tableone.py @@ -169,9 +169,9 @@ class TableOne: Run Tukey's test for far outliers. If variables are found to have far outliers, a remark will be added below the Table 1. (default: False) - auto_fill_nulls : bool, optional - Attempt to automatically handle None/Null values in categorical columns - by treating them as a category named 'None'. (default: True) + include_null : bool, optional + Include None/Null values for categorical variables by treating them as a + category level. (default: True) Attributes @@ -225,7 +225,7 @@ def __init__(self, data: pd.DataFrame, dip_test: bool = False, normal_test: bool = False, tukey_test: bool = False, pval_threshold: Optional[float] = None, - auto_fill_nulls: Optional[bool] = True) -> None: + include_null: Optional[bool] = True) -> None: # Warn about deprecated parameters handle_deprecated_parameters(labels, isnull, pval_test_name, remarks) @@ -240,7 +240,7 @@ def __init__(self, data: pd.DataFrame, htest, missing, ddof, rename, sort, limit, order, label_suffix, decimals, smd, overall, row_percent, dip_test, normal_test, tukey_test, pval_threshold, - auto_fill_nulls) + include_null) # Initialize intermediate tables self.initialize_intermediate_tables() @@ -282,12 +282,12 @@ def initialize_core_attributes(self, data, columns, categorical, continuous, gro htest, missing, ddof, rename, sort, limit, order, label_suffix, decimals, smd, overall, row_percent, dip_test, normal_test, tukey_test, pval_threshold, - auto_fill_nulls): + include_null): """ Initialize attributes. """ self._alt_labels = rename - self._auto_fill_nulls = auto_fill_nulls + self._include_null = include_null self._columns = columns if columns else data.columns.to_list() # type: ignore self._categorical = detect_categorical(data[self._columns], groupby) if categorical is None else categorical if continuous: @@ -318,7 +318,7 @@ def initialize_core_attributes(self, data, columns, categorical, continuous, gro self._tukey_test = tukey_test self._warnings = {} - if self._categorical and self._auto_fill_nulls: + if self._categorical and self._include_null: data[self._categorical] = handle_categorical_nulls(data[self._categorical]) self._groupbylvls = get_groups(data, self._groupby, self._order, self._reserved_columns) @@ -347,7 +347,7 @@ def validate_data(self, data): self.input_validator.validate(self._groupby, self._nonnormal, self._min_max, # type: ignore self._pval_adjust, self._order, self._pval, # type: ignore self._columns, self._categorical, self._continuous) # type: ignore - self.data_validator.validate(data, self._columns, self._categorical, self._auto_fill_nulls) # type: ignore + self.data_validator.validate(data, self._columns, self._categorical, self._include_null) # type: ignore def create_intermediate_tables(self, data): """ @@ -366,6 +366,7 @@ def create_intermediate_tables(self, data): self._categorical, self._decimals, self._row_percent, + self._include_null, groupby=None, groupbylvls=['Overall']) @@ -385,6 +386,7 @@ def create_intermediate_tables(self, data): self._categorical, self._decimals, self._row_percent, + self._include_null, groupby=self._groupby, groupbylvls=self._groupbylvls) @@ -413,6 +415,7 @@ def create_intermediate_tables(self, data): self._overall, self.cat_describe, self._categorical, + self._include_null, self._pval, self._pval_adjust, self.htest_table, diff --git a/tableone/tables.py b/tableone/tables.py index e35c743..b15ba65 100644 --- a/tableone/tables.py +++ b/tableone/tables.py @@ -193,6 +193,7 @@ def create_cat_describe(self, categorical, decimals, row_percent, + include_null, groupby: Optional[str] = None, groupbylvls: Optional[list] = None ) -> pd.DataFrame: @@ -223,12 +224,19 @@ def create_cat_describe(self, else: df = cat_slice.copy() - # create n column and null count column + # create n column # must be done before converting values to strings ct = df.count().to_frame(name='n') ct.index.name = 'variable' - nulls = df.isnull().sum().to_frame(name='Missing') - nulls.index.name = 'variable' + + if include_null: + # create an empty Missing column for display purposes + nulls = pd.DataFrame('', index=df.columns, columns=['Missing']) + nulls.index.name = 'variable' + else: + # Count and display null count + nulls = df.isnull().sum().to_frame(name='Missing') + nulls.index.name = 'variable' # Convert to str to handle int converted to boolean in the index. # Also avoid nans. @@ -445,6 +453,7 @@ def create_cat_table(self, overall, cat_describe, categorical, + include_null, pval, pval_adjust, htest_table, @@ -462,9 +471,14 @@ def create_cat_table(self, """ table = cat_describe['t1_summary'].copy() - # add the total count of null values across all levels - isnull = data[categorical].isnull().sum().to_frame(name='Missing') - isnull.index = isnull.index.rename('variable') + if include_null: + isnull = pd.DataFrame(index=categorical, columns=['Missing']) + isnull['Missing'] = '' + isnull.index.rename('variable', inplace=True) + else: + # add the total count of null values across all levels + isnull = data[categorical].isnull().sum().to_frame(name='Missing') + isnull.index = isnull.index.rename('variable') try: table = table.join(isnull) diff --git a/tableone/validators.py b/tableone/validators.py index d738293..879da63 100644 --- a/tableone/validators.py +++ b/tableone/validators.py @@ -12,7 +12,7 @@ def __init__(self): def validate(self, data: pd.DataFrame, columns: list, categorical: list, - auto_fill_nulls: bool) -> None: + include_null: bool) -> None: """ Check the input dataset for obvious issues. @@ -24,23 +24,6 @@ def validate(self, data: pd.DataFrame, columns: list, self.check_unique_index(data) self.check_columns_exist(data, columns) self.check_duplicate_columns(data, columns) - if categorical and not auto_fill_nulls: - self.check_categorical_none(data, categorical) - - def check_categorical_none(self, data: pd.DataFrame, categorical: List[str]): - """ - Ensure that categorical columns do not contain None values. - - Parameters: - data (pd.DataFrame): The DataFrame to check. - categorical (List[str]): The list of categorical columns to validate. - """ - contains_none = [col for col in categorical if data[col].isnull().any()] - if contains_none: - raise InputError(f"The following categorical columns contains one or more null values: {contains_none}. " - f"These must be converted to strings before processing. Either set " - f"`auto_fill_nulls = True` or manually convert nulls to strings with: " - f"data[categorical_columns] = data[categorical_columns].fillna('None')") def validate_input(self, data: pd.DataFrame): if not isinstance(data, pd.DataFrame): diff --git a/tests/unit/test_tableone.py b/tests/unit/test_tableone.py index e205528..601529b 100644 --- a/tests/unit/test_tableone.py +++ b/tests/unit/test_tableone.py @@ -216,7 +216,7 @@ def test_overall_n_and_percent_for_binary_cat_var_with_nan( """ categorical = ['likeshoney'] table = TableOne(data_sample, columns=categorical, - categorical=categorical) + categorical=categorical, include_null=False) lh = table.cat_describe.loc['likeshoney'] @@ -796,7 +796,8 @@ def test_nan_rows_not_deleted_in_categorical_columns(self): # create tableone t1 = TableOne(df, label_suffix=False, - categorical=['basket1', 'basket2', 'basket3', 'basket4']) + categorical=['basket1', 'basket2', 'basket3', 'basket4'], + include_null=False) assert all(t1.tableone.loc['basket1'].index == ['apple', 'banana', 'durian', 'lemon', @@ -1028,7 +1029,7 @@ def test_order_of_order_categorical_columns(self): # if a custom order is not specified, the categorical order # specified above should apply - t1 = TableOne(data, label_suffix=False) + t1 = TableOne(data, label_suffix=False, include_null=False) t1_expected_order = {'month': ["feb", "jan", "mar", "apr"], 'day': ["wed", "thu", "mon", "tue"]} @@ -1039,7 +1040,7 @@ def test_order_of_order_categorical_columns(self): t1_expected_order[k]) # if a desired order is set, it should override the order - t2 = TableOne(data, order=order, label_suffix=False) + t2 = TableOne(data, order=order, label_suffix=False, include_null=False) t2_expected_order = {'month': ["jan", "feb", "mar", "apr"], 'day': ["mon", "tue", "wed", "thu"]} @@ -1104,7 +1105,7 @@ def test_row_percent_false(self, data_pn): t1 = TableOne(data_pn, columns=columns, categorical=categorical, groupby=groupby, nonnormal=nonnormal, decimals=decimals, - row_percent=False) + row_percent=False, include_null=False) row1 = list(t1.tableone.loc["MechVent, n (%)"][group].values[0]) row1_expect = [0, '540 (54.0)', '468 (54.2)', '72 (52.9)'] @@ -1154,7 +1155,7 @@ def test_row_percent_true(self, data_pn): t2 = TableOne(data_pn, columns=columns, categorical=categorical, groupby=groupby, nonnormal=nonnormal, decimals=decimals, - row_percent=True) + row_percent=True, include_null=False) row1 = list(t2.tableone.loc["MechVent, n (%)"][group].values[0]) row1_expect = [0, '540 (100.0)', '468 (86.7)', '72 (13.3)'] @@ -1204,7 +1205,7 @@ def test_row_percent_true_and_overall_false(self, data_pn): t1 = TableOne(data_pn, columns=columns, overall=False, categorical=categorical, groupby=groupby, nonnormal=nonnormal, decimals=decimals, - row_percent=True) + row_percent=True, include_null=False) row1 = list(t1.tableone.loc["MechVent, n (%)"][group].values[0]) row1_expect = [0, '468 (86.7)', '72 (13.3)']