Context information storage for asyncio

At Sqreen, we are building an agent based on dynamic instrumentation. It detects security incidents from inside an application (injections, cross-site scripting etc.) and allows users to configure actions (blocking the attack, logging a stack trace etc.) without requiring code modification. The mechanisms behind dynamic instrumentation in Python are described in a previous blog post. Dynamic instrumentation is also used in Application Performance Management (APM) solutions, such as Datadog, Instana and New Relic.

Instrumenting the code allows us to execute callbacks before calling potentially hazardous functions. For example, to protect against SQL injections, we transparently wrap the method Cursor.execute with a security layer:

def sqreen_execute(self, sql_stmt, *sql_params):
    # Before executing the SQL statement, check it was not built with
    # malicious, unescaped request parameters in it.
    if has_malicious_param(sql_stmt, request.params):
        # If there is, this is an SQL injection. Abort!
        raise SQLinjection(remote_addr=request.remote_addr)
    else:
        # If not, we can safely call the original method.
        return self.execute(sql_stmt, *sql_params)

Let’s assume the instrumented code contains a vulnerable pattern like:

@app.route('/posts')
def posts(request):
    sql_stmt = 'SELECT * FROM posts WHERE id=%s' % request.params['id']

    # With dynamic instrumentation, sqreen_execute is transparently called
    # instead of cursor.execute.
    posts = cursor.execute(sql_stmt)

    return posts_template.render(posts)

Then the nominal request /posts/?id=42 will be executed (although unescaped, the request parameter is not malicious) but the malicious request /posts/?id=42 OR 1=1 won’t be. This way, we are able to protect the app without breaking it!

Context information storage

As we’ve seen above, the function sqreen_execute needs to know the current request to test the SQL statement safety. How can it get it?

Since our function transparently replaces Cursor.execute, it needs to have the same signature, hence we can’t pass the request as a parameter to sqreen_execute. Some web frameworks provide functions to get the current requests, but not all of them, and we strive for a universal solution.

What we can do is to insert a middleware (or instrument the framework’s request handling mechanism) to store the current request in a global variable:

CURRENT_REQUEST = None

def set_request(request):
    global CURRENT_REQUEST
    current_request = request

def get_request():
    return CURRENT_REQUEST

But there is a catch: web frameworks are able to handle several requests concurrently, for obvious performance reasons. So the above pattern won’t work: we may receive a first request request_1 (and store it in CURRENT_REQUEST), and before serving it receive a second request request_2 (overriding request_1 in CURRENT_REQUEST). At the time we look for SQL injections in request_1, we will mess it up with request_2! So we need a stronger, concurrent-safe mechanism to store the current request.

Thread-local storage

To tackle this issue, we need to know how concurrency is implemented. Most of Python web frameworks use threads: this is notably the case of Django, Flask and Pyramid, which are probably the most popular. They implement a common communication protocol with web servers, called WSGI (Web Server Gateway Interface) and initially described in PEP 333.

WSGI servers also use threads, along with processes, to spawn several application instances. Multiprocessing is not an issue here, since each process will handle its own copy of CURRENT_REQUEST. So, we just have to find a solution to let a service thread to store the request it is currently dealing with without impacting other threads.

And Python offers a solution for that. The function threading.local in the standard library return a namespace object whose values are thread specific. This allows us to implement thread-safe request storage as follows:

import threading

RUNTIME_STORAGE = threading.local()
RUNTIME_STORAGE.request = None

def set_request(request):
    RUNTIME_STORAGE.request = request

def get_request():
    return RUNTIME_STORAGE.request

What about asyncio?

In Python 3.4, a new concurrency model was introduced: asyncio. It provides infrastructure for single-threaded, asynchronous programming, including:

  • Coroutine functions, defined with async def, whose execution can be paused using the await keyword with another coroutine, and resumed once the other coroutine is completed.
  • An event loop to schedule and execute coroutines.

Here is an example of asynchronous code.

import asyncio

async def compute(x, y):
    print("Compute %s + %s ..." % (x, y))
    await asyncio.sleep(1.0)
    return x + y

async def print_sum(x, y):
    result = await compute(x, y)
    print("%s + %s = %s" % (x, y, result))

loop = asyncio.get_event_loop()
loop.run_until_complete(print_sum(1, 2))
loop.close()

There are two coroutine functions, print_sum and compute. At execution time

  • The event loop enters print_sum and immediately hands over to compute.
  • compute prints the computation and hands over to asyncio.sleep.
  • Nothing is done in the next second. If other tasks were scheduled in the event loop, they could be executed in the meantime, something that is not possible with the blocking function time.sleep.
  • compute is resumed and completed.
  • print_sum is resumed and completed.

