Baddbmm
Source:R/gen-namespace-docs.R
, R/gen-namespace-examples.R
, R/gen-namespace.R
torch_baddbmm.Rd
Baddbmm
Arguments
- self
(Tensor) the tensor to be added
- batch1
(Tensor) the first batch of matrices to be multiplied
- batch2
(Tensor) the second batch of matrices to be multiplied
- beta
(Number, optional) multiplier for
input
(\(\beta\))- alpha
(Number, optional) multiplier for \(\mbox{batch1} \mathbin{@} \mbox{batch2}\) (\(\alpha\))
baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=NULL) -> Tensor
Performs a batch matrix-matrix product of matrices in batch1
and batch2
.
input
is added to the final result.
batch1
and batch2
must be 3-D tensors each containing the same
number of matrices.
If batch1
is a \((b \times n \times m)\) tensor, batch2
is a
\((b \times m \times p)\) tensor, then input
must be
broadcastable with a
\((b \times n \times p)\) tensor and out
will be a
\((b \times n \times p)\) tensor. Both alpha
and beta
mean the
same as the scaling factors used in torch_addbmm
.
$$
\mbox{out}_i = \beta\ \mbox{input}_i + \alpha\ (\mbox{batch1}_i \mathbin{@} \mbox{batch2}_i)
$$
For inputs of type FloatTensor
or DoubleTensor
, arguments beta
and
alpha
must be real numbers, otherwise they should be integers.
Examples
if (torch_is_installed()) {
M = torch_randn(c(10, 3, 5))
batch1 = torch_randn(c(10, 3, 4))
batch2 = torch_randn(c(10, 4, 5))
torch_baddbmm(M, batch1, batch2)
}
#> torch_tensor
#> (1,.,.) =
#> -4.6742 -0.7577 2.6754 0.8191 -0.7530
#> 4.3995 -1.2198 -2.3838 -1.5699 -0.1833
#> -2.8331 -1.6828 -0.1383 2.6218 0.2945
#>
#> (2,.,.) =
#> -4.0776 2.4758 -2.4423 0.0407 1.0181
#> 0.9602 -0.7632 -1.9180 -1.5237 2.1434
#> -4.3002 -4.8091 -1.5533 1.2892 0.8208
#>
#> (3,.,.) =
#> 1.8408 -0.2855 -1.1812 -0.8825 0.3052
#> -2.4125 -1.0249 2.2428 1.0103 1.4386
#> -0.8179 2.3056 0.6533 -0.3723 0.8718
#>
#> (4,.,.) =
#> 4.3023 10.9624 1.4803 0.7257 -3.6036
#> 2.6867 4.3710 2.7404 -0.6179 -0.0689
#> 3.8225 4.8994 2.5229 0.5117 -1.4507
#>
#> (5,.,.) =
#> 0.0914 -1.7040 0.5731 -3.0537 -2.3071
#> -1.6430 -0.9083 -0.2146 -1.2739 1.5534
#> -0.8725 1.0120 -0.3942 -0.9003 -3.1054
#>
#> (6,.,.) =
#> -1.5916 -2.2534 3.0997 1.6804 3.2010
#> -0.3672 -0.5168 -3.0351 3.4030 5.5081
#> 1.6120 -1.9197 -4.6837 2.6320 1.9936
#>
#> ... [the output was truncated (use n=-1 to disable)]
#> [ CPUFloatType{10,3,5} ]