Selects values from input at the 1-dimensional indices from indices along the given dim.
Source:R/gen-namespace-docs.R, R/gen-namespace.R
torch_take_along_dim.RdSelects values from input at the 1-dimensional indices from indices along the given dim.
Note
If dim is NULL, the input array is treated as if it has been flattened to 1d.
Functions that return indices along a dimension, like torch_argmax() and torch_argsort(),
are designed to work with this function. See the examples below.
Examples
if (torch_is_installed()) {
t <- torch_tensor(matrix(c(10, 30, 20, 60, 40, 50), nrow = 2))
max_idx <- torch_argmax(t)
torch_take_along_dim(t, max_idx)
sorted_idx <- torch_argsort(t, dim=2)
torch_take_along_dim(t, sorted_idx, dim=2)
}
#> torch_tensor
#> 10 20 40
#> 30 50 60
#> [ CPUFloatType{2,3} ]