• BindsNET学习系列 —— LearningRule


    相关源码:bindsnet/bindsnet/learning/learning.py

    1、LearningRule

    class LearningRule(ABC):
        # language=rst
        """
        Abstract base class for learning rules.
        """
    
        def __init__(
            self,
            connection: AbstractConnection,
            nu: Optional[Union[float, Sequence[float]]] = None,
            reduction: Optional[callable] = None,
            weight_decay: float = 0.0,
            **kwargs
        ) -> None:
            # language=rst
            """
            Abstract constructor for the ``LearningRule`` object.
    
            :param connection: An ``AbstractConnection`` object.
            :param nu: Single or pair of learning rates for pre- and post-synaptic events.
            :param reduction: Method for reducing parameter updates along the batch
                dimension.
            :param weight_decay: Constant multiple to decay weights by on each iteration.
            """
            # Connection parameters.
            self.connection = connection
            self.source = connection.source
            self.target = connection.target
    
            self.wmin = connection.wmin
            self.wmax = connection.wmax
    
            # Learning rate(s).
            if nu is None:
                nu = [0.0, 0.0]
            elif isinstance(nu, float) or isinstance(nu, int):
                nu = [nu, nu]
    
            self.nu = torch.zeros(2, dtype=torch.float)
            self.nu[0] = nu[0]
            self.nu[1] = nu[1]
    
            # Parameter update reduction across minibatch dimension.
            if reduction is None:
                if self.source.batch_size == 1:
                    self.reduction = torch.squeeze
                else:
                    self.reduction = torch.sum
            else:
                self.reduction = reduction
    
            # Weight decay.
            self.weight_decay = 1.0 - weight_decay if weight_decay else 1.0
    
        def update(self) -> None:
            # language=rst
            """
            Abstract method for a learning rule update.
            """
            # Implement weight decay.
            if self.weight_decay:
                self.connection.w *= self.weight_decay
    
            # Bound weights.
            if (
                self.connection.wmin != -np.inf or self.connection.wmax != np.inf
            ) and not isinstance(self, NoOp):
                self.connection.w.clamp_(self.connection.wmin, self.connection.wmax)

    学习规则的抽象基类,可以包括权重衰减和权重截断。

    2、NoOp(没有效果的学习规则)

    class NoOp(LearningRule):
        # language=rst
        """
        Learning rule with no effect.
        """
    
        def __init__(
            self,
            connection: AbstractConnection,
            nu: Optional[Union[float, Sequence[float]]] = None,
            reduction: Optional[callable] = None,
            weight_decay: float = 0.0,
            **kwargs
        ) -> None:
            # language=rst
            """
            Abstract constructor for the ``LearningRule`` object.
    
            :param connection: An ``AbstractConnection`` object.
            :param nu: Single or pair of learning rates for pre- and post-synaptic events.
            :param reduction: Method for reducing parameter updates along the batch
                dimension.
            :param weight_decay: Constant multiple to decay weights by on each iteration.
            """
            super().__init__(
                connection=connection,
                nu=nu,
                reduction=reduction,
                weight_decay=weight_decay,
                **kwargs
            )
    
        def update(self, **kwargs) -> None:
            # language=rst
            """
            Abstract method for a learning rule update.
            """
            super().update()

    没有效果的学习规则(默认),直接继承自学习规则的抽象基类。

    3、PostPre(STDP的在线实现)

    class PostPre(LearningRule):
        # language=rst
        """
        Simple STDP rule involving both pre- and post-synaptic spiking activity. By default,
        pre-synaptic update is negative and the post-synaptic update is positive.
        """
    
        def __init__(
            self,
            connection: AbstractConnection,
            nu: Optional[Union[float, Sequence[float]]] = None,
            reduction: Optional[callable] = None,
            weight_decay: float = 0.0,
            **kwargs
        ) -> None:
            # language=rst
            """
            Constructor for ``PostPre`` learning rule.
    
            :param connection: An ``AbstractConnection`` object whose weights the
                ``PostPre`` learning rule will modify.
            :param nu: Single or pair of learning rates for pre- and post-synaptic events.
            :param reduction: Method for reducing parameter updates along the batch
                dimension.
            :param weight_decay: Constant multiple to decay weights by on each iteration.
            """
            super().__init__(
                connection=connection,
                nu=nu,
                reduction=reduction,
                weight_decay=weight_decay,
                **kwargs
            )
    
            assert (
                self.source.traces and self.target.traces
            ), "Both pre- and post-synaptic nodes must record spike traces."
    
            if isinstance(connection, (Connection, LocalConnection)):
                self.update = self._connection_update
            elif isinstance(connection, Conv2dConnection):
                self.update = self._conv2d_connection_update
            else:
                raise NotImplementedError(
                    "This learning rule is not supported for this Connection type."
                )
    
        def _connection_update(self, **kwargs) -> None:
            # language=rst
            """
            Post-pre learning rule for ``Connection`` subclass of ``AbstractConnection``
            class.
            """
            batch_size = self.source.batch_size
    
            # Pre-synaptic update. 突触前脉冲发放时基于突触后迹更新突触权重(下降)
            if self.nu[0]:
                source_s = self.source.s.view(batch_size, -1).unsqueeze(2).float()
                target_x = self.target.x.view(batch_size, -1).unsqueeze(1) * self.nu[0]
                self.connection.w -= self.reduction(torch.bmm(source_s, target_x), dim=0)
                del source_s, target_x
    
            # Post-synaptic update. 突触后脉冲发放时基于突触前迹更新突触权重(上升)
            if self.nu[1]:
                target_s = self.target.s.view(batch_size, -1).unsqueeze(1).float() * self.nu[1]
                source_x = self.source.x.view(batch_size, -1).unsqueeze(2)
                self.connection.w += self.reduction(torch.bmm(source_x, target_s), dim=0)
                del source_x, target_s
    
            super().update()
    
        def _conv2d_connection_update(self, **kwargs) -> None:
            # language=rst
            """
            Post-pre learning rule for ``Conv2dConnection`` subclass of
            ``AbstractConnection`` class.
            """
            # Get convolutional layer parameters.
            out_channels, _, kernel_height, kernel_width = self.connection.w.size()
            padding, stride = self.connection.padding, self.connection.stride
            batch_size = self.source.batch_size
    
            # Reshaping spike traces and spike occurrences.
            source_x = im2col_indices(
                self.source.x, kernel_height, kernel_width, padding=padding, stride=stride
            )
            target_x = self.target.x.view(batch_size, out_channels, -1)
            source_s = im2col_indices(
                self.source.s.float(),
                kernel_height,
                kernel_width,
                padding=padding,
                stride=stride,
            )
            target_s = self.target.s.view(batch_size, out_channels, -1).float()
    
            # Pre-synaptic update.
            if self.nu[0]:
                pre = self.reduction(
                    torch.bmm(target_x, source_s.permute((0, 2, 1))), dim=0
                )
                self.connection.w -= self.nu[0] * pre.view(self.connection.w.size())
    
            # Post-synaptic update.
            if self.nu[1]:
                post = self.reduction(
                    torch.bmm(target_s, source_x.permute((0, 2, 1))), dim=0
                )
                self.connection.w += self.nu[1] * post.view(self.connection.w.size())
    
            super().update()

    简单的STDP规则,包括突触前和突触后的脉冲活动。默认情况下,突触前更新为负(-学习率1 × 突触前脉冲 × 突触后发放迹),突触后更新为正(学习率2 × 突触前发放迹 × 突触后脉冲)。

    4、WeightDependentPostPre

    class WeightDependentPostPre(LearningRule):
        # language=rst
        """
        STDP rule involving both pre- and post-synaptic spiking activity. The post-synaptic
        update is positive and the pre- synaptic update is negative, and both are dependent
        on the magnitude of the synaptic weights.
        """
    
        def __init__(
            self,
            connection: AbstractConnection,
            nu: Optional[Union[float, Sequence[float]]] = None,
            reduction: Optional[callable] = None,
            weight_decay: float = 0.0,
            **kwargs
        ) -> None:
            # language=rst
            """
            Constructor for ``WeightDependentPostPre`` learning rule.
    
            :param connection: An ``AbstractConnection`` object whose weights the
                ``WeightDependentPostPre`` learning rule will modify.
            :param nu: Single or pair of learning rates for pre- and post-synaptic events.
            :param reduction: Method for reducing parameter updates along the batch
                dimension.
            :param weight_decay: Constant multiple to decay weights by on each iteration.
            """
            super().__init__(
                connection=connection,
                nu=nu,
                reduction=reduction,
                weight_decay=weight_decay,
                **kwargs
            )
    
            assert self.source.traces, "Pre-synaptic nodes must record spike traces."
            assert (
                connection.wmin != -np.inf and connection.wmax != np.inf
            ), "Connection must define finite wmin and wmax."
    
            self.wmin = connection.wmin
            self.wmax = connection.wmax
    
            if isinstance(connection, (Connection, LocalConnection)):
                self.update = self._connection_update
            elif isinstance(connection, Conv2dConnection):
                self.update = self._conv2d_connection_update
            else:
                raise NotImplementedError(
                    "This learning rule is not supported for this Connection type."
                )
    
        def _connection_update(self, **kwargs) -> None:
            # language=rst
            """
            Post-pre learning rule for ``Connection`` subclass of ``AbstractConnection``
            class.
            """
            batch_size = self.source.batch_size
    
            source_s = self.source.s.view(batch_size, -1).unsqueeze(2).float()
            source_x = self.source.x.view(batch_size, -1).unsqueeze(2)
            target_s = self.target.s.view(batch_size, -1).unsqueeze(1).float()
            target_x = self.target.x.view(batch_size, -1).unsqueeze(1)
    
            update = 0
    
            # Pre-synaptic update.
            if self.nu[0]:
                outer_product = self.reduction(torch.bmm(source_s, target_x), dim=0)
                update -= self.nu[0] * outer_product * (self.connection.w - self.wmin)
    
            # Post-synaptic update.
            if self.nu[1]:
                outer_product = self.reduction(torch.bmm(source_x, target_s), dim=0)
                update += self.nu[1] * outer_product * (self.wmax - self.connection.w)
    
            self.connection.w += update
    
            super().update()
    
        def _conv2d_connection_update(self, **kwargs) -> None:
            # language=rst
            """
            Post-pre learning rule for ``Conv2dConnection`` subclass of
            ``AbstractConnection`` class.
            """
            # Get convolutional layer parameters.
            (
                out_channels,
                in_channels,
                kernel_height,
                kernel_width,
            ) = self.connection.w.size()
            padding, stride = self.connection.padding, self.connection.stride
            batch_size = self.source.batch_size
    
            # Reshaping spike traces and spike occurrences.
            source_x = im2col_indices(
                self.source.x, kernel_height, kernel_width, padding=padding, stride=stride
            )
            target_x = self.target.x.view(batch_size, out_channels, -1)
            source_s = im2col_indices(
                self.source.s.float(),
                kernel_height,
                kernel_width,
                padding=padding,
                stride=stride,
            )
            target_s = self.target.s.view(batch_size, out_channels, -1).float()
    
            update = 0
    
            # Pre-synaptic update.
            if self.nu[0]:
                pre = self.reduction(
                    torch.bmm(target_x, source_s.permute((0, 2, 1))), dim=0
                )
                update -= (
                    self.nu[0]
                    * pre.view(self.connection.w.size())
                    * (self.connection.w - self.wmin)
                )
    
            # Post-synaptic update.
            if self.nu[1]:
                post = self.reduction(
                    torch.bmm(target_s, source_x.permute((0, 2, 1))), dim=0
                )
                update += (
                    self.nu[1]
                    * post.view(self.connection.w.size())
                    * (self.wmax - self.connection.wmin)
                )
    
            self.connection.w += update
    
            super().update()

    STDP规则涉及突触前和突触后的脉冲活动。突触前更新为负(-学习率1 × 突触前脉冲 × 突触后发放迹 × (weight - weightmin)),突触后更新为正(学习率2 × 突触前发放迹 × 突触后脉冲 × (weightmax - weight)),两者都依赖于突触权重的大小。

    5、Hebbian

    class Hebbian(LearningRule):
        # language=rst
        """
        Simple Hebbian learning rule. Pre- and post-synaptic updates are both positive.
        """
    
        def __init__(
            self,
            connection: AbstractConnection,
            nu: Optional[Union[float, Sequence[float]]] = None,
            reduction: Optional[callable] = None,
            weight_decay: float = 0.0,
            **kwargs
        ) -> None:
            # language=rst
            """
            Constructor for ``Hebbian`` learning rule.
    
            :param connection: An ``AbstractConnection`` object whose weights the
                ``Hebbian`` learning rule will modify.
            :param nu: Single or pair of learning rates for pre- and post-synaptic events.
            :param reduction: Method for reducing parameter updates along the batch
                dimension.
            :param weight_decay: Constant multiple to decay weights by on each iteration.
            """
            super().__init__(
                connection=connection,
                nu=nu,
                reduction=reduction,
                weight_decay=weight_decay,
                **kwargs
            )
    
            assert (
                self.source.traces and self.target.traces
            ), "Both pre- and post-synaptic nodes must record spike traces."
    
            if isinstance(connection, (Connection, LocalConnection)):
                self.update = self._connection_update
            elif isinstance(connection, Conv2dConnection):
                self.update = self._conv2d_connection_update
            else:
                raise NotImplementedError(
                    "This learning rule is not supported for this Connection type."
                )
    
        def _connection_update(self, **kwargs) -> None:
            # language=rst
            """
            Hebbian learning rule for ``Connection`` subclass of ``AbstractConnection``
            class.
            """
            batch_size = self.source.batch_size
    
            source_s = self.source.s.view(batch_size, -1).unsqueeze(2).float()
            source_x = self.source.x.view(batch_size, -1).unsqueeze(2)
            target_s = self.target.s.view(batch_size, -1).unsqueeze(1).float()
            target_x = self.target.x.view(batch_size, -1).unsqueeze(1)
    
            # Pre-synaptic update.
            update = self.reduction(torch.bmm(source_s, target_x), dim=0)
            self.connection.w += self.nu[0] * update
    
            # Post-synaptic update.
            update = self.reduction(torch.bmm(source_x, target_s), dim=0)
            self.connection.w += self.nu[1] * update
    
            super().update()
    
        def _conv2d_connection_update(self, **kwargs) -> None:
            # language=rst
            """
            Hebbian learning rule for ``Conv2dConnection`` subclass of
            ``AbstractConnection`` class.
            """
            out_channels, _, kernel_height, kernel_width = self.connection.w.size()
            padding, stride = self.connection.padding, self.connection.stride
            batch_size = self.source.batch_size
    
            # Reshaping spike traces and spike occurrences.
            source_x = im2col_indices(
                self.source.x, kernel_height, kernel_width, padding=padding, stride=stride
            )
            target_x = self.target.x.view(batch_size, out_channels, -1)
            source_s = im2col_indices(
                self.source.s.float(),
                kernel_height,
                kernel_width,
                padding=padding,
                stride=stride,
            )
            target_s = self.target.s.view(batch_size, out_channels, -1).float()
    
            # Pre-synaptic update.
            pre = self.reduction(torch.bmm(target_x, source_s.permute((0, 2, 1))), dim=0)
            self.connection.w += self.nu[0] * pre.view(self.connection.w.size())
    
            # Post-synaptic update.
            post = self.reduction(torch.bmm(target_s, source_x.permute((0, 2, 1))), dim=0)
            self.connection.w += self.nu[1] * post.view(self.connection.w.size())
    
            super().update()

    简单的Hebbian学习规则。突触前后的更新都为正(分别为学习率1 × 突触前脉冲 × 突触后发放迹,以及学习率2 × 突触前发放迹× 突触后脉冲)

    6、MSTDP

    class MSTDP(LearningRule):
        # language=rst
        """
        Reward-modulated STDP. Adapted from `(Florian 2007)
        <https://florian.io/papers/2007_Florian_Modulated_STDP.pdf>`_.
        """
    
        def __init__(
            self,
            connection: AbstractConnection,
            nu: Optional[Union[float, Sequence[float]]] = None,
            reduction: Optional[callable] = None,
            weight_decay: float = 0.0,
            **kwargs
        ) -> None:
            # language=rst
            """
            Constructor for ``MSTDP`` learning rule.
    
            :param connection: An ``AbstractConnection`` object whose weights the ``MSTDP``
                learning rule will modify.
            :param nu: Single or pair of learning rates for pre- and post-synaptic events,
                respectively.
            :param reduction: Method for reducing parameter updates along the minibatch
                dimension.
            :param weight_decay: Constant multiple to decay weights by on each iteration.
    
            Keyword arguments:
    
            :param tc_plus: Time constant for pre-synaptic firing trace.
            :param tc_minus: Time constant for post-synaptic firing trace.
            """
            super().__init__(
                connection=connection,
                nu=nu,
                reduction=reduction,
                weight_decay=weight_decay,
                **kwargs
            )
    
            if isinstance(connection, (Connection, LocalConnection)):
                self.update = self._connection_update
            elif isinstance(connection, Conv2dConnection):
                self.update = self._conv2d_connection_update
            else:
                raise NotImplementedError(
                    "This learning rule is not supported for this Connection type."
                )
    
            self.tc_plus = torch.tensor(kwargs.get("tc_plus", 20.0))
            self.tc_minus = torch.tensor(kwargs.get("tc_minus", 20.0))
    
        def _connection_update(self, **kwargs) -> None:
            # language=rst
            """
            MSTDP learning rule for ``Connection`` subclass of ``AbstractConnection`` class.
    
            Keyword arguments:
    
            :param Union[float, torch.Tensor] reward: Reward signal from reinforcement
                learning task.
            :param float a_plus: Learning rate (post-synaptic).
            :param float a_minus: Learning rate (pre-synaptic).
            """
            batch_size = self.source.batch_size
    
            # Initialize eligibility, P^+, and P^-.
            if not hasattr(self, "p_plus"):
                self.p_plus = torch.zeros(
                    batch_size, *self.source.shape, device=self.source.s.device
                )
            if not hasattr(self, "p_minus"):
                self.p_minus = torch.zeros(
                    batch_size, *self.target.shape, device=self.target.s.device
                )
            if not hasattr(self, "eligibility"):
                self.eligibility = torch.zeros(
                    batch_size, *self.connection.w.shape, device=self.connection.w.device
                )
    
            # Reshape pre- and post-synaptic spikes.
            source_s = self.source.s.view(batch_size, -1).float()
            target_s = self.target.s.view(batch_size, -1).float()
    
            # Parse keyword arguments.
            reward = kwargs["reward"]
            a_plus = torch.tensor(
                kwargs.get("a_plus", 1.0), device=self.connection.w.device
            )
            a_minus = torch.tensor(
                kwargs.get("a_minus", -1.0), device=self.connection.w.device
            )
    
            # Compute weight update based on the eligibility value of the past timestep.
            update = reward * self.eligibility
            self.connection.w += self.nu[0] * self.reduction(update, dim=0)
    
            # Update P^+ and P^- values.
            self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus)
            self.p_plus += a_plus * source_s
            self.p_minus *= torch.exp(-self.connection.dt / self.tc_minus)
            self.p_minus += a_minus * target_s
    
            # Calculate point eligibility value.
            self.eligibility = torch.bmm(
                self.p_plus.unsqueeze(2), target_s.unsqueeze(1)
            ) + torch.bmm(source_s.unsqueeze(2), self.p_minus.unsqueeze(1))
    
            super().update()
    
        def _conv2d_connection_update(self, **kwargs) -> None:
            # language=rst
            """
            MSTDP learning rule for ``Conv2dConnection`` subclass of ``AbstractConnection``
            class.
    
            Keyword arguments:
    
            :param Union[float, torch.Tensor] reward: Reward signal from reinforcement
                learning task.
            :param float a_plus: Learning rate (post-synaptic).
            :param float a_minus: Learning rate (pre-synaptic).
            """
            batch_size = self.source.batch_size
    
            # Initialize eligibility.
            if not hasattr(self, "eligibility"):
                self.eligibility = torch.zeros(
                    batch_size, *self.connection.w.shape, device=self.connection.w.device
                )
    
            # Parse keyword arguments.
            reward = kwargs["reward"]
            a_plus = torch.tensor(
                kwargs.get("a_plus", 1.0), device=self.connection.w.device
            )
            a_minus = torch.tensor(
                kwargs.get("a_minus", -1.0), device=self.connection.w.device
            )
    
            batch_size = self.source.batch_size
    
            # Compute weight update based on the eligibility value of the past timestep.
            update = reward * self.eligibility
            self.connection.w += self.nu[0] * torch.sum(update, dim=0)
    
            out_channels, _, kernel_height, kernel_width = self.connection.w.size()
            padding, stride = self.connection.padding, self.connection.stride
    
            # Initialize P^+ and P^-.
            if not hasattr(self, "p_plus"):
                self.p_plus = torch.zeros(
                    batch_size, *self.source.shape, device=self.connection.w.device
                )
                self.p_plus = im2col_indices(
                    self.p_plus, kernel_height, kernel_width, padding=padding, stride=stride
                )
            if not hasattr(self, "p_minus"):
                self.p_minus = torch.zeros(
                    batch_size, *self.target.shape, device=self.connection.w.device
                )
                self.p_minus = self.p_minus.view(batch_size, out_channels, -1).float()
    
            # Reshaping spike occurrences.
            source_s = im2col_indices(
                self.source.s.float(),
                kernel_height,
                kernel_width,
                padding=padding,
                stride=stride,
            )
            target_s = self.target.s.view(batch_size, out_channels, -1).float()
    
            # Update P^+ and P^- values. 前者跟踪突触前脉冲的影响,后者跟踪突触后脉冲的影响
            self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus)
            self.p_plus += a_plus * source_s
            self.p_minus *= torch.exp(-self.connection.dt / self.tc_minus)
            self.p_minus += a_minus * target_s
    
            # Calculate point eligibility value.
            self.eligibility = torch.bmm(
                target_s, self.p_plus.permute((0, 2, 1))
            ) + torch.bmm(self.p_minus, source_s.permute((0, 2, 1)))
            self.eligibility = self.eligibility.view(self.connection.w.size())
    
            super().update()

    奖励调节STDP (R-STDP),改编自(Florian 2007)<https://florian.io/papers/2007_Florian_Modulated_STDP.pdf>

    论文参见:Reinforcement Learning Through Modulation of Spike-Timing-Dependent Synaptic Plasticity - 穷酸秀才大艹包 - 博客园 (cnblogs.com)

    7、MSTDPET

    class MSTDPET(LearningRule):
        # language=rst
        """
        Reward-modulated STDP with eligibility trace. Adapted from
        `(Florian 2007) <https://florian.io/papers/2007_Florian_Modulated_STDP.pdf>`_.
        """
    
        def __init__(
            self,
            connection: AbstractConnection,
            nu: Optional[Union[float, Sequence[float]]] = None,
            reduction: Optional[callable] = None,
            weight_decay: float = 0.0,
            **kwargs
        ) -> None:
            # language=rst
            """
            Constructor for ``MSTDPET`` learning rule.
    
            :param connection: An ``AbstractConnection`` object whose weights the
                ``MSTDPET`` learning rule will modify.
            :param nu: Single or pair of learning rates for pre- and post-synaptic events,
                respectively.
            :param reduction: Method for reducing parameter updates along the minibatch
                dimension.
            :param weight_decay: Constant multiple to decay weights by on each iteration.
    
            Keyword arguments:
    
            :param float tc_plus: Time constant for pre-synaptic firing trace.
            :param float tc_minus: Time constant for post-synaptic firing trace.
            :param float tc_e_trace: Time constant for the eligibility trace.
            """
            super().__init__(
                connection=connection,
                nu=nu,
                reduction=reduction,
                weight_decay=weight_decay,
                **kwargs
            )
    
            if isinstance(connection, (Connection, LocalConnection)):
                self.update = self._connection_update
            elif isinstance(connection, Conv2dConnection):
                self.update = self._conv2d_connection_update
            else:
                raise NotImplementedError(
                    "This learning rule is not supported for this Connection type."
                )
    
            self.tc_plus = torch.tensor(kwargs.get("tc_plus", 20.0))
            self.tc_minus = torch.tensor(kwargs.get("tc_minus", 20.0))
            self.tc_e_trace = torch.tensor(kwargs.get("tc_e_trace", 25.0))
    
        def _connection_update(self, **kwargs) -> None:
            # language=rst
            """
            MSTDPET learning rule for ``Connection`` subclass of ``AbstractConnection``
            class.
    
            Keyword arguments:
    
            :param Union[float, torch.Tensor] reward: Reward signal from reinforcement
                learning task.
            :param float a_plus: Learning rate (post-synaptic).
            :param float a_minus: Learning rate (pre-synaptic).
            """
            # Initialize eligibility, eligibility trace, P^+, and P^-.
            if not hasattr(self, "p_plus"):
                self.p_plus = torch.zeros((self.source.n), device=self.source.s.device)
            if not hasattr(self, "p_minus"):
                self.p_minus = torch.zeros((self.target.n), device=self.target.s.device)
            if not hasattr(self, "eligibility"):
                self.eligibility = torch.zeros(
                    *self.connection.w.shape, device=self.connection.w.device
                )
            if not hasattr(self, "eligibility_trace"):
                self.eligibility_trace = torch.zeros(
                    *self.connection.w.shape, device=self.connection.w.device
                )
    
            # Reshape pre- and post-synaptic spikes.
            source_s = self.source.s.view(-1).float()
            target_s = self.target.s.view(-1).float()
    
            # Parse keyword arguments.
            reward = kwargs["reward"]
            a_plus = torch.tensor(
                kwargs.get("a_plus", 1.0), device=self.connection.w.device
            )
            a_minus = torch.tensor(
                kwargs.get("a_minus", -1.0), device=self.connection.w.device
            )
    
            # Calculate value of eligibility trace based on the value
            # of the point eligibility value of the past timestep.
            self.eligibility_trace *= torch.exp(-self.connection.dt / self.tc_e_trace)
            self.eligibility_trace += self.eligibility / self.tc_e_trace
    
            # Compute weight update.
            self.connection.w += (
                self.nu[0] * self.connection.dt * reward * self.eligibility_trace
            )
    
            # Update P^+ and P^- values.
            self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus)
            self.p_plus += a_plus * source_s
            self.p_minus *= torch.exp(-self.connection.dt / self.tc_minus)
            self.p_minus += a_minus * target_s
    
            # Calculate point eligibility value.
            self.eligibility = torch.ger(self.p_plus, target_s) + torch.ger(
                source_s, self.p_minus
            )
    
            super().update()
    
        def _conv2d_connection_update(self, **kwargs) -> None:
            # language=rst
            """
            MSTDPET learning rule for ``Conv2dConnection`` subclass of
            ``AbstractConnection`` class.
    
            Keyword arguments:
    
            :param Union[float, torch.Tensor] reward: Reward signal from reinforcement
                learning task.
            :param float a_plus: Learning rate (post-synaptic).
            :param float a_minus: Learning rate (pre-synaptic).
            """
            batch_size = self.source.batch_size
    
            # Initialize eligibility and eligibility trace.
            if not hasattr(self, "eligibility"):
                self.eligibility = torch.zeros(
                    batch_size, *self.connection.w.shape, device=self.connection.w.device
                )
            if not hasattr(self, "eligibility_trace"):
                self.eligibility_trace = torch.zeros(
                    batch_size, *self.connection.w.shape, device=self.connection.w.device
                )
    
            # Parse keyword arguments.
            reward = kwargs["reward"]
            a_plus = torch.tensor(
                kwargs.get("a_plus", 1.0), device=self.connection.w.device
            )
            a_minus = torch.tensor(
                kwargs.get("a_minus", -1.0), device=self.connection.w.device
            )
    
            # Calculate value of eligibility trace based on the value
            # of the point eligibility value of the past timestep.
            self.eligibility_trace *= torch.exp(-self.connection.dt / self.tc_e_trace)
    
            # Compute weight update.
            update = reward * self.eligibility_trace
            self.connection.w += self.nu[0] * self.connection.dt * torch.sum(update, dim=0)
    
            out_channels, _, kernel_height, kernel_width = self.connection.w.size()
            padding, stride = self.connection.padding, self.connection.stride
    
            # Initialize P^+ and P^-.
            if not hasattr(self, "p_plus"):
                self.p_plus = torch.zeros(
                    batch_size, *self.source.shape, device=self.connection.w.device
                )
                self.p_plus = im2col_indices(
                    self.p_plus, kernel_height, kernel_width, padding=padding, stride=stride
                )
            if not hasattr(self, "p_minus"):
                self.p_minus = torch.zeros(
                    batch_size, *self.target.shape, device=self.connection.w.device
                )
                self.p_minus = self.p_minus.view(batch_size, out_channels, -1).float()
    
            # Reshaping spike occurrences.
            source_s = im2col_indices(
                self.source.s.float(),
                kernel_height,
                kernel_width,
                padding=padding,
                stride=stride,
            )
            target_s = (
                self.target.s.permute(1, 2, 3, 0).view(batch_size, out_channels, -1).float()
            )
    
            # Update P^+ and P^- values.
            self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus)
            self.p_plus += a_plus * source_s
            self.p_minus *= torch.exp(-self.connection.dt / self.tc_minus)
            self.p_minus += a_minus * target_s
    
            # Calculate point eligibility value.
            self.eligibility = torch.bmm(
                target_s, self.p_plus.permute((0, 2, 1))
            ) + torch.bmm(self.p_minus, source_s.permute((0, 2, 1)))
            self.eligibility = self.eligibility.view(self.connection.w.size())
    
            super().update()

    带资格迹的奖励调节STDP (R-STDP with eligibility trace),改编自(Florian 2007)<https://florian.io/papers/2007_Florian_Modulated_STDP.pdf>

    论文参见:Reinforcement Learning Through Modulation of Spike-Timing-Dependent Synaptic Plasticity - 穷酸秀才大艹包 - 博客园 (cnblogs.com)

    8、Rmax

    class Rmax(LearningRule):
        # language=rst
        """
        Reward-modulated learning rule derived from reward maximization principles. Adapted
        from `(Vasilaki et al., 2009)
        <https://intranet.physio.unibe.ch/Publikationen/Dokumente/Vasilaki2009PloSComputBio_1.pdf>`_.
        """
    
        def __init__(
            self,
            connection: AbstractConnection,
            nu: Optional[Union[float, Sequence[float]]] = None,
            reduction: Optional[callable] = None,
            weight_decay: float = 0.0,
            **kwargs
        ) -> None:
            # language=rst
            """
            Constructor for ``R-max`` learning rule.
    
            :param connection: An ``AbstractConnection`` object whose weights the ``R-max``
                learning rule will modify.
            :param nu: Single or pair of learning rates for pre- and post-synaptic events,
                respectively.
            :param reduction: Method for reducing parameter updates along the minibatch
                dimension.
            :param weight_decay: Constant multiple to decay weights by on each iteration.
    
            Keyword arguments:
    
            :param float tc_c: Time constant for balancing naive Hebbian and policy gradient
                learning.
            :param float tc_e_trace: Time constant for the eligibility trace.
            """
            super().__init__(
                connection=connection,
                nu=nu,
                reduction=reduction,
                weight_decay=weight_decay,
                **kwargs
            )
    
            # Trace is needed for computing epsilon.
            assert (
                self.source.traces and self.source.traces_additive
            ), "Pre-synaptic nodes must use additive spike traces."
    
            # Derivation of R-max depends on stochastic SRM neurons!
            assert isinstance(
                self.target, SRM0Nodes
            ), "R-max needs stochastically firing neurons, use SRM0Nodes."
    
            if isinstance(connection, (Connection, LocalConnection)):
                self.update = self._connection_update
            else:
                raise NotImplementedError(
                    "This learning rule is not supported for this Connection type."
                )
    
            self.tc_c = torch.tensor(
                kwargs.get("tc_c", 5.0)
            )  # 0 for pure naive Hebbian, inf for pure policy gradient.
            self.tc_e_trace = torch.tensor(kwargs.get("tc_e_trace", 25.0))
    
        def _connection_update(self, **kwargs) -> None:
            # language=rst
            """
            R-max learning rule for ``Connection`` subclass of ``AbstractConnection`` class.
    
            Keyword arguments:
    
            :param Union[float, torch.Tensor] reward: Reward signal from reinforcement
                learning task.
            """
            # Initialize eligibility trace.
            if not hasattr(self, "eligibility_trace"):
                self.eligibility_trace = torch.zeros(
                    *self.connection.w.shape, device=self.connection.w.device
                )
    
            # Reshape variables.
            target_s = self.target.s.view(-1).float()
            target_s_prob = self.target.s_prob.view(-1)
            source_x = self.source.x.view(-1)
    
            # Parse keyword arguments.
            reward = kwargs["reward"]
    
            # New eligibility trace.
            self.eligibility_trace *= 1 - self.connection.dt / self.tc_e_trace
            self.eligibility_trace += (
                target_s
                - (target_s_prob / (1.0 + self.tc_c / self.connection.dt * target_s_prob))
            ) * source_x[:, None]
    
            # Compute weight update.
            self.connection.w += self.nu[0] * reward * self.eligibility_trace
    
            super().update()

    基于奖励最大化原理的奖励调节学习规则,改编自(Vasilaki et al., 2009)<https://intranet.physio.unibe.ch/Publikationen/Dokumente/Vasilaki2009PloSComputBio_1.pdf>

  • 相关阅读:
    html input type=file 选择图片,图片预览 纯html js实现图片预览
    asp.net mvc Controller控制器返回类型
    webrequest HttpWebRequest webclient/HttpClient
    js中__proto__和prototype constructor 的区别和关系
    JQuery的ajaxFileUpload的使用
    cuda中当数组数大于线程数的处理方法
    cuda中threadIdx、blockIdx、blockDim和gridDim的使用
    cuda和gcc版本不兼容
    【转】CentOS 6.6 升级GCC G++ (当前最新版本为v6.1.0) (完整)
    matlab练习程序(地图上画经纬度)
  • 原文地址:https://www.cnblogs.com/lucifer1997/p/14313716.html
Copyright © 2020-2023  润新知