Fused Softmax — Triton documentation

Excerpt

Note


Note

Go to the end to download the full example code.

In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.

In doing so, you will learn about:

  • The benefits of kernel fusion for bandwidth-bound operations.

  • Reduction operators in Triton.

Motivations

Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:

<span></span><span>import</span> <span>torch</span>

<span>import</span> <span>triton</span>
<span>import</span> <span>triton.language</span> <span>as</span> <span>tl</span>
<span>from</span> <span>triton.runtime</span> <span>import</span> <span>driver</span>


<span>def</span> <span>is_hip</span><span>():</span>
    <span>return</span> <span>triton</span><span>.</span><span>runtime</span><span>.</span><span>driver</span><span>.</span><span>active</span><span>.</span><span>get_current_target</span><span>()</span><span>.</span><span>backend</span> <span>==</span> <span>"hip"</span>


<span>def</span> <span>is_cdna</span><span>():</span>
    <span>return</span> <span>is_hip</span><span>()</span> <span>and</span> <span>triton</span><span>.</span><span>runtime</span><span>.</span><span>driver</span><span>.</span><span>active</span><span>.</span><span>get_current_target</span><span>()</span><span>.</span><span>arch</span> <span>in</span> <span>(</span><span>'gfx940'</span><span>,</span> <span>'gfx941'</span><span>,</span> <span>'gfx942'</span><span>,</span>
                                                                                   <span>'gfx90a'</span><span>,</span> <span>'gfx908'</span><span>)</span>


<span>def</span> <span>naive_softmax</span><span>(</span><span>x</span><span>):</span>
<span>    </span><span>"""Compute row-wise softmax of X using native pytorch</span>

<span>    We subtract the maximum element in order to avoid overflows. Softmax is invariant to</span>
<span>    this shift.</span>
<span>    """</span>
    <span># read  MN elements ; write M  elements</span>
    <span>x_max</span> <span>=</span> <span>x</span><span>.</span><span>max</span><span>(</span><span>dim</span><span>=</span><span>1</span><span>)[</span><span>0</span><span>]</span>
    <span># read MN + M elements ; write MN elements</span>
    <span>z</span> <span>=</span> <span>x</span> <span>-</span> <span>x_max</span><span>[:,</span> <span>None</span><span>]</span>
    <span># read  MN elements ; write MN elements</span>
    <span>numerator</span> <span>=</span> <span>torch</span><span>.</span><span>exp</span><span>(</span><span>z</span><span>)</span>
    <span># read  MN elements ; write M  elements</span>
    <span>denominator</span> <span>=</span> <span>numerator</span><span>.</span><span>sum</span><span>(</span><span>dim</span><span>=</span><span>1</span><span>)</span>
    <span># read MN + M elements ; write MN elements</span>
    <span>ret</span> <span>=</span> <span>numerator</span> <span>/</span> <span>denominator</span><span>[:,</span> <span>None</span><span>]</span>
    <span># in total: read 5MN + 2M elements ; wrote 3MN + 2M elements</span>
    <span>return</span> <span>ret</span>

When implemented naively in PyTorch, computing y = naive_softmax(x) for requires reading elements from DRAM and writing back elements. This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads X once and does all the necessary computations on-chip. Doing so would require reading and writing back only bytes, so we could expect a theoretical speed-up of ~4x (i.e., ). The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically but, as we will see later, it is still far from ideal.

Compute Kernel

Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.

Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:

