Mikko Kortelainen

Flask-SQLAlchemy and PostgreSQL Unit Testing with Transaction Savepoints

PostgreSQL provides us two very powerful features which are helpful with unit testing: transactional DDL and transaction savepoints. In this article I will show you how to use those with Flask-SQLAlchemy unit tests.

Transactional DDL means you can create tables inside a transaction, run tests against them, and roll back your changes after you are done. The database will be left in the same state as it was when you started. If you started with an empty database, nothing will be left in the database afterwards.

Savepoints allow you to roll back a part of a transaction without affecting what has happened before that savepoint was created. You can run a single test inside a savepoint and roll back just the changes that single test made without affecting the changes your set-up code has done before the savepoint.

That means you can create a large number of tables and other database objects in the beginning of your test suite and then run individual tests inside nested transaction using savepoints. There is no need to drop and re-create the whole database schema for each test.

1   PostgreSQL Database for Unit Tests

We will need an empty database to run our unit tests against. If you have PostgreSQL already installed, at least on Ubuntu there are command line tools to get it done:

postgres@host:~$ createuser -P unittestuser
Enter password for new role: password
Enter it again: password

postgres@host:~$ createdb -O unittestuser unittestdb

The createuser command creates a database user. The createdb creates the database, and the -O option sets the owner of the database.

2   Other Requirements

You obviously need to install Flask-SQLAlchemy and psycopg2 for PostgreSQL connectivity:

(virtualenv)user@host:~$ pip install flask flask-sqlalchemy psycopg2

3   An App, a Model, and Some Tests

First let's take a look at an ordinary test suite. Flask-SQLAlchemy wraps SQLAlchemy and integrates it with Flask, providing some additional functionality with it. A very minimal Flask app with Flask-SQLAlchemy and PostgreSQL support could look something like below.

Please note that I haven't tested the following example nor have I run its tests. The downloadable file at the end of this article contains a working and tested example, and includes the tests used to test it. The example below gives us something we can start do discuss here. Let me know if you find errors.

from flask import Flask
from flask.ext.sqlalchemy import SQLAlchemy

app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = 'postgresql+psycopg2://unittestuser:password@localhost/unittestdb'
db = SQLAlchemy(app)

class MyModel(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.String)

    def update_name(self, new_name):
        self.name = new_name

Our app is in deed very minimal. It doesn't even have any views. But it is enough for our purposes.

A minimal test suite for our database model might look like this:

from unittest import TestCase
from app import db, MyModel

class MyModelTestCase(TestCase):
    def setUp(self):
        db.create_all()
        db.session.add(MyModel(name='Test'))
        db.session.commit()

    def tearDown(self):
        db.drop_all()

    def test_model_exists(self):
        """Model should exist"""
        self.assertTrue(
            db.session.query(MyModel).filter(MyModel.name=='Test').first()
            is not None)

    def test_model_update(self):
        """Model modification should modify model"""
        model = db.session.query(MyModel).filter(MyModel.name=='Test').first()
        model.update_name('New name')
        db.session.commit()
        self.assertTrue(
            db.session.query(MyModel).filter(MyModel.name=='New name').first()
            is not None)

Note that the db.create_all/drop_all is executed for each test, as well as the initial model adding. We can avoid this by using a transaction and some savepoints.

4   Tests in Nested Transactions

So how do we use savepoints to our advantage when writing unit tests for Flask-SQLAlchemy apps? To test arbitrary Flask-SQLAlchemy code inside a nested transaction, we must replace the global db.session object with one that runs in a nested transaction. After all, your code is likely to perform all of its database operations through the db.session object.

We can choose to use either the TestCase class's setUpClass/tearDownClass classmethods which are called once per each TestCase, or the setUpModule/tearDownModule functions which unittest will call once per module when running the tests. And we can even use both if we want to do some setup work at the module level and some more in each class.

Here's an example of how to do it using the classmethods:

