Set parent policy with SetFactory
I'm trying to define a combinatorial class (i.e. Potato
), and I want my set factory to build potatoes using a class that have a specific instance (i.e. a Potato
with colour='yellow'
should be created as an instance of a subclass YellowPotato
). Now, this works when I run Potatoes(3)
where the elements are built with the appropriate subclass. But if I run Potatoes(3, colour='yellow')
then I get elements of the Potato
class, instead of YellowPotato
. Which policy should I set to fix this?
I would have attached a minimal example, but I can't attach file (or send links, apparently), so here it is.
from sage.combinat.permutation import Permutations
from sage.categories.finite_enumerated_sets import FiniteEnumeratedSets
from sage.misc.all import lazy_attribute, lazy_class_attribute
from sage.structure.dynamic_class import DynamicInheritComparisonClasscallMetaclass
from sage.structure.unique_representation import UniqueRepresentation
from sage.structure.set_factories import SelfParentPolicy, SetFactory, ParentWithSetFactory
from six import add_metaclass
from sage.structure.list_clone import ClonableIntArray # type: ignore
from sage.structure.unique_representation import UniqueRepresentation
from sage.sets.disjoint_union_enumerated_sets import DisjointUnionEnumeratedSets
from sage.rings.all import Integer
from sage.sets.family import Family
from sage.sets.positive_integers import PositiveIntegers
@add_metaclass(DynamicInheritComparisonClasscallMetaclass)
class Potato(ClonableIntArray):
@staticmethod
def __classcall_private__(cls, *args, **kwargs):
return cls._auto_parent._element_constructor_(*args, **kwargs)
@lazy_class_attribute
def _auto_parent(cls):
return Potatoes()
def __init__(self, parent, potato, colour=None):
ClonableIntArray.__init__(self, parent, potato)
self.potato = potato
self.colour = colour
def _repr_(self):
representation = f'{self.parent().Element.__name__}({self.potato}'
if self.colour is not None:
representation += f', colour={self.colour}'
representation += ')'
return representation
def check(self):
pass
class YellowPotato(Potato):
def __init__(self, parent, potato):
super().__init__(parent, potato, 'yellow')
class RedPotato(Potato):
def __init__(self, parent, potato):
super().__init__(parent, potato, 'red')
class PotatoesFactory(SetFactory):
Element = Potato
def __call__(self, size=None, colour=None, policy=None):
if policy is None:
policy = self._default_policy
if isinstance(size, (Integer, int)):
if isinstance(colour, str):
return _potatoes_size_colour(size, colour, policy)
else:
return Potatoes_size(size, policy)
elif size is None:
return Potatoes_all(policy)
else:
ValueError('Invalid size.')
def add_constraints(self, constraints, options):
constraints = list(constraints) + [None]*(2-len(constraints))
args, kwargs = options
for i, arg in enumerate(args):
if arg != constraints[i]:
constraints[i] = arg
if 'size' in kwargs:
constraints[0] = kwargs['size']
if 'colour' in kwargs:
constraints[1] = kwargs['colour']
return tuple(constraints)
@lazy_attribute
def _default_policy(self):
return SelfParentPolicy(self, self.Element)
Potatoes = PotatoesFactory()
class Potatoes_all(ParentWithSetFactory, DisjointUnionEnumeratedSets):
def __init__(self, policy):
ParentWithSetFactory.__init__(
self, (), policy, category=FiniteEnumeratedSets()
)
DisjointUnionEnumeratedSets.__init__(
self, Family(
PositiveIntegers(),
lambda n: Potatoes_size(
n, policy=self.facade_policy()
)
),
facade=True, keepkey=False, category=self.category()
)
def _repr_(self):
return 'Potatoes'
def check_element(self, el, check):
pass
class Potatoes_size(ParentWithSetFactory, DisjointUnionEnumeratedSets):
def __init__(self, size, policy):
self._size = size
ParentWithSetFactory.__init__(
self, (size, ), policy, category=FiniteEnumeratedSets()
)
DisjointUnionEnumeratedSets.__init__(
self, Family(
('yellow', 'red'),
lambda colour: _potatoes_size_colour(size, colour, policy=self.facade_policy())
),
facade=True, keepkey=False, category=self.category()
)
def _repr_(self):
return f'Potatoes of size {self._size}'
class Potatoes_size_red(ParentWithSetFactory, UniqueRepresentation):
Element = RedPotato
def __init__(self, size, policy):
self._size = size
ParentWithSetFactory.__init__(
self, (size, 'red'), policy, category=FiniteEnumeratedSets()
)
def __iter__(self):
for x in Permutations(self._size):
yield self.element_class(self, x, 'red')
def __repr__(self):
return f'Red potatoes of size {self._size}'
class Potatoes_size_yellow(ParentWithSetFactory, UniqueRepresentation):
Element = YellowPotato
def __init__(self, size, policy):
self._size = size
ParentWithSetFactory.__init__(
self, (size, 'yellow'), policy, category=FiniteEnumeratedSets()
)
def __iter__(self):
for x in Permutations(self._size):
yield self.element_class(self, x, 'yellow')
def __repr__(self):
return f'Yellow potatoes of size {self._size}'
def _potatoes_size_colour(size, colour, policy):
if colour == 'red':
return Potatoes_size_red(size, policy)
elif colour == 'yellow':
return Potatoes_size_yellow(size, policy)