Advanced Static Typing in Python: How to type a Decorator

This post is aimed at people who are (forced to) using Python and are interested in Functional Programming and Static Typing benefits, that allow you to detect many potential errors at dev-time and prevent costly follow-up errors.

First, I will give a quick recap about basic static typing (using mypy). Then we create a flexible identity function with a TypeVar using Python 3.12 syntax. We will also look into functions that use multiple type variables, as well as other typed functions. And lastly, we use all of the above to build a type-safe logging decorator that pairs maximum flexibility (akin to the Any type), paired with self-enforcing code that complains if you use it wrong, even before running it (the opposite of the Any type).

Identity function

Let's create a function that simply returns what was passed in.

def identity(x):
    return x

a = identity(5)
b = identity("Hello")
c = a + b
print(c)

When I asked some developers to predict the output, the majority said:

5Hello

But the actual output is:

Traceback (most recent call last):
  File "main.py", line 6, in <module>
    c = a + b
TypeError: unsupported operand type(s) for +: 'int' and 'str'

The prediction would have been true for JavaScript, but in Python, adding a String to an Integer is not allowed and results in a crash of the program.

But it will only crash when reaching that line, so for example:

import random

if random.randint(1, 10) > 9:
    print("Hello" / 2)

This code will only crash sometimes. Extrapolate this to a large code base and see...

The issue

We were allowed to "ship" the broken code. Since you are not guaranteed to run every possible branch in your code, and certainly aren't guaranteed to write tests for every possible branch in your code I see all code that you are allowed to write == code you are allowed to ship.

How do we prevent ourselves from being allowed to ship code that is broken by definition? The answer is static typing - a fancy term for "letting a machine tell us when we are wrong".

Disclaimer: we will use type hints, but for them to have any desirable effect in Python you MUST install a static analyzer. We will use mypy throughout this article. Once you installed it, you can simply run mypy . to scan your code. Note: static analyzers do not affect your production code in any way, they are meant to be run locally and in your CI pipeline and validate the source code in a mechanical way.

Now, before we continue: Just by adding a type checker and adding simple types to your functions, you will already solve 80% of your problems. Take this example:

def add(a: int, b: int) -> int:
    return a + b

e1 = add(4, 4)
e2 = add(None, 4) # not allowed
e3 = add(2, 3) + "yep" # not allowed

But with this simplistic approach to typing you will run into issues for the remaining 20% of your code base.

The challenge

Create a function that fulfills all of the following createria:

  1. The function returns the value that was passed in.
  2. The function accepts ALL possible input types
  3. The return type is ALWAYS the same as the input type
  4. mypy shall complain when a call would result in a runtime error
  5. mypy shall not complain about valid calls.

Here are some common, but ineffective approaches:

The simple type hint approach

Let's modify the code to include type hints (and remember: these ONLY do anything if you run mypy or another static analyzer. DO NOT USE TYPE HINTS IF YOU HAVE NO STATIC ANALYSIS STEP)

def identity(x: int) -> int:
    return x

a = identity(5)
b = identity("Hello") # ❌ mypy: incompatible type "str"; expected "int" 
c = a + b
print(c)

We get an error, but it's not the error we wanted. Due to our type signature, calling identity("Hello") is no longer allowed, as identity only accepts int anymore.

We violate the criterion "The function accepts ALL possible input types".

The Any approach

from typing import Any

def identity(x: Any) -> Any:
    return x

a = identity(5)
b = identity("Hello")
c = a + b # no warning, but blows up at runtime
print(c)

Now, mypy shuts up (0 errors). That is because the Any type is treated as all-powerful and basically no rules apply to it. But when we run the program, we run into the original runtime exception / program crash. So by adding the Any type we added pointless noise with zero benefit.

We violated the criterion "mypy shall complain when a call would result in a runtime error".

The union type approach

def identity(x: int | string) -> int | string:
    return x

a = identity(5)
x = a + a # ❌ mypy error: unsupported operand types
b = identity("Hello")
c = a + b
print(c)

What if you want to pass in a list[tuple[str, str]]? "Just add it to the union". Well obviously that does not scale, but there is a much bigger issue.

The signature suggests, that function could be int -> int, int -> str, str -> str and str -> int.

In line 5: x = a + a we know the result should just be 10. However the type checker gives us two errros for this line:

main.py:5: error: Unsupported operand types for + ("int" and "str")  [operator]
main.py:5: error: Unsupported operand types for + ("str" and "int")  [operator]

The reason is that according to the signature, these outcomes are possible, and the type checker always operates on all possibilities expressed by the type signature.

This solution violates the criteria 2, 3 and 5.

The overload approach

from typing import overload

@overload
def identity(x: int) -> int: ...

@overload
def identity(x: str) -> str: ...

def identity(x: int | str) -> int | str:
    return x

