Optimizing the brute-force search for an n-queens solution
In this post, we optimize our brute force search for solutions to n-queens problems for a few small boards.

NOTE: This post is part of a series of posts on Solving problems with simple yet powerful abstractions in Python. This is Part 5 of the series. You can find an index of the posts in the series in the specified link.
In our last post, we came up with a solver for the n-queens problem using abstractions we came up with in our previous posts. In this post, we’re going to take a look at optimizing our brute-force search.
Before we go about optimizing, let’s ensure we have some numbers to know that we’re indeed causing improvements due to our optimization efforts. Changing our main program/calling code would be the easiest way to do this. As things stand, let’s run a test evaluation of our brute-force solution to the n-queens problem. This can serve as a useful benchmark for us:
from time import time | |
from board import Size | |
from n_queens import n_queens | |
summary = "size: {:3d}, solutions: {:5d}, time: {:8.3f}s, last: {}" | |
# find solutions for n queens problem given n | |
def solve(n: Size): | |
time_start = time() | |
solutions = list(n_queens(n)) | |
time_elapsed = time() - time_start | |
board = solutions[-1] | |
print(summary.format(n, len(solutions), time_elapsed, board)) | |
for n in range(4, 9): | |
solve(Size(n)) | |
# prints | |
# size: 4, solutions: 2, time: 0.002s, last: [2, 0, 3, 1] | |
# size: 5, solutions: 10, time: 0.027s, last: [4, 2, 0, 3, 1] | |
# size: 6, solutions: 4, time: 0.345s, last: [4, 2, 0, 5, 3, 1] | |
# size: 7, solutions: 40, time: 5.946s, last: [6, 4, 2, 0, 5, 3, 1] | |
# size: 8, solutions: 92, time: 132.495s, last: [7, 3, 0, 2, 5, 1, 6, 4] |
Avoiding useless work
And reducing the search space in the process
When we look at the present search, we’re iterating over all queen positions at each row. However, we could never have solutions where 2 queens are present in the same column on two different rows. Given our representation, no two elements in the list of size n can ever be the same. To put it differently, our candidate solutions are only those that have n unique columns, not all possible combinations of them.
Given our domain is made up of a product of other domains, we need a way to filter out candidates to refine our domain. There’s another neat abstraction! Let’s write one:
from typing import Protocol, TypeVar | |
from domain import Domain | |
from map_domain import map_domain | |
from solve import Problem | |
T = TypeVar("T") | |
def __identity(_: Problem, candidate: T) -> T: | |
return candidate | |
class Predicate(Protocol[Problem, T]): | |
def __call__(self, problem: Problem, candidate: T) -> bool: | |
pass | |
class Filter(Protocol[Problem, T]): | |
def __call__(self, domain: Domain[Problem, T]) -> Domain[Problem, T]: | |
pass | |
def filter_domain(predicate: Predicate[Problem, T]) -> Filter[Problem, T]: | |
def __filter(domain: Domain[Problem, T]) -> Domain[Problem, T]: | |
def next_matching(problem: Problem, candidate: T | None) -> T | None: | |
while candidate is not None and not predicate(problem, candidate): | |
candidate = domain.next(problem, candidate) | |
return candidate | |
return map_domain(next_matching, __identity)(domain) | |
return __filter |
Notice that we’re reusing the map_domain
abstraction we came up with in the previous post. Using our latest abstraction, filtering out the unnecessary candidates is really simple:
from functools import cache, partial | |
from typing import Tuple, TypeAlias | |
from board import Board, Size, has_collision, row_pairs | |
from brute_force import First, Next | |
from domain import Domain, brute_force | |
from filter_domain import Predicate, filter_domain | |
from func_utils import always | |
from integers import integers | |
from map_domain import Down, Up, map_domain | |
from product import product | |
Rows: TypeAlias = Tuple[int, ...] | |
@cache | |
def __rows_domain(size: Size) -> Domain[Size, Rows]: | |
row_domain = integers(always(0), always(size)) | |
row_domains = [row_domain] * size | |
return product(*row_domains) | |
__first: First[Size, Rows] = lambda n: __rows_domain(n).first(n) | |
__next: Next[Size, Rows] = lambda n, rows: __rows_domain(n).next(n, rows) | |
__rows = Domain(__first, __next) | |
__unique: Predicate[Size, Rows] = lambda n, rows: len(set(rows)) == n | |
__unique_rows = filter_domain(__unique)(__rows) | |
__up: Up[Size, Rows, Board] = lambda n, rows: Board(rows) | |
__down: Down[Size, Rows, Board] = lambda n, board: board | |
__boards = map_domain(__up, __down)(__unique_rows) | |
def __accept(size: Size, board: Board) -> bool: | |
return not any(filter(partial(has_collision, board), row_pairs(size))) | |
n_queens = brute_force(__boards, __accept) |
We can see the performance improve:
from time import time | |
from board import Size | |
from n_queens import n_queens | |
# find solutions for n queens problem given n | |
def solve(n: Size): | |
summary = "size: {:3d}, solutions: {:5d}, time: {:8.3f}s, last: {}" | |
time_start = time() | |
solutions = list(n_queens(n)) | |
time_elapsed = time() - time_start | |
board = solutions[-1] | |
print(summary.format(n, len(solutions), time_elapsed, board)) | |
for n in range(4, 9): | |
solve(Size(n)) | |
# prints | |
# size: 4, solutions: 2, time: 0.002s, last: [2, 0, 3, 1] | |
# size: 5, solutions: 10, time: 0.019s, last: [4, 2, 0, 3, 1] | |
# size: 6, solutions: 4, time: 0.235s, last: [4, 2, 0, 5, 3, 1] | |
# size: 7, solutions: 40, time: 4.182s, last: [6, 4, 2, 0, 5, 3, 1] | |
# size: 8, solutions: 92, time: 94.317s, last: [7, 3, 0, 2, 5, 1, 6, 4] |
This optimization works better as the sizes of the boards increase, but we’re still doing a lot of work generating boards which are being filtered out later. Is there something better we can do?
Looking at this another way
This is where knowing a lot of concepts helps
So far, we’ve been thinking that every row has its own domain, and the board is a product of all the row domains. Which is neat, and it works! But can we do better? If we really think about the values we’re filtering out, and not filtering out, we can see there’s a pattern:
[0, 1, 2, 3] | |
[0, 1, 3, 2] | |
[0, 2, 1, 3] | |
[0, 2, 3, 1] | |
[0, 3, 1, 2] | |
[0, 3, 2, 1] | |
[1, 0, 2, 3] | |
[1, 0, 3, 2] | |
[1, 2, 0, 3] | |
[1, 2, 3, 0] | |
[1, 3, 0, 2] | |
[1, 3, 2, 0] | |
[2, 0, 1, 3] | |
[2, 0, 3, 1] | |
[2, 1, 0, 3] | |
[2, 1, 3, 0] | |
[2, 3, 0, 1] | |
[2, 3, 1, 0] | |
[3, 0, 1, 2] | |
[3, 0, 2, 1] | |
[3, 1, 0, 2] | |
[3, 1, 2, 0] | |
[3, 2, 0, 1] | |
[3, 2, 1, 0] |
All the boards are just permutations of [0, 1, 2, 3]
! What if we generated only permutations? This should be simple to write. Let’s try:
from copy import deepcopy | |
from functools import cache, partial | |
from typing import Tuple, TypeAlias | |
from board import Board, Size, has_collision, row_pairs | |
from domain import Domain, brute_force | |
from map_domain import Down, Up, map_domain | |
Rows: TypeAlias = Tuple[int, ...] | |
@cache | |
def __first(size: Size) -> Rows | None: | |
return list(range(size)) if size > 3 else None | |
def __next_permutation(rows: Rows) -> Rows: | |
next_rows = deepcopy(rows) | |
r = len(next_rows) - 1 | |
while next_rows[r - 1] >= next_rows[r] and r > 0: | |
r -= 1 | |
pivot = r | |
if pivot == 0: | |
next_rows.sort() | |
return next_rows | |
else: | |
swap = len(next_rows) - 1 | |
while next_rows[pivot - 1] >= next_rows[swap] and swap >= 0: | |
swap -= 1 | |
next_rows[pivot - 1], next_rows[swap] = next_rows[swap], next_rows[pivot - 1] | |
next_rows[pivot:] = sorted(next_rows[pivot:]) | |
return next_rows | |
def __next(size: Size, board: Rows) -> Rows | None: | |
next_rows = __next_permutation(board) | |
return None if next_rows == __first(size) else next_rows | |
__rows = Domain(__first, __next) | |
__up: Up[Size, Rows, Board] = lambda n, rows: Board(rows) | |
__down: Down[Size, Rows, Board] = lambda n, board: board | |
__boards = map_domain(__up, __down)(__rows) | |
def __accept(size: Size, board: Board) -> bool: | |
return not any(filter(partial(has_collision, board), row_pairs(size))) | |
n_queens = brute_force(__boards, __accept) |
Notice that since we’re only generating valid rows, we don’t need to use the filter anymore. And here’s our results:
from time import time | |
from board import Size | |
from n_queens import n_queens | |
# find solutions for n queens problem given n | |
def solve(n: Size): | |
summary = "size: {:3d}, solutions: {:5d}, time: {:8.3f}s, last: {}" | |
time_start = time() | |
solutions = list(n_queens(n)) | |
time_elapsed = time() - time_start | |
board = solutions[-1] | |
print(summary.format(n, len(solutions), time_elapsed, board)) | |
for n in range(4, 9): | |
solve(Size(n)) | |
# prints | |
# size: 4, solutions: 2, time: 0.000s, last: [2, 0, 3, 1] | |
# size: 5, solutions: 10, time: 0.001s, last: [4, 2, 0, 3, 1] | |
# size: 6, solutions: 4, time: 0.008s, last: [4, 2, 0, 5, 3, 1] | |
# size: 7, solutions: 40, time: 0.056s, last: [6, 4, 2, 0, 5, 3, 1] | |
# size: 8, solutions: 92, time: 0.351s, last: [7, 3, 0, 2, 5, 1, 6, 4] |
Sure enough, there’s a huge improvement. Now we hardly take a third of a second to find all 92 solutions for a size of 8! That’s pretty cool — all it required was to think of the problem slightly differently. Having a different perspective always helps — a lot! And this is just one such instance. Note that there’s nothing wrong with how we approached the solution initially. Sometimes, there’s just other ways to think about things that drastically simplify things for us.
Conclusion
How far can we go with this?
Well, for one, we just learnt that we can use brute force to solve problems like n-queens. Our initial abstraction of brute-force search seems to be holding up well. But exhaustively searching for solutions from all possible candidates is sometimes not possible. To illustrate this, we’re going to take a larger size, and see how brute-force may not work, and see what else we can come up with.
But that’s for the next post. Until then, have fun!