import sys
from collections import defaultdict
from typing import Dict
from typing import List
from typing import Optional
from lxml.etree import QName
from xsdata.logger import logger
from xsdata.models.codegen import Attr
from xsdata.models.codegen import AttrType
from xsdata.models.codegen import Class
from xsdata.models.codegen import Extension
from xsdata.models.codegen import Restrictions
from xsdata.models.enums import DataType
from xsdata.models.enums import NamespaceType
from xsdata.models.enums import Tag
from xsdata.utils import text
[docs]class ClassUtils:
INCLUDES_NONE = 0
INCLUDES_SOME = 1
INCLUDES_ALL = 2
[docs] @classmethod
def compare_attributes(cls, source: Class, target: Class) -> int:
if source is target:
return cls.INCLUDES_ALL
if not target.attrs:
return cls.INCLUDES_NONE
source_attrs = {attr.name for attr in source.attrs}
target_attrs = {attr.name for attr in target.attrs}
difference = source_attrs - target_attrs
if not difference:
return cls.INCLUDES_ALL
if len(difference) != len(source_attrs):
return cls.INCLUDES_SOME
return cls.INCLUDES_NONE
[docs] @classmethod
def sanitize_attributes(cls, target: Class):
for attr in target.attrs:
cls.sanitize_attribute(attr)
cls.sanitize_restrictions(attr.restrictions)
for i in range(len(target.attrs)):
cls.sanitize_attribute_sequence(target.attrs, i)
cls.sanitize_attribute_name(target.attrs, i)
for inner in target.inner:
cls.sanitize_attributes(inner)
[docs] @classmethod
def sanitize_attribute(cls, attr: Attr):
if attr.is_list:
attr.fixed = False
else:
attr.restrictions.sequential = False
if attr.is_optional or attr.is_xsi_type:
attr.fixed = False
attr.default = None
[docs] @classmethod
def sanitize_restrictions(cls, restrictions: Restrictions):
min_occurs = restrictions.min_occurs or 0
max_occurs = restrictions.max_occurs or 0
if min_occurs == 0 and max_occurs <= 1:
restrictions.required = None
restrictions.min_occurs = None
restrictions.max_occurs = None
if min_occurs == 1 and max_occurs == 1:
restrictions.required = True
restrictions.min_occurs = None
restrictions.max_occurs = None
elif restrictions.max_occurs and max_occurs > 1:
restrictions.min_occurs = min_occurs
restrictions.required = None
[docs] @classmethod
def sanitize_attribute_sequence(cls, attrs: List[Attr], index: int):
"""Reset the attribute at the given index if it has no siblings with
the sequential restriction."""
if (
not attrs[index].restrictions.sequential
or (index - 1 >= 0 and attrs[index - 1].restrictions.sequential)
or (index + 1 < len(attrs) and attrs[index + 1].restrictions.sequential)
):
return
attrs[index].restrictions.sequential = False
[docs] @classmethod
def sanitize_attribute_name(cls, attrs: List[Attr], index: int):
"""Check if the attribute at the given index has a duplicate name and
prepend if exists the attribute namespace."""
current = attrs[index]
current.name = text.suffix(current.name)
exists = any(
attr is not current and current.name == text.suffix(attr.name)
for attr in attrs
)
if exists and current.namespace:
current.name = f"{current.namespace}_{current.name}"
[docs] @classmethod
def merge_duplicate_attributes(cls, target: Class):
"""
Flatten duplicate attributes.
Remove duplicate fields in case of attributes or enumerations
otherwise convert fields to lists.
"""
if not target.attrs:
return
result: List[Attr] = []
for attr in target.attrs:
pos = cls.find_attribute(result, attr)
existing = result[pos] if pos > -1 else None
if not existing:
result.append(attr)
elif not (attr.is_attribute or attr.is_enumeration):
min_occurs = existing.restrictions.min_occurs or 0
max_occurs = existing.restrictions.max_occurs or 1
attr_min_occurs = attr.restrictions.min_occurs or 0
attr_max_occurs = attr.restrictions.max_occurs or 1
existing.restrictions.min_occurs = min(min_occurs, attr_min_occurs)
existing.restrictions.max_occurs = max_occurs + attr_max_occurs
existing.fixed = False
existing.restrictions.sequential = (
existing.restrictions.sequential or attr.restrictions.sequential
)
target.attrs = result
[docs] @classmethod
def copy_attributes(cls, source: Class, target: Class, extension: Extension):
prefix = text.prefix(extension.type.name)
target.extensions.remove(extension)
target_attr_names = {text.suffix(attr.name) for attr in target.attrs}
index = 0
for attr in source.attrs:
if text.suffix(attr.name) not in target_attr_names:
clone = cls.clone_attribute(attr, extension.restrictions, prefix)
if attr.index == sys.maxsize:
target.attrs.append(clone)
continue
target.attrs.insert(index, clone)
index += 1
cls.copy_inner_classes(source, target)
[docs] @classmethod
def clone_attribute(
cls, attr: Attr, restrictions: Restrictions, prefix: Optional[str] = None
) -> Attr:
clone = attr.clone()
clone.restrictions.merge(restrictions)
if prefix:
for attr_type in clone.types:
if not attr_type.native and attr_type.name.find(":") == -1:
attr_type.name = f"{prefix}:{attr_type.name}"
return clone
[docs] @classmethod
def merge_attribute_type(
cls, source: Class, target: Class, attr: Attr, attr_type: AttrType
):
if len(source.attrs) != 1:
logger.warning("Missing implementation: %s", source.type.__name__)
cls.reset_attribute_type(attr_type)
else:
source_attr = source.attrs[0]
index = attr.types.index(attr_type)
attr.types.pop(index)
for source_attr_type in source_attr.types:
clone_type = source_attr_type.clone()
attr.types.insert(index, clone_type)
index += 1
restrictions = source_attr.restrictions.clone()
restrictions.merge(attr.restrictions)
attr.restrictions = restrictions
cls.copy_inner_classes(source, target)
[docs] @classmethod
def copy_inner_classes(cls, source: Class, target: Class):
"""
Copy inner classes from source to target class.
Check for duplicates by name and skip if it already exists.
"""
for inner in source.inner:
if not any(existing.name == inner.name for existing in target.inner):
target.inner.append(inner)
[docs] @classmethod
def merge_redefined_classes(cls, classes: List[Class]):
"""Merge original and redefined classes."""
grouped: Dict[str, List[Class]] = defaultdict(list)
for item in classes:
grouped[f"{item.type.__name__}{item.source_qname()}"].append(item)
for items in grouped.values():
if len(items) == 1:
continue
winner: Class = items.pop()
for item in items:
classes.remove(item)
self_extension = next(
(
ext
for ext in winner.extensions
if text.suffix(ext.type.name) == winner.name
),
None,
)
if not self_extension:
continue
cls.copy_attributes(item, winner, self_extension)
for looser_ext in item.extensions:
new_ext = looser_ext.clone()
new_ext.restrictions.merge(self_extension.restrictions)
winner.extensions.append(new_ext)
[docs] @classmethod
def update_abstract_classes(cls, classes: List[Class]):
"""
Update classes with the same qualified name to set implied abstract
flags.
If a non abstract xs:element exists in the list mark the rest
xs:complexType(s) as abstract.
"""
element = next((obj for obj in classes if obj.is_element), None)
if element:
for obj in classes:
if obj is not element and obj.is_complex:
obj.abstract = True
[docs] @classmethod
def create_mixed_attribute(cls, target: Class):
if not target.mixed or target.has_wild_attr:
return
attr = Attr(
name="content",
local_name="content",
index=0,
types=[AttrType(name=DataType.ANY_TYPE.code, native=True)],
tag=Tag.ANY,
namespace=NamespaceType.ANY.value,
)
target.attrs.insert(0, attr)
[docs] @classmethod
def create_default_attribute(cls, item: Class, extension: Extension):
if extension.type.native_code == DataType.ANY_TYPE.code:
attr = Attr(
name="any_element",
local_name="any_element",
index=0,
default=list if extension.restrictions.is_list else None,
types=[extension.type.clone()],
tag=Tag.ANY,
namespace=NamespaceType.ANY.value,
restrictions=extension.restrictions.clone(),
)
else:
attr = Attr(
name="value",
local_name="value",
index=0,
default=None,
types=[extension.type.clone()],
tag=Tag.EXTENSION,
restrictions=extension.restrictions.clone(),
)
item.attrs.insert(0, attr)
item.extensions.remove(extension)
[docs] @classmethod
def create_reference_attribute(cls, source: Class, qname: QName) -> Attr:
prefix = None
if qname.namespace != source.source_namespace:
prefix = source.source_prefix
reference = f"{prefix}:{source.name}" if prefix else source.name
return Attr(
name=source.name,
local_name=source.name,
index=0,
default=None,
types=[AttrType(name=reference)],
tag=source.type.__name__,
namespace=source.namespace,
)
[docs] @classmethod
def find_attribute(cls, attrs: List[Attr], attr: Attr) -> int:
try:
return attrs.index(attr)
except ValueError:
return -1
[docs] @classmethod
def reset_attribute_type(cls, attr_type: AttrType):
attr_type.name = DataType.STRING.code
attr_type.native = True
attr_type.self_ref = False
attr_type.forward_ref = False