Components
Feedforward

Feedforward

Base class for all Feedforward layers.

Feedforward(act_layer=nn.GELU(), dim=768, expansion_factor=4, plugins=[])

Parameters

  • dim: The dimension size.
  • act_layer: Activation layer to be inserted between the two Linear layers.
  • expansion_factor: Factor by which to expand dim in the first Linear layer.
  • plugins: A list of FeedforwardPlugins to use.

Forward

(x: jaxtyping.Float[Tensor, '... d']) -> jaxtyping.Float[Tensor, '... d']