from dataclasses import dataclass
from dataclasses import field
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from lxml import etree
from xsdata.logger import logger
from xsdata.models.codegen import Attr
from xsdata.models.codegen import AttrType
from xsdata.models.codegen import Class
from xsdata.models.elements import Schema
from xsdata.models.enums import DataType
from xsdata.utils import text
[docs]@dataclass
class ClassReducer:
"""The purpose of this class is to minimize the number of generated classes
because of excess verbosity in the given xsd schema and duplicate types."""
common_types: Dict[str, Class] = field(default_factory=dict)
[docs] def process(self, schema: Schema, classes: List[Class]) -> List[Class]:
"""
Process class list in steps.
Steps:
* Separate common/enumerations from class list
* Add all common/enumerations to registry
* Flatten the current common types
* Flatten the generation types
:return: The final list of normalized classes/enumerations
"""
classes, common = self.separate_common_types(classes)
self.add_common_types(common, schema.target_namespace)
self.flatten_classes(common, schema)
self.flatten_classes(classes, schema)
return [obj for obj in common if obj.is_enumeration] + classes
[docs] def flatten_classes(self, classes: List[Class], schema: Schema):
for obj in classes:
self.flatten_class(obj, schema)
[docs] def add_common_types(self, classes: List[Class], namespace: Optional[str]):
"""Add class to the common types registry with its qualified name with
the target namespace."""
self.common_types.update(
{etree.QName(namespace, obj.name).text: obj for obj in classes}
)
[docs] def find_common_type(self, name: str, schema: Schema) -> Optional[Class]:
"""Find a common type by the qualified named with the prefixed
namespace if exists or the target namespace."""
prefix, suffix = text.split(name)
name = suffix if prefix else name
namespace = schema.nsmap.get(prefix, schema.target_namespace)
qname = etree.QName(namespace, name)
return self.common_types.get(qname.text)
[docs] def flatten_class(self, item: Class, schema: Schema):
"""
Flatten class traits from the common types registry.
Steps:
* Enum unions
* Parent classes
* Attributes
* Inner classes
"""
self.flatten_enumeration_unions(item, schema)
for extension in list(item.extensions):
self.flatten_extension(item, extension, schema)
for attr in item.attrs:
self.flatten_attribute(item, attr, schema)
for inner in item.inner:
self.flatten_class(inner, schema)
[docs] def flatten_enumeration_unions(self, item: Class, schema: Schema):
if len(item.attrs) == 1 and item.attrs[0].name == "value":
all_enums = True
attrs = []
for attr_type in item.attrs[0].types:
is_enumeration = False
if attr_type.forward_ref and len(item.inner) == 1:
if item.inner[0].is_enumeration:
is_enumeration = True
attrs.extend(item.inner[0].attrs)
elif not attr_type.forward_ref and not attr_type.native:
common = self.find_common_type(attr_type.name, schema)
if common is not None and common.is_enumeration:
is_enumeration = True
attrs.extend(common.attrs)
if not is_enumeration:
all_enums = False
if all_enums:
item.attrs = attrs
[docs] def flatten_extension(self, item: Class, extension: AttrType, schema: Schema):
"""
If the extension class is found in the registry prepend it's attributes
to the given class.
The attribute list is deep cloned and each attribute type is
prepended with the extension prefix if it isn't a reference to
another schema.
"""
if extension.native:
return
common = self.find_common_type(extension.name, schema)
if common is None:
return
if not item.is_enumeration:
self.copy_attributes(common, item, extension)
item.extensions.remove(extension)
[docs] def flatten_attribute(self, item: Class, attr: Attr, schema: Schema):
"""
If the attribute type is found in the registry overwrite the given
attribute type and merge the restrictions.
If the common type doesn't have just one attribute fallback to
the default xsd type xs:string
"""
types = []
for attr_type in attr.types:
common = None
if not attr_type.native:
common = self.find_common_type(attr_type.name, schema)
restrictions = {}
if common is None or common.is_enumeration:
types.append(attr_type)
elif len(common.attrs) == 1:
common_attr = common.attrs[0]
types.extend(common_attr.types)
restrictions = common_attr.restrictions
self.copy_inner_classes(common, item)
else:
types.append(AttrType(name=DataType.STRING.code, native=True))
logger.warning("Missing type implementation: %s", common.type.__name__)
for key, value in restrictions.items():
if getattr(attr, key) is None:
setattr(attr, key, value)
attr.types = types
[docs] def separate_common_types(self, classes: List[Class]):
def condition(x: Class):
return x.is_enumeration or x.is_abstract or x.is_common
matches = self.pop_classes(classes, condition=condition)
return classes, matches
[docs] @staticmethod
def pop_classes(classes: List[Class], condition: Callable) -> List[Class]:
"""Pop and return the objects matching the given condition from the
given list of of classes."""
matches = []
for i in range(len(classes) - 1, -1, -1):
if condition(classes[i]):
matches.append(classes.pop(i))
return list(reversed(matches))
[docs] @staticmethod
def copy_attributes(source: Class, target: Class, extension: AttrType):
prefix = text.prefix(extension.name)
target.inner.extend(source.inner)
position = next(
(
index
for index, attr in enumerate(target.attrs)
if attr.index > extension.index
),
0,
)
for attr in source.attrs:
new_attr = attr.clone()
if prefix:
for attr_type in new_attr.types:
if not attr_type.native and attr_type.name.find(":") == -1:
attr_type.name = f"{prefix}:{attr_type.name}"
target.attrs.insert(position, new_attr)
position += 1
[docs] @staticmethod
def copy_inner_classes(source: Class, target: Class):
for inner in source.inner:
exists = next(
(found for found in target.inner if found.name == inner.name), None
)
if not exists:
target.inner.append(inner)
reducer = ClassReducer()