class MyModelTestCase(TestCase):
    @classmethod
    def setUpClass(cls):
        # Create schema
        db.create_all()
        db.session.add(MyModel(name='Test'))
        db.session.commit()

    @classmethod
    def tearDownClass(cls):
        # Drop schema
        db.drop_all()

    def setUp(self):
        # Create two savepoint
        self.savepoint1 = db.session.begin_nested()
        self.savepoint2 = db.session.begin_nested()

        # Make backup of session and replace with savepoint
        self.session_backup = db.session
        db.session = self.savepoint2.session

    def tearDown(self):
        # Roll back to first savepoint
        self.savepoint1.rollback()

        # Restore original session
        db.session = self.session_backup

    def test_models_exists(self):
        """Models should exist"""
        self.assertTrue(
            db.session.query(MyModel).filter(MyModel.name=='Test').first()
            is not None)

    def test_model_update(self):
        """Model modification should modify model"""
        model = db.session.query(MyModel).filter(MyModel.name=='Test').first()
        model.update_name('New name')
        db.session.flush()
        self.assertTrue(
            db.session.query(MyModel).filter(MyModel.name=='New name').first()
            is not None)

The setUpClass method is only run once in the beginning of the test. The database schema is created there. The schema is dropped in the end in the tearDownClass method.

The setUp method now creates two savepoints. Strictly speaking only one savepoint is needed, but I would recommend two because the code you are testing might call db.session.commit() and that will commit the savepoint it is running in. A committed savepoint can not be rolled back to. So in the tearDown method we will roll back to the first savepoint which will be available also in the case the tested code commits. When the savepoint is rolled back, the next test case will see the database exactly like it was after the setUpClass call.

Note that if your code does multiple commits per test, this doesn't work and you will receive exceptions from SQLAlchemy. You can add more savepoints to work around that. You must have one more savepoint than you have commits in a single test. On the other hand, it is a good way of finding code paths that perform multiple commits. Doing lots of commits will result in a lot of disk activity and will slow down your application.

If you want to do multiple commits in your test case, you can use db.session.begin_nested and commit that:

class MyModelTestCase(TestCase):
    def test_with_commits(self):
        savepoint1 = db.session.begin_nested()
        ... do stuff ...
        savepoint1.session.add(stuff)
        savepoint1.commit()
        ... assert ...

        savepoint2 = db.session.begin_nested()
        ... do more stuff ...
        savepoint2.session.add(stuff)
        savepoint2.commit()
        ... assert some more ...

Stuff you do inside a committed savepoint will be available later on, so the savepoint1 stuff in the example above will be available inside savepoint2, unless you choose to roll back to savepoint1 instead of committing. Eventually of course the test tearDown will roll back to its own savepoint to undo everything you did before running the next test.

5   Transactional DDL

The next logical step is to run db.create_all() inside the same transaction as the tests. This way the database you run unit tests against will start empty and end up empty without even the need to drop tables afterwards.

In fact it might even be possible to run it against a database which has existing tables already if you drop first, and keep your old data after the tests. But I haven't tested if it works or not. It is probably not a good idea, though. There might be side effects since some database objects, such as sequences, are not kept completely isolated inside the transaction.

Flask-SQLAlchemy does not allow us to run the metadata creation inside the same transaction as the rest of the app. Again we need to monkey patch the db object and replace the session with our own, this time to make a new database connection. Flask-SQAlchemy depends on the db.session object to be a factory for new sessions, but we want to avoid new sessions since we want everything to run inside a single transaction. We will have to use a custom session which is a "factory" that returns itself.

Below is an example of how to do it in the setUpModule/tearDownModule functions.

from flask.ext.sqlalchemy import SessionBase

