Skip to content

Commit 3674aa0

Browse files
committed
commit
1 parent baf48d6 commit 3674aa0

18 files changed

+415
-17
lines changed

cifar/evaluate.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
python -u main.py \
22
--gpus 0 \
33
-e [best_model_path] \
4-
--model resnet20_1w1a \
4+
--model resnet20_bireal_1w1a \
55
--data_path [DATA_PATH] \
66
--dataset cifar10 \
77
-bt 128 \

cifar/models_bnn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .resnet import *
2+
from .resnet_bireal import *
23
from .resnet2 import *
34
from .vgg import *

cifar/models_bnn/resnet.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,9 @@ def __init__(self, in_planes, planes, stride=1, option='A'):
104104

105105
def forward(self, x):
106106
out = self.bn1(self.conv1(x))
107-
out += self.shortcut(x)
108107
out = F.hardtanh(out)
109-
x1 = out
110108
out = self.bn2(self.conv2(out))
111-
out += x1
109+
out += self.shortcut(x)
112110
out = F.hardtanh(out)
113111
return out
114112

cifar/models_bnn/resnet_bireal.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
'''
2+
Properly implemented ResNet-s for CIFAR10 as described in paper [1].
3+
4+
The implementation and structure of this file is hugely influenced by [2]
5+
which is implemented for ImageNet and doesn't have option A for identity.
6+
Moreover, most of the implementations on the web is copy-paste from
7+
torchvision's resnet and has wrong number of params.
8+
9+
Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following
10+
number of layers and parameters:
11+
12+
name | layers | params
13+
ResNet20 | 20 | 0.27M
14+
ResNet32 | 32 | 0.46M
15+
ResNet44 | 44 | 0.66M
16+
ResNet56 | 56 | 0.85M
17+
ResNet110 | 110 | 1.7M
18+
ResNet1202| 1202 | 19.4m
19+
20+
which this implementation indeed has.
21+
22+
Reference:
23+
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
24+
Deep Residual Learning for Image Recognition. arXiv:1512.03385
25+
[2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
26+
27+
If you use this implementation in you work, please don't forget to mention the
28+
author, Yerlan Idelbayev.
29+
'''
30+
import torch
31+
import torch.nn as nn
32+
import torch.nn.functional as F
33+
import torch.nn.init as init
34+
from modules import *
35+
36+
from torch.autograd import Variable
37+
38+
__all__ = ['resnet20_bireal_1w1a']
39+
40+
41+
class LambdaLayer(nn.Module):
42+
def __init__(self, lambd):
43+
super(LambdaLayer, self).__init__()
44+
self.lambd = lambd
45+
46+
def forward(self, x):
47+
return self.lambd(x)
48+
49+
50+
class BasicBlock(nn.Module):
51+
expansion = 1
52+
53+
def __init__(self, in_planes, planes, stride=1, option='A'):
54+
super(BasicBlock, self).__init__()
55+
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
56+
self.bn1 = nn.BatchNorm2d(planes)
57+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
58+
self.bn2 = nn.BatchNorm2d(planes)
59+
self.shortcut = nn.Sequential()
60+
if stride != 1 or in_planes != planes:
61+
if option == 'A':
62+
"""
63+
For CIFAR10 ResNet paper uses option A.
64+
"""
65+
self.shortcut = LambdaLayer(lambda x:
66+
F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
67+
elif option == 'B':
68+
self.shortcut = nn.Sequential(
69+
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
70+
nn.BatchNorm2d(self.expansion * planes)
71+
)
72+
73+
def forward(self, x):
74+
out = F.relu(self.bn1(self.conv1(x)))
75+
out = self.bn2(self.conv2(out))
76+
out += self.shortcut(x)
77+
out = F.relu(out)
78+
return out
79+
80+
81+
class BasicBlock_1w1a(nn.Module):
82+
expansion = 1
83+
84+
def __init__(self, in_planes, planes, stride=1, option='A'):
85+
super(BasicBlock_1w1a, self).__init__()
86+
self.conv1 = BinarizeConv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
87+
self.bn1 = nn.BatchNorm2d(planes)
88+
self.conv2 = BinarizeConv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
89+
self.bn2 = nn.BatchNorm2d(planes)
90+
91+
self.shortcut = nn.Sequential()
92+
if stride != 1 or in_planes != planes:
93+
if option == 'A':
94+
"""
95+
For CIFAR10 ResNet paper uses option A.
96+
"""
97+
self.shortcut = LambdaLayer(lambda x:
98+
F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
99+
elif option == 'B':
100+
self.shortcut = nn.Sequential(
101+
BinarizeConv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
102+
nn.BatchNorm2d(self.expansion * planes)
103+
)
104+
105+
def forward(self, x):
106+
out = self.bn1(self.conv1(x))
107+
out += self.shortcut(x)
108+
out = F.hardtanh(out)
109+
x1 = out
110+
out = self.bn2(self.conv2(out))
111+
out += x1
112+
out = F.hardtanh(out)
113+
return out
114+
115+
116+
class ResNet(nn.Module):
117+
def __init__(self, block, num_blocks, num_classes=10):
118+
super(ResNet, self).__init__()
119+
self.in_planes = 16
120+
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
121+
self.bn1 = nn.BatchNorm2d(16)
122+
self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
123+
self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
124+
self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
125+
self.bn2 = nn.BatchNorm1d(64)
126+
self.linear = nn.Linear(64, num_classes)
127+
128+
for m in self.modules():
129+
if isinstance(m, nn.BatchNorm2d):
130+
m.weight.data.fill_(1e-8)
131+
m.bias.data.zero_()
132+
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
133+
init.kaiming_normal_(m.weight)
134+
135+
def _make_layer(self, block, planes, num_blocks, stride):
136+
strides = [stride] + [1]*(num_blocks-1)
137+
layers = []
138+
for stride in strides:
139+
layers.append(block(self.in_planes, planes, stride))
140+
self.in_planes = planes * block.expansion
141+
142+
return nn.Sequential(*layers)
143+
144+
def forward(self, x):
145+
out = F.hardtanh(self.bn1(self.conv1(x)))
146+
out = self.layer1(out)
147+
out = self.layer2(out)
148+
out = self.layer3(out)
149+
out = F.avg_pool2d(out, out.size()[3])
150+
out = out.view(out.size(0), -1)
151+
out = self.bn2(out)
152+
out = self.linear(out)
153+
154+
return out
155+
156+
157+
def resnet20_bireal_1w1a(**kwargs):
158+
return ResNet(BasicBlock_1w1a, [3, 3, 3],**kwargs)
159+
160+
161+
def resnet20():
162+
return ResNet(BasicBlock, [3, 3, 3])
163+
164+
165+
def resnet32():
166+
return ResNet(BasicBlock, [5, 5, 5])
167+
168+
169+
def resnet44():
170+
return ResNet(BasicBlock, [7, 7, 7])
171+
172+
173+
def resnet56():
174+
return ResNet(BasicBlock, [9, 9, 9])
175+
176+
177+
def resnet110():
178+
return ResNet(BasicBlock, [18, 18, 18])
179+
180+
181+
def resnet1202():
182+
return ResNet(BasicBlock, [200, 200, 200])
183+
184+
185+
def test(net):
186+
import numpy as np
187+
total_params = 0
188+
189+
for x in filter(lambda p: p.requires_grad, net.parameters()):
190+
total_params += np.prod(x.data.numpy().shape)
191+
print("Total number of params", total_params)
192+
print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters()))))
193+
194+
195+
if __name__ == "__main__":
196+
for net_name in __all__:
197+
if net_name.startswith('resnet'):
198+
print(net_name)
199+
test(globals()[net_name]())
200+
print()

cifar/run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
python -u main.py \
22
--gpus 0 \
3-
--model resnet20_1w1a \
3+
--model resnet20_bireal_1w1a \
44
--results_dir [DIR] \
55
--data_path [DATA_PATH] \
66
--dataset cifar10 \

cifar/utils/options.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
'--model',
4444
'-a',
4545
metavar='MODEL',
46-
default='resnet20_1w1a',
46+
default='resnet20_bireal_1w1a',
4747
help='model architecture ')
4848

4949
parser.add_argument(
@@ -55,7 +55,7 @@
5555
parser.add_argument(
5656
'--data_path',
5757
type=str,
58-
default='/home/xuzihan/data',
58+
default='/home/data',
5959
help='The dictionary where the dataset is stored.')
6060

6161
parser.add_argument(

imagenet/models_bnn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .resnet import *
2+
from .resnet_bireal import *
23
from .resnet2 import *
34
from .vgg import *
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)