Logic underlying luz_callback_mixup().
Arguments
- x
an input batch
- y
a target batch
- weight
weighting coefficient to be used by
torch_lerp()
Value
A list of:
x, the new, mixed-up input batchy, alistof:ys, alistof:y1, the original targety1y2, the mixed-in targety2
weight, the mixing weights
Details
Based on the passed-in input and target batches, as well as applicable mixing weights, we return new tensors intended to replace the current batch. The new input batch is a weighted linear combination of input batch items, while the new target batch bundles the original targets, as well as the mixing weights, in a nested list.
Examples
if (torch::torch_is_installed()) {
batch_x <- torch::torch_randn(c(10, 768))
batch_y <- torch::torch_randn(10)
weight <- torch::torch_tensor(rep(0.9, 10))$view(c(10, 1))
nnf_mixup(batch_x, batch_y, weight)
}
#> $x
#> torch_tensor
#> Columns 1 to 6 2.0106e-02 1.2700e+00 -4.5297e-01 1.6274e-01 5.2511e-01 -2.9097e+00
#> 1.1165e+00 -1.1651e+00 -1.2490e-02 9.4594e-01 6.0349e-01 -1.2466e-01
#> -1.4823e+00 -1.0984e+00 1.3790e-01 8.2407e-01 1.0687e+00 1.5771e+00
#> -3.8448e-01 -2.4186e-01 -1.8506e+00 -2.4928e-01 7.0680e-01 -8.1963e-01
#> -6.9323e-01 1.0644e+00 7.2743e-02 -6.5212e-01 -1.0365e+00 -1.6201e+00
#> 3.0431e-01 1.3506e+00 6.0425e-01 3.9446e-02 1.1896e+00 -3.1331e-01
#> -6.5934e-01 -1.0023e-01 -1.6981e+00 7.7656e-01 -3.7352e-02 -1.2685e+00
#> -8.8970e-01 -6.1145e-01 -1.5474e-01 -1.1202e+00 9.7077e-01 -1.2335e-01
#> 1.0432e-01 4.7935e-01 1.1206e+00 8.4541e-01 -6.2771e-01 -6.9416e-01
#> -7.2512e-01 -8.7937e-01 -1.1267e+00 1.1258e+00 -6.2741e-01 1.1046e+00
#>
#> Columns 7 to 12-5.5276e-01 1.2940e+00 6.8571e-02 5.2053e-01 -1.4397e-01 6.3606e-01
#> 9.5878e-01 1.1129e+00 -8.2416e-01 1.3010e+00 3.1479e-01 -2.4285e-01
#> -4.6636e-01 2.2018e-02 9.8940e-01 -9.2898e-01 1.1628e+00 1.6556e+00
#> 6.0739e-01 9.3763e-01 -2.1561e-01 9.9574e-01 -4.3450e-01 1.2759e+00
#> -1.2120e+00 -2.9577e-01 -1.7762e-01 -2.0190e+00 3.5191e-02 -6.8218e-01
#> -1.2246e+00 -3.5967e-01 -9.4566e-01 7.7871e-01 -1.8878e+00 -1.1342e-01
#> -1.3023e-01 -1.0773e-01 -3.9717e-01 4.8980e-01 -7.6161e-01 9.8832e-01
#> -7.0246e-03 -1.9394e+00 7.4958e-01 -2.8474e-01 -1.3539e+00 -6.4441e-02
#> 5.4538e-01 -1.5291e-01 -1.5370e+00 -1.3794e+00 -9.8801e-01 5.7222e-01
#> 8.1105e-01 -2.5166e-01 4.3286e-01 -1.2386e+00 -2.8035e-01 7.9271e-01
#>
#> Columns 13 to 18 1.4179e-01 -9.1660e-02 8.7146e-01 6.1755e-01 4.4270e-01 1.4312e+00
#> 3.5975e-01 4.4905e-01 1.1004e+00 -1.9853e+00 1.0673e+00 1.8135e-01
#> 2.6951e-01 -4.6485e-01 -2.3144e+00 1.1701e+00 7.6070e-01 2.2296e-01
#> -4.0147e-01 -8.2541e-01 -5.7005e-01 -7.7117e-01 -1.5695e+00 -9.3967e-01
#> 9.4121e-02 -1.1493e+00 -3.9286e-01 -1.6607e+00 4.3537e-02 -4.7863e-01
#> 4.3758e-01 -7.6054e-01 2.7724e-01 -2.8658e-02 -5.9837e-01 -6.7632e-01
#> 1.1467e+00 -1.3246e-01 1.3957e+00 -2.2697e-01 -3.3288e-01 -4.5025e-01
#> -1.8194e+00 6.1576e-01 -7.0159e-01 -1.8710e+00 5.4053e-01 6.6031e-01
#> ... [the output was truncated (use n=-1 to disable)]
#> [ CPUFloatType{10,768} ]
#>
#> $y
#> $y$ys
#> $y$ys$y1
#> torch_tensor
#> 1.4136
#> -1.0387
#> 0.8239
#> -0.1957
#> 0.1990
#> 1.8951
#> -0.6351
#> -0.0262
#> 0.5783
#> -0.9430
#> [ CPUFloatType{10} ]
#>
#> $y$ys$y2
#> torch_tensor
#> -0.6351
#> -0.1957
#> -0.9430
#> 0.8239
#> -0.0262
#> 0.1990
#> 0.5783
#> 1.4136
#> 1.8951
#> -1.0387
#> [ CPUFloatType{10} ]
#>
#>
#> $y$weight
#> torch_tensor
#> 0.9000
#> 0.9000
#> 0.9000
#> 0.9000
#> 0.9000
#> 0.9000
#> 0.9000
#> 0.9000
#> 0.9000
#> 0.9000
#> [ CPUFloatType{10,1} ]
#>
#>