Advertisement
hhoppe

Advent of code 2021 day 19

Dec 19th, 2021
1,097
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.71 KB | None | 0 0
  1. def process1(s, part2=False):
  2.   scanners = []
  3.   for i, s2 in enumerate(s.strip().split('\n\n')):
  4.     lines = s2.split('\n')
  5.     array = np.array([list(map(int, line.split(','))) for line in lines[1:]])
  6.     scanners.append(array)
  7.  
  8.   scanner_transforms = [{i: np.eye(4, dtype=int)} for i in range(len(scanners))]
  9.   scanner_rep = list(range(len(scanners)))
  10.   TRANSFORMS = tuple(
  11.       np.array(matrix)
  12.       for diag in itertools.product([-1, 1], repeat=3)
  13.       for matrix in itertools.permutations(np.diag(diag))
  14.   )
  15.  
  16.   def compute_signatures():
  17.     all_signatures = []
  18.     for scanner in scanners:
  19.       n = len(scanner)
  20.       comb = itertools.chain.from_iterable(itertools.combinations(range(n), 2))
  21.       indices = np.fromiter(comb, dtype=int, count=n * (n - 1))
  22.       points = scanner[indices].reshape(-1, 2, 3)
  23.       diff = np.sort(abs(points[:, 1] - points[:, 0]), axis=-1)
  24.       signatures = dict(zip((tuple(e) for e in diff),
  25.                             (tuple(e) for e in indices.reshape(-1, 2))))
  26.       all_signatures.append(signatures)
  27.     return all_signatures
  28.  
  29.   all_signatures = compute_signatures()
  30.   all_signature_sets = [set(signature) for signature in all_signatures]
  31.   intersection_counts = [
  32.       (len(all_signature_sets[i] & all_signature_sets[j]), i, j)
  33.       for i, j in itertools.combinations(range(len(scanners)), 2)]
  34.  
  35.   for _, i, j in sorted(intersection_counts, reverse=True):
  36.     ir, jr = scanner_rep[i], scanner_rep[j]
  37.     if ir == jr:
  38.       continue  # Already joined.
  39.     intersection = all_signature_sets[i] & all_signature_sets[j]
  40.  
  41.     match_count = collections.defaultdict(lambda: collections.defaultdict(int))
  42.     for encoding, indices_i in all_signatures[i].items():
  43.       if encoding in intersection:
  44.         indices_j = all_signatures[j][encoding]
  45.         for index_i in indices_i:
  46.           for index_j in indices_j:
  47.             match_count[index_i][index_j] += 1
  48.  
  49.     index_mapping = {}
  50.     for index_i in range(len(scanners[i])):
  51.       max_count, index_j = max(
  52.           ((count, index) for index, count in match_count[index_i].items()),
  53.           default=(0, 0))
  54.       if max_count >= 11:  # Heuristically selected.
  55.         index_mapping[index_i] = index_j
  56.  
  57.     def get_transform4():
  58.       index_pairs = list(index_mapping.items())
  59.       index_i0, index_j0 = index_pairs[0]
  60.       for transform in TRANSFORMS:
  61.         offset = scanners[i][index_i0] - (transform @ scanners[j][index_j0])
  62.         for index_i, index_j in index_pairs[1:]:
  63.           transformed = (transform @ scanners[j][index_j]) + offset
  64.           if np.any(transformed != scanners[i][index_i]):
  65.             break
  66.         else:
  67.           return np.vstack((np.hstack((transform, offset[:, None])),
  68.                             [[0, 0, 0, 1]]))
  69.  
  70.     transform4 = get_transform4()  # To i from j.  (We want ir from jr.)
  71.     transform4 = scanner_transforms[ir][i] @ transform4 @ np.linalg.inv(
  72.         scanner_transforms[jr][j].astype(np.float32)).astype(int)
  73.  
  74.     points_ir = {tuple(point) for point in scanners[ir]}
  75.     new_points = []
  76.     for point in scanners[jr]:
  77.       transformed = (transform4 @ [*point, 1])[:3]
  78.       if tuple(transformed) not in points_ir:
  79.         new_points.append(transformed)
  80.     scanners[ir] = np.concatenate((scanners[ir], new_points), axis=0)
  81.  
  82.     for j2, transform2 in scanner_transforms[jr].items():
  83.       scanner_transforms[ir][j2] = transform4 @ transform2
  84.       scanner_rep[j2] = ir
  85.     scanner_rep[jr] = ir
  86.     last_merged = ir
  87.  
  88.   if not part2:
  89.     return len(scanners[last_merged])
  90.  
  91.   final_transforms = scanner_transforms[last_merged].values()
  92.   return max(abs(t1[:3, 3] - t2[:3, 3]).sum()
  93.              for t1, t2 in itertools.combinations(final_transforms, 2))
  94.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement