Logic underlying luz_callback_mixup()
.
Arguments
- x
an input batch
- y
a target batch
- weight
weighting coefficient to be used by
torch_lerp()
Value
A list
of:
x
, the new, mixed-up input batchy
, alist
of:ys
, alist
of:y1
, the original targety1
y2
, the mixed-in targety2
weight
, the mixing weights
Details
Based on the passed-in input and target batches, as well as applicable mixing weights, we return new tensors intended to replace the current batch. The new input batch is a weighted linear combination of input batch items, while the new target batch bundles the original targets, as well as the mixing weights, in a nested list.
Examples
if (torch::torch_is_installed()) {
batch_x <- torch::torch_randn(c(10, 768))
batch_y <- torch::torch_randn(10)
weight <- torch::torch_tensor(rep(0.9, 10))$view(c(10, 1))
nnf_mixup(batch_x, batch_y, weight)
}
#> $x
#> torch_tensor
#> Columns 1 to 10-0.4666 -0.3302 -0.0954 0.1768 0.9489 0.6147 0.5509 0.4224 -0.2634 -0.4973
#> -0.1252 1.1699 1.5377 0.9160 1.2721 0.0309 -0.9380 -0.0820 -0.3290 0.2987
#> 0.0542 0.3228 -0.5229 0.6594 -1.3018 -0.2767 1.4184 -0.6628 1.4401 -0.3181
#> -1.5442 -0.5685 0.4252 -0.6779 0.4632 -0.5835 -1.9980 0.4899 -0.9148 -0.2087
#> -0.8431 -0.8364 -0.4039 -0.7497 -0.5604 -0.3975 -0.1721 -1.6273 2.7207 1.5245
#> 1.2143 1.6179 0.7663 0.8420 -0.0036 -0.0467 0.7109 -0.0013 -0.7940 0.4381
#> -1.3108 -0.1772 0.4936 -0.9542 -1.3703 1.1584 2.1954 -0.1444 0.3671 0.3405
#> 0.6942 -0.0074 0.6826 -0.0585 -0.6870 -0.0381 -1.6109 0.9607 0.6577 -0.6296
#> 1.1197 -0.8918 0.5665 0.5156 -0.5087 -0.6637 -0.2623 -1.5227 1.6742 -1.4385
#> 0.1608 1.1332 -0.2821 0.2620 0.6137 0.7640 0.7132 -0.5285 0.2820 0.3814
#>
#> Columns 11 to 20 0.6963 0.9946 -0.0897 -1.5790 1.2424 0.4818 0.6354 -0.6699 -1.2644 0.4841
#> 0.3006 1.0723 -0.6502 -1.9213 -0.3512 0.2582 -2.5799 0.5079 0.4422 -0.8090
#> -1.6571 -0.7464 -1.7214 -0.5732 1.9138 0.4625 0.1397 -0.3193 -0.6694 -0.4603
#> 0.9213 0.2529 -0.2167 -0.5549 -0.6101 -0.0981 0.4016 -0.1353 0.2715 -0.7108
#> -0.7686 -0.1930 -0.1708 -0.3494 1.1729 -0.7997 -0.5097 -0.5724 0.4498 -0.2360
#> -1.3915 0.4599 -0.5039 -0.4208 -0.0983 0.9079 0.2558 -1.0477 -0.4175 -0.1999
#> -0.5736 -0.0324 -0.4937 0.2072 -0.0246 -0.2932 -0.9702 -0.9180 -0.2510 0.2096
#> 1.8065 -0.7042 -1.5242 1.1530 0.0400 1.5453 -1.4683 -0.0437 0.6381 0.0741
#> 0.9144 0.4296 -0.1361 -0.2350 1.2853 -1.5669 -0.8798 0.3402 0.6803 0.3532
#> -0.5817 0.3353 -0.4928 -0.3588 0.4600 -1.1416 -0.6120 -1.0405 -0.0766 -0.7525
#>
#> Columns 21 to 30 0.4310 0.5068 1.6607 1.8781 1.3990 1.2737 -0.5521 -0.7930 0.5067 -1.1764
#> -0.3932 0.6965 0.8264 0.8000 0.5519 -1.4412 0.6280 -1.6421 -0.3239 -1.7591
#> -0.2934 0.7490 0.6991 0.4614 0.4635 0.9153 0.0861 0.7032 -0.7959 0.3185
#> -0.0027 -0.8703 0.2375 0.9798 1.3430 1.3003 -0.2012 0.9301 0.0822 0.1943
#> -1.1636 1.5258 0.2249 0.9825 -0.1999 0.6641 0.4574 1.8420 0.7305 0.1278
#> 0.9659 0.1129 0.2226 0.3009 1.1808 -2.5042 -0.7563 0.6117 -0.1254 -0.8380
#> 1.0677 -0.6868 1.1988 -0.0568 -0.0821 1.4870 -0.1164 1.2909 -0.1225 1.4425
#> -0.5155 -0.0619 0.4617 0.9681 -0.0395 0.4291 -0.0748 -0.7954 0.1731 -1.0585
#> ... [the output was truncated (use n=-1 to disable)]
#> [ CPUFloatType{10,768} ]
#>
#> $y
#> $y$ys
#> $y$ys$y1
#> torch_tensor
#> 0.8623
#> 2.2379
#> 0.8066
#> 0.3824
#> 1.2061
#> -0.1632
#> -0.8430
#> 2.1627
#> -0.0679
#> -0.2768
#> [ CPUFloatType{10} ]
#>
#> $y$ys$y2
#> torch_tensor
#> -0.2768
#> -0.0679
#> 0.8066
#> 0.8623
#> 0.3824
#> 2.1627
#> 1.2061
#> 2.2379
#> -0.1632
#> -0.8430
#> [ CPUFloatType{10} ]
#>
#>
#> $y$weight
#> torch_tensor
#> 0.9000
#> 0.9000
#> 0.9000
#> 0.9000
#> 0.9000
#> 0.9000
#> 0.9000
#> 0.9000
#> 0.9000
#> 0.9000
#> [ CPUFloatType{10,1} ]
#>
#>