.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "intermediate/scaled_dot_product_attention_tutorial.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        Click :ref:`here <sphx_glr_download_intermediate_scaled_dot_product_attention_tutorial.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_intermediate_scaled_dot_product_attention_tutorial.py:


(Beta) Implementing High-Performance Transformers with Scaled Dot Product Attention (SDPA)
==========================================================================================


**Author:** `Driss Guessous <https://github.com/drisspg>`_

.. GENERATED FROM PYTHON SOURCE LINES 10-42

Summary
~~~~~~~~

In this tutorial, we want to highlight a new ``torch.nn.functional`` function
that can be helpful for implementing transformer architectures. The
function is named ``torch.nn.functional.scaled_dot_product_attention``.
For detailed description of the function, see the `PyTorch documentation <https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention>`__.
This function has already been incorporated into ``torch.nn.MultiheadAttention`` and ``torch.nn.TransformerEncoderLayer``.

Overview
~~~~~~~~~
At a high level, this PyTorch function calculates the
scaled dot product attention (SDPA) between query, key, and value according to
the definition found in the paper `Attention is all you
need <https://arxiv.org/abs/1706.03762>`__. While this function can
be written in PyTorch using existing functions, a fused implementation can provide
large performance benefits over a naive implementation.

Fused implementations
~~~~~~~~~~~~~~~~~~~~~~

For CUDA tensor inputs, the function will dispatch into one of the following
implementations:

* `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness <https://arxiv.org/abs/2205.14135>`__
* `Memory-Efficient Attention <https://github.com/facebookresearch/xformers>`__
* A PyTorch implementation defined in C++

.. note::

  This tutorial requires PyTorch 2.0.0 or later.


.. GENERATED FROM PYTHON SOURCE LINES 42-53

.. code-block:: default


    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Example Usage:
    query, key, value = torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device)
    F.scaled_dot_product_attention(query, key, value)



.. GENERATED FROM PYTHON SOURCE LINES 54-65

Explicit Dispatcher Control
~~~~~~~~~~~~~~~~~~~~~~~~~~~

While the function will implicitly dispatch to one of the three
implementations, the user can also explicitly control the dispatch via
the use of a context manager. This context manager allows users to
explicitly disable certain implementations. If a user wants to ensure
the function is indeed using the fastest implementation for their
specific inputs, the context manager can be used to sweep through
measuring performance.


.. GENERATED FROM PYTHON SOURCE LINES 65-116

