WindowAttention
Local window attention layer from Swin Transformer (opens in a new tab).
WindowAttention(dim=96, window_size=8, num_heads=8, head_dim=64)
Parameters
dim
: The dimension size.num_heads
: The number of attention heads.head_dim
: The dimension size for each attention head.window_size
: The window size for local attentions.plugins
: A list ofAttentionPlugin
s to use.
Forward
(x: jaxtyping.Float[Tensor, '... n d']) -> jaxtyping.Float[Tensor, '... n d']