Generating database migrations with acyclic graphs
# July 3, 2024
Mountaineer is a fresh take on fullstack webapp development, with typehints and first-class IDE support from the database to the frontend. It's fully open source and live in production environments.
Mountaineer 0.5.0 introduced database migration support, so you can now upgrade production databases directly from the CLI. It generates SQL for you automatically instead of writing manual table migrations, and removes the need for third party packages to support the same functionality. Let's dive into the implementation details of how we implemented a migration engine that relies on building and traversing a table schema graph.
But first, the finale. Once you make changes to your code definitions, all the migration process takes is:
$ poetry run migrate generate
$ poetry run migrate apply
$ poetry run migrate rollback
The migration files themselves declare an up
and down
migration to both update your database to the proposed specification, and downgrade to the previous value in case of a breaking change.1
# myproject/migrations/rev_1719350766.py
from fastapi.param_functions import Depends
from sqlmodel import text
from mountaineer.migrations.actions import ColumnType
from mountaineer.migrations.dependency import MigrationDependencies
from mountaineer.migrations.migration import MigrationRevisionBase
from mountaineer.migrations.migrator import Migrator
class MigrationRevision(MigrationRevisionBase):
up_revision: str = "1719350766"
down_revision: str | None = "1718753305"
async def up(
self,
migrator: Migrator = Depends(MigrationDependencies.get_migrator),
):
await migrator.actor.add_column(
table_name="pizza",
column_name="topping",
explicit_data_type=ColumnType.INTEGER,
explicit_data_is_list=False,
custom_data_type=None,
)
await migrator.actor.add_not_null(
table_name="pizza", column_name="topping"
)
async def down(
self,
migrator: Migrator = Depends(MigrationDependencies.get_migrator),
):
await migrator.actor.drop_column(
table_name="pizza", column_name="topping"
)
await migrator.actor.drop_not_null(
table_name="pizza", column_name="topping"
)
Problem framing
Database migration is a special case of declarative state transformation2. We want to convert the current state of the database (postgres) to the desired state (in code). We also want to do so with the minimum change possible - retaining the integrity of the data and backwards compatibility with previous clients where possible.
Formally, we can observe S_t
as the current state of code definitions. This represents the desired database schema that our project expects to interface with. Similarly we can observe the current state of the database T_t
, which is usually the sum of all past migrations applied to it starting from an empty database table at T_0
. However, T_t
might not exclusively be the result of past migrations. It may also have some changes that are made live on a database (adding an index, updating an attribute) that we need to also incorporate within the current state mapping changes.
Our goal is to find a migration function that maps m(T_t) = S_t
. We also want to find a function that does the inverse, m^-1(S_t) = T_t
so we can rollback the changes.
What's in a database
Databases use hierarchical data structures to index data efficiently, and we can leverage this when we consider how we want to implement data migrations. Databases are organized by tables, and then by columns, where each column can have varying properties. These properties can be attached at different levels of uniqueness within the database.
Database > Table > Column
> Column Attributes
> Constraints
> Indexes
> Enums
Constraints cover cases like foreign keys, primary keys, and unique values. Constraints and indexes in general can span multiple columns. Enums by definition have to be unique to the entire table. You can't have two enum sets with the same name because the same enum can be used across the table.3
Take the following schema:
_____________________________________________________ | Table: pizza | |-----------------------------------------------------| | Column | Type | Constraints | |------------|---------------|-----------------------| | pizza_id | INTEGER | PRIMARY KEY, | | | | AUTOINCREMENT | | name | VARCHAR(100) | UNIQUE | | description| TEXT | | | price | DECIMAL(8, 2) | | | size | ENUM('ItemSize')| | | vegetarian | BOOLEAN | | | created_at | TIMESTAMP | DEFAULT NOW() | -----------------------------------------------------
With a simple index across columns:
_____________________________________________________ | Table metadata: pizza | |-----------------------------------------------------| | | Index Name | Indexed Columns | Desc | | +-----------------+-----------------------+--------+ | | pizza_size_idx | pizza_id, size | | -----------------------------------------------------
And one enum definition:
_____________________________________________________ | Global Enum Values | |-----------------------------------------------------| | | Name | Values | | | +-----------------+-----------------------+--------+ | | ItemSize | 'Small', 'Medium', | | | | | 'Large' | | -----------------------------------------------------
The SQL to reconstruct the table definition looks like this:
-- Global Enum Values as constants
CREATE TYPE ItemSize AS ENUM ('Small', 'Medium', 'Large');
-- Create pizza table
CREATE TABLE pizza (
pizza_id INTEGER PRIMARY KEY AUTOINCREMENT,
name VARCHAR(100) UNIQUE,
description TEXT,
price DECIMAL(8, 2),
size ItemSize,
vegetarian BOOLEAN,
created_at DATE,
);
-- Create index on pizza table
CREATE INDEX pizza_size_idx ON pizza (pizza_id, size);
Our programatic schema is defined with this same hierarchy although it is a bit more implicit.
from enum import StrEnum
from sqlmodel import SQLModel, Field, Index
import sqlalchemy as sa
from datetime import datetime
class ItemSize(StrEnum):
Small = 'Small'
Medium = 'Medium'
Large = 'Large'
class Pizza(SQLModel, table=True):
pizza_id: int = Field(primary_key=True, sa_column=sa.Column(sa.Integer))
name: str
description: str | None = None
price: float = True
size: ItemSize = Field(sa_column=sa.Column(sa.Enum(ItemSize)))
vegetarian: bool = False
created_at: date
__table_args__ = (
sa.Index(
pizza_id.sa_column,
size.sa_column,
name="pizza_size_idx",
),
)
Within this code schema we default to defining constraints and enums directly at the column level, since this mirrors Python's class definition syntax. But some cross-column values are defined at the table arguments level in the __table_args__
. We also specify the pizza_id.sa_column
explicitly. Much like the database itself we're still referencing columns that we expect to exist before we add the constraints.4
Everything thus far has just been vanilla SQL and vanilla Python. Let's look a little harder at the migration logic now that a base schema is defined. For the sake of these examples we'll focus on changing the type of the created_at
column to be a datetime
instead of a date
.
Intermediary schema
To find a proper mapping function between states, it's much easier to deal with the old state and the new state as equivalent object classes. Trying to compare between in-memory SQLModels with in-database tables is a headache and harder to unit test. Instead we want to compare SQLModel<->SQLModel or Table<->Table. Which is better?
It's certainly possible to convert the database definitions into SQLModel objects and then compare the attributes of the Field()
objects that we saw above.5 But because SQLModels prioritize semantic definition, there are a few ways specify a single schema definition. For instance:
- Definition of columns implicitly (
str
) versus explicitly (sa.Column(sa.Varchar)
) - Constraints at the column level (
Field(unique=True)
) vs. table level (sa.UniqueConstraint()
) - Inline enum literal (
Literal['A', 'B']
) versus a separate enum class (MyEnum
)
Each of the above cases result in the same table output but the in-memory representations will differ. There are several edge cases to construction in this way.
Looking at the table directly provides a cleaner mapping of what we actually have to do eventually: output SQL commands that will manipulate the database state. Tables will need to be parsed into an intermediary in-memory representation anyway. So our best bet is to define an intermediary schema that more closely maps to the different database primitives. Taking the column definition as an example:
class DBColumn(BaseModel):
table_name: str
column_name: str
column_type: ColumnType
column_is_list: bool
nullable: bool
It's simple to write separate database->intermediary and SQLModel->intermediary parsers, and harder to write a wholesale converter from one representation to the other. Usually a tell-tale sign that this intermediary abstraction is warranted.
Once the parsers operate, we have a list of the different database objects in a flat ordering. This includes columns, constraints, tables, enums, etc. Everything that should be migrated must be converted into this intermediate space:
Database: [
DBColumn(
table_column="pizza",
column_name="created_at",
column_type=ColumnType.DATE,
column_is_list=False,
nullable=False,
)
...
]
Schema: [
DBColumn(
table_column="pizza",
column_name="created_at",
column_type=ColumnType.DATETIME, # Will require a type migration
column_is_list=False,
nullable=False,
)
...
]
Aligning objects
We now need some way to map from one set of intermediate values to the other. Thus far we only have two different lists of object representations of the table. How can we tell when an object is truly new (like a new column) versus a desired transformation of an exiting object (like a changed column type)?
Here we make our first design assumption. Users don't expect the automatic migration process to be automatic for every case, only reasonable defaults for the most common modifications. If users change the name of a column, for instance, it's okay for us to suggest deleting the old column and recreating with the new type. If they know the data should be able map properly automatically (or requires a manual intervention) they should be the ones to specify the steps.
For this reason we can establish a convention of each intermediate object having a representation
string. A representation is a unique identifier that we expect to not change for the same underlying object. These values should be available in both sources (ie. both the database and the SQLModel definition). A column would therefore be:
class DBColumn(BaseModel):
...
def representation(self):
return f"{self.table_name}-{self.column_name}"
Once a column has a different column name, or migrates to a different table, it's technically a different column within Postgres' schema. Whereas a constraint would only include the constraint name, since this attribute must be explicitly specified both in-database and in-memory:
class DBConstraint(BaseModel):
...
def representation(self):
return self.constraint_name
With these representations in-place, we now have a clear mapping of the old state and the new state of the objects. An object that appears across this grid should migrate its values. Otherwise we need to fully create or destroy the database object.
+-------------------------------+-------------------------------+ | Old (Database) | New (Schema) | +-------------------------------+-------------------------------+ | rep: pizza-created_at | rep: pizza-created_at | | table: pizza | table: pizza | | column_name: created_at | column_name: created_at | | column_type: DATE | column_type: DATETIME | ^ Migration Required | column_is_list: False | column_is_list: False | | nullable: False | nullable: False | +-------------------------------+-------------------------------+ | ... | ... | +-------------------------------+-------------------------------+
Migrating object attributes
Now that we have a single logical object for each element in the database hierarchy, we need to figure out the actual transformation to manipulate one object into the other. There are three main transformations we can make to a database primitive:
- Fully new object, need to create a new object with all the constraints from scratch
- Fully removed object, need to drop whatever object exists from the database schema
- Modified instance where some values are changed and we want to change it in-place
The first and second cases are trivial. And they're by definition invertible (with some data loss) - a rollback of a new object creation is a deletion and vice-versa. This allows us to flip the ordering to create the inverse function. The object primitives are in the best place to determine their migration paths, so we set up an API that allows the migration to take place differently for different objects. After this function is executed, the current database state previous
should be manipulated into the new definition provided by next
.
class DBColumn(BaseModel):
...
@classmethod
def migrate(cls, previous: DBColumn, next: DBColumn):
pass
While modifying attributes of an existing object seems more difficult, it's mostly the same as the case for creation and deletion. We simply look at the value of the object properties and determine how to manipulate one version into the other. This ends up looking like a series of if/then
conditional checks based on the equality of the previous and next values. It can then emit the relevant transformations. This sometimes can be handled natively like with a ALTER TABLE
and sometimes it requires a full deletion and recreation. The logic is conditional.
Exposing this migrate
to clients is probably not the best user experience. The API is too high level for users to inspect the business logic that's about to be executed by the database. Knowing that a column is going to be transformed to the metadata of another column buries the specifics of what we think they intend. In our case here, knowing that a column needs a new type obscures the way that we assign it that new type (via a delete/create, via a modification of schema, etc).
Instead we want to give the users the final say. We want to set up our migrations in terms of the values that are added, and the actions that will take place as part of the SQL. We want to track these changes without actually performing them. Our goal instead is to generate user-code that will allow users to modify their database.
Dry-run of migration
SQLModel (and SQLAlchemy by extension) provides utilities for constructing schemas from scratch, or interacting with existing tables. It doesn't bundle functions that make schema level manipulations to an existing database.
We can set up a helper class DBTransforms
to manage the creation and modification of our desired primitives.6 At the last count these covered:
- Adding/removing a new column
- Modifying a column type
- Adding/removing a constraint
- Adding/removing an enum
- Modifying enum values. This is an instance of a more complicated primitive transformation since enum manipulations aren't supported by Postgres once they're created. At the SQL level this requires the renaming of an existing enum, the creation of a new enum with the new values, migration of all old columns to the new enum, and removal of the outdated enum.
These migration primitives are called from two places. They need to be callable by users manually, in case they want to manually specify any additional migration behavior. They also have to generate the Python code that is placed into the automatic migration files. We could separate out these two different uses (so the mechanism to create the Python code is different from the actual class that calls the primitives), but this introduces the risk of code-skew over time that isn't caught by our static typechecking.
Instead it's a bit cleaner to add a dry_run
mode for the database transformation. If the dry run is enabled, it will just store a log of the function that is called. If we're calling it during the actual migration it will execute the desired logic on the database.
class DBTransforms:
def __init__(self, db_session: AsyncSession):
self.db_session = db_session
self.dry_run = True
self.dry_run_actions = []
async def add_column(self, column_name: str, column_type: str):
if self.dry_run:
self.dry_run_actions.append(
self.add_column,
{
"column_name": column_name,
"column_type": column_type
}
)
else:
await self.db_session.execute("...")
Then each primitive can interface with the transform class directly, and the logic can remain the same regardless of whether we're running in production or a dry-run:
class DBColumn(BaseModel):
@classmethod
async def migrate(
cls,
transforms: DBTransforms,
previous: DBColumn | None,
next: DBColumn | None
):
if previous and next:
if previous.column_type != next.column_type:
await DBTransforms.modify_column_type(...)
elif previous:
await DBTransforms.drop_column(...)
elif next:
await DBTransforms.add_column(next.column_name, next.column_type)
transforms = DBTransforms(...)
for old, new in aligned_objects:
await old.migrate(transforms, old, new)
We'll use the collected dry-run function calls to eventually build up the template files. But if we do so right now, we'll end up with object migrations happening at random since there's no ordering baked into the aligned_objects
list. This in turn risks dependency conflicts like attempting to create a new column before a new table. We need to first establish a proper ordering of the dependencies.
Dependencies
Aligning objects themselves is one issue, but this still leaves the issue of which transformations to apply in which order. Tables need to be created before the columns within it, enums types before columns, and columns before their constraints. This seems like a textbook case of topographic sorting to guarantee we'll prioritize parent nodes over all their children. But what are parent and children nodes here?
It's time to consider the way that we actually construct these intermediate objects given the original schema. We want to start with the object model classes themselves - since these are the ones that are registered in the global model registry when they are imported into the runtime. Starting at these high level definitions, we can iterate over the columns, and from there the built-in column types and any column level constraints. We can separately iterate over the table level definitions for constraints that span multiple columns.
This recursive traversal of the schema roughly mirrors how we want to order our creation. But we can make it more explicit. With every intermediate object (DBColumn
, DBConstraint
, etc) that we extract from the core schema, we can also provide a list of dependencies that must be created before this object is created. We usually won't know the full objects themselves but we can certainly derive their representation
values using the knowledge that we have at this particular leg of the schema traversal. In addition to the intermediate objects themselves, our schema generator therefore should yield pairs of IntermediateDBObject, [str representation of dependencies]
.
For the above schema it breaks down to something like this:
+----------------------+ | Table | | pizza | +----------------------+ | v +---------+------------+------------+-------------+------------+ | | | | | | v v v v v v +---------+ +------+ +----------+ +-------+ +------------+ +----------+ | pizza_id| | name | |description| | price | | size | |vegetarian| +---------+ +------+ +----------+ +-------+ +------------+ +----------+ | | ^ | | | +------------+--------------------------------+ | | | v | +--------------------------+ | | Index: pizza_size_idx | | | Columns: pizza_id, size | | +--------------------------+ | | | +----------------------+ | Global Enum | | (ItemSize) | +----------------------+
With this DAG established, we can order our intermediate actions. Calling the migration on the properly sorted nodes will resolve all of these implicit ordering rules that postgres requires.
ts = TopologicalSorter()
for node, dependencies in intermediate_schemas:
ts.add(node, *dependencies)
migration_order = list(ts.static_order())
representation_to_ordering = {
item.representation: i
for i, item in enumerate(migration_order)
}
for old, new in sorted(
aligned_objects,
key=lambda x: representation_to_ordering[x.representation]
):
await old.migrate(transforms, old, new)
We expect our topographically sorted DAG will resolve to:
ItemSize, Table pizza, All pizza columns, All pizza constraints
Template Files
From these properly sorted dry-run representations, building up the actual migration file is easy. We can just use some string interpolation to call the function based on the signature we've logged.
from inspect import get_signature
migration_contents = ""
for fn_pointer, args in transforms.dry_run_actions:
fn_name = fn_pointer.__name__
fn_args = ",".join(f"{key}={value}" for key, value in args.items())
migration_contents += f"{fn_name}({fn_args})n"
These can be injected directly into a template for the full migration class. Our particular implementation just uses a simple Python string with templating parameters but you could also use Jinja or whatever other formatter you prefer.
Insert this string into a migration
directory within the user's project, and you have a version controlled file that tracks the migration changes needed to bring the database into the current schema:
myproject/
└── migrations/
├── rev_1715305172.py
├── rev_1715367134.py
├── rev_1715639469.py
└── rev_1715816063.py
At execution time we simply run whatever class files haven't been run yet. Since the migration files themselves track a up_revision
and down_revision
hash, we can also resolve the desired ordering of these migration files and throw an error if any revisions run out of sequence.
Doing this robustly requires some coordination on the database side to track the current state of the migrated schema. We track the already applied database state with a small helper migration_info
table within the database.
Conclusions
This approach of projecting our database schema into a dependency graph ends up being a really flexible abstraction to handle changes over time. Extending schema support or handling edge cases is usually just a few line change in the parser. The rest of the pipeline (migration logic, code generation, order resolution) can stay unchanged once the intermediary objects are properly created.
As we continue to improve our support for these migrations in Mountaineer, we'll add more guard rails around proper data transformations and proactive warning of issues. The goal is to make database migrations as painless as local schema changes - even with production data that you need to migrate on the fly.
If you're interested in easy database migrations, check out the new migration logic on GitHub. Also shipping in the latest stable release.
-
This layout should be familiar to those that have used Alembic in the past with SQLAlchemy. The approach in Mountaineer focuses on defining migration classes versus functions. It also supports arbitrary dependency injection as part of the function arguments, so you can pass any augmented object directly into the migration. ↢
-
As also seen in languages like HCL / Terraform. ↢
-
Enums in Postgres become actual column type objects themselves (at the same level as UUIDs, varchars, etc). We can just customize their values to validate the items that we want. ↢
-
If you specify these columns as strings instead of field instances in a
sa.Index
, that works too. The implicit ordering of creating columns before the constraints will be resolved for you. ↢ -
Within Pydantic, and SQLModel by extension, this Field() definition is actually a function call that returns a
FieldInfo
class instance. It stores the same data but the FieldInfo is the actual struct that you manipulate at runtime. ↢ -
We'll see a clear reason why this helper class is useful in a second. But for the meantime we can assume that it's easier to unit test by itself, which is reason enough to own the database interface logic in a separate component. ↢