a = identity(5)
x = a + a
b = identity("Hello")
c = a + b # ❌ mypy error: unsupported operand types
print(c)

Okay so technically, this is a solution. It fails in the correct line (where we would actually get a runtime error) and allows everything before the problematic call.

But let's add another call:

d = identity(["Hello", "World"]) # ❌ No overload variant of "identity" matches argument type "list[str]"  [call-overload]

This solution still violates criterion 2: "The function accepts ALL possible input types".

The solution: a Generic function

This is valid Python 3.12 code:

def identity[T](x: T) -> T:
    return x

a = identity(5)
x = a + a
b = identity("Hello")
c = a + b # ❌ mypy error: unsupported operand types
print(c)

Perfect: mypy complains in the exact place where we would get the runtime error, and not any other place.

Let's add some more calls:

d = identity(["Hello", "World"])
d / 2 # ❌ Unsupported operand types for / ("list[str]" and "int")

e = identity(print)
e + 1 # ❌ Unsupported operand types for + (overloaded function and "int")

We are allowed to pass in whatever we want, but when we use the return type wrongly, we get the correct predictions about runtime errors.

This solution satisfies all five criteria. 🥳

How it works

# Python 3.12
def identity[T](x: T) -> T:
    return x

The function takes square brackets [] before the parameter list, where you can denote one or multiple Type Variables. You can then reference these variables in your parameters and return type.

When calling identity(5), we start from the middle:

  • x has type T
  • we pass 5 as x
  • 5 is an int
  • therefore T = int
  • therefore the function has signature int -> int

When calling identity("Hello") we automatically have signature:
(str) -> str

When calling identity([(3, "hi"), (17, "yep")]) we have singature:
(list[tuple[int, str]]) -> (list[tuple[int, str]])

Let's run the latter example through mypy:

a = identity([(3, "hi"), (17, "yep")])
print(a[1][1]) # yep
print(a[1][2]) # ❌ mypy: error: Tuple index out of range

Would you have spotted the error on line 3? In any case, I think having a machine to prevent me from shipping this code is awesome. Here you can run it through mypy yourself.

Two type parameters

Now that we know how a Generic function works, let's amp it up a little. A function with two type parameters.

def bundle[A, B](val1: A, val2: B) -> tuple[A, B]:
    return val1, val2

a = bundle(1, 2)
# a: tuple[int, int]

b = bundle(1, "Hello")
# b: tuple[int, str]

c = bundle([1, 2, 3], ("hi", "ho"))
# c: tuple[list[int]], tuple[str, str]]

The above is straight forward. Whatever types of A and B you pass in, result in a tuple[A, B] as the return type. (Both can be the same).

Type parameter as the return type of a Callable

Now it becomes more advanced. In many modern languages, you are allowed to pass functions into other functions. This is a core capability of Functional Programming and enables you to implement Dependency Injection, Strategy Pattern, Decorator pattern and many more without relying on OOP.

In the world of static typing, everything has a type, including functions. In Python, you can describe any function as a Callable. However, to be more precise you can specify the inputs and outputs. For example:

# you can call it...
Callable

# takes one int param, returns a str
Callable[[int], str]

# takes two int params, returns a str
Callable[[int, int], str]

# takes any number of params of any type, returns a str
Callable[..., str]

# takes a callable (int) -> str and returns a callable (int) -> str
Callable[[Callable[[int], str]], Callable[[int], str]]

Exercise: What is the type of this function, expressed in Callable syntax?

def add(a: int, b: int) -> int:

Generic Types and Callables combined

from typing import Callable

def into[A, B](value: A, transform: Callable[[A], B]) -> B:
    return transform(value)

This signature is a bit much if you are new to this topic. Let's go at it via an example.

def number_to_message(n: int) -> str:
    return f"The number is {n}"

a = into(value=3, transform=number_to_message)
print(a) # prints "The number is 3"

We pass value=3, therefore we can resolve A = int:

def into[B](value: int, transform: Callable[[int], B]) -> B:

We pass transform=number_to_message. number_to_message is a function with the signature (int) -> str . Or in Python's formal way of writing it: Callable[[int], str] . As our type variable B is in the place of the return type of the callable, we can resolve B as str.

def into(value: int, transform: Callable[[int], str]) -> str:

Since B shows up in as the return type of the Callable that is passed in, as well as in the return type of the function itself, the return type of the entire function call is str.

a = into(value=3, transform=number_to_message)
reveal_type(a) # str

But... why

The examples I showed so far may seem a unnecessary. After all, instead of all this typing nonsense:

from typing import Callable

def into[A, B](value: A, transform: Callable[[A], B]) -> B:
    return transform(value)

def number_to_message(n: int) -> str:
    return f"The number is {n}"

a = into(value=3, transform=number_to_message)
print(a) # prints "The number is 3"

We could have just done this:

val = 3
a = f"The number is {val}"
print(a) # prints "The number is 3"

