Skip to content

Commit

Permalink
structured optimize
Browse files Browse the repository at this point in the history
  • Loading branch information
wrencanfly authored Dec 12, 2023
1 parent 1afa120 commit 6c2b3fb
Showing 1 changed file with 34 additions and 33 deletions.
67 changes: 34 additions & 33 deletions models/learning_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,41 @@


class HebbLayer(nn.Module):
def __init__(self, input_dim, output_dim, lr, require_hebb=True, activation=True, update_rule='hebb',p=None):
super(HebbLayer, self).__init__()
# Init weights
self.weights = nn.Parameter(torch.randn(input_dim, output_dim) * 0.01, requires_grad=False) # ban bp for hebb layers
self.require_hebb = require_hebb
self.past_product_sum = torch.zeros_like(self.weights.data)
self.stimuli_times = 0
self.relu = nn.ReLU()
self.activation = activation
self.lr = lr
self.update_rule = update_rule
self.p = p
def __init__(self, input_dim, output_dim, lr, require_hebb=True, activation=True, update_rule='hebb', p=None):
super(HebbLayer, self).__init__()
# Initialize weights
self.weights = nn.Parameter(torch.randn(input_dim, output_dim) * 0.01, requires_grad=False)
self.require_hebb = require_hebb
self.past_product_sum = torch.zeros_like(self.weights.data)
self.stimuli_times = 0
self.relu = nn.ReLU()
self.activation = activation
self.lr = lr
self.update_rule = update_rule
self.p = p

# mapping
self.update_methods = {
'hebb': self.hebb_update,
'oja': self.oja_update,
'gupta': self.gupta_update,
'modified_gupta': self.modified_gupta_update
}

# Validate the selected update rule
if update_rule not in self.update_methods:
raise ValueError("Invalid update rule specified")


def forward(self, x): # Forward call triggers the update!
z = self.get_product(x)
# perform update during forward pass if required
if self.require_hebb:
# Select update rule
if self.update_rule == 'hebb':
self.hebb_update(x)
elif self.update_rule == 'oja':
self.oja_update(x)
elif self.update_rule == 'gupta':
if self.p is None:
raise ValueError("Percentile 'p' must be provided for Gupta update rule")
self.gupta_update(x)
elif self.update_rule == 'modified_gupta':
if self.p is None:
raise ValueError("Percentile 'p' must be provided for modified Gupta update rule")
self.modified_gupta_update(x)
else:
raise ValueError("Invalid update rule specified")
return z
def forward(self, x):
z = self.get_product(x)
if self.require_hebb:
update_method = self.update_methods.get(self.update_rule)
if update_method:
update_method(x)
else:
raise ValueError(f"Update rule {self.update_rule} not implemented")
return z

# 1.naive hebbian learning rule
def hebb_update(self,x):
Expand Down Expand Up @@ -113,4 +114,4 @@ def get_unsupervised_weights(X, n_hidden, n_epochs, batch_size,
nc = torch.max(torch.abs(ds))
if nc < precision: nc = precision
weights += eps*(ds/nc)
return weights
return weights

0 comments on commit 6c2b3fb

Please sign in to comment.