Comprehensions
A comprehension is a concise notation for performing some operation on each element of a collection of objects, and/or selecting a subset of elements that satisfy some condition. They are borrowed from the functional programming language Haskell (https://www.haskell.org/) and, together with iterators and generators, contribute to giving Python a functional flavor.
Python offers several types of comprehensions: list, dictionary, and set. We will concentrate on list comprehensions; once you understand them, the other types will be easy to grasp.
Let us start with a simple example. We want to calculate a list with the squares of the first 10 natural numbers. We could use a for
loop and append a square to the list in each iteration:
# squares.for.txt
>>> squares = []
>>> for n in range(10):
... squares.append(n**2)
...
>>> squares
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
This is not very elegant as we have to initialize the list first. With map()
, we can achieve the same thing in just one line of code:
# squares.map.txt
>>> squares = list(map(lambda n: n**2, range(10)))
>>> squares
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
Now, let us see how to achieve the same result using a list comprehension:
# squares.comprehension.txt
>>> [n**2 for n in range(10)]
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
This is much easier to read, and we no longer need to use a lambda. We have placed a for
loop within square brackets. Let us now filter out the odd squares. We will show you how to do it with map()
and filter()
first, before then using a list comprehension again:
# even.squares.py
# using map and filter
sq1 = list(
map(lambda n: n**2, filter(lambda n: not n % 2, range(10)))
)
# equivalent, but using list comprehensions
sq2 = [n**2 for n in range(10) if not n % 2]
print(sq1, sq1 == sq2) # prints: [0, 4, 16, 36, 64] True
We think that the difference in readability is now evident. The list comprehension reads much better. It is almost English: give us all squares (n**2
) for n
between 0 and 9 if n
is even.
According to the Python documentation (https://docs.python.org/3.12/tutorial/datastructures.html#list-comprehensions), the following is true:
A list comprehension consists of brackets containing an expression followed by a for clause, then zero or more for or if clauses. The result will be a new list resulting from evaluating the expression in the context of the for and if clauses which follow it.
Nested comprehensions
Let us see an example of nested loops. This is quite common because many algorithms involve iterating on a sequence using two placeholders. The first one runs through the whole sequence, left to right. The second one does, too, but it starts from the first one, instead of 0. The concept is that of testing all pairs without duplication. Let us see the classical for
loop equivalent:
# pairs.for.loop.py
items = "ABCD"
pairs = []
for a in range(len(items)):
for b in range(a, len(items)):
pairs.append((items[a], items[b]))
If you print pairs
at the end, you get the following:
$ python pairs.for.loop.py
[('A', 'A'), ('A', 'B'), ('A', 'C'), ('A', 'D'), ('B', 'B'), ('B', 'C'), ('B', 'D'), ('C', 'C'), ('C', 'D'), ('D', 'D')]
All the tuples with the same letter are those where b
is at the same position as a
. Now, let us see how we can translate this to a list comprehension:
# pairs.list.comprehension.py
items = "ABCD"
pairs = [
(items[a], items[b])
for a in range(len(items))
for b in range(a, len(items))
]
Notice that because the for
loop over b
depends on a
, it must come after the for
loop over a
in the comprehension. If you swap them around, you will get a name error.
Another way of achieving the same result is to use the combinations_with_replacement()
function from the itertools
module (which we briefly introduced in Chapter 3, Conditionals and Iteration). You can read more about it in the official Python documentation.
Filtering a comprehension
We can also apply filtering to a comprehension. Let us first do it with filter()
, and find all Pythagorean triples whose short sides are numbers smaller than 10. A Pythagorean triple is a triple (a, b, c) of integer numbers satisfying the equation a2 + b2 = c2.
We obviously do not want to test a combination twice, and therefore, we will use a trick similar to the one we saw in the previous example:
# pythagorean.triple.py
from math import sqrt
# this will generate all possible pairs
mx = 10
triples = [
(a, b, sqrt(a**2 + b**2))
for a in range(1, mx)
for b in range(a, mx)
]
# this will filter out all non-Pythagorean triples
triples = list(
filter(lambda triple: triple[2].is_integer(), triples)
)
print(triples) # prints: [(3, 4, 5.0), (6, 8, 10.0)]
In the preceding code, we generated a list of three-tuples, triples
. Each tuple contains two integer numbers (the legs), and the hypotenuse of the Pythagorean triangle, whose legs are the first two numbers in the tuple. For example, when a
is 3 and b
is 4, the tuple will be (3, 4, 5.0)
, and when a
is 5 and b
is 7, the tuple will be (5, 7, 8.602325267042627)
.
After generating all the triples
, we need to filter out all those where the hypotenuse is not an integer number. To achieve this, we filter based on float_number.is_integer()
being True
. This means that of the two example tuples we just showed you, the one with hypotenuse 5.0
will be retained, while the one with the 8.602325267042627
hypotenuse will be discarded.
This is good, but we do not like the fact that the triple has two integer numbers and a float—they are all supposed to be integers. We can use map()
to fix this:
# pythagorean.triple.int.py
from math import sqrt
mx = 10
triples = [
(a, b, sqrt(a**2 + b**2))
for a in range(1, mx)
for b in range(a, mx)
]
triples = filter(lambda triple: triple[2].is_integer(), triples)
# this will make the third number in the tuples integer
triples = list(
map(lambda triple: triple[:2] + (int(triple[2]),), triples)
)
print(triples) # prints: [(3, 4, 5), (6, 8, 10)]
Notice the step we added. We slice each element in triples
, taking only the first two elements. Then, we concatenate the slice with a one-tuple, containing the integer version of that float number that we did not like. This code is getting quite complicated. We can achieve the same result with a much simpler list comprehension:
# pythagorean.triple.comprehension.py
from math import sqrt
# this step is the same as before
mx = 10
triples = [
(a, b, sqrt(a**2 + b**2))
for a in range(1, mx)
for b in range(a, mx)
]
# here we combine filter and map in one CLEAN list comprehension
triples = [
(a, b, int(c)) for a, b, c in triples if c.is_integer()
]
print(triples) # prints: [(3, 4, 5), (6, 8, 10)]
That is cleaner, easier to read, and shorter. There is still room for improvement, though. We are still wasting memory by constructing a list with many triples that we end up discarding. We can fix that by combining the two comprehensions into one:
# pythagorean.triple.walrus.py
from math import sqrt
# this step is the same as before
mx = 10
# We can combine generating and filtering in one comprehension
triples = [
(a, b, int(c))
for a in range(1, mx)
for b in range(a, mx)
if (c := sqrt(a**2 + b**2)).is_integer()
]
print(triples) # prints: [(3, 4, 5), (6, 8, 10)]
Now that is elegant. By generating the triples and filtering them in the same list comprehension, we avoid keeping any triple that does not pass the test in memory. Notice that we used an assignment expression
to avoid needing to compute the value of sqrt(a**2 + b**2)
twice.
Dictionary comprehensions
Dictionary comprehensions work exactly like list comprehensions, but to construct dictionaries. There is only a slight difference in the syntax. The following example will suffice to explain everything you need to know:
# dictionary.comprehensions.py
from string import ascii_lowercase
lettermap = {c: k for k, c in enumerate(ascii_lowercase, 1)}
If you print lettermap
, you will see the following:
$ python dictionary.comprehensions.py
{'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8,
'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15,
'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22,
'w': 23, 'x': 24, 'y': 25, 'z': 26}
In the preceding code, we are enumerating the sequence of all lowercase ASCII letters (using the enumerate
function). We then construct a dictionary with the resulting letter/number pairs as keys and values. Notice how the syntax is similar to the familiar dictionary syntax.
There is also another way to do the same thing:
lettermap = dict((c, k) for k, c in enumerate(ascii_lowercase, 1))
In this case, we are feeding a generator expression (we will talk more about these later in this chapter) to the dict
constructor.
Dictionaries do not allow duplicate keys, as shown in the following example:
# dictionary.comprehensions.duplicates.py
word = "Hello"
swaps = {c: c.swapcase() for c in word}
print(swaps) # prints: {'H': 'h', 'e': 'E', 'l': 'L', 'o': 'O'}
We create a dictionary with the letters of the string "Hello"
as keys and the same letters, but with the case swapped, as values. Notice that there is only one "l": "L"
pair. The constructor does not complain; it simply reassigns duplicates to the last value. Let us make this clearer with another example that assigns to each key its position in the string:
# dictionary.comprehensions.positions.py
word = "Hello"
positions = {c: k for k, c in enumerate(word)}
print(positions) # prints: {'H': 0, 'e': 1, 'l': 3, 'o': 4}
Notice the value associated with the letter l: 3
. The l: 2
pair is not there; it has been overridden by l: 3
.
Set comprehensions
Set comprehensions are similar to list and dictionary ones. Let us see one quick example:
# set.comprehensions.py
word = "Hello"
letters1 = {c for c in word}
letters2 = set(c for c in word)
print(letters1) # prints: {'H', 'o', 'e', 'l'}
print(letters1 == letters2) # prints: True
Notice how for set comprehensions, as for dictionaries, duplication is not allowed, and therefore the resulting set has only four letters. Also, notice that the expressions assigned to letters1
and letters2
produce equivalent sets.
The syntax used to create letters1
is similar to that of a dictionary comprehension. You can spot the difference only by the fact that dictionaries require keys and values, separated by colons, while sets do not. For letters2
, we fed a generator expression to the set()
constructor.