class TestingSession(SessionBase):
    def __init__(self, db, bind, **options):
        self.app = db.get_app()

        SessionBase.__init__(
            self, autocommit=False, autoflush=True,
            bind=bind, **options
        )

    def __call__(self):
        # Flask-SQLAlchemy wants to create a new session
        # Simply return the existing session
        return self

    def get_bind(self, mapper=None, clause=None):
        # mapper is None if someone tries to just get a connection
        if mapper is not None:
            info = getattr(mapper.mapped_table, 'info', {})
            bind_key = info.get('bind_key')
            if bind_key is not None:
                state = flask.ext.sqlalchemy.get_state(self.app)
                return state.db.get_engine(self.app, bind=bind_key)
        return SessionBase.get_bind(self, mapper, clause)

from unittest import TestCase
from app import db

def setUpModule():
    global session_backup

    # Create a connection and start a transaction. This is needed so that
    # we can run the drop_all/create_all inside the same transaction as
    # the tests
    connection = db.engine.connect()
    transaction = connection.begin()

    # Back up the original session and replace with our own
    session_backup = db.session
    db.session = TestingSession(db, connection)

    ## Drop all to get an empty database free of old crud just in case
    db.metadata.drop_all(transaction.connection)

    ## Create everything
    db.metadata.create_all(transaction.connection)

def tearDownModule():
    # Roll back everything
    db.session.rollback()

    # Restore backup
    db.session = session_backup

class MyModelTestCase(TestCase):
    @classmethod
    def setUpClass(cls):
        # Create a model instance for this class
        db.session.add(MyModel(name='Test'))
        db.session.flush()

        # Create class savepoint
        cls.savepoint = db.session.begin_nested()

        # Backup and replace
        cls.session_backup = db.session
        db.session = self.savepoint

    @classmethod
    def tearDownClass(cls):
        # Roll back to class savepoint
        cls.savepoint.rollback()

        # Restore original session
        db.session = cls.session_backup

    def setUp(self):
        # Create two test savepoints
        self.savepoint1 = db.session.begin_nested()
        self.savepoint2 = db.session.begin_nested()

        # Make backup of session and replace with savepoint
        self.session_backup = db.session
        db.session = self.savepoint2.session

        # This is for using app_context().pop()
        db.session.remove = lambda: None

    def tearDown(self):
        # Roll back to first test savepoint
        self.savepoint1.rollback()

        # Restore original session
        db.session = self.session_backup
    .
    .
    .

What setUpModule does is it makes its own database connection instead of using the one Flask-SQLAlchemy has created. This is needed so that we can begin a transaction, run the DDL in it, and continue with running the tests in the same transaction. The Flask-SQLAlchemy db.create_all() will commit the schema into the database if we use it directly. The TestingSession is a modified Flask-SQLAlchemy SignallingSessionclass. It keeps everything running in our single transaction.

Another detail to note is the db.session.remove = lambda: None statement in the test setUp code. It is a required fix when you push the application context in your own setUp method and pop it in the tearDown method. Flask-SQLAlchemy will raise an exception during app context teardown without it.

The rest of it is a simple extension of the original idea. Here we have two levels of setup code instead of one, and savepoints to guard them both. The setUpModule call creates the schema once for the whole module, while the model creation is done in the setUpClass method. It means that if you add another TestCase in the same module the schema will stay intact but the model inserted or any other operation performed in MyModelTestCase's setUpClass will only be there for MyModelTestCase's tests.

6   Abstract it Out

It is obviously not a good idea to keep code like that in each test module so let's extract the functionality and put it in a separate module which we will call flask_sqlalchemy_testing. The code for this and the tests can be downloaded at the end of the article.

from flask.ext.sqlalchemy import SessionBase

class TestingSession(SessionBase):
    def __init__(self, db, bind, **options):
        self.app = db.get_app()

        SessionBase.__init__(
            self, autocommit=False, autoflush=True,
            bind=bind, **options
        )

    def __call__(self):
        # Flask-SQLAlchemy wants to create a new session
        # Simply return the existing session
        return self

    def get_bind(self, mapper=None, clause=None):
        # mapper is None if someone tries to just get a connection
        if mapper is not None:
            info = getattr(mapper.mapped_table, 'info', {})
            bind_key = info.get('bind_key')
            if bind_key is not None:
                state = flask.ext.sqlalchemy.get_state(self.app)
                return state.db.get_engine(self.app, bind=bind_key)
        return SessionBase.get_bind(self, mapper, clause)

