Source code for mixbox.fields

# Copyright (c) 2015, The MITRE Corporation. All rights reserved.
# See LICENSE.txt for complete terms.
"""
Entity field data descriptors (TypedFields) and associated classes.
"""
import functools
import inspect

from .datautils import is_sequence, resolve_class
from .typedlist import TypedList
from .dates import parse_date, parse_datetime, serialize_date, serialize_datetime
from .xml import strip_cdata, cdata
from .vendor import six
from .compat import long


def unset(entity, *types):
    """Unset the TypedFields on the input `entity`.

    Args:
        entity: A mixbox.Entity object.
        *types: A variable-length list of TypedField subclasses. If not
            provided, defaults to TypedField.
    """
    if not types:
        types = (TypedField,)

    fields = list(entity._fields.keys())
    remove = (x for x in fields if isinstance(x, types))

    for field in remove:
        del entity._fields[field]


def _matches(field, params):
    """Return True if the input TypedField `field` contains instance attributes
    that match the input parameters.

    Args:
        field: A TypedField instance.
        params: A dictionary of TypedField instance attribute-to-value mappings.

    Returns:
        True if the input TypedField matches the input parameters.
    """
    fieldattrs = six.iteritems(params)
    return all(getattr(field, attr) == val for attr, val in fieldattrs)


def iterfields(klass):
    """Iterate over the input class members and yield its TypedFields.

    Args:
        klass: A class (usually an Entity subclass).

    Yields:
        (class attribute name, TypedField instance) tuples.
    """
    is_field = lambda x: isinstance(x, TypedField)

    for name, field in inspect.getmembers(klass, predicate=is_field):
        yield name, field


def find(entity, **kwargs):
    """Return all TypedFields found on the input `Entity` that were initialized
    with the input **kwargs.

    Example:
        >>> find(myentity, multiple=True, type_=Foo)

    Note:
        TypedFields.__init__() can accept a string or a class as a type_
        argument, but this method expects a class.

    Args:
        **kwargs: TypedField __init__ **kwargs to search on.

    Returns:
        A list of TypedFields with matching **kwarg values.
    """
    try:
        typedfields = entity.typed_fields()
    except AttributeError:
        typedfields = iterfields(entity.__class__)

    matching = [x for x in typedfields if _matches(x, kwargs)]
    return matching


