ViT (Vision Transformer)
The ViT (Vision Transformer) is a transformer-based neural network architecture for image classification. It divides an image into fixed-size patches, linearly embeds each patch, adds position embeddings, and processes the resulting sequence of vectors through a standard transformer encoder.
The ViT model was introduced in the paper "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" and has shown strong performance on image classification benchmarks.
Flash / Splash Attention
ViT supports hardware-accelerated attention via Tokamax:
| Backend | Hardware | Notes |
|---|---|---|
"mosaic" |
NVIDIA H100 (SM90) / B100 (SM100) | Pallas Mosaic GPU kernel |
"triton" |
Any NVIDIA GPU | Pallas Triton kernel |
"cudnn" |
NVIDIA GPU | Via JAX-NN / cuDNN |
"mosaic_tpu" |
TPU v5 / v7 | Splash attention (block-sparse) |
"xla_chunked" |
GPU / TPU | Flash-style chunked XLA |
"xla" |
Any | Standard XLA fallback |
import jimm
# GPU: try H100 Mosaic kernel, fall back to Triton, then XLA
model = jimm.VisionTransformer.from_pretrained("google/vit-base-patch16-224",
attention_fn=jimm.make_tokamax_attention(["mosaic", "triton", "xla"]))
# TPU: try Splash attention, fall back to chunked XLA
model = jimm.VisionTransformer.from_pretrained("google/vit-base-patch16-224",
attention_fn=jimm.make_tokamax_attention(["mosaic_tpu", "xla_chunked"]))
Note: Flash/Splash attention does not provide a speedup at typical ViT context lengths (e.g. 196 tokens for 224px/16px). The primary benefit is memory reduction at longer sequence lengths.
FSDP / Explicit Sharding
ViT supports JAX explicit sharding (FSDP-style) via ViTSharding. Large weight matrices are sharded on the contracting (in_features) dimension so that activations carry only the batch-axis sharding.
from jax.experimental import mesh_utils
from jax.sharding import AxisType, Mesh
import jax
n_devices = jax.device_count()
mesh = Mesh(
mesh_utils.create_device_mesh((1, n_devices)),
("data", "fsdp"),
axis_types=(AxisType.Explicit, AxisType.Explicit),
)
jax.set_mesh(mesh)
model = jimm.VisionTransformer.from_pretrained("google/vit-base-patch16-224")
ViTSharding specs represent per-layer shapes. The Transformer stack prepends None for the scan axis to Variable metadata after nnx.vmap, so the optimizer receives the correct stacked spec natively.
To disable sharding, pass sharding=jimm.common.sharding.NoSharding().
jimm.models.vit.VisionTransformer
Bases: Module
Vision Transformer (ViT) model for image classification.
This implements the Vision Transformer as described in the paper "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale"
Source code in src/jimm/models/vit/vit_model.py
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 | |
__call__(x)
Forward pass of the Vision Transformer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Float[Array, 'batch height width channels']
|
Input tensor with shape [batch, height, width, channels] |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'batch num_classes']
|
Float[Array, "batch num_classes"]: Output logits with shape [batch, num_classes] |
Source code in src/jimm/models/vit/vit_model.py
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 | |
__init__(num_classes=1000, in_channels=3, img_size=224, patch_size=16, num_layers=12, num_heads=12, mlp_dim=3072, hidden_size=768, dropout_rate=0.1, use_quick_gelu=False, use_gradient_checkpointing=False, attention_fn=None, do_classification=True, rngs=None, dtype=jnp.float32, param_dtype=jnp.float32, sharding=ViTSharding)
Initialize a Vision Transformer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_classes
|
int
|
Number of output classes. Defaults to 1000. |
1000
|
in_channels
|
int
|
Number of input channels. Defaults to 3. |
3
|
img_size
|
int
|
Size of the input image (assumed square). Defaults to 224. |
224
|
patch_size
|
int
|
Size of each patch (assumed square). Defaults to 16. |
16
|
num_layers
|
int
|
Number of transformer layers. Defaults to 12. |
12
|
num_heads
|
int
|
Number of attention heads. Defaults to 12. |
12
|
mlp_dim
|
int
|
Size of the MLP dimension. Defaults to 3072. |
3072
|
hidden_size
|
int
|
Size of the hidden dimension. Defaults to 768. |
768
|
dropout_rate
|
float
|
Dropout rate. Defaults to 0.1. |
0.1
|
use_quick_gelu
|
bool
|
Whether to use quickgelu instead of gelu. Defaults to False. |
False
|
use_gradient_checkpointing
|
bool
|
Whether to use gradient checkpointing. Defaults to False. |
False
|
attention_fn
|
Callable[..., Any] | None
|
Custom attention function (e.g. jimm.tokamax_attention). Defaults to None. |
None
|
do_classification
|
bool
|
Whether to include the final classification head. Defaults to True. |
True
|
rngs
|
Rngs | None
|
Random number generator keys. If None, initializes to nnx.Rngs(0). |
None
|
dtype
|
DTypeLike
|
Data type for computations. Defaults to jnp.float32. |
float32
|
param_dtype
|
DTypeLike
|
Data type for parameters. Defaults to jnp.float32. |
float32
|
sharding
|
ShardingSpec
|
Sharding specification for parameters. Defaults to ViTSharding. |
ViTSharding
|
Source code in src/jimm/models/vit/vit_model.py
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 | |
from_config(config, *, rngs=None, dtype=jnp.float32, param_dtype=jnp.float32, sharding=ViTSharding, use_gradient_checkpointing=False, attention_fn=None)
classmethod
Create model from HuggingFace-compatible config dict.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
dict[str, Any]
|
Configuration dictionary in HuggingFace ViT format. |
required |
rngs
|
Rngs | None
|
Random number generator state. If None, initializes to nnx.Rngs(0). |
None
|
dtype
|
DTypeLike
|
Data type for computations. |
float32
|
param_dtype
|
DTypeLike
|
Data type for parameters. |
float32
|
sharding
|
ShardingSpec
|
Sharding specification for parameters. |
ViTSharding
|
use_gradient_checkpointing
|
bool
|
Enable gradient checkpointing. |
False
|
attention_fn
|
Callable[..., Any] | None
|
Custom attention function. Defaults to None. |
None
|
Returns:
| Type | Description |
|---|---|
VisionTransformer
|
VisionTransformer with random weights. |
Source code in src/jimm/models/vit/vit_model.py
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 | |
from_pretrained(model_name_or_path, use_pytorch=False, rngs=None, dtype=jnp.float32, param_dtype=jnp.float32, sharding=ViTSharding, use_gradient_checkpointing=False, attention_fn=None)
classmethod
Load a pretrained Vision Transformer from a local path or HuggingFace Hub.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_name_or_path
|
str
|
Path to local weights or HuggingFace model ID. |
required |
use_pytorch
|
bool
|
Whether to load from PyTorch weights. Defaults to False. |
False
|
rngs
|
Rngs | None
|
Random number generator keys. If None, initializes to nnx.Rngs(0). |
None
|
dtype
|
DTypeLike
|
Data type for computations. Defaults to jnp.float32. |
float32
|
param_dtype
|
DTypeLike
|
Data type for parameters. Defaults to jnp.float32. |
float32
|
sharding
|
ShardingSpec
|
Sharding specification for parameters. Defaults to ViTSharding. |
ViTSharding
|
use_gradient_checkpointing
|
bool
|
Whether to use gradient checkpointing. Defaults to False. |
False
|
attention_fn
|
Callable[..., Any] | None
|
Custom attention function. Defaults to None. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
VisionTransformer |
VisionTransformer
|
Initialized Vision Transformer with pretrained weights |
Source code in src/jimm/models/vit/vit_model.py
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | |
save_pretrained(save_directory)
Save the model weights and config in HuggingFace format.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
save_directory
|
str
|
Directory path where the model will be saved. |
required |
Source code in src/jimm/models/vit/vit_model.py
201 202 203 204 205 206 207 208 209 | |