And even the type checking would still work. (mypy understands that val is int, and that a is a str)

The typed Decorator

But now let us build a decorator that fulfills the following criteria:

  1. Can take any function of any type
  2. Prints the function name before and after executing it
  3. Returns the same function
  4. Tracks the types such that mypy complains when an error would happen
  5. mypy does not complain for wrong reasons

Let's create two small functions that we will decorate later.

def greet(name: str) -> str:
    return f"Hello, {name}"

def add(a: int, b: int) -> int:
    return a + b

And our decorator:

def log_it(func):
    def wrapper(*args, **kwargs):
        print("Calling " + func.__name__)
        result = func(*args, **kwargs)
        print(f"Result was {result}")
        return result
    return wrapper

Now, we add the decorator to our functions and call them:

@log_it
def greet(name: str) -> str:
    return f"Hello, {name}"

@log_it
def add(a: int, b: int) -> int:
    return a + b

a = greet("Alice")
b = add(3, 8)

This works just fine. a = "Hello, Alice" and b = 11. The function calls were printed as well.

So what's the issue?

The decorator signature def log_it(func): does not specify the parameter or return type. Therefore, it gets assigned the type (Any) -> Any in the land of static typing.

Let's add some more code

a = greet("Alice")
r1 = a / 2

b = add(3, 8)
r2 = b + "yep"

Our greet() and add() functions are typed nicely, the static analyzer is running, but it says "all good". But both the lines of r1 and r2 will cause a crash of the program and we are back to the start of this article - a situation we want to avoid.

Now firstly, you can configure mypy to complain about untyped functions, by adding this to your mypy.ini:

disallow_untyped_defs=True
disallow_untyped_calls=True

Then, it would no longer allow you to "ship" code with missing signatures that may hide crashes, like we have seen in the code above.

But how DO you type the decorator? Let's do it progressively:

def log_it(func: Callable) -> Callable:

This is a start, and not wrong. The decorator takes a function and returns a (slightly modified) function. But since we do not specify the inputs and outputs, we still end up with Callable[..., Any] as the return type, and calling the Callable therefore returns Any, our arch nemesis. Effectively we are not checking much.

Let's type the callable's return type with a TypeVar.

def log_it[R](func: Callable[..., R]) -> Callable[..., R]:

This is already much better:

a = greet("Alice")
r1 = a / 2 # ❌ Unsupported operand types for / ("str" and "int")

b = add(3, 8)
r2 = b + "yep" # ❌ Unsupported operand types for + ("int" and "str")

Thanks to our signature, mypy knows the return type is preserved. The result of greet("Alice") is known to be a string (even after the decoration), therefore we cannot allow to divide that result by a a number.

However consider this

a = greet(1)

Our original signature of greet was def greet(name: str) -> str:

The above call however is allowed. The reason is, that our decorator said the function that goes in / comes out has the type Callable[..., R]. The inputs get turned into a variable-length list of Any's. We lose all type tracking of the inputs.

Final signature and content of the decorator:

def log_it[**P, R](func: Callable[P, R]) -> Callable[P, R]:
    def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
        print("Calling " + func.__name__)
        result = func(*args, **kwargs)
        print(f"Result was {result}")
        return result
    return wrapper

Note: the **P syntax is new here. It is called a ParamSpec, and I have not fully grasped the concept myself. But essentially you can use it when you need to capture the entire set of Parameter types of a function in one variable. You will also be able to call P.args and P.kwargs afterward to extract the lists of type variables for the positional and named parameters 🤯 (you will get used to it).

Here is the fully typed code, with examples of errors that mypy catches:

from typing import Callable

def log_it[**P, R](func: Callable[P, R]) -> Callable[P, R]:
    def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
        print("Calling " + func.__name__)
        result = func(*args, **kwargs)
        print(f"Result was {result}")
        return result
    return wrapper
    
@log_it
def greet(name: str) -> str:
    return f"Hello, {name}"

@log_it
def add(a: int, b: int) -> int:
    return a + b
        
greet("Alice") # good
greet("Alice", "Bob") # ❌ Too many arguments for "greet"
greet("Alice") / 2 # ❌  Unsupported operand types for / ("str" and "int")

add(1, 2) # good
add(1) # ❌ Missing positional argument "b" in call to "add"
add(1, 2) + "Yep" # ❌ Unsupported operand types for + ("int" and "str")
add("Bob") # ❌ Argument 1 to "add" has incompatible type "str"; expected "int"

You can run mypy on this code here.

What's next

Now, if you have questions left, like

  • How do I type a decorator that takes an async function?
  • How do I type a decorator that adds/removes/replaces types on that Callable that it takes?

Then I'm glad to tell you, I have already solved these problems in the codebase I'm currently working in. I just need to find the time to write another article. If you're interested, take this as an opportunity to sign up to my newsletter to be informed when the next article drops.