class TypedField(object):

    def __init__(self, name, type_=None,
                 key_name=None, comparable=True, multiple=False,
                 preset_hook=None, postset_hook=None, factory=None,
                 listfunc=None):
        """
        Create a new field.

        Args:
            `name` (str): name of the field as contained in the binding class.
            `type_` (type/str): Required type for values assigned to this field.
                If`None`, no type checking is performed. String values are
                treated as fully qualified package paths to a class (e.g.,
                "A.B.C" would be the full path to the type "C".)
            `key_name` (str): name for field when represented as a dictionary.
                (Optional) If omitted, `name.lower()` will be used.
            `comparable` (boolean): whether this field should be considered
                when checking Entities for equality. Default is True. If False,
                this field is not considered.
            `multiple` (boolean): Whether multiple instances of this field can
                exist on the Entity.
            `preset_hook` (callable): called before assigning a value to this
                field, but after type checking is performed (if applicable).
                This should typically be used to perform additional validation
                checks on the value, perhaps based on current state of the
                instance. The callable should accept two arguments: (1) the
                instance object being modified, and (2)the value it is being
                set to.
            `postset_hook` (callable): similar to `preset_hook` (and takes the
                same arguments), but is called after setting the value. This
                can be used, for example, to modify other fields of the
                instance to maintain some type of invariant.
            `listfunc` (callable): A datatype or a function that creates a
                mutable sequence type for multiple field internal storage.
                E.g., "list".
        """
        self.name = name
        self.comparable = comparable
        self.multiple = multiple
        self.preset_hook = preset_hook
        self.postset_hook = postset_hook

        # The type of the field. This is lazily set via the type_ property
        # at first access.
        self._unresolved_type = type_

        # The factory for the field. This controls which class will be used
        # for from_dict() and from_obj() calls for this field.
        # Lazily set via the factory property.
        self._unresolved_factory = factory

        # Dictionary key name for the field.
        if key_name:
            self._key_name = key_name
        else:
            self._key_name = name.lower()

        # List creation function for multiple fields.
        if listfunc:
            self._listfunc = listfunc
        elif type_:
            self._listfunc = functools.partial(TypedList, type=type_)
        else:
            self._listfunc = list

    def __get__(self, instance, owner=None):
        """Return the TypedField value for the input `instance` and `owner`.

        If the TypedField is a "multiple" field and hasn't been set yet,
        set the field to an empty list and return it.

        Args:
            instance: An instance of the `owner` class that this TypedField
                belongs to..
            owner: The TypedField owner class.
        """
        if instance is None:
            return self
        elif self in instance._fields:
            return instance._fields[self]
        elif self.multiple:
            return instance._fields.setdefault(self, self._listfunc())
        else:
            return None

    def _clean(self, value):
        """Validate and clean a candidate value for this field."""
        if value is None:
            return None
        elif self.type_ is None:
            return value
        elif self.check_type(value):
            return value
        elif self.is_type_castable:  # noqa
            return self.type_(value)

        error_fmt = "%s must be a %s, not a %s"
        error = error_fmt % (self.name, self.type_, type(value))
        raise TypeError(error)

    def __set__(self, instance, value):
        """Sets the field value on `instance` for this TypedField.

        If the TypedField has a `type_` and `value` is not an instance of
        ``type_``, an attempt may be made to convert `value` into an instance
        of ``type_``.

        If the field is ``multiple``, an attempt is made to convert `value`
        into a list if it is not an iterable type.
        """
        if self.multiple:
            if value is None:
                value = self._listfunc()
            elif not is_sequence(value):
                value = self._listfunc([self._clean(value)])
            else:
                value = self._listfunc(self._clean(x) for x in value if x is not None)
        else:
            value = self._clean(value)

        if self.preset_hook:
            self.preset_hook(instance, value)

        instance._fields[self] = value

        if self.postset_hook:
            self.postset_hook(instance, value)

    def __str__(self):
        return self.name

    def check_type(self, value):
        if not self.type_:
            return True
        elif hasattr(self.type_, "istypeof"):
            return self.type_.istypeof(value)
        else:
            return isinstance(value, self.type_)

    @property
    def key_name(self):
        return self._key_name

    @property
    def type_(self):
        try:
            return self._resolved_type
        except AttributeError:
            self._resolved_type = resolve_class(self._unresolved_type)
        return self._resolved_type

    @type_.setter
    def type_(self, value):
        self._resolved_type = value

    @property
    def factory(self):
        try:
            return self._resolved_factory
        except AttributeError:
            self._resolved_factory = resolve_class(self._unresolved_factory)
        return self._resolved_factory

    @factory.setter
    def factory(self, value):
        self._resolved_factory = value

    @property
    def transformer(self):
        """Return the class for this field that transforms non-Entity objects
        (e.g., dicts or binding objects) into Entity instances.

        Any non-None value returned from this method should implement a
        from_obj() and from_dict() method.

        Returns:
            None if no type_ or factory is defined by the field. Return a class
            with from_dict and from_obj methods otherwise.
        """
        if self.factory:
            return self.factory
        elif self.type_:
            return self.type_
        else:
            return None

    @property
    def is_type_castable(self):
        return getattr(self.type_, "_try_cast", False)

    def binding_value(self, value):
        return value

    def dict_value(self, value):
        return value

    def __copy__(self):
        """See __deepcopy__."""
        return self

    def __deepcopy__(self, memo):
        """Return itself (don't actually make a copy at all).

        TypedFields store themselves as a key in an Entity._fields dictionary
        and use themselves as a key for value retrieval.

        The deepcopy() function would normally descend into the _fields dictionary
        of an Entity and replace the keys with *copies* of the original
        TypedFields.

        As such, a TypedField would never find itself in a deepcopied Entity,
        because the _fields dictionary had its keys swapped out for copies
        of the original TypedField.

        We could control __deepcopy__ at the Entity level, but it's a fair
        amount more complicated and ultimately, we probably never want
        TypedFields to actually be copied since they are class-level
        property descriptors.
        """
        memo[id(self)] = self  # add self to the memo so this isn't called again.
        return self


class BytesField(TypedField):
    def _clean(self, value):
        return six.binary_type(value)


class TextField(TypedField):
    def _clean(self, value):
        return six.text_type(value)


class BooleanField(TypedField):
    def _clean(self, value):
        return bool(value)


class IntegerField(TypedField):
    def _clean(self, value):
        if value in (None, ""):
            return None
        elif isinstance(value, six.string_types):
            return int(value, 0)
        else:
            return int(value)


class LongField(TypedField):
    def _clean(self, value):
        if value in (None, ""):
            return None
        elif isinstance(value, six.string_types):
            return long(value, 0)
        else:
            return long(value)


class FloatField(TypedField):
    def _clean(self, value):
        if value in (None, ""):
            return None
        return float(value)


class DateTimeField(TypedField):
    def _clean(self, value):
        return parse_datetime(value)

    def dict_value(self, value):
        return serialize_datetime(value)

    def binding_value(self, value):
        return serialize_datetime(value)


class DateField(TypedField):
    def _clean(self, value):
        return parse_date(value)

    def dict_value(self, value):
        return serialize_date(value)

    def binding_value(self, value):
        return serialize_date(value)


class CDATAField(TypedField):
    def _clean(self, value):
        return strip_cdata(value)

    def binding_value(self, value):
        return cdata(value)


class IdField(TypedField):
    def __set__(self, instance, value):
        """Set the id field to `value`. If `value` is not None or an empty
        string, unset the idref fields on `instance`.
        """
        super(IdField, self).__set__(instance, value)

        if value:
            unset(instance, IdrefField)


class IdrefField(TypedField):
    def __set__(self, instance, value):
        """Set the idref field to `value`. If `value` is not None or an empty
        string, unset the id fields on `instance`.
        """
        super(IdrefField, self).__set__(instance, value)

        if value:
            unset(instance, IdField)