Skip to contents

Logic underlying luz_callback_mixup().

Usage

nnf_mixup(x, y, weight)

Arguments

x

an input batch

y

a target batch

weight

weighting coefficient to be used by torch_lerp()

Value

A list of:

  • x, the new, mixed-up input batch

  • y, a list of:

    • ys, a list of:

      • y1, the original target y1

      • y2, the mixed-in target y2

    • weight, the mixing weights

Details

Based on the passed-in input and target batches, as well as applicable mixing weights, we return new tensors intended to replace the current batch. The new input batch is a weighted linear combination of input batch items, while the new target batch bundles the original targets, as well as the mixing weights, in a nested list.

Examples

if (torch::torch_is_installed()) {
batch_x <- torch::torch_randn(c(10, 768))
batch_y <- torch::torch_randn(10)
weight <- torch::torch_tensor(rep(0.9, 10))$view(c(10, 1))
nnf_mixup(batch_x, batch_y, weight)
}
#> $x
#> torch_tensor
#> Columns 1 to 6 2.0106e-02  1.2700e+00 -4.5297e-01  1.6274e-01  5.2511e-01 -2.9097e+00
#>  1.1165e+00 -1.1651e+00 -1.2490e-02  9.4594e-01  6.0349e-01 -1.2466e-01
#> -1.4823e+00 -1.0984e+00  1.3790e-01  8.2407e-01  1.0687e+00  1.5771e+00
#> -3.8448e-01 -2.4186e-01 -1.8506e+00 -2.4928e-01  7.0680e-01 -8.1963e-01
#> -6.9323e-01  1.0644e+00  7.2743e-02 -6.5212e-01 -1.0365e+00 -1.6201e+00
#>  3.0431e-01  1.3506e+00  6.0425e-01  3.9446e-02  1.1896e+00 -3.1331e-01
#> -6.5934e-01 -1.0023e-01 -1.6981e+00  7.7656e-01 -3.7352e-02 -1.2685e+00
#> -8.8970e-01 -6.1145e-01 -1.5474e-01 -1.1202e+00  9.7077e-01 -1.2335e-01
#>  1.0432e-01  4.7935e-01  1.1206e+00  8.4541e-01 -6.2771e-01 -6.9416e-01
#> -7.2512e-01 -8.7937e-01 -1.1267e+00  1.1258e+00 -6.2741e-01  1.1046e+00
#> 
#> Columns 7 to 12-5.5276e-01  1.2940e+00  6.8571e-02  5.2053e-01 -1.4397e-01  6.3606e-01
#>  9.5878e-01  1.1129e+00 -8.2416e-01  1.3010e+00  3.1479e-01 -2.4285e-01
#> -4.6636e-01  2.2018e-02  9.8940e-01 -9.2898e-01  1.1628e+00  1.6556e+00
#>  6.0739e-01  9.3763e-01 -2.1561e-01  9.9574e-01 -4.3450e-01  1.2759e+00
#> -1.2120e+00 -2.9577e-01 -1.7762e-01 -2.0190e+00  3.5191e-02 -6.8218e-01
#> -1.2246e+00 -3.5967e-01 -9.4566e-01  7.7871e-01 -1.8878e+00 -1.1342e-01
#> -1.3023e-01 -1.0773e-01 -3.9717e-01  4.8980e-01 -7.6161e-01  9.8832e-01
#> -7.0246e-03 -1.9394e+00  7.4958e-01 -2.8474e-01 -1.3539e+00 -6.4441e-02
#>  5.4538e-01 -1.5291e-01 -1.5370e+00 -1.3794e+00 -9.8801e-01  5.7222e-01
#>  8.1105e-01 -2.5166e-01  4.3286e-01 -1.2386e+00 -2.8035e-01  7.9271e-01
#> 
#> Columns 13 to 18 1.4179e-01 -9.1660e-02  8.7146e-01  6.1755e-01  4.4270e-01  1.4312e+00
#>  3.5975e-01  4.4905e-01  1.1004e+00 -1.9853e+00  1.0673e+00  1.8135e-01
#>  2.6951e-01 -4.6485e-01 -2.3144e+00  1.1701e+00  7.6070e-01  2.2296e-01
#> -4.0147e-01 -8.2541e-01 -5.7005e-01 -7.7117e-01 -1.5695e+00 -9.3967e-01
#>  9.4121e-02 -1.1493e+00 -3.9286e-01 -1.6607e+00  4.3537e-02 -4.7863e-01
#>  4.3758e-01 -7.6054e-01  2.7724e-01 -2.8658e-02 -5.9837e-01 -6.7632e-01
#>  1.1467e+00 -1.3246e-01  1.3957e+00 -2.2697e-01 -3.3288e-01 -4.5025e-01
#> -1.8194e+00  6.1576e-01 -7.0159e-01 -1.8710e+00  5.4053e-01  6.6031e-01
#> ... [the output was truncated (use n=-1 to disable)]
#> [ CPUFloatType{10,768} ]
#> 
#> $y
#> $y$ys
#> $y$ys$y1
#> torch_tensor
#>  1.4136
#> -1.0387
#>  0.8239
#> -0.1957
#>  0.1990
#>  1.8951
#> -0.6351
#> -0.0262
#>  0.5783
#> -0.9430
#> [ CPUFloatType{10} ]
#> 
#> $y$ys$y2
#> torch_tensor
#> -0.6351
#> -0.1957
#> -0.9430
#>  0.8239
#> -0.0262
#>  0.1990
#>  0.5783
#>  1.4136
#>  1.8951
#> -1.0387
#> [ CPUFloatType{10} ]
#> 
#> 
#> $y$weight
#> torch_tensor
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#> [ CPUFloatType{10,1} ]
#> 
#>