import math
from collections import defaultdict
from dataclasses import dataclass
from dataclasses import field
from decimal import Decimal
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Type
from xml.etree.ElementTree import QName
from xml.sax.saxutils import quoteattr
from docformatter import format_code
from jinja2 import Environment
from xsdata.codegen.models import Attr
from xsdata.codegen.models import AttrChoice
from xsdata.codegen.models import AttrType
from xsdata.codegen.models import Class
from xsdata.formats.converter import converter
from xsdata.formats.dataclass import utils
from xsdata.models.config import GeneratorAlias
from xsdata.models.config import GeneratorConfig
from xsdata.utils import text
from xsdata.utils.collections import unique_sequence
from xsdata.utils.namespaces import clean_uri
CLASS = 0
FIELD = 1
MODULE = 2
PACKAGE = 3
[docs]@dataclass
class Filters:
class_aliases: Dict = field(default_factory=dict)
field_aliases: Dict = field(default_factory=dict)
package_aliases: Dict = field(default_factory=dict)
module_aliases: Dict = field(default_factory=dict)
class_case: Callable = field(default=text.pascal_case)
field_case: Callable = field(default=text.snake_case)
package_case: Callable = field(default=text.snake_case)
module_case: Callable = field(default=text.snake_case)
class_safe_prefix: str = field(default="type")
field_safe_prefix: str = field(default="value")
package_safe_prefix: str = field(default="pkg")
module_safe_prefix: str = field(default="mod")
cache: Dict = field(default_factory=lambda: defaultdict(dict), init=False)
[docs] def register(self, env: Environment):
env.filters.update(
{
"field_name": self.field_name,
"field_default": self.field_default_value,
"field_metadata": self.field_metadata,
"field_type": self.field_type,
"class_name": self.class_name,
"class_docstring": self.class_docstring,
"constant_name": self.constant_name,
"constant_value": self.constant_value,
"default_imports": self.default_imports,
"format_metadata": self.format_metadata,
"type_name": self.type_name,
}
)
[docs] def class_name(self, name: str) -> str:
"""Convert the given string to a class name according to the selected
conventions or use an existing alias."""
cache = self.cache[CLASS]
if name not in cache:
cache[name] = self.class_aliases.get(name) or self._class_name(name)
return cache[name]
def _class_name(self, name: str) -> str:
return self.class_case(utils.safe_snake(name, self.class_safe_prefix))
[docs] def field_name(self, name: str) -> str:
"""Convert the given string to a field name according to the selected
conventions or use an existing alias."""
cache = self.cache[FIELD]
if name not in cache:
cache[name] = self.field_aliases.get(name) or self._attribute_name(name)
return cache[name]
def _attribute_name(self, name: str) -> str:
return self.field_case(utils.safe_snake(name, self.field_safe_prefix))
[docs] def constant_name(self, name: str) -> str:
"""Apply python conventions for constant names."""
return self.field_name(name).upper()
[docs] def module_name(self, name: str) -> str:
"""Convert the given string to a module name according to the selected
conventions or use an existing alias."""
cache = self.cache[MODULE]
if name not in cache:
cache[name] = self.module_aliases.get(name) or self._module_name(name)
return cache[name]
def _module_name(self, name: str) -> str:
return self.module_case(
utils.safe_snake(clean_uri(name), self.module_safe_prefix)
)
[docs] def package_name(self, name: str) -> str:
"""Convert the given string to a package name according to the selected
conventions or use an existing alias."""
cache = self.cache[PACKAGE]
if name not in cache:
if name in self.package_aliases:
cache[name] = self.package_aliases[name]
else:
cache[name] = ".".join(
self.package_aliases.get(part) or self._package_name(part)
for part in name.split(".")
)
return cache[name]
def _package_name(self, part: str) -> str:
return self.package_case(utils.safe_snake(part, self.package_safe_prefix))
[docs] def type_name(self, attr_type: AttrType) -> str:
"""Return native python type name or apply class name conventions."""
return attr_type.native_name or self.class_name(attr_type.name)
[docs] def field_choices(
self, attr: Attr, parent_namespace: Optional[str], parents: List[str]
) -> Optional[List]:
"""
Return a list of metadata dictionaries for the choices of the given
attribute.
Return None if attribute has no choices.
"""
if not attr.choices:
return None
def build(choice: AttrChoice) -> Dict:
types = list({x.native_type for x in choice.types if x.native})
restrictions = choice.restrictions.asdict(types)
namespace = (
choice.namespace if parent_namespace != choice.namespace else None
)
return self.filter_metadata(
{
"name": choice.name,
"wildcard": choice.is_wildcard,
"type": self.choice_type(choice, parents),
"namespace": namespace,
**restrictions,
}
)
return list(map(build, attr.choices))
[docs] def class_docstring(self, obj: Class, enum: bool = False) -> str:
"""Generate docstring for the given class and the constructor
arguments."""
lines = []
if obj.help:
lines.append(obj.help)
var_type = "cvar" if enum else "ivar"
name_func = self.constant_name if enum else self.field_name
for attr in obj.attrs:
description = attr.help.strip() if attr.help else ""
lines.append(f":{var_type} {name_func(attr.name)}: {description}".strip())
return format_code('"""\n{}\n"""'.format("\n".join(lines))) if lines else ""
[docs] def field_default_value(self, attr: Attr, ns_map: Optional[Dict] = None) -> Any:
"""Generate the field default value/factory for the given attribute."""
if attr.is_list or (attr.is_tokens and not attr.default):
return "list"
if attr.is_dict:
return "dict"
if not isinstance(attr.default, str):
return attr.default
if attr.default.startswith("@enum@"):
return self.field_default_enum(attr)
types = converter.sort_types(
list(
{attr_type.native_type for attr_type in attr.types if attr_type.native}
)
)
if attr.is_tokens:
return self.field_default_tokens(attr, types)
return self.prepare_default_value(
converter.from_string(attr.default, types, ns_map=ns_map)
)
[docs] def field_default_enum(self, attr: Attr) -> str:
source, enumeration = attr.default[6:].split("::", 1)
source = next(x.alias or source for x in attr.types if x.name == source)
return f"{self.class_name(source)}.{self.constant_name(enumeration)}"
[docs] def field_default_tokens(self, attr: Attr, types: List[Type]) -> str:
assert isinstance(attr.default, str)
tokens = ", ".join(
str(self.prepare_default_value(converter.from_string(val, types)))
for val in attr.default.split()
)
return f"lambda: [{tokens}]"
[docs] def field_type(self, attr: Attr, parents: List[str]) -> str:
"""Generate type hints for the given attribute."""
type_names = unique_sequence(
self.field_type_name(x, parents) for x in attr.types
)
result = ", ".join(type_names)
if len(type_names) > 1:
result = f"Union[{result}]"
if attr.is_tokens:
result = f"List[{result}]"
if attr.is_list:
result = f"List[{result}]"
elif attr.is_dict:
result = "Dict"
elif attr.default is None and not attr.is_factory:
result = f"Optional[{result}]"
return result
[docs] def choice_type(self, choice: AttrChoice, parents: List[str]) -> str:
"""
Generate type hints for the given choice.
Choices support a subset of features from normal attributes.
First of all we don't have a proper type hint but a type
metadata key. That's why we always need to wrap as Type[xxx].
The second big difference is that our choice belongs to a
compound field that might be a list, that's why list restriction
is also ignored.
"""
type_names = unique_sequence(
self.field_type_name(x, parents) for x in choice.types
)
result = ", ".join(type_names)
if len(type_names) > 1:
result = f"Union[{result}]"
if choice.is_tokens:
result = f"List[{result}]"
return f"Type[{result}]"
[docs] def field_type_name(self, attr_type: AttrType, parents: List[str]) -> str:
name = (
self.class_name(attr_type.alias)
if attr_type.alias
else self.type_name(attr_type)
)
if attr_type.forward and attr_type.circular:
outer_str = ".".join(map(self.class_name, parents))
name = f'"{outer_str}"'
elif attr_type.forward:
outer_str = ".".join(map(self.class_name, parents))
name = f'"{outer_str}.{name}"'
elif attr_type.circular:
name = f'"{name}"'
return name
[docs] def constant_value(self, attr: Attr) -> str:
"""Return the attr default value or type as constant value."""
attr_type = attr.types[0]
if attr_type.native:
return f'"{attr.default}"'
if attr_type.alias:
return self.class_name(attr_type.alias)
return self.type_name(attr_type)
[docs] @classmethod
def prepare_default_value(cls, value: Any) -> Any:
if isinstance(value, str):
return quoteattr(value)
if isinstance(value, float):
return f"float('{value}')" if math.isinf(value) else value
if isinstance(value, Decimal):
return repr(value)
if isinstance(value, QName):
return f'QName("{value.text}")'
return value
[docs] @classmethod
def type_is_included(cls, output: str, type_name: str) -> bool:
return (
f": {type_name}" in output
or f"[{type_name}" in output
or f", {type_name}" in output
or f"= {type_name}" in output
)
[docs] @classmethod
def default_imports(cls, output: str) -> str:
"""Generate the default imports for the given package output."""
result = []
dataclasses = []
if "@dataclass" in output:
dataclasses.append("dataclass")
if "field(" in output:
dataclasses.append("field")
if dataclasses:
result.append(f"from dataclasses import {', '.join(dataclasses)}")
if cls.type_is_included(output, "Decimal"):
result.append("from decimal import Decimal")
if "(Enum)" in output:
result.append("from enum import Enum")
typing_patterns = {
"Dict": [": Dict"],
"List": [": List["],
"Optional": ["Optional["],
"Type": ["Type["],
"Union": ["Union["],
}
types = [
name
for name, patterns in typing_patterns.items()
if any(pattern in output for pattern in patterns)
]
if types:
result.append(f"from typing import {', '.join(types)}")
if cls.type_is_included(output, "QName"):
result.append("from xml.etree.ElementTree import QName")
return "\n".join(result)
[docs] @classmethod
def from_config(cls, config: GeneratorConfig) -> "Filters":
def index_aliases(aliases: List[GeneratorAlias]) -> Dict:
return {alias.source: alias.target for alias in aliases}
return cls(
class_aliases=index_aliases(config.aliases.class_name),
field_aliases=index_aliases(config.aliases.field_name),
package_aliases=index_aliases(config.aliases.package_name),
module_aliases=index_aliases(config.aliases.module_name),
class_case=config.conventions.class_name.case.func,
field_case=config.conventions.field_name.case.func,
package_case=config.conventions.package_name.case.func,
module_case=config.conventions.module_name.case.func,
class_safe_prefix=config.conventions.class_name.safe_prefix,
field_safe_prefix=config.conventions.field_name.safe_prefix,
package_safe_prefix=config.conventions.package_name.safe_prefix,
module_safe_prefix=config.conventions.module_name.safe_prefix,
)