Nav apraksta

xword.py 10KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. import itertools
  2. import math
  3. import cv2
  4. import numpy as np
  5. import copy
  6. import argparse
  7. def non_greys_to_white(img, threshold=48):
  8. b, g, r = cv2.split(img)
  9. rgb_diff = cv2.subtract(cv2.max(cv2.max(b, g), r), cv2.min(cv2.min(b, g), r))
  10. filtered = img.copy()
  11. filtered[np.where(rgb_diff > threshold)] = (255, 255, 255)
  12. return filtered
  13. def load_image_as_greyscale(file_name, filter_colours, colour_filter_threshold):
  14. img = cv2.imread(file_name)
  15. if img is None:
  16. raise RuntimeError("Failed to load image")
  17. if filter_colours:
  18. img = non_greys_to_white(img, colour_filter_threshold)
  19. return cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  20. def preprocess_image(original, gaussian_blur_size, adaptive_threshold_block_size, adaptive_threshold_mean_adjustment, num_dilations):
  21. img = cv2.GaussianBlur(original, (gaussian_blur_size, gaussian_blur_size), 0)
  22. img = cv2.adaptiveThreshold(img, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, adaptive_threshold_block_size, adaptive_threshold_mean_adjustment)
  23. kernel = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]], np.uint8)
  24. for i in range(num_dilations):
  25. img = cv2.dilate(img, kernel)
  26. return img
  27. def morph_open_image(img, kernel_size, iterations=1):
  28. kernel = cv2.getStructuringElement(cv2.MORPH_RECT, kernel_size)
  29. return cv2.morphologyEx(img, cv2.MORPH_OPEN, kernel, iterations=iterations)
  30. def get_fundamental_frequency(fft):
  31. mag = abs(fft[0:len(fft) // 2])
  32. mag[0] = 0
  33. return int(np.argmax(mag))
  34. def get_line_fft(img, line_detector_element_size, axis):
  35. lines = morph_open_image(img, (line_detector_element_size, 1) if axis == 1 else (1, line_detector_element_size))
  36. return np.fft.fft(np.sum(lines, axis=axis))
  37. def get_line_frequency(img, line_detector_element_size, axis):
  38. return get_fundamental_frequency(get_line_fft(img, line_detector_element_size, axis))
  39. def find_biggest_contour(img):
  40. contours, hierarchy = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  41. biggest = None
  42. max_area = 0
  43. for contour in contours:
  44. area = cv2.contourArea(contour)
  45. if area > max_area:
  46. biggest = contour
  47. max_area = area
  48. return biggest
  49. def erode_contour(img_shape, contour, erosion_kernel_size, iterations):
  50. contour_img = np.zeros(img_shape, dtype=np.uint8)
  51. cv2.drawContours(contour_img, [contour], 0, 255, -1)
  52. contour_img = morph_open_image(contour_img, (erosion_kernel_size, erosion_kernel_size), iterations)
  53. return find_biggest_contour(contour_img)
  54. def get_contour_corners(img, contour):
  55. height, width = img.shape
  56. top_left = [width, height]
  57. top_right = [-1, height]
  58. bottom_left = [width, -1]
  59. bottom_right = [-1, -1]
  60. for vertex in contour:
  61. point = vertex[0]
  62. sum = point[0] + point[1]
  63. diff = point[0] - point[1]
  64. if sum < top_left[0] + top_left[1]:
  65. top_left = point
  66. if sum > bottom_right[0] + bottom_right[1]:
  67. bottom_right = point
  68. if diff < bottom_left[0] - bottom_left[1]:
  69. bottom_left = point
  70. if diff > top_right[0] - top_right[1]:
  71. top_right = point
  72. return top_left, top_right, bottom_right, bottom_left
  73. def segment_length(p1, p2):
  74. dx = p1[0] - p2[0]
  75. dy = p1[1] - p2[1]
  76. return math.sqrt(dx ** 2 + dy ** 2)
  77. def get_longest_side(poly):
  78. previous = poly[-1]
  79. max = 0
  80. for current in poly:
  81. len = segment_length(previous, current)
  82. if len > max:
  83. max = len
  84. previous = current
  85. return max
  86. def extract_square(img, top_left, top_right, bottom_right, bottom_left):
  87. src = [top_left, top_right, bottom_right, bottom_left]
  88. longest = get_longest_side(src)
  89. dst = [[0, 0], [longest - 1, 0], [longest - 1, longest - 1], [0, longest - 1]]
  90. m = cv2.getPerspectiveTransform(np.array(src, dtype=np.float32), np.array(dst, dtype=np.float32))
  91. return cv2.warpPerspective(img, m, (int(longest), int(longest)))
  92. def get_threshold_from_quantile(img, quantile):
  93. height, width = img.shape
  94. num_pixels = height * width
  95. pixels = np.sort(np.reshape(img, num_pixels))
  96. return pixels[int(num_pixels * quantile)]
  97. def extract_grid_colours(img, num_rows, num_cols, sampling_block_size_ratio):
  98. height, width = img.shape
  99. row_delta = int(height * sampling_block_size_ratio / num_rows / 2)
  100. col_delta = int(width * sampling_block_size_ratio / num_cols / 2)
  101. sampling_block_area = (2 * row_delta + 1) * (2 * col_delta + 1)
  102. grid = []
  103. for row in range(num_rows):
  104. line = []
  105. y = int(((row + 0.5) / num_rows) * height)
  106. for col in range(num_cols):
  107. sum = 0
  108. x = int(((col + 0.5) / num_cols) * width)
  109. for dy in range(-row_delta, row_delta + 1):
  110. for dx in range(-col_delta, col_delta + 1):
  111. sum += img[y + dy, x + dx]
  112. line.append(sum / sampling_block_area)
  113. grid.append(line)
  114. return grid
  115. def get_grid_colour_threshold(grid_colours):
  116. pixels = sorted(itertools.chain.from_iterable(grid_colours))
  117. delta_max = -1
  118. i_max = -1
  119. for i in range(1, len(pixels)):
  120. delta = pixels[i] - pixels[i - 1]
  121. if delta > delta_max:
  122. delta_max = delta
  123. i_max = i
  124. return (pixels[i_max] + pixels[i_max - 1]) / 2
  125. def grid_colours_to_blocks(grid_colours, num_rows, num_cols, sampling_threshold):
  126. grid = copy.deepcopy(grid_colours)
  127. warning = False
  128. midpoint = num_rows // 2 + (0 if num_rows % 2 == 0 else 1)
  129. for row in range(midpoint):
  130. for col in range(num_cols):
  131. # If there is an odd number of rows then row and row2 will point to
  132. # the same row when we reach the middle. Doesn't seem worth adding a
  133. # special case.
  134. row2 = num_rows - row - 1
  135. col2 = num_cols - col - 1
  136. delta1 = grid_colours[row][col] - sampling_threshold
  137. delta2 = grid_colours[row2][col2] - sampling_threshold
  138. if (delta1 > 0) and (delta2 > 0):
  139. filled = False
  140. elif (delta1 < 0) and (delta2 < 0):
  141. filled = True
  142. else:
  143. warning = True
  144. if abs(delta1) > abs(delta2):
  145. filled = delta1 < 0
  146. else:
  147. filled = delta2 < 0
  148. grid[row][col] = {'filled': filled}
  149. grid[row2][col2] = {'filled': filled}
  150. number = 1
  151. for row in range(num_rows):
  152. for col in range(num_cols):
  153. if (not grid[row][col]['filled'] and (
  154. (((col == 0) or grid[row][col - 1]['filled']) and (col < num_cols - 1) and not grid[row][col + 1]['filled']) or
  155. (((row == 0) or grid[row - 1][col]['filled']) and (row < num_rows - 1) and not grid[row + 1][col]['filled'])
  156. )):
  157. grid[row][col]['number'] = number
  158. number += 1
  159. return warning, grid
  160. def draw_point(image, point, colour):
  161. height, width, _ = image.shape
  162. for dx in range(-10, 11):
  163. for dy in range(-10, 11):
  164. x = point[0] + dx
  165. y = point[1] + dy
  166. if (x >= 0) and (y >= 0) and (x < width) and (y < height):
  167. image[y, x] = colour
  168. def extract_crossword_grid(
  169. file_name,
  170. callback=None,
  171. remove_colours=False,
  172. colour_removal_threshold=48,
  173. gaussian_blur_size=11,
  174. adaptive_threshold_block_size=11,
  175. adaptive_threshold_mean_adjustment=2,
  176. square=True,
  177. num_dilations=1,
  178. contour_erosion_kernel_size=5,
  179. contour_erosion_iterations=5,
  180. line_detector_element_size=51,
  181. sampling_block_size_ratio=0.25,
  182. sampling_threshold_quantile=0.3,
  183. sampling_threshold=None
  184. ):
  185. warnings = []
  186. original = load_image_as_greyscale(file_name, remove_colours, colour_removal_threshold)
  187. if callback is not None:
  188. callback('original', original)
  189. img = preprocess_image(original, gaussian_blur_size, adaptive_threshold_block_size, adaptive_threshold_mean_adjustment, num_dilations)
  190. if callback is not None:
  191. callback('preprocessed', img)
  192. biggest = find_biggest_contour(img)
  193. biggest = erode_contour(img.shape, biggest, contour_erosion_kernel_size, contour_erosion_iterations)
  194. top_left, top_right, bottom_right, bottom_left = get_contour_corners(img, biggest)
  195. img = extract_square(img, top_left, top_right, bottom_right, bottom_left)
  196. if callback is not None:
  197. callback('pre-fft', img)
  198. num_rows = get_line_frequency(img, line_detector_element_size, 1)
  199. num_cols = get_line_frequency(img, line_detector_element_size, 0)
  200. if square and (num_rows != num_cols):
  201. warnings.append("Crossword is not square")
  202. block_img = extract_square(original, top_left, top_right, bottom_right, bottom_left)
  203. grid_colours = extract_grid_colours(block_img, num_rows, num_cols, sampling_block_size_ratio)
  204. if sampling_threshold is None:
  205. #sampling_threshold = get_threshold_from_quantile(block_img, sampling_threshold_quantile)
  206. sampling_threshold = get_grid_colour_threshold(grid_colours)
  207. else:
  208. sampling_threshold = sampling_threshold
  209. warning, grid = grid_colours_to_blocks(grid_colours, num_rows, num_cols, sampling_threshold)
  210. if warning:
  211. warnings.append("Some blocks may be the wrong colour")
  212. return warnings, grid, num_rows, num_cols, block_img
  213. def draw_grid(
  214. grid,
  215. num_rows,
  216. num_cols,
  217. grid_line_thickness=4,
  218. grid_square_size=64,
  219. grid_border_size=20
  220. ):
  221. step = grid_square_size + grid_line_thickness
  222. grid_height = num_rows * step + grid_line_thickness
  223. grid_width = num_cols * step + grid_line_thickness
  224. output = np.full([2 * grid_border_size + grid_height, 2 * grid_border_size + grid_width], 255, dtype=np.uint8)
  225. cv2.rectangle(output, (grid_border_size, grid_border_size), (grid_border_size + grid_width - 1, grid_border_size + grid_height - 1), 0, -1)
  226. for row in range(num_rows):
  227. y = row * step + grid_line_thickness + grid_border_size
  228. for col in range(num_cols):
  229. if not grid[row][col]['filled']:
  230. x = col * step + grid_line_thickness + grid_border_size
  231. cv2.rectangle(output, (x, y), (x + grid_square_size - 1, y + grid_square_size - 1), 255, -1)
  232. return output