<span></span><span>@triton</span><span>.</span><span>jit</span>
<span>def</span> <span>softmax_kernel</span><span>(</span><span>output_ptr</span><span>,</span> <span>input_ptr</span><span>,</span> <span>input_row_stride</span><span>,</span> <span>output_row_stride</span><span>,</span> <span>n_rows</span><span>,</span> <span>n_cols</span><span>,</span> <span>BLOCK_SIZE</span><span>:</span> <span>tl</span><span>.</span><span>constexpr</span><span>,</span>
                   <span>num_stages</span><span>:</span> <span>tl</span><span>.</span><span>constexpr</span><span>):</span>
    <span># starting row of the program</span>
    <span>row_start</span> <span>=</span> <span>tl</span><span>.</span><span>program_id</span><span>(</span><span>0</span><span>)</span>
    <span>row_step</span> <span>=</span> <span>tl</span><span>.</span><span>num_programs</span><span>(</span><span>0</span><span>)</span>
    <span>for</span> <span>row_idx</span> <span>in</span> <span>tl</span><span>.</span><span>range</span><span>(</span><span>row_start</span><span>,</span> <span>n_rows</span><span>,</span> <span>row_step</span><span>,</span> <span>num_stages</span><span>=</span><span>num_stages</span><span>):</span>
        <span># The stride represents how much we need to increase the pointer to advance 1 row</span>
        <span>row_start_ptr</span> <span>=</span> <span>input_ptr</span> <span>+</span> <span>row_idx</span> <span>*</span> <span>input_row_stride</span>
        <span># The block size is the next power of two greater than n_cols, so we can fit each</span>
        <span># row in a single block</span>
        <span>col_offsets</span> <span>=</span> <span>tl</span><span>.</span><span>arange</span><span>(</span><span>0</span><span>,</span> <span>BLOCK_SIZE</span><span>)</span>
        <span>input_ptrs</span> <span>=</span> <span>row_start_ptr</span> <span>+</span> <span>col_offsets</span>
        <span># Load the row into SRAM, using a mask since BLOCK_SIZE may be &gt; than n_cols</span>
        <span>mask</span> <span>=</span> <span>col_offsets</span> <span>&lt;</span> <span>n_cols</span>
        <span>row</span> <span>=</span> <span>tl</span><span>.</span><span>load</span><span>(</span><span>input_ptrs</span><span>,</span> <span>mask</span><span>=</span><span>mask</span><span>,</span> <span>other</span><span>=-</span><span>float</span><span>(</span><span>'inf'</span><span>))</span>
        <span># Subtract maximum for numerical stability</span>
        <span>row_minus_max</span> <span>=</span> <span>row</span> <span>-</span> <span>tl</span><span>.</span><span>max</span><span>(</span><span>row</span><span>,</span> <span>axis</span><span>=</span><span>0</span><span>)</span>
        <span># Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)</span>
        <span>numerator</span> <span>=</span> <span>tl</span><span>.</span><span>exp</span><span>(</span><span>row_minus_max</span><span>)</span>
        <span>denominator</span> <span>=</span> <span>tl</span><span>.</span><span>sum</span><span>(</span><span>numerator</span><span>,</span> <span>axis</span><span>=</span><span>0</span><span>)</span>
        <span>softmax_output</span> <span>=</span> <span>numerator</span> <span>/</span> <span>denominator</span>
        <span># Write back output to DRAM</span>
        <span>output_row_start_ptr</span> <span>=</span> <span>output_ptr</span> <span>+</span> <span>row_idx</span> <span>*</span> <span>output_row_stride</span>
        <span>output_ptrs</span> <span>=</span> <span>output_row_start_ptr</span> <span>+</span> <span>col_offsets</span>
        <span>tl</span><span>.</span><span>store</span><span>(</span><span>output_ptrs</span><span>,</span> <span>softmax_output</span><span>,</span> <span>mask</span><span>=</span><span>mask</span><span>)</span>

We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.

<span></span><span>device</span> <span>=</span> <span>torch</span><span>.</span><span>cuda</span><span>.</span><span>current_device</span><span>()</span>
<span>properties</span> <span>=</span> <span>driver</span><span>.</span><span>active</span><span>.</span><span>utils</span><span>.</span><span>get_device_properties</span><span>(</span><span>device</span><span>)</span>
<span>NUM_SM</span> <span>=</span> <span>properties</span><span>[</span><span>"multiprocessor_count"</span><span>]</span>
<span>NUM_REGS</span> <span>=</span> <span>properties</span><span>[</span><span>"max_num_regs"</span><span>]</span>
<span>SIZE_SMEM</span> <span>=</span> <span>properties</span><span>[</span><span>"max_shared_mem"</span><span>]</span>
<span>WARP_SIZE</span> <span>=</span> <span>properties</span><span>[</span><span>"warpSize"</span><span>]</span>
<span>target</span> <span>=</span> <span>triton</span><span>.</span><span>runtime</span><span>.</span><span>driver</span><span>.</span><span>active</span><span>.</span><span>get_current_target</span><span>()</span>
<span>kernels</span> <span>=</span> <span>{}</span>


