Skip to contents

Logic underlying luz_callback_mixup().

Usage

nnf_mixup(x, y, weight)

Arguments

x

an input batch

y

a target batch

weight

weighting coefficient to be used by torch_lerp()

Value

A list of:

  • x, the new, mixed-up input batch

  • y, a list of:

    • ys, a list of:

      • y1, the original target y1

      • y2, the mixed-in target y2

    • weight, the mixing weights

Details

Based on the passed-in input and target batches, as well as applicable mixing weights, we return new tensors intended to replace the current batch. The new input batch is a weighted linear combination of input batch items, while the new target batch bundles the original targets, as well as the mixing weights, in a nested list.

Examples

if (torch::torch_is_installed()) {
batch_x <- torch::torch_randn(c(10, 768))
batch_y <- torch::torch_randn(10)
weight <- torch::torch_tensor(rep(0.9, 10))$view(c(10, 1))
nnf_mixup(batch_x, batch_y, weight)
}
#> $x
#> torch_tensor
#> Columns 1 to 10-0.4666 -0.3302 -0.0954  0.1768  0.9489  0.6147  0.5509  0.4224 -0.2634 -0.4973
#> -0.1252  1.1699  1.5377  0.9160  1.2721  0.0309 -0.9380 -0.0820 -0.3290  0.2987
#>  0.0542  0.3228 -0.5229  0.6594 -1.3018 -0.2767  1.4184 -0.6628  1.4401 -0.3181
#> -1.5442 -0.5685  0.4252 -0.6779  0.4632 -0.5835 -1.9980  0.4899 -0.9148 -0.2087
#> -0.8431 -0.8364 -0.4039 -0.7497 -0.5604 -0.3975 -0.1721 -1.6273  2.7207  1.5245
#>  1.2143  1.6179  0.7663  0.8420 -0.0036 -0.0467  0.7109 -0.0013 -0.7940  0.4381
#> -1.3108 -0.1772  0.4936 -0.9542 -1.3703  1.1584  2.1954 -0.1444  0.3671  0.3405
#>  0.6942 -0.0074  0.6826 -0.0585 -0.6870 -0.0381 -1.6109  0.9607  0.6577 -0.6296
#>  1.1197 -0.8918  0.5665  0.5156 -0.5087 -0.6637 -0.2623 -1.5227  1.6742 -1.4385
#>  0.1608  1.1332 -0.2821  0.2620  0.6137  0.7640  0.7132 -0.5285  0.2820  0.3814
#> 
#> Columns 11 to 20 0.6963  0.9946 -0.0897 -1.5790  1.2424  0.4818  0.6354 -0.6699 -1.2644  0.4841
#>  0.3006  1.0723 -0.6502 -1.9213 -0.3512  0.2582 -2.5799  0.5079  0.4422 -0.8090
#> -1.6571 -0.7464 -1.7214 -0.5732  1.9138  0.4625  0.1397 -0.3193 -0.6694 -0.4603
#>  0.9213  0.2529 -0.2167 -0.5549 -0.6101 -0.0981  0.4016 -0.1353  0.2715 -0.7108
#> -0.7686 -0.1930 -0.1708 -0.3494  1.1729 -0.7997 -0.5097 -0.5724  0.4498 -0.2360
#> -1.3915  0.4599 -0.5039 -0.4208 -0.0983  0.9079  0.2558 -1.0477 -0.4175 -0.1999
#> -0.5736 -0.0324 -0.4937  0.2072 -0.0246 -0.2932 -0.9702 -0.9180 -0.2510  0.2096
#>  1.8065 -0.7042 -1.5242  1.1530  0.0400  1.5453 -1.4683 -0.0437  0.6381  0.0741
#>  0.9144  0.4296 -0.1361 -0.2350  1.2853 -1.5669 -0.8798  0.3402  0.6803  0.3532
#> -0.5817  0.3353 -0.4928 -0.3588  0.4600 -1.1416 -0.6120 -1.0405 -0.0766 -0.7525
#> 
#> Columns 21 to 30 0.4310  0.5068  1.6607  1.8781  1.3990  1.2737 -0.5521 -0.7930  0.5067 -1.1764
#> -0.3932  0.6965  0.8264  0.8000  0.5519 -1.4412  0.6280 -1.6421 -0.3239 -1.7591
#> -0.2934  0.7490  0.6991  0.4614  0.4635  0.9153  0.0861  0.7032 -0.7959  0.3185
#> -0.0027 -0.8703  0.2375  0.9798  1.3430  1.3003 -0.2012  0.9301  0.0822  0.1943
#> -1.1636  1.5258  0.2249  0.9825 -0.1999  0.6641  0.4574  1.8420  0.7305  0.1278
#>  0.9659  0.1129  0.2226  0.3009  1.1808 -2.5042 -0.7563  0.6117 -0.1254 -0.8380
#>  1.0677 -0.6868  1.1988 -0.0568 -0.0821  1.4870 -0.1164  1.2909 -0.1225  1.4425
#> -0.5155 -0.0619  0.4617  0.9681 -0.0395  0.4291 -0.0748 -0.7954  0.1731 -1.0585
#> ... [the output was truncated (use n=-1 to disable)]
#> [ CPUFloatType{10,768} ]
#> 
#> $y
#> $y$ys
#> $y$ys$y1
#> torch_tensor
#>  0.8623
#>  2.2379
#>  0.8066
#>  0.3824
#>  1.2061
#> -0.1632
#> -0.8430
#>  2.1627
#> -0.0679
#> -0.2768
#> [ CPUFloatType{10} ]
#> 
#> $y$ys$y2
#> torch_tensor
#> -0.2768
#> -0.0679
#>  0.8066
#>  0.8623
#>  0.3824
#>  2.1627
#>  1.2061
#>  2.2379
#> -0.1632
#> -0.8430
#> [ CPUFloatType{10} ]
#> 
#> 
#> $y$weight
#> torch_tensor
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#>  0.9000
#> [ CPUFloatType{10,1} ]
#> 
#>