Skip to content

ViT (Vision Transformer)

The ViT (Vision Transformer) is a transformer-based neural network architecture for image classification. It divides an image into fixed-size patches, linearly embeds each patch, adds position embeddings, and processes the resulting sequence of vectors through a standard transformer encoder.

The ViT model was introduced in the paper "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" and has shown strong performance on image classification benchmarks.

jimm.models.vit.VisionTransformer

Bases: Module

Vision Transformer (ViT) model for image classification.

This implements the Vision Transformer as described in the paper "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale"

Source code in src/jimm/models/vit/vit_model.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
class VisionTransformer(nnx.Module):
    """Vision Transformer (ViT) model for image classification.

    This implements the Vision Transformer as described in the paper
    "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale"
    """

    def __init__(
        self,
        num_classes: int = 1000,
        in_channels: int = 3,
        img_size: int = 224,
        patch_size: int = 16,
        num_layers: int = 12,
        num_heads: int = 12,
        mlp_dim: int = 3072,
        hidden_size: int = 768,
        dropout_rate: float = 0.1,
        use_quick_gelu: bool = False,
        use_gradient_checkpointing: bool = False,
        do_classification: bool = True,
        dtype: DTypeLike = jnp.float32,
        param_dtype: DTypeLike = jnp.float32,
        rngs: rnglib.Rngs = nnx.Rngs(0),
        mesh: Mesh | None = None,
        mesh_rules: MeshRules = DEFAULT_SHARDING,
    ) -> None:
        """Initialize a Vision Transformer.

        Args:
            num_classes (int, optional): Number of output classes. Defaults to 1000.
            in_channels (int, optional): Number of input channels. Defaults to 3.
            img_size (int, optional): Size of the input image (assumed square). Defaults to 224.
            patch_size (int, optional): Size of each patch (assumed square). Defaults to 16.
            num_layers (int, optional): Number of transformer layers. Defaults to 12.
            num_heads (int, optional): Number of attention heads. Defaults to 12.
            mlp_dim (int, optional): Size of the MLP dimension. Defaults to 3072.
            hidden_size (int, optional): Size of the hidden dimension. Defaults to 768.
            dropout_rate (float, optional): Dropout rate. Defaults to 0.1.
            use_quick_gelu (bool, optional): Whether to use quickgelu instead of gelu. Defaults to False.
            use_gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False.
            do_classification (bool, optional): Whether to include the final classification head. Defaults to True.
            dtype (DTypeLike, optional): Data type for computations. Defaults to jnp.float32.
            param_dtype (DTypeLike, optional): Data type for parameters. Defaults to jnp.float32.
            rngs (rnglib.Rngs, optional): Random number generator keys. Defaults to nnx.Rngs(0).
            mesh (Mesh | None, optional): Optional JAX device mesh for parameter sharding. Defaults to None.
            mesh_rules (MeshRules, optional): Logical axis sharding rules. Defaults to DEFAULT_SHARDING.
        """
        self.do_classification = do_classification
        self._original_config = None
        self.encoder = VisionTransformerBase(
            img_size=img_size,
            patch_size=patch_size,
            in_channels=in_channels,
            hidden_size=hidden_size,
            num_layers=num_layers,
            num_heads=num_heads,
            mlp_dim=mlp_dim,
            dropout_rate=dropout_rate,
            use_quick_gelu=use_quick_gelu,
            use_gradient_checkpointing=use_gradient_checkpointing,
            use_pre_norm=False,
            use_patch_bias=True,
            layernorm_epsilon=1e-12,
            rngs=rngs,
            dtype=dtype,
            param_dtype=param_dtype,
            mesh=mesh,
            mesh_rules=mesh_rules,
        )

        if self.do_classification:
            self.classifier = nnx.Linear(
                hidden_size,
                num_classes,
                dtype=dtype,
                param_dtype=param_dtype,
                rngs=rngs,
                kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), mesh_rules("classifier_in", "classifier_out")),
                bias_init=nnx.with_partitioning(
                    nnx.initializers.zeros_init(),
                    mesh_rules(
                        "classifier_out",
                    ),
                ),
            )

    def __call__(self, x: Float[Array, "batch height width channels"]) -> Float[Array, "batch num_classes"]:
        """Forward pass of the Vision Transformer.

        Args:
            x (Float[Array, "batch height width channels"]): Input tensor with shape [batch, height, width, channels]

        Returns:
            Float[Array, "batch num_classes"]: Output logits with shape [batch, num_classes]
        """
        x = self.encoder(x)
        if self.do_classification:
            return self.classifier(x)
        return x

    @classmethod
    def from_pretrained(
        cls,
        model_name_or_path: str,
        use_pytorch: bool = False,
        mesh: Mesh | None = None,
        dtype: DTypeLike = jnp.float32,
        param_dtype: DTypeLike = jnp.float32,
        use_gradient_checkpointing: bool = False,
        rngs: rnglib.Rngs = nnx.Rngs(0),
    ) -> "VisionTransformer":
        """Load a pretrained Vision Transformer from a local path or HuggingFace Hub.

        Args:
            model_name_or_path (str): Path to local weights or HuggingFace model ID.
            use_pytorch (bool): Whether to load from PyTorch weights. Defaults to False.
            mesh (Mesh | None): Optional device mesh for parameter sharding. Defaults to None.
            dtype (DTypeLike): Data type for computations. Defaults to jnp.float32.
            param_dtype (DTypeLike): Data type for parameters. Defaults to jnp.float32.
            use_gradient_checkpointing (bool): Whether to use gradient checkpointing. Defaults to False.
            rngs (rnglib.Rngs): Random number generator keys. Defaults to nnx.Rngs(0).

        Returns:
            VisionTransformer: Initialized Vision Transformer with pretrained weights
        """
        from .params import load_from_pretrained

        return load_from_pretrained(cls, model_name_or_path, use_pytorch, mesh, dtype, param_dtype, use_gradient_checkpointing, rngs)

    def save_pretrained(self, save_directory: str):
        """Save the model weights and config in HuggingFace format.

        Args:
            save_directory (str): Directory path where the model will be saved.
        """
        from .params import save_pretrained

        save_pretrained(self, save_directory)

