Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding functions for checking masks, minor train_on_gpu changes #41

Open
wants to merge 8 commits into
base: integration_tpu_resnet
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
Binary file not shown.
25 changes: 21 additions & 4 deletions models/official/projects/maskformer/configs/maskformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ class MaskFormerTask(cfg.TaskConfig):
SET_MODEL_BFLOAT16 = False
SET_DATA_BFLOAT16 = True

if not os.environ.get('USE_BFLOAT16_DATA'):
SET_DATA_BFLOAT16 = False


@exp_factory.register_config_factory('maskformer_coco_panoptic')
def maskformer_coco_panoptic() -> cfg.ExperimentConfig:
"""Config to get results that matches the paper."""
Expand All @@ -124,9 +128,22 @@ def maskformer_coco_panoptic() -> cfg.ExperimentConfig:
ckpt_interval = (COCO_TRAIN_EXAMPLES // train_batch_size) * 10 # Don't write ckpts frequently. Slows down the training
image_size = int(os.environ.get('IMG_SIZE'))

steps_per_epoch = COCO_TRAIN_EXAMPLES // train_batch_size
train_steps = 300 * steps_per_epoch # 300 epochs
decay_at = train_steps - 100 * steps_per_epoch # 200 epochs
if os.environ.get('STEPS_PER_EPOCH'):
steps_per_epoch = int(os.environ.get('STEPS_PER_EPOCH'))
else:
steps_per_epoch = COCO_TRAIN_EXAMPLES // train_batch_size

if os.environ.get('NUM_EPOCH'):
train_steps = int(os.environ.get('NUM_EPOCH')) * steps_per_epoch
decay_at = int(2/3 * train_steps)
else:
train_steps = 300 * steps_per_epoch # 300 epochs
decay_at = train_steps - 100 * steps_per_epoch # 200 epochs

# steps_per_epoch = COCO_TRAIN_EXAMPLES // train_batch_size
# train_steps = 300 * steps_per_epoch # 300 epochs
# decay_at = train_steps - 100 * steps_per_epoch # 200 epochs

config = cfg.ExperimentConfig(
task = MaskFormerTask(
init_checkpoint="",
Expand Down Expand Up @@ -179,7 +196,7 @@ def maskformer_coco_panoptic() -> cfg.ExperimentConfig:
)),
trainer=cfg.TrainerConfig(
train_steps=train_steps,
validation_steps=COCO_VAL_EXAMPLES // eval_batch_size,
validation_steps=COCO_VAL_EXAMPLES // eval_batch_size if not os.environ.get('VAL_STEPS') else int(os.environ.get('VAL_STEPS')),
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
Expand Down
8 changes: 4 additions & 4 deletions models/official/projects/maskformer/eval_cpu.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#!/bin/bash
train_bsize=1
eval_bsize=1
export PYTHONPATH=$PYTHONPATH:~/tf-maskformer/models
export MODEL_DIR="gs://cam2-models/maskformer_vishal_exps/EXP20_v8_eval"
export MASKFORMER_CKPT="gs://cam2-models/maskformer_vishal_exps/EXP20_v8/ckpt-18480"
export PYTHONPATH=$PYTHONPATH:/depot/davisjam/data/akshath/MaskFormer_tf/tf-maskformer/models
export MODEL_DIR="gs://cam2-models/maskformer_vishal_exps/EXP26_v8_eval"
export MASKFORMER_CKPT="gs://cam2-models/maskformer_vishal_exps/EXP26_v8/ckpt-482328"
export RESNET_CKPT="gs://cam2-models/maskformer_vishal_exps/resnet50_pretrained/tfmg/ckpt-62400"
export TFRECORDS_DIR="gs://cam2-datasets/coco_panoptic/tfrecords"
export TRAIN_BATCH_SIZE=$train_bsize
Expand All @@ -16,7 +16,7 @@ export OVERRIDES="runtime.distribution_strategy=one_device,runtime.mixed_precisi
task.validation_data.global_batch_size=$EVAL_BATCH_SIZE,task.model.which_pixel_decoder=transformer_fpn,\
task.init_checkpoint_modules=all,\
task.init_checkpoint=$MASKFORMER_CKPT"
python3 models/official/projects/maskformer/train.py \
python3 train.py \
--experiment maskformer_coco_panoptic \
--mode eval \
--model_dir $MODEL_DIR \
Expand Down
2 changes: 1 addition & 1 deletion models/official/projects/maskformer/eval_gpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ task.validation_data.global_batch_size=$EVAL_BATCH_SIZE,\
task.model.which_pixel_decoder=transformer_fpn,\
task.init_checkpoint_modules=all,\
task.init_checkpoint=$MASKFORMER_CKPT"
python3 models/official/projects/maskformer/train.py \
python3 train.py \
--experiment maskformer_coco_panoptic \
--mode eval \
--model_dir $MODEL_DIR \
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
36 changes: 27 additions & 9 deletions models/official/projects/maskformer/losses/maskformer_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def batch(self, y_true, y_pred):
loss = tf.einsum("bnc,bmc->bnm",focal_pos,y_true) + tf.einsum(
"bnc,bmc->bnm", focal_neg,(1 - y_true)
)

return loss/hw


Expand Down Expand Up @@ -88,7 +89,7 @@ def batch(self, y_true, y_pred):
return loss

class Loss:
def __init__(self, num_classes, matcher, eos_coef, cost_class = 1.0, cost_focal = 20.0, cost_dice = 1.0, ignore_label =0):
def __init__(self, num_classes, matcher, eos_coef, cost_class = 1.0, cost_focal = 1.0, cost_dice = 1.0, ignore_label =0):

self.num_classes = num_classes
self.matcher = matcher
Expand Down Expand Up @@ -120,13 +121,13 @@ def memory_efficient_matcher(self, outputs, y_true):
tgt_mask_permuted = tf.reshape(tgt_mask_permuted, [tf.shape(tgt_mask_permuted)[0],tf.shape(tgt_mask_permuted)[1], -1]) # [b, 100, h*w]

cost_focal = FocalLossMod().batch(tgt_mask_permuted, out_mask)
cost_dice = DiceLoss().batch(tgt_mask_permuted, out_mask)
cost_dice = DiceLoss().batch(tgt_mask_permuted, out_mask)


total_cost = (
self.cost_focal * cost_focal
+ self.cost_class * cost_class
+ self.cost_dice * cost_dice
self.cost_focal * cost_focal
+ self.cost_class * cost_class
+ self.cost_dice * cost_dice
)

max_cost = (
Expand All @@ -135,14 +136,26 @@ def memory_efficient_matcher(self, outputs, y_true):
self.cost_dice * 0.0
)


# print('cost_focal')
# print(cost_focal, cost_class, cost_dice)
# print('total_ocst')
# print(total_cost)
# print('max_cost')
# print(max_cost)

# Append highest cost where there are no objects : No object class == 0 (self.ignore_label)
valid = tf.expand_dims(tf.cast(tf.not_equal(tgt_ids, self.ignore_label), dtype=total_cost.dtype), axis=1)
# print('max_cost - ', max_cost)
# print('total_cost before - ', total_cost)

total_cost = (1 - valid) * max_cost + valid * total_cost
# print('total_cost after - ', total_cost)

total_cost = tf.where(
tf.logical_or(tf.math.is_nan(total_cost), tf.math.is_inf(total_cost)),
max_cost * tf.ones_like(total_cost, dtype=total_cost.dtype),
total_cost)


_, inds = matchers.hungarian_matching(total_cost)
indices = tf.stop_gradient(inds)
Expand All @@ -168,10 +181,10 @@ def get_loss(self, outputs, y_true, indices):
num_masks = tf.reduce_sum(tf.cast(tf.logical_not(background), tf.float32), axis=-1)

xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target_classes, logits=cls_assigned)

cls_loss = tf.where(background, self.eos_coef * xentropy, xentropy)

cls_weights = tf.where(background, self.eos_coef * tf.ones_like(cls_loss), tf.ones_like(cls_loss))
# print('Weights: ', cls_weights)

num_masks_per_replica = tf.reduce_sum(num_masks)

Expand All @@ -181,7 +194,10 @@ def get_loss(self, outputs, y_true, indices):
num_masks_sum, cls_weights_sum = replica_context.all_reduce(tf.distribute.ReduceOp.SUM,[num_masks_per_replica, cls_weights_per_replica])

# Final losses
# print('Losses: ', cls_loss)

cls_loss = tf.math.divide_no_nan(tf.reduce_sum(cls_loss), cls_weights_sum)
# print('Final loss given for changing the tvars - ', cls_loss)

out_mask = mask_assigned
tgt_mask = individual_masks
Expand All @@ -206,10 +222,12 @@ def get_loss(self, outputs, y_true, indices):
focal_loss = FocalLossMod(alpha=0.25, gamma=2)(tgt_mask, out_mask)
focal_loss_weighted = tf.where(background, tf.zeros_like(focal_loss), focal_loss)
focal_loss_final = tf.math.divide_no_nan(tf.math.reduce_sum(tf.math.reduce_sum(focal_loss_weighted, axis=-1)), num_masks_sum)
# print(focal_loss_weighted)
dice_loss = DiceLoss()(tgt_mask, out_mask)
dice_loss_weighted = tf.where(background, tf.zeros_like(dice_loss), dice_loss)
dice_loss_final = tf.math.divide_no_nan(tf.math.reduce_sum(tf.math.reduce_sum(dice_loss_weighted, axis=-1)), num_masks_sum)
# print(dice_loss_weighted)
# raise ValueError('2')

return cls_loss, focal_loss_final, dice_loss_final

Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def call(self, inputs):
target_shape = tf.shape(targets)

if mask is not None:

cross_attention_mask = tf.tile(
tf.expand_dims(mask, axis=1), [1, target_shape[1], 1])
self_attention_mask=tf.ones(
Expand Down
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers
from official.nlp.modeling import models

'''
Transformer Parameters:
Expand Down Expand Up @@ -74,6 +77,8 @@ def build(self, input_shape):
# Final Layer
self._layers.append(
tf.keras.layers.Dense(dim[1], activation=None))
# kernel_initializer=tf_utils.clone_initializer(tf.keras.initializers.get('glorot_uniform')),
# bias_initializer=tf_utils.clone_initializer(tf.keras.initializers.get('glorot_uniform')))

def call(self, x):
for layer in self._layers:
Expand Down
6 changes: 4 additions & 2 deletions models/official/projects/maskformer/modeling/maskformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import tensorflow as tf

import numpy as np
import os
from official.projects.maskformer.modeling.decoder.transformer_decoder import MaskFormerTransformer
from official.projects.maskformer.modeling.layers.nn_block import MLPHead
from official.projects.maskformer.modeling.decoder.transformer_pixel_decoder import TransformerFPN
Expand Down Expand Up @@ -150,6 +151,7 @@ def process_feature_maps(self, maps):
def call(self, image, training = False):
backbone_feature_maps = self._backbone(image)
backbone_feature_maps_procesed = self.process_feature_maps(backbone_feature_maps)

if self._pixel_decoder == 'fpn':
mask_features = self.pixel_decoder(backbone_feature_maps_procesed)
transformer_enc_feat = backbone_feature_maps_procesed['5']
Expand All @@ -158,4 +160,4 @@ def call(self, image, training = False):
transformer_features = self.transformer({"features": transformer_enc_feat})
seg_pred = self.head({"per_pixel_embeddings" : mask_features,
"per_segment_embeddings": transformer_features})
return seg_pred
return seg_pred
Loading