Skip to content

ViT (Vision Transformer)

The ViT (Vision Transformer) is a transformer-based neural network architecture for image classification. It divides an image into fixed-size patches, linearly embeds each patch, adds position embeddings, and processes the resulting sequence of vectors through a standard transformer encoder.

The ViT model was introduced in the paper "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" and has shown strong performance on image classification benchmarks.

Flash / Splash Attention

ViT supports hardware-accelerated attention via Tokamax:

Backend Hardware Notes
"mosaic" NVIDIA H100 (SM90) / B100 (SM100) Pallas Mosaic GPU kernel
"triton" Any NVIDIA GPU Pallas Triton kernel
"cudnn" NVIDIA GPU Via JAX-NN / cuDNN
"mosaic_tpu" TPU v5 / v7 Splash attention (block-sparse)
"xla_chunked" GPU / TPU Flash-style chunked XLA
"xla" Any Standard XLA fallback
import jimm

# GPU: try H100 Mosaic kernel, fall back to Triton, then XLA
model = jimm.VisionTransformer.from_pretrained("google/vit-base-patch16-224",
                                                attention_fn=jimm.make_tokamax_attention(["mosaic", "triton", "xla"]))

# TPU: try Splash attention, fall back to chunked XLA
model = jimm.VisionTransformer.from_pretrained("google/vit-base-patch16-224",
                                                attention_fn=jimm.make_tokamax_attention(["mosaic_tpu", "xla_chunked"]))

Note: Flash/Splash attention does not provide a speedup at typical ViT context lengths (e.g. 196 tokens for 224px/16px). The primary benefit is memory reduction at longer sequence lengths.

FSDP / Explicit Sharding

ViT supports JAX explicit sharding (FSDP-style) via ViTSharding. Large weight matrices are sharded on the contracting (in_features) dimension so that activations carry only the batch-axis sharding.

from jax.experimental import mesh_utils
from jax.sharding import AxisType, Mesh
import jax

n_devices = jax.device_count()
mesh = Mesh(
    mesh_utils.create_device_mesh((1, n_devices)),
    ("data", "fsdp"),
    axis_types=(AxisType.Explicit, AxisType.Explicit),
)
jax.set_mesh(mesh)

model = jimm.VisionTransformer.from_pretrained("google/vit-base-patch16-224")

ViTSharding specs represent per-layer shapes. The Transformer stack prepends None for the scan axis to Variable metadata after nnx.vmap, so the optimizer receives the correct stacked spec natively.

To disable sharding, pass sharding=jimm.common.sharding.NoSharding().

jimm.models.vit.VisionTransformer

Bases: Module

Vision Transformer (ViT) model for image classification.

This implements the Vision Transformer as described in the paper "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale"

