writing

[Python] Typing the hard stuff

Jan 2026

Rules of Thumb

  1. Expect generic types, return specific types. This gives users of your functions and classes the biggest flexibility and safety
def process(items: Iterable[Item]) -> list[Item]:
    return [item for item in items if item.is_valid()]
  1. Always use annotations - they help the IDE, they help AI, they help self-document.

Typing: Generators

def my_generator() -> Iterator[int]:
    for i in range(10):
        yield i

# Generator[YieldType, SendType, ReturnType]
def my_generator() -> Generator[int, None, None]:
    for i in range(10):
        yield i

# SendType 

# Coroutine is just a generator that uses `yield` to receive values
# -> Coroutine[None, None, ReturnType]
# -> Awaitable[ReturnType]

async def print_if_has_prefix(prefix: str) -> None:
    print("Searching prefix:{}".format(prefix)) 
    while True: 
        name = (yield) 
        if prefix in name: 
            print(name) 
  
polite_coro = print_if_has_prefix("Dear") 
  
# This will start execution of coroutine and  
# Prints first line "Searching prefix..." 
# and advance execution to the first yield expression 
polite_coro.__next__() 
  
# sending inputs 
polite_coro.send("Atul") # No output
polite_coro.send("Dear Atul") # Prints "Dear Atul"

Typing: Classes

def birth(self, animal: type[Animal]) -> Animal:
    return animal()

Typing: Context Managers 1

from __future__ import annotations

from typing import overload
from types import TracebackType

class MyContextManager:
  def __enter__(self) -> None:
    pass

  @overload
  def __exit__(self, exc_type: None, exc_val: None, exc_tb: None) -> None:
    ...

  @overload
  def __exit__(
    self,
    exc_type: type[BaseException],
    exc_val: BaseException,
    exc_tb: TracebackType,
  ) -> None:
    ...

  def __exit__(
    self,
    exc_type: type[BaseException] | None,
    exc_val: BaseException | None,
    exc_tb: TracebackType | None,
  ) -> None:
    pass

Typing: Decorators

from typing import Callable, TypeVar, Any

_ReturnType = TypeVar('_ReturnType')

def foo(arg: str) -> Callable[[Callable[..., _ReturnType]], Callable[..., _ReturnType]]:
    def decorator(function: Callable[..., _ReturnType]) -> Callable[..., _ReturnType]:
        # Run on function registration only
        print(f"Decorator registered: {arg}")
        @functools.wraps(function)
        def wrapper(*args: Any, **kwargs: Any) -> _ReturnType:
            # Run on every function call
            print(f"Decorated function call: {arg}")
            result = function(*args, **kwargs)
            return result
        # Note we can collapse the wrapper into just `function` if we don't need to modify the function call
        return wrapper
    return decorator


# If you don't need the return type - use a bound type variable
from typing import Callable, TypeVar, Any

WrappedFn = TypeVar("WrappedFn", bound=Callable[..., Any])

def foo(arg: str) -> Callable[[WrappedFn], WrappedFn]:
    def decorator(function: WrappedFn) -> WrappedFn:
        return function
    return decorator

Typing: Optional Imports 1

try:
  import matplotlib.pyplot as plt

  _HAS_PLT = True
except ImportError:
    _HAS_PLT = False

def train():
  if _HAS_PLT:
    ...

Typing: External Modules 1

[tool.mypy]
mypy_path = "mypy_stubs"
mypy_stubs
└── example
    ├── __init__.pyi
    └── widgets.pyi
from typing import Any

def __getattr__(name: str) -> Any: ...

class Widget:
  def __init__(self, name: str) -> None: ...
  def frobnicate(self) -> None: ...
# In __init__.pyi
from typing import Any

def __getattr__(name: str) -> Any: ...

In runnable code, Python calls such a getattr function for any access of missing attributes. This allows you to do anything you like, such as dealing with complicated deprecations.

In stub files, type checkers understand this particular getattr definition to mark all unmentioned names as type Any. So, in our example, code that uses Widget can be fully type checked:

Typing: Array Shapes (NumPy, JAX, PyTorch)

Use jaxtyping for readable shape annotation, despite the name it is not just for JAX.

from jaxtyping import Float, Int, Array

def normalize(x: Float[np.ndarray, "batch features"]) -> Float[np.ndarray, "batch features"]:
    return x / x.sum(axis=-1, keepdims=True)

def attention(
    q: Float[Array, "batch heads seq_q dim"],
    k: Float[Array, "batch heads seq_k dim"],
    v: Float[Array, "batch heads seq_k dim"],
) -> Float[Array, "batch heads seq_q dim"]:
    ...

Shape syntax:

SyntaxMeaning
"batch features"Named dimensions (reusable)
"batch 3"Fixed size dimension
"*batch"Variadic (0+ dims)
"batch ..."Arbitrary trailing dims
"batch #channels"Symbolic constant

Type aliases for common patterns:

from jaxtyping import Float

Image = Float[np.ndarray, "height width channels"]
BatchedImages = Float[np.ndarray, "batch height width channels"]

def resize(img: Image, scale: float) -> Image:
    ...

Typing: Duck-types 1

from typing import Protocol

class SupportsClose(Protocol):
  def close(self) -> None: ...
# If you need arbitrary arguments
T = TypeVar("T", covariant=True)
Operation = Callable[..., T]

Typing: Wrapper Functions

from typing_extensions import Annotated, Doc, ParamSpec  # type: ignore [attr-defined]

P = ParamSpec("P")

def dispatch(
  func: Annotated[
    Callable[P, Any],
    Doc(
      """
      Extra documentation for the function.
      """
    ),
  ],
  *args: P.args,
  **kwargs: P.kwargs,
) -> Any:
    ...

Typing: Sentinels

class Sentinel(str, Enum):
    NOT_SET = "NOT_SET"

    def __bool__(self) -> Literal[False]:
        return False

Typing: None Preserving Functions

Preserve None through function calls. Most specific overload first.

@overload
def ensure_utc(value: datetime) -> datetime: ...

@overload
def ensure_utc(value: datetime | None) -> datetime | None: ...

def ensure_utc(value: datetime | None) -> datetime | None:
    if value is None:
        return None
    return value.astimezone(timezone.utc)

Typing: Add metadata

Use Annotated to attach metadata without polluting the type.

# Meh: default is Dependency, not Session
def handler(session: Session = Depends(get_session)): ...

# Better: type stays Session, metadata separate
def handler(session: Annotated[Session, Depends(get_session)]): ...

Typing: Discriminated Unions

class Processing(TypedDict):
    status: Literal["processing"]
    solution: None

class Ready(TypedDict):
    status: Literal["ready"]
    solution: dict

Result = Processing | Ready

def handle(r: Result):
    if r["status"] == "ready":
        print(r["solution"])  # narrowed to dict
  1. Source from Adam Johnson’s Blog 2 3 4