asyncio is a great model for concurrency when IO is involved: when the code being executed is blocked waiting for an answer (for example, DB results), the program can switch to other tasks and come back to it later. It is less system-expensive than threads, and is usually faster when slow IO operations are involved.

This makes asyncio well suited for network operations and, despite being relatively young, several web frameworks have been developed around it. Among them, we recently brought support for aiohttp in our agent. This was a very interesting and challenging task since we had no support for aiohttp at all so far, and an important issue we met was with the request storage mechanism.

Here is what can happen: we receive a first request request_1 and start dealing with it in a coroutine. At some point, the coroutine is suspended and the event loop hands over to another one that handles request_2. The important point is that these two coroutines are executed in the same thread, so threading.local contains the same data for both. When the first coroutine resumes, RUNTIME_STORAGE.request has been set to request_2: that is precisely what we want to prevent.

First attempt: let’s use tasks!

What we need is a mechanism similar to threading.local that works with asyncio, i.e. lets us store context variables and keep track of values per asynchronous execution context.

Unfortunately, there is currently no built-in mechanism in Python to handle this. Different proposals have been made to provide a generic solution in future versions of Python (PEP 550, PEP 567), but in the meantime, we have to devise a solution on our own.

Let’s dig a bit further into the internals of asyncio. A coroutine whose execution is scheduled is wrapped into an asyncio.Task object, responsible for executing the coroutine object in an event loop.

https://docs.python.org/3/_images/tulip_coro.png

Sequence diagram of the example

There is also a function asyncio.Task.current_task that returns the currently running task. Mmh… We could use this to map the current request to the task handling it. Something like this could work:

import asyncio

TASK_REQUESTS = {}

def set_request(request):
    task = asyncio.Task.current_task()
    TASK_REQUESTS[id(task)] = request

def get_request():
    task = asyncio.Task.current_task()
    return TASK_REQUESTS.get(id(task))

With this implementation, we’d also need a mechanism to ensure request deletion once the task is completed, to avoid accumulating old requests and cause a memory leak. A way to avoid dealing with it is to store the request within the Task object, as an extra attribute:

def set_request(request):
     task = asyncio.Task.current_task()
     setattr(task, 'current_request', request)
 
 def get_request():
     task = asyncio.Task.current_task()
     return getattr(task, 'current_request', None)

So, does it work? Let’s test!

import random

class Request:
    # Dummy request object, for the sake of testing.
    pass

async def handle_request(request):
    set_request(request)
    await asyncio.sleep(random.uniform(0, 2))
    await check_request(request)

async def check_request(request):
    # Check that the stored request corresponds to the current request. If not,
    # an AssertionError is raised and the test is interrupted with an error.
    assert get_request() is request, "wrong request"

NUM_REQUESTS = 1000

loop = asyncio.get_event_loop()
coros = [handle_request(Request()) for _ in range(NUM_REQUESTS)]
loop.run_until_complete(asyncio.gather(*coros))
loop.close()
print("Success!")

This test simulates one thousand concurrent requests. Each one is handled in a dedicated coroutine handle_request. This function stores the request, then pauses for a random duration (this simulates an async operation such as a DB access, and ensures that the coroutine execution flow is interleaved). When resumed, a nested coroutine check_request is called that ensures that get_request returns the correct request. If not, the test is interrupted by an error.

And here are some good news: the test runs smoothly with task-based request storage. It also fails with thread-local storage, which was expected but shows the test is relevant. So, have we solved our problem?

Context inheritance between tasks

Let’s try something a bit more twisted:

async def handle_request(request):
    set_request(request)
    await asyncio.gather(
        asyncio.sleep(random.uniform(0, 2)),
        check_request(request),
    )

Instead of executing asyncio.sleep and check_request sequentially, this version of handle_request runs them concurrently. This should not be a big deal: the code is a bit more concurrent, but it does not impact request handling. In particular, check_request is still called after set_request for each request.

Nevertheless, this new test fails! Something went wrong when we introduced asyncio.gather, but what?

Well, remember that scheduled coroutines are wrapped into tasks? That’s exactly what happens here: asyncio.gather creates tasks around the arguments asyncio.sleep() and check_request() and these tasks are executed by the event loop.

async def handle_request(request):
    set_request(request)                              # Running in task 1.
    await asyncio.gather(
        asyncio.sleep(random.uniform(0, 2)),          # Create child task 2.
        check_request(request),                       # Create child task 3.
    )

async def check_request(request):
    assert get_request() is request, "wrong request"  # Running in task 3.

The consequence is that set_request and get_request are called in different tasks, making the test fail. This is not due to request interleaving, as we can check by setting NUM_REQUESTS to 1: the test keeps failing.

In fact, when calling get_request from a child task, we need a mechanism to retrieve the request from the parent task if it is not defined in the child task. But asyncio does not allow us to access the parent task, so this is not going to work.