Source code in src/jimm/models/vit/vit_model.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
class VisionTransformer(nnx.Module):
    """Vision Transformer (ViT) model for image classification.

    This implements the Vision Transformer as described in the paper
    "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale"
    """

    def __init__(
        self,
        num_classes: int = 1000,
        in_channels: int = 3,
        img_size: int = 224,
        patch_size: int = 16,
        num_layers: int = 12,
        num_heads: int = 12,
        mlp_dim: int = 3072,
        hidden_size: int = 768,
        dropout_rate: float = 0.1,
        use_quick_gelu: bool = False,
        use_gradient_checkpointing: bool = False,
        attention_fn: Callable[..., Any] | None = None,
        do_classification: bool = True,
        rngs: rnglib.Rngs | None = None,
        dtype: DTypeLike = jnp.float32,
        param_dtype: DTypeLike = jnp.float32,
        sharding: ShardingSpec = ViTSharding,
    ) -> None:
        """Initialize a Vision Transformer.

        Args:
            num_classes (int, optional): Number of output classes. Defaults to 1000.
            in_channels (int, optional): Number of input channels. Defaults to 3.
            img_size (int, optional): Size of the input image (assumed square). Defaults to 224.
            patch_size (int, optional): Size of each patch (assumed square). Defaults to 16.
            num_layers (int, optional): Number of transformer layers. Defaults to 12.
            num_heads (int, optional): Number of attention heads. Defaults to 12.
            mlp_dim (int, optional): Size of the MLP dimension. Defaults to 3072.
            hidden_size (int, optional): Size of the hidden dimension. Defaults to 768.
            dropout_rate (float, optional): Dropout rate. Defaults to 0.1.
            use_quick_gelu (bool, optional): Whether to use quickgelu instead of gelu. Defaults to False.
            use_gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False.
            attention_fn (Callable[..., Any] | None, optional): Custom attention function (e.g. jimm.tokamax_attention). Defaults to None.
            do_classification (bool, optional): Whether to include the final classification head. Defaults to True.
            rngs (rnglib.Rngs | None, optional): Random number generator keys. If None, initializes to nnx.Rngs(0).
            dtype (DTypeLike, optional): Data type for computations. Defaults to jnp.float32.
            param_dtype (DTypeLike, optional): Data type for parameters. Defaults to jnp.float32.
            sharding (ShardingSpec, optional): Sharding specification for parameters. Defaults to ViTSharding.
        """
        if rngs is None:
            rngs = nnx.Rngs(0)
        self.do_classification = do_classification
        self._original_config = None
        self.encoder = VisionTransformerBase(
            img_size=img_size,
            patch_size=patch_size,
            in_channels=in_channels,
            hidden_size=hidden_size,
            num_layers=num_layers,
            num_heads=num_heads,
            mlp_dim=mlp_dim,
            dropout_rate=dropout_rate,
            use_quick_gelu=use_quick_gelu,
            use_gradient_checkpointing=use_gradient_checkpointing,
            attention_fn=attention_fn,
            use_pre_norm=False,
            use_patch_bias=True,
            layernorm_epsilon=1e-12,
            rngs=rngs,
            dtype=dtype,
            param_dtype=param_dtype,
            sharding=sharding,
        )

        if self.do_classification:
            self.classifier = nnx.Linear(
                hidden_size,
                num_classes,
                dtype=dtype,
                param_dtype=param_dtype,
                rngs=rngs,
                kernel_init=nnx.with_partitioning(
                    nnx.initializers.xavier_uniform(),
                    sharding.proj_kernel,
                ),
                bias_init=nnx.with_partitioning(
                    nnx.initializers.zeros_init(),
                    sharding.proj_bias,
                ),
            )

    def __call__(self, x: Float[Array, "batch height width channels"]) -> Float[Array, "batch num_classes"]:
        """Forward pass of the Vision Transformer.

        Args:
            x (Float[Array, "batch height width channels"]): Input tensor with shape [batch, height, width, channels]

        Returns:
            Float[Array, "batch num_classes"]: Output logits with shape [batch, num_classes]
        """
        x = self.encoder(x)
        if self.do_classification:
            kernel_spec = sharding_of(self.classifier.kernel[...]).spec
            logits_sharding = named_sharding_like(x, P(sharding_of(x).spec[0], kernel_spec[1]))
            return self.classifier(x, out_sharding=logits_sharding)
        return x

    @classmethod
    def from_pretrained(
        cls,
        model_name_or_path: str,
        use_pytorch: bool = False,
        rngs: rnglib.Rngs | None = None,
        dtype: DTypeLike = jnp.float32,
        param_dtype: DTypeLike = jnp.float32,
        sharding: ShardingSpec = ViTSharding,
        use_gradient_checkpointing: bool = False,
        attention_fn: Callable[..., Any] | None = None,
    ) -> "VisionTransformer":
        """Load a pretrained Vision Transformer from a local path or HuggingFace Hub.

        Args:
            model_name_or_path (str): Path to local weights or HuggingFace model ID.
            use_pytorch (bool): Whether to load from PyTorch weights. Defaults to False.
            rngs (rnglib.Rngs | None): Random number generator keys. If None, initializes to nnx.Rngs(0).
            dtype (DTypeLike): Data type for computations. Defaults to jnp.float32.
            param_dtype (DTypeLike): Data type for parameters. Defaults to jnp.float32.
            sharding (ShardingSpec): Sharding specification for parameters. Defaults to ViTSharding.
            use_gradient_checkpointing (bool): Whether to use gradient checkpointing. Defaults to False.
            attention_fn (Callable[..., Any] | None): Custom attention function. Defaults to None.

        Returns:
            VisionTransformer: Initialized Vision Transformer with pretrained weights
        """
        from .params import load_from_pretrained

        return load_from_pretrained(cls, model_name_or_path, use_pytorch, rngs, dtype, param_dtype, sharding, use_gradient_checkpointing, attention_fn)

    @classmethod
    def from_config(
        cls,
        config: dict[str, Any],
        *,
        rngs: rnglib.Rngs | None = None,
        dtype: DTypeLike = jnp.float32,
        param_dtype: DTypeLike = jnp.float32,
        sharding: ShardingSpec = ViTSharding,
        use_gradient_checkpointing: bool = False,
        attention_fn: Callable[..., Any] | None = None,
    ) -> "VisionTransformer":
        """Create model from HuggingFace-compatible config dict.

        Args:
            config: Configuration dictionary in HuggingFace ViT format.
            rngs: Random number generator state. If None, initializes to nnx.Rngs(0).
            dtype: Data type for computations.
            param_dtype: Data type for parameters.
            sharding: Sharding specification for parameters.
            use_gradient_checkpointing: Enable gradient checkpointing.
            attention_fn: Custom attention function. Defaults to None.

        Returns:
            VisionTransformer with random weights.
        """
        if rngs is None:
            rngs = nnx.Rngs(0)
        num_classes = len(config["id2label"]) if "id2label" in config else config.get("num_labels", 1000)
        use_quick_gelu = config.get("hidden_act") == "quick_gelu"

        return cls(
            num_classes=num_classes,
            img_size=config["image_size"],
            patch_size=config["patch_size"],
            num_layers=config["num_hidden_layers"],
            num_heads=config["num_attention_heads"],
            mlp_dim=config["intermediate_size"],
            hidden_size=config["hidden_size"],
            use_quick_gelu=use_quick_gelu,
            use_gradient_checkpointing=use_gradient_checkpointing,
            attention_fn=attention_fn,
            dtype=dtype,
            param_dtype=param_dtype,
            rngs=rngs,
            sharding=sharding,
        )

    def save_pretrained(self, save_directory: str):
        """Save the model weights and config in HuggingFace format.

        Args:
            save_directory (str): Directory path where the model will be saved.
        """
        from .params import save_pretrained

        save_pretrained(self, save_directory)

