Set parent policy with SetFactory

asked 2020-11-29 11:50:41 -0600

SashaIr gravatar image

I'm trying to define a combinatorial class (i.e. Potato), and I want my set factory to build potatoes using a class that have a specific instance (i.e. a Potato with colour='yellow' should be created as an instance of a subclass YellowPotato). Now, this works when I run Potatoes(3) where the elements are built with the appropriate subclass. But if I run Potatoes(3, colour='yellow') then I get elements of the Potato class, instead of YellowPotato. Which policy should I set to fix this?

I would have attached a minimal example, but I can't attach file (or send links, apparently), so here it is.

from sage.combinat.permutation import Permutations
from sage.categories.finite_enumerated_sets import FiniteEnumeratedSets
from sage.misc.all import lazy_attribute, lazy_class_attribute
from sage.structure.dynamic_class import DynamicInheritComparisonClasscallMetaclass
from sage.structure.unique_representation import UniqueRepresentation
from sage.structure.set_factories import SelfParentPolicy, SetFactory, ParentWithSetFactory
from six import add_metaclass
from sage.structure.list_clone import ClonableIntArray  # type: ignore
from sage.structure.unique_representation import UniqueRepresentation
from sage.sets.disjoint_union_enumerated_sets import DisjointUnionEnumeratedSets
from sage.rings.all import Integer
from sage.sets.family import Family
from sage.sets.positive_integers import PositiveIntegers


@add_metaclass(DynamicInheritComparisonClasscallMetaclass)
class Potato(ClonableIntArray):

    @staticmethod
    def __classcall_private__(cls, *args, **kwargs):
        return cls._auto_parent._element_constructor_(*args, **kwargs)

    @lazy_class_attribute
    def _auto_parent(cls):
        return Potatoes()

    def __init__(self, parent, potato, colour=None):
        ClonableIntArray.__init__(self, parent, potato)
        self.potato = potato
        self.colour = colour

    def _repr_(self):
        representation = f'{self.parent().Element.__name__}({self.potato}'
        if self.colour is not None:
            representation += f', colour={self.colour}'
        representation += ')'
        return representation

    def check(self):
        pass


class YellowPotato(Potato):
    def __init__(self, parent, potato):
        super().__init__(parent, potato, 'yellow')


class RedPotato(Potato):
    def __init__(self, parent, potato):
        super().__init__(parent, potato, 'red')


class PotatoesFactory(SetFactory):
    Element = Potato

    def __call__(self, size=None, colour=None, policy=None):

        if policy is None:
            policy = self._default_policy

        if isinstance(size, (Integer, int)):
            if isinstance(colour, str):
                return _potatoes_size_colour(size, colour, policy)
            else:
                return Potatoes_size(size, policy)
        elif size is None:
            return Potatoes_all(policy)
        else:
            ValueError('Invalid size.')

    def add_constraints(self, constraints, options):

        constraints = list(constraints) + [None]*(2-len(constraints))
        args, kwargs = options

        for i, arg in enumerate(args):
            if arg != constraints[i]:
                constraints[i] = arg

        if 'size' in kwargs:
            constraints[0] = kwargs['size']
        if 'colour' in kwargs:
            constraints[1] = kwargs['colour']

        return tuple(constraints)

    @lazy_attribute
    def _default_policy(self):
        return SelfParentPolicy(self, self.Element)


Potatoes = PotatoesFactory()


class Potatoes_all(ParentWithSetFactory, DisjointUnionEnumeratedSets):

    def __init__(self, policy):
        ParentWithSetFactory.__init__(
            self, (), policy, category=FiniteEnumeratedSets()
        )
        DisjointUnionEnumeratedSets.__init__(
            self, Family(
                PositiveIntegers(),
                lambda n: Potatoes_size(
                    n, policy=self.facade_policy()
                )
            ),
            facade=True, keepkey=False, category=self.category()
        )

    def _repr_(self):
        return 'Potatoes'

    def check_element(self, el, check):
        pass


class Potatoes_size(ParentWithSetFactory, DisjointUnionEnumeratedSets):

    def __init__(self, size, policy):
        self._size = size
        ParentWithSetFactory.__init__(
            self, (size, ), policy, category=FiniteEnumeratedSets()
        )
        DisjointUnionEnumeratedSets.__init__(
            self, Family(
                ('yellow', 'red'),
                lambda colour: _potatoes_size_colour(size, colour, policy=self.facade_policy())
            ),
            facade=True, keepkey=False, category=self.category()
        )

    def _repr_(self):
        return f'Potatoes of size {self._size}'


class Potatoes_size_red(ParentWithSetFactory, UniqueRepresentation):
    Element = RedPotato

    def __init__(self, size, policy):
        self._size = size
        ParentWithSetFactory.__init__(
            self, (size, 'red'), policy, category=FiniteEnumeratedSets()
        )

    def __iter__(self):
        for x in Permutations(self._size):
            yield self.element_class(self, x, 'red')

    def __repr__(self):
        return f'Red potatoes of size {self._size}'


class Potatoes_size_yellow(ParentWithSetFactory, UniqueRepresentation):
    Element = YellowPotato

    def __init__(self, size, policy):
        self._size = size
        ParentWithSetFactory.__init__(
            self, (size, 'yellow'), policy, category=FiniteEnumeratedSets()
        )

    def __iter__(self):
        for x in Permutations(self._size):
            yield self.element_class(self, x, 'yellow')

    def __repr__(self):
        return f'Yellow potatoes of size {self._size}'


def _potatoes_size_colour(size, colour, policy):
    if colour == 'red':
        return Potatoes_size_red(size, policy)
    elif colour == 'yellow':
        return Potatoes_size_yellow(size, policy)
edit retag flag offensive close merge delete