Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- @numba.njit
- def day20_jit(distance, ys, xs, ye, xe, part2, min_savings):
- d, y, x = 0, ys, xs
- distance[y, x] = d
- while (y, x) != (ye, xe):
- for y, x in ((y, x - 1), (y, x + 1), (y - 1, x), (y + 1, x)):
- if distance[y, x] == -1:
- break
- d += 1
- distance[y, x] = d
- radius = 20 if part2 else 2
- count = 0
- for (y, x), d in np.ndenumerate(distance):
- if d >= 0:
- for y2 in range(y, min(y + radius + 1, distance.shape[0] - 1)):
- radius_x = radius - abs(y2 - y)
- xt = x - radius_x if y2 > y else x + 1
- for x2 in range(max(xt, 0), min(x + radius_x + 1, distance.shape[1] - 1)):
- d2 = distance[y2, x2]
- if d2 >= 0:
- savings = abs(d2 - d) - abs(y2 - y) - abs(x2 - x)
- if savings >= min_savings:
- count += 1
- return count
- def day20(s, *, part2=False, min_savings=100):
- grid = np.array([list(line) for line in s.splitlines()])
- ((ys, xs),) = np.argwhere(grid == 'S')
- ((ye, xe),) = np.argwhere(grid == 'E')
- distance = np.where(grid == '#', -2, -1)
- return day20_jit(distance, ys, xs, ye, xe, part2, min_savings)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement