Skip to content

Commit

Permalink
Nest the backbone structure config inside a BackboneConfig class
Browse files Browse the repository at this point in the history
The backbone configuration was previously a OneOfConfig, which prevented extra fields from being saved in Tensorboard. This change nests the backbone configuration within a BackboneConfig class, ensuring that all configuration fields are saved in Tensorboard for complete tracking and analysis.

PiperOrigin-RevId: 704758475
  • Loading branch information
LouYu2015 authored and tensorflower-gardener committed Dec 10, 2024
1 parent 564c480 commit 8c5c79c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
27 changes: 19 additions & 8 deletions official/projects/pix2seq/configs/pix2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ class Backbone(backbones.Backbone):
resnet: backbones.ResNet = dataclasses.field(default_factory=backbones.ResNet)
uvit: uvit_backbones.VisionTransformer = dataclasses.field(
default_factory=uvit_backbones.VisionTransformer)


@dataclasses.dataclass
class BackboneConfig(hyperparams.Config):
"""Configuration for backbones."""

backbone: Backbone = dataclasses.field(default_factory=Backbone)
# Whether to freeze this backbone during training.
freeze: bool = False
# The endpoint name of the features to extract from the backbone.
Expand Down Expand Up @@ -126,11 +133,13 @@ class Pix2Seq(hyperparams.Config):
input_size: List[int] = dataclasses.field(default_factory=list)
# Backbones for each image modality. If just using RGB, you should only set
# one backbone.
backbones: List[Backbone] = dataclasses.field(
default_factory=lambda: [
Backbone( # pylint: disable=g-long-lambda
type='resnet',
resnet=backbones.ResNet(model_id=50, bn_trainable=False),
backbones: List[BackboneConfig] = dataclasses.field(
default_factory=lambda: [ # pylint: disable=g-long-lambda
BackboneConfig(
backbone=Backbone(
type='resnet',
resnet=backbones.ResNet(model_id=50, bn_trainable=False),
)
)
]
)
Expand Down Expand Up @@ -182,9 +191,11 @@ def pix2seq_r50_coco() -> cfg.ExperimentConfig:
model=Pix2Seq(
input_size=[640, 640, 3],
backbones=[
Backbone(
type='resnet',
resnet=backbones.ResNet(model_id=50),
BackboneConfig(
backbone=Backbone(
type='resnet',
resnet=backbones.ResNet(model_id=50),
),
norm_activation=common.NormActivation(
norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=True
),
Expand Down
4 changes: 2 additions & 2 deletions official/projects/pix2seq/tasks/pix2seq_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _build_backbones_and_endpoint_names(
for backbone_config in config.backbones:
backbone = backbones_lib.factory.build_backbone(
input_specs=input_specs,
backbone_config=backbone_config,
backbone_config=backbone_config.backbone,
norm_activation_config=backbone_config.norm_activation,
)
backbone.trainable = not backbone_config.freeze
Expand Down Expand Up @@ -134,7 +134,7 @@ def initialize(self, model: tf_keras.Model):
continue

backbone_init_ckpt = self._get_ckpt(backbone_config.init_checkpoint)
if backbone_config.type == 'uvit':
if backbone_config.backbone.type == 'uvit':
# The UVit object has a special function called load_checkpoint.
# The other backbones do not.
backbone.load_checkpoint(ckpt_filepath=backbone_init_ckpt)
Expand Down

0 comments on commit 8c5c79c

Please sign in to comment.