Feedforward
Base class for all Feedforward layers.
Feedforward(act_layer=nn.GELU(), dim=768, expansion_factor=4, plugins=[])Parameters
dim: The dimension size.act_layer: Activation layer to be inserted between the two Linear layers.expansion_factor: Factor by which to expanddimin the first Linear layer.plugins: A list ofFeedforwardPlugins to use.
Forward
(x: jaxtyping.Float[Tensor, '... d']) -> jaxtyping.Float[Tensor, '... d']