Source code for torch.nn.modules.activation
import warnings
from typing import Optional, Tuple
import torch
from torch import Tensor
from .linear import _LinearWithBias
from torch.nn.init import xavier_uniform_
from torch.nn.init import constant_
from torch.nn.init import xavier_normal_
from torch.nn.parameter import Parameter
from .module import Module
from .. import functional as F
class Threshold(Module):
r"""Thresholds each element of the input Tensor.
Threshold is defined as:
.. math::
y =
\begin{cases}
x, &\text{ if } x > \text{threshold} \\
\text{value}, &\text{ otherwise }
\end{cases}
Args:
threshold: The value to threshold at
value: The value to replace with
inplace: can optionally do the operation in-place. Default: ``False``
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
Examples::
>>> m = nn.Threshold(0.1, 20)
>>> input = torch.randn(2)
>>> output = m(input)
"""
__constants__ = ['threshold', 'value', 'inplace']
threshold: float
value: float
inplace: bool
def __init__(self, threshold: float, value: float, inplace: bool = False) -> None:
super(Threshold, self).__init__()
self.threshold = threshold
self.value = value
self.inplace = inplace
# TODO: check in THNN (if inplace == True, then assert value <= threshold)
def forward(self, input: Tensor) -> Tensor:
return F.threshold(input, self.threshold, self.value, self.inplace)
def extra_repr(self):
inplace_str = ', inplace=True' if self.inplace else ''
return 'threshold={}, value={}{}'.format(
self.threshold, self.value, inplace_str
)
class ReLU(Module):
r"""Applies the rectified linear unit function element-wise:
:math:`\text{ReLU}(x) = (x)^+ = \max(0, x)`
Args:
inplace: can optionally do the operation in-place. Default: ``False``
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
.. image:: ../scripts/activation_images/ReLU.png
Examples::
>>> m = nn.ReLU()
>>> input = torch.randn(2)
>>> output = m(input)
An implementation of CReLU - https://arxiv.org/abs/1603.05201
>>> m = nn.ReLU()
>>> input = torch.randn(2).unsqueeze(0)
>>> output = torch.cat((m(input),m(-input)))
"""
__constants__ = ['inplace']
inplace: bool
def __init__(self, inplace: bool = False):
super(ReLU, self).__init__()
self.inplace = inplace
def forward(self, input: Tensor) -> Tensor:
return F.relu(input, inplace=self.inplace)
def extra_repr(self) -> str:
inplace_str = 'inplace=True' if self.inplace else ''
return inplace_str
class RReLU(Module):
r"""Applies the randomized leaky rectified liner unit function, element-wise,
as described in the paper:
`Empirical Evaluation of Rectified Activations in Convolutional Network`_.
The function is defined as:
.. math::
\text{RReLU}(x) =
\begin{cases}
x & \text{if } x \geq 0 \\
ax & \text{ otherwise }
\end{cases}
where :math:`a` is randomly sampled from uniform distribution
:math:`\mathcal{U}(\text{lower}, \text{upper})`.
See: https://arxiv.org/pdf/1505.00853.pdf
Args:
lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
inplace: can optionally do the operation in-place. Default: ``False``
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
Examples::
>>> m = nn.RReLU(0.1, 0.3)
>>> input = torch.randn(2)
>>> output = m(input)
.. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
https://arxiv.org/abs/1505.00853
"""
__constants__ = ['lower', 'upper', 'inplace']
lower: float
upper: float
inplace: bool
def __init__(
self,
lower: float = 1. / 8,
upper: float = 1. / 3,
inplace: bool = False
):
super(RReLU, self).__init__()
self.lower = lower
self.upper = upper
self.inplace = inplace
def forward(self, input: Tensor) -> Tensor:
return F.rrelu(input, self.lower, self.upper, self.training, self.inplace)
def extra_repr(self):
inplace_str = ', inplace=True' if self.inplace else ''
return 'lower={}, upper={}{}'.format(self.lower, self.upper, inplace_str)
[docs]class Hardtanh(Module):
r"""Applies the HardTanh function element-wise
HardTanh is defined as:
.. math::
\text{HardTanh}(x) = \begin{cases}
1 & \text{ if } x > 1 \\
-1 & \text{ if } x < -1 \\
x & \text{ otherwise } \\
\end{cases}
The range of the linear region :math:`[-1, 1]` can be adjusted using
:attr:`min_val` and :attr:`max_val`.
Args:
min_val: minimum value of the linear region range. Default: -1
max_val: maximum value of the linear region range. Default: 1
inplace: can optionally do the operation in-place. Default: ``False``
Keyword arguments :attr:`min_value` and :attr:`max_value`
have been deprecated in favor of :attr:`min_val` and :attr:`max_val`.
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
.. image:: ../scripts/activation_images/Hardtanh.png
Examples::
>>> m = nn.Hardtanh(-2, 2)
>>> input = torch.randn(2)
>>> output = m(input)
"""
__constants__ = ['min_val', 'max_val', 'inplace']
min_val: float
max_val: float
inplace: bool
def __init__(
self,
min_val: float = -1.,
max_val: float = 1.,
inplace: bool = False,
min_value: Optional[float] = None,
max_value: Optional[float] = None
) -> None:
super(Hardtanh, self).__init__()
if min_value is not None:
warnings.warn("keyword argument min_value is deprecated and rename to min_val")
min_val = min_value
if max_value is not None:
warnings.warn("keyword argument max_value is deprecated and rename to max_val")
max_val = max_value
self.min_val = min_val
self.max_val = max_val
self.inplace = inplace
assert self.max_val > self.min_val
def forward(self, input: Tensor) -> Tensor:
return F.hardtanh(input, self.min_val, self.max_val, self.inplace)
def extra_repr(self) -> str:
inplace_str = ', inplace=True' if self.inplace else ''
return 'min_val={}, max_val={}{}'.format(
self.min_val, self.max_val, inplace_str
)
class ReLU6(Hardtanh):
r"""Applies the element-wise function:
.. math::
\text{ReLU6}(x) = \min(\max(0,x), 6)
Args:
inplace: can optionally do the operation in-place. Default: ``False``
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
.. image:: ../scripts/activation_images/ReLU6.png
Examples::
>>> m = nn.ReLU6()
>>> input = torch.randn(2)
>>> output = m(input)
"""
def __init__(self, inplace: bool = False):
super(ReLU6, self).__init__(0., 6., inplace)
def extra_repr(self) -> str:
inplace_str = 'inplace=True' if self.inplace else ''
return inplace_str
class Sigmoid(Module):
r"""Applies the element-wise function:
.. math::
\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
.. image:: ../scripts/activation_images/Sigmoid.png
Examples::
>>> m = nn.Sigmoid()
>>> input = torch.randn(2)
>>> output = m(input)
"""
def forward(self, input: Tensor) -> Tensor:
return torch.sigmoid(input)
[docs]class Hardsigmoid(Module):
r"""Applies the element-wise function:
.. math::
\text{Hardsigmoid}(x) = \begin{cases}
0 & \text{if~} x \le -3, \\
1 & \text{if~} x \ge +3, \\
x / 6 + 1 / 2 & \text{otherwise}
\end{cases}
Args:
inplace: can optionally do the operation in-place. Default: ``False``
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
Examples::
>>> m = nn.Hardsigmoid()
>>> input = torch.randn(2)
>>> output = m(input)
"""
__constants__ = ['inplace']
inplace: bool
def __init__(self, inplace : bool = False) -> None:
super(Hardsigmoid, self).__init__()
self.inplace = inplace
def forward(self, input: Tensor) -> Tensor:
return F.hardsigmoid(input, self.inplace)
class Tanh(Module):
r"""Applies the element-wise function:
.. math::
\text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)}
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
.. image:: ../scripts/activation_images/Tanh.png
Examples::
>>> m = nn.Tanh()
>>> input = torch.randn(2)
>>> output = m(input)
"""
def forward(self, input: Tensor) -> Tensor:
return torch.tanh(input)
class SiLU(Module):
r"""Applies the silu function, element-wise.
.. math::
\text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.}
.. note::
See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_
where the SiLU (Sigmoid Linear Unit) was originally coined, and see
`Sigmoid-Weighted Linear Units for Neural Network Function Approximation
in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish:
a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_
where the SiLU was experimented with later.
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
Examples::
>>> m = nn.SiLU()
>>> input = torch.randn(2)
>>> output = m(input)
"""
__constants__ = ['inplace']
inplace: bool
def __init__(self, inplace: bool = False):
super(SiLU, self).__init__()
self.inplace = inplace
def forward(self, input: Tensor) -> Tensor:
return F.silu(input, inplace=self.inplace)
def extra_repr(self) -> str:
inplace_str = 'inplace=True' if self.inplace else ''
return inplace_str
[docs]class Hardswish(Module):
r"""Applies the hardswish function, element-wise, as described in the paper:
`Searching for MobileNetV3`_.
.. math::
\text{Hardswish}(x) = \begin{cases}
0 & \text{if~} x \le -3, \\
x & \text{if~} x \ge +3, \\
x \cdot (x + 3) /6 & \text{otherwise}
\end{cases}
Args:
inplace: can optionally do the operation in-place. Default: ``False``
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
Examples::
>>> m = nn.Hardswish()
>>> input = torch.randn(2)
>>> output = m(input)
.. _`Searching for MobileNetV3`:
https://arxiv.org/abs/1905.02244
"""
__constants__ = ['inplace']
inplace: bool
def __init__(self, inplace : bool = False) -> None:
super(Hardswish, self).__init__()
self.inplace = inplace
def forward(self, input: Tensor) -> Tensor:
return F.hardswish(input, self.inplace)
class ELU(Module):
r"""Applies the element-wise function:
.. math::
\text{ELU}(x) = \begin{cases}
x, & \text{ if } x > 0\\
\alpha * (\exp(x) - 1), & \text{ if } x \leq 0
\end{cases}
Args:
alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
inplace: can optionally do the operation in-place. Default: ``False``
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
.. image:: ../scripts/activation_images/ELU.png
Examples::
>>> m = nn.ELU()
>>> input = torch.randn(2)
>>> output = m(input)
"""
__constants__ = ['alpha', 'inplace']
alpha: float
inplace: bool
def __init__(self, alpha: float = 1., inplace: bool = False) -> None:
super(ELU, self).__init__()
self.alpha = alpha
self.inplace = inplace
def forward(self, input: Tensor) -> Tensor:
return F.elu(input, self.alpha, self.inplace)
def extra_repr(self) -> str:
inplace_str = ', inplace=True' if self.inplace else ''
return 'alpha={}{}'.format(self.alpha, inplace_str)
class CELU(Module):
r"""Applies the element-wise function:
.. math::
\text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
More details can be found in the paper `Continuously Differentiable Exponential Linear Units`_ .
Args:
alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
inplace: can optionally do the operation in-place. Default: ``False``
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
.. image:: ../scripts/activation_images/CELU.png
Examples::
>>> m = nn.CELU()
>>> input = torch.randn(2)
>>> output = m(input)
.. _`Continuously Differentiable Exponential Linear Units`:
https://arxiv.org/abs/1704.07483
"""
__constants__ = ['alpha', 'inplace']
alpha: float
inplace: bool
def __init__(self, alpha: float = 1., inplace: bool = False) -> None:
super(CELU, self).__init__()
self.alpha = alpha
self.inplace = inplace
def forward(self, input: Tensor) -> Tensor:
return F.celu(input, self.alpha, self.inplace)
def extra_repr(self) -> str:
inplace_str = ', inplace=True' if self.inplace else ''
return 'alpha={}{}'.format(self.alpha, inplace_str)
class SELU(Module):
r"""Applied element-wise, as:
.. math::
\text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))
with :math:`\alpha = 1.6732632423543772848170429916717` and
:math:`\text{scale} = 1.0507009873554804934193349852946`.
More details can be found in the paper `Self-Normalizing Neural Networks`_ .
Args:
inplace (bool, optional): can optionally do the operation in-place. Default: ``False``
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
.. image:: ../scripts/activation_images/SELU.png
Examples::
>>> m = nn.SELU()
>>> input = torch.randn(2)
>>> output = m(input)
.. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
"""
__constants__ = ['inplace']
inplace: bool
def __init__(self, inplace: bool = False) -> None:
super(SELU, self).__init__()
self.inplace = inplace
def forward(self, input: Tensor) -> Tensor:
return F.selu(input, self.inplace)
def extra_repr(self) -> str:
inplace_str = 'inplace=True' if self.inplace else ''
return inplace_str
class GLU(Module):
r"""Applies the gated linear unit function
:math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
of the input matrices and :math:`b` is the second half.
Args:
dim (int): the dimension on which to split the input. Default: -1
Shape:
- Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
dimensions
- Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
Examples::
>>> m = nn.GLU()
>>> input = torch.randn(4, 2)
>>> output = m(input)
"""
__constants__ = ['dim']
dim: int
def __init__(self, dim: int = -1) -> None:
super(GLU, self).__init__()
self.dim = dim
def forward(self, input: Tensor) -> Tensor:
return F.glu(input, self.dim)
def extra_repr(self) -> str:
return 'dim={}'.format(self.dim)
[docs]class GELU(Module):
r"""Applies the Gaussian Error Linear Units function:
.. math:: \text{GELU}(x) = x * \Phi(x)
where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
.. image:: ../scripts/activation_images/GELU.png
Examples::
>>> m = nn.GELU()
>>> input = torch.randn(2)
>>> output = m(input)
"""
def forward(self, input: Tensor) -> Tensor:
return F.gelu(input)
[docs]class Hardshrink(Module):
r"""Applies the hard shrinkage function element-wise:
.. math::
\text{HardShrink}(x) =
\begin{cases}
x, & \text{ if } x > \lambda \\
x, & \text{ if } x < -\lambda \\
0, & \text{ otherwise }
\end{cases}
Args:
lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
.. image:: ../scripts/activation_images/Hardshrink.png
Examples::
>>> m = nn.Hardshrink()
>>> input = torch.randn(2)
>>> output = m(input)
"""
__constants__ = ['lambd']
lambd: float
def __init__(self, lambd: float = 0.5) -> None:
super(Hardshrink, self).__init__()
self.lambd = lambd
def forward(self, input: Tensor) -> Tensor:
return F.hardshrink(input, self.lambd)
def extra_repr(self) -> str:
return '{}'.format(self.lambd)
class LeakyReLU(Module):
r"""Applies the element-wise function:
.. math::
\text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)
or
.. math::
\text{LeakyRELU}(x) =
\begin{cases}
x, & \text{ if } x \geq 0 \\
\text{negative\_slope} \times x, & \text{ otherwise }
\end{cases}
Args:
negative_slope: Controls the angle of the negative slope. Default: 1e-2
inplace: can optionally do the operation in-place. Default: ``False``
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
.. image:: ../scripts/activation_images/LeakyReLU.png
Examples::
>>> m = nn.LeakyReLU(0.1)
>>> input = torch.randn(2)
>>> output = m(input)
"""
__constants__ = ['inplace', 'negative_slope']
inplace: bool
negative_slope: float
def __init__(self, negative_slope: float = 1e-2, inplace: bool = False) -> None:
super(LeakyReLU, self).__init__()
self.negative_slope = negative_slope
self.inplace = inplace
def forward(self, input: Tensor) -> Tensor:
return F.leaky_relu(input, self.negative_slope, self.inplace)
def extra_repr(self) -> str:
inplace_str = ', inplace=True' if self.inplace else ''
return 'negative_slope={}{}'.format(self.negative_slope, inplace_str)
class LogSigmoid(Module):
r"""Applies the element-wise function:
.. math::
\text{LogSigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right)
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
.. image:: ../scripts/activation_images/LogSigmoid.png
Examples::
>>> m = nn.LogSigmoid()
>>> input = torch.randn(2)
>>> output = m(input)
"""
def forward(self, input: Tensor) -> Tensor:
return F.logsigmoid(input)
class Softplus(Module):
r"""Applies the element-wise function:
.. math::
\text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))
SoftPlus is a smooth approximation to the ReLU function and can be used
to constrain the output of a machine to always be positive.
For numerical stability the implementation reverts to the linear function
when :math:`input \times \beta > threshold`.
Args:
beta: the :math:`\beta` value for the Softplus formulation. Default: 1
threshold: values above this revert to a linear function. Default: 20
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
.. image:: ../scripts/activation_images/Softplus.png
Examples::
>>> m = nn.Softplus()
>>> input = torch.randn(2)
>>> output = m(input)
"""
__constants__ = ['beta', 'threshold']
beta: int
threshold: int
def __init__(self, beta: int = 1, threshold: int = 20) -> None:
super(Softplus, self).__init__()
self.beta = beta
self.threshold = threshold
def forward(self, input: Tensor) -> Tensor:
return F.softplus(input, self.beta, self.threshold)
def extra_repr(self) -> str:
return 'beta={}, threshold={}'.format(self.beta, self.threshold)
class Softshrink(Module):
r"""Applies the soft shrinkage function elementwise:
.. math::
\text{SoftShrinkage}(x) =
\begin{cases}
x - \lambda, & \text{ if } x > \lambda \\
x + \lambda, & \text{ if } x < -\lambda \\
0, & \text{ otherwise }
\end{cases}
Args:
lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
.. image:: ../scripts/activation_images/Softshrink.png
Examples::
>>> m = nn.Softshrink()
>>> input = torch.randn(2)
>>> output = m(input)
"""
__constants__ = ['lambd']
lambd: float
def __init__(self, lambd: float = 0.5) -> None:
super(Softshrink, self).__init__()
self.lambd = lambd
def forward(self, input: Tensor) -> Tensor:
return F.softshrink(input, self.lambd)
def extra_repr(self) -> str:
return str(self.lambd)
class MultiheadAttention(Module):
r"""Allows the model to jointly attend to information
from different representation subspaces.
See `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_
.. math::
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
Args:
embed_dim: total dimension of the model.
num_heads: parallel attention heads.
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
bias: add bias as module parameter. Default: True.
add_bias_kv: add bias to the key and value sequences at dim=0.
add_zero_attn: add a new batch of zeros to the key and
value sequences at dim=1.
kdim: total number of features in key. Default: None.
vdim: total number of features in value. Default: None.
Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set
to :attr:`embed_dim` such that query, key, and value have the same
number of features.
Examples::
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
"""
bias_k: Optional[torch.Tensor]
bias_v: Optional[torch.Tensor]
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
if self._qkv_same_embed_dim is False:
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
self.register_parameter('in_proj_weight', None)
else:
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
self.register_parameter('q_proj_weight', None)
self.register_parameter('k_proj_weight', None)
self.register_parameter('v_proj_weight', None)
if bias:
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
else:
self.register_parameter('in_proj_bias', None)
self.out_proj = _LinearWithBias(embed_dim, embed_dim)
if add_bias_kv:
self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
else:
self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
self._reset_parameters()
def _reset_parameters(self):
if self._qkv_same_embed_dim:
xavier_uniform_(self.in_proj_weight)
else:
xavier_uniform_(self.q_proj_weight)
xavier_uniform_(self.k_proj_weight)
xavier_uniform_(self.v_proj_weight)
if self.in_proj_bias is not None:
constant_(self.in_proj_bias, 0.)
constant_(self.out_proj.bias, 0.)
if self.bias_k is not None:
xavier_normal_(self.bias_k)
if self.bias_v is not None:
xavier_normal_(self.bias_v)
def __setstate__(self, state):
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
if '_qkv_same_embed_dim' not in state:
state['_qkv_same_embed_dim'] = True
super(MultiheadAttention, self).__setstate__(state)
def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
See "Attention Is All You Need" for more details.
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. When given a binary mask and a value is True,
the corresponding value on the attention layer will be ignored. When given
a byte mask and a value is non-zero, the corresponding value on the attention
layer will be ignored
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
Shapes for inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the position
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- attn_mask: if a 2D mask: :math:`(L, S)` where L is the target sequence length, S is the
source sequence length.
If a 3D mask: :math:`(N\cdot\text{num\_heads}, L, S)` where N is the batch size, L is the target sequence
length, S is the source sequence length. ``attn_mask`` ensure that position i is allowed to attend
the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
Shapes for outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length.
"""
if not self._qkv_same_embed_dim:
return F.multi_head_attention_forward(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight, self.in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask, use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight)
else:
return F.multi_head_attention_forward(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight, self.in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask)
class PReLU(Module):
r"""Applies the element-wise function:
.. math::
\text{PReLU}(x) = \max(0,x) + a * \min(0,x)
or
.. math::
\text{PReLU}(x) =
\begin{cases}
x, & \text{ if } x \geq 0 \\
ax, & \text{ otherwise }
\end{cases}
Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single
parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
a separate :math:`a` is used for each input channel.
.. note::
weight decay should not be used when learning :math:`a` for good performance.
.. note::
Channel dim is the 2nd dim of input. When input has dims < 2, then there is
no channel dim and the number of channels = 1.
Args:
num_parameters (int): number of :math:`a` to learn.
Although it takes an int as input, there is only two values are legitimate:
1, or the number of channels at input. Default: 1
init (float): the initial value of :math:`a`. Default: 0.25
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
Attributes:
weight (Tensor): the learnable weights of shape (:attr:`num_parameters`).
.. image:: ../scripts/activation_images/PReLU.png
Examples::
>>> m = nn.PReLU()
>>> input = torch.randn(2)
>>> output = m(input)
"""
__constants__ = ['num_parameters']
num_parameters: int
def __init__(self, num_parameters: int = 1, init: float = 0.25) -> None:
self.num_parameters = num_parameters
super(PReLU, self).__init__()
self.weight = Parameter(torch.Tensor(num_parameters).fill_(init))
def forward(self, input: Tensor) -> Tensor:
return F.prelu(input, self.weight)
def extra_repr(self) -> str:
return 'num_parameters={}'.format(self.num_parameters)
class Softsign(Module):
r"""Applies the element-wise function:
.. math::
\text{SoftSign}(x) = \frac{x}{ 1 + |x|}
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
.. image:: ../scripts/activation_images/Softsign.png
Examples::
>>> m = nn.Softsign()
>>> input = torch.randn(2)
>>> output = m(input)
"""
def forward(self, input: Tensor) -> Tensor:
return F.softsign(input)
class Tanhshrink(Module):
r"""Applies the element-wise function:
.. math::
\text{Tanhshrink}(x) = x - \tanh(x)
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
.. image:: ../scripts/activation_images/Tanhshrink.png
Examples::
>>> m = nn.Tanhshrink()
>>> input = torch.randn(2)
>>> output = m(input)
"""
def forward(self, input: Tensor) -> Tensor:
return F.tanhshrink(input)
class Softmin(Module):
r"""Applies the Softmin function to an n-dimensional input Tensor
rescaling them so that the elements of the n-dimensional output Tensor
lie in the range `[0, 1]` and sum to 1.
Softmin is defined as:
.. math::
\text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
Shape:
- Input: :math:`(*)` where `*` means, any number of additional
dimensions
- Output: :math:`(*)`, same shape as the input
Args:
dim (int): A dimension along which Softmin will be computed (so every slice
along dim will sum to 1).
Returns:
a Tensor of the same dimension and shape as the input, with
values in the range [0, 1]
Examples::
>>> m = nn.Softmin()
>>> input = torch.randn(2, 3)
>>> output = m(input)
"""
__constants__ = ['dim']
dim: Optional[int]
def __init__(self, dim: Optional[int] = None) -> None:
super(Softmin, self).__init__()
self.dim = dim
def __setstate__(self, state):
self.__dict__.update(state)
if not hasattr(self, 'dim'):
self.dim = None
def forward(self, input: Tensor) -> Tensor:
return F.softmin(input, self.dim, _stacklevel=5)
def extra_repr(self):
return 'dim={dim}'.format(dim=self.dim)
class Softmax(Module):
r"""Applies the Softmax function to an n-dimensional input Tensor
rescaling them so that the elements of the n-dimensional output Tensor
lie in the range [0,1] and sum to 1.
Softmax is defined as:
.. math::
\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
When the input Tensor is a sparse tensor then the unspecifed
values are treated as ``-inf``.
Shape:
- Input: :math:`(*)` where `*` means, any number of additional
dimensions
- Output: :math:`(*)`, same shape as the input
Returns:
a Tensor of the same dimension and shape as the input with
values in the range [0, 1]
Args:
dim (int): A dimension along which Softmax will be computed (so every slice
along dim will sum to 1).
.. note::
This module doesn't work directly with NLLLoss,
which expects the Log to be computed between the Softmax and itself.
Use `LogSoftmax` instead (it's faster and has better numerical properties).
Examples::
>>> m = nn.Softmax(dim=1)
>>> input = torch.randn(2, 3)
>>> output = m(input)
"""
__constants__ = ['dim']
dim: Optional[int]
def __init__(self, dim: Optional[int] = None) -> None:
super(Softmax, self).__init__()
self.dim = dim
def __setstate__(self, state):
self.__dict__.update(state)
if not hasattr(self, 'dim'):
self.dim = None
def forward(self, input: Tensor) -> Tensor:
return F.softmax(input, self.dim, _stacklevel=5)
def extra_repr(self) -> str:
return 'dim={dim}'.format(dim=self.dim)
class Softmax2d(Module):
r"""Applies SoftMax over features to each spatial location.
When given an image of ``Channels x Height x Width``, it will
apply `Softmax` to each location :math:`(Channels, h_i, w_j)`
Shape:
- Input: :math:`(N, C, H, W)`
- Output: :math:`(N, C, H, W)` (same shape as input)
Returns:
a Tensor of the same dimension and shape as the input with
values in the range [0, 1]
Examples::
>>> m = nn.Softmax2d()
>>> # you softmax over the 2nd dimension
>>> input = torch.randn(2, 3, 12, 13)
>>> output = m(input)
"""
def forward(self, input: Tensor) -> Tensor:
assert input.dim() == 4, 'Softmax2d requires a 4D tensor as input'
return F.softmax(input, 1, _stacklevel=5)
class LogSoftmax(Module):
r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional
input Tensor. The LogSoftmax formulation can be simplified as:
.. math::
\text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
Shape:
- Input: :math:`(*)` where `*` means, any number of additional
dimensions
- Output: :math:`(*)`, same shape as the input
Args:
dim (int): A dimension along which LogSoftmax will be computed.
Returns:
a Tensor of the same dimension and shape as the input with
values in the range [-inf, 0)
Examples::
>>> m = nn.LogSoftmax()
>>> input = torch.randn(2, 3)
>>> output = m(input)
"""
__constants__ = ['dim']
dim: Optional[int]
def __init__(self, dim: Optional[int] = None) -> None:
super(LogSoftmax, self).__init__()
self.dim = dim
def __setstate__(self, state):
self.__dict__.update(state)
if not hasattr(self, 'dim'):
self.dim = None
def forward(self, input: Tensor) -> Tensor:
return F.log_softmax(input, self.dim, _stacklevel=5)
def extra_repr(self):
return 'dim={dim}'.format(dim=self.dim)