Attention
Base class for all multi-head self attentions.
Attention(dim=768, num_heads=8, head_dim=64, plugins=[])Parameters
dim: The dimension size.num_heads: The number of attention heads.head_dim: The dimension size for each attention head.plugins: A list ofAttentionPlugins to use.
Forward
(x: jaxtyping.Float[Tensor, '... n d']) -> jaxtyping.Float[Tensor, '... n d']