Feedforward
Base class for all Feedforward layers.
Feedforward(act_layer=nn.GELU(), dim=768, expansion_factor=4, plugins=[])
Parameters
dim
: The dimension size.act_layer
: Activation layer to be inserted between the two Linear layers.expansion_factor
: Factor by which to expanddim
in the first Linear layer.plugins
: A list ofFeedforwardPlugin
s to use.
Forward
(x: jaxtyping.Float[Tensor, '... d']) -> jaxtyping.Float[Tensor, '... d']