__call__(x)

Forward pass of the Vision Transformer.

Parameters:

Name Type Description Default
x Float[Array, 'batch height width channels']

Input tensor with shape [batch, height, width, channels]

required

Returns:

Type Description
Float[Array, 'batch num_classes']

Float[Array, "batch num_classes"]: Output logits with shape [batch, num_classes]

Source code in src/jimm/models/vit/vit_model.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def __call__(self, x: Float[Array, "batch height width channels"]) -> Float[Array, "batch num_classes"]:
    """Forward pass of the Vision Transformer.

    Args:
        x (Float[Array, "batch height width channels"]): Input tensor with shape [batch, height, width, channels]

    Returns:
        Float[Array, "batch num_classes"]: Output logits with shape [batch, num_classes]
    """
    x = self.encoder(x)
    if self.do_classification:
        kernel_spec = sharding_of(self.classifier.kernel[...]).spec
        logits_sharding = named_sharding_like(x, P(sharding_of(x).spec[0], kernel_spec[1]))
        return self.classifier(x, out_sharding=logits_sharding)
    return x

__init__(num_classes=1000, in_channels=3, img_size=224, patch_size=16, num_layers=12, num_heads=12, mlp_dim=3072, hidden_size=768, dropout_rate=0.1, use_quick_gelu=False, use_gradient_checkpointing=False, attention_fn=None, do_classification=True, rngs=None, dtype=jnp.float32, param_dtype=jnp.float32, sharding=ViTSharding)

Initialize a Vision Transformer.

Parameters:

Name Type Description Default
num_classes int

Number of output classes. Defaults to 1000.

1000
in_channels int

Number of input channels. Defaults to 3.

3
img_size int

Size of the input image (assumed square). Defaults to 224.

224
patch_size int

Size of each patch (assumed square). Defaults to 16.

16
num_layers int

Number of transformer layers. Defaults to 12.

12
num_heads int

Number of attention heads. Defaults to 12.

12
mlp_dim int

Size of the MLP dimension. Defaults to 3072.

3072
hidden_size int

Size of the hidden dimension. Defaults to 768.

768
dropout_rate float

Dropout rate. Defaults to 0.1.

0.1
use_quick_gelu bool

Whether to use quickgelu instead of gelu. Defaults to False.

False
use_gradient_checkpointing bool

Whether to use gradient checkpointing. Defaults to False.

False
attention_fn Callable[..., Any] | None

Custom attention function (e.g. jimm.tokamax_attention). Defaults to None.

None
do_classification bool

Whether to include the final classification head. Defaults to True.

True
rngs Rngs | None

Random number generator keys. If None, initializes to nnx.Rngs(0).

None
dtype DTypeLike

Data type for computations. Defaults to jnp.float32.

float32
param_dtype DTypeLike

Data type for parameters. Defaults to jnp.float32.

float32
sharding ShardingSpec

Sharding specification for parameters. Defaults to ViTSharding.

