Source code for xsdata.codegen.handlers.class_extension

from dataclasses import dataclass
from typing import Optional

from xsdata.codegen.mixins import ContainerInterface
from xsdata.codegen.mixins import HandlerInterface
from xsdata.codegen.models import Attr
from xsdata.codegen.models import AttrType
from xsdata.codegen.models import Class
from xsdata.codegen.models import Extension
from xsdata.codegen.utils import ClassUtils
from xsdata.logger import logger
from xsdata.models.enums import DataType
from xsdata.models.enums import NamespaceType
from xsdata.models.enums import Tag
from xsdata.models.xsd import ComplexType
from xsdata.models.xsd import SimpleType


[docs]@dataclass class ClassExtensionHandler(HandlerInterface): """Reduce class extensions by copying or creating new attributes.""" REMOVE_EXTENSION = 0 FLATTEN_EXTENSION = 1 IGNORE_EXTENSION = 2 container: ContainerInterface
[docs] def process(self, target: Class): """ Iterate and process the target class's extensions in reverser order. The reverse order is necessary in order to maintain the correct attributes ordering during cloning. """ for extension in reversed(target.extensions): self.process_extension(target, extension)
[docs] def process_extension(self, target: Class, extension: Extension): """Slit the process of extension into schema data types and user defined types.""" if extension.type.native: self.process_native_extension(target, extension) else: self.process_dependency_extension(target, extension)
[docs] @classmethod def process_native_extension(cls, target: Class, extension: Extension): """ Native type flatten handler. In case of enumerations copy the native data type to all enum members, otherwise create a default text value with the extension attributes. """ if target.is_enumeration: cls.copy_extension_type(target, extension) else: cls.add_default_attribute(target, extension)
[docs] def process_dependency_extension(self, target: Class, extension: Extension): """User defined type flatten handler.""" source = self.find_dependency(extension.type) if not source: logger.warning("Missing extension type: %s", extension.type.name) target.extensions.remove(extension) elif not source.is_complex or source.is_enumeration: self.process_simple_extension(source, target, extension) else: self.process_complex_extension(source, target, extension)
[docs] @classmethod def process_simple_extension(cls, source: Class, target: Class, ext: Extension): """ Simple flatten extension handler for common classes eg SimpleType, Restriction. Steps: 1. If target is source: drop the extension. 2. If source is enumeration and target isn't create default value attribute. 3. If both source and target are enumerations copy all attributes. 4. If both source and target are not enumerations copy all attributes. 5. If target is enumeration: drop the extension. """ if source is target: target.extensions.remove(ext) elif source.is_enumeration and not target.is_enumeration: cls.add_default_attribute(target, ext) elif source.is_enumeration == target.is_enumeration: ClassUtils.copy_attributes(source, target, ext) else: # this is an enumeration target.extensions.remove(ext)
[docs] @classmethod def process_complex_extension(cls, source: Class, target: Class, ext: Extension): """ Complex flatten extension handler for primary classes eg ComplexType, Element. Compare source and target classes and either remove the extension completely, copy all source attributes to the target class or leave the extension alone. """ res = cls.compare_attributes(source, target) if res == cls.REMOVE_EXTENSION: target.extensions.remove(ext) elif res == cls.FLATTEN_EXTENSION: ClassUtils.copy_attributes(source, target, ext) else: logger.debug("Ignore extension: %s", ext.type.name)
[docs] def find_dependency(self, attr_type: AttrType) -> Optional[Class]: """ Find dependency for the given extension type with priority. Search priority: xs:SimpleType > xs:ComplexType """ conditions = (lambda x: x.type is SimpleType, lambda x: x.type is ComplexType) for condition in conditions: result = self.container.find(attr_type.qname, condition=condition) if result: return result return None
[docs] @classmethod def compare_attributes(cls, source: Class, target: Class) -> int: """ Compare the attributes of the two classes and return whether the source class can and should be flattened. Remove: 1. Source is the Target 2. Target includes all the source attributes Flatten: 1. Source includes some of the target attributes 2. The source class is marked to be forced flattened 3. Source class includes an attribute that needs to be last 4. Target class includes an attribute that needs to be last 5. Source class is a simple type """ if source is target: return cls.REMOVE_EXTENSION if target.attrs and source.attrs: source_attrs = {attr.name for attr in source.attrs} target_attrs = {attr.name for attr in target.attrs} difference = source_attrs - target_attrs if not difference: return cls.REMOVE_EXTENSION if len(difference) != len(source_attrs): return cls.FLATTEN_EXTENSION if ( source.strict_type or (source.has_suffix_attr and target.attrs) or target.has_suffix_attr or source.is_simple_type ): return cls.FLATTEN_EXTENSION return cls.IGNORE_EXTENSION
[docs] @classmethod def copy_extension_type(cls, target: Class, extension: Extension): """Add the given extension type to all target attributes types and remove it from the target class extensions.""" for attr in target.attrs: attr.types.append(extension.type.clone()) target.extensions.remove(extension)
[docs] @classmethod def add_default_attribute(cls, target: Class, extension: Extension): """Add a default value field to the given class based on the extension type.""" if extension.type.native_code != DataType.ANY_TYPE.code: tag = Tag.EXTENSION name = "value" default = None namespace = None else: tag = Tag.ANY name = "any_element" default = list if extension.restrictions.is_list else None namespace = NamespaceType.ANY attr = cls.get_or_create_attribute(target, name, tag) attr.types.append(extension.type.clone()) attr.restrictions.merge(extension.restrictions) attr.namespace = namespace attr.default = default target.extensions.remove(extension)
[docs] @classmethod def get_or_create_attribute(cls, target: Class, name: str, tag: str) -> Attr: """Find or create for the given parameters an attribute in the target class.""" for attr in target.attrs: if attr.name == attr.local_name == name and attr.tag == tag: return attr attr = Attr(name=name, local_name=name, tag=tag) target.attrs.insert(0, attr) return attr