Adding wheels to flash-attention
# August 20, 2023
flash-attention is a low level implementation of exact attention. Unlike torch, which processes attention multiplications in sequence, flash-attention
combines the operations into a fused kernel, which can speed up execution by 85%. And since attention is such a core primitive of most modern language models, it makes for much faster training and inference across the board.
It now has an install time that's just as fast. Tri Dao (the main package author) and I recently added precompiled binaries to the python package. I'll say upfront: this particular implementation is a rather unorthodox use of wheels. Standard wheels only can depend on operating system and python version; flash-attention requires CUDA and torch versions as well. So it naturally required a bit of off-roading. This is a breakdown of the approach we took.
What's in a wheel
Many optimized libraries contain C or C siblings, which must be built at some point before Python can execute it at runtime. Python's had support for these for a long time: first the setuptools egg format and now wheels. Both let maintainers delegate compilation to a CI machine before clients install the package. They pre-build a version of the code - targeted for a specific OS and Python version - and push this to pypi along the raw code. One build, potentially millions of installs.
There are three main concepts for wheels:
sdist
(source distribution): Raw code that is uploaded to pypi, so individuals can build from scratch if necessary. This is the fallback behavior when wheels are not available.
bdist
(binary distribution): Compiled version of code, shipped as binaries. This could be wheels, but also .exe
for Windows executables, .rpm
for Red Hat package manager, etc.
bdist_wheel
: a type of bdist
that creates a .whl
file. These wheels conform to PEP 427, which specifies the conventional format for how libraries are expected to build this code for it to be picked up by pypi.
Solution Sketch
The flash-attention library is a Python wrapper over C++ and CUDA, so at install time it needs to compile itself for the current OS and installed dependencies. Here we have a 5D tensor of dependencies: OS, Python Version, Torch Version, CUDA version, and the C++11 binary interface.
The existing wheel installation behavior was so close to what we needed, but it couldn't quite be shoehorned. There's a hard assumption in pip that wheels will only be based on host operating system and python version, and wheel filenames will be named accordingly. As such there's a lot of path-sniffing logic to determine matching resources. It would be near impossible to override all of these places without writing a fully custom wheel installer.
So let's take a step back. How is this whole process architected?
- pip / setuptools determine if there's a compatible resource in pypi
- If so, we install this version
- Otherwise, we use the sdist raw code to build the binary from scratch.
Build logic is handled through bdist_wheel, which implements the build logic for packages that aren't already built:
- Determine dependencies: OS, Python Version
- Construct a string for the wheel filename, specifying its dependencies
- run() the wheel building, which sets up the C-level compiler and builds the file in a temporary directory
- After running, determine if there were any errors in the build process
- If everything worked, move the filepath out of the built artifact into a permanent location
There might not be an easy way to modify the standard wheel installation logic, but there is an easy way to short circuit this build process. We target the third step. Instead of always running the build, we make it conditional: If a matching dependency is found, use that as the built file. If it's not, build from scratch. If we're clever about filepaths, from the rest of the bdist_wheel pipeline it will look like we just built the file. From there all downstream linking and installation should happen the same as if everything's completely local.
The Code
1. Install a custom cmdclass
in setup()
class CachedWheelsCommand(_bdist_wheel):
def run(self):
if FORCE_BUILD:
return super().run()
...
setup(
...
cmdclass={
'bdist_wheel': CachedWheelsCommand,
"build_ext": BuildExtension
} if ext_modules else {
'bdist_wheel': CachedWheelsCommand,
},
...
)
By default, setuptools
will build wheels through the bdist_wheel
command. It supports overriding the class that is used, however, by specifying an alternative command class in the setup()
. We keep the bulk of the logic but re-implement the behavior of the main runner.
We also keep an environment variable for force-building from scratch. This is useful if clients run into some issues during install time, and is also how CI will use our same setup.py file to force a new wheel build before it's pushed into the artifact repository.
2. Determining Dependencies
class CachedWheelsCommand(_bdist_wheel):
def run(self):
...
# Determine the version numbers that will be used to determine
# the correct wheel
# We're using the CUDA version used to build torch, not the
# one currently installed
torch_cuda_version = parse(torch.version.cuda)
torch_version_raw = parse(torch.__version__)
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
platform_name = get_platform()
flash_version = get_package_version()
cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
...
In addition to the standard wheel dependencies, flash-attention requires specific cuda versions, torch versions, and C++11 ABI versions to run. We first parse the dependencies that are installed locally to make sure we're pulling a compatible wheel.
3. Conventional github source
BASE_WHEEL_URL = "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}"
class CachedWheelsCommand(_bdist_wheel):
def run(self):
...
# Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename = f'{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl'
wheel_url = BASE_WHEEL_URL.format(
tag_name=f"v{flash_version}",
wheel_name=wheel_filename
)
print("Guessing wheel URL: ", wheel_url)
...
Since the package is already hosted on github, including the wheels in a release was pretty natural. Each file tied to a github release is associated with the given tag_name
of a particular release. We let the setup script guess this automatically using the current flash version that is specified in the versions file. The general artifact pattern is specified in BASE_WHEEL_URL
.
From there, we use the dependencies to build up a conventional name for the wheel. The format of the actual wheel_filename
doesn't technically matter. We just need to make sure CI builds the file to the same path, so it's uploaded properly to the github releases.
4. Download
class CachedWheelsCommand(_bdist_wheel):
def run(self):
...
try:
urllib.request.urlretrieve(wheel_url, wheel_filename)
# Make the archive
if not os.path.exists(self.dist_dir):
os.makedirs(self.dist_dir)
...
except urllib.error.HTTPError:
print("Precompiled wheel not found. Building from source...")
# If the wheel could not be downloaded, build from source
super().run()
While we could also parse the release page and determine artifacts that way, these github artifacts can be accessed directly from their URL. This is the most direct route to checking for the presence of a compatible wheel file that has been uploaded. If the file path exists, we have a prebuilt file available and download it locally. If not, we fall back to the super's run() behavior that will proceed to build the file from scratch.
5. Swapping the file
class CachedWheelsCommand(_bdist_wheel):
def run(self):
...
try:
...
impl_tag, abi_tag, plat_tag = self.get_tag()
archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
print("Raw wheel path", wheel_path)
os.rename(wheel_filename, wheel_path)
...
The rest of the bdist_wheel pipeline assumes that run() will write its artifact to a specific place. It uses the convention in wheel_path
here to validate success, and move this file to other system paths. We copy this same pattern from where it's defined originally in the bdist_wheel class:
f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
Writing to this convention ensures that for all logic downstream, the wheel is treated like a locally built copy of the build files.
Conclusion
I've written custom cmd classes before, so knew on some level that they were just executing arbitrary python code. But my mental model was still that they're only used for building local binaries. Doing surgery on the existing build_ext logic seemed out of the question.
But at the end of the day, the existing run() logic has a pretty simple API contract. Callers expect it to take in the raw source and write a binary file to a path. Everything else is left up to the implementation. cmd_class really is just a generic install hook for packages that need wheels; you could write as much logic as you'd like here and pip will run it alongside the install.
This custom wheel command decreased the package install time from a worst case 2hours (on one core) to an easy 10 seconds. The time just comes from having to download the new wheel dependency. It's certainly made our CI pipelines faster - hope it can also speed up your development workflow.