Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- @numba.njit # Optional speedup.
- def day16_num_visited(grid, start): # start == (y, x, dir); dir 0,1,2,3 == S,E,N,W.
- visited = np.full((*grid.shape, 4), False)
- stack = [start]
- while stack:
- y, x, dir = stack.pop()
- while 0 <= y < grid.shape[0] and 0 <= x < grid.shape[1] and not visited[y, x, dir]:
- visited[y, x, dir] = True
- match grid[y, x]:
- case 92: # ord('\\').
- dir = (1, 0, 3, 2)[dir]
- case 47: # ord('/').
- dir = 3 - dir
- case 45: # ord('-').
- if dir in (0, 2):
- dir = 1
- stack.append((y, x, 3))
- case 124: # ord('|').
- if dir in (1, 3):
- dir = 0
- stack.append((y, x, 2))
- y, x = y + (1, 0, -1, 0)[dir], x + (0, 1, 0, -1)[dir]
- return (visited.sum(2) > 0).sum() # visited.any(2).sum() if not numba.
- def day16(s, *, part2=False):
- grid = np.array([[ord(ch) for ch in line] for line in s.splitlines()], np.uint8)
- if not part2:
- return day16_num_visited(grid, (0, 0, 1))
- starts = (
- [(0, x, 0) for x in range(grid.shape[1])]
- + [(grid.shape[0] - 1, x, 2) for x in range(grid.shape[1])]
- + [(y, 0, 1) for y in range(grid.shape[0])]
- + [(y, grid.shape[1] - 1, 3) for y in range(grid.shape[0])]
- )
- return max(day16_num_visited(grid, start) for start in starts)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement