import sys
from collections import defaultdict
from dataclasses import dataclass
from dataclasses import Field
from dataclasses import field
from dataclasses import fields
from dataclasses import is_dataclass
from dataclasses import MISSING
from typing import _eval_type # type: ignore
from typing import Any
from typing import Callable
from typing import Dict
from typing import get_type_hints
from typing import Iterator
from typing import List
from typing import Optional
from typing import Set
from typing import Type
from xsdata.exceptions import XmlContextError
from xsdata.formats.bindings import T
from xsdata.formats.converter import converter
from xsdata.formats.dataclass.models.constants import XmlType
from xsdata.formats.dataclass.models.elements import XmlMeta
from xsdata.formats.dataclass.models.elements import XmlVar
from xsdata.models.enums import DataType
from xsdata.models.enums import NamespaceType
from xsdata.utils.constants import EMPTY_SEQUENCE
from xsdata.utils.namespaces import build_qname
[docs]@dataclass
class XmlContext:
"""
Generate and cache the necessary metadata to bind an xml document data to a
dataclass model.
:param element_name: Default callable to convert field names to element tags
:param attribute_name: Default callable to convert field names to attribute tags
:param cache: Cache models metadata
:param xsi_cache: Index models by xsi:type
"""
element_name: Callable = field(default=lambda x: x)
attribute_name: Callable = field(default=lambda x: x)
cache: Dict[Type, XmlMeta] = field(default_factory=dict)
xsi_cache: Dict[str, List[Type]] = field(default_factory=lambda: defaultdict(list))
sys_modules: int = field(default=0, init=False)
[docs] def fetch(
self,
clazz: Type,
parent_ns: Optional[str] = None,
xsi_type: Optional[str] = None,
) -> XmlMeta:
"""
Fetch the model metadata of the given dataclass type, namespace and xsi
type.
:param clazz: A dataclass model
:param parent_ns: The parent dataclass namespace if present.
:param xsi_type: if present it means that the given clazz is derived and the
lookup procedure needs to check and match a dataclass model to the qualified
name instead.
"""
meta = self.build(clazz, parent_ns)
subclass = None
if xsi_type and meta.source_qname != xsi_type:
subclass = self.find_subclass(clazz, xsi_type)
return self.build(subclass, parent_ns) if subclass else meta
[docs] def build_xsi_cache(self):
"""Index all imported dataclasses by their xsi:type qualified name."""
if len(sys.modules) == self.sys_modules:
return
self.xsi_cache.clear()
for clazz in self.get_subclasses(object):
if is_dataclass(clazz):
meta = clazz.Meta if "Meta" in clazz.__dict__ else None
name = getattr(meta, "name", None) or self.local_name(clazz.__name__)
module = sys.modules[clazz.__module__]
source_namespace = getattr(module, "__NAMESPACE__", None)
source_qname = build_qname(source_namespace, name)
self.xsi_cache[source_qname].append(clazz)
self.sys_modules = len(sys.modules)
[docs] def find_types(self, qname: str) -> List[Type[T]]:
"""
Find all classes that match the given xsi:type qname.
- Ignores native schema types, xs:string, xs:float, xs:int, ...
- Rebuild cache if new modules were imported since last run
"""
if not DataType.from_qname(qname):
self.build_xsi_cache()
if qname in self.xsi_cache:
return self.xsi_cache[qname]
return []
[docs] def find_type(self, qname: str) -> Optional[Type[T]]:
"""Return the most recently imported class that matches the given
xsi:type qname."""
types: List[Type] = self.find_types(qname)
return types[-1] if types else None
[docs] def find_type_by_fields(self, field_names: Set) -> Optional[Type[T]]:
self.build_xsi_cache()
for types in self.xsi_cache.values():
for clazz in types:
if field_names == {attr.name for attr in fields(clazz)}:
return clazz
return None
[docs] def find_subclass(self, clazz: Type, qname: str) -> Optional[Type]:
"""Compare all classes that match the given xsi:type qname and return
the first one that is either a subclass or shares the same parent class
as the original class."""
types: List[Type] = self.find_types(qname)
for tp in types:
for tp_mro in tp.__mro__:
if tp_mro is not object and tp_mro in clazz.__mro__:
return tp
return None
[docs] def build(self, clazz: Type, parent_ns: Optional[str] = None) -> XmlMeta:
"""Fetch from cache or build the metadata object for the given class
and parent namespace."""
if clazz not in self.cache:
# Ensure the given type is a dataclass.
if not is_dataclass(clazz):
raise XmlContextError(f"Object {clazz} is not a dataclass.")
# Fetch the dataclass meta settings and make sure we don't inherit
# the parent class meta.
meta = clazz.Meta if "Meta" in clazz.__dict__ else None
name = getattr(meta, "name", None) or self.local_name(clazz.__name__)
nillable = getattr(meta, "nillable", False)
namespace = getattr(meta, "namespace", parent_ns)
module = sys.modules[clazz.__module__]
source_namespace = getattr(module, "__NAMESPACE__", None)
self.cache[clazz] = XmlMeta(
clazz=clazz,
qname=build_qname(namespace, name),
source_qname=build_qname(source_namespace, name),
nillable=nillable,
vars=list(self.get_type_hints(clazz, namespace)),
)
return self.cache[clazz]
[docs] def get_type_hints(self, clazz: Type, parent_ns: Optional[str]) -> Iterator[XmlVar]:
"""Build the model class fields metadata."""
type_hints = get_type_hints(clazz)
default_xml_type = self.default_xml_type(clazz)
for var in fields(clazz):
type_hint = type_hints[var.name]
types = self.real_types(type_hint)
is_tokens = var.metadata.get("tokens", False)
is_element_list = self.is_element_list(type_hint, is_tokens)
is_class = any(is_dataclass(clazz) for clazz in types)
xml_type = var.metadata.get("type")
local_name = var.metadata.get("name")
if not xml_type:
xml_type = default_xml_type if not is_class else "Element"
if not local_name:
local_name = self.local_name(var.name, xml_type)
xml_clazz = XmlType.to_xml_class(xml_type)
namespace = var.metadata.get("namespace")
namespaces = self.resolve_namespaces(xml_type, namespace, parent_ns)
default_namespace = self.default_namespace(namespaces)
qname = build_qname(default_namespace, local_name)
choices = list(
self.build_choices(
clazz,
var.name,
parent_ns,
var.metadata.get("choices", EMPTY_SEQUENCE),
)
)
yield xml_clazz(
name=var.name,
qname=qname,
namespaces=namespaces,
init=var.init,
mixed=var.metadata.get("mixed", False),
nillable=var.metadata.get("nillable", False),
dataclass=is_class,
sequential=var.metadata.get("sequential", False),
tokens=is_tokens,
list_element=is_element_list,
types=types,
default=self.default_value(var),
choices=choices,
)
[docs] def build_choices(
self,
clazz: Type,
parent_name: str,
parent_namespace: Optional[str],
choices: List[Dict],
):
existing = set()
globalns = sys.modules[clazz.__module__].__dict__
for choice in choices:
xml_type = XmlType.WILDCARD if choice.get("wildcard") else XmlType.ELEMENT
namespace = choice.get("namespace")
namespaces = self.resolve_namespaces(xml_type, namespace, parent_namespace)
default_namespace = self.default_namespace(namespaces)
types = self.real_types(_eval_type(choice["type"], globalns, None))
derived = any(True for tp in types if tp in existing)
is_class = any(is_dataclass(clazz) for clazz in types)
xml_clazz = XmlType.to_xml_class(xml_type)
qname = build_qname(default_namespace, choice.get("name", "any"))
nillable = choice.get("nillable", False)
if xml_type == XmlType.ELEMENT and len(types) == 1 and types[0] == object:
derived = True
yield xml_clazz(
name=parent_name,
qname=qname,
namespaces=namespaces,
nillable=nillable,
dataclass=is_class,
tokens=choice.get("tokens", False),
derived=derived or nillable,
types=types,
)
existing.update(types)
[docs] @classmethod
def resolve_namespaces(
cls,
xml_type: Optional[str],
namespace: Optional[str],
parent_namespace: Optional[str],
) -> List[str]:
"""
Resolve the namespace(s) for the given xml type and the parent
namespace.
Only elements and wildcards are allowed to inherit the parent namespace if
the given namespace is empty.
In case of wildcard try to decode the ##any, ##other, ##local, ##target.
"""
if xml_type in (XmlType.ELEMENT, XmlType.WILDCARD) and namespace is None:
namespace = parent_namespace
if not namespace:
return []
result = set()
for ns in namespace.split():
if ns == NamespaceType.TARGET_NS:
result.add(parent_namespace or NamespaceType.ANY_NS)
elif ns == NamespaceType.LOCAL_NS:
result.add("")
elif ns == NamespaceType.OTHER_NS:
result.add(f"!{parent_namespace or ''}")
else:
result.add(ns)
return list(result)
[docs] @classmethod
def default_namespace(cls, namespaces: List[str]) -> Optional[str]:
"""Return the first valid namespace uri or None."""
for namespace in namespaces:
if namespace and not namespace.startswith("#"):
return namespace
return None
[docs] @classmethod
def default_value(cls, var: Field) -> Any:
"""Return the default value/factory for the given field."""
if var.default_factory is not MISSING: # type: ignore
return var.default_factory # type: ignore
if var.default is not MISSING:
return var.default
return None
[docs] @classmethod
def real_types(cls, type_hint: Any) -> List:
"""Return a list of real types that can be used to bind or cast
data."""
types = []
if type_hint is Dict:
types.append(type_hint)
elif hasattr(type_hint, "__origin__"):
while len(type_hint.__args__) == 1 and hasattr(
type_hint.__args__[0], "__origin__"
):
type_hint = type_hint.__args__[0]
types = [x for x in type_hint.__args__ if x is not None.__class__]
else:
types.append(type_hint)
return converter.sort_types(types)
[docs] @classmethod
def is_derived(cls, obj: Any, clazz: Type) -> bool:
"""Return whether the given obj is derived from the given dataclass
type."""
if obj is None:
return False
if isinstance(obj, clazz):
return True
return any(x is not object and isinstance(obj, x) for x in clazz.__bases__)
[docs] @classmethod
def is_element_list(cls, type_hint: Any, is_tokens: bool) -> bool:
if getattr(type_hint, "__origin__", None) in (list, List):
if not is_tokens:
return True
type_hint = type_hint.__args__[0]
if getattr(type_hint, "__origin__", None) in (list, List):
return True
return False
[docs] @classmethod
def default_xml_type(cls, clazz: Type) -> str:
"""Return the default xml type for the fields of the given dataclass
with an undefined type."""
counters: Dict[str, int] = defaultdict(int)
for var in fields(clazz):
xml_type = var.metadata.get("type")
counters[xml_type or "undefined"] += 1
if counters[XmlType.TEXT] > 1:
raise XmlContextError(
f"Dataclass `{clazz.__name__}` includes more than one text node!"
)
if counters["undefined"] == 1 and counters[XmlType.TEXT] == 0:
return XmlType.TEXT
return XmlType.ELEMENT
[docs] @classmethod
def get_subclasses(cls, clazz: Type):
try:
for subclass in clazz.__subclasses__():
yield from cls.get_subclasses(subclass)
yield subclass
except TypeError:
pass
[docs] def local_name(self, name: str, xml_type: Optional[str] = None) -> str:
if xml_type == "Attribute":
return self.attribute_name(name)
return self.element_name(name)