Skip to content

Commit e94a480

Browse files
Add usage example for DINOv2 (huggingface#37398)
* Add usage example for DINOv2 * More explicit shape names * More verbose text * Moved example to Notes section * Indentation
1 parent d20aa68 commit e94a480

File tree

1 file changed

+62
-27
lines changed

1 file changed

+62
-27
lines changed

docs/source/en/model_doc/dinov2.md

Lines changed: 62 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -111,33 +111,68 @@ print("Predicted class:", model.config.id2label[predicted_class_idx])
111111

112112
## Notes
113113

114-
- Use [torch.jit.trace](https://pytorch.org/docs/stable/generated/torch.jit.trace.html) to speedup inference. However, it will produce some mismatched elements. The difference between the original and traced model is 1e-4.
115-
116-
```py
117-
import torch
118-
from transformers import AutoImageProcessor, AutoModel
119-
from PIL import Image
120-
import requests
121-
122-
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
123-
image = Image.open(requests.get(url, stream=True).raw)
124-
125-
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
126-
model = AutoModel.from_pretrained('facebook/dinov2-base')
127-
128-
inputs = processor(images=image, return_tensors="pt")
129-
outputs = model(**inputs)
130-
last_hidden_states = outputs[0]
131-
132-
# We have to force return_dict=False for tracing
133-
model.config.return_dict = False
134-
135-
with torch.no_grad():
136-
traced_model = torch.jit.trace(model, [inputs.pixel_values])
137-
traced_outputs = traced_model(inputs.pixel_values)
138-
139-
print((last_hidden_states - traced_outputs[0]).abs().max())
140-
```
114+
- The example below shows how to split the output tensor into:
115+
- one embedding for the whole image, commonly referred to as a `CLS` token,
116+
useful for classification and retrieval
117+
- a set of local embeddings, one for each `14x14` patch of the input image,
118+
useful for dense tasks, such as semantic segmentation
119+
120+
```py
121+
from transformers import AutoImageProcessor, AutoModel
122+
from PIL import Image
123+
import requests
124+
125+
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
126+
image = Image.open(requests.get(url, stream=True).raw)
127+
print(image.height, image.width) # [480, 640]
128+
129+
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
130+
model = AutoModel.from_pretrained('facebook/dinov2-base')
131+
patch_size = model.config.patch_size
132+
133+
inputs = processor(images=image, return_tensors="pt")
134+
print(inputs.pixel_values.shape) # [1, 3, 224, 224]
135+
batch_size, rgb, img_height, img_width = inputs.pixel_values.shape
136+
num_patches_height, num_patches_width = img_height // patch_size, img_width // patch_size
137+
num_patches_flat = num_patches_height * num_patches_width
138+
139+
outputs = model(**inputs)
140+
last_hidden_states = outputs[0]
141+
print(last_hidden_states.shape) # [1, 1 + 256, 768]
142+
assert last_hidden_states.shape == (batch_size, 1 + num_patches_flat, model.config.hidden_size)
143+
144+
cls_token = last_hidden_states[:, 0, :]
145+
patch_features = last_hidden_states[:, 1:, :].unflatten(1, (num_patches_height, num_patches_width))
146+
```
147+
148+
- Use [torch.jit.trace](https://pytorch.org/docs/stable/generated/torch.jit.trace.html) to speedup inference.
149+
However, it will produce some mismatched elements. The difference between the original and traced model is 1e-4.
150+
151+
```py
152+
import torch
153+
from transformers import AutoImageProcessor, AutoModel
154+
from PIL import Image
155+
import requests
156+
157+
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
158+
image = Image.open(requests.get(url, stream=True).raw)
159+
160+
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
161+
model = AutoModel.from_pretrained('facebook/dinov2-base')
162+
163+
inputs = processor(images=image, return_tensors="pt")
164+
outputs = model(**inputs)
165+
last_hidden_states = outputs[0]
166+
167+
# We have to force return_dict=False for tracing
168+
model.config.return_dict = False
169+
170+
with torch.no_grad():
171+
traced_model = torch.jit.trace(model, [inputs.pixel_values])
172+
traced_outputs = traced_model(inputs.pixel_values)
173+
174+
print((last_hidden_states - traced_outputs[0]).abs().max())
175+
```
141176

142177
## Dinov2Config
143178

0 commit comments

Comments
 (0)