Skip to contents

Logic underlying luz_callback_mixup().

Usage

nnf_mixup(x, y, weight)

Arguments

x

an input batch

y

a target batch

weight

weighting coefficient to be used by torch_lerp()

Value

A list of:

  • x, the new, mixed-up input batch

  • y, a list of:

    • ys, a list of:

      • y1, the original target y1

      • y2, the mixed-in target y2

    • weight, the mixing weights

Details

Based on the passed-in input and target batches, as well as applicable mixing weights, we return new tensors intended to replace the current batch. The new input batch is a weighted linear combination of input batch items, while the new target batch bundles the original targets, as well as the mixing weights, in a nested list.

Examples

if (torch::torch_is_installed()) {
batch_x <- torch::torch_randn(c(10, 768))
batch_y <- torch::torch_randn(10)
weight <- torch::torch_tensor(rep(0.9, 10))$view(c(10, 1))
nnf_mixup(batch_x, batch_y, weight)
}
#> $x
#> torch_tensor
#> Columns 1 to 10-0.6469 -0.9310  0.1524 -0.4395 -1.4442 -1.2684  0.5356 -0.6808  1.3154 -0.4752
#> -0.7167 -0.0031  0.9353 -0.6944  1.6350  1.1504  0.8301  0.8876 -0.6360 -0.8560
#>  1.4219 -1.4930 -0.0468  0.9367  0.5145  0.0812  0.1935 -0.6160 -1.3576  0.1073
#> -1.6242 -0.6707 -0.0747 -1.0209  1.0642 -0.9675  0.1963  0.1962 -0.1281 -0.5428
#> -0.6315 -0.9305  1.5236  0.2141 -0.7123  0.2535  0.0460  2.2670  1.6615  0.8583
#> -0.3951 -0.3679  0.4367 -0.8363 -0.9625 -0.5817 -0.9201  0.4515  0.0039 -0.2890
#>  0.3568  0.4232 -0.4579  0.5568  0.1149 -0.1283  0.2694 -0.5650  0.4589 -0.5266
#> -1.1249 -1.4162 -0.0954  0.2440  0.6322 -0.2990  0.5144 -1.2415 -0.7254  0.9307
#>  0.1869 -1.5618  1.3070 -1.7564 -0.2035 -1.0637 -0.4798 -1.3048 -1.1315 -0.0870
#>  1.1378 -0.3488 -1.0797 -1.0463  1.0734 -1.1346  0.8522  1.1682  0.8022 -0.4833
#> 
#> Columns 11 to 20 0.4183 -0.0387 -0.0418 -0.3676 -1.7040 -0.2589  1.8973 -0.6650  1.7713 -0.8560
#>  0.0803  1.0373 -0.9608 -0.2193  0.6525  1.7050 -0.0753 -0.1556 -1.3076  0.9122
#> -0.2489 -0.1274  0.7356  1.2849  0.1981 -0.5206 -0.2472 -1.3160  0.1743  0.0312
#> -0.6380  0.2101 -0.5718 -0.2799  1.4505 -0.2653 -0.1349  0.7861 -0.7610 -0.7263
#> -0.0696 -1.0550 -0.8051  1.5486  0.5974 -0.5647  0.9874 -0.0439  1.9012  0.0331
#> -0.1400 -0.5078  0.6952  0.1667 -0.5133  0.3830  0.1046  0.0747  0.4461 -1.4822
#>  1.0327  0.4595  0.7464  0.1188  1.0748  0.8945 -0.2496 -0.1329 -0.0348 -0.0080
#> -0.6273  0.7710 -0.0949  0.7634 -0.1691  1.1655 -0.5125 -0.3077  1.1594 -0.8206
#>  0.7015 -0.7710  0.6539 -0.5061 -0.9573  0.2129 -1.3081 -0.7200  1.1786  1.2563
#> -0.2125  1.3604 -0.0418  1.2541 -0.2737  1.1466  0.5154  0.3405  0.9556 -0.5195
#> 
#> Columns 21 to 30-1.3697  0.6732  0.1280  0.4123 -0.0084  1.0859  0.3769  0.0599  0.4626  0.0802
#>  1.5471  0.5927 -0.2893 -0.0725 -0.4028  0.0656  1.5710 -0.6081 -1.4766 -0.1458
#> -0.9554 -0.0182 -0.3317  0.6831 -0.6641 -0.5517  0.0284  0.8557 -0.8856 -0.2497
#> -1.4536  1.0637 -0.3293  1.9786  0.8624 -1.6988  0.0979 -0.6237 -0.4204 -0.2358
#>  0.2312  0.2066  0.3519 -0.0037  0.4934 -0.1126 -0.5809  0.1636  0.6992 -0.2484
#>  0.7449  0.6745  0.4239 -1.1411  0.7445 -0.6189 -1.7413 -0.5367  0.2897 -1.7761
#>  0.5034  0.9260 -0.2767 -0.6361 -0.5847  0.4803 -0.1097  0.3087  0.2315  0.0879
#> -0.5007  0.7723 -1.2654  1.6867  0.0664 -0.8799 -0.6741  0.0403  0.3340 -0.4850
#> ... [the output was truncated (use n=-1 to disable)]
#> [ CPUFloatType{10,768} ]
#> 
#> $y
#> $y$ys
#> $y$ys$y1
#> torch_tensor
#> -0.2407
#> -1.0981
#>  1.5146
#> -0.7193
#>  0.3000
#>  0.5489
#>  0.8206
#> -1.4301
#> -0.1189
#> -0.2843
#> [ CPUFloatType{10} ]
#> 
#> $y$ys$y2
#> torch_tensor
#> -0.2843
#>  0.5489
#> -0.7193
#> -1.4301
#> -0.2407
#>  1.5146
#> -1.0981
#> -0.1189
#>  0.3000
#>  0.8206
#> [ CPUFloatType{10} ]
#> 
#> 
#> $y$weight
#> torch_tensor
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#> [ CPUFloatType{10,1} ]
#> 
#>