Skip to content

Commit 31e05df

Browse files
Tiffany VlaarTiffany Vlaar
authored andcommitted
code
1 parent 895eab6 commit 31e05df

15 files changed

+825
-0
lines changed

.DS_Store

8 KB
Binary file not shown.

OGconstraint_CIFAR10_resnet34.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import torch
2+
import torch.nn as nn
3+
import numpy as np
4+
from models import *
5+
from Optimizers import OGconstraint_ud
6+
from Optimizers import initOGconstraint
7+
from datasets import CIFAR10data
8+
from train import train
9+
from test import test
10+
11+
torch.cuda.set_device(2)
12+
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
13+
print(f"Running on {device}.")
14+
torch.manual_seed(5) #optional
15+
16+
#Hyperparameters
17+
h = 0.1
18+
T = 0
19+
dt1 = h/3 #for warm-up
20+
cgamma = 0.9
21+
WD = 0
22+
dgamma = 0
23+
num_runs = 3
24+
num_epochs = 150
25+
batchsize = 128
26+
27+
loader_train,loader_test = CIFAR10data.generatedata(batchsize=batchsize)
28+
29+
RES_train_loss_allruns = []
30+
RES_test_loss_allruns = []
31+
RES_test_acc_allruns = []
32+
RES_train_acc_allruns = []
33+
34+
for run in range(num_runs):
35+
print("run = ", run)
36+
net = ResNet34()
37+
Constrainedlist, net = initOGconstraint.initOG(net)
38+
net = net.to(device)
39+
40+
criterion = nn.CrossEntropyLoss()
41+
optimizer = OGconstraint_ud.oCoLAud(net.parameters(),device,Constrainedlist=Constrainedlist,lr=dt1,cgamma=cgamma,dgamma=dgamma,weight_decay=WD)
42+
43+
RES_train_loss = []
44+
RES_train_acc = []
45+
RES_test_loss = []
46+
RES_test_acc = []
47+
48+
for epoch in range(num_epochs):
49+
50+
net, optimizer, loss_train,acc_train = train(epoch,loader_train,net,optimizer,criterion,device)
51+
loss_test,acc_test = test(loader_test,net,criterion,device)
52+
53+
RES_train_loss.append(loss_train)
54+
RES_train_acc.append(acc_train)
55+
RES_test_loss.append(loss_test)
56+
RES_test_acc.append(acc_test)
57+
58+
#warmup
59+
if epoch < 2:
60+
dt1 += (h/3)
61+
optimizer.param_groups[0]['lr'] = dt1
62+
#learning rate decay
63+
elif epoch == 50:
64+
optimizer.param_groups[0]['lr'] = 0.01
65+
elif epoch == 100:
66+
optimizer.param_groups[0]['lr'] = 0.001
67+
68+
69+
RES_train_loss_allruns.append(RES_train_loss)
70+
RES_train_acc_allruns.append(RES_train_acc)
71+
RES_test_loss_allruns.append(RES_test_loss)
72+
RES_test_acc_allruns.append(RES_test_acc)
73+
74+
75+
with open(f'OGconstraint_Resnet34_CIFAR10_batchsize_{batchsize}_WD_{WD}_cgam_{cgamma}_h_{h}_T_{T}_{num_runs}runs_{num_epochs}epochs.txt', 'w+') as f:
76+
f.write(f'Training loss min: {np.min(RES_train_loss_allruns,0)}\n')
77+
f.write(f'Test loss min: {np.min(RES_test_loss_allruns,0)}\n')
78+
f.write(f'Training accuracy min: {np.min(RES_train_acc_allruns,0)}\n')
79+
f.write(f'Test accuracy min: {np.min(RES_test_acc_allruns,0)}\n')
80+
f.write(f'Training loss max: {np.max(RES_train_loss_allruns,0)}\n')
81+
f.write(f'Test loss max: {np.max(RES_test_loss_allruns,0)}\n')
82+
f.write(f'Training accuracy max: {np.max(RES_train_acc_allruns,0)}\n')
83+
f.write(f'Test accuracy max: {np.max(RES_test_acc_allruns,0)}\n')
84+
f.write(f'Training loss std: {np.std(RES_train_loss_allruns,0)}\n')
85+
f.write(f'Test loss std: {np.std(RES_test_loss_allruns,0)}\n')
86+
f.write(f'Training accuracy std: {np.std(RES_train_acc_allruns,0)}\n')
87+
f.write(f'Test accuracy std: {np.std(RES_test_acc_allruns,0)}\n')
88+
f.write(f'Training loss mean: {np.mean(RES_train_loss_allruns,0)}\n')
89+
f.write(f'Test loss mean: {np.mean(RES_test_loss_allruns,0)}\n')
90+
f.write(f'Training accuracy mean: {np.mean(RES_train_acc_allruns,0)}\n')
91+
f.write(f'Test accuracy mean: {np.mean(RES_test_acc_allruns,0)}\n')
92+
93+

