diff --git a/.github/workflows/codespeed.yml b/.github/workflows/codespeed.yml index 15db7e6..1dd99f7 100644 --- a/.github/workflows/codespeed.yml +++ b/.github/workflows/codespeed.yml @@ -40,7 +40,7 @@ jobs: run: cd powerboxesrs && cargo codspeed build - name: Run benchmarks - uses: CodSpeedHQ/action@v1 + uses: CodSpeedHQ/action@v2 with: token: ${{ secrets.CODSPEED_TOKEN }} run: pytest bindings/tests/ --codspeed && cd powerboxesrs && cargo codspeed run diff --git a/bindings/python/powerboxes/__init__.py b/bindings/python/powerboxes/__init__.py index db94955..3b6f155 100644 --- a/bindings/python/powerboxes/__init__.py +++ b/bindings/python/powerboxes/__init__.py @@ -136,7 +136,7 @@ def parallel_iou_distance( raise TypeError(_BOXES_NOT_NP_ARRAY) if boxes1.dtype == boxes2.dtype: try: - return _dtype_to_func_parallel_iou_distance[boxes1.dtype](boxes1, boxes2) + return _dtype_to_func_parallel_iou_distance[boxes1.dtype](boxes1, boxes2) # type: ignore except KeyError: raise TypeError( f"Box dtype: {boxes1.dtype} not in supported dtypes {supported_dtypes}" @@ -167,7 +167,7 @@ def parallel_giou_distance( raise TypeError(_BOXES_NOT_NP_ARRAY) if boxes1.dtype == boxes2.dtype: try: - return _dtype_to_func_parallel_giou_distance[boxes1.dtype](boxes1, boxes2) + return _dtype_to_func_parallel_giou_distance[boxes1.dtype](boxes1, boxes2) # type: ignore except KeyError: raise TypeError( f"Box dtype: {boxes1.dtype} not in supported dtypes {supported_dtypes}" @@ -198,7 +198,7 @@ def giou_distance( raise TypeError(_BOXES_NOT_NP_ARRAY) if boxes1.dtype == boxes2.dtype: try: - return _dtype_to_func_giou_distance[boxes1.dtype](boxes1, boxes2) + return _dtype_to_func_giou_distance[boxes1.dtype](boxes1, boxes2) # type: ignore except KeyError: raise TypeError( f"Box dtype: {boxes1.dtype} not in supported dtypes {supported_dtypes}" @@ -229,7 +229,7 @@ def tiou_distance( raise TypeError(_BOXES_NOT_NP_ARRAY) if boxes1.dtype == boxes2.dtype: try: - return _dtype_to_func_tiou_distance[boxes1.dtype](boxes1, boxes2) + return _dtype_to_func_tiou_distance[boxes1.dtype](boxes1, boxes2) # type: ignore except KeyError: raise TypeError( f"Box dtype: {boxes1.dtype} not in supported dtypes {supported_dtypes}" @@ -346,7 +346,7 @@ def remove_small_boxes(boxes: npt.NDArray[T], min_size: float) -> npt.NDArray[T] if not isinstance(boxes, np.ndarray): raise TypeError(_BOXES_NOT_NP_ARRAY) try: - return _dtype_to_func_remove_small_boxes[boxes.dtype](boxes, min_size) + return _dtype_to_func_remove_small_boxes[boxes.dtype](boxes, min_size) # type: ignore except KeyError: raise TypeError( f"Box dtype: {boxes.dtype} not in supported dtypes {supported_dtypes}" @@ -365,7 +365,7 @@ def boxes_areas(boxes: npt.NDArray[T]) -> npt.NDArray[np.float64]: if not isinstance(boxes, np.ndarray): raise TypeError(_BOXES_NOT_NP_ARRAY) try: - return _dtype_to_func_box_areas[boxes.dtype](boxes) + return _dtype_to_func_box_areas[boxes.dtype](boxes) # type: ignore except KeyError: raise TypeError( f"Box dtype: {boxes.dtype} not in supported dtypes {supported_dtypes}" @@ -391,7 +391,7 @@ def box_convert(boxes: npt.NDArray[T], in_fmt: str, out_fmt: str) -> npt.NDArray if not isinstance(boxes, np.ndarray): raise TypeError(_BOXES_NOT_NP_ARRAY) try: - return _dtype_to_func_box_convert[boxes.dtype](boxes, in_fmt, out_fmt) + return _dtype_to_func_box_convert[boxes.dtype](boxes, in_fmt, out_fmt) # type: ignore except KeyError: raise TypeError( f"Box dtype: {boxes.dtype} not in supported dtypes {supported_dtypes}" @@ -438,7 +438,7 @@ def nms( if not isinstance(boxes, np.ndarray) or not isinstance(scores, np.ndarray): raise TypeError("Boxes and scores must be numpy arrays") try: - return _dtype_to_func_nms[boxes.dtype]( + return _dtype_to_func_nms[boxes.dtype]( # type: ignore boxes, scores, iou_threshold, score_threshold ) except KeyError: @@ -473,7 +473,7 @@ def rtree_nms( if not isinstance(boxes, np.ndarray) or not isinstance(scores, np.ndarray): raise TypeError("Boxes and scores must be numpy arrays") try: - return _dtype_to_func_rtree_nms[boxes.dtype]( + return _dtype_to_func_rtree_nms[boxes.dtype]( # type: ignore boxes, scores, iou_threshold, score_threshold ) except KeyError: diff --git a/bindings/src/lib.rs b/bindings/src/lib.rs index 8e6c1ea..c54b234 100644 --- a/bindings/src/lib.rs +++ b/bindings/src/lib.rs @@ -2,6 +2,7 @@ mod utils; use std::fmt::Debug; +use ndarray::Array1; use num_traits::{Bounded, Float, Num, Signed, ToPrimitive}; use numpy::{PyArray1, PyArray2, PyArray3}; use powerboxesrs::{boxes, diou, giou, iou, nms, tiou}; @@ -123,7 +124,7 @@ fn _powerboxes(_py: Python, m: &PyModule) -> PyResult<()> { #[pyfunction] fn masks_to_boxes(_py: Python, masks: &PyArray3) -> PyResult>> { let masks = preprocess_array3(masks); - let boxes = boxes::masks_to_boxes(&masks); + let boxes = boxes::masks_to_boxes(masks); let boxes_as_numpy = utils::array_to_numpy(_py, boxes).unwrap(); return Ok(boxes_as_numpy.to_owned()); } @@ -179,7 +180,7 @@ fn diou_distance_generic( boxes2: &PyArray2, ) -> PyResult>> where - T: Float + numpy::Element, + T: Num + Float + numpy::Element, { let boxes1 = preprocess_boxes(boxes1).unwrap(); let boxes2 = preprocess_boxes(boxes2).unwrap(); @@ -216,7 +217,7 @@ where { let boxes1 = preprocess_boxes(boxes1).unwrap(); let boxes2 = preprocess_boxes(boxes2).unwrap(); - let iou = iou::iou_distance(&boxes1, &boxes2); + let iou = iou::iou_distance(boxes1, boxes2); let iou_as_numpy = utils::array_to_numpy(_py, iou).unwrap(); return Ok(iou_as_numpy.to_owned()); } @@ -806,7 +807,7 @@ where )) } }; - let converted_boxes = boxes::box_convert(&boxes, &in_fmt, &out_fmt); + let converted_boxes = boxes::box_convert(&boxes, in_fmt, out_fmt); let converted_boxes_as_numpy = utils::array_to_numpy(_py, converted_boxes).unwrap(); return Ok(converted_boxes_as_numpy.to_owned()); } @@ -902,12 +903,13 @@ fn nms_generic( score_threshold: f64, ) -> PyResult>> where - T: Num + numpy::Element + PartialOrd + ToPrimitive + Copy, + T: numpy::Element + Num + PartialEq + PartialOrd + ToPrimitive + Copy, { let boxes = preprocess_boxes(boxes).unwrap(); let scores = preprocess_array1(scores); let keep = nms::nms(&boxes, &scores, iou_threshold, score_threshold); - let keep_as_numpy = utils::array_to_numpy(_py, keep).unwrap(); + let keep_as_ndarray = Array1::from(keep); + let keep_as_numpy = utils::array_to_numpy(_py, keep_as_ndarray).unwrap(); return Ok(keep_as_numpy.to_owned()); } #[pyfunction] @@ -1064,21 +1066,23 @@ fn rtree_nms_generic( score_threshold: f64, ) -> PyResult>> where - T: Num - + numpy::Element - + PartialOrd - + ToPrimitive - + Copy + T: numpy::Element + + Num + Signed + Bounded + Debug + + PartialEq + + PartialOrd + + ToPrimitive + + Copy + Sync + Send, { let boxes = preprocess_boxes(boxes).unwrap(); let scores = preprocess_array1(scores); let keep = nms::rtree_nms(&boxes, &scores, iou_threshold, score_threshold); - let keep_as_numpy = utils::array_to_numpy(_py, keep).unwrap(); + let keep_as_ndarray = Array1::from(keep); + let keep_as_numpy = utils::array_to_numpy(_py, keep_as_ndarray).unwrap(); return Ok(keep_as_numpy.to_owned()); } #[pyfunction] diff --git a/bindings/src/utils.rs b/bindings/src/utils.rs index 8ff2b6f..ae3c5a3 100644 --- a/bindings/src/utils.rs +++ b/bindings/src/utils.rs @@ -1,19 +1,44 @@ -use ndarray::{Array1, Array2, Array3, ArrayBase, OwnedRepr}; +use ndarray::{ArrayBase, Dim, OwnedRepr, ViewRepr}; use num_traits::Num; use numpy::{IntoPyArray, PyArray, PyArray1, PyArray2, PyArray3}; use pyo3::prelude::*; -pub fn array_to_numpy( +/// Converts a 2-dimensional Rust ndarray to a NumPy array. +/// +/// # Arguments +/// +/// * `py` - The Python interpreter context. +/// * `array` - The 2-dimensional Rust ndarray to convert. +/// +/// # Returns +/// +/// A reference to the converted NumPy array. +/// +/// # Example +/// +/// ```rust +/// let py = Python::acquire_gil().python(); +/// let array_2d: Array2 = Array2::ones((3, 3)); +/// let numpy_array_2d = array2_to_numpy(py, array_2d).unwrap(); +/// ``` +pub fn array_to_numpy( py: Python, array: ArrayBase, D>, -) -> PyResult<&PyArray> { - let numpy_array: &PyArray = array.into_pyarray(py); +) -> PyResult<&PyArray> +where + T: numpy::Element, + D: ndarray::Dimension, +{ + let numpy_array = array.into_pyarray(py); + return Ok(numpy_array); } -pub fn preprocess_boxes(array: &PyArray2) -> Result, PyErr> +pub fn preprocess_boxes( + array: &PyArray2, +) -> Result, Dim<[usize; 2]>>, PyErr> where - N: Num + numpy::Element + Send, + N: numpy::Element, { let array = unsafe { array.as_array() }; let array_shape = array.shape(); @@ -32,16 +57,14 @@ where } } - let array = array - .to_owned() - .into_shape((array_shape[0], array_shape[1])) - .unwrap(); return Ok(array); } -pub fn preprocess_rotated_boxes(array: &PyArray2) -> Result, PyErr> +pub fn preprocess_rotated_boxes<'a, N>( + array: &PyArray2, +) -> Result, Dim<[usize; 2]>>, PyErr> where - N: Num + numpy::Element + Send, + N: Num + numpy::Element + Send + 'a, { let array = unsafe { array.as_array() }; let array_shape = array.shape(); @@ -60,42 +83,37 @@ where } } - let array = array - .to_owned() - .into_shape((array_shape[0], array_shape[1])) - .unwrap(); return Ok(array); } -pub fn preprocess_array3(array: &PyArray3) -> Array3 +pub fn preprocess_array3<'a, N>(array: &PyArray3) -> ArrayBase, Dim<[usize; 3]>> where - N: numpy::Element, + N: numpy::Element + 'a, { - let array = unsafe { array.as_array().to_owned() }; + let array = unsafe { array.as_array() }; return array; } -pub fn preprocess_array1(array: &PyArray1) -> Array1 +pub fn preprocess_array1<'a, N>(array: &PyArray1) -> ArrayBase, Dim<[usize; 1]>> where - N: numpy::Element, + N: numpy::Element + 'a, { - let array = unsafe { array.as_array().to_owned() }; + let array: ArrayBase, ndarray::prelude::Dim<[usize; 1]>> = + unsafe { array.as_array() }; return array; } #[cfg(test)] mod tests { use super::*; - use ndarray::ArrayBase; + use ndarray::Array1; #[test] fn test_array_to_numpy() { - let data = vec![1., 2., 3., 4.]; - let array = ArrayBase::from_shape_vec((1, 4), data).unwrap(); + let array = Array1::from(vec![1., 2., 3., 4.]); Python::with_gil(|py| { let result = array_to_numpy(py, array).unwrap(); - assert_eq!(result.readonly().shape(), &[1, 4]); - assert_eq!(result.readonly().shape(), &[1, 4]); + assert_eq!(result.readonly().shape(), &[4]); }); } diff --git a/bindings/tests/test_boxes.py b/bindings/tests/test_boxes.py index 1fe38d5..fe5b27f 100644 --- a/bindings/tests/test_boxes.py +++ b/bindings/tests/test_boxes.py @@ -1,7 +1,8 @@ -from powerboxes import masks_to_boxes -import numpy as np import os + +import numpy as np from PIL import Image +from powerboxes import masks_to_boxes def test_masks_box(): diff --git a/powerboxesrs/src/boxes.rs b/powerboxesrs/src/boxes.rs index 649e5d7..bf5bcfe 100644 --- a/powerboxesrs/src/boxes.rs +++ b/powerboxesrs/src/boxes.rs @@ -1,5 +1,7 @@ -use ndarray::{Array1, Array2, Array3, Axis, Zip}; -use num_traits::{Num, ToPrimitive}; +use ndarray::{Array1, Array2, ArrayView2, ArrayView3, ArrayViewMut2, Axis, Zip}; +use num_traits::{real::Real, Num, ToPrimitive}; + +#[derive(Copy, Clone)] pub enum BoxFormat { XYXY, XYWH, @@ -10,7 +12,7 @@ pub enum BoxFormat { /// /// # Arguments /// -/// * `boxes` - A 2D array of boxes represented as an `Array2` in xyxy format. +/// * `boxes` - A 2D array of boxes represented as an `ArrayView2` in xyxy format. /// /// # Returns /// @@ -28,13 +30,14 @@ pub enum BoxFormat { /// /// assert_eq!(areas, array![4., 100.]); /// ``` -pub fn box_areas(boxes: &Array2) -> Array1 +pub fn box_areas<'a, N, BA>(boxes: BA) -> Array1 where - N: Num + PartialEq + ToPrimitive + Copy, + N: Num + PartialEq + ToPrimitive + Copy + 'a, + BA: Into>, { + let boxes = boxes.into(); let num_boxes = boxes.nrows(); let mut areas = Array1::::zeros(num_boxes); - Zip::indexed(&mut areas).for_each(|i, area| { let box1 = boxes.row(i); let area_ = (box1[2] - box1[0]) * (box1[3] - box1[1]); @@ -67,19 +70,21 @@ where /// /// assert_eq!(areas, array![4., 100.]); /// ``` -pub fn parallel_box_areas(boxes: &Array2) -> Array1 +pub fn parallel_box_areas<'a, N, BA>(boxes: BA) -> Array1 where - N: Num + PartialEq + ToPrimitive + Clone + Send + Sync + Copy, + N: Real + Send + Sync + 'a, + BA: Into>, { + let boxes = boxes.into(); let num_boxes = boxes.nrows(); let mut areas = Array1::::zeros(num_boxes); Zip::indexed(&mut areas).par_for_each(|i, area| { let box1 = boxes.row(i); - let x1 = box1[0]; - let y1 = box1[1]; - let x2 = box1[2]; - let y2 = box1[3]; + let x1: N = box1[0]; + let y1: N = box1[1]; + let x2: N = box1[2]; + let y2: N = box1[3]; let _area = (x2 - x1) * (y2 - y1); *area = _area.to_f64().unwrap(); }); @@ -110,10 +115,12 @@ where /// /// assert_eq!(result, array![[0., 0., 10., 10.]]); /// ``` -pub fn remove_small_boxes(boxes: &Array2, min_size: f64) -> Array2 +pub fn remove_small_boxes<'a, N, BA>(boxes: BA, min_size: f64) -> Array2 where - N: Num + PartialEq + ToPrimitive + Clone + Copy, + N: Num + PartialEq + Clone + PartialOrd + ToPrimitive + Copy + 'a, + BA: Into>, { + let boxes = boxes.into(); let areas = box_areas(boxes); let keep: Vec = areas .indexed_iter() @@ -123,7 +130,8 @@ where return boxes.select(Axis(0), &keep); } -/// Converts a 2D array of boxes from one format to another. +/// Converts a 2D array of boxes from one format to another, in-place. +/// This works because all box formats use 4 values in their representations. /// /// # Arguments /// @@ -131,108 +139,126 @@ where /// * `in_fmt` - The input format of the boxes. /// * `out_fmt` - The desired output format of the boxes. /// -/// # Returns -/// -/// A 2D array of boxes in the output format. -/// /// # Example /// /// ``` /// use ndarray::arr2; -/// use powerboxesrs::boxes::{BoxFormat, box_convert}; +/// use powerboxesrs::boxes::{BoxFormat, box_convert_inplace}; /// -/// let boxes = arr2(&[ +/// let mut boxes = arr2(&[ /// [10.0, 20.0, 30.0, 40.0], /// [75.0, 25.0, 100.0, 200.0], /// [100.0, 100.0, 101.0, 101.0], /// ]); -/// let in_fmt = BoxFormat::XYXY; -/// let out_fmt = BoxFormat::CXCYWH; /// let expected_output = arr2(&[ /// [20.0, 30.0, 20.0, 20.0], /// [87.5, 112.5, 25.0, 175.0], /// [100.5, 100.5, 1.0, 1.0], /// ]); -/// let output = box_convert(&boxes, &in_fmt, &out_fmt); -/// assert_eq!(output, expected_output); +/// box_convert_inplace(&mut boxes, BoxFormat::XYXY, BoxFormat::CXCYWH); +/// assert_eq!(boxes, expected_output); /// ``` -pub fn box_convert(boxes: &Array2, in_fmt: &BoxFormat, out_fmt: &BoxFormat) -> Array2 +pub fn box_convert_inplace<'a, N, BA>(boxes: BA, in_fmt: BoxFormat, out_fmt: BoxFormat) where - N: Num + PartialEq + ToPrimitive + Clone + Copy, + N: Num + PartialEq + PartialOrd + ToPrimitive + Clone + Copy + 'a, + BA: Into>, { - let num_boxes: usize = boxes.nrows(); - let mut converted_boxes = Array2::::zeros((num_boxes, 4)); - - Zip::indexed(converted_boxes.rows_mut()).for_each(|i, mut box1| { - let box2 = boxes.row(i); - match (in_fmt, out_fmt) { + boxes + .into() + .rows_mut() + .into_iter() + .for_each(|mut bx| match (in_fmt, out_fmt) { (BoxFormat::XYXY, BoxFormat::XYWH) => { - let x1 = box2[0]; - let y1 = box2[1]; - let x2 = box2[2]; - let y2 = box2[3]; - box1[0] = x1; - box1[1] = y1; - box1[2] = x2 - x1; - box1[3] = y2 - y1; + bx[2] = bx[2] - bx[0]; + bx[3] = bx[3] - bx[1]; } (BoxFormat::XYXY, BoxFormat::CXCYWH) => { - let x1 = box2[0]; - let y1 = box2[1]; - let x2 = box2[2]; - let y2 = box2[3]; - box1[0] = (x1 + x2) / (N::one() + N::one()); - box1[1] = (y1 + y2) / (N::one() + N::one()); - box1[2] = x2 - x1; - box1[3] = y2 - y1; + let x1 = bx[0]; + let y1 = bx[1]; + let x2 = bx[2]; + let y2 = bx[3]; + bx[0] = (x1 + x2) / (N::one() + N::one()); + bx[1] = (y1 + y2) / (N::one() + N::one()); + bx[2] = x2 - x1; + bx[3] = y2 - y1; } (BoxFormat::XYWH, BoxFormat::XYXY) => { - let x1 = box2[0]; - let y1 = box2[1]; - let w = box2[2]; - let h = box2[3]; - box1[0] = x1; - box1[1] = y1; - box1[2] = x1 + w; - box1[3] = y1 + h; + bx[2] = bx[0] + bx[2]; + bx[3] = bx[1] + bx[3]; } (BoxFormat::XYWH, BoxFormat::CXCYWH) => { - let x1 = box2[0]; - let y1 = box2[1]; - let w = box2[2]; - let h = box2[3]; - box1[0] = x1 + w / (N::one() + N::one()); - box1[1] = y1 + h / (N::one() + N::one()); - box1[2] = w; - box1[3] = h; + let w = bx[2]; + let h = bx[3]; + bx[0] = bx[0] + w / (N::one() + N::one()); + bx[1] = bx[1] + h / (N::one() + N::one()); + bx[2] = w; + bx[3] = h; } (BoxFormat::CXCYWH, BoxFormat::XYXY) => { - let cx = box2[0]; - let cy = box2[1]; - let w = box2[2]; - let h = box2[3]; - box1[0] = cx - w / (N::one() + N::one()); - box1[1] = cy - h / (N::one() + N::one()); - box1[2] = cx + w / (N::one() + N::one()); - box1[3] = cy + h / (N::one() + N::one()); + let cx = bx[0]; + let cy = bx[1]; + let wd2 = bx[2] / (N::one() + N::one()); + let hd2 = bx[3] / (N::one() + N::one()); + bx[0] = cx - wd2; + bx[1] = cy - hd2; + bx[2] = cx + wd2; + bx[3] = cy + hd2; } (BoxFormat::CXCYWH, BoxFormat::XYWH) => { - let cx = box2[0]; - let cy = box2[1]; - let w = box2[2]; - let h = box2[3]; - box1[0] = cx - w / (N::one() + N::one()); - box1[1] = cy - h / (N::one() + N::one()); - box1[2] = w; - box1[3] = h; + let w = bx[2]; + let h = bx[3]; + bx[0] = bx[0] - w / (N::one() + N::one()); + bx[1] = bx[1] - h / (N::one() + N::one()); + bx[2] = w; + bx[3] = h; } (BoxFormat::XYXY, BoxFormat::XYXY) => (), (BoxFormat::XYWH, BoxFormat::XYWH) => (), (BoxFormat::CXCYWH, BoxFormat::CXCYWH) => (), - } - }); - return converted_boxes; + }); +} + +/// Converts a 2D array of boxes from one format to another. +/// +/// # Arguments +/// +/// * `boxes` - A 2D array of boxes in the input format. +/// * `in_fmt` - The input format of the boxes. +/// * `out_fmt` - The desired output format of the boxes. +/// +/// # Returns +/// +/// A 2D array of boxes in the output format. +/// +/// # Example +/// +/// ``` +/// use ndarray::arr2; +/// use powerboxesrs::boxes::{BoxFormat, box_convert}; +/// +/// let boxes = arr2(&[ +/// [10.0, 20.0, 30.0, 40.0], +/// [75.0, 25.0, 100.0, 200.0], +/// [100.0, 100.0, 101.0, 101.0], +/// ]); +/// let expected_output = arr2(&[ +/// [20.0, 30.0, 20.0, 20.0], +/// [87.5, 112.5, 25.0, 175.0], +/// [100.5, 100.5, 1.0, 1.0], +/// ]); +/// let output = box_convert(&boxes, BoxFormat::XYXY, BoxFormat::CXCYWH); +/// assert_eq!(output, expected_output); +/// ``` +pub fn box_convert<'a, N, BA>(boxes: BA, in_fmt: BoxFormat, out_fmt: BoxFormat) -> Array2 +where + N: Num + PartialEq + PartialOrd + ToPrimitive + Clone + Copy + 'a, + BA: Into>, +{ + let mut converted_boxes = boxes.into().to_owned(); + box_convert_inplace(&mut converted_boxes, in_fmt, out_fmt); + converted_boxes } + /// Converts a 2D array of boxes from one format to another, in parallel. /// This function is only faster than `box_convert` for large arrays /// @@ -257,23 +283,21 @@ where /// [75.0, 25.0, 100.0, 200.0], /// [100.0, 100.0, 101.0, 101.0], /// ]); -/// let in_fmt = BoxFormat::XYXY; -/// let out_fmt = BoxFormat::CXCYWH; /// let expected_output = arr2(&[ /// [20.0, 30.0, 20.0, 20.0], /// [87.5, 112.5, 25.0, 175.0], /// [100.5, 100.5, 1.0, 1.0], /// ]); -/// let output = parallel_box_convert(&boxes, &in_fmt, &out_fmt); +/// let output = parallel_box_convert(&boxes, BoxFormat::XYXY, BoxFormat::CXCYWH); /// assert_eq!(expected_output, output); /// ``` pub fn parallel_box_convert( boxes: &Array2, - in_fmt: &BoxFormat, - out_fmt: &BoxFormat, + in_fmt: BoxFormat, + out_fmt: BoxFormat, ) -> Array2 where - N: Num + PartialEq + ToPrimitive + Clone + Sync + Send + Copy, + N: Num + PartialEq + PartialOrd + ToPrimitive + Clone + Sync + Send + Copy, { let mut converted_boxes = boxes.clone(); @@ -361,7 +385,11 @@ where /// ]); /// let boxes = masks_to_boxes(&masks); /// assert_eq!(boxes, array![[0, 0, 2, 0], [0, 1, 2, 1], [2, 1, 2, 1]]); -pub fn masks_to_boxes(masks: &Array3) -> Array2 { +pub fn masks_to_boxes<'a, MA>(masks: MA) -> Array2 +where + MA: Into>, +{ + let masks = masks.into(); let num_masks = masks.shape()[0]; let height = masks.shape()[1]; let width = masks.shape()[2]; @@ -415,7 +443,11 @@ pub fn masks_to_boxes(masks: &Array3) -> Array2 { /// /// A 1D array containing the computed areas of each rotated box. /// -pub fn rotated_box_areas(boxes: &Array2) -> Array1 { +pub fn rotated_box_areas<'a, BA>(boxes: BA) -> Array1 +where + BA: Into>, +{ + let boxes = boxes.into(); let n_boxes = boxes.nrows(); let mut areas = Array1::zeros(n_boxes); @@ -430,7 +462,7 @@ pub fn rotated_box_areas(boxes: &Array2) -> Array1 { #[cfg(test)] mod tests { use super::*; - use ndarray::{arr2, arr3, array}; + use ndarray::{arr2, arr3, array, Array3}; #[test] fn test_box_convert_xyxy_to_xywh() { let boxes = arr2(&[ @@ -445,8 +477,8 @@ mod tests { [75.0, 25.0, 25.0, 175.0], [100.0, 100.0, 1.0, 1.0], ]); - let output = box_convert(&boxes, &in_fmt, &out_fmt); - let parallel_output = parallel_box_convert(&boxes, &in_fmt, &out_fmt); + let output = box_convert(&boxes, in_fmt, out_fmt); + let parallel_output = parallel_box_convert(&boxes, in_fmt, out_fmt); assert_eq!(output, expected_output); assert_eq!(output, parallel_output); } @@ -465,8 +497,8 @@ mod tests { [87.5, 112.5, 25.0, 175.0], [100.5, 100.5, 1.0, 1.0], ]); - let output = box_convert(&boxes, &in_fmt, &out_fmt); - let parallel_output = parallel_box_convert(&boxes, &in_fmt, &out_fmt); + let output = box_convert(&boxes, in_fmt, out_fmt); + let parallel_output = parallel_box_convert(&boxes, in_fmt, out_fmt); assert_eq!(output, expected_output); assert_eq!(output, parallel_output); } @@ -485,8 +517,8 @@ mod tests { [75.0, 25.0, 100.0, 200.0], [100.0, 100.0, 101.0, 101.0], ]); - let output = box_convert(&boxes, &in_fmt, &out_fmt); - let parallel_output = parallel_box_convert(&boxes, &in_fmt, &out_fmt); + let output = box_convert(&boxes, in_fmt, out_fmt); + let parallel_output = parallel_box_convert(&boxes, in_fmt, out_fmt); assert_eq!(output, expected_output); assert_eq!(output, parallel_output); } @@ -505,8 +537,8 @@ mod tests { [87.5, 112.5, 25.0, 175.0], [100.5, 100.5, 1.0, 1.0], ]); - let output = box_convert(&boxes, &in_fmt, &out_fmt); - let parallel_output = parallel_box_convert(&boxes, &in_fmt, &out_fmt); + let output = box_convert(&boxes, in_fmt, out_fmt); + let parallel_output = parallel_box_convert(&boxes, in_fmt, out_fmt); assert_eq!(output, expected_output); assert_eq!(output, parallel_output); } @@ -525,8 +557,8 @@ mod tests { [75., 25., 100., 200.], [100., 100., 101., 101.], ]); - let output = box_convert(&boxes, &in_fmt, &out_fmt); - let parallel_output = parallel_box_convert(&boxes, &in_fmt, &out_fmt); + let output = box_convert(&boxes, in_fmt, out_fmt); + let parallel_output = parallel_box_convert(&boxes, in_fmt, out_fmt); assert_eq!(output, expected_output); assert_eq!(output, parallel_output); } @@ -545,8 +577,8 @@ mod tests { [75.0, 25.0, 25.0, 175.0], [100.0, 100.0, 1.0, 1.0], ]); - let output = box_convert(&boxes, &in_fmt, &out_fmt); - let parallel_output = parallel_box_convert(&boxes, &in_fmt, &out_fmt); + let output = box_convert(&boxes, in_fmt, out_fmt); + let parallel_output = parallel_box_convert(&boxes, in_fmt, out_fmt); assert_eq!(output, expected_output); assert_eq!(output, parallel_output); } @@ -558,14 +590,14 @@ mod tests { [75., 25., 100., 200.], [100., 100., 101., 101.], ]); - let xywh = parallel_box_convert(&boxes, &BoxFormat::XYXY, &BoxFormat::XYWH); - let cxcywh = parallel_box_convert(&xywh, &BoxFormat::XYWH, &BoxFormat::CXCYWH); + let xywh = parallel_box_convert(&boxes, BoxFormat::XYXY, BoxFormat::XYWH); + let cxcywh = parallel_box_convert(&xywh, BoxFormat::XYWH, BoxFormat::CXCYWH); assert_eq!( - parallel_box_convert(&cxcywh, &BoxFormat::CXCYWH, &BoxFormat::XYXY), + parallel_box_convert(&cxcywh, BoxFormat::CXCYWH, BoxFormat::XYXY), boxes ); assert_eq!( - parallel_box_convert(&xywh, &BoxFormat::XYWH, &BoxFormat::XYXY), + parallel_box_convert(&xywh, BoxFormat::XYWH, BoxFormat::XYXY), boxes ); } diff --git a/powerboxesrs/src/diou.rs b/powerboxesrs/src/diou.rs index f01b7b0..a2f0418 100644 --- a/powerboxesrs/src/diou.rs +++ b/powerboxesrs/src/diou.rs @@ -1,6 +1,6 @@ use crate::{boxes, utils}; -use ndarray::Array2; -use num_traits::{real::Real, Float, Num, ToPrimitive}; +use ndarray::{Array2, ArrayView2}; +use num_traits::{Float, Num, ToPrimitive}; /// Calculates the intersection over union (DIoU) distance between two sets of bounding boxes. /// https://arxiv.org/pdf/1911.08287.pdf @@ -15,13 +15,16 @@ use num_traits::{real::Real, Float, Num, ToPrimitive}; /// /// A 2D array of shape (N, M) representing the DIoU distance between each pair of bounding boxes /// ``` -pub fn diou_distance(boxes1: &Array2, boxes2: &Array2) -> Array2 +pub fn diou_distance<'a, BA, N>(boxes1: BA, boxes2: BA) -> Array2 where - N: Num + PartialOrd + ToPrimitive + Copy + Float + Real, + N: Num + PartialOrd + ToPrimitive + Float + 'a, + BA: Into>, { + let boxes1 = boxes1.into(); + let boxes2 = boxes2.into(); let num_boxes1 = boxes1.nrows(); let num_boxes2 = boxes2.nrows(); - let two = N::from(2).unwrap(); + let two = N::one() + N::one(); let mut diou_matrix = Array2::::zeros((num_boxes1, num_boxes2)); let areas_boxes1 = boxes::box_areas(&boxes1); let areas_boxes2 = boxes::box_areas(&boxes2); @@ -49,7 +52,7 @@ where let intersection = (x2 - x1) * (y2 - y1); let intersection = intersection.to_f64().unwrap(); let intersection = utils::min(intersection, utils::min(area1, area2)); - let iou = intersection / (area1 + area2 - intersection + utils::EPS); + let iou = intersection / (area1 + area2 - intersection); let center_box1 = [(a1_x1 + a1_x2) / two, (a1_y1 + a1_y2) / two]; let center_box2 = [(a2_x1 + a2_x2) / two, (a2_y1 + a2_y2) / two]; diff --git a/powerboxesrs/src/giou.rs b/powerboxesrs/src/giou.rs index 3e02068..449207b 100644 --- a/powerboxesrs/src/giou.rs +++ b/powerboxesrs/src/giou.rs @@ -1,5 +1,5 @@ -use ndarray::{Array2, Zip}; -use num_traits::{Num, ToPrimitive}; +use ndarray::{Array2, ArrayView2, Zip}; +use num_traits::{real::Real, Num, ToPrimitive}; use rstar::RTree; use crate::{ @@ -31,10 +31,13 @@ use crate::{ /// assert_eq!(giou.shape(), &[2, 3]); /// assert_eq!(giou, array![[0., 1.6800000000000002, 1.7777777777777777], [1.7777777777777777, 1.0793650793650793, 0.]]); /// ``` -pub fn giou_distance(boxes1: &Array2, boxes2: &Array2) -> Array2 +pub fn giou_distance<'a, N, BA>(boxes1: BA, boxes2: BA) -> Array2 where - N: Num + PartialOrd + ToPrimitive + Copy, + N: Num + PartialEq + PartialOrd + ToPrimitive + Copy + 'a, + BA: Into>, { + let boxes1 = boxes1.into(); + let boxes2 = boxes2.into(); let num_boxes1 = boxes1.nrows(); let num_boxes2 = boxes2.nrows(); @@ -66,7 +69,7 @@ where let intersection = (x2 - x1) * (y2 - y1); let intersection = intersection.to_f64().unwrap(); let intersection = utils::min(intersection, utils::min(area1, area2)); - let union = area1 + area2 - intersection + utils::EPS; + let union = area1 + area2 - intersection; (intersection / union, union) }; // Calculate the enclosing box (C) coordinates @@ -110,16 +113,19 @@ where /// assert_eq!(giou.shape(), &[2, 3]); /// assert_eq!(giou, array![[0., 1.6800000000000002, 1.7777777777777777], [1.7777777777777777, 1.0793650793650793, 0.]]); /// ``` -pub fn parallel_giou_distance(boxes1: &Array2, boxes2: &Array2) -> Array2 +pub fn parallel_giou_distance<'a, N, BA>(boxes1: BA, boxes2: BA) -> Array2 where - N: Num + PartialOrd + ToPrimitive + Copy + Sync + Send, + N: Real + Sync + Send + 'a, + BA: Into>, { + let boxes1 = boxes1.into(); + let boxes2 = boxes2.into(); let num_boxes1 = boxes1.nrows(); let num_boxes2 = boxes2.nrows(); let mut giou_matrix = Array2::::zeros((num_boxes1, num_boxes2)); - let areas_boxes1 = boxes::parallel_box_areas(&boxes1); - let areas_boxes2 = boxes::parallel_box_areas(&boxes2); + let areas_boxes1 = boxes::parallel_box_areas(boxes1); + let areas_boxes2 = boxes::parallel_box_areas(boxes2); Zip::indexed(giou_matrix.rows_mut()).par_for_each(|i, mut row| { let a1 = boxes1.row(i); let a1_x1 = a1[0]; @@ -146,7 +152,7 @@ where let intersection = (x2 - x1) * (y2 - y1); let intersection = intersection.to_f64().unwrap(); let intersection = utils::min(intersection, utils::min(area1, area2)); - let union = area1 + area2 - intersection + utils::EPS; + let union = area1 + area2 - intersection; (intersection / union, union) }; // Calculate the enclosing box (C) coordinates @@ -187,7 +193,12 @@ where /// The element at position (i, j) in the matrix represents the rotated Giou distance between the i-th box in `boxes1` and /// the j-th box in `boxes2`. /// -pub fn rotated_giou_distance(boxes1: &Array2, boxes2: &Array2) -> Array2 { +pub fn rotated_giou_distance<'a, BA>(boxes1: BA, boxes2: BA) -> Array2 +where + BA: Into>, +{ + let boxes1 = boxes1.into(); + let boxes2 = boxes2.into(); let num_boxes1 = boxes1.nrows(); let num_boxes2 = boxes2.nrows(); @@ -243,7 +254,7 @@ pub fn rotated_giou_distance(boxes1: &Array2, boxes2: &Array2) -> Arra let rect1 = boxes1_rects[box1.index]; let rect2 = boxes2_rects[box2.index]; let intersection = intersection_area(&rect1, &rect2); - let union = area1 + area2 - intersection + utils::EPS; + let union = area1 + area2 - intersection; // Calculate the enclosing box (C) coordinates let c_x1 = utils::min(box1.x1, box2.x1); let c_y1 = utils::min(box1.y1, box2.y1); diff --git a/powerboxesrs/src/iou.rs b/powerboxesrs/src/iou.rs index b467e5a..2547ae8 100644 --- a/powerboxesrs/src/iou.rs +++ b/powerboxesrs/src/iou.rs @@ -3,7 +3,7 @@ use crate::{ rotation::{intersection_area, minimal_bounding_rect, Rect}, utils, }; -use ndarray::{Array2, Zip}; +use ndarray::{Array2, ArrayView2, Zip}; use num_traits::{Num, ToPrimitive}; use rstar::RTree; @@ -29,10 +29,13 @@ use rstar::RTree; /// let iou = iou_distance(&boxes1, &boxes2); /// assert_eq!(iou, array![[0.8571428571428572, 1.],[1., 0.8571428571428572]]); /// ``` -pub fn iou_distance(boxes1: &Array2, boxes2: &Array2) -> Array2 +pub fn iou_distance<'a, N, BA>(boxes1: BA, boxes2: BA) -> Array2 where - N: Num + PartialOrd + ToPrimitive + Copy, + N: Num + PartialEq + PartialOrd + ToPrimitive + Copy + 'a, + BA: Into>, { + let boxes1 = boxes1.into(); + let boxes2 = boxes2.into(); let num_boxes1 = boxes1.nrows(); let num_boxes2 = boxes2.nrows(); @@ -63,8 +66,7 @@ where let intersection = (x2 - x1) * (y2 - y1); let intersection = intersection.to_f64().unwrap(); let intersection = utils::min(intersection, utils::min(area1, area2)); - iou_matrix[[i, j]] = - utils::ONE - (intersection / (area1 + area2 - intersection + utils::EPS)); + iou_matrix[[i, j]] = utils::ONE - (intersection / (area1 + area2 - intersection)); } } @@ -95,10 +97,13 @@ where /// let iou = parallel_iou_distance(&boxes1, &boxes2); /// assert_eq!(iou, array![[0.8571428571428572, 1.],[1., 0.8571428571428572]]); /// ``` -pub fn parallel_iou_distance(boxes1: &Array2, boxes2: &Array2) -> Array2 +pub fn parallel_iou_distance<'a, N, BA>(boxes1: BA, boxes2: BA) -> Array2 where - N: Num + PartialOrd + ToPrimitive + Copy + Clone + Sync + Send, + N: Num + PartialEq + PartialOrd + ToPrimitive + Send + Sync + Copy + 'a, + BA: Into>, { + let boxes1 = boxes1.into(); + let boxes2 = boxes2.into(); let num_boxes1 = boxes1.nrows(); let num_boxes2 = boxes2.nrows(); @@ -131,7 +136,7 @@ where let intersection = (x2 - x1) * (y2 - y1); let intersection = intersection.to_f64().unwrap(); let intersection = utils::min(intersection, utils::min(area1, area2)); - *d = 1. - (intersection / (area1 + area2 - intersection + utils::EPS)); + *d = 1. - (intersection / (area1 + area2 - intersection)); } }); }); @@ -151,7 +156,13 @@ where /// # Returns /// A 2D array containing the Rotated IoU distance matrix. The element at position (i, j) represents /// the Rotated IoU distance between the i-th box in `boxes1` and the j-th box in `boxes2`. -pub fn rotated_iou_distance(boxes1: &Array2, boxes2: &Array2) -> Array2 { +pub fn rotated_iou_distance<'a, BA>(boxes1: BA, boxes2: BA) -> Array2 +where + BA: Into>, +{ + let boxes1 = boxes1.into(); + let boxes2 = boxes2.into(); + let num_boxes1 = boxes1.nrows(); let num_boxes2 = boxes2.nrows(); @@ -205,7 +216,7 @@ pub fn rotated_iou_distance(boxes1: &Array2, boxes2: &Array2) -> Array let area1 = areas1[box1.index]; let area2 = areas2[box2.index]; let intersection = intersection_area(&boxes1_rects[box1.index], &boxes2_rects[box2.index]); - let union = area1 + area2 - intersection + utils::EPS; + let union = area1 + area2 - intersection; iou_matrix[[box1.index, box2.index]] = utils::ONE - intersection / union; } diff --git a/powerboxesrs/src/lib.rs b/powerboxesrs/src/lib.rs index 6f697a2..9d25a0c 100644 --- a/powerboxesrs/src/lib.rs +++ b/powerboxesrs/src/lib.rs @@ -46,10 +46,10 @@ //! - `rtree_nms`: Non-maximum suppression, returns the indices of the boxes to keep, uses a r-tree internally to avoid quadratic complexity, useful when having many boxes. //! pub mod boxes; +pub mod diou; pub mod giou; pub mod iou; pub mod nms; pub mod rotation; pub mod tiou; mod utils; -pub mod diou; diff --git a/powerboxesrs/src/nms.rs b/powerboxesrs/src/nms.rs index b19327a..1506be7 100644 --- a/powerboxesrs/src/nms.rs +++ b/powerboxesrs/src/nms.rs @@ -1,11 +1,19 @@ // Largely inspired by lsnms: https://github.com/remydubois/lsnms use std::cmp::Ordering; -use crate::{boxes, utils}; -use ndarray::{Array1, Array2, Axis}; +use crate::utils; +use ndarray::{Array1, ArrayView1, ArrayView2, Axis}; use num_traits::{Num, ToPrimitive}; use rstar::{RTree, RTreeNum, AABB}; +#[inline(always)] +pub fn area(bx: N, by: N, bxx: N, byy: N) -> N +where + N: Num + PartialEq + PartialOrd + ToPrimitive, +{ + (bxx - bx) * (byy - by) +} + /// Performs non-maximum suppression (NMS) on a set of bounding boxes using their scores and IoU. /// # Arguments /// @@ -27,71 +35,86 @@ use rstar::{RTree, RTreeNum, AABB}; /// let boxes = arr2(&[[0.0, 0.0, 2.0, 2.0], [1.0, 1.0, 3.0, 3.0]]); /// let scores = Array1::from(vec![1.0, 1.0]); /// let keep = nms(&boxes, &scores, 0.8, 0.0); -/// assert_eq!(keep, Array1::from(vec![0, 1])); +/// assert_eq!(keep, vec![0, 1]); /// ``` -pub fn nms( - boxes: &Array2, - scores: &Array1, +pub fn nms<'a, N, BA, SA>( + boxes: BA, + scores: SA, iou_threshold: f64, score_threshold: f64, -) -> Array1 +) -> Vec where - N: Num + PartialOrd + ToPrimitive + Copy, + N: Num + PartialEq + PartialOrd + ToPrimitive + Copy + PartialEq + 'a, + BA: Into>, + SA: Into>, { - let mut above_score_threshold: Vec = (0..scores.len()).collect(); - if score_threshold > utils::EPS { - // filter out boxes lower than score threshold - above_score_threshold = scores - .iter() - .enumerate() - .filter(|(_, &score)| score >= score_threshold) - .map(|(idx, _)| idx) - .collect(); - } + let boxes = boxes.into(); + let scores = scores.into(); + assert_eq!(boxes.nrows(), scores.len_of(Axis(0))); + + let order: Vec = { + let mut indices: Vec<_> = if score_threshold > utils::ZERO { + // filter out boxes lower than score threshold + scores + .iter() + .enumerate() + .filter(|(_, &score)| score >= score_threshold) + .map(|(idx, _)| idx) + .collect() + } else { + (0..scores.len()).collect() + }; + // sort box indices by scores + indices.sort_unstable_by(|&a, &b| { + scores[b].partial_cmp(&scores[a]).unwrap_or(Ordering::Equal) + }); + indices + }; - let filtered_boxes = boxes.select(Axis(0), &above_score_threshold); - // Compute areas once - let areas = boxes::box_areas(&filtered_boxes); - // sort box indices by scores - above_score_threshold - .sort_unstable_by(|&a, &b| scores[b].partial_cmp(&scores[a]).unwrap_or(Ordering::Equal)); - let order = Array1::from(above_score_threshold); let mut keep: Vec = Vec::new(); - let mut suppress = Array1::from_elem(order.len(), false); + let mut suppress = vec![false; order.len()]; - for i in 0..order.len() { - let idx = order[i]; + for (i, &idx) in order.iter().enumerate() { if suppress[i] { continue; } keep.push(idx); - let area1 = areas[i]; let box1 = boxes.row(idx); + let b1x = box1[0]; + let b1y = box1[1]; + let b1xx = box1[2]; + let b1yy = box1[3]; + let area1 = area(b1x, b1y, b1xx, b1yy); for j in (i + 1)..order.len() { - let idx_j = order[j]; if suppress[j] { continue; } - let area2 = areas[j]; - let box2 = boxes.row(idx_j); + let box2 = boxes.row(order[j]); + let b2x = box2[0]; + let b2y = box2[1]; + let b2xx = box2[2]; + let b2yy = box2[3]; - let mut iou = 0.0; - let x1 = utils::max(box1[0], box2[0]); - let x2 = utils::min(box1[2], box2[2]); - let y1 = utils::max(box1[1], box2[1]); - let y2 = utils::min(box1[3], box2[3]); - if y2 > y1 && x2 > x1 { - let intersection = (x2 - x1) * (y2 - y1); - let intersection = intersection.to_f64().unwrap(); - let intersection = utils::min(intersection, utils::min(area1, area2)); - iou = intersection / (area1 + area2 - intersection + utils::EPS); - } + // Intersection-over-union + let x = utils::max(b1x, b2x); + let y = utils::max(b1y, b2y); + let xx = utils::min(b1xx, b2xx); + let yy = utils::min(b1yy, b2yy); + if x > xx || y > yy { + // Boxes are not intersecting at all + continue; + }; + // Boxes are intersecting + let intersection: N = area(x, y, xx, yy); + let area2: N = area(b2x, b2y, b2xx, b2yy); + let union: N = area1 + area2 - intersection; + let iou: f64 = intersection.to_f64().unwrap() / union.to_f64().unwrap(); if iou > iou_threshold { suppress[j] = true; } } } - return Array1::from(keep); + keep } /// Performs non-maximum suppression (NMS) on a set of bounding using their score and IoU. @@ -119,37 +142,44 @@ where /// let boxes = arr2(&[[0.0, 0.0, 2.0, 2.0], [1.0, 1.0, 3.0, 3.0]]); /// let scores = Array1::from(vec![1.0, 1.0]); /// let keep = rtree_nms(&boxes, &scores, 0.8, 0.0); -/// assert_eq!(keep, Array1::from(vec![0, 1])); +/// assert_eq!(keep, vec![0, 1]); /// ``` -pub fn rtree_nms( - boxes: &Array2, - scores: &Array1, +pub fn rtree_nms<'a, N, BA, SA>( + boxes: BA, + scores: SA, iou_threshold: f64, score_threshold: f64, -) -> Array1 +) -> Vec where - N: RTreeNum + ToPrimitive + Send + Sync, + N: RTreeNum + PartialEq + PartialOrd + ToPrimitive + Copy + PartialEq + Send + Sync + 'a, + BA: Into>, + SA: Into>, { - let mut above_score_threshold: Vec = (0..scores.len()).collect(); - if score_threshold > utils::EPS { - // filter out boxes lower than score threshold - above_score_threshold = scores - .iter() - .enumerate() - .filter(|(_, &score)| score >= score_threshold) - .map(|(idx, _)| idx) - .collect(); - } - // Compute areas once - let areas = boxes::box_areas(&boxes); - // sort box indices by scores - above_score_threshold - .sort_unstable_by(|&a, &b| scores[b].partial_cmp(&scores[a]).unwrap_or(Ordering::Equal)); - let order = Array1::from(above_score_threshold); + let scores = scores.into(); + let boxes = boxes.into(); + let order: Vec = { + let mut indices: Vec<_> = if score_threshold > utils::ZERO { + // filter out boxes lower than score threshold + scores + .iter() + .enumerate() + .filter(|(_, &score)| score >= score_threshold) + .map(|(idx, _)| idx) + .collect() + } else { + (0..scores.len()).collect() + }; + // sort box indices by scores + indices.sort_unstable_by(|&a, &b| { + scores[b].partial_cmp(&scores[a]).unwrap_or(Ordering::Equal) + }); + indices + }; + let mut keep: Vec = Vec::new(); let mut suppress = Array1::from_elem(scores.len(), false); - // build rtree + // build rtree let rtree: RTree> = RTree::bulk_load( order .iter() @@ -171,37 +201,45 @@ where continue; } keep.push(idx); - let area1 = areas[i]; let box1 = boxes.row(idx); - - for bbox in rtree.locate_in_envelope_intersecting(&AABB::from_corners( - [box1[0], box1[1]], - [box1[2], box1[3]], - )) { + let b1x = box1[0]; + let b1y = box1[1]; + let b1xx = box1[2]; + let b1yy = box1[3]; + let area1 = area(b1x, b1y, b1xx, b1yy); + for bbox in + rtree.locate_in_envelope_intersecting(&AABB::from_corners([b1x, b1y], [b1xx, b1yy])) + { let idx_j = bbox.index; if suppress[idx_j] { continue; } - let area2 = areas[idx_j]; let box2 = boxes.row(idx_j); + let b2x = box2[0]; + let b2y = box2[1]; + let b2xx = box2[2]; + let b2yy = box2[3]; - let mut iou = 0.0; - let x1 = utils::max(box1[0], box2[0]); - let x2 = utils::min(box1[2], box2[2]); - let y1 = utils::max(box1[1], box2[1]); - let y2 = utils::min(box1[3], box2[3]); - if y2 > y1 && x2 > x1 { - let intersection = (x2 - x1) * (y2 - y1); - let intersection = intersection.to_f64().unwrap(); - let intersection = f64::min(intersection, f64::min(area1, area2)); - iou = intersection / (area1 + area2 - intersection + utils::EPS); - } + // Intersection-over-union + let x = utils::max(b1x, b2x); + let y = utils::max(b1y, b2y); + let xx = utils::min(b1xx, b2xx); + let yy = utils::min(b1yy, b2yy); + if x > xx || y > yy { + // Boxes are not intersecting at all + continue; + }; + // Boxes are intersecting + let intersection: N = area(x, y, xx, yy); + let area2: N = area(b2x, b2y, b2xx, b2yy); + let union: N = area1 + area2 - intersection; + let iou: f64 = intersection.to_f64().unwrap() / union.to_f64().unwrap(); if iou > iou_threshold { suppress[idx_j] = true; } } } - return Array1::from(keep); + keep } #[cfg(test)] @@ -224,7 +262,7 @@ mod tests { let keep = nms(&boxes, &scores, 0.5, 0.0); let keep_rtree = rtree_nms(&boxes, &scores, 0.5, 0.0); - assert_eq!(keep, Array1::from(vec![0, 2, 4])); + assert_eq!(keep, vec![0, 2, 4]); assert_eq!(keep_rtree, keep); } @@ -236,7 +274,7 @@ mod tests { let keep = nms(&boxes, &scores, 0.5, 1.0); let keep_rtree = rtree_nms(&boxes, &scores, 0.5, 1.0); - assert_eq!(keep, Array1::from(vec![])); + assert_eq!(keep, vec![]); assert_eq!(keep, keep_rtree) } @@ -247,7 +285,7 @@ mod tests { let scores = Array1::from(vec![0.0, 1.0]); let keep = nms(&boxes, &scores, 0.5, 0.5); let keep_rtree = rtree_nms(&boxes, &scores, 0.5, 0.5); - assert_eq!(keep, Array1::from(vec![1])); + assert_eq!(keep, vec![1]); assert_eq!(keep, keep_rtree) } @@ -258,7 +296,7 @@ mod tests { let scores = Array1::from(vec![1.0, 1.0]); let keep = nms(&boxes, &scores, 0.8, 0.0); let keep_rtree = rtree_nms(&boxes, &scores, 0.8, 0.0); - assert_eq!(keep, Array1::from(vec![0, 1])); + assert_eq!(keep, vec![0, 1]); assert_eq!(keep, keep_rtree) } } diff --git a/powerboxesrs/src/tiou.rs b/powerboxesrs/src/tiou.rs index 21c2efa..929a6f2 100644 --- a/powerboxesrs/src/tiou.rs +++ b/powerboxesrs/src/tiou.rs @@ -1,4 +1,4 @@ -use ndarray::Array2; +use ndarray::{Array2, ArrayView2}; use num_traits::{Num, ToPrimitive}; use crate::{ @@ -31,16 +31,19 @@ use crate::{ /// assert_eq!(tiou.shape(), &[2, 3]); /// assert_eq!(tiou, array![[0., 0.84, 0.8888888888888888], [0.8888888888888888, 0.5555555555555556, 0.]]); /// ``` -pub fn tiou_distance(boxes1: &Array2, boxes2: &Array2) -> Array2 +pub fn tiou_distance<'a, N, BA>(boxes1: BA, boxes2: BA) -> Array2 where - N: Num + PartialOrd + ToPrimitive + Copy, + N: Num + PartialEq + PartialOrd + ToPrimitive + Copy + 'a, + BA: Into>, { + let boxes1 = boxes1.into(); + let boxes2 = boxes2.into(); let num_boxes1 = boxes1.nrows(); let num_boxes2 = boxes2.nrows(); let mut tiou_matrix = Array2::::zeros((num_boxes1, num_boxes2)); - let areas_boxes1 = boxes::box_areas(&boxes1); - let areas_boxes2 = boxes::box_areas(&boxes2); + let areas_boxes1 = boxes::box_areas(boxes1); + let areas_boxes2 = boxes::box_areas(boxes2); let boxes1_vecs: Vec<(N, N, N, N)> = boxes1 .rows() .into_iter() @@ -95,13 +98,18 @@ where /// The element at position (i, j) in the matrix represents the rotated Giou distance between the i-th box in `boxes1` and /// the j-th box in `boxes2`. /// -pub fn rotated_tiou_distance(boxes1: &Array2, boxes2: &Array2) -> Array2 { +pub fn rotated_tiou_distance<'a, BA>(boxes1: BA, boxes2: BA) -> Array2 +where + BA: Into>, +{ + let boxes1 = boxes1.into(); + let boxes2 = boxes2.into(); let num_boxes1 = boxes1.nrows(); let num_boxes2 = boxes2.nrows(); let mut iou_matrix = Array2::::ones((num_boxes1, num_boxes2)); - let areas1 = rotated_box_areas(&boxes1); - let areas2 = rotated_box_areas(&boxes2); + let areas1 = rotated_box_areas(boxes1); + let areas2 = rotated_box_areas(boxes2); let boxes1_rects: Vec<(f64, f64, f64, f64)> = boxes1 .rows() diff --git a/powerboxesrs/src/utils.rs b/powerboxesrs/src/utils.rs index 91c8322..c40c183 100644 --- a/powerboxesrs/src/utils.rs +++ b/powerboxesrs/src/utils.rs @@ -1,7 +1,6 @@ use num_traits::{Num, ToPrimitive}; use rstar::{RStarInsertionStrategy, RTreeNum, RTreeObject, RTreeParams, AABB}; -pub const EPS: f64 = 1e-16; pub const ONE: f64 = 1.0; pub const ZERO: f64 = 0.0;