Advertisement
nouvia

benchmark softmax on cuda

Dec 18th, 2020 (edited)
1,769
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Julia 3.21 KB | None | 0 0
  1. using CUDA, NNlib, BenchmarkTools, Test, BSON
  2. using Plots
  3. pyplot() # for scale = :log2
  4. BenchmarkTools.DEFAULT_PARAMETERS.gctrial = false
  5. BenchmarkTools.DEFAULT_PARAMETERS.samples = 5
  6. BenchmarkTools.DEFAULT_PARAMETERS.seconds = 0.2
  7. BenchmarkTools.DEFAULT_PARAMETERS.evals = 1
  8.  
  9. function _softmax!(y::T, x::T; dims) where {T<:DenseCuArray}
  10.     y .= exp.(x .- maximum(x; dims))
  11.     y ./= sum(y; dims)
  12. end
  13.  
  14. # validate
  15. @testset begin
  16.     x = CUDA.rand((2:8)...)
  17.     out = similar(x)
  18.     for i = 1:ndims(x)
  19.         ref = softmax(Array(x), dims = i) |> cu
  20.         @test softmax!(out, x, dims = i) ≈ ref
  21.         @test _softmax!(out, x, dims = i) ≈ ref
  22.     end
  23. end
  24.  
  25. function plot_perf!(plt, x, y, ds, dims)
  26.     plot!(
  27.         plt,
  28.         x,
  29.         y,
  30.         label = "$(ds) dims=$(dims)",
  31.         legend = :outertopright,
  32.         title = "benchmark softmax
  33. Ratio = log(Julia Time / CUDNN Time)
  34. >0 means CUDNN is faster than Julia",
  35.         xlabel = "batch size",
  36.         ylabel = "Ratio",
  37.         # yscale = :log10,
  38.         xscale = :log2,
  39.         dpi = 300,
  40.     )
  41.     plt
  42. end
  43.  
  44. select_last_dim(xs::AbstractArray{T,N}, inds) where {T,N} =
  45.     @views xs[ntuple(_ -> (:), N - 1)..., inds]
  46.  
  47. function benchsoftmax(nd, maxbatch)
  48.     x = CUDA.rand(nd..., 2^maxbatch)
  49.     out = similar(x)
  50.     batches = 2 .^ (1:maxbatch) # change it to the desired values.
  51.     results = Array{BenchmarkTools.Trial}(undef, length(batches), 2)
  52.     for dims = 1:ndims(x)
  53.         isfile("benchmark_softmax_$(nd)_$(dims).BSON") && continue
  54.         for (i, b) in enumerate(batches)
  55.             println("$i / $(length(batches))")
  56.             y = select_last_dim(x, 1:b)
  57.             o = select_last_dim(x, 1:b)
  58.             for (j, fn) in [(1, _softmax!), (2, softmax!)]
  59.                 results[i, j] = @benchmark CUDA.@sync $fn($o, $y, dims = $dims)
  60.             end
  61.         end
  62.         BSON.@save "benchmark_softmax_$(nd)_$(dims).BSON" nd batches dims results
  63.     end
  64. end
  65.  
  66. for i in (6:12)
  67.     benchsoftmax((2^i,), 24 - i)
  68. end
  69. benchsoftmax((1024, 1024), 8)
  70.  
  71. fnames = [
  72.     "benchmark_softmax_(64,)_1.BSON"
  73.     "benchmark_softmax_(128,)_1.BSON"
  74.     "benchmark_softmax_(256,)_1.BSON"
  75.     "benchmark_softmax_(512,)_1.BSON"
  76.     "benchmark_softmax_(1024,)_1.BSON"
  77.     "benchmark_softmax_(2048,)_1.BSON"
  78.     "benchmark_softmax_(4096,)_1.BSON"
  79.     "benchmark_softmax_(64,)_2.BSON"
  80.     "benchmark_softmax_(128,)_2.BSON"
  81.     "benchmark_softmax_(256,)_2.BSON"
  82.     "benchmark_softmax_(512,)_2.BSON"
  83.     "benchmark_softmax_(1024,)_2.BSON"
  84.     "benchmark_softmax_(2048,)_2.BSON"
  85.     "benchmark_softmax_(4096,)_2.BSON"
  86.     "benchmark_softmax_(1024, 1024)_1.BSON"
  87.     "benchmark_softmax_(1024, 1024)_2.BSON"
  88.     "benchmark_softmax_(1024, 1024)_3.BSON"
  89.     ""
  90.     ""
  91.     ""
  92.     ""
  93. ]
  94.  
  95. for (i, fname) in enumerate(fnames)
  96.     if i == 1
  97.         plt = plot()
  98.     end
  99.     try
  100.         BSON.@load fname nd batches dims results
  101.         dur = time.(median.(results))
  102.         ratio = log.(dur[:, 1] ./ dur[:, 2])
  103.         plot_perf!(plt, batches, ratio, "$nd x B", dims)
  104.     catch e
  105.     end
  106.     if i % 7 == 0
  107.         savefig(plt, "benchmark_softmax_$i.svg")
  108.         savefig(plt, "benchmark_softmax_$i.png")
  109.         plt = plot()
  110.     end
  111. end
  112.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement