Baddbmm
Source:R/gen-namespace-docs.R
, R/gen-namespace-examples.R
, R/gen-namespace.R
torch_baddbmm.Rd
Baddbmm
Arguments
- self
(Tensor) the tensor to be added
- batch1
(Tensor) the first batch of matrices to be multiplied
- batch2
(Tensor) the second batch of matrices to be multiplied
- beta
(Number, optional) multiplier for
input
(\(\beta\))- alpha
(Number, optional) multiplier for \(\mbox{batch1} \mathbin{@} \mbox{batch2}\) (\(\alpha\))
baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=NULL) -> Tensor
Performs a batch matrix-matrix product of matrices in batch1
and batch2
.
input
is added to the final result.
batch1
and batch2
must be 3-D tensors each containing the same
number of matrices.
If batch1
is a \((b \times n \times m)\) tensor, batch2
is a
\((b \times m \times p)\) tensor, then input
must be
broadcastable with a
\((b \times n \times p)\) tensor and out
will be a
\((b \times n \times p)\) tensor. Both alpha
and beta
mean the
same as the scaling factors used in torch_addbmm
.
$$
\mbox{out}_i = \beta\ \mbox{input}_i + \alpha\ (\mbox{batch1}_i \mathbin{@} \mbox{batch2}_i)
$$
For inputs of type FloatTensor
or DoubleTensor
, arguments beta
and
alpha
must be real numbers, otherwise they should be integers.
Examples
if (torch_is_installed()) {
M = torch_randn(c(10, 3, 5))
batch1 = torch_randn(c(10, 3, 4))
batch2 = torch_randn(c(10, 4, 5))
torch_baddbmm(M, batch1, batch2)
}
#> torch_tensor
#> (1,.,.) =
#> -2.2340 1.7635 0.9479 2.3790 -1.2281
#> -0.3416 -0.7033 4.4765 -0.5124 -2.9818
#> 6.6451 -1.7640 4.7803 1.0714 -5.2186
#>
#> (2,.,.) =
#> -0.4415 0.7622 1.4382 2.0461 -1.9238
#> -2.2134 -1.0698 2.0950 -0.5376 -0.3850
#> 2.8008 -1.2804 1.0730 0.8620 1.8867
#>
#> (3,.,.) =
#> -3.1838 -2.5005 1.1407 3.6254 3.3644
#> -0.6580 1.4790 -0.6131 -2.4125 -0.6438
#> -2.1321 1.3556 -1.0501 3.4206 -0.0329
#>
#> (4,.,.) =
#> -0.0207 3.5412 6.3124 3.5410 -2.5211
#> 2.4532 -0.1371 -2.2309 -0.2358 -3.9424
#> -0.3182 0.5569 -0.4766 -1.1889 -2.4155
#>
#> (5,.,.) =
#> -2.0903 2.1367 -0.1270 -1.3680 3.3815
#> 0.7117 -0.3244 -0.5709 2.5757 -1.6880
#> 2.3461 -1.0594 -0.3505 0.2868 -3.2575
#>
#> (6,.,.) =
#> -0.3449 -2.8094 4.7409 -0.9218 -0.3620
#> 2.1505 -0.3360 1.3940 2.2367 -5.4588
#> 0.1223 2.8051 1.0574 -1.2303 1.6856
#>
#> ... [the output was truncated (use n=-1 to disable)]
#> [ CPUFloatType{10,3,5} ]