import functools
import math
import re
from decimal import Decimal
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from xml.sax.saxutils import quoteattr
from docformatter import format_code
from lxml.etree import QName
from xsdata.codegen.models import Attr
from xsdata.codegen.models import AttrType
from xsdata.codegen.models import Class
from xsdata.formats.converters import to_python
from xsdata.formats.dataclass import utils
from xsdata.models.enums import DataType
from xsdata.utils import text
[docs]@functools.lru_cache(maxsize=50)
def class_name(name: str) -> str:
"""Apply python conventions for class names."""
return text.pascal_case(utils.safe_snake(name, "type"))
[docs]@functools.lru_cache(maxsize=50)
def attribute_name(name: str) -> str:
"""Apply python conventions for instance variable names."""
return text.snake_case(utils.safe_snake(text.suffix(name)))
[docs]@functools.lru_cache(maxsize=50)
def constant_name(name: str) -> str:
"""Apply python conventions for constant names."""
return text.snake_case(utils.safe_snake(name)).upper()
[docs]def type_name(attr_type: AttrType) -> str:
"""Return native python type name or apply class name conventions."""
return attr_type.native_name or class_name(text.suffix(attr_type.name))
[docs]def class_docstring(obj: Class, enum: bool = False) -> str:
"""Generate docstring for the given class and the constructor arguments."""
lines = []
if obj.help:
lines.append(obj.help)
var_type = "cvar" if enum else "ivar"
name_func = constant_name if enum else attribute_name
for attr in obj.attrs:
description = attr.help.strip() if attr.help else ""
lines.append(f":{var_type} {name_func(attr.name)}: {description}".strip())
return format_code('"""\n{}\n"""'.format("\n".join(lines))) if lines else ""
[docs]def default_imports(output: str) -> str:
"""Generate the default imports for the given package output."""
result = []
if "Decimal" in output:
result.append("from decimal import Decimal")
if "(Enum)" in output:
result.append("from enum import Enum")
dataclasses = []
if "@dataclass" in output:
dataclasses.append("dataclass")
if "field(" in output:
dataclasses.append("field")
if dataclasses:
result.append(f"from dataclasses import {', '.join(dataclasses)}")
if "QName" in output:
result.append("from lxml.etree import QName")
types = [tp for tp in ["Dict", "List", "Optional", "Union"] if f"{tp}[" in output]
if types:
result.append(f"from typing import {', '.join(types)}")
return "\n".join(result)
[docs]def attribute_default(attr: Attr, ns_map: Optional[Dict] = None) -> Any:
"""Generate the field default value/factory for the given attribute."""
if attr.is_list:
return "list"
if attr.is_map:
return "dict"
if not isinstance(attr.default, str):
return attr.default
data_types = {
attr_type.native_code: attr_type.native_type
for attr_type in attr.types
if attr_type.native
}
local_types = list(set(data_types.values()))
default_value = to_python(local_types, attr.default, ns_map, in_order=False)
if isinstance(default_value, str):
if DataType.NMTOKENS.code in data_types:
default_value = quoteattr(
" ".join(filter(None, map(str.strip, re.split(r"\s+", default_value))))
)
elif default_value.startswith("@enum@"):
source, enumeration = default_value[6:].split("::", 1)
attr_type = next(
attr_type
for attr_type in attr.types
if text.suffix(attr_type.name) == source
)
if attr_type.alias:
source = attr_type.alias
default_value = f"{class_name(source)}.{constant_name(enumeration)}"
else:
default_value = quoteattr(default_value)
elif isinstance(default_value, float) and math.isinf(default_value):
default_value = f"float('{default_value}')"
elif isinstance(default_value, Decimal):
default_value = repr(default_value)
elif isinstance(default_value, QName):
default_value = (
f'QName("{default_value.namespace}", "{default_value.localname}")'
)
return default_value
[docs]def attribute_type(attr: Attr, parents: List[str]) -> str:
"""Generate type hints for the given attribute."""
type_names: List[str] = []
for attr_type in attr.types:
name = class_name(attr_type.alias) if attr_type.alias else type_name(attr_type)
if attr_type.forward and attr_type.circular:
outer_str = ".".join(map(class_name, parents))
name = f'"{outer_str}"'
elif attr_type.forward:
outer_str = ".".join(map(class_name, parents))
name = f'"{outer_str}.{name}"'
elif attr_type.circular:
name = f'"{name}"'
if name not in type_names:
type_names.append(name)
result = ", ".join(type_names)
if len(type_names) > 1:
result = f"Union[{result}]"
if attr.is_list:
result = f"List[{result}]"
elif attr.is_map:
result = f"Dict[{result}]"
elif attr.default is None and "Dict" not in result:
result = f"Optional[{result}]"
return result
filters = {
"attribute_name": attribute_name,
"attribute_default": attribute_default,
"attribute_metadata": attribute_metadata,
"attribute_type": attribute_type,
"class_name": class_name,
"class_docstring": class_docstring,
"constant_name": constant_name,
"default_imports": default_imports,
"format_arguments": format_arguments,
"type_name": type_name,
}