Skip to contents

Logic underlying luz_callback_mixup().

Usage

nnf_mixup(x, y, weight)

Arguments

x

an input batch

y

a target batch

weight

weighting coefficient to be used by torch_lerp()

Value

A list of:

  • x, the new, mixed-up input batch

  • y, a list of:

    • ys, a list of:

      • y1, the original target y1

      • y2, the mixed-in target y2

    • weight, the mixing weights

Details

Based on the passed-in input and target batches, as well as applicable mixing weights, we return new tensors intended to replace the current batch. The new input batch is a weighted linear combination of input batch items, while the new target batch bundles the original targets, as well as the mixing weights, in a nested list.

Examples

if (torch::torch_is_installed()) {
batch_x <- torch::torch_randn(c(10, 768))
batch_y <- torch::torch_randn(10)
weight <- torch::torch_tensor(rep(0.9, 10))$view(c(10, 1))
nnf_mixup(batch_x, batch_y, weight)
}
#> $x
#> torch_tensor
#> Columns 1 to 10-0.0388 -0.5147  0.1528  1.4681 -0.1335 -0.4187  0.0916  0.4432  0.3333  1.1760
#>  0.1065  1.2875  1.1909  0.3518 -0.0729  1.1922  1.9033  0.7038 -0.0371 -0.6199
#>  0.3751 -0.4785  1.4614 -1.0684  0.0674 -0.2723  0.0538  0.8777 -0.7899 -0.0546
#>  1.7172  0.6017  0.5209 -0.8755 -2.8726  0.7486  0.7031  1.3526  0.2161  0.0619
#> -0.4907 -0.5043  1.7053 -1.2472  1.3692 -1.2678 -0.6934 -0.4278  1.1817  1.6249
#>  0.7504 -0.9469  0.5333 -0.9636  0.6600  0.2351 -0.3975 -0.5084  2.0290 -1.2584
#> -0.4016 -0.5389  0.2617 -0.3103 -0.3897 -0.0355 -1.8771 -0.4196  0.9311 -1.5272
#> -0.1946 -1.2560 -0.8337  0.3822 -1.0204  0.3273  0.3029  0.7270  1.7815 -2.0175
#>  1.8852  1.4489  2.1339  1.6041  0.5192 -0.1285  0.0362  0.8030  0.5516  0.3031
#>  0.1802  0.4184 -0.3096  0.4516 -0.3219  0.9342  0.8262 -0.9324  0.6644 -0.1857
#> 
#> Columns 11 to 20 0.7707  0.6651 -1.6346 -0.5143  1.0511  0.0390 -0.6011 -0.9578 -0.5305 -0.1290
#> -0.6383 -0.4460 -0.7367 -1.9040 -0.1486 -0.0339  0.9253 -0.0205 -0.5200  0.2068
#>  0.0075  0.0436 -1.0644 -0.6352 -0.0335  0.0208  0.5054  1.0118 -1.2105  0.0660
#>  1.3097 -0.4146  1.3684 -0.2652 -0.3315 -0.9235 -1.4044  0.9823  0.0922 -1.1068
#> -0.7146  0.1868 -0.8744 -0.4222  1.0329 -1.1169  1.3180 -0.6371  0.5652 -1.5883
#> -0.5268  0.6825  0.9629 -1.3293 -0.9162  0.2485 -0.5864 -0.4100 -0.3027  0.4357
#>  0.6108  0.1319 -0.5298  0.2008 -0.5582  0.4560  0.6962 -2.6540 -0.8816  1.2026
#> -0.1099 -0.1066 -0.0099  0.3775 -0.2429 -0.4107 -1.4051 -0.8963 -0.9110 -0.1178
#>  0.7021  0.0067 -1.2622  0.3111  0.0115  0.6144  0.1081  1.4005  0.3951  0.5762
#> -1.5525 -0.0775 -1.3004  1.8015  0.7441  0.0715  0.2697  0.7478  0.5416  1.6924
#> 
#> Columns 21 to 30 0.1282 -0.0046  1.1695 -1.0575 -0.2158  0.0105 -0.9580 -0.9434  0.2543 -0.2760
#> -0.1143 -1.9662 -0.7360 -0.2951  1.5821 -0.9543  0.6810 -1.3308 -0.4612  0.5321
#> -0.5334  0.7384  0.6912  0.6091 -0.8392  0.2161 -1.0272  0.3630 -0.5804  1.1235
#>  0.5819  1.1540  0.2377 -0.9044 -1.3636 -0.7521  0.1972 -0.2546 -0.3545  0.0031
#>  1.6457  0.0663 -0.4243 -1.2821  0.6887 -0.8063 -1.4193  0.5372  1.7860 -0.5206
#>  0.7545 -0.5751 -1.0241 -1.1748 -0.3331  0.9950 -0.6852 -1.5422 -0.5874 -0.7968
#> -1.3007 -0.9115 -0.6044 -0.1871  0.0760  0.5478 -1.8621 -0.6550 -0.6051  1.3657
#> -0.3948 -1.2257  0.9265 -1.3015  1.6928 -0.9030 -0.1255 -0.1473  0.1594 -0.4501
#> ... [the output was truncated (use n=-1 to disable)]
#> [ CPUFloatType{10,768} ]
#> 
#> $y
#> $y$ys
#> $y$ys$y1
#> torch_tensor
#>  1.1525
#>  0.5113
#>  1.5977
#> -0.9638
#>  1.2944
#>  0.1385
#> -0.1867
#>  0.9214
#>  0.8290
#>  0.3817
#> [ CPUFloatType{10} ]
#> 
#> $y$ys$y2
#> torch_tensor
#>  1.5977
#>  0.8290
#>  0.9214
#> -0.9638
#>  1.1525
#>  1.2944
#>  0.1385
#> -0.1867
#>  0.3817
#>  0.5113
#> [ CPUFloatType{10} ]
#> 
#> 
#> $y$weight
#> torch_tensor
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#> [ CPUFloatType{10,1} ]
#> 
#>