Attention
Base class for all multi-head self attentions.
Attention(dim=768, num_heads=8, head_dim=64, plugins=[])
Parameters
dim
: The dimension size.num_heads
: The number of attention heads.head_dim
: The dimension size for each attention head.plugins
: A list ofAttentionPlugin
s to use.
Forward
(x: jaxtyping.Float[Tensor, '... n d']) -> jaxtyping.Float[Tensor, '... n d']