Baddbmm
Source:R/gen-namespace-docs.R, R/gen-namespace-examples.R, R/gen-namespace.R
torch_baddbmm.RdBaddbmm
Arguments
- self
(Tensor) the tensor to be added
- batch1
(Tensor) the first batch of matrices to be multiplied
- batch2
(Tensor) the second batch of matrices to be multiplied
- beta
(Number, optional) multiplier for
input(\(\beta\))- alpha
(Number, optional) multiplier for \(\mbox{batch1} \mathbin{@} \mbox{batch2}\) (\(\alpha\))
baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=NULL) -> Tensor
Performs a batch matrix-matrix product of matrices in batch1
and batch2.
input is added to the final result.
batch1 and batch2 must be 3-D tensors each containing the same
number of matrices.
If batch1 is a \((b \times n \times m)\) tensor, batch2 is a
\((b \times m \times p)\) tensor, then input must be
broadcastable with a
\((b \times n \times p)\) tensor and out will be a
\((b \times n \times p)\) tensor. Both alpha and beta mean the
same as the scaling factors used in torch_addbmm.
$$
\mbox{out}_i = \beta\ \mbox{input}_i + \alpha\ (\mbox{batch1}_i \mathbin{@} \mbox{batch2}_i)
$$
For inputs of type FloatTensor or DoubleTensor, arguments beta and
alpha must be real numbers, otherwise they should be integers.
Examples
if (torch_is_installed()) {
M = torch_randn(c(10, 3, 5))
batch1 = torch_randn(c(10, 3, 4))
batch2 = torch_randn(c(10, 4, 5))
torch_baddbmm(M, batch1, batch2)
}
#> torch_tensor
#> (1,.,.) =
#> -0.3125 0.4709 -6.0409 -0.8507 1.4647
#> -3.6003 0.0514 -0.3210 0.2408 -0.2266
#> -0.3610 -4.2665 0.7647 3.7500 0.7592
#>
#> (2,.,.) =
#> 2.8620 -4.8879 -2.4154 1.9197 1.1557
#> -0.9771 -2.0414 -2.7588 2.0003 0.1726
#> -2.5435 0.7578 -0.8908 1.0419 -1.1945
#>
#> (3,.,.) =
#> -1.8110 0.7073 -0.2344 -0.0336 -0.6837
#> -0.9149 -0.8904 -0.0850 0.3731 0.6879
#> 1.8477 0.3263 0.9845 -1.1609 0.4395
#>
#> (4,.,.) =
#> -2.0647 -1.2823 -1.0871 -0.6356 2.8242
#> -0.4440 0.2274 -1.3624 -2.1402 0.8623
#> -2.3443 3.8567 0.1094 4.2228 -7.2962
#>
#> (5,.,.) =
#> 0.0310 0.9185 -2.9944 2.9899 2.5773
#> 0.3958 -0.1560 -2.8712 0.3905 -4.6104
#> 3.8791 -1.3918 -0.7317 -3.3647 -3.5260
#>
#> (6,.,.) =
#> -0.9752 -6.3182 5.5663 3.5570 -6.6013
#> 4.9685 4.1413 -0.8252 -4.7024 3.7466
#> -2.7136 -1.9025 1.8377 2.3034 -1.3702
#>
#> ... [the output was truncated (use n=-1 to disable)]
#> [ CPUFloatType{10,3,5} ]