Skip to content

Commit

Permalink
Suggested version, aggregating logits directly
Browse files Browse the repository at this point in the history
  • Loading branch information
peiva-git committed Nov 30, 2023
1 parent c64f5c0 commit f9b910e
Showing 1 changed file with 10 additions and 35 deletions.
45 changes: 10 additions & 35 deletions basketballtrainer/models/pp_liteseg_rancrops.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,48 +91,23 @@ def forward(self, x):
)
for random_crop, crop_x, crop_y in random_crops
]
# discard background channel and apply softmax
softmax_tensors = [
(
pp.nn.functional.softmax(pp.unsqueeze(logit[:, 1, :, :], axis=1)),
logit_x,
logit_y
)
for logit, logit_x, logit_y in logit_tensors
]
# 3. pad and aggregate
image_height, image_width = pp.shape(x)[2:]
softmax_tensors_padded = [
logits_padded = [
pp.nn.functional.pad(
softmax,
logit,
pad=(
softmax_x,
image_width - softmax_x - pp.shape(softmax).numpy()[3],
softmax_y,
image_height - softmax_y - pp.shape(softmax).numpy()[2]
logit_x,
image_width - logit_x - pp.shape(logit).numpy()[3],
logit_y,
image_height - logit_y - pp.shape(logit).numpy()[2]
),
value=float(pp.min(softmax))
value=float(pp.min(logit))
)
for softmax, softmax_x, softmax_y in softmax_tensors
for logit, logit_x, logit_y in logit_tensors
]
softmax_aggregation = pp.mean(pp.to_tensor(softmax_tensors_padded), axis=0)
# 4. apply detection rule
# Since a two-channel output is expected,
# a properly generated background channel needs to be re-introduced
foreground = pp.where(softmax_aggregation > self.__detection_threshold, 1, 0)
background = pp.where(softmax_aggregation <= self.__detection_threshold, 1, 0)
softmax_aggregation = pp.expand(
softmax_aggregation,
shape=(
softmax_aggregation.shape[0],
self.__num_classes,
softmax_aggregation.shape[2],
softmax_aggregation.shape[3]
)
)
softmax_aggregation[:, 0, :, :] = background.astype('float32')
softmax_aggregation[:, 1, :, :] = foreground.astype('float32')
return [softmax_aggregation]
logit_aggregation = pp.mean(pp.to_tensor(logits_padded), axis=0)
return [logit_aggregation]
else:
return super(PPLiteSegRandomCrops, self).forward(x)

Expand Down

0 comments on commit f9b910e

Please sign in to comment.