from unittest import TestCase

from app import db

def setUpModule():
    """Sets up the module so that we can run the whole test suite inside
    a single transaction. Each test will be executed in a nested transaction
    using PostgreSQL savepoints.

    We will:

    1. Start a transaction
    2. Drop everything
    3. (Re-)create everything
    """
    global session_backup, connection

    # Create a connection and start a transaction. This is needed so that
    # we can run the drop_all/create_all inside the same transaction as
    # the tests
    connection = db.engine.connect()
    transaction = connection.begin()

    # Back up the original session and replace with our own
    session_backup = db.session
    db.session = TestingSession(db, connection)

    ## Drop all to get an empty database free of old crud
    db.metadata.drop_all(transaction.connection)

    ## Create everything
    db.metadata.create_all(transaction.connection)

def tearDownModule():
    """Roll back everything, leaving database as it was before we started"""
    db.session.rollback()
    connection.close()

    # Restore backup
    db.session = session_backup

class FlaskDbTestCase(TestCase):
    @classmethod
    def setUpClass(cls):
        """Makes a savepoint for this class"""
        cls.savepoint1 = db.session.begin_nested()
        cls.savepoint2 = db.session.begin_nested()

        # Replace app db object temporarily with a nested tx
        cls.session_backup = db.session
        db.session = cls.savepoint2.session

    @classmethod
    def tearDownClass(cls):
        """Rolls back to the class savepoint"""
        cls.savepoint1.rollback()

        # Restore original session
        db.session = cls.session_backup

    def setUp(self):
        """Makes a savepoint for this test"""
        # We actually need to make 2 savepoints in case the code to be tested
        # happens to issue a commit so that we can roll back to this point.
        # A commit in the tested code can only commit "savepoint2" so we
        # will use "savepoint" for the rollback.
        self.savepoint1 = db.session.begin_nested()
        self.savepoint2 = db.session.begin_nested()

        # Replace app db object temporarily with a nested tx
        self.session_backup = db.session
        db.session = self.savepoint2.session

        # This is for using app_context().pop()
        db.session.remove = lambda: None

    def tearDown(self):
        """Rolls back to the test class savepoint"""
        # Roll back savepoint - schema will back to fresh
        self.savepoint1.rollback()

        # Restore original session
        db.session = self.session_backup

Let's use an app which has two models for a change:

from flask import Flask
from flask.ext.sqlalchemy import SQLAlchemy

app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = 'postgresql+psycopg2://unittestuser:password@localhost/unittestdb'
db = SQLAlchemy(app)

class ModelX(db.Model):
    x = db.Column(db.Integer, primary_key=True)

class ModelY(db.Model):
    y = db.Column(db.Integer, primary_key=True)

Here's the test module:

import flask_sqlalchemy_testing
from flask_sqlalchemy_testing import FlaskDbTestCase

from app import db, ModelX, ModelY

def setUpModule():
    flask_sqlalchemy_testing.setUpModule()

    # You can issue module-wide setup operations here
    # (these will be rolled back after all tests in this module have run)
    db.session.add(ModelX(x=0))
    db.session.add(ModelY(y=0))
    db.session.commit()

from flask_sqlalchemy_testing import tearDownModule

