Advertisement
hhoppe

Advent of code 2024 day 6 jit fast 12 ms

Dec 6th, 2024 (edited)
94
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.89 KB | None | 0 0
  1. @numba.njit
  2. def day6_jit(grid: np.ndarray, y0: int, x0: int, part2: bool):
  3.   y, x, dy, dx = y0, x0, -1, 0
  4.   count = 0
  5.   while True:
  6.     if grid[y, x] != 'X':
  7.       count += 1
  8.     grid[y, x] = 'X'
  9.     y1, x1 = y + dy, x + dx
  10.     if not (0 <= y1 < grid.shape[0] and 0 <= x1 < grid.shape[1]):
  11.       break
  12.     if grid[y1, x1] == '#':
  13.       dy, dx = dx, -dy  # Rotate clockwise.
  14.     else:
  15.       y, x = y1, x1
  16.  
  17.   if not part2:
  18.     return count
  19.  
  20.   big = max(grid.shape)
  21.   jump_steps = np.empty((4, *grid.shape), np.int32)
  22.   for dir in range(4):
  23.     rotated_grid = np.rot90(grid, k=dir)
  24.     rotated_jump_steps = np.rot90(jump_steps[dir], k=dir)
  25.     for y, row in enumerate(rotated_grid):
  26.       num = big
  27.       for x, ch in enumerate(row):
  28.         num = -1 if ch == '#' else num + 1
  29.         rotated_jump_steps[y, x] = num
  30.  
  31.   grid[y0, x0] = '^'
  32.   dydx_from_dir = [(0, -1), (-1, 0), (0, 1), (1, 0)]
  33.   count = 0
  34.   for (obstacle_y, obstacle_x), ch in np.ndenumerate(grid):
  35.     if ch == 'X':  # Candidate obstacle locations must lie on the original path.
  36.       y, x, dir = y0, x0, 1
  37.       visited = set()
  38.       while True:
  39.         dy, dx = dydx_from_dir[dir]
  40.         steps = jump_steps[dir, y, x]
  41.         if obstacle_y == y:
  42.           d = obstacle_x - x
  43.           if d * dx > 0:
  44.             steps = min(steps, abs(d) - 1)
  45.         if obstacle_x == x:
  46.           d = obstacle_y - y
  47.           if d * dy > 0:
  48.             steps = min(steps, abs(d) - 1)
  49.         if steps >= big:
  50.           break
  51.         y, x = y + dy * steps, x + dx * steps
  52.         state = y, x, dir
  53.         if state in visited:
  54.           count += 1
  55.           break
  56.         visited.add(state)
  57.         dir = (dir + 1) % 4  # Rotate clockwise.
  58.  
  59.   return count
  60.  
  61.  
  62. def day6(s, part2=False):
  63.   grid = np.array([list(line) for line in s.splitlines()])
  64.   ((y0, x0),) = np.argwhere(grid == '^')
  65.   return day6_jit(grid, y0, x0, part2)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement