Skip to article frontmatterSkip to article content

If you’ve been writing Python for any length of time, you’ve likely encountered a situation that left you scratching your head. This post is placeholder of various examples.

Scopes

Binding in closures:

You write a loop, create a lambda function inside it, and expect each lambda to remember the state of the loop for when it was created. But it doesn’t. You’ve just fallen into one of the most classic “gotchas”: late binding in closures.

We’ll create a list of functions, where each function should print a different number from 0 to 4.

funcs = []
for i in range(5):
    funcs.append(lambda: print(i))

# What do you expect this to print?
for f in funcs:
    f()

Most people (including me one day) expect this to print 0, 1, 2, 3, 4. Instead, you get this:

4
4
4
4
4

Why? The lambda doesn’t capture the value of i at the time of its creation. It captures a reference of i. When you finally call the functions later, the loop has already finished, and the variable i is left with its final value, which is 4. All the lambda functions refer to that same, single variable i.

A real-world example: custom PyTorch optimizer

This has (had) real-world consequences: in a recent project, I was building a custom PyTorch optimizer. I needed a function to “freeze” certain model parameters by applying a mask to their gradients. The plan was to loop through the model’s parameters, and for those that needed masking, register a post_accumulate_grad_hook to apply the mask before the optimizer steps. Here’s a simplified version of the code:

# ... inside a method of a PyTorch Optimizer ...
for p in group["params"]:
    # Get a mask for the parameter p, or None if it's not masked
    mask = masks.get(p, None)

    # If there's no mask, skip to the next parameter
    if mask is None:
        continue

    # Register a hook to apply the mask during backpropagation
    # THIS CODE IS BUGGED!
    state["param_hook"] = p.register_post_accumulate_grad_hook(
        lambda param: param.grad.mul_(mask)
    )

The logic seems sound. We find a mask, and if it exists, we register a hook that uses it. The problem is that the mask variable is captured by the lambda, just like i was in our first example.

Imagine the optimizer has two parameters. The first has a mask, but the second does not:

  1. Loop 1: p is the first parameter. mask gets a valid Tensor. A lambda is created that captures the variable mask.

  2. Loop 2: p is the second parameter. masks.get(p, None) returns None. The if condition is not met, and the loop continues. The mask variable is now None.

  3. Loop ends. The mask variable’s final value is None.

Later, when loss.backward() is called, the hook for the first parameter executes: it looks up the value of the captured variable mask, which is now None, leading to a TypeError: mul_(): argument 'other' (position 1) must be Tensor, not NoneType.

The Solution: capture by value with a default

The fix is simple and a common Python idiom: You force the lambda to capture the value at definition time by using a default argument:

funcs = []
for i in range(5):
    # The magic is m=i
    funcs.append(lambda m=i: print(m))

# This now prints 0, 1, 2, 3, 4 as expected
for f in funcs:
    f()

By using m=i, we create a new variable m local to the lambda. Its default value is evaluated at the time the lambda is defined, effectively capturing the current value of i for each iteration. Applying this to our PyTorch optimizer code:

# ...
# The corrected, working code
state["param_hook"] = p.register_post_accumulate_grad_hook(
    lambda param, m=mask: param.grad.mul_(m)
)

By changing lambda param: ... to lambda param, m=mask: ..., we ensure that each hook captures its own specific mask tensor, solving the bug.

Torch iadd mutation

I always thought that x = x + 1 and x += 1 are equivalent... but it turns out it does not apply to torch tensors.

For x = x + 1, Python evaluates the expression x + 1, creates a new object in memory to hold the result, and then rebinds the name x to point to this new object. The original object that x pointed to remains unchanged.

When you write x += 1, Python calls the __iadd__ (in-place add) method. For mutable objects like PyTorch tensors, this modifies the existing object directly in memory. Any other variables that reference this object will see the change immediately.

Inside some PyTorch optimizer.step() function, we often see code like:

# Inside Adam optimizer
step_t = state_steps[i] # step_t is a reference to the tensor in the state dict
step_t += 1             # Modifies the tensor IN-PLACE

Because += is used, the tensor stored inside the optimizer’s state dictionary is updated. If the code had used step_t = step_t + 1, step_t would just point to a new local tensor, and the optimizer’s state would never update!

Example

import torch

# Scenario A: Out-of-place
x = torch.tensor([0])
y = x           # y points to the same memory as x
x = x + 1       # x now points to a NEW tensor
print(f"x: {x.item()}, y: {y.item()}") 
# > x: 1, y: 0

# Scenario B: In-place
x = torch.tensor([0])
y = x           # y points to the same memory as x
x += 1          # x is modified in-place
print(f"x: {x.item()}, y: {y.item()}") 
# > x: 1, y: 1