Source code for xsdata.reducer

from collections import defaultdict
from dataclasses import dataclass
from dataclasses import field
from typing import Dict
from typing import List
from typing import Optional

from lxml import etree

from xsdata.logger import logger
from xsdata.models.codegen import Attr
from xsdata.models.codegen import AttrType
from xsdata.models.codegen import Class
from xsdata.models.codegen import Extension
from xsdata.models.elements import Schema
from xsdata.models.enums import DataType
from xsdata.models.enums import TagType
from xsdata.utils import text


[docs]def simple_type(item: Class): return item.is_enumeration or item.is_abstract or item.is_common
[docs]@dataclass class ClassReducer: """The purpose of this class is to minimize the number of generated classes because of excess verbosity in the given xsd schema and duplicate types.""" schema: Schema = field(init=False) common_types: Dict[str, Class] = field(default_factory=dict) processed: Dict = field(default_factory=dict) class_index: Dict[str, List[Class]] = field( default_factory=lambda: defaultdict(list) )
[docs] def process(self, schema: Schema, classes: List[Class]) -> List[Class]: """ Process class list in steps. Steps: * Merge redefined classes * Create a class qname index * Mark as abstract classes with the same qname * Flatten classes * Return a final class list for code generators. """ self.schema = schema self.merge_redefined_classes(classes) self.create_class_qname_index(classes) self.mark_abstract_duplicate_classes() self.flatten_classes() return self.fetch_classes_for_generation()
[docs] def fetch_classes_for_generation(self) -> List[Class]: """ Return the qualified classes to continue for code generation. The rest of the classes are stored as common types to be used later by the next schemas in the proccess. Qualifications: * not an abstract * type: element | complexType | simpleType with enumerations """ result = [] for qname, classes in self.class_index.items(): for item in classes: should_store = item.is_common or item.is_abstract if should_store: qname = self.qname(item.name) self.common_types[qname] = item if not should_store or item.is_enumeration: result.append(item) return result
[docs] def create_class_qname_index(self, classes: List[Class]): self.class_index.clear() self.processed.clear() for item in classes: qname = self.qname(item.name) self.class_index[qname].append(item)
[docs] def flatten_classes(self): for classes in self.class_index.values(): for obj in classes: if obj.key not in self.processed: self.flatten_class(obj)
[docs] def is_self_referencing(self, item: Class, dependency: AttrType) -> bool: return self.find_class(dependency, condition=lambda x: x is item) is not None
[docs] def find_class( self, dependency: AttrType, condition=simple_type ) -> Optional[Class]: qname = self.qname(dependency.name) item = self.find_schema_class(qname, condition=condition) return item or self.find_common_class(qname, condition=condition)
[docs] def find_common_class(self, qname: str, condition=None): if qname in self.common_types: candidate = self.common_types[qname] return candidate if not condition or condition(candidate) else None else: return None
[docs] def find_schema_class(self, qname: str, condition=None) -> Optional[Class]: candidates = list(filter(condition, self.class_index.get(qname, []))) if candidates: candidate = candidates.pop(0) if candidates: logger.warning("More than one candidate found for %s", qname) if candidate.key not in self.processed: self.flatten_class(candidate) return candidate return None
[docs] def merge_redefined_classes(self, classes: List[Class]): """Merge original and redefined classes.""" grouped: Dict[str, List[Class]] = defaultdict(list) for item in classes: grouped[f"{item.type.__name__}{item.name}"].append(item) for items in grouped.values(): if len(items) == 1: continue if len(items) > 2: raise NotImplementedError( f"Redefined class `{items[0].name}` more than once." ) winner: Class = items.pop() looser: Class = items.pop() classes.remove(looser) for i in range(len(winner.attrs)): attr = winner.attrs[i] if attr.types[0].name == winner.name or attr.types[0].name.endswith( f":{winner.name}" ): restrictions = looser.attrs[i].restrictions attr.types = looser.attrs[i].types attr.restrictions.update(restrictions) for i in range(len(winner.extensions) - 1, -1, -1): extension = winner.extensions[i] if extension.type.name == winner.name or extension.type.name.endswith( f":{winner.name}" ): winner.extensions.pop(i) self.copy_attributes(looser, winner, extension)
[docs] def mark_abstract_duplicate_classes(self): """Search for groups with more than one class and mark as abstract any complex type with the same name as an element.""" for classes in self.class_index.values(): if len(classes) == 1: continue element = next( (obj for obj in classes if obj.is_element and not obj.is_abstract), None ) if not element: continue for obj in classes: if obj is not element and not obj.is_common: obj.is_abstract = True
[docs] def flatten_class(self, item: Class): """ Flatten class traits from the common types registry. Steps: * Enum unions * Parent classes * Attributes * Inner classes """ self.processed[item.key] = True if item.is_common: self.flatten_enumeration_unions(item) for extension in list(item.extensions): self.flatten_extension(item, extension) for attr in list(item.attrs): self.flatten_attribute(item, attr) for inner in item.inner: self.flatten_class(inner)
[docs] def flatten_enumeration_unions(self, item: Class): if len(item.attrs) == 1 and item.attrs[0].name == "value": all_enums = True attrs = [] for attr_type in item.attrs[0].types: is_enumeration = False if attr_type.forward_ref and len(item.inner) == 1: if item.inner[0].is_enumeration: is_enumeration = True attrs.extend(item.inner[0].attrs) elif not attr_type.forward_ref and not attr_type.native: common = self.find_class(attr_type) if common is not None and common.is_enumeration: is_enumeration = True attrs.extend(common.attrs) if not is_enumeration: all_enums = False if all_enums: item.attrs = attrs
[docs] def flatten_extension(self, item: Class, extension: Extension): """ If the extension class is found in the registry prepend it's attributes to the given class. The attribute list is deep cloned and each attribute type is prepended with the extension prefix if it isn't a reference to another schema. """ if extension.type.native: return common = self.find_class(extension.type) if common is None: return elif common is item: pass elif not item.is_enumeration and common.is_enumeration: self.create_default_attribute(item, extension) elif not item.is_enumeration or common.is_enumeration: self.copy_attributes(common, item, extension) item.extensions.remove(extension)
[docs] def flatten_attribute(self, item: Class, attr: Attr): """ If the attribute type is found in the registry overwrite the given attribute type and merge the restrictions. If the common type doesn't have just one attribute fallback to the default xsd type xs:string """ types = [] for attr_type in attr.types: common = None if not attr_type.native: common = self.find_class(attr_type) if common is None: attr_type.self_ref = self.is_self_referencing(item, attr_type) types.append(attr_type) elif common.is_enumeration: types.append(attr_type) elif len(common.attrs) == 1: common_attr = common.attrs[0] types.extend(common_attr.types) attr.restrictions.update(common_attr.restrictions) self.copy_inner_classes(common, item) else: types.append(AttrType(name=DataType.STRING.code, native=True)) logger.warning("Missing type implementation: %s", common.type.__name__) attr.types = types
[docs] def qname(self, name: str) -> str: prefix, suffix = text.split(name) namespace = self.schema.target_namespace if prefix: name = suffix namespace = self.schema.nsmap.get(prefix) return etree.QName(namespace, name).text
[docs] @staticmethod def copy_attributes(source: Class, target: Class, extension: Extension): prefix = text.prefix(extension.type.name) target.inner.extend(source.inner) position = next( ( index for index, attr in enumerate(target.attrs) if attr.index > extension.type.index ), 0, ) for attr in source.attrs: new_attr = attr.clone() new_attr.restrictions.update(extension.restrictions, force=True) if prefix: for attr_type in new_attr.types: if not attr_type.native and attr_type.name.find(":") == -1: attr_type.name = f"{prefix}:{attr_type.name}" target.attrs.insert(position, new_attr) position += 1
[docs] @staticmethod def copy_inner_classes(source: Class, target: Class): for inner in source.inner: exists = next( (found for found in target.inner if found.name == inner.name), None ) if not exists: target.inner.append(inner)
[docs] @staticmethod def create_default_attribute(item: Class, extension: Extension): item.attrs.append( Attr( name="value", index=0, default=None, types=[extension.type.clone()], local_type=TagType.EXTENSION, restrictions=extension.restrictions.clone(), ) )
reducer = ClassReducer()