add branching for solving any sudoku

This commit is contained in:
Joseph Hopfmüller
2023-02-10 12:58:23 +01:00
parent 6bbbf72b1e
commit 26cc85b4a4
3 changed files with 139 additions and 28 deletions

162
sudoku.py
View File

@@ -1,11 +1,16 @@
# import colorama
# from termcolor import colored
from copy import deepcopy
from random import choice, randrange
import pandas as pd
import numpy as np
import sys
class cell():
def __init__(self, row, col):
self.possible = set([i for i in range(1,10)])
self.solved = False
self.solution = 0
self.prefilled = False
self.branch = False
self.row = row
self.col = col
@@ -18,11 +23,15 @@ class cell():
retstr += f'with possible values {self.possible}'
return retstr
def set_value(self, value):
def set_value(self, value, branch=False):
self.possible = {value}
self.solution = value
self.solved = True
self.prefilled = True
self.prefilled = not branch
self.branch = branch
def cleanup(self):
self.branch = False
def remove_value(self, value):
if not self.solved:
@@ -54,6 +63,10 @@ class cell():
class sudoku_grid():
def __init__(self, prefilled_cells):
self.grid = [[cell(i,j) for j in range(9)] for i in range(9)]
self.branches = list()
self.branch_points = list()
self.last_hash = None
self.iteration = 0
# prefilled_cells is a nested list (9x9) of values (1-9), 0 specifies an empty cell
try:
@@ -87,6 +100,8 @@ class sudoku_grid():
current_val = self.grid[i][j].solution
if self.grid[i][j].prefilled:
retstr += f'{current_val}'
elif self.grid[i][j].branch:
retstr += f'\033[91m{current_val}\033[00m'
else:
retstr += f'\033[92m{current_val}\033[00m'
else:
@@ -104,12 +119,24 @@ class sudoku_grid():
retstr += '\n ╚═══╧═══╧═══╩═══╧═══╧═══╩═══╧═══╧═══╝'
return retstr
def iterate(self):
try:
self.iterate1()
self.iterate2()
self.check_branching()
self.solvable()
self.iteration += 1
except:
print('No solution found')
print(self)
sys.exit()
# remove values based on solved cells
def iterate1(self):
# iterate over all cells
for i in range(9):
for j in range(9):
# # remove posibble values based on solved cells
for j in range(9):
# # remove possible values based on solved cells
current_value = self.grid[i][j].solution
if current_value:
for k in range(9):
@@ -122,6 +149,7 @@ class sudoku_grid():
current_cell = self.grid[(i//3)*3+(i+k)%3][(j//3)*3+(j+l)%3]
current_cell.remove_value(current_value) # remove value from current 3x3 box
# remove values based on possible values in other cells of row/column/square
def iterate2(self):
for i in range(9):
for j in range(9):
@@ -158,9 +186,41 @@ class sudoku_grid():
current_cell2 = self.grid[row][col]
current_set = current_set.union(current_cell2.possible)
current_cell.collapse(current_set)
def cleanup(self):
for i in range(9):
for j in range(9):
self.grid[i][j].cleanup()
def check_branching(self):
if self.invalid_states():
self.rollback_branch()
new_hash = hash(str(self))
if new_hash == self.last_hash:
[row, col, e, sols] = self.find_lowest_entropy()
self.branch(row, col, sols)
self.last_hash = new_hash
# copy current solution and branch point (branch_point: row, column, selected_solution)
def save_branch(self, branch_point):
self.branches.append(deepcopy(self.grid))
self.branch_points.append(branch_point)
# revert latest branch
def rollback_branch(self):
solution = self.branches.pop()
branch_point = self.branch_points.pop()
self.grid = solution
current_cell = self.grid[branch_point[0]][branch_point[1]]
current_cell.remove_value(branch_point[2])
print('rollback branch')
def branch(self, row, col, solutions):
current_cell = self.grid[row][col]
selected_value = choice(solutions)
self.save_branch((row, col, selected_value))
print(f'Branched: ({row},{col}): {current_cell.possible}, chose {selected_value}')
current_cell.set_value(selected_value, branch=True)
def find_lowest_entropy(self):
lowest_i = -1
@@ -191,8 +251,7 @@ class sudoku_grid():
return None
possible = self.grid[row][col].possible
if len(possible) == 1:
self.grid[row][col].solved = list(possible)[0]
self.grid[row][col].solved = list(possible)[0]
def single_solutions_exist(self):
for i in range(9):
@@ -208,8 +267,38 @@ class sudoku_grid():
if not self.grid[i][j].solved:
return False
return True
def solvable(self):
if len(self.branches) == 0:
for i in range(9):
for j in range(9):
if len(self.grid[i][j]) == 0:
raise Exception
return True
def invalid_states(self):
for i in range(9):
row_list = list()
col_list = list()
for j in range(9):
row_list.append(self.grid[i][j].solution)
col_list.append(self.grid[j][i].solution)
if len(self.grid[i][j]) == 0:
return True
row_list.sort()
col_list.sort()
row_list = [x for x in row_list if x != 0]
col_list = [x for x in col_list if x != 0]
row_comp = list(set(row_list))
col_comp = list(set(col_list))
row_comp.sort()
col_comp.sort()
if row_comp != row_list:
return True
if col_comp != col_list:
return True
return False
iteration = 1
if __name__ == '__main__':
# colorama.init()
# prefilled = [ [0,4,9, 7,0,5, 0,0,0],
@@ -224,26 +313,45 @@ if __name__ == '__main__':
# [0,6,0, 1,0,0, 0,0,0],
# [4,0,5, 0,0,0, 3,0,2]]
prefilled = [ [0,0,7, 0,0,5, 0,0,3],
[0,0,9, 0,6,0, 0,0,0],
[3,6,0, 0,0,8, 2,0,0],
# prefilled = [ [0,0,7, 0,0,5, 0,0,3],
# [0,0,9, 0,6,0, 0,0,0],
# [3,6,0, 0,0,8, 2,0,0],
[0,0,6, 0,0,0, 0,0,0],
[5,1,0, 0,8,0, 0,0,9],
[0,0,0, 0,0,2, 0,4,0],
# [0,0,6, 0,0,0, 0,0,0],
# [5,1,0, 0,8,0, 0,0,9],
# [0,0,0, 0,0,2, 0,4,0],
[0,0,0, 5,0,0, 9,0,0],
[8,3,0, 0,1,0, 0,0,5],
[7,0,0, 0,0,0, 0,0,0]]
# [0,0,0, 5,0,0, 9,0,0],
# [8,3,0, 0,1,0, 0,0,5],
# [7,0,0, 0,0,0, 0,0,0]]
sudoku = sudoku_grid(prefilled)
prefilled = [ [2,0,4, 0,6,1, 0,0,9],
[0,1,0, 0,0,4, 0,0,0],
[0,7,0, 0,0,0, 0,2,0],
[0,2,0, 0,0,0, 0,0,0],
[8,0,3, 0,0,7, 0,9,0],
[0,0,0, 5,0,0, 0,0,6],
[4,0,9, 0,0,3, 0,8,0],
[0,0,1, 0,0,0, 0,0,0],
[0,0,0, 0,7,0, 3,0,0]]
df = pd.read_csv('sudoku.csv')
row = randrange(len(df))
quiz = df.loc[row]['quizzes']
# print(quiz)
quiz = list(quiz)
quiz = [int(x) for x in quiz]
quiz = np.reshape(quiz, (9, 9)).tolist()
# sudoku = sudoku_grid(prefilled)
sudoku = sudoku_grid(quiz)
print(sudoku)
while not sudoku.is_solved():
sudoku.iterate1()
print(f'Iteration {iteration}a')
print(sudoku)
sudoku.iterate2()
print(f'Iteration {iteration}b')
print(sudoku)
iteration += 1
sudoku.iterate()
# print(f'Iteration {sudoku.iteration}')
# print(sudoku)
sudoku.cleanup()
print('Solved!')
print(sudoku)