Skip to content

Commit

Permalink
Merge pull request #1610 from mikel-brostrom/move-per-class-decorator…
Browse files Browse the repository at this point in the history
…-into-basetracker

per class decorator moved into basetracker
  • Loading branch information
mikel-brostrom authored Sep 5, 2024
2 parents 69f8483 + 05a0276 commit 7662c5f
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 109 deletions.
82 changes: 79 additions & 3 deletions boxmot/trackers/basetracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ def __init__(
max_age: int = 30,
min_hits: int = 3,
iou_threshold: float = 0.3,
max_obs: int = 50
max_obs: int = 50,
nr_classes: int = 80,
per_class: bool = False
):
"""
Initialize the BaseTracker object with detection threshold, maximum age, minimum hits,
Expand All @@ -33,11 +35,20 @@ def __init__(
self.max_age = max_age
self.max_obs = max_obs
self.min_hits = min_hits
self.per_class = per_class # Track per class or not
self.nr_classes = nr_classes
self.iou_threshold = iou_threshold
self.per_class_active_tracks = {}
self.last_emb_size = None

self.frame_count = 0
self.active_tracks = [] # This might be handled differently in derived classes
self.per_class_active_tracks = None

# Initialize per-class active tracks
if self.per_class:
self.per_class_active_tracks = {}
for i in range(self.nr_classes):
self.per_class_active_tracks[i] = []

if self.max_age >= self.max_obs:
LOGGER.warning("Max age > max observations, increasing size of max observations...")
Expand All @@ -59,6 +70,71 @@ def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) ->
- NotImplementedError: If the subclass does not implement this method.
"""
raise NotImplementedError("The update method needs to be implemented by the subclass.")

def get_class_dets_n_embs(self, dets, embs, cls_id):
# Initialize empty arrays for detections and embeddings
class_dets = np.empty((0, 6))
class_embs = np.empty((0, self.last_emb_size)) if self.last_emb_size is not None else None

# Check if there are detections
if dets.size > 0:
class_indices = np.where(dets[:, 5] == cls_id)[0]
class_dets = dets[class_indices]

if embs is not None:
# Assert that if embeddings are provided, they have the same number of elements as detections
assert dets.shape[0] == embs.shape[0], "Detections and embeddings must have the same number of elements when both are provided"

if embs.size > 0:
class_embs = embs[class_indices]
self.last_emb_size = class_embs.shape[1] # Update the last known embedding size
else:
class_embs = None
return class_dets, class_embs

@staticmethod
def per_class_decorator(update_method):
"""
Decorator for the update method to handle per-class processing.
"""
def wrapper(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None):
if self.per_class:
# Initialize an array to store the tracks for each class
per_class_tracks = []

# same frame count for all classes
frame_count = self.frame_count

for cls_id in range(self.nr_classes):
# Get detections and embeddings for the current class
class_dets, class_embs = self.get_class_dets_n_embs(dets, embs, cls_id)

LOGGER.debug(f"Processing class {int(cls_id)}: {class_dets.shape} with embeddings {class_embs.shape if class_embs is not None else None}")

# Activate the specific active tracks for this class id
self.active_tracks = self.per_class_active_tracks[cls_id]

# Reset frame count for every class
self.frame_count = frame_count

# Update detections using the decorated method
tracks = update_method(self, dets=class_dets, img=img, embs=class_embs)

# Save the updated active tracks
self.per_class_active_tracks[cls_id] = self.active_tracks

if tracks.size > 0:
per_class_tracks.append(tracks)

# Increase frame count by 1
self.frame_count = frame_count + 1

return np.vstack(per_class_tracks) if per_class_tracks else np.empty((0, 8))
else:
# Process all detections at once if per_class is False
return update_method(self, dets=dets, img=img, embs=embs)
return wrapper