__call__(x)

Forward pass of the Vision Transformer.

Parameters:

Name Type Description Default
x Float[Array, 'batch height width channels']

Input tensor with shape [batch, height, width, channels]

required

Returns:

Type Description
Float[Array, 'batch num_classes']

Float[Array, "batch num_classes"]: Output logits with shape [batch, num_classes]

Source code in src/jimm/models/vit/vit_model.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
def __call__(self, x: Float[Array, "batch height width channels"]) -> Float[Array, "batch num_classes"]:
    """Forward pass of the Vision Transformer.

    Args:
        x (Float[Array, "batch height width channels"]): Input tensor with shape [batch, height, width, channels]

    Returns:
        Float[Array, "batch num_classes"]: Output logits with shape [batch, num_classes]
    """
    x = self.encoder(x)
    if self.do_classification:
        return self.classifier(x)
    return x

__init__(num_classes=1000, in_channels=3, img_size=224, patch_size=16, num_layers=12, num_heads=12, mlp_dim=3072, hidden_size=768, dropout_rate=0.1, use_quick_gelu=False, use_gradient_checkpointing=False, do_classification=True, dtype=jnp.float32, param_dtype=jnp.float32, rngs=nnx.Rngs(0), mesh=None, mesh_rules=DEFAULT_SHARDING)

Initialize a Vision Transformer.

Parameters:

Name Type Description Default
num_classes int

Number of output classes. Defaults to 1000.

1000
in_channels int

Number of input channels. Defaults to 3.

3
img_size int

Size of the input image (assumed square). Defaults to 224.

224
patch_size int

Size of each patch (assumed square). Defaults to 16.

16
num_layers int

Number of transformer layers. Defaults to 12.

12
num_heads int

Number of attention heads. Defaults to 12.

12
mlp_dim int

Size of the MLP dimension. Defaults to 3072.

3072
hidden_size int

Size of the hidden dimension. Defaults to 768.

768
dropout_rate float

Dropout rate. Defaults to 0.1.

0.1
use_quick_gelu bool

Whether to use quickgelu instead of gelu. Defaults to False.

False
use_gradient_checkpointing bool

Whether to use gradient checkpointing. Defaults to False.

False
do_classification bool

Whether to include the final classification head. Defaults to True.

True
dtype DTypeLike

Data type for computations. Defaults to jnp.float32.

float32
param_dtype DTypeLike

Data type for parameters. Defaults to jnp.float32.

float32
rngs Rngs

Random number generator keys. Defaults to nnx.Rngs(0).

Rngs(0)
mesh Mesh | None

Optional JAX device mesh for parameter sharding. Defaults to None.

None
mesh_rules MeshRules

Logical axis sharding rules. Defaults to DEFAULT_SHARDING.