ViTSharding
Source code in src/jimm/models/vit/vit_model.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def __init__(
    self,
    num_classes: int = 1000,
    in_channels: int = 3,
    img_size: int = 224,
    patch_size: int = 16,
    num_layers: int = 12,
    num_heads: int = 12,
    mlp_dim: int = 3072,
    hidden_size: int = 768,
    dropout_rate: float = 0.1,
    use_quick_gelu: bool = False,
    use_gradient_checkpointing: bool = False,
    attention_fn: Callable[..., Any] | None = None,
    do_classification: bool = True,
    rngs: rnglib.Rngs | None = None,
    dtype: DTypeLike = jnp.float32,
    param_dtype: DTypeLike = jnp.float32,
    sharding: ShardingSpec = ViTSharding,
) -> None:
    """Initialize a Vision Transformer.

    Args:
        num_classes (int, optional): Number of output classes. Defaults to 1000.
        in_channels (int, optional): Number of input channels. Defaults to 3.
        img_size (int, optional): Size of the input image (assumed square). Defaults to 224.
        patch_size (int, optional): Size of each patch (assumed square). Defaults to 16.
        num_layers (int, optional): Number of transformer layers. Defaults to 12.
        num_heads (int, optional): Number of attention heads. Defaults to 12.
        mlp_dim (int, optional): Size of the MLP dimension. Defaults to 3072.
        hidden_size (int, optional): Size of the hidden dimension. Defaults to 768.
        dropout_rate (float, optional): Dropout rate. Defaults to 0.1.
        use_quick_gelu (bool, optional): Whether to use quickgelu instead of gelu. Defaults to False.
        use_gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False.
        attention_fn (Callable[..., Any] | None, optional): Custom attention function (e.g. jimm.tokamax_attention). Defaults to None.
        do_classification (bool, optional): Whether to include the final classification head. Defaults to True.
        rngs (rnglib.Rngs | None, optional): Random number generator keys. If None, initializes to nnx.Rngs(0).
        dtype (DTypeLike, optional): Data type for computations. Defaults to jnp.float32.
        param_dtype (DTypeLike, optional): Data type for parameters. Defaults to jnp.float32.
        sharding (ShardingSpec, optional): Sharding specification for parameters. Defaults to ViTSharding.
    """
    if rngs is None:
        rngs = nnx.Rngs(0)
    self.do_classification = do_classification
    self._original_config = None
    self.encoder = VisionTransformerBase(
        img_size=img_size,
        patch_size=patch_size,
        in_channels=in_channels,
        hidden_size=hidden_size,
        num_layers=num_layers,
        num_heads=num_heads,
        mlp_dim=mlp_dim,
        dropout_rate=dropout_rate,
        use_quick_gelu=use_quick_gelu,
        use_gradient_checkpointing=use_gradient_checkpointing,
        attention_fn=attention_fn,
        use_pre_norm=False,
        use_patch_bias=True,
        layernorm_epsilon=1e-12,
        rngs=rngs,
        dtype=dtype,
        param_dtype=param_dtype,
        sharding=sharding,
    )

    if self.do_classification:
        self.classifier = nnx.Linear(
            hidden_size,
            num_classes,
            dtype=dtype,
            param_dtype=param_dtype,
            rngs=rngs,
            kernel_init=nnx.with_partitioning(
                nnx.initializers.xavier_uniform(),
                sharding.proj_kernel,
            ),
            bias_init=nnx.with_partitioning(
                nnx.initializers.zeros_init(),
                sharding.proj_bias,
            ),
        )

from_config(config, *, rngs=None, dtype=jnp.float32, param_dtype=jnp.float32, sharding=ViTSharding, use_gradient_checkpointing=False, attention_fn=None) classmethod

Create model from HuggingFace-compatible config dict.

Parameters:

Name Type Description Default
config dict[str, Any]

Configuration dictionary in HuggingFace ViT format.

required
rngs Rngs | None

Random number generator state. If None, initializes to nnx.Rngs(0).

None
dtype DTypeLike

Data type for computations.

float32
param_dtype DTypeLike

Data type for parameters.

float32
sharding ShardingSpec

Sharding specification for parameters.

ViTSharding
use_gradient_checkpointing bool

Enable gradient checkpointing.

False
attention_fn Callable[..., Any] | None

Custom attention function. Defaults to None.

None

Returns:

Type Description
VisionTransformer

VisionTransformer with random weights.

