Baddbmm
Source:R/gen-namespace-docs.R
, R/gen-namespace-examples.R
, R/gen-namespace.R
torch_baddbmm.Rd
Baddbmm
Arguments
- self
(Tensor) the tensor to be added
- batch1
(Tensor) the first batch of matrices to be multiplied
- batch2
(Tensor) the second batch of matrices to be multiplied
- beta
(Number, optional) multiplier for
input
(\(\beta\))- alpha
(Number, optional) multiplier for \(\mbox{batch1} \mathbin{@} \mbox{batch2}\) (\(\alpha\))
baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=NULL) -> Tensor
Performs a batch matrix-matrix product of matrices in batch1
and batch2
.
input
is added to the final result.
batch1
and batch2
must be 3-D tensors each containing the same
number of matrices.
If batch1
is a \((b \times n \times m)\) tensor, batch2
is a
\((b \times m \times p)\) tensor, then input
must be
broadcastable with a
\((b \times n \times p)\) tensor and out
will be a
\((b \times n \times p)\) tensor. Both alpha
and beta
mean the
same as the scaling factors used in torch_addbmm
.
$$
\mbox{out}_i = \beta\ \mbox{input}_i + \alpha\ (\mbox{batch1}_i \mathbin{@} \mbox{batch2}_i)
$$
For inputs of type FloatTensor
or DoubleTensor
, arguments beta
and
alpha
must be real numbers, otherwise they should be integers.
Examples
if (torch_is_installed()) {
M = torch_randn(c(10, 3, 5))
batch1 = torch_randn(c(10, 3, 4))
batch2 = torch_randn(c(10, 4, 5))
torch_baddbmm(M, batch1, batch2)
}
#> torch_tensor
#> (1,.,.) =
#> -0.9133 -0.7796 6.9415 -0.8429 1.6910
#> -1.1649 2.3691 0.9948 -0.0708 2.0124
#> 3.1791 -0.4127 0.7201 -0.5607 1.4573
#>
#> (2,.,.) =
#> -0.4013 0.4887 3.9224 -0.8925 -2.2159
#> -0.9042 -0.8472 -1.3981 -1.6880 1.8332
#> -2.3427 -2.7121 0.6380 -0.6552 -0.2278
#>
#> (3,.,.) =
#> 3.8180 -4.7817 2.1171 1.7672 -0.3649
#> 4.3834 -1.2660 0.2749 1.6570 -0.1524
#> 5.5033 -6.1041 7.7504 3.6278 2.1958
#>
#> (4,.,.) =
#> -0.4392 1.6435 2.9361 -1.2175 3.0783
#> 0.9815 -3.3663 -0.7723 -1.7715 -0.9704
#> -3.3060 1.6847 -2.7370 0.2200 -1.9242
#>
#> (5,.,.) =
#> 0.4078 -0.5546 1.4380 0.9488 -0.9015
#> 1.1546 2.0692 -1.1276 -0.8037 -0.3632
#> -0.0663 0.7661 -0.0254 2.5303 -0.2988
#>
#> (6,.,.) =
#> -2.8348 -1.9627 -0.9978 0.5599 0.0186
#> 2.3897 -2.8633 -1.5330 4.7937 -4.3234
#> 1.0622 0.1465 1.6078 0.1947 0.1206
#>
#> ... [the output was truncated (use n=-1 to disable)]
#> [ CPUFloatType{10,3,5} ]