Component Class

Component Class

TorchMix provides the Component class, which is a drop-in replacement for nn.Module with enhanced configuration support. Consider the following example:

from torchmix import Component
 
 
class Model(Component):
    def __init__(self, a: int, b: str, c: list[str]):
        pass

Without Component, you would have to manually write configurations, like this:

from dataclasses import dataclass
 
@dataclass
class ModelA:
    a: int = 1
    b: str = "x"
    c: list[str] = ["foo"]
 
@dataclass
class ModelB:
    a: int = 2
    b: str = "y"
    c: list[str] = ["foo", "bar"]
 
@dataclass
class ModelC:
    a: int = 3
    b: str = "z"
    c: list[str] = ["foo", "bar", "baz"]

These dataclasses as configurations are so-called Structured Configs. They enable both runtime and static type checking for more robust configuration management. See the hydra docs (opens in a new tab) or omegaconf docs (opens in a new tab) for more detail.

However, Writing these configurations manually can be time-consuming and often result in redundant code. With Component, you can just simply instantiate your desired state - and the configs will be just there for you:

Model(a=1, b="x", c=["foo"]).config
Model(a=2, b="y", c=["foo", "bar"]).config
Model(a=3, b="z", c=["foo", "bar", "baz"]).config

You can then directly register these structured configs into hydra's ConfigStore via the store method.

Model(a=1, b="x", c=["foo"]).store(group="model", name="a")
Model(a=2, b="y", c=["foo", "bar"]).store(group="model", name="b")
Model(a=3, b="z", c=["foo", "bar", "baz"]).store(group="model", name="c")

This feature is built on top of hydra-zen (opens in a new tab). Check their docs (opens in a new tab) for more information!

Nested Components

In deep learning, models are always composed of multiple sub-modules. However, writing configurations for those nested modules can be even more confusing and repetitive. Let's say we have these models:

from torchmix import Component
 
 
class Model(Component):
    def __init__(self, a: Component, b: str, c: str):
        pass
 
class SubModel(Component):
    def __init__(self, i: Component, j: float, k: int):
        pass
 
class SubSubModel(Component):
    def __init__(self, p: int, q: int):
        pass

and would use these models like this:

model = Model(
    SubModel(
        SubSubModel(1, 2),
        3e-4,
        32,
    ),
    "adam",
    "imagenet",
)

Let me write the hydra compatible configurations for this setting:

@dataclass
class SubSubModelConfig:
    _target_: str = "your_library.SubSubModel"
    p: int = 1
    q: int = 2
 
@dataclass
class SubModelConfig:
    _target_: str = "your_library.SubModel"
    i: SubSubModelConfig = SubSubModelConfig()
    j: float = 3e-4
    k: int = 32
 
@dataclass
class ModelConfig:
    _target_: str = "your_library.Model"
    a: SubModelConfig = SubModelConfig()
    b: float = 3e-4
    c: int = 32

This does not seem good, right? Component come to the rescue. The truth is, it was already there!

model.config  # 🤯

This would result in the following configuration:

_target_: your_library.Model
a:
  _target_: your_library.SubModel
  i:
    _target_: your_library.SubSubModel
    p: 1
    q: 2
  j: 0.0003
  k: 32
b: adam
c: imagenet

Component allows you to confidently write arbitrarily nested modules and easily integrate them into the hydra ecosystem, as demonstrated below:

nn.Sequential(
    Attach(
        Token(dim=768),
        Add(
            PatchEmbed(patch_size=16),
            PositionEmbed(seq_length=196, dim=768),
        ),
    ),
    Repeat(
        nn.Sequential(
            PreNorm(
                Attention(
                    dim=768,
                    num_heads=12,
                    head_dim=64,
                ),
                dim=768,
            ),
            PreNorm(
                MLP(
                    dim=768,
                    act_layer=nn.GELU(),
                    expansion_factor=4,
                ),
                dim=768,
            ),
        ),
        depth=12,
    ),
    Extract(0),
    nn.Linear(768, 1000),
).store(group="model", name="vit")

All components provided by TorchMix are subclasses of the Component class. We also offer Component version of PyTorch's nn module. Just import nn from TorchMix!

from torchmix import nn

Duplicating Components

This auto-magic configuration generation feature has a handy side effect: you can easily duplicate components with the exactly same configuration:

from hydra.utils import instantiate
from torchmix import Component
 
model: Component = ...  # some complex component
 
model_new = instantiate(model.config)

You can just use the instantiate method of the Component class.

model_new = model.instantiate()

A direct application of this feature is the Repeat container, which creates an nn.Sequential with multiple instances of the same component.

In the example below, not only the Attention is duplicated, but also the plugins passed as arguments, such as CausalMask, RelativePositionBias, DropAttention, and DropProjection, are also duplicated 12 times:

Repeat(
    Attention(
        dim=768,
        num_heads=12,
        head_dim=64,
        plugins=[
            CausalMask(),
            RelativePositionBias(
                num_buckets=256,
                num_heads=12,
                causal=True,
            ),
            DropAttention(p=0.1),
            DropProjection(p=0.1),
        ],
    ),
    depth=12
)