Source code in src/jimm/models/vit/vit_model.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
@classmethod
def from_config(
    cls,
    config: dict[str, Any],
    *,
    rngs: rnglib.Rngs | None = None,
    dtype: DTypeLike = jnp.float32,
    param_dtype: DTypeLike = jnp.float32,
    sharding: ShardingSpec = ViTSharding,
    use_gradient_checkpointing: bool = False,
    attention_fn: Callable[..., Any] | None = None,
) -> "VisionTransformer":
    """Create model from HuggingFace-compatible config dict.

    Args:
        config: Configuration dictionary in HuggingFace ViT format.
        rngs: Random number generator state. If None, initializes to nnx.Rngs(0).
        dtype: Data type for computations.
        param_dtype: Data type for parameters.
        sharding: Sharding specification for parameters.
        use_gradient_checkpointing: Enable gradient checkpointing.
        attention_fn: Custom attention function. Defaults to None.

    Returns:
        VisionTransformer with random weights.
    """
    if rngs is None:
        rngs = nnx.Rngs(0)
    num_classes = len(config["id2label"]) if "id2label" in config else config.get("num_labels", 1000)
    use_quick_gelu = config.get("hidden_act") == "quick_gelu"

    return cls(
        num_classes=num_classes,
        img_size=config["image_size"],
        patch_size=config["patch_size"],
        num_layers=config["num_hidden_layers"],
        num_heads=config["num_attention_heads"],
        mlp_dim=config["intermediate_size"],
        hidden_size=config["hidden_size"],
        use_quick_gelu=use_quick_gelu,
        use_gradient_checkpointing=use_gradient_checkpointing,
        attention_fn=attention_fn,
        dtype=dtype,
        param_dtype=param_dtype,
        rngs=rngs,
        sharding=sharding,
    )

from_pretrained(model_name_or_path, use_pytorch=False, rngs=None, dtype=jnp.float32, param_dtype=jnp.float32, sharding=ViTSharding, use_gradient_checkpointing=False, attention_fn=None) classmethod

Load a pretrained Vision Transformer from a local path or HuggingFace Hub.

Parameters:

Name Type Description Default
model_name_or_path str

Path to local weights or HuggingFace model ID.

required
use_pytorch bool

Whether to load from PyTorch weights. Defaults to False.

False
rngs Rngs | None

Random number generator keys. If None, initializes to nnx.Rngs(0).

None
dtype DTypeLike

Data type for computations. Defaults to jnp.float32.

float32
param_dtype DTypeLike

Data type for parameters. Defaults to jnp.float32.

float32
sharding ShardingSpec

Sharding specification for parameters. Defaults to ViTSharding.

ViTSharding
use_gradient_checkpointing bool

Whether to use gradient checkpointing. Defaults to False.

False
attention_fn Callable[..., Any] | None

Custom attention function. Defaults to None.

None

Returns:

Name Type Description
VisionTransformer VisionTransformer

Initialized Vision Transformer with pretrained weights

Source code in src/jimm/models/vit/vit_model.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
@classmethod
def from_pretrained(
    cls,
    model_name_or_path: str,
    use_pytorch: bool = False,
    rngs: rnglib.Rngs | None = None,
    dtype: DTypeLike = jnp.float32,
    param_dtype: DTypeLike = jnp.float32,
    sharding: ShardingSpec = ViTSharding,
    use_gradient_checkpointing: bool = False,
    attention_fn: Callable[..., Any] | None = None,
) -> "VisionTransformer":
    """Load a pretrained Vision Transformer from a local path or HuggingFace Hub.

    Args:
        model_name_or_path (str): Path to local weights or HuggingFace model ID.
        use_pytorch (bool): Whether to load from PyTorch weights. Defaults to False.
        rngs (rnglib.Rngs | None): Random number generator keys. If None, initializes to nnx.Rngs(0).
        dtype (DTypeLike): Data type for computations. Defaults to jnp.float32.
        param_dtype (DTypeLike): Data type for parameters. Defaults to jnp.float32.
        sharding (ShardingSpec): Sharding specification for parameters. Defaults to ViTSharding.
        use_gradient_checkpointing (bool): Whether to use gradient checkpointing. Defaults to False.
        attention_fn (Callable[..., Any] | None): Custom attention function. Defaults to None.

    Returns:
        VisionTransformer: Initialized Vision Transformer with pretrained weights
    """
    from .params import load_from_pretrained

    return load_from_pretrained(cls, model_name_or_path, use_pytorch, rngs, dtype, param_dtype, sharding, use_gradient_checkpointing, attention_fn)

save_pretrained(save_directory)

Save the model weights and config in HuggingFace format.

Parameters:

Name Type Description Default
save_directory str

Directory path where the model will be saved.

required
Source code in src/jimm/models/vit/vit_model.py
201
202
203
204
205
206
207
208
209
def save_pretrained(self, save_directory: str):
    """Save the model weights and config in HuggingFace format.

    Args:
        save_directory (str): Directory path where the model will be saved.
    """
    from .params import save_pretrained

    save_pretrained(self, save_directory)