A Fast Sudoku Solver

Everyone has to write a Sudoku solver!

A few months back I was in an hotel room at 4 am in the morning when the thermostat inexplicably decided to heat my room to 90 degrees Fahrenheit. That woke me up, and since I had nothing to do for a few hours I decided to write a Sudoku solver (because everyone has to write one).

I hadn’t thought about it since then, until I ran across this post on reddit and decided to run the same puzzles through my solver and see how it did.

The upshot is that I can do all 49151 puzzles from the post above in about 650ms. For comparison, the Haskell solution referenced above does it in about 35s. We’re using different machines etc. so it’s not the most scientific comparison in the world (but I doubt my PC is 50x faster), and I don’t really know how much of this is C vs Haskell speed and how much of it is my implementation choices so I’m really not trying to rag on their solution or Haskell (in fact, you should read the posts, they’re well written and I enjoyed reading Haskell again after many years!).

Anyway, since my solution is slightly different, and I haven’t blogged in a long time, this seemed worth writing up.

Before I wrote any code I went to Peter Norvig’s Sudoko post just because I remember reading it a few years ago and it seemed like a good place to refresh my memory of the Sudoku rules, and some of the terminology. I didn’t follow along with his solution, so ended up with a slightly different way of doing it, but if the method names and terminology seems similar, that’s why.

Performance from the start

One of the things I wanted to do when playing around with this was to make sure things were written to be fast (the way you write code when working on games, or other performance sensitive code).

One of my pet peeves is the “Premature optimization is the root of all evil”-crowd, where it almost seems as if people write intentionally slow code because giving any consideration to performance while writing the code is “evil”. This is a gross mischaracterization of what Knuth actually said, and it is certainly not how you write fast systems. For fast systems, you have to consider performance up front and throughout - that doesn’t mean you optimize every line of code with a profiler right away, just that you write it with reasonable performance choices.

The code

I knew that this was going to involve a lot of brute force searching and copying of the board state, which meant I wanted to keep it simple and without allocations or other overhead (which isn’t needed for this solution). The main data structure looks like this:

 1 2 3 4 5 6 7 8 9101112131415161718192021
struct board {
  
  // Contains a 9-bit mask per square, indicating possible digits in each square
  uint16_t squares[9][9];

  // For each "unit" (rows, cols, blocks) and each digit (0..8), store a bit mask
  // indicating where this digit is still possible. This could be reconstructed from
  // the squares array above, but it's faster to maintain a redundant "reverse index".
  uint16_t possible_digit_locations_blocks[3][3][9];
  uint16_t possible_digit_locations_rows[9][9];
  uint16_t possible_digit_locations_cols[9][9];

   board() {
    std::fill_n((uint16_t*)squares, 81, (uint16_t)0x1ff);
    std::fill_n((uint16_t*)possible_digit_locations_rows, 9 * 9, (uint16_t)0x1FF);
    std::fill_n((uint16_t*)possible_digit_locations_cols, 9 * 9, (uint16_t)0x1FF);
    std::fill_n((uint16_t*)possible_digit_locations_blocks, 9 * 9, (uint16_t)0x1FF);
  }

  //.... methods go here ...
}

First, I have a 9x9 grid of squares, each one contains a 9-bit bitmask of possible digits that this square can contain (thus they are initialized to 9 ones, or 0x1FF). I use a uint16_t because it’s the smallest type that fits 9 bits, the top 7 bits are unused). Some other solutions start off by using some abstract “set” data structure, which IMO seems like overkill. It seems fair to assume that most programmers will know (or can learn) how one might use a bit field to represent small sets, but it wouldn’t be hard to wrap that uint16_t in a struct with some set operations on it.

This is the main board state and once each of these bit masks have only one bit set, we are done and have a solution. Note that I assume all digits are zero-based for convenience (i.e. I subtract one when loading, and add one when printing).

Next, and here’s where my solution is a bit different from others I’ve seen, I maintain a “reverse index” for each digit. So for each digit in each “unit” (row, column or 3x3 block), it maintains a bit mask of where that digit is still a possibility. Initially each digit is possible in every location of the unit, so they too are initialized to 0x1FF. Having this reverse index means I don’t have to search through all of the “peers” (i.e. squares in the same row, column or block) in order to elimiate a digit - I can directly find which squares need it eliminated. It also means I can easily figure out when a given digit only has one remaining possible location (see further below).

In order to copy this board state, a simple memcpy will do. There are no allocations, and this will all live on the stack during the search.

In a few places below I need to loop over all the set bits in one of these bit masks, so I have this simple macro for that.

