Bmm
Note
This function does not broadcast .
For broadcasting matrix products, see torch_matmul
.
bmm(input, mat2, out=NULL) -> Tensor
Performs a batch matrix-matrix product of matrices stored in input
and mat2
.
input
and mat2
must be 3-D tensors each containing
the same number of matrices.
If input
is a \((b \times n \times m)\) tensor, mat2
is a
\((b \times m \times p)\) tensor, out
will be a
\((b \times n \times p)\) tensor.
$$ \mbox{out}_i = \mbox{input}_i \mathbin{@} \mbox{mat2}_i $$
Examples
if (torch_is_installed()) {
input = torch_randn(c(10, 3, 4))
mat2 = torch_randn(c(10, 4, 5))
res = torch_bmm(input, mat2)
res
}
#> torch_tensor
#> (1,.,.) =
#> -1.7660 -0.0806 1.0352 0.0796 -2.2862
#> -0.3558 0.7149 -0.2041 0.3123 -0.6912
#> -0.4830 -0.3687 -1.6232 -0.0344 -1.8763
#>
#> (2,.,.) =
#> 0.2255 -0.6348 1.8854 0.4987 0.1007
#> 0.2278 -0.5640 -3.6552 -1.0890 1.0498
#> -0.9328 0.1237 2.2589 0.3648 -0.4633
#>
#> (3,.,.) =
#> -2.1111 2.2743 0.1642 -2.6410 0.1674
#> 1.3347 -1.0351 -0.2361 1.0920 -1.3609
#> 1.5108 0.2095 -0.9747 1.8780 -2.8055
#>
#> (4,.,.) =
#> -4.5610 -3.9504 1.9734 -1.3507 1.4226
#> 0.1380 -2.8356 -0.7851 -3.5248 1.1184
#> 1.3115 0.0982 2.3809 3.2303 0.9007
#>
#> (5,.,.) =
#> 1.3222 0.1414 -0.9732 -0.9802 0.6976
#> 0.5068 -0.5360 0.1973 0.6529 0.2519
#> -3.0596 2.3465 -0.8538 -2.5193 2.8394
#>
#> (6,.,.) =
#> 0.3266 -0.1740 1.9548 -0.4432 -5.0411
#> 0.2572 0.0003 -1.2587 -0.2590 -1.2956
#> 2.1741 1.2181 -2.0571 -1.0165 1.0470
#>
#> ... [the output was truncated (use n=-1 to disable)]
#> [ CPUFloatType{10,3,5} ]