add branching for solving any sudoku

This commit is contained in:
Joseph Hopfmüller
2023-02-10 12:58:23 +01:00
parent 6bbbf72b1e
commit 26cc85b4a4
3 changed files with 139 additions and 28 deletions

1
.gitignore vendored
View File

@@ -160,3 +160,4 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/ #.idea/
.csv

View File

@@ -1,3 +1,5 @@
# sudoku_collapse # sudoku_collapse
a sudoku solver using wave function collapse a sudoku solver using wave function collapse (ish)
data from https://www.kaggle.com/datasets/bryanpark/sudoku

156
sudoku.py
View File

@@ -1,11 +1,16 @@
# import colorama from copy import deepcopy
# from termcolor import colored from random import choice, randrange
import pandas as pd
import numpy as np
import sys
class cell(): class cell():
def __init__(self, row, col): def __init__(self, row, col):
self.possible = set([i for i in range(1,10)]) self.possible = set([i for i in range(1,10)])
self.solved = False self.solved = False
self.solution = 0 self.solution = 0
self.prefilled = False self.prefilled = False
self.branch = False
self.row = row self.row = row
self.col = col self.col = col
@@ -18,11 +23,15 @@ class cell():
retstr += f'with possible values {self.possible}' retstr += f'with possible values {self.possible}'
return retstr return retstr
def set_value(self, value): def set_value(self, value, branch=False):
self.possible = {value} self.possible = {value}
self.solution = value self.solution = value
self.solved = True self.solved = True
self.prefilled = True self.prefilled = not branch
self.branch = branch
def cleanup(self):
self.branch = False
def remove_value(self, value): def remove_value(self, value):
if not self.solved: if not self.solved:
@@ -54,6 +63,10 @@ class cell():
class sudoku_grid(): class sudoku_grid():
def __init__(self, prefilled_cells): def __init__(self, prefilled_cells):
self.grid = [[cell(i,j) for j in range(9)] for i in range(9)] self.grid = [[cell(i,j) for j in range(9)] for i in range(9)]
self.branches = list()
self.branch_points = list()
self.last_hash = None
self.iteration = 0
# prefilled_cells is a nested list (9x9) of values (1-9), 0 specifies an empty cell # prefilled_cells is a nested list (9x9) of values (1-9), 0 specifies an empty cell
try: try:
@@ -87,6 +100,8 @@ class sudoku_grid():
current_val = self.grid[i][j].solution current_val = self.grid[i][j].solution
if self.grid[i][j].prefilled: if self.grid[i][j].prefilled:
retstr += f'{current_val}' retstr += f'{current_val}'
elif self.grid[i][j].branch:
retstr += f'\033[91m{current_val}\033[00m'
else: else:
retstr += f'\033[92m{current_val}\033[00m' retstr += f'\033[92m{current_val}\033[00m'
else: else:
@@ -104,12 +119,24 @@ class sudoku_grid():
retstr += '\n ╚═══╧═══╧═══╩═══╧═══╧═══╩═══╧═══╧═══╝' retstr += '\n ╚═══╧═══╧═══╩═══╧═══╧═══╩═══╧═══╧═══╝'
return retstr return retstr
def iterate(self):
try:
self.iterate1()
self.iterate2()
self.check_branching()
self.solvable()
self.iteration += 1
except:
print('No solution found')
print(self)
sys.exit()
# remove values based on solved cells
def iterate1(self): def iterate1(self):
# iterate over all cells # iterate over all cells
for i in range(9): for i in range(9):
for j in range(9): for j in range(9):
# # remove posibble values based on solved cells # # remove possible values based on solved cells
current_value = self.grid[i][j].solution current_value = self.grid[i][j].solution
if current_value: if current_value:
for k in range(9): for k in range(9):
@@ -122,6 +149,7 @@ class sudoku_grid():
current_cell = self.grid[(i//3)*3+(i+k)%3][(j//3)*3+(j+l)%3] current_cell = self.grid[(i//3)*3+(i+k)%3][(j//3)*3+(j+l)%3]
current_cell.remove_value(current_value) # remove value from current 3x3 box current_cell.remove_value(current_value) # remove value from current 3x3 box
# remove values based on possible values in other cells of row/column/square
def iterate2(self): def iterate2(self):
for i in range(9): for i in range(9):
for j in range(9): for j in range(9):
@@ -159,8 +187,40 @@ class sudoku_grid():
current_set = current_set.union(current_cell2.possible) current_set = current_set.union(current_cell2.possible)
current_cell.collapse(current_set) current_cell.collapse(current_set)
def cleanup(self):
for i in range(9):
for j in range(9):
self.grid[i][j].cleanup()
def check_branching(self):
if self.invalid_states():
self.rollback_branch()
new_hash = hash(str(self))
if new_hash == self.last_hash:
[row, col, e, sols] = self.find_lowest_entropy()
self.branch(row, col, sols)
self.last_hash = new_hash
# copy current solution and branch point (branch_point: row, column, selected_solution)
def save_branch(self, branch_point):
self.branches.append(deepcopy(self.grid))
self.branch_points.append(branch_point)
# revert latest branch
def rollback_branch(self):
solution = self.branches.pop()
branch_point = self.branch_points.pop()
self.grid = solution
current_cell = self.grid[branch_point[0]][branch_point[1]]
current_cell.remove_value(branch_point[2])
print('rollback branch')
def branch(self, row, col, solutions):
current_cell = self.grid[row][col]
selected_value = choice(solutions)
self.save_branch((row, col, selected_value))
print(f'Branched: ({row},{col}): {current_cell.possible}, chose {selected_value}')
current_cell.set_value(selected_value, branch=True)
def find_lowest_entropy(self): def find_lowest_entropy(self):
lowest_i = -1 lowest_i = -1
@@ -193,7 +253,6 @@ class sudoku_grid():
if len(possible) == 1: if len(possible) == 1:
self.grid[row][col].solved = list(possible)[0] self.grid[row][col].solved = list(possible)[0]
def single_solutions_exist(self): def single_solutions_exist(self):
for i in range(9): for i in range(9):
for j in range(9): for j in range(9):
@@ -209,7 +268,37 @@ class sudoku_grid():
return False return False
return True return True
iteration = 1 def solvable(self):
if len(self.branches) == 0:
for i in range(9):
for j in range(9):
if len(self.grid[i][j]) == 0:
raise Exception
return True
def invalid_states(self):
for i in range(9):
row_list = list()
col_list = list()
for j in range(9):
row_list.append(self.grid[i][j].solution)
col_list.append(self.grid[j][i].solution)
if len(self.grid[i][j]) == 0:
return True
row_list.sort()
col_list.sort()
row_list = [x for x in row_list if x != 0]
col_list = [x for x in col_list if x != 0]
row_comp = list(set(row_list))
col_comp = list(set(col_list))
row_comp.sort()
col_comp.sort()
if row_comp != row_list:
return True
if col_comp != col_list:
return True
return False
if __name__ == '__main__': if __name__ == '__main__':
# colorama.init() # colorama.init()
# prefilled = [ [0,4,9, 7,0,5, 0,0,0], # prefilled = [ [0,4,9, 7,0,5, 0,0,0],
@@ -224,26 +313,45 @@ if __name__ == '__main__':
# [0,6,0, 1,0,0, 0,0,0], # [0,6,0, 1,0,0, 0,0,0],
# [4,0,5, 0,0,0, 3,0,2]] # [4,0,5, 0,0,0, 3,0,2]]
prefilled = [ [0,0,7, 0,0,5, 0,0,3], # prefilled = [ [0,0,7, 0,0,5, 0,0,3],
[0,0,9, 0,6,0, 0,0,0], # [0,0,9, 0,6,0, 0,0,0],
[3,6,0, 0,0,8, 2,0,0], # [3,6,0, 0,0,8, 2,0,0],
[0,0,6, 0,0,0, 0,0,0], # [0,0,6, 0,0,0, 0,0,0],
[5,1,0, 0,8,0, 0,0,9], # [5,1,0, 0,8,0, 0,0,9],
[0,0,0, 0,0,2, 0,4,0], # [0,0,0, 0,0,2, 0,4,0],
[0,0,0, 5,0,0, 9,0,0], # [0,0,0, 5,0,0, 9,0,0],
[8,3,0, 0,1,0, 0,0,5], # [8,3,0, 0,1,0, 0,0,5],
[7,0,0, 0,0,0, 0,0,0]] # [7,0,0, 0,0,0, 0,0,0]]
sudoku = sudoku_grid(prefilled) prefilled = [ [2,0,4, 0,6,1, 0,0,9],
[0,1,0, 0,0,4, 0,0,0],
[0,7,0, 0,0,0, 0,2,0],
[0,2,0, 0,0,0, 0,0,0],
[8,0,3, 0,0,7, 0,9,0],
[0,0,0, 5,0,0, 0,0,6],
[4,0,9, 0,0,3, 0,8,0],
[0,0,1, 0,0,0, 0,0,0],
[0,0,0, 0,7,0, 3,0,0]]
df = pd.read_csv('sudoku.csv')
row = randrange(len(df))
quiz = df.loc[row]['quizzes']
# print(quiz)
quiz = list(quiz)
quiz = [int(x) for x in quiz]
quiz = np.reshape(quiz, (9, 9)).tolist()
# sudoku = sudoku_grid(prefilled)
sudoku = sudoku_grid(quiz)
print(sudoku) print(sudoku)
while not sudoku.is_solved(): while not sudoku.is_solved():
sudoku.iterate1() sudoku.iterate()
print(f'Iteration {iteration}a') # print(f'Iteration {sudoku.iteration}')
print(sudoku) # print(sudoku)
sudoku.iterate2() sudoku.cleanup()
print(f'Iteration {iteration}b') print('Solved!')
print(sudoku)
iteration += 1
print(sudoku) print(sudoku)