TORCHMIX 🧩
GitHubGitHub (opens in a new tab)
  • Introduction
  • Component Class
  • Global Configurations
  • Config-only Mode
  • Examples
    • ViT
    • BERT
    • GPT
  • Components
    • Attention
    • WindowAttention
    • Feedforward
    • PositionalEmbedding
    • SinusoidalEmbedding
    • VocabEmbedding
    • PatchEmbedding
    • ClassEmbedding
    • AvgPool
    • ClassPool
    • PatchMerge
    • Add
    • Mul
    • Attach
    • Dropout
    • DropPath
    • StochasticDepth
    • PreNorm
    • PostNorm
    • Repeat
  • Plugins
    • CausalMask
    • DropAttention
    • DropProjection
    • RelativePositionBias
    • RelativePositionBiasViT
    • RotaryEmbedding
    • SubNorm
    • DropActivation
    • DropProjectionIn
    • DropProjectionOut
    • Transpose
  • Introduction
  • Component Class
  • Global Configurations
  • Config-only Mode
  • Examples
    • ViT
    • BERT
    • GPT
  • Components
    • Attention
      • Parameters
      • Forward
    • WindowAttention
    • Feedforward
    • PositionalEmbedding
    • SinusoidalEmbedding
    • VocabEmbedding
    • PatchEmbedding
    • ClassEmbedding
    • AvgPool
    • ClassPool
    • PatchMerge
    • Add
    • Mul
    • Attach
    • Dropout
    • DropPath
    • StochasticDepth
    • PreNorm
    • PostNorm
    • Repeat
  • Plugins
    • CausalMask
    • DropAttention
    • DropProjection
    • RelativePositionBias
    • RelativePositionBiasViT
    • RotaryEmbedding
    • SubNorm
    • DropActivation
    • DropProjectionIn
    • DropProjectionOut
    • Transpose

On This Page

  • Parameters
  • Forward
Question? Give us feedback → (opens in a new tab)Edit this page
Components
Attention

Attention

Base class for all multi-head self attentions.

Attention(dim=768, num_heads=8, head_dim=64, plugins=[])

Parameters

  • dim: The dimension size.
  • num_heads: The number of attention heads.
  • head_dim: The dimension size for each attention head.
  • plugins: A list of AttentionPlugins to use.

Forward

(x: jaxtyping.Float[Tensor, '... n d']) -> jaxtyping.Float[Tensor, '... n d']
GPTWindowAttention