WindowAttention
Local window attention layer from Swin Transformer (opens in a new tab).
WindowAttention(dim=96, window_size=8, num_heads=8, head_dim=64)Parameters
dim: The dimension size.num_heads: The number of attention heads.head_dim: The dimension size for each attention head.window_size: The window size for local attentions.plugins: A list ofAttentionPlugins to use.
Forward
(x: jaxtyping.Float[Tensor, '... n d']) -> jaxtyping.Float[Tensor, '... n d']