Source code for xsdata.generators
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Iterator, List, Optional, Tuple
from jinja2 import Environment, FileSystemLoader, Template
from xsdata.formats.dataclass.utils import replace_words
from xsdata.models.codegen import Attr, Class, Package
from xsdata.models.elements import Schema
from xsdata.models.enums import XSDType
from xsdata.resolver import DependenciesResolver
from xsdata.utils import text
[docs]class AbstractGenerator(ABC):
templates_dir: Optional[Path] = None
def __init__(self):
if self.templates_dir is None:
raise TypeError("Missing renderer templates directory")
self.env = Environment(
loader=FileSystemLoader(str(self.templates_dir))
)
self.resolver = DependenciesResolver()
[docs] def template(self, name: str) -> Template:
return self.env.get_template("{}.jinja2".format(name))
[docs] @abstractmethod
def render(
self, schema: Schema, classes: List[Class], package: str
) -> Iterator[Tuple[Path, str]]:
pass
[docs]class PythonAbstractGenerator(AbstractGenerator, ABC):
[docs] @classmethod
def process_class(cls, obj: Class, parents: List[str] = None) -> Class:
"""Normalize all class instance fields, extends, name and the inner
classes recursively."""
parents = parents or []
obj.name = cls.class_name(obj.name)
for extension in obj.extensions:
extension.name = cls.type_name(extension.name)
curr_parents = parents + [obj.name]
for inner in obj.inner:
cls.process_class(inner, curr_parents)
is_enum = obj.is_enumeration
for attr in obj.attrs:
if is_enum:
cls.process_enumeration(attr)
else:
cls.process_attribute(attr, curr_parents)
return obj
[docs] @classmethod
def process_attribute(cls, attr: Attr, parents) -> None:
"""Normalize attribute properties."""
attr.name = cls.attribute_name(attr.name)
attr.type = cls.attribute_type(attr, parents)
attr.local_name = text.split(attr.local_name)[1]
attr.default = cls.attribute_default(attr)
[docs] @classmethod
def process_enumeration(cls, attr: Attr, *args) -> None:
"""Normalize attribute properties."""
attr.name = cls.enumeration_name(attr.name)
attr.default = cls.attribute_default(attr)
[docs] @classmethod
def process_import(cls, package: Package) -> Package:
"""Normalize import package properties."""
package.name = cls.class_name(package.name)
if package.alias:
package.alias = cls.class_name(package.alias)
return package
[docs] @classmethod
def class_name(cls, name: str) -> str:
"""Convert class names to pascal case."""
return text.pascal_case(name)
[docs] @classmethod
def type_name(cls, name: str) -> str:
"""Convert xsd types to python or apply class name conventions after
stripping any reference prefix."""
return XSDType.get_local(name) or cls.class_name(text.split(name)[1])
[docs] @classmethod
def attribute_name(cls, name: str) -> str:
"""
Strip reference prefix and turn to snake case.
If the name is one of the python reserved words append the
prefix _value
"""
local_name = text.split(name)[1]
return text.snake_case(
replace_words.get(local_name.lower(), local_name)
)
[docs] @classmethod
def enumeration_name(cls, name: str) -> str:
"""
Strip reference prefix and turn to snake case.
If the name is one of the python reserved words append the
prefix _value
"""
return cls.attribute_name(name).upper()
[docs] @classmethod
def attribute_type(cls, attr: Attr, parents: List[str]) -> str:
"""
Normalize attribute type.
Steps:
* If type alias is present use class name normalization
* Otherwise use the type name normalization
* Prepend outer class names and quote result for forward references
* Wrap the result with List if the field accepts a list of values
* Wrap the result with Optional if the field default value is None
"""
type_names: List[str] = []
for name in attr.types:
type_name = (
cls.class_name(attr.type_aliases[name])
if name in attr.type_aliases
else cls.type_name(name)
)
if type_name not in type_names:
type_names.append(type_name)
result = ", ".join(type_names)
if attr.forward_ref:
outer_str = ".".join(parents)
result = f'"{outer_str}.{result}"'
elif len(type_names) > 1:
result = f"Union[{result}]"
if attr.is_list:
result = f"List[{result}]"
elif attr.default is None:
result = f"Optional[{result}]"
return result
[docs] @classmethod
def attribute_default(cls, attr: Attr) -> Optional[Any]:
"""Normalize default value/factory by the attribute type."""
if attr.is_list:
return "list"
elif isinstance(attr.default, str):
if attr.type == "bool":
return attr.default == "true"
if attr.type == "int":
return int(attr.default)
if attr.type == "float":
return float(attr.default)
return f'"{attr.default}"'
else:
return attr.default