<span>def</span> <span>softmax</span><span>(</span><span>x</span><span>):</span>
    <span>n_rows</span><span>,</span> <span>n_cols</span> <span>=</span> <span>x</span><span>.</span><span>shape</span>

    <span># The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`</span>
    <span>BLOCK_SIZE</span> <span>=</span> <span>triton</span><span>.</span><span>next_power_of_2</span><span>(</span><span>n_cols</span><span>)</span>

    <span># Another trick we can use is to ask the compiler to use more threads per row by</span>
    <span># increasing the number of warps (`num_warps`) over which each row is distributed.</span>
    <span># You will see in the next tutorial how to auto-tune this value in a more natural</span>
    <span># way so you don't have to come up with manual heuristics yourself.</span>
    <span>num_warps</span> <span>=</span> <span>8</span>

    <span># Number of software piepling stages.</span>
    <span>num_stages</span> <span>=</span> <span>4</span> <span>if</span> <span>SIZE_SMEM</span> <span>&gt;</span> <span>200000</span> <span>else</span> <span>2</span>

    <span># Allocate output</span>
    <span>y</span> <span>=</span> <span>torch</span><span>.</span><span>empty_like</span><span>(</span><span>x</span><span>)</span>

    <span># pre-compile kernel to get register usage and compute thread occupancy.</span>
    <span>kernel</span><span>,</span> <span>num_programs</span> <span>=</span> <span>kernels</span><span>.</span><span>get</span><span>(</span><span>BLOCK_SIZE</span><span>,</span> <span>(</span><span>None</span><span>,</span> <span>0</span><span>))</span>
    <span>if</span> <span>kernel</span> <span>is</span> <span>None</span><span>:</span>
        <span>kernel</span> <span>=</span> <span>softmax_kernel</span><span>.</span><span>warmup</span><span>(</span><span>y</span><span>,</span> <span>x</span><span>,</span> <span>x</span><span>.</span><span>stride</span><span>(</span><span>0</span><span>),</span> <span>y</span><span>.</span><span>stride</span><span>(</span><span>0</span><span>),</span> <span>n_rows</span><span>,</span> <span>n_cols</span><span>,</span> <span>BLOCK_SIZE</span><span>=</span><span>BLOCK_SIZE</span><span>,</span>
                                       <span>num_stages</span><span>=</span><span>num_stages</span><span>,</span> <span>num_warps</span><span>=</span><span>num_warps</span><span>,</span> <span>grid</span><span>=</span><span>(</span><span>1</span><span>,</span> <span>))</span>
        <span>kernel</span><span>.</span><span>_init_handles</span><span>()</span>
        <span>n_regs</span> <span>=</span> <span>kernel</span><span>.</span><span>n_regs</span>
        <span>size_smem</span> <span>=</span> <span>kernel</span><span>.</span><span>metadata</span><span>.</span><span>shared</span>
        <span>if</span> <span>is_hip</span><span>():</span>
            <span># NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.</span>
            <span># However, this is not always the case. In most cases all registers can be used as regular purpose registers.</span>
            <span># ISA SECTION (3.6.4 for CDNA3)</span>
            <span># VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used</span>
            <span># with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total</span>
            <span># VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is</span>
            <span># not required to be equal numbers of both types.</span>
            <span>if</span> <span>is_cdna</span><span>():</span>
                <span>NUM_GPRS</span> <span>=</span> <span>NUM_REGS</span> <span>*</span> <span>2</span>

            <span># MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.</span>
            <span># When we divide this number with WARP_SIZE we get maximum number of waves that can</span>
            <span># execute on a CU (multi-processor)  in parallel.</span>
            <span>MAX_NUM_THREADS</span> <span>=</span> <span>properties</span><span>[</span><span>"max_threads_per_sm"</span><span>]</span>
            <span>max_num_waves</span> <span>=</span> <span>MAX_NUM_THREADS</span> <span>//</span> <span>WARP_SIZE</span>
            <span>occupancy</span> <span>=</span> <span>min</span><span>(</span><span>NUM_GPRS</span> <span>//</span> <span>WARP_SIZE</span> <span>//</span> <span>n_regs</span><span>,</span> <span>max_num_waves</span><span>)</span> <span>//</span> <span>num_warps</span>
        <span>else</span><span>:</span>
            <span>occupancy</span> <span>=</span> <span>NUM_REGS</span> <span>//</span> <span>(</span><span>n_regs</span> <span>*</span> <span>WARP_SIZE</span> <span>*</span> <span>num_warps</span><span>)</span>
        <span>occupancy</span> <span>=</span> <span>min</span><span>(</span><span>occupancy</span><span>,</span> <span>SIZE_SMEM</span> <span>//</span> <span>size_smem</span><span>)</span>
        <span>num_programs</span> <span>=</span> <span>NUM_SM</span> <span>*</span> <span>occupancy</span>
        <span>kernels</span><span>[</span><span>BLOCK_SIZE</span><span>]</span> <span>=</span> <span>(</span><span>kernel</span><span>,</span> <span>num_programs</span><span>)</span>

    <span>num_programs</span> <span>=</span> <span>min</span><span>(</span><span>num_programs</span><span>,</span> <span>n_rows</span><span>)</span>

    <span># Create a number of persistent programs.</span>
    <span>kernel</span><span>[(</span><span>num_programs</span><span>,</span> <span>1</span><span>,</span> <span>1</span><span>)](</span>
        <span>y</span><span>,</span>
        <span>x</span><span>,</span>
        <span>x</span><span>.</span><span>stride</span><span>(</span><span>0</span><span>),</span>
        <span>y</span><span>.</span><span>stride</span><span>(</span><span>0</span><span>),</span>
        <span>n_rows</span><span>,</span>
        <span>n_cols</span><span>,</span>
    <span>)</span>
    <span>return</span> <span>y</span>

Unit Test

We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.

<span></span><span>torch</span><span>.</span><span>manual_seed</span><span>(</span><span>0</span><span>)</span>
<span>x</span> <span>=</span> <span>torch</span><span>.</span><span>randn</span><span>(</span><span>1823</span><span>,</span> <span>781</span><span>,</span> <span>device</span><span>=</span><span>'cuda'</span><span>)</span>
<span>y_triton</span> <span>=</span> <span>softmax</span><span>(</span><span>x</span><span>)</span>
<span>y_torch</span> <span>=</span> <span>torch</span><span>.</span><span>softmax</span><span>(</span><span>x</span><span>,</span> <span>axis</span><span>=</span><span>1</span><span>)</span>
<span>assert</span> <span>torch</span><span>.</span><span>allclose</span><span>(</span><span>y_triton</span><span>,</span> <span>y_torch</span><span>),</span> <span>(</span><span>y_triton</span><span>,</span> <span>y_torch</span><span>)</span>

As expected, the results are identical.

Benchmark

Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows. We will then compare its performance against (1) torch.softmax and (2) the naive_softmax defined above.

<span></span><span>@triton</span><span>.</span><span>testing</span><span>.</span><span>perf_report</span><span>(</span>
    <span>triton</span><span>.</span><span>testing</span><span>.</span><span>Benchmark</span><span>(</span>
        <span>x_names</span><span>=</span><span>[</span><span>'N'</span><span>],</span>  <span># argument names to use as an x-axis for the plot</span>
        <span>x_vals</span><span>=</span><span>[</span><span>128</span> <span>*</span> <span>i</span> <span>for</span> <span>i</span> <span>in</span> <span>range</span><span>(</span><span>2</span><span>,</span> <span>100</span><span>)],</span>  <span># different possible values for `x_name`</span>
        <span>line_arg</span><span>=</span><span>'provider'</span><span>,</span>  <span># argument name whose value corresponds to a different line in the plot</span>
        <span>line_vals</span><span>=</span><span>[</span><span>'triton'</span><span>,</span> <span>'torch'</span><span>],</span>  <span># possible values for `line_arg``</span>
        <span>line_names</span><span>=</span><span>[</span>
            <span>"Triton"</span><span>,</span>
            <span>"Torch"</span><span>,</span>
        <span>],</span>  <span># label name for the lines</span>
        <span>styles</span><span>=</span><span>[(</span><span>'blue'</span><span>,</span> <span>'-'</span><span>),</span> <span>(</span><span>'green'</span><span>,</span> <span>'-'</span><span>)],</span>  <span># line styles</span>
        <span>ylabel</span><span>=</span><span>"GB/s"</span><span>,</span>  <span># label name for the y-axis</span>
        <span>plot_name</span><span>=</span><span>"softmax-performance"</span><span>,</span>  <span># name for the plot. Used also as a file name for saving the plot.</span>
        <span>args</span><span>=</span><span>{</span><span>'M'</span><span>:</span> <span>4096</span><span>},</span>  <span># values for function arguments not in `x_names` and `y_name`</span>
    <span>))</span>
<span>def</span> <span>benchmark</span><span>(</span><span>M</span><span>,</span> <span>N</span><span>,</span> <span>provider</span><span>):</span>
    <span>x</span> <span>=</span> <span>torch</span><span>.</span><span>randn</span><span>(</span><span>M</span><span>,</span> <span>N</span><span>,</span> <span>device</span><span>=</span><span>'cuda'</span><span>,</span> <span>dtype</span><span>=</span><span>torch</span><span>.</span><span>float32</span><span>)</span>
    <span>stream</span> <span>=</span> <span>torch</span><span>.</span><span>cuda</span><span>.</span><span>Stream</span><span>()</span>
    <span>torch</span><span>.</span><span>cuda</span><span>.</span><span>set_stream</span><span>(</span><span>stream</span><span>)</span>
    <span>if</span> <span>provider</span> <span>==</span> <span>'torch'</span><span>:</span>
        <span>ms</span> <span>=</span> <span>triton</span><span>.</span><span>testing</span><span>.</span><span>do_bench</span><span>(</span><span>lambda</span><span>:</span> <span>torch</span><span>.</span><span>softmax</span><span>(</span><span>x</span><span>,</span> <span>axis</span><span>=-</span><span>1</span><span>))</span>
    <span>if</span> <span>provider</span> <span>==</span> <span>'triton'</span><span>:</span>
        <span>ms</span> <span>=</span> <span>triton</span><span>.</span><span>testing</span><span>.</span><span>do_bench</span><span>(</span><span>lambda</span><span>:</span> <span>softmax</span><span>(</span><span>x</span><span>))</span>
    <span>gbps</span> <span>=</span> <span>lambda</span> <span>ms</span><span>:</span> <span>2</span> <span>*</span> <span>x</span><span>.</span><span>nelement</span><span>()</span> <span>*</span> <span>x</span><span>.</span><span>element_size</span><span>()</span> <span>*</span> <span>1e-9</span> <span>/</span> <span>(</span><span>ms</span> <span>*</span> <span>1e-3</span><span>)</span>
    <span>return</span> <span>gbps</span><span>(</span><span>ms</span><span>)</span>


<span>benchmark</span><span>.</span><span>run</span><span>(</span><span>show_plots</span><span>=</span><span>True</span><span>,</span> <span>print_data</span><span>=</span><span>True</span><span>)</span>

02 fused softmax

<span></span>softmax-performance:
          N       Triton        Torch
0     256.0   478.919310   687.273787
1     384.0   605.411715   819.289040
2     512.0   761.099367   934.683618
3     640.0   785.552905   962.016130
4     768.0   880.073059  1025.759534
5     896.0   927.677054  1058.137819
6    1024.0   983.837789  1108.743692
7    1152.0  1096.790360   614.358556
8    1280.0  1148.233274   668.818744
9    1408.0  1150.227640   720.693012
10   1536.0  1182.473582   782.216974
11   1664.0  1216.574883   814.715998
12   1792.0  1242.628231   854.750376
13   1920.0  1253.575856   907.742024
14   2048.0  1269.852891   953.080248
15   2176.0  1257.273610   974.432016
16   2304.0  1262.677874  1011.595565
17   2432.0  1299.296021  1056.308627
18   2560.0  1306.784949  1082.867302
19   2688.0  1313.819078  1099.897287
20   2816.0  1318.463265  1126.399926
21   2944.0  1322.835757  1166.786437
22   3072.0  1346.370285  1180.412745
23   3200.0  1350.082383  1195.441919
24   3328.0  1360.093837  1225.507456
25   3456.0  1375.795286  1245.904506
26   3584.0  1372.955622  1261.663759
27   3712.0  1386.851514  1269.518524
28   3840.0  1383.941865  1303.207851
29   3968.0  1389.301195  1317.206272
30   4096.0  1397.348688  1324.026193
31   4224.0  1337.787438  1161.983073
32   4352.0  1334.628317  1175.156812
33   4480.0  1357.196330  1181.139507
34   4608.0  1364.957502  1195.702494
35   4736.0  1361.432159  1198.436404
36   4864.0  1378.560609  1219.552557
37   4992.0  1367.111266  1232.723697
38   5120.0  1374.124506  1255.575493
39   5248.0  1378.665107  1261.659380
40   5376.0  1376.411844  1287.514489
41   5504.0  1380.046906  1299.148837
42   5632.0  1382.392228  1315.560426
43   5760.0  1393.669997  1326.327157
44   5888.0  1389.150289  1340.074174
45   6016.0  1398.670392  1350.904213
46   6144.0  1410.218972  1376.087527
47   6272.0  1414.080366  1376.465508
48   6400.0  1416.444662  1385.075582
49   6528.0  1418.413758  1396.029679
50   6656.0  1424.202937  1400.217751
51   6784.0  1409.345707  1415.868223
52   6912.0  1428.837209  1421.806350
53   7040.0  1419.389760  1432.133212
54   7168.0  1424.869740  1433.771252
55   7296.0  1428.331895  1444.078677
56   7424.0  1430.010094  1446.560973
57   7552.0  1423.821714  1456.153560
58   7680.0  1433.745404  1458.530407
59   7808.0  1432.200076  1463.867323
60   7936.0  1433.262612  1465.890268
61   8064.0  1437.370991  1475.045892
62   8192.0  1434.594521  1482.481691
63   8320.0  1386.720429  1399.620234
64   8448.0  1381.399651  1402.663422
65   8576.0  1394.701625  1398.532505
66   8704.0  1384.298625  1400.508376
67   8832.0  1382.690748  1405.417788
68   8960.0  1396.893196  1410.663745
69   9088.0  1407.815867  1416.328473
70   9216.0  1402.560109  1426.002938
71   9344.0  1398.795763  1423.569547
72   9472.0  1398.391728  1434.063187
73   9600.0  1399.009274  1435.793545
74   9728.0  1396.017273  1440.856696
75   9856.0  1414.772213  1440.195824
76   9984.0  1393.640290  1453.226701
77  10112.0  1411.025063  1454.434968
78  10240.0  1416.371231  1466.229997
79  10368.0  1411.954147  1465.497981
80  10496.0  1414.468529  1469.190538
81  10624.0  1408.722590  1467.668473
82  10752.0  1404.295464  1469.379278
83  10880.0  1400.953900  1482.342526
84  11008.0  1420.597508  1474.695685
85  11136.0  1424.100701  1485.490602
86  11264.0  1430.205349  1488.547851
87  11392.0  1411.913312  1489.778907
88  11520.0  1422.768767  1494.569038
89  11648.0  1423.179210  1498.987919
90  11776.0  1431.071404  1500.151596
91  11904.0  1442.538703  1507.563154
92  12032.0  1421.563295  1508.734560
93  12160.0  1419.964136  1509.722295
94  12288.0  1435.906495  1391.634053
95  12416.0  1449.353849  1390.744306
96  12544.0  1441.752742  1392.910081
97  12672.0  1445.543061  1392.769793

In the above plot, we can see that:

  • Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.

  • Triton is noticeably faster than torch.softmax – in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.

Total running time of the script: (0 minutes 23.443 seconds)

Gallery generated by Sphinx-Gallery