DEFAULT_SHARDING
Source code in src/jimm/models/vit/vit_model.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def __init__(
    self,
    num_classes: int = 1000,
    in_channels: int = 3,
    img_size: int = 224,
    patch_size: int = 16,
    num_layers: int = 12,
    num_heads: int = 12,
    mlp_dim: int = 3072,
    hidden_size: int = 768,
    dropout_rate: float = 0.1,
    use_quick_gelu: bool = False,
    use_gradient_checkpointing: bool = False,
    do_classification: bool = True,
    dtype: DTypeLike = jnp.float32,
    param_dtype: DTypeLike = jnp.float32,
    rngs: rnglib.Rngs = nnx.Rngs(0),
    mesh: Mesh | None = None,
    mesh_rules: MeshRules = DEFAULT_SHARDING,
) -> None:
    """Initialize a Vision Transformer.

    Args:
        num_classes (int, optional): Number of output classes. Defaults to 1000.
        in_channels (int, optional): Number of input channels. Defaults to 3.
        img_size (int, optional): Size of the input image (assumed square). Defaults to 224.
        patch_size (int, optional): Size of each patch (assumed square). Defaults to 16.
        num_layers (int, optional): Number of transformer layers. Defaults to 12.
        num_heads (int, optional): Number of attention heads. Defaults to 12.
        mlp_dim (int, optional): Size of the MLP dimension. Defaults to 3072.
        hidden_size (int, optional): Size of the hidden dimension. Defaults to 768.
        dropout_rate (float, optional): Dropout rate. Defaults to 0.1.
        use_quick_gelu (bool, optional): Whether to use quickgelu instead of gelu. Defaults to False.
        use_gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False.
        do_classification (bool, optional): Whether to include the final classification head. Defaults to True.
        dtype (DTypeLike, optional): Data type for computations. Defaults to jnp.float32.
        param_dtype (DTypeLike, optional): Data type for parameters. Defaults to jnp.float32.
        rngs (rnglib.Rngs, optional): Random number generator keys. Defaults to nnx.Rngs(0).
        mesh (Mesh | None, optional): Optional JAX device mesh for parameter sharding. Defaults to None.
        mesh_rules (MeshRules, optional): Logical axis sharding rules. Defaults to DEFAULT_SHARDING.
    """
    self.do_classification = do_classification
    self._original_config = None
    self.encoder = VisionTransformerBase(
        img_size=img_size,
        patch_size=patch_size,
        in_channels=in_channels,
        hidden_size=hidden_size,
        num_layers=num_layers,
        num_heads=num_heads,
        mlp_dim=mlp_dim,
        dropout_rate=dropout_rate,
        use_quick_gelu=use_quick_gelu,
        use_gradient_checkpointing=use_gradient_checkpointing,
        use_pre_norm=False,
        use_patch_bias=True,
        layernorm_epsilon=1e-12,
        rngs=rngs,
        dtype=dtype,
        param_dtype=param_dtype,
        mesh=mesh,
        mesh_rules=mesh_rules,
    )

    if self.do_classification:
        self.classifier = nnx.Linear(
            hidden_size,
            num_classes,
            dtype=dtype,
            param_dtype=param_dtype,
            rngs=rngs,
            kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), mesh_rules("classifier_in", "classifier_out")),
            bias_init=nnx.with_partitioning(
                nnx.initializers.zeros_init(),
                mesh_rules(
                    "classifier_out",
                ),
            ),
        )

from_pretrained(model_name_or_path, use_pytorch=False, mesh=None, dtype=jnp.float32, param_dtype=jnp.float32, use_gradient_checkpointing=False, rngs=nnx.Rngs(0)) classmethod

Load a pretrained Vision Transformer from a local path or HuggingFace Hub.

Parameters:

Name Type Description Default
model_name_or_path str

Path to local weights or HuggingFace model ID.

required
use_pytorch bool

Whether to load from PyTorch weights. Defaults to False.

False
mesh Mesh | None

Optional device mesh for parameter sharding. Defaults to None.

None
dtype DTypeLike

Data type for computations. Defaults to jnp.float32.

float32
param_dtype DTypeLike

Data type for parameters. Defaults to jnp.float32.

float32
use_gradient_checkpointing bool

Whether to use gradient checkpointing. Defaults to False.

False
rngs Rngs

Random number generator keys. Defaults to nnx.Rngs(0).

Rngs(0)

Returns:

Name Type Description
VisionTransformer VisionTransformer

Initialized Vision Transformer with pretrained weights

Source code in src/jimm/models/vit/vit_model.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
@classmethod
def from_pretrained(
    cls,
    model_name_or_path: str,
    use_pytorch: bool = False,
    mesh: Mesh | None = None,
    dtype: DTypeLike = jnp.float32,
    param_dtype: DTypeLike = jnp.float32,
    use_gradient_checkpointing: bool = False,
    rngs: rnglib.Rngs = nnx.Rngs(0),
) -> "VisionTransformer":
    """Load a pretrained Vision Transformer from a local path or HuggingFace Hub.

    Args:
        model_name_or_path (str): Path to local weights or HuggingFace model ID.
        use_pytorch (bool): Whether to load from PyTorch weights. Defaults to False.
        mesh (Mesh | None): Optional device mesh for parameter sharding. Defaults to None.
        dtype (DTypeLike): Data type for computations. Defaults to jnp.float32.
        param_dtype (DTypeLike): Data type for parameters. Defaults to jnp.float32.
        use_gradient_checkpointing (bool): Whether to use gradient checkpointing. Defaults to False.
        rngs (rnglib.Rngs): Random number generator keys. Defaults to nnx.Rngs(0).

    Returns:
        VisionTransformer: Initialized Vision Transformer with pretrained weights
    """
    from .params import load_from_pretrained

    return load_from_pretrained(cls, model_name_or_path, use_pytorch, mesh, dtype, param_dtype, use_gradient_checkpointing, rngs)

save_pretrained(save_directory)

Save the model weights and config in HuggingFace format.

Parameters:

Name Type Description Default
save_directory str

Directory path where the model will be saved.

required
Source code in src/jimm/models/vit/vit_model.py
142
143
144
145
146
147
148
149
150
def save_pretrained(self, save_directory: str):
    """Save the model weights and config in HuggingFace format.

    Args:
        save_directory (str): Directory path where the model will be saved.
    """
    from .params import save_pretrained

    save_pretrained(self, save_directory)