.. code-block:: default


    # Lets define a helpful benchmarking function:
    import torch.utils.benchmark as benchmark
    def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
        t0 = benchmark.Timer(
            stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
        )
        return t0.blocked_autorange().mean * 1e6

    # Lets define the hyper-parameters of our input
    batch_size = 32
    max_sequence_len = 1024
    num_heads = 32
    embed_dimension = 32

    dtype = torch.float16

    query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
    key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
    value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)

    print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")

    # Lets explore the speed of each of the 3 implementations
    from torch.backends.cuda import sdp_kernel, SDPBackend

    # Helpful arguments mapper
    backend_map = {
        SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False},
        SDPBackend.FLASH_ATTENTION: {"enable_math": False, "enable_flash": True, "enable_mem_efficient": False},
        SDPBackend.EFFICIENT_ATTENTION: {
            "enable_math": False, "enable_flash": False, "enable_mem_efficient": True}
    }

    with sdp_kernel(**backend_map[SDPBackend.MATH]):
        print(f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")


    with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
        try:
            print(f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
        except RuntimeError:
            print("FlashAttention is not supported. See warnings for reasons.")

    with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
        try:
            print(f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
        except RuntimeError:
            print("EfficientAttention is not supported. See warnings for reasons.")



.. GENERATED FROM PYTHON SOURCE LINES 117-126

Hardware dependence
~~~~~~~~~~~~~~~~~~~

Depending on what machine you ran the above cell on and what hardware is
available, your results might be different.
- If you don’t have a GPU and are running on CPU then the context manager
will have no effect and all three runs should return similar timings.
- Depending on what compute capability your graphics card supports
flash attention or memory efficient might have failed.

.. GENERATED FROM PYTHON SOURCE LINES 129-136

Causal Self Attention
~~~~~~~~~~~~~~~~~~~~~

Below is an example implementation of a multi-headed causal self
attention block inspired by
`Andrej Karpathy NanoGPT <https://github.com/karpathy/nanoGPT>`__ repository.


.. GENERATED FROM PYTHON SOURCE LINES 136-189

.. code-block:: default


    class CausalSelfAttention(nn.Module):

        def __init__(self, num_heads: int, embed_dimension: int, bias: bool=False, is_causal: bool=False, dropout:float=0.0):
            super().__init__()
            assert embed_dimension % num_heads == 0
            # key, query, value projections for all heads, but in a batch
            self.c_attn = nn.Linear(embed_dimension, 3 * embed_dimension, bias=bias)
            # output projection
            self.c_proj = nn.Linear(embed_dimension, embed_dimension, bias=bias)
            # regularization
            self.dropout = dropout
            self.resid_dropout = nn.Dropout(dropout)
            self.num_heads = num_heads
            self.embed_dimension = embed_dimension
            # Perform causal masking
            self.is_causal = is_causal

        def forward(self, x):
            # calculate query, key, values for all heads in batch and move head forward to be the batch dim
            query_projected = self.c_attn(x)

            batch_size = query_projected.size(0)
            embed_dim = query_projected.size(2)
            head_dim = embed_dim // (self.num_heads * 3)

            query, key, value = query_projected.chunk(3, -1)
            query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
            key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
            value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)

            if self.training:
                dropout = self.dropout
                is_causal = self.is_causal
            else:
                dropout = 0.0
                is_causal = False

            y = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=dropout, is_causal=is_causal)
            y = y.transpose(1, 2).view(batch_size, -1, self.num_heads * head_dim)

            y = self.resid_dropout(self.c_proj(y))
            return y


    num_heads = 8
    heads_per_dim = 64
    embed_dimension = num_heads * heads_per_dim
    dtype = torch.float16
    model = CausalSelfAttention(num_heads=num_heads, embed_dimension=embed_dimension, bias=False, is_causal=True, dropout=0.1).to("cuda").to(dtype).eval()
    print(model)



.. GENERATED FROM PYTHON SOURCE LINES 190-197

``NestedTensor`` and Dense tensor support
-----------------------------------------

SDPA supports both ``NestedTensor`` and Dense tensor inputs. ``NestedTensors`` handle the case where the input is a batch of variable length sequences
without needing to pad each sequence to the maximum length in the batch. For more information about ``NestedTensors`` see
`torch.nested <https://pytorch.org/docs/stable/nested.html>`__ and `NestedTensors Tutorial <https://pytorch.org/tutorials/prototype/nestedtensor.html>`__.


.. GENERATED FROM PYTHON SOURCE LINES 197-250

.. code-block:: default


    import random
    def generate_rand_batch(
        batch_size,
        max_sequence_len,
        embed_dimension,
        pad_percentage=None,
        dtype=torch.float16,
        device="cuda",
    ):
        if not pad_percentage:
            return (
                torch.randn(
                    batch_size,
                    max_sequence_len,
                    embed_dimension,
                    dtype=dtype,
                    device=device,
                ),
                None,
            )
        # Random sequence lengths
        seq_len_list = [
            int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01)))
            for _ in range(batch_size)
        ]
        # Make random entry in the batch have max sequence length
        seq_len_list[random.randint(0, batch_size - 1)] = max_sequence_len
        return (
            torch.nested.nested_tensor(
                [
                    torch.randn(seq_len, embed_dimension,
                                dtype=dtype, device=device)
                    for seq_len in seq_len_list
                ]
            ),
            seq_len_list,
        )

    random_nt, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=0.5, dtype=dtype, device=device)
    random_dense, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=None, dtype=dtype, device=device)

    # Currently the fused implementations don't support ``NestedTensor`` for training
    model.eval()

    with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
        try:
            print(f"Random NT runs in {benchmark_torch_function_in_microseconds(model, random_nt):.3f} microseconds")
            print(f"Random Dense runs in {benchmark_torch_function_in_microseconds(model, random_dense):.3f} microseconds")
        except RuntimeError:
            print("FlashAttention is not supported. See warnings for reasons.")



