diff --git a/basketballtrainer/models/pp_liteseg_rancrops.py b/basketballtrainer/models/pp_liteseg_rancrops.py index 5fb84ac..d971961 100644 --- a/basketballtrainer/models/pp_liteseg_rancrops.py +++ b/basketballtrainer/models/pp_liteseg_rancrops.py @@ -91,48 +91,23 @@ def forward(self, x): ) for random_crop, crop_x, crop_y in random_crops ] - # discard background channel and apply softmax - softmax_tensors = [ - ( - pp.nn.functional.softmax(pp.unsqueeze(logit[:, 1, :, :], axis=1)), - logit_x, - logit_y - ) - for logit, logit_x, logit_y in logit_tensors - ] # 3. pad and aggregate image_height, image_width = pp.shape(x)[2:] - softmax_tensors_padded = [ + logits_padded = [ pp.nn.functional.pad( - softmax, + logit, pad=( - softmax_x, - image_width - softmax_x - pp.shape(softmax).numpy()[3], - softmax_y, - image_height - softmax_y - pp.shape(softmax).numpy()[2] + logit_x, + image_width - logit_x - pp.shape(logit).numpy()[3], + logit_y, + image_height - logit_y - pp.shape(logit).numpy()[2] ), - value=float(pp.min(softmax)) + value=float(pp.min(logit)) ) - for softmax, softmax_x, softmax_y in softmax_tensors + for logit, logit_x, logit_y in logit_tensors ] - softmax_aggregation = pp.mean(pp.to_tensor(softmax_tensors_padded), axis=0) - # 4. apply detection rule - # Since a two-channel output is expected, - # a properly generated background channel needs to be re-introduced - foreground = pp.where(softmax_aggregation > self.__detection_threshold, 1, 0) - background = pp.where(softmax_aggregation <= self.__detection_threshold, 1, 0) - softmax_aggregation = pp.expand( - softmax_aggregation, - shape=( - softmax_aggregation.shape[0], - self.__num_classes, - softmax_aggregation.shape[2], - softmax_aggregation.shape[3] - ) - ) - softmax_aggregation[:, 0, :, :] = background.astype('float32') - softmax_aggregation[:, 1, :, :] = foreground.astype('float32') - return [softmax_aggregation] + logit_aggregation = pp.mean(pp.to_tensor(logits_padded), axis=0) + return [logit_aggregation] else: return super(PPLiteSegRandomCrops, self).forward(x)