def check_inputs(self, dets, img):
assert isinstance(
Expand Down Expand Up @@ -191,7 +267,7 @@ def plot_results(self, img: np.ndarray, show_trajectories: bool, thickness: int
"""

# if values in dict
if self.per_class_active_tracks:
if self.per_class_active_tracks is not None:
for k in self.per_class_active_tracks.keys():
active_tracks = self.per_class_active_tracks[k]
for a in active_tracks:
Expand Down
5 changes: 2 additions & 3 deletions boxmot/trackers/botsort/bot_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
iou_distance, linear_assignment)
from boxmot.utils.ops import xywh2xyxy, xyxy2xywh
from boxmot.trackers.basetracker import BaseTracker
from boxmot.utils import PerClassDecorator


class STrack(BaseTrack):
Expand Down Expand Up @@ -205,7 +204,7 @@ def __init__(
fuse_first_associate: bool = False,
with_reid: bool = True,
):
super().__init__()
super().__init__(per_class=per_class)
self.lost_stracks = [] # type: list[STrack]
self.removed_stracks = [] # type: list[STrack]
BaseTrack.clear_count()
Expand Down Expand Up @@ -233,7 +232,7 @@ def __init__(
self.cmc = SOF()
self.fuse_first_associate = fuse_first_associate

@PerClassDecorator
@BaseTracker.per_class_decorator
def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> np.ndarray:

self.check_inputs(dets, img)
Expand Down
5 changes: 2 additions & 3 deletions boxmot/trackers/bytetrack/byte_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from boxmot.utils.matching import fuse_score, iou_distance, linear_assignment
from boxmot.utils.ops import tlwh2xyah, xywh2tlwh, xywh2xyxy, xyxy2xywh
from boxmot.trackers.basetracker import BaseTracker
from boxmot.utils import PerClassDecorator


class STrack(BaseTrack):
Expand Down Expand Up @@ -125,7 +124,7 @@ def __init__(
frame_rate=30,
per_class=False,
):
super().__init__()
super().__init__(per_class=per_class)
self.active_tracks = [] # type: list[STrack]
self.lost_stracks = [] # type: list[STrack]
self.removed_stracks = [] # type: list[STrack]
Expand All @@ -141,7 +140,7 @@ def __init__(
self.max_time_lost = self.buffer_size
self.kalman_filter = KalmanFilterXYAH()

@PerClassDecorator
@BaseTracker.per_class_decorator
def update(self, dets: np.ndarray, img: np.ndarray = None, embs: np.ndarray = None) -> np.ndarray:

self.check_inputs(dets, img)
Expand Down
5 changes: 2 additions & 3 deletions boxmot/trackers/deepocsort/deep_ocsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from boxmot.utils.association import associate, linear_assignment
from boxmot.utils.iou import get_asso_func
from boxmot.trackers.basetracker import BaseTracker
from boxmot.utils import PerClassDecorator
from boxmot.utils.ops import xyxy2xysr


Expand Down Expand Up @@ -246,7 +245,7 @@ def __init__(
Q_s_scaling=0.0001,
**kwargs
):
super().__init__(max_age=max_age)
super().__init__(max_age=max_age, per_class=per_class)
"""
Sets key parameters for SORT
"""
Expand Down Expand Up @@ -274,7 +273,7 @@ def __init__(
self.cmc_off = cmc_off
self.aw_off = aw_off

@PerClassDecorator
@BaseTracker.per_class_decorator
def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> np.ndarray:
"""
Params:
Expand Down
6 changes: 2 additions & 4 deletions boxmot/trackers/hybridsort/hybridsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@
from boxmot.trackers.hybridsort.association import (
associate_4_points_with_score, associate_4_points_with_score_with_reid,
cal_score_dif_batch_two_score, embedding_distance, linear_assignment)
from boxmot.utils import PerClassDecorator
from boxmot.utils.iou import get_asso_func
from boxmot.trackers.basetracker import BaseTracker
from boxmot.utils import PerClassDecorator


np.random.seed(0)
Expand Down Expand Up @@ -335,7 +333,7 @@ def get_state(self):
class HybridSORT(BaseTracker):
def __init__(self, reid_weights, device, half, det_thresh, per_class=False, max_age=30, min_hits=3,
iou_threshold=0.3, delta_t=3, asso_func="iou", inertia=0.2, longterm_reid_weight=0, TCM_first_step_weight=0, use_byte=False):
super().__init__(max_age=max_age)
super().__init__(max_age=max_age, per_class=per_class)

"""
Sets key parameters for SORT
Expand Down Expand Up @@ -376,7 +374,7 @@ def camera_update(self, trackers, warp_matrix):
for tracker in trackers:
tracker.camera_update(warp_matrix)

@PerClassDecorator
@BaseTracker.per_class_decorator
def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> np.ndarray:
"""
Params:
Expand Down
5 changes: 2 additions & 3 deletions boxmot/trackers/imprassoc/impr_assoc_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
d_iou_distance)
from boxmot.utils.ops import xywh2xyxy, xyxy2xywh
from boxmot.trackers.basetracker import BaseTracker
from boxmot.utils import PerClassDecorator


class STrack(BaseTrack):
Expand Down Expand Up @@ -207,7 +206,7 @@ def __init__(
frame_rate=30,
with_reid: bool = True
):
super().__init__()
super().__init__(per_class=per_class)
self.active_tracks = [] # type: list[STrack]
self.lost_stracks = [] # type: list[STrack]
self.removed_stracks = [] # type: list[STrack]
Expand Down Expand Up @@ -241,7 +240,7 @@ def __init__(
self.cmc = SOF()


@PerClassDecorator
@BaseTracker.per_class_decorator
def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> np.ndarray:
self.check_inputs(dets, img)

Expand Down
5 changes: 2 additions & 3 deletions boxmot/trackers/ocsort/ocsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from boxmot.utils.iou import get_asso_func
from boxmot.utils.iou import run_asso_func
from boxmot.trackers.basetracker import BaseTracker
from boxmot.utils import PerClassDecorator
from boxmot.utils.ops import xyxy2xysr


Expand Down Expand Up @@ -197,7 +196,7 @@ def __init__(
Q_xy_scaling=0.01,
Q_s_scaling=0.0001
):
super().__init__(max_age=max_age)
super().__init__(max_age=max_age, per_class=per_class)
"""
Sets key parameters for SORT
"""
Expand All @@ -215,7 +214,7 @@ def __init__(
self.Q_s_scaling = Q_s_scaling
KalmanBoxTracker.count = 0

@PerClassDecorator
@BaseTracker.per_class_decorator
def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> np.ndarray:
"""
Params:
Expand Down
4 changes: 2 additions & 2 deletions boxmot/trackers/strongsort/strong_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from boxmot.trackers.strongsort.sort.tracker import Tracker
from boxmot.utils.matching import NearestNeighborDistanceMetric
from boxmot.utils.ops import xyxy2tlwh
from boxmot.utils import PerClassDecorator
from boxmot.trackers.basetracker import BaseTracker


class StrongSORT(object):
Expand Down Expand Up @@ -42,7 +42,7 @@ def __init__(
)
self.cmc = get_cmc_method('ecc')()

@PerClassDecorator
@BaseTracker.per_class_decorator
def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> np.ndarray:
assert isinstance(
dets, np.ndarray
Expand Down
85 changes: 0 additions & 85 deletions boxmot/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,88 +23,3 @@

logger.remove()
logger.add(sys.stderr, colorize=True, level="INFO")


class PerClassDecorator:
def __init__(self, method):
# Store the method that will be decorated
self.update = method
self.nr_classes = 80
self.per_class_active_tracks = {}
for i in range(self.nr_classes):
self.per_class_active_tracks[i] = []
self.last_emb_size = None

def get_class_dets_n_embs(self, dets, embs, cls_id):
# Initialize empty arrays for detections and embeddings
class_dets = np.empty((0, 6))
class_embs = np.empty((0, self.last_emb_size)) if self.last_emb_size is not None else None

# Check if there are detections
if dets.size > 0:
class_indices = np.where(dets[:, 5] == cls_id)[0]
class_dets = dets[class_indices]

if embs is not None:
# Assert that if embeddings are provided, they have the same number of elements as detections
assert dets.shape[0] == embs.shape[0], "Detections and embeddings must have the same number of elements when both are provided"

if embs.size > 0:
class_embs = embs[class_indices]
self.last_emb_size = class_embs.shape[1] # Update the last known embedding size
else:
class_embs = None
return class_dets, class_embs

def __get__(self, instance, owner):
# This makes PerClassDecorator a non-data descriptor that binds the method to the instance
def wrapper(*args, **kwargs):
# Unpack arguments for clarity
args = list(args)
dets = args[0]
im = args[1]
embs = args[2] if len(args) > 2 else None

if instance.per_class is True:

# Initialize an array to store the tracks for each class
per_class_tracks = []

# same frame count for all classes
frame_count = instance.frame_count

for i, cls_id in enumerate(range(self.nr_classes)):

class_dets, class_embs = self.get_class_dets_n_embs(dets, embs, cls_id)

logger.debug(f"Processing class {int(cls_id)}: {class_dets.shape} with embeddings {class_embs.shape if class_embs is not None else None}")

# activate the specific active tracks for this class id
instance.active_tracks = self.per_class_active_tracks[cls_id]

# reset frame count for every class
instance.frame_count = frame_count

# Update detections using the decorated method
tracks = self.update(instance, dets=class_dets, img=im, embs=class_embs)

# save the updated active tracks
self.per_class_active_tracks[cls_id] = instance.active_tracks

if tracks.size > 0:
per_class_tracks.append(tracks)

# when all active tracks lists have been updated
instance.per_class_active_tracks = self.per_class_active_tracks

# increase frame count by 1
instance.frame_count = frame_count + 1

tracks = np.vstack(per_class_tracks) if per_class_tracks else np.empty((0, 8))
else:
# Process all detections at once if per_class is False or detections are empty
tracks = self.update(instance, dets=dets, img=im, embs=embs)

return tracks

return wrapper

0 comments on commit 7662c5f

Please sign in to comment.