Source code for xsdata.models.mixins

import re
import sys
from abc import ABC, abstractmethod
from dataclasses import MISSING, Field, dataclass, fields
from typing import Any, Dict, Iterator, Optional, Type, TypeVar

from lxml import etree

from xsdata.models.enums import FormType, Namespace
from xsdata.utils import text


[docs]class TypedField(ABC): @property @abstractmethod def real_type(self) -> Optional[str]: pass
[docs]class NamedField: @property def real_name(self) -> str: name = getattr(self, "name", None) or getattr(self, "ref", None) if name: return name raise NotImplementedError("Element has no name: {}".format(self)) @property def is_abstract(self) -> bool: return getattr(self, "abstract", False) @property def namespace(self): form: FormType = getattr(self, "form", FormType.UNQUALIFIED) if form == FormType.UNQUALIFIED: return None lookup = getattr(self, "ref", None) or getattr(self, "name") prefix, name = text.split(lookup or "") return self.nsmap.get(prefix)
[docs]class RestrictedField(ABC):
[docs] @abstractmethod def get_restrictions(self) -> Dict[str, Any]: pass
[docs]class OccurrencesMixin:
[docs] def get_restrictions(self) -> Dict[str, Any]: min_occurs = getattr(self, "min_occurs", 1) max_occurs = getattr(self, "max_occurs", 1) if min_occurs == max_occurs == 1: return dict(required=True) if max_occurs > min_occurs and max_occurs > 1: return dict(min_occurs=min_occurs, max_occurs=max_occurs) return dict()
T = TypeVar("T", bound="BaseModel")
[docs]class BaseModel: def __init__(self, *args, **kwargs): pass
[docs] @classmethod def from_element(cls: Type[T], el: etree.Element, index: int) -> T: attrs = { text.snake_case(etree.QName(key).localname): value for key, value in el.attrib.items() } data = { attr.name: cls.xsd_value(attr, attrs) if attr.name in attrs else cls.default_value(attr) for attr in fields(cls) if attr.init } if "nsmap" in data: data["nsmap"] = el.nsmap if "prefix" in data: data["prefix"] = el.prefix if "text" in data and el.text: data["text"] = re.sub(r"\s+", " ", el.text).strip() data["index"] = index return cls(**data)
[docs] @classmethod def default_value(cls: Type[T], field: Field) -> Any: factory = getattr(field, "default_factory") if getattr(field, "default_factory") is not MISSING: return factory() # mypy: ignore return None if field.default is MISSING else field.default
[docs] @classmethod def xsd_value(cls, field: Field, kwargs: Dict) -> Any: name = field.name value = kwargs[name] clazz = field.type if name == "max_occurs" and value == "unbounded": return sys.maxsize # Optional if hasattr(clazz, "__origin__"): clazz = clazz.__args__[0] if clazz == bool: return value == "true" try: if clazz == int: return int(value) if clazz == float: return float(value) except ValueError: return str(value) return clazz(value)
[docs] @classmethod def create(cls: Type[T], **kwargs) -> T: if not kwargs.get("prefix") and not kwargs.get("nsmap"): kwargs.update({"prefix": "xs", "nsmap": {"xs": Namespace.SCHEMA}}) data = { attr.name: kwargs[attr.name] if attr.name in kwargs else cls.default_value(attr) for attr in fields(cls) if attr.init } return cls(**data)
[docs]@dataclass class ElementBase(BaseModel): id: Optional[str] prefix: str nsmap: dict index: int
[docs] def children(self): for attribute in fields(self): value = getattr(self, attribute.name) if ( isinstance(value, list) and len(value) and isinstance(value[0], ElementBase) ): for v in value: yield v elif isinstance(value, ElementBase): yield value
@property def is_attribute(self) -> bool: return False @property def extends(self) -> Optional[str]: return None @property def extensions(self) -> Iterator[str]: extends = self.extends or "" return filter(None, extends.split(" ")) @property def num(self): return sum( [ len(getattr(self, attribute.name)) for attribute in fields(self) if isinstance(getattr(self, attribute.name), list) ] )