.. GENERATED FROM PYTHON SOURCE LINES 251-261

Using SDPA with ``torch.compile``
=================================

With the release of PyTorch 2.0, a new feature called
``torch.compile()`` has been introduced, which can provide
significant performance improvements over eager mode.
Scaled dot product attention is fully composable with ``torch.compile()``.
To demonstrate this, let's compile the ``CausalSelfAttention`` module using
``torch.compile()`` and observe the resulting performance improvements.


.. GENERATED FROM PYTHON SOURCE LINES 261-277

.. code-block:: default


    batch_size = 32
    max_sequence_len = 256
    x = torch.rand(batch_size, max_sequence_len,
                   embed_dimension, device=device, dtype=dtype)
    print(
        f"The non compiled module runs in  {benchmark_torch_function_in_microseconds(model, x):.3f} microseconds")


    compiled_model = torch.compile(model)
    # Let's compile it
    compiled_model(x)
    print(
        f"The compiled module runs in  {benchmark_torch_function_in_microseconds(compiled_model, x):.3f} microseconds")



.. GENERATED FROM PYTHON SOURCE LINES 278-285

The exact execution time is dependent on machine, however the results for mine:
The non compiled module runs in  166.616 microseconds
The compiled module runs in  166.726 microseconds
That is not what we were expecting. Let's dig a little deeper.
PyTorch comes with an amazing built-in profiler that you can use to
inspect the performance characteristics of your code.


.. GENERATED FROM PYTHON SOURCE LINES 286-313

.. code-block:: default


    from torch.profiler import profile, record_function, ProfilerActivity
    activities = [ProfilerActivity.CPU]
    if device == 'cuda':
        activities.append(ProfilerActivity.CUDA)

    with profile(activities=activities, record_shapes=False) as prof:
        with record_function(" Non-Compilied Causal Attention"):
            for _ in range(25):
                model(x)
    print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))


    with profile(activities=activities, record_shapes=False) as prof:
        with record_function("Compiled Causal Attention"):
            for _ in range(25):
                compiled_model(x)
    print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

    # For even more insights, you can export the trace and use ``chrome://tracing`` to view the results
    # ::
    #
    #    prof.export_chrome_trace("compiled_causal_attention_trace.json").





.. GENERATED FROM PYTHON SOURCE LINES 314-329

The previous code snippet generates a report of the top 10 PyTorch functions
that consumed the most GPU execution time, for both the compiled and non-compiled module.
The analysis reveals that the majority of time spent on the GPU is concentrated
on the same set of functions for both modules.
The reason for this here is that ``torch.compile`` is very good at removing the
framework overhead associated with PyTorch. If your model is launching
large, efficient CUDA kernels, which in this case ``CausaulSelfAttention``
is, then the overhead of PyTorch can be hidden.

In reality, your module does not normally consist of a singular
``CausalSelfAttention`` block. When experimenting with `Andrej Karpathy NanoGPT <https://github.com/karpathy/nanoGPT>`__ repository, compiling
the module took the time per train step from: ``6090.49ms`` to
``3273.17ms``! This was done on commit: ``ae3a8d5`` of NanoGPT training on
the Shakespeare dataset.


.. GENERATED FROM PYTHON SOURCE LINES 332-344

Conclusion
==========

In this tutorial, we have demonstrated the basic usage of
``torch.nn.functional.scaled_dot_product_attention``. We have shown how
the ``sdp_kernel`` context manager can be used to assert a certain
implementation is used on GPU. As well, we built a simple
``CausalSelfAttention`` module that works with ``NestedTensor`` and is torch
compilable. In the process we have shown how to the profiling tools can
be used to explore the performance characteristics of a user defined
module.



.. rst-class:: sphx-glr-timing

   **Total running time of the script:** ( 0 minutes  0.000 seconds)


.. _sphx_glr_download_intermediate_scaled_dot_product_attention_tutorial.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example


    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: scaled_dot_product_attention_tutorial.py <scaled_dot_product_attention_tutorial.py>`

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: scaled_dot_product_attention_tutorial.ipynb <scaled_dot_product_attention_tutorial.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_