Source code for xsdata.resolver

import logging
from dataclasses import dataclass
from dataclasses import field
from typing import Dict
from typing import Iterator
from typing import List
from typing import Optional
from typing import Set

from lxml import etree
from toposort import toposort_flatten

from xsdata.models.codegen import Class
from xsdata.models.codegen import Package
from xsdata.models.elements import Schema
from xsdata.utils import text
from xsdata.utils.text import split

logger = logging.getLogger(__name__)


[docs]@dataclass class DependenciesResolver: processed: Dict[str, str] = field(default_factory=dict) aliases: Dict[str, str] = field(default_factory=dict) imports: List[Package] = field(default_factory=list) class_list: List[str] = field(init=False) class_map: Dict[str, Class] = field(init=False) schema: Schema = field(init=False) package: str = field(init=False)
[docs] def process(self, classes: List[Class], schema: Schema, package: str): """ Process a list of classes for the given schema and package. Reset aliases and imports from any previous runs keep the record of the processed class names """ self.imports.clear() self.aliases.clear() self.schema = schema self.class_map = self.create_class_map(classes) self.class_list = self.create_class_list(classes) self.package = package self.resolve_imports()
[docs] def sorted_imports(self) -> List[Package]: """Return a new sorted by name list of import packages.""" return sorted(self.imports, key=lambda x: x.name)
[docs] def sorted_classes(self) -> Iterator[Class]: """ Return an iterator of classes property sorted for generation. Keep track of the class names and their target package name for future classes. Also apply type aliases for the given process run. """ for name in self.class_list: obj = self.class_map.get(name) if obj is not None: self.add_package(obj) yield self.apply_aliases(obj)
[docs] def apply_aliases(self, obj: Class) -> Class: """Walk the attributes tree and set the type aliases.""" for attr in obj.attrs: for att_type in attr.types: att_type.alias = self.aliases.get(att_type.name) for inner in obj.inner: self.apply_aliases(inner) return obj
[docs] def resolve_imports(self) -> None: """Walk all the import class names and add type aliases for name collisions with the given list of classes and build a list of import packages.""" for ref in self.import_classes(): prefix, name = split(ref) package = self.find_package(prefix, name) alias = ref if prefix and self.class_map.get(name) else None self.add_import(name=name, package=package, alias=alias)
[docs] def add_import(self, name: str, package: str, alias: Optional[str]) -> None: """Create and append an import package to the list of imports, collect a map of aliases for when we process the list of classes to generate.""" if alias is not None: self.aliases[alias] = alias self.imports.append(Package(name=name, source=package, alias=alias))
[docs] def add_package(self, obj: Class) -> None: """ Add the given class to the map of processed items indexed with the qname of the class and the schema target namespace. eg {http://www.namespace/name}ClassName """ qname = etree.QName(self.schema.target_namespace, obj.name) self.processed[qname.text] = self.package
[docs] def find_package(self, prefix, name) -> str: """ Use the schema namespaces map to find the package where the requested class belongs to. Example: * Schema nsmap {"common": "http://www.common/ns"} * Resolved processed {"{http://www.common/ns}address": "source.package"} * Request for (common, address) will return source.package """ namespace = ( self.schema.nsmap.get(prefix) if prefix else self.schema.target_namespace ) qname = etree.QName(namespace, name) return self.processed[qname.text]
[docs] def import_classes(self): """Return a list of class that need to be imported.""" return [name for name in self.class_list if name not in self.class_map]
[docs] def create_class_list(self, classes: List[Class]): """Use topology sort to return a flat list for all the dependencies.""" prefix = self.schema.target_prefix return toposort_flatten( {obj.name: self.collect_deps(obj, prefix) for obj in classes} )
[docs] def collect_deps(self, obj: Class, prefix: Optional[str]) -> Set[str]: """ Return a list of dependencies for the given class. Collect: * base classes * attribute types * recursively go through the inner classes * Ignore inner class references * Filter the standard xsd types """ deps: Set[str] = set() for attr in obj.attrs: deps.update( [ attr_type.name for attr_type in attr.types if not attr_type.forward_ref and not attr_type.native ] ) deps.update(ext.type.name for ext in obj.extensions if not ext.type.native) for inner in obj.inner: deps.update(self.collect_deps(inner, prefix)) return {text.strip_prefix(dep, prefix) for dep in deps}
[docs] @staticmethod def create_class_map(classes: List[Class]): """Index the list of classes by name.""" result: Dict[str, Class] = dict() for obj in classes: if obj.name in result: raise ValueError(f"Duplicate class name`{obj.name}`") result[obj.name] = obj return result