Bmm
Note
This function does not broadcast .
For broadcasting matrix products, see torch_matmul
.
bmm(input, mat2, out=NULL) -> Tensor
Performs a batch matrix-matrix product of matrices stored in input
and mat2
.
input
and mat2
must be 3-D tensors each containing
the same number of matrices.
If input
is a \((b \times n \times m)\) tensor, mat2
is a
\((b \times m \times p)\) tensor, out
will be a
\((b \times n \times p)\) tensor.
$$ \mbox{out}_i = \mbox{input}_i \mathbin{@} \mbox{mat2}_i $$
Examples
if (torch_is_installed()) {
input = torch_randn(c(10, 3, 4))
mat2 = torch_randn(c(10, 4, 5))
res = torch_bmm(input, mat2)
res
}
#> torch_tensor
#> (1,.,.) =
#> 0.3372 -0.9793 -0.9178 1.6752 -3.9694
#> 0.6731 0.4739 0.6238 -0.8550 1.3649
#> 1.2408 -1.9492 4.0812 -0.5648 1.4745
#>
#> (2,.,.) =
#> 0.4185 -1.6425 0.8958 -0.5948 -0.4059
#> 0.3336 -0.8206 1.6552 1.8691 -3.3935
#> -1.9448 -1.6547 1.8130 -3.0308 0.8444
#>
#> (3,.,.) =
#> -1.7618 -0.0694 0.8885 -0.2286 4.5873
#> 2.8709 1.3536 0.6855 1.7678 -1.4571
#> -1.9785 2.2843 2.3418 -2.1916 5.1376
#>
#> (4,.,.) =
#> 1.0374 1.4050 0.1721 0.9362 1.3186
#> -1.2310 0.5466 3.3942 2.2200 -3.1539
#> 1.9582 -0.0251 -0.6099 -0.8545 -0.6797
#>
#> (5,.,.) =
#> 0.7116 1.7785 -0.6294 4.2205 -0.1129
#> 0.2466 0.0895 -1.5727 5.1407 0.3615
#> 0.1041 -1.8901 1.4343 -5.1313 -0.3243
#>
#> (6,.,.) =
#> -0.9272 -0.6083 1.1190 2.8363 -0.7331
#> -0.1927 0.3642 0.0112 0.2584 0.7779
#> 0.9377 -0.2305 -2.8036 -3.4465 -0.5447
#>
#> ... [the output was truncated (use n=-1 to disable)]
#> [ CPUFloatType{10,3,5} ]