diff --git a/models/learning_rules.py b/models/learning_rules.py index db47e65..7d5783b 100644 --- a/models/learning_rules.py +++ b/models/learning_rules.py @@ -3,40 +3,41 @@ class HebbLayer(nn.Module): - def __init__(self, input_dim, output_dim, lr, require_hebb=True, activation=True, update_rule='hebb',p=None): - super(HebbLayer, self).__init__() - # Init weights - self.weights = nn.Parameter(torch.randn(input_dim, output_dim) * 0.01, requires_grad=False) # ban bp for hebb layers - self.require_hebb = require_hebb - self.past_product_sum = torch.zeros_like(self.weights.data) - self.stimuli_times = 0 - self.relu = nn.ReLU() - self.activation = activation - self.lr = lr - self.update_rule = update_rule - self.p = p + def __init__(self, input_dim, output_dim, lr, require_hebb=True, activation=True, update_rule='hebb', p=None): + super(HebbLayer, self).__init__() + # Initialize weights + self.weights = nn.Parameter(torch.randn(input_dim, output_dim) * 0.01, requires_grad=False) + self.require_hebb = require_hebb + self.past_product_sum = torch.zeros_like(self.weights.data) + self.stimuli_times = 0 + self.relu = nn.ReLU() + self.activation = activation + self.lr = lr + self.update_rule = update_rule + self.p = p + + # mapping + self.update_methods = { + 'hebb': self.hebb_update, + 'oja': self.oja_update, + 'gupta': self.gupta_update, + 'modified_gupta': self.modified_gupta_update + } + + # Validate the selected update rule + if update_rule not in self.update_methods: + raise ValueError("Invalid update rule specified") - def forward(self, x): # Forward call triggers the update! - z = self.get_product(x) - # perform update during forward pass if required - if self.require_hebb: - # Select update rule - if self.update_rule == 'hebb': - self.hebb_update(x) - elif self.update_rule == 'oja': - self.oja_update(x) - elif self.update_rule == 'gupta': - if self.p is None: - raise ValueError("Percentile 'p' must be provided for Gupta update rule") - self.gupta_update(x) - elif self.update_rule == 'modified_gupta': - if self.p is None: - raise ValueError("Percentile 'p' must be provided for modified Gupta update rule") - self.modified_gupta_update(x) - else: - raise ValueError("Invalid update rule specified") - return z + def forward(self, x): + z = self.get_product(x) + if self.require_hebb: + update_method = self.update_methods.get(self.update_rule) + if update_method: + update_method(x) + else: + raise ValueError(f"Update rule {self.update_rule} not implemented") + return z # 1.naive hebbian learning rule def hebb_update(self,x): @@ -113,4 +114,4 @@ def get_unsupervised_weights(X, n_hidden, n_epochs, batch_size, nc = torch.max(torch.abs(ds)) if nc < precision: nc = precision weights += eps*(ds/nc) - return weights \ No newline at end of file + return weights