Logic underlying luz_callback_mixup()
.
Arguments
- x
an input batch
- y
a target batch
- weight
weighting coefficient to be used by
torch_lerp()
Value
A list
of:
x
, the new, mixed-up input batchy
, alist
of:ys
, alist
of:y1
, the original targety1
y2
, the mixed-in targety2
weight
, the mixing weights
Details
Based on the passed-in input and target batches, as well as applicable mixing weights, we return new tensors intended to replace the current batch. The new input batch is a weighted linear combination of input batch items, while the new target batch bundles the original targets, as well as the mixing weights, in a nested list.
Examples
if (torch::torch_is_installed()) {
batch_x <- torch::torch_randn(c(10, 768))
batch_y <- torch::torch_randn(10)
weight <- torch::torch_tensor(rep(0.9, 10))$view(c(10, 1))
nnf_mixup(batch_x, batch_y, weight)
}
#> $x
#> torch_tensor
#> Columns 1 to 10-0.6469 -0.9310 0.1524 -0.4395 -1.4442 -1.2684 0.5356 -0.6808 1.3154 -0.4752
#> -0.7167 -0.0031 0.9353 -0.6944 1.6350 1.1504 0.8301 0.8876 -0.6360 -0.8560
#> 1.4219 -1.4930 -0.0468 0.9367 0.5145 0.0812 0.1935 -0.6160 -1.3576 0.1073
#> -1.6242 -0.6707 -0.0747 -1.0209 1.0642 -0.9675 0.1963 0.1962 -0.1281 -0.5428
#> -0.6315 -0.9305 1.5236 0.2141 -0.7123 0.2535 0.0460 2.2670 1.6615 0.8583
#> -0.3951 -0.3679 0.4367 -0.8363 -0.9625 -0.5817 -0.9201 0.4515 0.0039 -0.2890
#> 0.3568 0.4232 -0.4579 0.5568 0.1149 -0.1283 0.2694 -0.5650 0.4589 -0.5266
#> -1.1249 -1.4162 -0.0954 0.2440 0.6322 -0.2990 0.5144 -1.2415 -0.7254 0.9307
#> 0.1869 -1.5618 1.3070 -1.7564 -0.2035 -1.0637 -0.4798 -1.3048 -1.1315 -0.0870
#> 1.1378 -0.3488 -1.0797 -1.0463 1.0734 -1.1346 0.8522 1.1682 0.8022 -0.4833
#>
#> Columns 11 to 20 0.4183 -0.0387 -0.0418 -0.3676 -1.7040 -0.2589 1.8973 -0.6650 1.7713 -0.8560
#> 0.0803 1.0373 -0.9608 -0.2193 0.6525 1.7050 -0.0753 -0.1556 -1.3076 0.9122
#> -0.2489 -0.1274 0.7356 1.2849 0.1981 -0.5206 -0.2472 -1.3160 0.1743 0.0312
#> -0.6380 0.2101 -0.5718 -0.2799 1.4505 -0.2653 -0.1349 0.7861 -0.7610 -0.7263
#> -0.0696 -1.0550 -0.8051 1.5486 0.5974 -0.5647 0.9874 -0.0439 1.9012 0.0331
#> -0.1400 -0.5078 0.6952 0.1667 -0.5133 0.3830 0.1046 0.0747 0.4461 -1.4822
#> 1.0327 0.4595 0.7464 0.1188 1.0748 0.8945 -0.2496 -0.1329 -0.0348 -0.0080
#> -0.6273 0.7710 -0.0949 0.7634 -0.1691 1.1655 -0.5125 -0.3077 1.1594 -0.8206
#> 0.7015 -0.7710 0.6539 -0.5061 -0.9573 0.2129 -1.3081 -0.7200 1.1786 1.2563
#> -0.2125 1.3604 -0.0418 1.2541 -0.2737 1.1466 0.5154 0.3405 0.9556 -0.5195
#>
#> Columns 21 to 30-1.3697 0.6732 0.1280 0.4123 -0.0084 1.0859 0.3769 0.0599 0.4626 0.0802
#> 1.5471 0.5927 -0.2893 -0.0725 -0.4028 0.0656 1.5710 -0.6081 -1.4766 -0.1458
#> -0.9554 -0.0182 -0.3317 0.6831 -0.6641 -0.5517 0.0284 0.8557 -0.8856 -0.2497
#> -1.4536 1.0637 -0.3293 1.9786 0.8624 -1.6988 0.0979 -0.6237 -0.4204 -0.2358
#> 0.2312 0.2066 0.3519 -0.0037 0.4934 -0.1126 -0.5809 0.1636 0.6992 -0.2484
#> 0.7449 0.6745 0.4239 -1.1411 0.7445 -0.6189 -1.7413 -0.5367 0.2897 -1.7761
#> 0.5034 0.9260 -0.2767 -0.6361 -0.5847 0.4803 -0.1097 0.3087 0.2315 0.0879
#> -0.5007 0.7723 -1.2654 1.6867 0.0664 -0.8799 -0.6741 0.0403 0.3340 -0.4850
#> ... [the output was truncated (use n=-1 to disable)]
#> [ CPUFloatType{10,768} ]
#>
#> $y
#> $y$ys
#> $y$ys$y1
#> torch_tensor
#> -0.2407
#> -1.0981
#> 1.5146
#> -0.7193
#> 0.3000
#> 0.5489
#> 0.8206
#> -1.4301
#> -0.1189
#> -0.2843
#> [ CPUFloatType{10} ]
#>
#> $y$ys$y2
#> torch_tensor
#> -0.2843
#> 0.5489
#> -0.7193
#> -1.4301
#> -0.2407
#> 1.5146
#> -1.0981
#> -0.1189
#> 0.3000
#> 0.8206
#> [ CPUFloatType{10} ]
#>
#>
#> $y$weight
#> torch_tensor
#> 0.9000
#> 0.9000
#> 0.9000
#> 0.9000
#> 0.9000
#> 0.9000
#> 0.9000
#> 0.9000
#> 0.9000
#> 0.9000
#> [ CPUFloatType{10,1} ]
#>
#>