123456
#define FOREACH_SET_BIT(digit_mask_input, body) { \
  unsigned long bit_index; uint16_t digit_mask = (digit_mask_input); \
  while (_BitScanForward(&bit_index, digit_mask)) { \
    digit_mask &= digit_mask-1; body\
  }\
}

_BitScanForward just gives you the index of the least significant set bit (on MSVC, other compilers have similar intrinsics), and the magic bit fiddling clears the first set bit. Then it runs the loop body with the variable bit_index set to the index of the bit.

The meat of the algorithm is the eliminate method. It takes a location and a set of digits to eliminate and also propagates the contraints flowing from that elimination.

 1 2 3 4 5 6 7 8 91011121314
bool eliminate(int x, int y, uint16_t digits_to_eliminate) {    
    digits_to_eliminate &= squares[x][y]; // Only eliminate digits that exists in square

    if (digits_to_eliminate == 0) {
        return true; // already eliminated.
    }

    squares[x][y] &= ~digits_to_eliminate; // clear digit

    uint16_t remaining_digits = squares[x][y];

    if (__popcnt16(remaining_digits) == 0) {
        return false; // contradiction found, no possible digits left.
    }

First, we just clear out the right bits in the current square. If we don’t have any remaining digits we have found a contradiction and can early out (this implies that we’ve made a mistake earlier and have to backtrack).

 1 2 3 4 5 6 7 8 9101112131415
int block_x = x / 3;
int block_y = y / 3;
int block_bit_index = (x%3) + 3*(y%3);

// Clear out the "reverse index" 
FOREACH_SET_BIT(digits_to_eliminate, 
    assert(possible_digit_locations_cols[x][bit_index] & (1 << y));
    possible_digit_locations_cols[x][bit_index] &= ~(1 << y);

    assert(possible_digit_locations_rows[y][bit_index] & (1 << x));
    possible_digit_locations_rows[y][bit_index] &= ~(1 << x);

    assert(possible_digit_locations_blocks[block_x][block_y][bit_index] & (1 << block_bit_index));
    possible_digit_locations_blocks[block_x][block_y][bit_index] &= ~(1 << block_bit_index);
);  

We also have to update the “reverse index”. To do this we loop through all the bits that we’re eliminating, and for each one we clear out the current coordinate in its bitmasks.

One slightly interesting thing to note here is that I’m intentionally keeping things very simple in terms of abstraction and complexity. I don’t make any attempt to abstract over the “units” or anything like that. Why? Because we only have three kinds of units, not twenty, so it really doesn’t buy you much. Yes, I’m kinda repeating myself three times here (and I will again below), but in return you get code that has fewer layers of abstractions for people to internalize and understand.

Plus, as I mentioned above I wanted to write this with performance in mind from the start. That means not introducing abstractions in the inner loop in order to avoid minor duplication. If two choices are approximately of equal value (each has pros an cons) and one of them is likely to be faster, then pick the likely-to-be-faster one. You still need to measure at some point to see if you were right, but you can often make decent guesses too. So let’s hold off on adding in abstractions until they buy us a lot more in terms of readability.

Anyway, back to the code. At this stage the board state is updated with the elimination, and now we have to handle the consequences of the elimination.

 1 2 3 4 5 6 7 8 9101112131415161718192021222324252627282930313233343536
// If we've eliminated all but one digit, then we should eliminate that digit from all the peers.
if (__popcnt16(remaining_digits) == 1) {
    int remaining_digit_index = get_set_bit(remaining_digits);

    // Get all the positions where this digit is set in the current row, column and block,
    // and eliminate them from those squares.
    
    // Start with the current row.
    uint16_t remaining_pos_mask = possible_digit_locations_rows[y][remaining_digit_index];
    remaining_pos_mask &= ~(1 << x); // Don't eliminate from the current square.
    FOREACH_SET_BIT(remaining_pos_mask,       
        if (!eliminate(bit_index,y, remaining_digits)) {
            return false;
        }
    )

    // Next eliminate it from the column.
    remaining_pos_mask = possible_digit_locations_cols[x][remaining_digit_index];
    remaining_pos_mask &= ~(1 << y); // Don't eliminate from the current square
    FOREACH_SET_BIT(remaining_pos_mask,
        if (!eliminate(x, bit_index, remaining_digits)) {
            return false;
        }
    )

    // Next eliminate it from the current block
    remaining_pos_mask = possible_digit_locations_blocks[block_x][block_y][remaining_digit_index];
    remaining_pos_mask &= ~(1 << block_bit_index); // Don't eliminate from the current square
    FOREACH_SET_BIT(remaining_pos_mask,
        int x_offset = bit_index % 3;
        int y_offset = bit_index / 3;
        if (!eliminate(block_x*3 + x_offset, block_y *3 + y_offset, remaining_digits)) {
            return false;
        }
    )
}

The first and most obvious rule is to check if a square has been reduce to only one possibility, which means it’s “done”, and by the rules of Sudoku, that digit needs to be eliminated from all other squares amongst its peers. To do that, we look at the “reverse index” to find any places where the digit exists, and clear it out by recursively calling eliminate on those squares. Again, we just do mostly the same thing for the row, column, and block (and again, while there’s some duplicated logic here, it’s fairly brain-dead duplication, and understanding the abstraction required to avoid it would likely add more cognitive overhead than it saves).

 1 2 3 4 5 6 7 8 910111213141516171819202122232425262728293031
    // For each digit we just eliminated, find if it now only has one remaining posible location
    // in either the row, column or block. If so, set the digit
    FOREACH_SET_BIT(digits_to_eliminate,
        // Check the row
        if (__popcnt16(possible_digit_locations_rows[y][bit_index]) == 1) {
            int digit_x = get_set_bit(possible_digit_locations_rows[y][bit_index]);
            if (!set_digit(digit_x, y, bit_index)) {
                return false;
            }
        }

        // Column
        if (__popcnt16(possible_digit_locations_cols[x][bit_index]) == 1) {
            int digit_y = get_set_bit(possible_digit_locations_cols[x][bit_index]);
            if (!set_digit(x, digit_y, bit_index)) {
                return false;
            }
        }

        // Block
        if (__popcnt16(possible_digit_locations_blocks[block_x][block_y][bit_index]) == 1) {
            int bit_with_digit = get_set_bit(possible_digit_locations_blocks[block_x][block_y][bit_index]);
            int x_offset = bit_with_digit % 3;
            int y_offset = bit_with_digit / 3;
            if (!set_digit(block_x*3 + x_offset, block_y *3 + y_offset, bit_index)) {
                return false;
            }
        }
    )
    return true; // Successfully eliminated the digits
}

