Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # This is a simplified version of the set partition class in the discrete helpers library.
- # https://github.com/watchduck/discrete_helpers/tree/main/discretehelpers/set_part
- # https://codereview.stackexchange.com/questions/292376/find-the-join-of-two-set-partitions-can-the-calculation-be-improved
- from functools import cached_property
- from itertools import combinations, product, chain
- def have(val):
- return val is not None
- ########################################################################################################################
- class SetPart(object):
- def __init__(self, blocks=None, domain='N'):
- """
- :param blocks: List of lists. Each block is a list with at least two elements. The blocks do not intersect.
- :param domain: Set of allowed elements of blocks. Can be a finite set. Elements are usually integers.
- By default, the domain is the set of non-negative integers.
- """
- if domain not in ['N', 'Z']:
- assert type(domain) in [set, list, tuple, range]
- self.domain = set(domain)
- else:
- self.domain = domain # keep the letters
- if blocks is None:
- self.set_trivial()
- return
- blocks = sorted(sorted(block) for block in blocks if len(block) > 1)
- if not blocks:
- self.set_trivial()
- return
- self.blocks = blocks
- _ = dict()
- for block_index, block in enumerate(self.blocks):
- for element in block:
- _[element] = block_index
- self.non_singleton_to_block_index = _
- self.non_singletons = set(self.non_singleton_to_block_index.keys())
- assert len(self.non_singletons) == sum([len(block) for block in self.blocks])
- if self.domain == 'N':
- self.length = max(self.non_singletons) + 1
- self.trivial = False
- ##############################################################################################
- def set_trivial(self):
- self.trivial = True
- self.blocks = []
- self.non_singleton_to_block_index = dict()
- self.non_singletons = set()
- if self.domain == 'N':
- self.length = 0
- ###############################################
- def __eq__(self, other):
- return self.blocks == other.blocks
- ###############################################
- def element_in_domain(self, element):
- if self.domain == 'N':
- return type(element) == int and element >= 0
- elif self.domain == 'Z':
- return type(element) == int
- else:
- return element in self.domain
- ###############################################
- def merge_pair(self, a, b):
- """
- When elements `a` and `b` are in different blocks, both blocks will be merged.
- Changes the partition. Returns nothing.
- """
- if a == b:
- return # nothing to do
- assert self.element_in_domain(a) and self.element_in_domain(b)
- a_found = a in self.non_singletons
- b_found = b in self.non_singletons
- if a_found and b_found:
- block_index_a = self.non_singleton_to_block_index[a]
- block_index_b = self.non_singleton_to_block_index[b]
- if block_index_a == block_index_b:
- return # nothing to do
- block_a = self.blocks[block_index_a]
- block_b = self.blocks[block_index_b]
- merged_block = sorted(block_a + block_b)
- self.blocks.remove(block_a)
- self.blocks.remove(block_b)
- self.blocks.append(merged_block)
- elif not a_found and not b_found:
- self.blocks.append([a, b])
- else: # a_found and not b_found
- if b_found and not a_found:
- a, b = b, a
- block_index_a = self.non_singleton_to_block_index[a]
- self.blocks[block_index_a].append(b)
- self.__init__(self.blocks, self.domain) # reinitialize
- ###############################################
- def blocks_with_singletons(self, elements=None):
- """
- :param elements: Any subset of the domain.
- :return: Blocks with added singleton-blocks for each element in `elements` that is not in an actual block.
- """
- assert type(elements) in [set, list, range]
- assert self.non_singletons.issubset(set(elements))
- singletons = set(elements).difference(self.non_singletons)
- singleton_blocks = [[_] for _ in singletons]
- return sorted(self.blocks + singleton_blocks)
- ##############################################################################################
- @cached_property
- def pairs(self):
- """
- :return: For each block the set of all is 2-element subsets. All those in one set.
- Slow for big blocks, because of factorial growth.
- """
- result = set()
- for block in self.blocks:
- for pair in combinations(block, 2):
- result.add(pair)
- return result
- ###############################################
- def join_pairs(self, other):
- """
- :param other: another set partition
- :return: The join of the two set partitions.
- This method uses the property `pairs`, so it is also slow for big blocks.
- """
- assert self.domain == other.domain
- result = SetPart([], self.domain)
- for pair in self.pairs.union(other.pairs):
- result.merge_pair(*pair)
- return result
- ###############################################
- def meet(self, other):
- """
- :param other: another set partition
- :return: The meet of the two set partitions.
- Let M be the meet of partitions A and B. The blocks of M are the intersections of the blocks of A and B.
- """
- meet_blocks = []
- for s_block, o_block in product(self.blocks, other.blocks):
- intersection = set(s_block) & set(o_block)
- if intersection:
- meet_blocks.append(sorted(intersection))
- return SetPart(meet_blocks, self.domain)
- ###############################################
- def join(self, other):
- """
- :param other: another set partition
- :return: The join of the two set partitions.
- This method uses the method `join_pairs`.
- The danger of factorial growth is reduced, by making the input partitions smaller.
- """
- meet_part = self.meet(other)
- trash = set()
- rep_to_trash = dict()
- for block in meet_part.blocks:
- block_rep = min(block)
- block_trash = set(block) - {block_rep}
- trash |= block_trash
- rep_to_trash[block_rep] = block_trash
- clean_s_blocks = [sorted(set(block) - trash) for block in self.blocks]
- clean_o_blocks = [sorted(set(block) - trash) for block in other.blocks]
- clean_s_part = SetPart(clean_s_blocks)
- clean_o_part = SetPart(clean_o_blocks)
- clean_join_part = clean_s_part.join_pairs(clean_o_part)
- s_elements = set(chain.from_iterable(self.blocks))
- o_elements = set(chain.from_iterable(other.blocks))
- dirty_domain = s_elements | o_elements
- clean_domain = dirty_domain - trash
- dirty_join_blocks = []
- clean_blocks_with_singletons = clean_join_part.blocks_with_singletons(elements=clean_domain)
- for clean_block in clean_blocks_with_singletons:
- dirty_block = set(clean_block)
- for element in clean_block:
- if element in rep_to_trash:
- dirty_block |= rep_to_trash[element]
- dirty_join_blocks.append(sorted(dirty_block))
- return SetPart(dirty_join_blocks, domain=self.domain)
- ########################################################################################################################
- p = SetPart([[1, 2], [7, 8, 9]])
- assert p.pairs == {(1, 2), (7, 8), (7, 9), (8, 9)}
- assert p.blocks_with_singletons({1, 2, 5, 7, 8, 9}) == [[1, 2], [5], [7, 8, 9]]
- p.merge_pair(5, 6)
- assert p == SetPart([[1, 2], [5, 6], [7, 8, 9]])
- p.merge_pair(2, 5)
- assert p == SetPart([[1, 2, 5, 6], [7, 8, 9]])
- ###############################################
- a = SetPart([[0, 1, 2, 4], [5, 6, 9], [7, 8]])
- b = SetPart([[0, 1], [2, 3, 4], [6, 8, 9]])
- assert a.meet(b) == SetPart([[0, 1], [2, 4], [6, 9]])
- assert a.join(b) == a.join_pairs(b) == SetPart([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement