@@ -111,33 +111,68 @@ print("Predicted class:", model.config.id2label[predicted_class_idx])
111
111
112
112
## Notes
113
113
114
- - Use [ torch.jit.trace] ( https://pytorch.org/docs/stable/generated/torch.jit.trace.html ) to speedup inference. However, it will produce some mismatched elements. The difference between the original and traced model is 1e-4.
115
-
116
- ``` py
117
- import torch
118
- from transformers import AutoImageProcessor, AutoModel
119
- from PIL import Image
120
- import requests
121
-
122
- url = ' http://images.cocodataset.org/val2017/000000039769.jpg'
123
- image = Image.open(requests.get(url, stream = True ).raw)
124
-
125
- processor = AutoImageProcessor.from_pretrained(' facebook/dinov2-base' )
126
- model = AutoModel.from_pretrained(' facebook/dinov2-base' )
127
-
128
- inputs = processor(images = image, return_tensors = " pt" )
129
- outputs = model(** inputs)
130
- last_hidden_states = outputs[0 ]
131
-
132
- # We have to force return_dict=False for tracing
133
- model.config.return_dict = False
134
-
135
- with torch.no_grad():
136
- traced_model = torch.jit.trace(model, [inputs.pixel_values])
137
- traced_outputs = traced_model(inputs.pixel_values)
138
-
139
- print ((last_hidden_states - traced_outputs[0 ]).abs().max())
140
- ```
114
+ - The example below shows how to split the output tensor into:
115
+ - one embedding for the whole image, commonly referred to as a ` CLS ` token,
116
+ useful for classification and retrieval
117
+ - a set of local embeddings, one for each ` 14x14 ` patch of the input image,
118
+ useful for dense tasks, such as semantic segmentation
119
+
120
+ ``` py
121
+ from transformers import AutoImageProcessor, AutoModel
122
+ from PIL import Image
123
+ import requests
124
+
125
+ url = ' http://images.cocodataset.org/val2017/000000039769.jpg'
126
+ image = Image.open(requests.get(url, stream = True ).raw)
127
+ print (image.height, image.width) # [480, 640]
128
+
129
+ processor = AutoImageProcessor.from_pretrained(' facebook/dinov2-base' )
130
+ model = AutoModel.from_pretrained(' facebook/dinov2-base' )
131
+ patch_size = model.config.patch_size
132
+
133
+ inputs = processor(images = image, return_tensors = " pt" )
134
+ print (inputs.pixel_values.shape) # [1, 3, 224, 224]
135
+ batch_size, rgb, img_height, img_width = inputs.pixel_values.shape
136
+ num_patches_height, num_patches_width = img_height // patch_size, img_width // patch_size
137
+ num_patches_flat = num_patches_height * num_patches_width
138
+
139
+ outputs = model(** inputs)
140
+ last_hidden_states = outputs[0 ]
141
+ print (last_hidden_states.shape) # [1, 1 + 256, 768]
142
+ assert last_hidden_states.shape == (batch_size, 1 + num_patches_flat, model.config.hidden_size)
143
+
144
+ cls_token = last_hidden_states[:, 0 , :]
145
+ patch_features = last_hidden_states[:, 1 :, :].unflatten(1 , (num_patches_height, num_patches_width))
146
+ ```
147
+
148
+ - Use [ torch.jit.trace] ( https://pytorch.org/docs/stable/generated/torch.jit.trace.html ) to speedup inference.
149
+ However, it will produce some mismatched elements. The difference between the original and traced model is 1e-4.
150
+
151
+ ``` py
152
+ import torch
153
+ from transformers import AutoImageProcessor, AutoModel
154
+ from PIL import Image
155
+ import requests
156
+
157
+ url = ' http://images.cocodataset.org/val2017/000000039769.jpg'
158
+ image = Image.open(requests.get(url, stream = True ).raw)
159
+
160
+ processor = AutoImageProcessor.from_pretrained(' facebook/dinov2-base' )
161
+ model = AutoModel.from_pretrained(' facebook/dinov2-base' )
162
+
163
+ inputs = processor(images = image, return_tensors = " pt" )
164
+ outputs = model(** inputs)
165
+ last_hidden_states = outputs[0 ]
166
+
167
+ # We have to force return_dict=False for tracing
168
+ model.config.return_dict = False
169
+
170
+ with torch.no_grad():
171
+ traced_model = torch.jit.trace(model, [inputs.pixel_values])
172
+ traced_outputs = traced_model(inputs.pixel_values)
173
+
174
+ print ((last_hidden_states - traced_outputs[0 ]).abs().max())
175
+ ```
141
176
142
177
## Dinov2Config
143
178
0 commit comments