Component Class
TorchMix provides the Component
class, which is a drop-in replacement for
nn.Module
with enhanced configuration support.
Consider the following example:
from torchmix import Component
class Model(Component):
def __init__(self, a: int, b: str, c: list[str]):
pass
Without Component
, you would have to manually write
configurations, like this:
from dataclasses import dataclass
@dataclass
class ModelA:
a: int = 1
b: str = "x"
c: list[str] = ["foo"]
@dataclass
class ModelB:
a: int = 2
b: str = "y"
c: list[str] = ["foo", "bar"]
@dataclass
class ModelC:
a: int = 3
b: str = "z"
c: list[str] = ["foo", "bar", "baz"]
These dataclasses as configurations are so-called Structured Configs. They enable both runtime and static type checking for more robust configuration management. See the hydra docs (opens in a new tab) or omegaconf docs (opens in a new tab) for more detail.
However, Writing these configurations manually can be time-consuming and often result in
redundant code. With Component
, you can just simply instantiate your desired state -
and the configs will be just there for you:
Model(a=1, b="x", c=["foo"]).config
Model(a=2, b="y", c=["foo", "bar"]).config
Model(a=3, b="z", c=["foo", "bar", "baz"]).config
You can then directly register these structured configs into hydra
's ConfigStore
via the store
method.
Model(a=1, b="x", c=["foo"]).store(group="model", name="a")
Model(a=2, b="y", c=["foo", "bar"]).store(group="model", name="b")
Model(a=3, b="z", c=["foo", "bar", "baz"]).store(group="model", name="c")
This feature is built on top of
hydra-zen
(opens in a new tab). Check
their docs (opens in a new tab) for more
information!
Nested Components
In deep learning, models are always composed of multiple sub-modules. However, writing configurations for those nested modules can be even more confusing and repetitive. Let's say we have these models:
from torchmix import Component
class Model(Component):
def __init__(self, a: Component, b: str, c: str):
pass
class SubModel(Component):
def __init__(self, i: Component, j: float, k: int):
pass
class SubSubModel(Component):
def __init__(self, p: int, q: int):
pass
and would use these models like this:
model = Model(
SubModel(
SubSubModel(1, 2),
3e-4,
32,
),
"adam",
"imagenet",
)
Let me write the hydra
compatible configurations for this setting:
@dataclass
class SubSubModelConfig:
_target_: str = "your_library.SubSubModel"
p: int = 1
q: int = 2
@dataclass
class SubModelConfig:
_target_: str = "your_library.SubModel"
i: SubSubModelConfig = SubSubModelConfig()
j: float = 3e-4
k: int = 32
@dataclass
class ModelConfig:
_target_: str = "your_library.Model"
a: SubModelConfig = SubModelConfig()
b: float = 3e-4
c: int = 32
This does not seem good, right? Component
come to the rescue.
The truth is, it was already there!
model.config # 🤯
This would result in the following configuration:
_target_: your_library.Model
a:
_target_: your_library.SubModel
i:
_target_: your_library.SubSubModel
p: 1
q: 2
j: 0.0003
k: 32
b: adam
c: imagenet
Component
allows you to confidently write arbitrarily nested modules
and easily integrate them into the hydra
ecosystem, as demonstrated below:
nn.Sequential(
Attach(
Token(dim=768),
Add(
PatchEmbed(patch_size=16),
PositionEmbed(seq_length=196, dim=768),
),
),
Repeat(
nn.Sequential(
PreNorm(
Attention(
dim=768,
num_heads=12,
head_dim=64,
),
dim=768,
),
PreNorm(
MLP(
dim=768,
act_layer=nn.GELU(),
expansion_factor=4,
),
dim=768,
),
),
depth=12,
),
Extract(0),
nn.Linear(768, 1000),
).store(group="model", name="vit")
All components provided by TorchMix are subclasses of the Component
class.
We also offer Component
version of PyTorch's nn
module.
Just import nn
from TorchMix!
from torchmix import nn
Duplicating Components
This auto-magic configuration generation feature has a handy side effect: you can easily duplicate components with the exactly same configuration:
from hydra.utils import instantiate
from torchmix import Component
model: Component = ... # some complex component
model_new = instantiate(model.config)
You can just use the instantiate
method of the Component
class.
model_new = model.instantiate()
A direct application of this feature is the Repeat
container,
which creates an nn.Sequential
with multiple instances of the same component.
In the example below, not only the Attention
is duplicated, but also the plugins
passed as arguments, such as CausalMask
, RelativePositionBias
, DropAttention
, and
DropProjection
, are also duplicated 12 times:
Repeat(
Attention(
dim=768,
num_heads=12,
head_dim=64,
plugins=[
CausalMask(),
RelativePositionBias(
num_buckets=256,
num_heads=12,
causal=True,
),
DropAttention(p=0.1),
DropProjection(p=0.1),
],
),
depth=12
)