Baddbmm
Source:R/gen-namespace-docs.R
, R/gen-namespace-examples.R
, R/gen-namespace.R
torch_baddbmm.Rd
Baddbmm
Arguments
- self
(Tensor) the tensor to be added
- batch1
(Tensor) the first batch of matrices to be multiplied
- batch2
(Tensor) the second batch of matrices to be multiplied
- beta
(Number, optional) multiplier for
input
(\(\beta\))- alpha
(Number, optional) multiplier for \(\mbox{batch1} \mathbin{@} \mbox{batch2}\) (\(\alpha\))
baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=NULL) -> Tensor
Performs a batch matrix-matrix product of matrices in batch1
and batch2
.
input
is added to the final result.
batch1
and batch2
must be 3-D tensors each containing the same
number of matrices.
If batch1
is a \((b \times n \times m)\) tensor, batch2
is a
\((b \times m \times p)\) tensor, then input
must be
broadcastable with a
\((b \times n \times p)\) tensor and out
will be a
\((b \times n \times p)\) tensor. Both alpha
and beta
mean the
same as the scaling factors used in torch_addbmm
.
$$
\mbox{out}_i = \beta\ \mbox{input}_i + \alpha\ (\mbox{batch1}_i \mathbin{@} \mbox{batch2}_i)
$$
For inputs of type FloatTensor
or DoubleTensor
, arguments beta
and
alpha
must be real numbers, otherwise they should be integers.
Examples
if (torch_is_installed()) {
M = torch_randn(c(10, 3, 5))
batch1 = torch_randn(c(10, 3, 4))
batch2 = torch_randn(c(10, 4, 5))
torch_baddbmm(M, batch1, batch2)
}
#> torch_tensor
#> (1,.,.) =
#> 2.7117 -1.1912 4.9084 0.6841 0.2307
#> 1.7016 1.6146 1.3161 -4.9664 -3.2845
#> -0.8329 -1.5403 0.5754 1.2187 0.9986
#>
#> (2,.,.) =
#> -2.5038 -2.2476 1.1316 -0.8145 2.3152
#> 0.5651 -1.3365 -1.0422 1.9366 -3.1478
#> -1.4987 1.8500 -1.7202 -1.1271 -2.6041
#>
#> (3,.,.) =
#> 6.6814 -3.4016 4.2021 1.4307 -1.3391
#> 2.3543 1.0567 0.1357 0.3702 0.5046
#> 1.9748 1.1347 2.6923 -0.6518 -2.0495
#>
#> (4,.,.) =
#> 1.2919 2.4984 -1.3428 2.2040 1.0927
#> 3.6114 -0.1126 -0.5390 -1.1033 2.0112
#> 1.9713 1.4530 -0.3242 1.0933 1.2345
#>
#> (5,.,.) =
#> -1.5820 2.5815 0.2429 2.6317 -0.9560
#> 1.0474 -1.6117 -0.8069 -1.2014 -0.2397
#> -3.0932 1.9581 1.4640 0.9919 -0.7975
#>
#> (6,.,.) =
#> 0.6682 0.3185 0.8619 0.4187 0.2971
#> 2.2023 1.9607 -2.1154 1.8012 0.4424
#> -0.2887 -1.2457 3.0106 1.0587 -0.5133
#>
#> ... [the output was truncated (use n=-1 to disable)]
#> [ CPUFloatType{10,3,5} ]