diff --git a/databricks/koalas/frame.py b/databricks/koalas/frame.py index dd80728230..76ab6511c1 100644 --- a/databricks/koalas/frame.py +++ b/databricks/koalas/frame.py @@ -53,6 +53,7 @@ else: from pandas.core.dtypes.common import _get_dtype_from_object as infer_dtype_from_object from pandas.core.accessor import CachedAccessor +import pandas.core.common as com from pandas.core.dtypes.inference import is_sequence import pyspark from pyspark import StorageLevel @@ -7194,7 +7195,9 @@ def sample( n: Optional[int] = None, frac: Optional[float] = None, replace: bool = False, + weights: Optional[Any] = None, random_state: Optional[int] = None, + axis: Optional[Any] = None, ) -> "DataFrame": """ Return a random sample of items from an axis of object. @@ -7215,14 +7218,35 @@ def sample( Fraction of axis items to return. replace : bool, default False Sample with or without replacement. + weights : ndarray-like, optional + Currently does not support Series and str. + Default 'None' results in equal probability weighting. + If passed a Series, will align with target object on index. Index + values in weights not found in sampled object will be ignored and + index values in sampled object not in weights will be assigned + weights of zero. + If called on a DataFrame, will accept the name of a column + when axis = 0. + Unless weights are a Series, weights must be same length as axis + being sampled. + If weights do not sum to 1, they will be normalized to sum to 1. + Missing values in the weights column will be treated as zero. + Infinite values not allowed. random_state : int, optional Seed for the random number generator (if int). + axis : {0 or ‘index’, 1 or ‘columns’, None}, default None + Axis to sample. Accepts axis number or name. Default is stat axis + for given data type (0 for Series and DataFrames). Returns ------- Series or DataFrame A new object of same type as caller containing the sampled items. + Notes + ----- + If `frac` > 1, `replacement` should be set to `True`. + Examples -------- >>> df = ks.DataFrame({'num_legs': [2, 4, 8, 0], @@ -7237,46 +7261,97 @@ def sample( spider 8 0 1 fish 0 0 8 - A random 25% sample of the ``DataFrame``. + Extract 3 random elements from the ``Series`` ``df['num_legs']``: Note that we use `random_state` to ensure the reproducibility of the examples. - >>> df.sample(frac=0.25, random_state=1) # doctest: +SKIP - num_legs num_wings num_specimen_seen - falcon 2 2 10 - fish 0 0 8 - - Extract 25% random elements from the ``Series`` ``df['num_legs']``, with replacement, - so the same items could appear more than once. - - >>> df['num_legs'].sample(frac=0.4, replace=True, random_state=1) # doctest: +SKIP + >>> df['num_legs'].sample(n=3, random_state=1).sort_index() falcon 2 - spider 8 + fish 0 spider 8 Name: num_legs, dtype: int64 - Specifying the exact number of items to return is not supported at the moment. + A random 50% sample of the ``DataFrame`` with replacement: - >>> df.sample(n=5) # doctest: +ELLIPSIS - Traceback (most recent call last): - ... - NotImplementedError: Function sample currently does not support specifying ... + >>> df.sample(frac=0.5, replace=True, random_state=1).sort_index() + num_legs num_wings num_specimen_seen + dog 4 0 2 + fish 0 0 8 """ - # Note: we don't run any of the doctests because the result can change depending on the - # system's core count. - if n is not None: - raise NotImplementedError( - "Function sample currently does not support specifying " - "exact number of items to return. Use frac instead." - ) + axis = validate_axis(axis) + if axis == 1: + raise NotImplementedError("Function sample currently does not support axis=1.") - if frac is None: - raise ValueError("frac must be specified.") + axis_length = self.shape[axis] - sdf = self._internal.resolved_copy.spark_frame.sample( - withReplacement=replace, fraction=frac, seed=random_state - ) - return DataFrame(self._internal.with_new_sdf(sdf)) + # Process random_state argument + if LooseVersion(pd.__version__) >= LooseVersion("0.24"): + rs = com.random_state(random_state) + else: + rs = com._random_state(random_state) + + # Check weights for compliance + if weights is not None: + + # If a series or str, ks.Series currently does not support the Series.__iter__ method, + # It cannot be initialized to the pandas Series, so here is to_pandas. + # Don't support weights as Series for now since it could occur performance degradation. + if isinstance(weights, (ks.Series, str)): + raise NotImplementedError( + "The weights parameter does not currently support the Series and str." + ) + + weights = pd.Series(weights, dtype="float64") + + if len(weights) != axis_length: + raise ValueError("Weights and axis to be sampled must be of same length") + + if (weights == np.inf).any() or (weights == -np.inf).any(): + raise ValueError("weight vector may not include `inf` values") + + if (weights < 0).any(): + raise ValueError("weight vector may not include negative values") + + # If has nan, set to zero. + weights = weights.fillna(0) + + # Renormalize if don't sum to 1 + weights_sum = weights.sum() + if weights_sum != 1: + if weights_sum != 0: + weights = weights / weights_sum + else: + raise ValueError("Invalid weights: weights sum to zero") + + weights = weights._values + + # If no frac or n, default to n=1. + if n is None and frac is None: + n = 1 + elif frac is not None and frac > 1 and not replace: + raise ValueError( + "Replace has to be set to `True` when " "upsampling the population `frac` > 1." + ) + elif n is not None and frac is None and n % 1 != 0: + raise ValueError("Only integers accepted as `n` values") + elif n is None and frac is not None: + n = int(round(frac * axis_length)) + elif n is not None and frac is not None: + raise ValueError("Please enter a value for `frac` OR `n`, not both") + + # Check for negative sizes + if n < 0: + raise ValueError("A negative number of rows requested. Please provide positive value.") + + # Because duplicated row selection is not currently supported. + # So if frac > 1, use the pyspark implementation. + if frac is not None and frac > 1: + sdf = self._internal.resolved_copy.spark_frame.sample( + withReplacement=replace, fraction=float(frac), seed=random_state + ) + return DataFrame(self._internal.with_new_sdf(sdf)) + locs = rs.choice(axis_length, size=n, replace=replace, p=weights) + return self.take(locs, axis=axis) def astype(self, dtype) -> "DataFrame": """ diff --git a/databricks/koalas/series.py b/databricks/koalas/series.py index 8f9d03abaf..232efdec4f 100644 --- a/databricks/koalas/series.py +++ b/databricks/koalas/series.py @@ -2855,10 +2855,19 @@ def sample( n: Optional[int] = None, frac: Optional[float] = None, replace: bool = False, + weights: Optional[Any] = None, random_state: Optional[int] = None, + axis: Optional[Any] = None, ) -> "Series": return first_series( - self.to_frame().sample(n=n, frac=frac, replace=replace, random_state=random_state) + self.to_frame().sample( + n=n, + frac=frac, + replace=replace, + weights=weights, + random_state=random_state, + axis=axis, + ) ).rename(self.name) sample.__doc__ = DataFrame.sample.__doc__ diff --git a/databricks/koalas/tests/test_dataframe.py b/databricks/koalas/tests/test_dataframe.py index b55b5c4265..076f94b5a8 100644 --- a/databricks/koalas/tests/test_dataframe.py +++ b/databricks/koalas/tests/test_dataframe.py @@ -2028,21 +2028,162 @@ def test_binary_operator_multiply(self): self.assertRaisesRegex(TypeError, ks_err_msg, lambda: 0.1 * kdf["a"]) def test_sample(self): - pdf = pd.DataFrame({"A": [0, 2, 4]}) + pdf = pd.DataFrame({"col1": [5, 6, 7], "col2": ["a", "b", "c"]}, index=[9, 5, 3]) kdf = ks.from_pandas(pdf) - # Make sure the tests run, but we can't check the result because they are non-deterministic. - kdf.sample(frac=0.1) - kdf.sample(frac=0.2, replace=True) - kdf.sample(frac=0.2, random_state=5) - kdf["A"].sample(frac=0.2) - kdf["A"].sample(frac=0.2, replace=True) - kdf["A"].sample(frac=0.2, random_state=5) + ### + # Check behavior of random_state argument + ### + + # Check for stability when receives seed or random state -- run 10 + # times. + for test in range(10): + seed = np.random.randint(0, 100) + self.assert_eq(kdf.sample(n=2, random_state=seed), kdf.sample(n=2, random_state=seed)) + self.assert_eq( + kdf.sample(frac=0.7, random_state=seed), kdf.sample(frac=0.7, random_state=seed) + ) + self.assert_eq( + kdf.sample(n=2, random_state=np.random.RandomState(test)), + kdf.sample(n=2, random_state=np.random.RandomState(test)), + ) + self.assert_eq( + kdf.sample(frac=0.7, random_state=np.random.RandomState(test)), + kdf.sample(frac=0.7, random_state=np.random.RandomState(test)), + ) + + # Check for error when random_state argument invalid. + with self.assertRaises(ValueError): + kdf.sample(random_state="astring!") + + ### + # Check behavior of `frac` and `N` + ### + + # Giving both frac and N throws error + with self.assertRaises(ValueError): + kdf.sample(n=3, frac=0.3) + + # Check that raises right error for negative lengths + with self.assertRaises(ValueError): + kdf.sample(n=-3) + with self.assertRaises(ValueError): + kdf.sample(frac=-0.3) + # Make sure float values of `n` give error with self.assertRaises(ValueError): - kdf.sample() + kdf.sample(n=3.2) + + # Check lengths are right + assert len(kdf.sample(n=2) == 2) + assert len(kdf.sample(frac=0.34) == 1) + assert len(kdf.sample(frac=0.48) == 2) + + ### + # Check weights + ### + + # Weight length must be right + with self.assertRaises(ValueError): + kdf.sample(n=3, weights=[0, 1]) + + with self.assertRaises(ValueError): + bad_weights = [0.5] * 11 + kdf.sample(n=3, weights=bad_weights) + + # Weight do not support a Series or str with self.assertRaises(NotImplementedError): - kdf.sample(n=1) + weight_series = ks.Series([0, 0.2]) + kdf.sample(n=4, weights=weight_series) + + with self.assertRaises(NotImplementedError): + kdf.sample(n=4, weights="col1") + + # Check won't accept negative weights + with self.assertRaises(ValueError): + bad_weights = [-0.1] * 3 + kdf.sample(n=3, weights=bad_weights) + + # Check inf and -inf throw errors: + with self.assertRaises(ValueError): + weights_with_inf = [0.1] * 3 + weights_with_inf[0] = np.inf + kdf.sample(n=3, weights=weights_with_inf) + + with self.assertRaises(ValueError): + weights_with_ninf = [0.1] * 3 + weights_with_ninf[0] = -np.inf + kdf.sample(n=3, weights=weights_with_ninf) + + # All zeros raises errors + zero_weights = [0] * 3 + with self.assertRaises(ValueError): + kdf.sample(n=3, weights=zero_weights) + + # All missing weights + nan_weights = [np.nan] * 3 + with self.assertRaises(ValueError): + kdf.sample(n=3, weights=nan_weights) + + # Check np.nan are replaced by zeros. + weights_with_nan = [np.nan] * 3 + weights_with_nan[2] = 0.5 + self.assert_eq( + kdf.sample(n=1, axis=0, weights=weights_with_nan), + pdf.sample(n=1, axis=0, weights=weights_with_nan), + ) + + # Check None are also replaced by zeros. + weights_with_None = [None] * 3 + weights_with_None[2] = 0.5 + self.assert_eq( + kdf.sample(n=1, axis=0, weights=weights_with_None), + pdf.sample(n=1, axis=0, weights=weights_with_None), + ) + + ### + # Test axis argument + ### + + # Test axis argument + pdf = pd.DataFrame({"col1": range(10), "col2": ["a"] * 10}) + kdf = ks.from_pandas(pdf) + second_column_weight = [0, 1] + + weight = [0] * 10 + weight[5] = 0.5 + self.assert_eq( + kdf.sample(n=1, axis="rows", weights=weight), + pdf.sample(n=1, axis="rows", weights=weight), + ) + self.assert_eq( + kdf.sample(n=1, axis="index", weights=weight), + pdf.sample(n=1, axis="index", weights=weight), + ) + + # Check out of range axis values + with self.assertRaises(ValueError): + kdf.sample(n=1, axis=2) + + with self.assertRaises(ValueError): + kdf.sample(n=1, axis="not_a_name") + + # Check for axis=1 raise NotImplementedError + with self.assertRaises(NotImplementedError): + kdf.sample(n=1, axis=1) + + with self.assertRaises(NotImplementedError): + kdf.sample(n=1, axis="columns") + + # Check for frac > 1 and replace + kdf = ks.DataFrame({"A": list("abc")}) + msg = "Replace has to be set to `True` when " "upsampling the population `frac` > 1." + with self.assertRaisesRegex(ValueError, msg): + kdf.sample(frac=2, replace=False) + + # Check for frac > 1 and replace + # Make sure the tests run, but we can't check the result because they are non-deterministic. + kdf.sample(frac=2, replace=True) def test_add_prefix(self): pdf = pd.DataFrame({"A": [1, 2, 3, 4], "B": [3, 4, 5, 6]}, index=np.random.rand(4)) diff --git a/databricks/koalas/utils.py b/databricks/koalas/utils.py index 4c353928e8..776aa8db8e 100644 --- a/databricks/koalas/utils.py +++ b/databricks/koalas/utils.py @@ -652,7 +652,7 @@ def is_name_like_value( def validate_axis(axis=0, none_axis=0): """ Check the given axis is valid. """ # convert to numeric axis - axis = {None: none_axis, "index": 0, "columns": 1}.get(axis, axis) + axis = {None: none_axis, "index": 0, "rows": 0, "columns": 1}.get(axis, axis) if axis not in (none_axis, 0, 1): raise ValueError("No axis named {0}".format(axis)) return axis