Skip to content

Commit 5c47d08

Browse files
Add Swin2SR ImageProcessorFast (huggingface#37169)
* Add fast image processor support for Swin2SR * Add Swin2SR tests of fast image processing * Update docs and remove unnecessary test func * Fix docstring formatting * Skip fast vs slow processing test --------- Co-authored-by: Yoni Gozlan <[email protected]>
1 parent 17742bd commit 5c47d08

File tree

5 files changed

+171
-7
lines changed

5 files changed

+171
-7
lines changed

docs/source/en/model_doc/swin2sr.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ A demo Space for image super-resolution with SwinSR can be found [here](https://
5050
[[autodoc]] Swin2SRImageProcessor
5151
- preprocess
5252

53+
## Swin2SRImageProcessorFast
54+
55+
[[autodoc]] Swin2SRImageProcessorFast
56+
- preprocess
57+
5358
## Swin2SRConfig
5459

5560
[[autodoc]] Swin2SRConfig

src/transformers/models/auto/image_processing_auto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@
150150
("superglue", ("SuperGlueImageProcessor",)),
151151
("swiftformer", ("ViTImageProcessor", "ViTImageProcessorFast")),
152152
("swin", ("ViTImageProcessor", "ViTImageProcessorFast")),
153-
("swin2sr", ("Swin2SRImageProcessor",)),
153+
("swin2sr", ("Swin2SRImageProcessor", "Swin2SRImageProcessorFast")),
154154
("swinv2", ("ViTImageProcessor", "ViTImageProcessorFast")),
155155
("table-transformer", ("DetrImageProcessor",)),
156156
("timesformer", ("VideoMAEImageProcessor",)),

src/transformers/models/swin2sr/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
if TYPE_CHECKING:
2121
from .configuration_swin2sr import *
2222
from .image_processing_swin2sr import *
23+
from .image_processing_swin2sr_fast import *
2324
from .modeling_swin2sr import *
2425
else:
2526
import sys
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# coding=utf-8
2+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Fast Image processor class for Swin2SR."""
16+
17+
from typing import List, Optional, Union
18+
19+
from ...image_processing_utils import (
20+
BatchFeature,
21+
ChannelDimension,
22+
get_image_size,
23+
)
24+
from ...image_processing_utils_fast import (
25+
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
26+
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
27+
BaseImageProcessorFast,
28+
DefaultFastImageProcessorKwargs,
29+
group_images_by_shape,
30+
reorder_images,
31+
)
32+
from ...image_utils import ImageInput
33+
from ...processing_utils import Unpack
34+
from ...utils import (
35+
TensorType,
36+
add_start_docstrings,
37+
is_torch_available,
38+
is_torchvision_available,
39+
is_torchvision_v2_available,
40+
)
41+
42+
43+
if is_torch_available():
44+
import torch
45+
46+
if is_torchvision_available():
47+
if is_torchvision_v2_available():
48+
from torchvision.transforms.v2 import functional as F
49+
else:
50+
from torchvision.transforms import functional as F
51+
52+
53+
class Swin2SRFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
54+
do_pad: Optional[bool]
55+
pad_size: Optional[int]
56+
57+
58+
@add_start_docstrings(
59+
"Constructs a fast Swin2SR image processor.",
60+
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
61+
"""
62+
do_pad (`bool`, *optional*, defaults to `True`):
63+
Whether to pad the image to make the height and width divisible by `window_size`.
64+
pad_size (`int`, *optional*, defaults to `8`):
65+
The size of the sliding window for the local attention.
66+
""",
67+
)
68+
class Swin2SRImageProcessorFast(BaseImageProcessorFast):
69+
do_rescale = True
70+
rescale_factor = 1 / 255
71+
do_pad = True
72+
pad_size = 8
73+
valid_kwargs = Swin2SRFastImageProcessorKwargs
74+
75+
def __init__(self, **kwargs: Unpack[Swin2SRFastImageProcessorKwargs]):
76+
super().__init__(**kwargs)
77+
78+
@add_start_docstrings(
79+
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
80+
"""
81+
do_pad (`bool`, *optional*, defaults to `True`):
82+
Whether to pad the image to make the height and width divisible by `window_size`.
83+
pad_size (`int`, *optional*, defaults to `8`):
84+
The size of the sliding window for the local attention.
85+
""",
86+
)
87+
def preprocess(self, images: ImageInput, **kwargs: Unpack[Swin2SRFastImageProcessorKwargs]) -> BatchFeature:
88+
return super().preprocess(images, **kwargs)
89+
90+
def pad(self, images: "torch.Tensor", size: int) -> "torch.Tensor":
91+
"""
92+
Pad an image to make the height and width divisible by `size`.
93+
94+
Args:
95+
images (`torch.Tensor`):
96+
Images to pad.
97+
size (`int`):
98+
The size to make the height and width divisible by.
99+
100+
Returns:
101+
`torch.Tensor`: The padded images.
102+
"""
103+
height, width = get_image_size(images, ChannelDimension.FIRST)
104+
pad_height = (height // size + 1) * size - height
105+
pad_width = (width // size + 1) * size - width
106+
107+
return F.pad(
108+
images,
109+
(0, 0, pad_width, pad_height),
110+
padding_mode="symmetric",
111+
)
112+
113+
def _preprocess(
114+
self,
115+
images: List["torch.Tensor"],
116+
do_rescale: bool,
117+
rescale_factor: float,
118+
do_pad: bool,
119+
pad_size: int,
120+
return_tensors: Optional[Union[str, TensorType]],
121+
interpolation: Optional["F.InterpolationMode"],
122+
**kwargs,
123+
) -> BatchFeature:
124+
grouped_images, grouped_images_index = group_images_by_shape(images)
125+
processed_image_grouped = {}
126+
for shape, stacked_images in grouped_images.items():
127+
if do_rescale:
128+
stacked_images = self.rescale(stacked_images, scale=rescale_factor)
129+
if do_pad:
130+
stacked_images = self.pad(stacked_images, size=pad_size)
131+
processed_image_grouped[shape] = stacked_images
132+
processed_images = reorder_images(processed_image_grouped, grouped_images_index)
133+
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
134+
135+
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
136+
137+
138+
__all__ = ["Swin2SRImageProcessorFast"]

tests/models/swin2sr/test_image_processing_swin2sr.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import numpy as np
1919

2020
from transformers.testing_utils import require_torch, require_vision
21-
from transformers.utils import is_torch_available, is_vision_available
21+
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
2222

2323
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
2424

@@ -30,6 +30,9 @@
3030
from PIL import Image
3131

3232
from transformers import Swin2SRImageProcessor
33+
34+
if is_torchvision_available():
35+
from transformers import Swin2SRImageProcessorFast
3336
from transformers.image_transforms import get_image_size
3437

3538

@@ -97,6 +100,7 @@ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=F
97100
@require_vision
98101
class Swin2SRImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
99102
image_processing_class = Swin2SRImageProcessor if is_vision_available() else None
103+
fast_image_processing_class = Swin2SRImageProcessorFast if is_torchvision_available() else None
100104

101105
def setUp(self):
102106
super().setUp()
@@ -107,11 +111,12 @@ def image_processor_dict(self):
107111
return self.image_processor_tester.prepare_image_processor_dict()
108112

109113
def test_image_processor_properties(self):
110-
image_processor = self.image_processing_class(**self.image_processor_dict)
111-
self.assertTrue(hasattr(image_processor, "do_rescale"))
112-
self.assertTrue(hasattr(image_processor, "rescale_factor"))
113-
self.assertTrue(hasattr(image_processor, "do_pad"))
114-
self.assertTrue(hasattr(image_processor, "pad_size"))
114+
for image_processing_class in self.image_processor_list:
115+
image_processing = image_processing_class(**self.image_processor_dict)
116+
self.assertTrue(hasattr(image_processing, "do_rescale"))
117+
self.assertTrue(hasattr(image_processing, "rescale_factor"))
118+
self.assertTrue(hasattr(image_processing, "do_pad"))
119+
self.assertTrue(hasattr(image_processing, "pad_size"))
115120

116121
def calculate_expected_size(self, image):
117122
old_height, old_width = get_image_size(image)
@@ -181,3 +186,18 @@ def test_call_pytorch(self):
181186
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
182187
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
183188
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
189+
190+
@unittest.skip(reason="No speed gain on CPU due to minimal processing.")
191+
def test_fast_is_faster_than_slow(self):
192+
pass
193+
194+
def test_slow_fast_equivalence_batched(self):
195+
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)
196+
197+
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
198+
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
199+
200+
encoded_slow = image_processor_slow(image_inputs, return_tensors="pt").pixel_values
201+
encoded_fast = image_processor_fast(image_inputs, return_tensors="pt").pixel_values
202+
203+
self.assertTrue(torch.allclose(encoded_slow, encoded_fast, atol=1e-1))

0 commit comments

Comments
 (0)