Finally, we check if any digit has only one possible location, and if so we “set” it by calling set_digit. Again, this code uses the reverse index to find all the places where eliminated digits are still possible, and if any of them only have one bit set, then that must be where that digit goes.

The set_digit function is just a trivial helper:

 1 2 3 4 5 6 7 8 91011
bool set_digit(int x, int y, int d) {
    int digit_mask = 1 << d;
    assert(squares[x][y] & digit_mask);

    if (!eliminate(x, y, ~digit_mask)) {
        return false;
    }

    assert(squares[x][y] == digit_mask);
    return true;
}

Next, we just have to implement the search itself.

 1 2 3 4 5 6 7 8 910111213141516171819202122232425
bool search(const board& current_board, board& final_board) {
  // Find the next square to do a trial assignment on
  int x, y;
  if (current_board.find_next_square_to_assign(x, y)) {
    uint16_t digits = current_board.squares[x][y];

    // Then assign each possible digit to this square
    FOREACH_SET_BIT(digits, {
      board board_copy = current_board;
      // If we can successfully set this digit, do search from here
      if (board_copy.set_digit(x, y, bit_index)) {
        if (search(board_copy, final_board)) {
          return true;
        }
      }
    })
  }
  else {
    // No more squares to assign, so we're done!    
    assert(current_board.is_solved());
    final_board = current_board;
    return true;
  }
  return false;
}

All we’re really doing is taking a candidate board, picking a square that has more than one possibility, setting it to one of those possibilities, and recursing until we either get a contradiction (eliminate returns false) or we have no more squares with more than one possibilities (which means the board is solved). The search function picks the square with the fewest number of possible digits first, in order to maximize the chances that we pick the right digit - e.g. if we set random digit in a square with only two options we have a 50% chance of getting it right, but if it has 4 options it’s only 25%.

Conclusion

So there you have it, my Sudoku solver. You can get the full code here. I didn’t bother directly checking for “naked pairs” or any other strategies, because honestly at an average of 13µs per puzzle it just doesn’t seem worth it. Same goes for parallelization.

I know Sudoku solvers are a bit CS 101, but perhaps it can at least illustrate that if you make some fairly minor performance considerations up front and throughout, you can get a reasonably fast solution from the start without having to jump throgh any crazy hoops, or even optimizing in any serious way at all.

Update

After posting on twitter, some people pointed out a few things. Fabian Giesen and Peter Alexander both pointed out that the same trick i use in FOREACH_SET_BIT to clear the LSB can be used to check for popcnt == 1 (for non-zero masks), which is slightly cheaper. I also discovered a timing bug where the initial “elimination” that happens while constructing the board wasn’t getting included in the timing (this was under-measuring the actual time by a factor of ~3x, so fairly significant). I’ve updated the code to fix that (and the timing numbers above are now referencing that version), but left the in-line code snippets the same.

Comment Form is loading comments...