Applies the element-wise function: $$ \mbox{PReLU}(x) = \max(0,x) + a * \min(0,x) $$ or $$ \mbox{PReLU}(x) = \left\{ \begin{array}{ll} x, & \mbox{ if } x \geq 0 \\ ax, & \mbox{ otherwise } \end{array} \right. $$
Details
Here \(a\) is a learnable parameter. When called without arguments, nn.prelu()
uses a single
parameter \(a\) across all input channels. If called with nn_prelu(nChannels)
,
a separate \(a\) is used for each input channel.
Note
weight decay should not be used when learning \(a\) for good performance.
Channel dim is the 2nd dim of input. When input has dims < 2, then there is no channel dim and the number of channels = 1.
Shape
Input: \((N, *)\) where
*
means, any number of additional dimensionsOutput: \((N, *)\), same shape as the input
Examples
if (torch_is_installed()) {
m <- nn_prelu()
input <- torch_randn(2)
output <- m(input)
}