Optimizers/.DS_Store

6 KB
Binary file not shown.

Optimizers/OGconstraint_ud.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
import math
2+
import torch
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
from torch.optim import Optimizer
6+
import numpy as np
7+
8+
class oCoLAud(Optimizer):
9+
def __init__(self,params,device,Constrainedlist,lr=0.1,cgamma=0,dgamma=0,weight_decay=0):
10+
self.device = device
11+
self.Constrainedlist = Constrainedlist
12+
defaults = dict(lr=lr,cgamma=cgamma,dgamma=dgamma,weight_decay=weight_decay)
13+
super(oCoLAud,self).__init__(params,defaults)
14+
15+
def __setstate__(self,state):
16+
super(oCoLAud,self).__setstate__(state)
17+
18+
@torch.no_grad()
19+
def stepMom(self):
20+
for group in self.param_groups:
21+
22+
for i,p in enumerate(group['params']):
23+
24+
if p.grad is None:
25+
continue
26+
27+
param_state =self.state[p]
28+
shapep = p.shape
29+
if self.Constrainedlist[i] == 1:
30+
shapep0 = shapep[0]
31+
if len(shapep) > 2:
32+
shapep1 = shapep[1]*shapep[2]*shapep[3]
33+
else:
34+
shapep1 = shapep[1]
35+
36+
d_p = p.grad
37+
buf = param_state['momentum_buffer'] = -0.01*torch.clone(d_p).detach()
38+
buffy = torch.clone(buf).detach().reshape((shapep0,shapep1))
39+
Weighty = torch.clone(p).detach().reshape((shapep0,shapep1))
40+
41+
if shapep0 >= shapep1:
42+
bufproj = -0.5*torch.matmul(Weighty,(torch.matmul(torch.transpose(buffy,0,1),Weighty)+torch.matmul(torch.transpose(Weighty,0,1),buffy))).reshape(*shapep)
43+
else:
44+
bufproj = -0.5*torch.transpose(torch.matmul(torch.transpose(Weighty,0,1),(torch.matmul(Weighty,torch.transpose(buffy,0,1))+torch.matmul(buffy,torch.transpose(Weighty,0,1)))),0,1).reshape(*shapep)
45+
46+
buf.add_(bufproj)
47+
48+
else:
49+
d_p = p.grad
50+
buf = param_state['momentum_buffer'] = -0.01*torch.clone(d_p).detach()
51+
52+
53+
@torch.no_grad()
54+
def step(self):
55+
56+
for group in self.param_groups:
57+
cgamma = group['cgamma']
58+
dgamma = group['dgamma']
59+
weight_decay = group['weight_decay']
60+
61+
for i,p in enumerate(group['params']):
62+
63+
if p.grad is None:
64+
continue
65+
66+
param_state = self.state[p]
67+
shapep = p.shape
68+
69+
if self.Constrainedlist[i] == 1:
70+
71+
shapep0 = shapep[0]
72+
if len(shapep) > 2:
73+
shapep1 = shapep[1]*shapep[2]*shapep[3]
74+
else:
75+
shapep1 = shapep[1]
76+
77+
if 'OldWeight' not in param_state:
78+
OldWeight = param_state['OldWeight'] = torch.clone(p).detach()
79+
OldWeight = OldWeight.reshape((shapep0,shapep1))
80+
if shapep0 >= shapep1:
81+
prodis = torch.matmul(torch.transpose(OldWeight,0,1),OldWeight)
82+
else:
83+
prodis = torch.matmul(OldWeight,torch.transpose(OldWeight,0,1))
84+
OldWeightT = torch.transpose(OldWeight,0,1)
85+
Id = param_state['Id'] = torch.eye(*prodis.shape).to(self.device)
86+
else:
87+
OldWeight = param_state['OldWeight']
88+
OldWeight = torch.clone(p).detach()
89+
OldWeight = OldWeight.reshape((shapep0,shapep1))
90+
if shapep0 < shapep1:
91+
OldWeightT = torch.transpose(OldWeight,0,1)
92+
Id = param_state['Id']
93+
94+
buf = param_state['momentum_buffer']
95+
96+
# O -step
97+
if dgamma == 0:
98+
buf.mul_(cgamma)
99+
else:
100+
buf.mul_(cgamma).add_(dgamma,torch.cuda.FloatTensor(*shapep).normal_())
101+
buffy = torch.clone(buf).detach().reshape((shapep0,shapep1))
102+
if shapep0 >= shapep1:
103+
bufproj = -0.5*torch.matmul(OldWeight,(torch.matmul(torch.transpose(buffy,0,1),OldWeight)+torch.matmul(torch.transpose(OldWeight,0,1),buffy))).reshape(*shapep)
104+
else:
105+
bufproj = -0.5*torch.transpose(torch.matmul(OldWeightT,(torch.matmul(OldWeight,torch.transpose(buffy,0,1))+torch.matmul(buffy,torch.transpose(OldWeight,0,1)))),0,1).reshape(*shapep)
106+
107+
buf.add_(bufproj)
108+
109+
# B-step
110+
d_p = p.grad
111+
if weight_decay != 0:
112+
d_p = d_p.add(p, alpha=weight_decay)
113+
114+
buf.add_(-d_p)
115+
buffy = torch.clone(buf).detach().reshape((shapep0,shapep1))
116+
if shapep0 >= shapep1:
117+
bufproj = -0.5*torch.matmul(OldWeight,(torch.matmul(torch.transpose(buffy,0,1),OldWeight)+torch.matmul(torch.transpose(OldWeight,0,1),buffy))).reshape(*shapep)
118+
else:
119+
bufproj = -0.5*torch.transpose(torch.matmul(OldWeightT,(torch.matmul(OldWeight,torch.transpose(buffy,0,1))+torch.matmul(buffy,torch.transpose(OldWeight,0,1)))),0,1).reshape(*shapep)
120+
121+
122+
buf.add_(bufproj)
123+
d_p = buf
124+
125+
# A-step
126+
p.data.add_(d_p,alpha=group['lr'])
127+
p.data = p.reshape((shapep0,shapep1))
128+
FirstStep = torch.clone(p).detach()
129+
130+
if shapep0 >= shapep1:
131+
for ks in range(10):
132+
Lambda = torch.matmul(torch.transpose(p,0,1),p)-Id
133+
products = -0.5*torch.matmul(OldWeight,Lambda)
134+
p.add_(products)
135+
136+
bufproj1 = ((p.data-FirstStep)/group['lr']).reshape(*shapep)
137+
buf.add_(bufproj1)
138+
else:
139+
for ks in range(10):
140+
Lambda = torch.matmul(p,torch.transpose(p,0,1))-Id
141+
products = -0.5*torch.transpose(torch.matmul(OldWeightT,Lambda),0,1)
142+
p.add_(products)
143+
144+
bufproj1 = ((p.data-FirstStep)/group['lr']).reshape(*shapep)
145+
buf.add_(bufproj1)
146+
147+
p.data = p.reshape(*shapep)
148+
149+
OldWeight = torch.clone(p).detach()
150+
OldWeight = OldWeight.reshape((shapep0,shapep1))
151+
152+
buffy = torch.clone(buf).detach().reshape((shapep0,shapep1))
153+
if shapep0 >= shapep1:
154+
bufproj = -0.5*torch.matmul(OldWeight,(torch.matmul(torch.transpose(buffy,0,1),OldWeight)+torch.matmul(torch.transpose(OldWeight,0,1),buffy))).reshape(*shapep)
155+
else:
156+
bufproj = -0.5*torch.transpose(torch.matmul(OldWeightT,(torch.matmul(OldWeight,torch.transpose(buffy,0,1))+torch.matmul(buffy,torch.transpose(OldWeight,0,1)))),0,1).reshape(*shapep)
157+
158+
buf.add_(bufproj)
159+
else:
160+
buf = param_state['momentum_buffer']
161+
162+
if dgamma == 0:
163+
buf.mul_(cgamma)
164+
else:
165+
buf.mul_(cgamma).add_(dgamma,torch.cuda.FloatTensor(*shapep).normal_())
166+
167+
d_p = p.grad
168+
if weight_decay != 0:
169+
d_p = d_p.add(p, alpha=weight_decay)
170+
171+
buf.add_(-d_p,alpha=1)
172+
d_p = buf
173+
p.data.add_(d_p,alpha=group['lr'])
174+
175+
176+
177+
178+
179+
180+
181+
182+
183+

Optimizers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .circleconstraint_ud import *
2+
from .OGconstraint_ud import *

0 commit comments

Comments
 (0)