Source code for xsdata.resolver

import logging
from dataclasses import dataclass
from dataclasses import field
from typing import Dict
from typing import List

from lxml.etree import QName
from toposort import toposort_flatten

from xsdata.exceptions import ResolverValueError
from xsdata.models.codegen import Class
from xsdata.models.codegen import Package

logger = logging.getLogger(__name__)


[docs]@dataclass class DependenciesResolver: packages: Dict[QName, str] = field(default_factory=dict) aliases: Dict[QName, str] = field(default_factory=dict) imports: List[Package] = field(default_factory=list) class_list: List[QName] = field(init=False, default_factory=list) class_map: Dict[QName, Class] = field(init=False, default_factory=dict) package: str = field(init=False)
[docs] def process(self, classes: List[Class]): """ Resolve the dependencies for the given list of classes and the target package. Reset aliases and imports from any previous runs keep the record of the processed class names """ self.imports.clear() self.aliases.clear() self.class_map = self.create_class_map(classes) self.class_list = self.create_class_list(classes) self.resolve_imports()
[docs] def sorted_imports(self) -> List[Package]: """Return a new sorted by name list of import packages.""" return sorted(self.imports, key=lambda x: x.name)
[docs] def sorted_classes(self) -> List[Class]: """Return an iterator of classes property sorted for generation and apply import aliases.""" result = [] for name in self.class_list: obj = self.class_map.get(name) if obj is not None: self.apply_aliases(obj) result.append(obj) return result
[docs] def apply_aliases(self, obj: Class): """Walk the attributes tree and set the type aliases.""" for attr in obj.attrs: for attr_type in attr.types: attr_type_qname = obj.source_qname(attr_type.name) attr_type.alias = self.aliases.get(attr_type_qname) for inner in obj.inner: self.apply_aliases(inner)
[docs] def resolve_imports(self): """Walk the import qualified names, check for naming collisions and add the necessary code generator import instance.""" local_names = [qname.localname for qname in self.class_map.keys()] for qname in self.import_classes(): package = self.find_package(qname) exists = qname.localname in local_names self.add_import(qname=qname, package=package, exists=exists)
[docs] def add_import(self, qname: QName, package: str, exists: bool = False): """Append an import package to the list of imports with any if necessary aliases if the import name exists in the local module.""" alias = None if exists: module = package.split(".")[-1] alias = f"{module}:{qname.localname}" self.aliases[qname] = alias self.imports.append(Package(name=qname.localname, source=package, alias=alias))
[docs] def find_package(self, qname: QName) -> str: """ Return the package name for the given qualified class name. :raises ResolverValueError: if name doesn't exist. """ if qname not in self.packages: raise ResolverValueError(f"Unknown dependency: {qname.text}") return self.packages[qname]
[docs] def import_classes(self) -> List[QName]: """Return a list of class that need to be imported.""" return [qname for qname in self.class_list if qname not in self.class_map]
[docs] @staticmethod def create_class_list(classes: List[Class]): """Use topology sort to return a flat list for all the dependencies.""" return toposort_flatten( {obj.source_qname(): obj.dependencies() for obj in classes} )
[docs] @staticmethod def create_class_map(classes: List[Class]): """Index the list of classes by name.""" result: Dict[QName, Class] = dict() for obj in classes: qname = obj.source_qname() if qname in result: raise ResolverValueError(f"Duplicate class: `{obj.name}`") result[qname] = obj return result