On the other hand, something asyncio let us do is to replace the function called to create new tasks, a.k.a. the task factory. This function is called in the context of the parent task, and returns a fresh child task. Well, let’s use it to decorate the child task with the current request!

Here is what a “request-aware” task factory would look like:

def request_task_factory(loop, coro):
    # This is the default way to create a child task.
    child_task = asyncio.tasks.Task(coro, loop=loop)

    # Retrieve the request from the parent task...
    parent_task = asyncio.Task.current_task(loop=loop)
    current_request = getattr(parent_task, 'current_request', None)

    # ...and store it in the child task too.
    setattr(child_task, 'current_request', current_request)

    return child_task

To install the task factory, we also need to call loop.set_task_factory(request_task_factory) before running the loop. So, here is the final version of our code:

import asyncio
import random

class Request:
    pass

def set_request(request):
    task = asyncio.Task.current_task()
    setattr(task, 'current_request', request)

def get_request():
    task = asyncio.Task.current_task()
    return getattr(task, 'current_request', None)

def request_task_factory(loop, coro):
    child_task = asyncio.tasks.Task(coro, loop=loop)
    parent_task = asyncio.Task.current_task(loop=loop)
    current_request = getattr(parent_task, 'current_request', None)
    setattr(child_task, 'current_request', current_request)
    return child_task

async def handle_request(request):
    set_request(request)
    await asyncio.gather(
        asyncio.sleep(random.uniform(0, 2)),
        check_request(request),
    )

async def check_request(request):
    assert get_request() is request

NUM_REQUESTS = 1000

loop = asyncio.get_event_loop()
loop.set_task_factory(request_task_factory)
coros = [handle_request(Request()) for _ in range(NUM_REQUESTS)]
loop.run_until_complete(asyncio.gather(*coros))
loop.close()

And it works flawlessly!

And now what?

We now have the foundations to solve the request storage issue in our agent. Since we want the agent to work as transparently as possible and not require code modification from the user, there are still two minor issues to be tackled:

  • We want to automatically set up the task factory. This will be done with dynamic instrumentation.
  • But if a custom task is set up by the user, we don’t want to overwrite it. Instead, we will wrap it up on our own.

Let’s start with the second problem. We can define a generic function wrap_request_task_factory that takes a task factory as argument and returns a variant of it that supports request propagation. The code of the wrapped function is really close to the one of request_task_factory above:

from functools import wraps

def wrap_request_task_factory(task_factory):

    @wraps(task_factory)
    def wrapped(loop, coro):
        child_task = task_factory(loop, coro)
        parent_task = asyncio.Task.current_task(loop=loop)
        current_request = getattr(parent_task, 'current_request', None)
        setattr(child_task, 'current_request', current_request)
        return child_task

    return wrapped

Then, the definition of request_task_factory can be simplified to:

@wrap_request_task_factory
def request_task_factory(loop, coro):
    return asyncio.Task.current_task(loop=loop)

Time to go back to dynamic instrumentation. By hooking the import system, we can transparently replace an imported class by a custom one. So let’s define a function patch_loop_cls that creates a custom loop class with the desired behavior:

def wrap_loop_cls(loop_cls):

    class SqreenLoop(loop_cls):

        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            # We want to use request_task_factory to be the default task
            # factory.
            super().set_task_factory(request_task_factory)

        def set_task_factory(self, task_factory):
            # If the user sets up a custom task factory, let's wrap it with
            # request propagation.
            wrapped_task_factory = wrap_request_task_factory(task_factory)
            super().set_task_factory(wrapped_task_factory)

    return SqreenLoop

This loop class transparently replaces the base one. It uses the correct task factory by default and allows the user to change it while keeping the request management layer.

Closing words

We have published most of this work (without the instrumentation part) in a Python library called AioContext. It comes with generic Context objects that behave like dictionaries. It also allows to restore the original task factory if contexts are no longer needed, and stores contexts as an extra attribute of the task factory to avoid messing with the asyncio.Task class itself. The documentation is available here.

import asyncio
import aiocontext
import random

class Request:
    pass

CONTEXT = aiocontext.Context()

async def handle_request(request):
    CONTEXT['current_request'] = request
    await asyncio.gather(
        asyncio.sleep(random.uniform(0, 2)),
        check_request(request),
    )

async def check_request(request):
    assert CONTEXT['current_request'] is request

NUM_REQUESTS = 1000

loop = asyncio.get_event_loop()
aiocontext.wrap_task_factory(loop)
CONTEXT.attach(loop)
coros = [handle_request(Request()) for _ in range(NUM_REQUESTS)]
loop.run_until_complete(asyncio.gather(*coros))
loop.close()

This work was strongly inspired by Manuel Miranda’s blog post From Flask to aiohttp and the library aiotask-context. We want to thank him for the great contribution.