Logic underlying luz_callback_mixup()
.
Arguments
- x
an input batch
- y
a target batch
- weight
weighting coefficient to be used by
torch_lerp()
Value
A list
of:
x
, the new, mixed-up input batchy
, alist
of:ys
, alist
of:y1
, the original targety1
y2
, the mixed-in targety2
weight
, the mixing weights
Details
Based on the passed-in input and target batches, as well as applicable mixing weights, we return new tensors intended to replace the current batch. The new input batch is a weighted linear combination of input batch items, while the new target batch bundles the original targets, as well as the mixing weights, in a nested list.
Examples
if (torch::torch_is_installed()) {
batch_x <- torch::torch_randn(c(10, 768))
batch_y <- torch::torch_randn(10)
weight <- torch::torch_tensor(rep(0.9, 10))$view(c(10, 1))
nnf_mixup(batch_x, batch_y, weight)
}
#> $x
#> torch_tensor
#> Columns 1 to 6-9.2156e-01 1.7575e+00 1.5404e+00 -1.2119e+00 -1.1037e+00 6.0738e-01
#> -6.1126e-01 -1.1691e+00 -1.2765e+00 6.7805e-01 7.8990e-01 3.1883e-01
#> 6.2647e-02 1.2564e-01 -4.0070e-01 -7.0199e-01 4.2789e-01 -1.1532e+00
#> 1.2746e+00 -2.0882e+00 6.7265e-02 -1.0254e+00 4.8568e-01 4.5166e-01
#> -4.0371e-01 6.8361e-01 -1.1215e-01 -1.2972e+00 -3.3998e-01 1.2756e+00
#> 1.0619e+00 1.2164e+00 2.1873e+00 6.5320e-01 2.5060e-02 -2.1226e+00
#> 8.7148e-02 -1.4042e-01 -8.1498e-01 -2.1304e-01 -5.8140e-01 8.1769e-01
#> -1.9304e-01 -3.1196e-01 -2.2534e-01 2.1747e-01 -2.7495e-01 1.1422e+00
#> -5.7953e-01 -7.6793e-02 8.3499e-01 -5.5249e-01 6.7076e-01 -7.0809e-03
#> -6.2982e-01 -2.2006e-01 2.2986e+00 -1.0858e-01 -4.7518e-01 -6.5748e-02
#>
#> Columns 7 to 12 4.5106e-01 3.6115e-01 -1.3084e+00 2.9041e-02 -2.0911e-01 -8.2130e-01
#> 1.1912e+00 -7.3075e-01 3.8588e-01 -7.8069e-01 -4.0413e-01 -3.3518e-01
#> 1.0025e-02 -1.0487e+00 -2.8117e-02 -8.4114e-01 1.1518e+00 1.6673e+00
#> 2.5226e-01 -2.7181e-02 4.6357e-01 -6.7224e-01 -1.0575e+00 4.4460e-01
#> -5.0025e-01 -6.5837e-02 -1.2669e-01 5.2106e-01 -2.5001e-01 -1.2199e-01
#> -7.6469e-02 -9.0876e-01 -8.6287e-01 6.2168e-01 4.4867e-01 -1.1513e+00
#> -3.0791e-01 4.6719e-01 -6.6698e-01 4.5505e-01 -3.1452e-01 4.4677e-02
#> -2.3796e-01 2.9452e-01 1.8050e+00 -2.3014e-02 -7.6431e-01 -1.9115e-01
#> -2.6378e-01 1.0625e+00 1.6711e+00 1.4281e+00 2.8299e-01 6.6748e-01
#> 1.9737e-01 -2.2290e-01 1.3783e-01 9.1378e-01 1.1617e+00 1.4401e-01
#>
#> Columns 13 to 18 1.3247e+00 4.3923e-01 4.6452e-01 3.0960e-01 -8.8560e-01 -2.1976e-01
#> -5.1395e-01 1.3979e+00 1.6057e-01 1.3175e+00 -4.7377e-01 -4.1433e-01
#> -6.7036e-01 -1.3290e+00 2.9255e-01 9.8462e-01 -8.6272e-01 1.9397e+00
#> -3.1585e-01 4.1250e-01 4.5168e-01 -2.3258e-02 1.6954e+00 -2.0993e-01
#> -1.8619e+00 9.6965e-01 -1.0426e+00 1.5196e+00 1.3628e+00 -1.1300e+00
#> -4.2520e-01 3.5759e-01 3.6117e-01 -1.1502e+00 -1.5024e+00 8.0233e-01
#> 4.0323e-02 1.1931e+00 5.0230e-01 -6.3314e-01 9.0631e-01 -3.9448e-02
#> -1.0282e+00 -2.7839e+00 -2.3659e-01 5.3641e-01 -1.6557e-01 -5.5280e-01
#> ... [the output was truncated (use n=-1 to disable)]
#> [ CPUFloatType{10,768} ]
#>
#> $y
#> $y$ys
#> $y$ys$y1
#> torch_tensor
#> -0.8288
#> 1.6901
#> 1.2201
#> -0.4529
#> -2.0239
#> 0.1304
#> 1.1964
#> -1.3784
#> -0.7967
#> 1.4333
#> [ CPUFloatType{10} ]
#>
#> $y$ys$y2
#> torch_tensor
#> 0.1304
#> 1.2201
#> 1.6901
#> -1.3784
#> 1.1964
#> -0.8288
#> -0.4529
#> 1.4333
#> -0.7967
#> -2.0239
#> [ CPUFloatType{10} ]
#>
#>
#> $y$weight
#> torch_tensor
#> 0.9000
#> 0.9000
#> 0.9000
#> 0.9000
#> 0.9000
#> 0.9000
#> 0.9000
#> 0.9000
#> 0.9000
#> 0.9000
#> [ CPUFloatType{10,1} ]
#>
#>