Tyler Sengia

Finding How PyTorch Samples from a Multinomial Distribution

Ever wondered how PyTorch exposes low-level C++ operations to Python? I have, and maybe I can save you a few hours of digging through source code to figure out how they do it.

At the time of writing, the most recent release of PyTorch is version 2.6.0 (commit 1a68837), so things may have changed by the time you read this post year later.

The Rabbit Hole that Brought Me Here

On a separate, related hobby project (soon to come, I’ll link to it once it’s out), I need to sample from a categorical distribution. Wikipedia has some different ways to sample from a categorical distribution, but I really want to see how PyTorch does it.

The PyTorch categorical distribution is implemented in the torch.distributions.categorical module, which the Categorical.sample() method calls the torch.multinomial() function.

Ok, great, now to see where torch.multinomial() is implemented. Shift + Click on it in my IDE to take me to the definition and… nothing. Great, now I have to really dig.

Start from the Root of the Tree

PyTorch is overwhelmingly complex, but it is a Python package, and like any Python package, it defines its default imports in its __init__.py file.

Anything that you access via torch. in Python must be in the __all__ list in the __init__.py file. So, we need to figure out where this multinomial method gets added to the __all__ list.

At line 2101 of __init__.py I find a mention of __all__.append():

__name, __obj = "", None
for __name in dir(_C._VariableFunctions):
    if __name.startswith("__") or __name in PRIVATE_OPS:
        continue
    __obj = getattr(_C._VariableFunctions, __name)
    __obj.__module__ = __name__  # "torch"
    # Hide some APIs that should not be public
    if __name == "segment_reduce":
        # TODO: Once the undocumented FC window is passed, remove the line bellow
        globals()[__name] = __obj
        __name = "_" + __name
    globals()[__name] = __obj
    if not __name.startswith("_"):
        __all__.append(__name)

Ok, looks like PyTorch is importing a C/C++ library exposed as the _C module, and it references a _VariableFunctions module under it. Oh, there’s a torch/_C/_VariableFunctions.pyi.in file! That must be what we’re looking for.

Besides for some boilerplate imports at the beginning, it’s blank. The _VariableFunctions.pyi.in is a template that has not been filled out yet, at the end you can see that the ${function_hints} variable needs substituted in.

Template Substitution

Ok, something must be populating the _VariableFunctions.pyi.in template. A quick search across the repository turns up line 1462 of tools/pyi/gen_pyi.py. The docstring at the top of gen_pyi.py reads:

This module implements generation of type stubs for PyTorch,
enabling use of autocomplete in IDEs like PyCharm, which otherwise
don't understand C extension modules.

At the moment, this module only handles type stubs for torch and
torch.Tensor.  It should eventually be expanded to cover all functions
which come are autogenerated.

Here's our general strategy:

- We start off with a hand-written __init__.pyi.in file.  This
  file contains type definitions for everything we cannot automatically
  generate, including pure Python definitions directly in __init__.py
  (the latter case should be pretty rare).

- We go through automatically bound functions based on the
  type information recorded in native_functions.yaml and
  generate type hints for them (generate_type_hints)

There are a number of type hints which we've special-cased;
read gen_pyi for the gory details.

Wow! Ok, this tells me a lot, and it gives me my next lead!

The native_functions.py File

Apparently this gen_pyi.py file generates that _VariableFunctions.pyi using data declared in native_functions.yaml. Another quick search shows that the native_functions.yaml file being referenced is aten/src/ATen/native/native_functions.yaml.

A search shows that my torch.multinomial() function is declared in this file on line 9524, and the CPU kernel it uses is named multinomial.

Ok, which C++ file does this multinomial kernel live in? The first line of the native_functions.yaml file helpfully says See README.md in this directory for more guidance, so I should check that out first. Always RTFM, especially when the documentation tells you to.

The aten/src/ATen/native/README.md file has some more details about the native_functions.yaml mechanism, and explains that each function may have a different dispatch depending upon the backend that is being used, and that the default namespace is assumed to be at::native.

Ok, so the multinomial kernel referenced in native_functions.yaml is actually at::native::multinomial in C++. Time to find that.

Finding at::native::multinomial

Instead of using Ctrl + F to find it, I did a scan of the many .cpp files under aten/src/ATen/native and spotted the Distributions.cpp file. Ah, multinomial is a type of random distribution, let’s search there.

And I found it! All the way at the end of the file, line 627 of aten/src/ATen/native/Distributions.cpp.

Conclusion

Now I can read the exact algorithm that PyTorch uses to sample from a multinomial distribution.

But also, I now have a better understanding of the colossal PyTorch codebase in case I ever need to track down implementation details again.

In summary, these are the steps that PyTorch takes to expose an ATen kernel to you via the torch. module:

  1. Kernel is written in aten/src/ATen/ somewhere, where it is then compiled into PyTorch’s dynamic libraries.
  2. Developers add the function to the native_functions.yaml file, which is then used to generate a .pyi file to give your IDE typehints.
  3. At import, the __init__.py file adds the function to the __all__ list.