Skip to contents

Logic underlying luz_callback_mixup().

Usage

nnf_mixup(x, y, weight)

Arguments

x

an input batch

y

a target batch

weight

weighting coefficient to be used by torch_lerp()

Value

A list of:

  • x, the new, mixed-up input batch

  • y, a list of:

    • ys, a list of:

      • y1, the original target y1

      • y2, the mixed-in target y2

    • weight, the mixing weights

Details

Based on the passed-in input and target batches, as well as applicable mixing weights, we return new tensors intended to replace the current batch. The new input batch is a weighted linear combination of input batch items, while the new target batch bundles the original targets, as well as the mixing weights, in a nested list.

Examples

if (torch::torch_is_installed()) {
batch_x <- torch::torch_randn(c(10, 768))
batch_y <- torch::torch_randn(10)
weight <- torch::torch_tensor(rep(0.9, 10))$view(c(10, 1))
nnf_mixup(batch_x, batch_y, weight)
}
#> $x
#> torch_tensor
#> Columns 1 to 6-9.2156e-01  1.7575e+00  1.5404e+00 -1.2119e+00 -1.1037e+00  6.0738e-01
#> -6.1126e-01 -1.1691e+00 -1.2765e+00  6.7805e-01  7.8990e-01  3.1883e-01
#>  6.2647e-02  1.2564e-01 -4.0070e-01 -7.0199e-01  4.2789e-01 -1.1532e+00
#>  1.2746e+00 -2.0882e+00  6.7265e-02 -1.0254e+00  4.8568e-01  4.5166e-01
#> -4.0371e-01  6.8361e-01 -1.1215e-01 -1.2972e+00 -3.3998e-01  1.2756e+00
#>  1.0619e+00  1.2164e+00  2.1873e+00  6.5320e-01  2.5060e-02 -2.1226e+00
#>  8.7148e-02 -1.4042e-01 -8.1498e-01 -2.1304e-01 -5.8140e-01  8.1769e-01
#> -1.9304e-01 -3.1196e-01 -2.2534e-01  2.1747e-01 -2.7495e-01  1.1422e+00
#> -5.7953e-01 -7.6793e-02  8.3499e-01 -5.5249e-01  6.7076e-01 -7.0809e-03
#> -6.2982e-01 -2.2006e-01  2.2986e+00 -1.0858e-01 -4.7518e-01 -6.5748e-02
#> 
#> Columns 7 to 12 4.5106e-01  3.6115e-01 -1.3084e+00  2.9041e-02 -2.0911e-01 -8.2130e-01
#>  1.1912e+00 -7.3075e-01  3.8588e-01 -7.8069e-01 -4.0413e-01 -3.3518e-01
#>  1.0025e-02 -1.0487e+00 -2.8117e-02 -8.4114e-01  1.1518e+00  1.6673e+00
#>  2.5226e-01 -2.7181e-02  4.6357e-01 -6.7224e-01 -1.0575e+00  4.4460e-01
#> -5.0025e-01 -6.5837e-02 -1.2669e-01  5.2106e-01 -2.5001e-01 -1.2199e-01
#> -7.6469e-02 -9.0876e-01 -8.6287e-01  6.2168e-01  4.4867e-01 -1.1513e+00
#> -3.0791e-01  4.6719e-01 -6.6698e-01  4.5505e-01 -3.1452e-01  4.4677e-02
#> -2.3796e-01  2.9452e-01  1.8050e+00 -2.3014e-02 -7.6431e-01 -1.9115e-01
#> -2.6378e-01  1.0625e+00  1.6711e+00  1.4281e+00  2.8299e-01  6.6748e-01
#>  1.9737e-01 -2.2290e-01  1.3783e-01  9.1378e-01  1.1617e+00  1.4401e-01
#> 
#> Columns 13 to 18 1.3247e+00  4.3923e-01  4.6452e-01  3.0960e-01 -8.8560e-01 -2.1976e-01
#> -5.1395e-01  1.3979e+00  1.6057e-01  1.3175e+00 -4.7377e-01 -4.1433e-01
#> -6.7036e-01 -1.3290e+00  2.9255e-01  9.8462e-01 -8.6272e-01  1.9397e+00
#> -3.1585e-01  4.1250e-01  4.5168e-01 -2.3258e-02  1.6954e+00 -2.0993e-01
#> -1.8619e+00  9.6965e-01 -1.0426e+00  1.5196e+00  1.3628e+00 -1.1300e+00
#> -4.2520e-01  3.5759e-01  3.6117e-01 -1.1502e+00 -1.5024e+00  8.0233e-01
#>  4.0323e-02  1.1931e+00  5.0230e-01 -6.3314e-01  9.0631e-01 -3.9448e-02
#> -1.0282e+00 -2.7839e+00 -2.3659e-01  5.3641e-01 -1.6557e-01 -5.5280e-01
#> ... [the output was truncated (use n=-1 to disable)]
#> [ CPUFloatType{10,768} ]
#> 
#> $y
#> $y$ys
#> $y$ys$y1
#> torch_tensor
#> -0.8288
#>  1.6901
#>  1.2201
#> -0.4529
#> -2.0239
#>  0.1304
#>  1.1964
#> -1.3784
#> -0.7967
#>  1.4333
#> [ CPUFloatType{10} ]
#> 
#> $y$ys$y2
#> torch_tensor
#>  0.1304
#>  1.2201
#>  1.6901
#> -1.3784
#>  1.1964
#> -0.8288
#> -0.4529
#>  1.4333
#> -0.7967
#> -2.0239
#> [ CPUFloatType{10} ]
#> 
#> 
#> $y$weight
#> torch_tensor
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#> [ CPUFloatType{10,1} ]
#> 
#>