Advertisement
mate2code

simplified SetPart class

Jun 3rd, 2024 (edited)
500
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 8.59 KB | None | 0 0
  1. # This is a simplified version of the set partition class in the discrete helpers library.
  2. # https://github.com/watchduck/discrete_helpers/tree/main/discretehelpers/set_part
  3. # https://codereview.stackexchange.com/questions/292376/find-the-join-of-two-set-partitions-can-the-calculation-be-improved
  4.  
  5. from functools import cached_property
  6. from itertools import combinations, product, chain
  7.  
  8.  
  9. def have(val):
  10.     return val is not None
  11.  
  12.  
  13. ########################################################################################################################
  14.  
  15.  
  16. class SetPart(object):
  17.  
  18.     def __init__(self, blocks=None, domain='N'):
  19.  
  20.         """
  21.        :param blocks: List of lists. Each block is a list with at least two elements. The blocks do not intersect.
  22.        :param domain: Set of allowed elements of blocks. Can be a finite set. Elements are usually integers.
  23.                       By default, the domain is the set of non-negative integers.
  24.        """
  25.  
  26.         if domain not in ['N', 'Z']:
  27.             assert type(domain) in [set, list, tuple, range]
  28.             self.domain = set(domain)
  29.         else:
  30.             self.domain = domain  # keep the letters
  31.  
  32.         if blocks is None:
  33.             self.set_trivial()
  34.             return
  35.  
  36.         blocks = sorted(sorted(block) for block in blocks if len(block) > 1)
  37.         if not blocks:
  38.             self.set_trivial()
  39.             return
  40.  
  41.         self.blocks = blocks
  42.  
  43.         _ = dict()
  44.         for block_index, block in enumerate(self.blocks):
  45.             for element in block:
  46.                 _[element] = block_index
  47.         self.non_singleton_to_block_index = _
  48.  
  49.         self.non_singletons = set(self.non_singleton_to_block_index.keys())
  50.         assert len(self.non_singletons) == sum([len(block) for block in self.blocks])
  51.  
  52.         if self.domain == 'N':
  53.             self.length = max(self.non_singletons) + 1
  54.  
  55.         self.trivial = False
  56.  
  57.     ##############################################################################################
  58.  
  59.     def set_trivial(self):
  60.         self.trivial = True
  61.         self.blocks = []
  62.         self.non_singleton_to_block_index = dict()
  63.         self.non_singletons = set()
  64.  
  65.         if self.domain == 'N':
  66.             self.length = 0
  67.  
  68.     ###############################################
  69.  
  70.     def __eq__(self, other):
  71.         return self.blocks == other.blocks
  72.  
  73.     ###############################################
  74.  
  75.     def element_in_domain(self, element):
  76.         if self.domain == 'N':
  77.             return type(element) == int and element >= 0
  78.         elif self.domain == 'Z':
  79.             return type(element) == int
  80.         else:
  81.             return element in self.domain
  82.  
  83.     ###############################################
  84.  
  85.     def merge_pair(self, a, b):
  86.  
  87.         """
  88.        When elements `a` and `b` are in different blocks, both blocks will be merged.
  89.        Changes the partition. Returns nothing.
  90.        """
  91.  
  92.         if a == b:
  93.             return  # nothing to do
  94.  
  95.         assert self.element_in_domain(a) and self.element_in_domain(b)
  96.  
  97.         a_found = a in self.non_singletons
  98.         b_found = b in self.non_singletons
  99.  
  100.         if a_found and b_found:
  101.  
  102.             block_index_a = self.non_singleton_to_block_index[a]
  103.             block_index_b = self.non_singleton_to_block_index[b]
  104.  
  105.             if block_index_a == block_index_b:
  106.                 return  # nothing to do
  107.  
  108.             block_a = self.blocks[block_index_a]
  109.             block_b = self.blocks[block_index_b]
  110.             merged_block = sorted(block_a + block_b)
  111.  
  112.             self.blocks.remove(block_a)
  113.             self.blocks.remove(block_b)
  114.             self.blocks.append(merged_block)
  115.  
  116.         elif not a_found and not b_found:
  117.             self.blocks.append([a, b])
  118.  
  119.         else:  # a_found and not b_found
  120.             if b_found and not a_found:
  121.                 a, b = b, a
  122.             block_index_a = self.non_singleton_to_block_index[a]
  123.             self.blocks[block_index_a].append(b)
  124.  
  125.         self.__init__(self.blocks, self.domain)  # reinitialize
  126.  
  127.     ###############################################
  128.  
  129.     def blocks_with_singletons(self, elements=None):
  130.         """
  131.        :param elements: Any subset of the domain.
  132.        :return: Blocks with added singleton-blocks for each element in `elements` that is not in an actual block.
  133.        """
  134.         assert type(elements) in [set, list, range]
  135.         assert self.non_singletons.issubset(set(elements))
  136.         singletons = set(elements).difference(self.non_singletons)
  137.         singleton_blocks = [[_] for _ in singletons]
  138.         return sorted(self.blocks + singleton_blocks)
  139.  
  140.     ##############################################################################################
  141.  
  142.     @cached_property
  143.     def pairs(self):
  144.         """
  145.        :return: For each block the set of all is 2-element subsets. All those in one set.
  146.        Slow for big blocks, because of factorial growth.
  147.        """
  148.         result = set()
  149.         for block in self.blocks:
  150.             for pair in combinations(block, 2):
  151.                 result.add(pair)
  152.         return result
  153.  
  154.     ###############################################
  155.  
  156.     def join_pairs(self, other):
  157.         """
  158.        :param other: another set partition
  159.        :return: The join of the two set partitions.
  160.        This method uses the property `pairs`, so it is also slow for big blocks.
  161.        """
  162.         assert self.domain == other.domain
  163.         result = SetPart([], self.domain)
  164.         for pair in self.pairs.union(other.pairs):
  165.             result.merge_pair(*pair)
  166.         return result
  167.  
  168.     ###############################################
  169.  
  170.     def meet(self, other):
  171.         """
  172.        :param other: another set partition
  173.        :return: The meet of the two set partitions.
  174.        Let M be the meet of partitions A and B. The blocks of M are the intersections of the blocks of A and B.
  175.        """
  176.         meet_blocks = []
  177.         for s_block, o_block in product(self.blocks, other.blocks):
  178.             intersection = set(s_block) & set(o_block)
  179.             if intersection:
  180.                 meet_blocks.append(sorted(intersection))
  181.         return SetPart(meet_blocks, self.domain)
  182.  
  183.     ###############################################
  184.  
  185.     def join(self, other):
  186.         """
  187.        :param other: another set partition
  188.        :return: The join of the two set partitions.
  189.        This method uses the method `join_pairs`.
  190.        The danger of factorial growth is reduced, by making the input partitions smaller.
  191.        """
  192.         meet_part = self.meet(other)
  193.  
  194.         trash = set()
  195.         rep_to_trash = dict()
  196.         for block in meet_part.blocks:
  197.             block_rep = min(block)
  198.             block_trash = set(block) - {block_rep}
  199.             trash |= block_trash
  200.             rep_to_trash[block_rep] = block_trash
  201.  
  202.         clean_s_blocks = [sorted(set(block) - trash) for block in self.blocks]
  203.         clean_o_blocks = [sorted(set(block) - trash) for block in other.blocks]
  204.  
  205.         clean_s_part = SetPart(clean_s_blocks)
  206.         clean_o_part = SetPart(clean_o_blocks)
  207.  
  208.         clean_join_part = clean_s_part.join_pairs(clean_o_part)
  209.  
  210.         s_elements = set(chain.from_iterable(self.blocks))
  211.         o_elements = set(chain.from_iterable(other.blocks))
  212.         dirty_domain = s_elements | o_elements
  213.         clean_domain = dirty_domain - trash
  214.  
  215.         dirty_join_blocks = []
  216.         clean_blocks_with_singletons = clean_join_part.blocks_with_singletons(elements=clean_domain)
  217.         for clean_block in clean_blocks_with_singletons:
  218.             dirty_block = set(clean_block)
  219.             for element in clean_block:
  220.                 if element in rep_to_trash:
  221.                     dirty_block |= rep_to_trash[element]
  222.             dirty_join_blocks.append(sorted(dirty_block))
  223.  
  224.         return SetPart(dirty_join_blocks, domain=self.domain)
  225.  
  226.  
  227. ########################################################################################################################
  228.  
  229. p = SetPart([[1, 2], [7, 8, 9]])
  230.  
  231. assert p.pairs == {(1, 2), (7, 8), (7, 9), (8, 9)}
  232.  
  233. assert p.blocks_with_singletons({1, 2, 5, 7, 8, 9}) == [[1, 2], [5], [7, 8, 9]]
  234.  
  235. p.merge_pair(5, 6)
  236. assert p == SetPart([[1, 2], [5, 6], [7, 8, 9]])
  237. p.merge_pair(2, 5)
  238. assert p == SetPart([[1, 2, 5, 6], [7, 8, 9]])
  239.  
  240. ###############################################
  241.  
  242. a = SetPart([[0, 1, 2, 4], [5, 6, 9], [7, 8]])
  243. b = SetPart([[0, 1], [2, 3, 4], [6, 8, 9]])
  244.  
  245. assert a.meet(b) == SetPart([[0, 1], [2, 4], [6, 9]])
  246. assert a.join(b) == a.join_pairs(b) == SetPart([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
  247.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement