Source code for xsdata.reducer

import copy
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional

from lxml import etree

from xsdata.logger import logger
from xsdata.models.codegen import Attr, Class, Extension
from xsdata.models.elements import Schema
from xsdata.models.enums import XSDType
from xsdata.utils.text import split


[docs]@dataclass class ClassReducer: """The purpose of this class is to minimize the number of generated classes because of excess verbosity in the given xsd schema and duplicate types.""" common_types: Dict[str, Class] = field(default_factory=dict)
[docs] def process(self, schema: Schema, classes: List[Class]) -> List[Class]: """ Process class list in steps. Steps: * Separate common/enumerations from class list * Add all common/enumerations to registry * Flatten the current common types * Flatten the generation types :return: The final list of normalized classes/enumerations """ namespace = schema.target_namespace nsmap = schema.nsmap classes, common = self.separate_common_types(classes) self.add_common_types(common, namespace) self.flatten_classes(common, nsmap) self.flatten_classes(classes, nsmap) return [obj for obj in common if obj.is_enumeration] + classes
[docs] def flatten_classes(self, classes: List[Class], nsmap: Dict): for obj in classes: self.flatten_class(obj, nsmap)
[docs] def add_common_types(self, classes: List[Class], namespace: Optional[str]): """Add class to the common types registry with its qualified name with the target namespace.""" self.common_types.update( {etree.QName(namespace, obj.name).text: obj for obj in classes} )
[docs] def find_common_type(self, name: str, nsmap: Dict) -> Optional[Class]: """Find a common type by the qualified named with the namespace prefix.""" prefix = None split_name = name.split(":") if len(split_name) == 2: prefix, name = split_name namespace = nsmap.get(prefix) qname = etree.QName(namespace, name) return self.common_types.get(qname.text)
[docs] def flatten_class(self, item: Class, nsmap: Dict): """ Flatten class traits from the common types registry. Steps: * Parent classes * Attributes * Inner classes """ for extension in list(item.extensions): self.flatten_extension(item, extension, nsmap) for attr in item.attrs: self.flatten_attribute(attr, nsmap) for inner in item.inner: self.flatten_class(inner, nsmap)
[docs] def flatten_extension( self, item: Class, extension: Extension, nsmap: Dict ): """ If the extension class is found in the registry prepend it's attributes to the given class. The attribute list is deep cloned and each attribute type is prepended with the extension prefix if it isn't a reference to another schema. """ common = self.find_common_type(extension.name, nsmap) if common is not None: prefix, ext = split(extension.name) item.inner.extend(copy.deepcopy(common.inner)) new_attrs = copy.deepcopy(common.attrs) position = next( ( index for index, attr in enumerate(item.attrs) if attr.index > extension.index ), 0, ) for attr in new_attrs: if prefix and attr.type.find(":") == -1: attr.type = f"{prefix}:{attr.type}" item.attrs.insert(position, attr) position += 1 item.extensions.remove(extension)
[docs] def flatten_attribute(self, attr: Attr, nsmap: Dict): """ If the attribute type is found in the registry overwrite the given attribute type and merge the restrictions. If the common type doesn't have just one attribute fallback to the default xsd type xs:string """ append = False for type_name in attr.types: common = self.find_common_type(type_name, nsmap) if common is None or common.is_enumeration: continue if len(common.attrs) == 1: type_name = common.attrs[0].type restrictions = common.attrs[0].restrictions else: type_name = XSDType.STRING.code restrictions = {} logger.warning( f"Missing type implementation: {common.type.__name__}" ) attr.type = f"{attr.type} {type_name}" if append else type_name for key, value in restrictions.items(): setattr(attr, key, value) append = True
[docs] def separate_common_types(self, classes: List[Class]): def condition(x: Class): return x.is_enumeration or x.is_abstract or x.is_common matches = self.pop_classes(classes, condition=condition) return classes, matches
[docs] @staticmethod def pop_classes(classes: List[Class], condition: Callable) -> List[Class]: """Pop and return the objects matching the given condition from the given list of of classes.""" matches = [] for i in range(len(classes) - 1, -1, -1): if condition(classes[i]): matches.append(classes.pop(i)) return list(reversed(matches))
reducer = ClassReducer()