Source code for xsdata.formats.dataclass.context
from dataclasses import dataclass
from dataclasses import Field
from dataclasses import field
from dataclasses import fields
from dataclasses import is_dataclass
from dataclasses import MISSING
from typing import Any
from typing import Callable
from typing import Dict
from typing import get_type_hints
from typing import Iterator
from typing import List
from typing import Optional
from typing import Type
from lxml.etree import QName
from xsdata.exceptions import XmlContextError
from xsdata.formats.converters import sort_types
from xsdata.formats.dataclass.models.constants import XmlType
from xsdata.formats.dataclass.models.elements import XmlMeta
from xsdata.formats.dataclass.models.elements import XmlVar
from xsdata.models.enums import NamespaceType
[docs]@dataclass
class XmlContext:
name_generator: Callable = field(default=lambda x: x)
cache: Dict[Type, XmlMeta] = field(default_factory=dict)
[docs] def build(self, clazz: Type, parent_ns: Optional[str] = None) -> XmlMeta:
if clazz not in self.cache:
if not is_dataclass(clazz):
raise XmlContextError(f"Object {clazz} is not a dataclass.")
meta = getattr(clazz, "Meta", None)
if meta and meta.__qualname__ != f"{clazz.__name__}.Meta":
meta = None
name = getattr(meta, "name", self.name_generator(clazz.__name__))
nillable = getattr(meta, "nillable", False)
namespace = getattr(meta, "namespace", parent_ns)
self.cache[clazz] = XmlMeta(
name=name,
clazz=clazz,
qname=QName(namespace, name),
nillable=nillable,
vars=list(self.get_type_hints(clazz, namespace)),
)
return self.cache[clazz]
[docs] def get_type_hints(self, clazz, parent_ns: Optional[str]) -> Iterator[XmlVar]:
type_hints = get_type_hints(clazz)
for var in fields(clazz):
type_hint = type_hints[var.name]
types = self.real_types(type_hint)
xml_type = var.metadata.get("type")
xml_clazz = XmlType.to_xml_class(xml_type)
namespace = var.metadata.get("namespace")
namespaces = self.resolve_namespaces(xml_type, namespace, parent_ns)
local_name = var.metadata.get("name") or self.name_generator(var.name)
is_class = any(is_dataclass(clazz) for clazz in types)
first_namespace = (
namespaces[0]
if len(namespaces) > 0 and namespaces[0] and namespaces[0][0] != "#"
else None
)
yield xml_clazz(
name=var.name,
qname=QName(first_namespace, local_name),
namespaces=namespaces,
init=var.init,
nillable=var.metadata.get("nillable", False),
dataclass=is_class,
sequential=var.metadata.get("sequential", False),
types=types,
default=self.default_value(var),
)
[docs] @staticmethod
def resolve_namespaces(xml_type, namespace, parent_namespace):
if xml_type in (XmlType.ELEMENT, XmlType.WILDCARD) and namespace is None:
namespace = parent_namespace
if not namespace:
return []
result = set()
for ns in namespace.split(" "):
ns = ns.strip()
if not ns:
continue
ns_type = NamespaceType.get_enum(ns)
if ns_type == NamespaceType.TARGET:
result.add(parent_namespace or NamespaceType.ANY.value)
elif ns_type == NamespaceType.LOCAL:
result.add("")
elif ns_type == NamespaceType.OTHER:
result.add(f"!{parent_namespace or ''}")
else:
result.add(ns)
return list(result)
[docs] @staticmethod
def default_value(var: Field) -> Any:
if var.default_factory is not MISSING: # type: ignore
return var.default_factory # type: ignore
if var.default is not MISSING:
return var.default
return None
[docs] @staticmethod
def real_types(type_hint) -> List:
types = []
if type_hint is Dict:
types.append(type_hint)
elif hasattr(type_hint, "__origin__"):
while len(type_hint.__args__) == 1 and hasattr(
type_hint.__args__[0], "__origin__"
):
type_hint = type_hint.__args__[0]
types = [
x for x in type_hint.__args__ if x is not None.__class__ # type: ignore
]
else:
types.append(type_hint)
return sort_types(types)