Github:https://github.com/enajx/HebbianMetaLearning
Blog:https://www.cnblogs.com/lucifer1997/p/14656603.html
Evolution Strategies as a Scalable Alternative to Reinforcement Learning
# Algorithm 2. Salimans, 2017: https://arxiv.org/abs/1703.03864
Meta-Learning through Hebbian Plasticity in Random Networks
HebbianMetaLearning/evolution_strategy_hebb.py:
- class EvolutionStrategyHebb(object):
单独进化赫布系数;赫布系数与CNN参数/初始权重共同进化;
- run:
def run(self, iterations, print_step=10, path='heb_coeffs'): id_ = str(int(time.time())) if not exists(path + '/' + id_): mkdir(path + '/' + id_) print('Run: ' + id_ + ' ........................................................................ ') pool = mp.Pool(self.num_threads) if self.num_threads > 1 else None generations_rewards = [] # Algorithm 2. Salimans, 2017: https://arxiv.org/abs/1703.03864 for iteration in range(iterations): # Evolution of Hebbian coefficients & coevolution of cnn parameters and/or initial weights if self.pixel_env or self.coevolve_init: population = self._get_population() # Sample normal noise: Step 5 population_coevolved = self._get_population(coevolved_param=True) # Sample normal noise: Step 5 rewards = self._get_rewards_coevolved(pool, population, population_coevolved) # Compute population fitness: Step 6 self._update_coeffs(rewards, population) # Update coefficients: Steps 8->12 self._update_coevolved_param(rewards, population_coevolved) # Update coevolved parameters: Steps 8->12 # Evolution of Hebbian coefficients else: population = self._get_population() # Sample normal noise: Step 5 rewards = self._get_rewards(pool, population) # Compute population fitness: Step 6 self._update_coeffs(rewards, population) # Update coefficients: Steps 8->12 # Print fitness and save Hebbian coefficients and/or Coevolved / CNNs parameters if (iteration + 1) % print_step == 0: rew_ = rewards.mean() print('iter %4i | reward: %3i | update_factor: %f lr: %f | sum_coeffs: %i sum_abs_coeffs: %4i' % (iteration + 1, rew_ , self.update_factor, self.learning_rate, int(np.sum(self.coeffs)), int(np.sum(abs(self.coeffs)))), flush=True) if rew_ > 100: torch.save(self.get_coeffs(), path + "/"+ id_ + '/HEBcoeffs__' + self.environment + "__rew_" + str(int(rew_)) + '__' + self.hebb_rule + "__init_" + str(self.init_weights) + "__pop_" + str(self.POPULATION_SIZE) + '__coeffs' + "__{}.dat".format(iteration)) if self.coevolve_init: torch.save(self.get_coevolved_parameters(), path + "/"+ id_ + '/HEBcoeffs__' + self.environment + "__rew_" + str(int(rew_)) + '__' + self.hebb_rule + "__init_" + str(self.init_weights) + "__pop_" + str(self.POPULATION_SIZE) + '__coevolved_initial_weights' + "__{}.dat".format(iteration)) elif self.pixel_env: torch.save(self.get_coevolved_parameters(), path + "/"+ id_ + '/HEBcoeffs__' + self.environment + "__rew_" + str(int(rew_)) + '__' + self.hebb_rule + "__init_" + str(self.init_weights) + "__pop_" + str(self.POPULATION_SIZE) + '__CNN_parameters' + "__{}.dat".format(iteration)) generations_rewards.append(rew_) np.save(path + "/"+ id_ + '/Fitness_values_' + id_ + '_' + self.environment + '.npy', np.array(generations_rewards)) if pool is not None: pool.close() pool.join()
- _get_population:
def _get_population(self, coevolved_param = False): population = [] if coevolved_param == False: for i in range(int(self.POPULATION_SIZE / 2)): x1 = [] x2 = [] for w in self.coeffs: # j: (coefficients_per_synapse, 1) eg. (5, 1) j = np.random.randn(*w.shape) # x1, x2: (coefficients_per_synapse, number of synapses) eg. (92690, 5) x1.append(j) x2.append(-j) # population: (population size, coefficients_per_synapse, number of synapses), eg. (10, 92690, 5) population.append(x1) population.append(x2) elif coevolved_param == True: for i in range(int(self.POPULATION_SIZE / 2)): x1 = [] x2 = [] for w in self.initial_weights_co: j = np.random.randn(*w.shape) x1.append(j) x2.append(-j) population.append(x1) population.append(x2) return np.array(population).astype(np.float32)
- _get_rewards_coevolved:
def _get_rewards_coevolved(self, pool, population, population_coevolved): if pool is not None: worker_args = [] for z in range(len(population)): heb_coeffs_try1 = [] for index, i in enumerate(population[z]): jittered = self.SIGMA * i heb_coeffs_try1.append(self.coeffs[index] + jittered) heb_coeffs_try = np.array(heb_coeffs_try1).astype(np.float32) coevolved_parameters_try1 = [] for index, i in enumerate(population_coevolved[z]): jittered = self.SIGMA * i coevolved_parameters_try1.append(self.initial_weights_co[index] + jittered) coevolved_parameters_try = np.array(coevolved_parameters_try1).astype(np.float32) worker_args.append((self.get_reward, self.hebb_rule, self.environment, self.init_weights, heb_coeffs_try, coevolved_parameters_try)) rewards = pool.map(worker_process_hebb_coevo, worker_args) else: rewards = [] for z in range(len(population)): heb_coeffs_try = np.array(self._get_params_try(self.coeffs, population[z])) coevolved_parameters_try = np.array(self._get_params_try(self.initial_weights_co, population_coevolved[z])) rewards.append(self.get_reward(self.hebb_rule, self.environment, self.init_weights, heb_coeffs_try, coevolved_parameters_try)) rewards = np.array(rewards).astype(np.float32) return rewards
- _update_coeffs / _update_coevolved_param:
def _update_coeffs(self, rewards, population): rewards = compute_centered_ranks(rewards) std = rewards.std() if std == 0: raise ValueError('Variance should not be zero') rewards = (rewards - rewards.mean()) / std for index, c in enumerate(self.coeffs): layer_population = np.array([p[index] for p in population]) self.update_factor = self.learning_rate / (self.POPULATION_SIZE * self.SIGMA) self.coeffs[index] = c + self.update_factor * np.dot(layer_population.T, rewards).T if self.learning_rate > 0.001: self.learning_rate *= self.decay # Decay sigma if self.SIGMA > 0.01: self.SIGMA *= 0.999 def _update_coevolved_param(self, rewards, population): rewards = compute_centered_ranks(rewards) std = rewards.std() if std == 0: raise ValueError('Variance should not be zero') rewards = (rewards - rewards.mean()) / std for index, w in enumerate(self.initial_weights_co): layer_population = np.array([p[index] for p in population]) self.update_factor = self.learning_rate / (self.POPULATION_SIZE * self.SIGMA) self.initial_weights_co[index] = w + self.update_factor * np.dot(layer_population.T, rewards).T
- _get_rewards:
def _get_rewards(self, pool, population): if pool is not None: worker_args = [] for p in population: heb_coeffs_try1 = [] for index, i in enumerate(p): jittered = self.SIGMA * i heb_coeffs_try1.append(self.coeffs[index] + jittered) heb_coeffs_try = np.array(heb_coeffs_try1).astype(np.float32) worker_args.append((self.get_reward, self.hebb_rule, self.environment, self.init_weights, heb_coeffs_try)) rewards = pool.map(worker_process_hebb, worker_args) else: rewards = [] for p in population: heb_coeffs_try = np.array(self._get_params_try(self.coeffs, p)) rewards.append(self.get_reward(self.hebb_rule, self.environment, self.init_weights, heb_coeffs_try)) rewards = np.array(rewards).astype(np.float32) return rewards
- get_coeffs:
def get_coeffs(self): return self.coeffs.astype(np.float32)
- get_coevolved_parameters:
def get_coevolved_parameters(self): return self.initial_weights_co.astype(np.float32)
HebbianMetaLearning/fitness_functions.py:
def fitness_hebb(hebb_rule : str, environment : str, init_weights = 'uni', *evolved_parameters: List[np.array]) -> float: """ Evaluate an agent 'evolved_parameters' controlled by a Hebbian network in an environment 'environment' during a lifetime. The initial weights are either co-evolved (if 'init_weights' == 'coevolve') along with the Hebbian coefficients or randomly sampled at each episode from the 'init_weights' distribution. Subsequently the weights are updated following the hebbian update mechanism 'hebb_rule'. Returns the episodic fitness of the agent. """ def weights_init(m): if isinstance(m, torch.nn.Linear): if init_weights == 'xa_uni': torch.nn.init.xavier_uniform(m.weight.data, 0.3) elif init_weights == 'sparse': torch.nn.init.sparse_(m.weight.data, 0.8) elif init_weights == 'uni': torch.nn.init.uniform_(m.weight.data, -0.1, 0.1) elif init_weights == 'normal': torch.nn.init.normal_(m.weight.data, 0, 0.024) elif init_weights == 'ka_uni': torch.nn.init.kaiming_uniform_(m.weight.data, 3) elif init_weights == 'uni_big': torch.nn.init.uniform_(m.weight.data, -1, 1) elif init_weights == 'xa_uni_big': torch.nn.init.xavier_uniform(m.weight.data) elif init_weights == 'ones': torch.nn.init.ones_(m.weight.data) elif init_weights == 'zeros': torch.nn.init.zeros_(m.weight.data) elif init_weights == 'default': pass # Unpack evolved parameters try: hebb_coeffs, initial_weights_co = evolved_parameters except: hebb_coeffs = evolved_parameters[0] # Intial weights co-evolution flag: coevolve_init = True if init_weights == 'coevolve' else False with torch.no_grad(): # Load environment try: env = gym.make(environment, verbose = 0) except: env = gym.make(environment) # env.render() # bullet envs # For environments with several intra-episode lives -eg. Breakout- try: if 'FIRE' in env.unwrapped.get_action_meanings(): env = FireEpisodicLifeEnv(env) except: pass # Check if selected env is pixel or state-vector if len(env.observation_space.shape) == 3: # Pixel-based environment pixel_env = True env = w.ResizeObservation(env, 84) # Resize and normilise input env = ScaledFloatFrame(env) input_channels = 3 elif len(env.observation_space.shape) == 1: pixel_env = False input_dim = env.observation_space.shape[0] elif len(env.observation_space.shape) == 0: pixel_env = False input_dim = env.observation_space.n # Determine action space dimension if isinstance(env.action_space, Box): action_dim = env.action_space.shape[0] elif isinstance(env.action_space, Discrete): action_dim = env.action_space.n else: raise ValueError('Only Box and Discrete action spaces supported') # Initialise policy network: with CNN layer for pixel envs and simple MLP for state-vector envs if pixel_env == True: p = CNN_heb(input_channels, action_dim) else: p = MLP_heb(input_dim, action_dim) # Initialise weights of the policy network with an specific distribution or with the co-evolved weights if coevolve_init: nn.utils.vector_to_parameters(torch.tensor(initial_weights_co, dtype=torch.float32), p.parameters()) else: # Randomly sample initial weights from chosen distribution p.apply(weights_init) # Load CNN paramters if pixel_env: cnn_weights1 = initial_weights_co[:162] cnn_weights2 = initial_weights_co[162:] list(p.parameters())[0].data = torch.tensor(cnn_weights1.reshape((6, 3, 3, 3))).float() list(p.parameters())[1].data = torch.tensor(cnn_weights2.reshape((8, 6, 5, 5))).float() p = p.float() # Unpack network's weights if pixel_env: weightsCNN1, weightsCNN2, weights1_2, weights2_3, weights3_4 = list(p.parameters()) else: weights1_2, weights2_3, weights3_4 = list(p.parameters()) # Convert weights to numpy so we can JIT them with Numba weights1_2 = weights1_2.detach().numpy() weights2_3 = weights2_3.detach().numpy() weights3_4 = weights3_4.detach().numpy() observation = env.reset() if pixel_env: observation = np.swapaxes(observation, 0, 2) # (3, 84, 84) # Burnout phase for the bullet quadruped so it starts off from the floor if environment == 'AntBulletEnv-v0': action = np.zeros(8) for _ in range(40): __ = env.step(action) # Normalize weights flag for non-bullet envs normalised_weights = False if environment[-12:-6] == 'Bullet' else True # Inner loop neg_count = 0 rew_ep = 0 t = 0 while True: # For obaservation ∈ gym.spaces.Discrete, we one-hot encode the observation if isinstance(env.observation_space, Discrete): observation = (observation == torch.arange(env.observation_space.n)).float() o0, o1, o2, o3 = p([observation]) o0 = o0.numpy() o1 = o1.numpy() o2 = o2.numpy() # Bounding the action space if environment == 'CarRacing-v0': action = np.array([torch.tanh(o3[0]), torch.sigmoid(o3[1]), torch.sigmoid(o3[2])]) o3 = o3.numpy() elif environment[-12:-6] == 'Bullet': o3 = torch.tanh(o3).numpy() action = o3 else: if isinstance(env.action_space, Box): action = o3.numpy() action = np.clip(action, env.action_space.low, env.action_space.high) elif isinstance(env.action_space, Discrete): action = np.argmax(o3).numpy() o3 = o3.numpy() # Environment simulation step observation, reward, done, info = env.step(action) if environment == 'AntBulletEnv-v0': reward = env.unwrapped.rewards[1] # Distance walked rew_ep += reward # env.render('human') # Gym envs if pixel_env: observation = np.swapaxes(observation, 0, 2) # (3, 84, 84) # Early stopping conditions if environment == 'CarRacing-v0': neg_count = neg_count + 1 if reward < 0.0 else 0 if (done or neg_count > 20): break elif environment[-12:-6] == 'Bullet': if t > 200: neg_count = neg_count + 1 if reward < 0.0 else 0 if (done or neg_count > 30): break else: if done: break t += 1 #### Episodic/Intra-life hebbian update of the weights if hebb_rule == 'A': weights1_2, weights2_3, weights3_4 = hebbian_update_A(hebb_coeffs, weights1_2, weights2_3, weights3_4, o0, o1, o2, o3) elif hebb_rule == 'AD': weights1_2, weights2_3, weights3_4 = hebbian_update_AD(hebb_coeffs, weights1_2, weights2_3, weights3_4, o0, o1, o2, o3) elif hebb_rule == 'AD_lr': weights1_2, weights2_3, weights3_4 = hebbian_update_AD_lr(hebb_coeffs, weights1_2, weights2_3, weights3_4, o0, o1, o2, o3) elif hebb_rule == 'ABC': weights1_2, weights2_3, weights3_4 = hebbian_update_ABC(hebb_coeffs, weights1_2, weights2_3, weights3_4, o0, o1, o2, o3) elif hebb_rule == 'ABC_lr': weights1_2, weights2_3, weights3_4 = hebbian_update_ABC_lr(hebb_coeffs, weights1_2, weights2_3, weights3_4, o0, o1, o2, o3) elif hebb_rule == 'ABCD': weights1_2, weights2_3, weights3_4 = hebbian_update_ABCD(hebb_coeffs, weights1_2, weights2_3, weights3_4, o0, o1, o2, o3) elif hebb_rule == 'ABCD_lr': weights1_2, weights2_3, weights3_4 = hebbian_update_ABCD_lr_D_in(hebb_coeffs, weights1_2, weights2_3, weights3_4, o0, o1, o2, o3) elif hebb_rule == 'ABCD_lr_D_out': weights1_2, weights2_3, weights3_4 = hebbian_update_ABCD_lr_D_out(hebb_coeffs, weights1_2, weights2_3, weights3_4, o0, o1, o2, o3) elif hebb_rule == 'ABCD_lr_D_in_and_out': weights1_2, weights2_3, weights3_4 = hebbian_update_ABCD_lr_D_in_and_out(hebb_coeffs, weights1_2, weights2_3, weights3_4, o0, o1, o2, o3) else: raise ValueError('The provided Hebbian rule is not valid') # Normalise weights per layer if normalised_weights == True: (a, b, c) = (0, 1, 2) if not pixel_env else (2, 3, 4) list(p.parameters())[a].data /= list(p.parameters())[a].__abs__().max() list(p.parameters())[b].data /= list(p.parameters())[b].__abs__().max() list(p.parameters())[c].data /= list(p.parameters())[c].__abs__().max() env.close() return rew_ep
HebbianMetaLearning/policies.py:
- MLP:
class MLP_heb(nn.Module): "MLP, no bias" def __init__(self, input_space, action_space): super(MLP_heb, self).__init__() self.fc1 = nn.Linear(input_space, 128, bias=False) self.fc2 = nn.Linear(128, 64, bias=False) self.fc3 = nn.Linear(64, action_space, bias=False) def forward(self, ob): state = torch.as_tensor(ob[0]).float().detach() x1 = torch.tanh(self.fc1(state)) x2 = torch.tanh(self.fc2(x1)) o = self.fc3(x2) return state, x1, x2, o
- CNN:
class CNN_heb(nn.Module): "CNN+MLP with n=input_channels frames as input. Non-activated last layer's output" def __init__(self, input_channels, action_space_dim): super(CNN_heb, self).__init__() self.conv1 = nn.Conv2d(in_channels=input_channels, out_channels=6, kernel_size=3, stride=1, bias=False) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(in_channels=6, out_channels=8, kernel_size=5, stride=2, bias=False) self.linear1 = nn.Linear(648, 128, bias=False) self.linear2 = nn.Linear(128, 64, bias=False) self.out = nn.Linear(64, action_space_dim, bias=False) def forward(self, ob): state = torch.as_tensor(ob.copy()) state = state.float() x1 = self.pool(torch.tanh(self.conv1(state))) x2 = self.pool(torch.tanh(self.conv2(x1))) x3 = x2.view(-1) x4 = torch.tanh(self.linear1(x3)) x5 = torch.tanh(self.linear2(x4)) o = self.out(x5) return x3, x4, x5, o
HebbianMetaLearning/hebbian_weights_update.py:
@njit def hebbian_update_A(heb_coeffs, weights1_2, weights2_3, weights3_4, o0, o1, o2, o3): heb_offset = 0 # Layer 1 for i in range(weights1_2.shape[1]): for j in range(weights1_2.shape[0]): idx = (weights1_2.shape[0] - 1) * i + i + j weights1_2[:, i][j] += heb_coeffs[idx] * o0[i] * o1[j] heb_offset = weights1_2.shape[1] * weights1_2.shape[0] # Layer 2 for i in range(weights2_3.shape[1]): for j in range(weights2_3.shape[0]): idx = heb_offset + (weights2_3.shape[0] - 1) * i + i + j weights2_3[:, i][j] += heb_coeffs[idx] * o1[i] * o2[j] heb_offset += weights2_3.shape[1] * weights2_3.shape[0] # Layer 3 for i in range(weights3_4.shape[1]): for j in range(weights3_4.shape[0]): idx = heb_offset + (weights3_4.shape[0] - 1) * i + i + j weights3_4[:, i][j] += heb_coeffs[idx] * o2[i] * o3[j] return weights1_2, weights2_3, weights3_4 @njit def hebbian_update_AD(heb_coeffs, weights1_2, weights2_3, weights3_4, o0, o1, o2, o3): heb_offset = 0 # Layer 1 for i in range(weights1_2.shape[1]): for j in range(weights1_2.shape[0]): idx = (weights1_2.shape[0] - 1) * i + i + j weights1_2[:, i][j] += heb_coeffs[idx][0] * o0[i] * o1[j] + heb_coeffs[idx][1] heb_offset = weights1_2.shape[1] * weights1_2.shape[0] # Layer 2 for i in range(weights2_3.shape[1]): for j in range(weights2_3.shape[0]): idx = heb_offset + (weights2_3.shape[0] - 1) * i + i + j weights2_3[:, i][j] += heb_coeffs[idx][0] * o1[i] * o2[j] + heb_coeffs[idx][1] heb_offset += weights2_3.shape[1] * weights2_3.shape[0] # Layer 3 for i in range(weights3_4.shape[1]): for j in range(weights3_4.shape[0]): idx = heb_offset + (weights3_4.shape[0] - 1) * i + i + j weights3_4[:, i][j] += heb_coeffs[idx][0] * o2[i] * o3[j] + heb_coeffs[idx][1] return weights1_2, weights2_3, weights3_4 @njit def hebbian_update_AD_lr(heb_coeffs, weights1_2, weights2_3, weights3_4, o0, o1, o2, o3): heb_offset = 0 # Layer 1 for i in range(weights1_2.shape[1]): for j in range(weights1_2.shape[0]): idx = (weights1_2.shape[0] - 1) * i + i + j weights1_2[:, i][j] += (heb_coeffs[idx][0] * o0[i] * o1[j] + heb_coeffs[idx][1]) * heb_coeffs[idx][2] heb_offset = weights1_2.shape[1] * weights1_2.shape[0] # Layer 2 for i in range(weights2_3.shape[1]): for j in range(weights2_3.shape[0]): idx = heb_offset + (weights2_3.shape[0] - 1) * i + i + j weights2_3[:, i][j] += (heb_coeffs[idx][0] * o1[i] * o2[j] + heb_coeffs[idx][1]) * heb_coeffs[idx][2] heb_offset += weights2_3.shape[1] * weights2_3.shape[0] # Layer 3 for i in range(weights3_4.shape[1]): for j in range(weights3_4.shape[0]): idx = heb_offset + (weights3_4.shape[0] - 1) * i + i + j weights3_4[:, i][j] += (heb_coeffs[idx][0] * o2[i] * o3[j] + heb_coeffs[idx][1]) * heb_coeffs[idx][2] return weights1_2, weights2_3, weights3_4 @njit def hebbian_update_ABC(heb_coeffs, weights1_2, weights2_3, weights3_4, o0, o1, o2, o3): heb_offset = 0 # Layer 1 for i in range(weights1_2.shape[1]): for j in range(weights1_2.shape[0]): idx = (weights1_2.shape[0] - 1) * i + i + j weights1_2[:, i][j] += ( heb_coeffs[idx][0] * o0[i] * o1[j] + heb_coeffs[idx][1] * o0[i] + heb_coeffs[idx][2] * o1[j]) heb_offset += weights1_2.shape[1] * weights1_2.shape[0] # Layer 2 for i in range(weights2_3.shape[1]): for j in range(weights2_3.shape[0]): idx = heb_offset + (weights2_3.shape[0] - 1) * i + i + j weights2_3[:, i][j] += ( heb_coeffs[idx][0] * o1[i] * o2[j] + heb_coeffs[idx][1] * o1[i] + heb_coeffs[idx][2] * o2[j]) heb_offset += weights2_3.shape[1] * weights2_3.shape[0] # Layer 3 for i in range(weights3_4.shape[1]): for j in range(weights3_4.shape[0]): idx = heb_offset + (weights3_4.shape[0] - 1) * i + i + j weights3_4[:, i][j] += ( heb_coeffs[idx][0] * o2[i] * o3[j] + heb_coeffs[idx][1] * o2[i] + heb_coeffs[idx][2] * o3[j]) return weights1_2, weights2_3, weights3_4 @njit def hebbian_update_ABC_lr(heb_coeffs, weights1_2, weights2_3, weights3_4, o0, o1, o2, o3): heb_offset = 0 # Layer 1 for i in range(weights1_2.shape[1]): for j in range(weights1_2.shape[0]): idx = (weights1_2.shape[0] - 1) * i + i + j weights1_2[:, i][j] += heb_coeffs[idx][3] * ( heb_coeffs[idx][0] * o0[i] * o1[j] + heb_coeffs[idx][1] * o0[i] + heb_coeffs[idx][2] * o1[j]) heb_offset += weights1_2.shape[1] * weights1_2.shape[0] # Layer 2 for i in range(weights2_3.shape[1]): for j in range(weights2_3.shape[0]): idx = heb_offset + (weights2_3.shape[0] - 1) * i + i + j weights2_3[:, i][j] += heb_coeffs[idx][3] * ( heb_coeffs[idx][0] * o1[i] * o2[j] + heb_coeffs[idx][1] * o1[i] + heb_coeffs[idx][2] * o2[j]) heb_offset += weights2_3.shape[1] * weights2_3.shape[0] # Layer 3 for i in range(weights3_4.shape[1]): for j in range(weights3_4.shape[0]): idx = heb_offset + (weights3_4.shape[0]-1) * i + i + j weights3_4[:, i][j] += heb_coeffs[idx][3] * ( heb_coeffs[idx][0] * o2[i] * o3[j] + heb_coeffs[idx][1] * o2[i] + heb_coeffs[idx][2] * o3[j]) return weights1_2, weights2_3, weights3_4 @njit def hebbian_update_ABCD(heb_coeffs, weights1_2, weights2_3, weights3_4, o0, o1, o2, o3): heb_offset = 0 # Layer 1 for i in range(weights1_2.shape[1]): for j in range(weights1_2.shape[0]): idx = (weights1_2.shape[0] - 1) * i + i + j weights1_2[:, i][j] += heb_coeffs[idx][3] + ( heb_coeffs[idx][0] * o0[i] * o1[j] + heb_coeffs[idx][1] * o0[i] + heb_coeffs[idx][2] * o1[j]) heb_offset += weights1_2.shape[1] * weights1_2.shape[0] # Layer 2 for i in range(weights2_3.shape[1]): for j in range(weights2_3.shape[0]): idx = heb_offset + (weights2_3.shape[0] - 1) * i + i + j weights2_3[:, i][j] += heb_coeffs[idx][3] + ( heb_coeffs[idx][0] * o1[i] * o2[j] + heb_coeffs[idx][1] * o1[i] + heb_coeffs[idx][2] * o2[j]) heb_offset += weights2_3.shape[1] * weights2_3.shape[0] # Layer 3 for i in range(weights3_4.shape[1]): for j in range(weights3_4.shape[0]): idx = heb_offset + (weights3_4.shape[0] - 1) * i + i + j weights3_4[:, i][j] += heb_coeffs[idx][3] + ( heb_coeffs[idx][0] * o2[i] * o3[j] + heb_coeffs[idx][1] * o2[i] + heb_coeffs[idx][2] * o3[j]) return weights1_2, weights2_3, weights3_4 @njit def hebbian_update_ABCD_lr_D_in(heb_coeffs, weights1_2, weights2_3, weights3_4, o0, o1, o2, o3): heb_offset = 0 ## Layer 1 for i in range(weights1_2.shape[1]): for j in range(weights1_2.shape[0]): idx = (weights1_2.shape[0] - 1) * i + i + j weights1_2[:, i][j] += heb_coeffs[idx][3] * ( heb_coeffs[idx][0] * o0[i] * o1[j] + heb_coeffs[idx][1] * o0[i] + heb_coeffs[idx][2] * o1[j] + heb_coeffs[idx][4]) heb_offset += weights1_2.shape[1] * weights1_2.shape[0] # Layer 2 for i in range(weights2_3.shape[1]): for j in range(weights2_3.shape[0]): idx = heb_offset + (weights2_3.shape[0] - 1) * i + i + j weights2_3[:, i][j] += heb_coeffs[idx][3] * ( heb_coeffs[idx][0] * o1[i] * o2[j] + heb_coeffs[idx][1] * o1[i] + heb_coeffs[idx][2] * o2[j] + heb_coeffs[idx][4]) heb_offset += weights2_3.shape[1] * weights2_3.shape[0] # Layer 3 for i in range(weights3_4.shape[1]): for j in range(weights3_4.shape[0]): idx = heb_offset + (weights3_4.shape[0] - 1) * i + i + j weights3_4[:, i][j] += heb_coeffs[idx][3] * ( heb_coeffs[idx][0] * o2[i] * o3[j] + heb_coeffs[idx][1] * o2[i] + heb_coeffs[idx][2] * o3[j] + heb_coeffs[idx][4]) return weights1_2, weights2_3, weights3_4 @njit def hebbian_update_ABCD_lr_D_out(heb_coeffs, weights1_2, weights2_3, weights3_4, o0, o1, o2, o3): heb_offset = 0 # Layer 1 for i in range(weights1_2.shape[1]): for j in range(weights1_2.shape[0]): idx = (weights1_2.shape[0] - 1) * i + i + j weights1_2[:, i][j] += heb_coeffs[idx][3] * ( heb_coeffs[idx][0] * o0[i] * o1[j] + heb_coeffs[idx][1] * o0[i] + heb_coeffs[idx][2] * o1[j]) + heb_coeffs[idx][4] heb_offset += weights1_2.shape[1] * weights1_2.shape[0] # Layer 2 for i in range(weights2_3.shape[1]): for j in range(weights2_3.shape[0]): idx = heb_offset + (weights2_3.shape[0] - 1) * i + i + j weights2_3[:, i][j] += heb_coeffs[idx][3] * ( heb_coeffs[idx][0] * o1[i] * o2[j] + heb_coeffs[idx][1] * o1[i] + heb_coeffs[idx][2] * o2[j]) + heb_coeffs[idx][4] heb_offset += weights2_3.shape[1] * weights2_3.shape[0] # Layer 3 for i in range(weights3_4.shape[1]): for j in range(weights3_4.shape[0]): idx = heb_offset + (weights3_4.shape[0] - 1) * i + i + j weights3_4[:, i][j] += heb_coeffs[idx][3] * ( heb_coeffs[idx][0] * o2[i] * o3[j] + heb_coeffs[idx][1] * o2[i] + heb_coeffs[idx][2] * o3[j]) + heb_coeffs[idx][4] return weights1_2, weights2_3, weights3_4 @njit def hebbian_update_ABCD_lr_D_in_and_out(heb_coeffs, weights1_2, weights2_3, weights3_4, o0, o1, o2, o3): heb_offset = 0 # Layer 1 for i in range(weights1_2.shape[1]): for j in range(weights1_2.shape[0]): idx = (weights1_2.shape[0] - 1) * i + i + j weights1_2[:, i][j] += heb_coeffs[idx][3] * ( heb_coeffs[idx][0] * o0[i] * o1[j] + heb_coeffs[idx][1] * o0[i] + heb_coeffs[idx][2] * o1[j] + heb_coeffs[idx][4]) + heb_coeffs[idx][5] heb_offset += weights1_2.shape[1] * weights1_2.shape[0] # Layer 2 for i in range(weights2_3.shape[1]): for j in range(weights2_3.shape[0]): idx = heb_offset + (weights2_3.shape[0] - 1) * i + i + j weights2_3[:, i][j] += heb_coeffs[idx][3] * ( heb_coeffs[idx][0] * o1[i] * o2[j] + heb_coeffs[idx][1] * o1[i] + heb_coeffs[idx][2] * o2[j] + heb_coeffs[idx][4]) + heb_coeffs[idx][5] heb_offset += weights2_3.shape[1] * weights2_3.shape[0] # Layer 3 for i in range(weights3_4.shape[1]): for j in range(weights3_4.shape[0]): idx = heb_offset + (weights3_4.shape[0] - 1) * i + i + j weights3_4[:, i][j] += heb_coeffs[idx][3] * ( heb_coeffs[idx][0] * o2[i] * o3[j] + heb_coeffs[idx][1] * o2[i] + heb_coeffs[idx][2] * o3[j] + heb_coeffs[idx][4]) + heb_coeffs[idx][5] return weights1_2, weights2_3, weights3_4