Selects values from input at the 1-dimensional indices from indices along the given dim.
Source:R/gen-namespace-docs.R
, R/gen-namespace.R
torch_take_along_dim.Rd
Selects values from input at the 1-dimensional indices from indices along the given dim.
Note
If dim is NULL
, the input array is treated as if it has been flattened to 1d.
Functions that return indices along a dimension, like torch_argmax()
and torch_argsort()
,
are designed to work with this function. See the examples below.
Examples
if (torch_is_installed()) {
t <- torch_tensor(matrix(c(10, 30, 20, 60, 40, 50), nrow = 2))
max_idx <- torch_argmax(t)
torch_take_along_dim(t, max_idx)
sorted_idx <- torch_argsort(t, dim=2)
torch_take_along_dim(t, sorted_idx, dim=2)
}
#> torch_tensor
#> 10 20 40
#> 30 50 60
#> [ CPUFloatType{2,3} ]