diff --git a/official/projects/pix2seq/configs/pix2seq.py b/official/projects/pix2seq/configs/pix2seq.py index 0cc8b595269..bb711b75808 100644 --- a/official/projects/pix2seq/configs/pix2seq.py +++ b/official/projects/pix2seq/configs/pix2seq.py @@ -99,6 +99,13 @@ class Backbone(backbones.Backbone): resnet: backbones.ResNet = dataclasses.field(default_factory=backbones.ResNet) uvit: uvit_backbones.VisionTransformer = dataclasses.field( default_factory=uvit_backbones.VisionTransformer) + + +@dataclasses.dataclass +class BackboneConfig(hyperparams.Config): + """Configuration for backbones.""" + + backbone: Backbone = dataclasses.field(default_factory=Backbone) # Whether to freeze this backbone during training. freeze: bool = False # The endpoint name of the features to extract from the backbone. @@ -126,11 +133,13 @@ class Pix2Seq(hyperparams.Config): input_size: List[int] = dataclasses.field(default_factory=list) # Backbones for each image modality. If just using RGB, you should only set # one backbone. - backbones: List[Backbone] = dataclasses.field( - default_factory=lambda: [ - Backbone( # pylint: disable=g-long-lambda - type='resnet', - resnet=backbones.ResNet(model_id=50, bn_trainable=False), + backbones: List[BackboneConfig] = dataclasses.field( + default_factory=lambda: [ # pylint: disable=g-long-lambda + BackboneConfig( + backbone=Backbone( + type='resnet', + resnet=backbones.ResNet(model_id=50, bn_trainable=False), + ) ) ] ) @@ -182,9 +191,11 @@ def pix2seq_r50_coco() -> cfg.ExperimentConfig: model=Pix2Seq( input_size=[640, 640, 3], backbones=[ - Backbone( - type='resnet', - resnet=backbones.ResNet(model_id=50), + BackboneConfig( + backbone=Backbone( + type='resnet', + resnet=backbones.ResNet(model_id=50), + ), norm_activation=common.NormActivation( norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=True ), diff --git a/official/projects/pix2seq/tasks/pix2seq_task.py b/official/projects/pix2seq/tasks/pix2seq_task.py index 41b8a38a1cb..212e33f18b0 100644 --- a/official/projects/pix2seq/tasks/pix2seq_task.py +++ b/official/projects/pix2seq/tasks/pix2seq_task.py @@ -57,7 +57,7 @@ def _build_backbones_and_endpoint_names( for backbone_config in config.backbones: backbone = backbones_lib.factory.build_backbone( input_specs=input_specs, - backbone_config=backbone_config, + backbone_config=backbone_config.backbone, norm_activation_config=backbone_config.norm_activation, ) backbone.trainable = not backbone_config.freeze @@ -134,7 +134,7 @@ def initialize(self, model: tf_keras.Model): continue backbone_init_ckpt = self._get_ckpt(backbone_config.init_checkpoint) - if backbone_config.type == 'uvit': + if backbone_config.backbone.type == 'uvit': # The UVit object has a special function called load_checkpoint. # The other backbones do not. backbone.load_checkpoint(ckpt_filepath=backbone_init_ckpt)