class ModelXTestCase(FlaskDbTestCase):
    @classmethod
    def setUpClass(cls):
        super(ModelXTestCase, cls).setUpClass()

        # These will be rolled back after all the tests in this class
        db.session.add(ModelX(x=1))
        db.session.flush()

    def setUp(self):
        super(ModelXTestCase, self).setUp()

        # These will be rolled back after each test
        db.session.add(ModelX(x=2))
        db.session.flush()

    def test_model_x_instance_0_exists(self):
        """Model X instance 0 should exist"""
        self.assertIsNotNone(
            ModelX.query.filter_by(x=0).first())

    def test_model_x_instance_1_exists(self):
        """Model X instance 1 should exist"""
        self.assertIsNotNone(
            ModelX.query.filter_by(x=1).first())

    def test_model_x_instance_2_exists(self):
        """Model X instance 2 should exist"""
        self.assertIsNotNone(
            ModelX.query.filter_by(x=2).first())

    def test_model_y_instance_0_exists(self):
        """Model Y instance 0 should exist also in ModelXTestCase"""
        self.assertIsNotNone(
            ModelY.query.filter_by(y=0).first())

    def test_model_y_instance_1_exists_not(self):
        """Model Y instance 1 should not exist in ModelXTestCase"""
        self.assertIsNone(
            ModelY.query.filter_by(y=1).first())

    def test_model_y_instance_2_exists_not(self):
        """Model Y instance 2 should not exist in ModelXTestCase"""
        self.assertIsNone(
            ModelY.query.filter_by(y=2).first())

    def test_savepoint(self):
        """Test model states with savepoint"""
        from sqlalchemy import inspect
        modelx = ModelX(x=3)
        model_state = inspect(modelx)
        self.assertTrue(model_state.transient)

        savepoint = db.session.begin_nested()
        savepoint.session.add(modelx)
        self.assertTrue(model_state.pending)

        savepoint.commit()
        self.assertTrue(model_state.persistent)

class ModelYTestCase(FlaskDbTestCase):
    @classmethod
    def setUpClass(cls):
        super(ModelYTestCase, cls).setUpClass()

        db.session.add(ModelY(y=1))
        db.session.flush()

    def setUp(self):
        super(ModelYTestCase, self).setUp()

        db.session.add(ModelY(y=2))
        db.session.flush()

    def test_model_y_instance_0_exists(self):
        """Model Y instance 0 should exist"""
        self.assertIsNotNone(
            ModelY.query.filter_by(y=0).first())

    def test_model_y_instance_1_exists(self):
        """Model Y instance 1 should exist"""
        self.assertIsNotNone(
            ModelY.query.filter_by(y=1).first())
        self.assertTrue(
            db.session.query(ModelY).filter(ModelY.y==1).first()
            is not None)

    def test_model_y_instance_2_exists(self):
        """Model Y instance 2 should exist"""
        self.assertIsNotNone(
            ModelY.query.filter_by(y=1).first())

    def test_model_x_instance_0_exists(self):
        """Model X instance 0 should exist also in ModelYTestCase"""
        self.assertIsNotNone(
            ModelX.query.filter_by(x=0).first())

    def test_model_x_instance_1_exists_not(self):
        """Model X instance 1 should not exist in ModelYTestCase"""
        self.assertIsNone(
            ModelX.query.filter_by(x=1).first())

    def test_model_x_instance_2_exists_not(self):
        """Model X instance 2 should not exist in ModelYTestCase"""
        self.assertIsNone(
            ModelX.query.filter_by(x=2).first())

if __name__ == '__main__':
    #import logging
    #log = logging.getLogger(__name__)
    #logging.basicConfig(level=logging.INFO)
    #logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)

    import unittest
    unittest.main(verbosity=2)

You can run the tests with "python test_app.py". You can uncomment the lines which enable sqlalchemy.engine logging to see the SQL it is issuing.

We must of course call the setUpModule of the flask_sqlalchemy_testing module in the start of our own setUpModule. If you don't need to override it, remember to import it. The tearDownModule also needs to be imported so that it will run. The same goes for setUpClass/tearDownClass.

7   Download the Code

The code is available to download and use as you wish. Here it is:

flask_sa_testing.tar.gz

I have used it, but not very extensively yet. Let me know if you find errors in it.

Here are the commands to run it in a virtualenv:

tar zxvf flask_sa_testing
cd flask_sa_testing
virtualenv virtualenv
source virtualenv/bin/activate
pip install -r requirements.txt
createuser -P unittestuser
createdb -O unittestuser unittestdb
python test_app.py