Components
WindowAttention

WindowAttention

Local window attention layer from Swin Transformer (opens in a new tab).

WindowAttention(dim=96, window_size=8, num_heads=8, head_dim=64)

Parameters

  • dim: The dimension size.
  • num_heads: The number of attention heads.
  • head_dim: The dimension size for each attention head.
  • window_size: The window size for local attentions.
  • plugins: A list of AttentionPlugins to use.

Forward

(x: jaxtyping.Float[Tensor, '... n d']) -> jaxtyping.Float[Tensor, '... n d']