Source code for xsdata.formats.dataclass.generator

from collections import defaultdict
from pathlib import Path
from typing import DefaultDict, Dict, Iterator, List, Tuple

from jinja2 import Environment, FileSystemLoader, Template

from xsdata.formats.dataclass.filters import filters
from xsdata.generators import PythonAbstractGenerator
from xsdata.models.codegen import Class, Package
from xsdata.models.elements import Schema
from xsdata.resolver import DependenciesResolver
from xsdata.utils.text import snake_case

[docs]class DataclassGenerator(PythonAbstractGenerator): def __init__(self): templates_dir = Path(__file__).parent.joinpath("templates") self.env = Environment(loader=FileSystemLoader(str(templates_dir)),) self.env.filters.update(filters) self.resolver = DependenciesResolver()
[docs] def template(self, name: str) -> Template: return self.env.get_template("{}.jinja2".format(name))
[docs] def render_module( self, output: str, imports: Dict[str, List[Package]] ) -> str: return self.template("module").render(output=output, imports=imports)
[docs] def render_class(self, obj: Class) -> str: template = "enum" if obj.is_enumeration else "class" return self.template(template).render(obj=obj)
[docs] def render( self, schema: Schema, classes: List[Class], package: str ) -> Iterator[Tuple[Path, str]]: """Given a schema, a list of classes and a target package return to the writer factory the target file path and the rendered code.""" module = snake_case(schema.module) package_arr = list(map(snake_case, package.split("."))) package = "{}.{}".format(".".join(package_arr), module) target = Path.cwd().joinpath(*package_arr) file_path = target.joinpath(f"{module}.py") self.resolver.process(classes=classes, schema=schema, package=package) target.mkdir(parents=True, exist_ok=True) imports = self.prepare_imports() output = self.render_classes() yield file_path, self.render_module(imports=imports, output=output)
[docs] def print( self, schema: Schema, classes: List[Class], package: str ) -> Iterator[Tuple[str, Class]]: module = snake_case(schema.module) package_arr = list(map(snake_case, package.split("."))) package = "{}.{}".format(".".join(package_arr), module) self.resolver.process(classes=classes, schema=schema, package=package) for obj in sorted(self.prepare_classes(), key=lambda x: yield package, obj
[docs] def render_classes(self) -> str: """Get a list of sorted classes from the imports resolver, apply the python code conventions and return the rendered output.""" output = "\n".join( map(self.render_class, self.prepare_classes()) ).strip() return f"\n\n{output}\n"
[docs] def prepare_classes(self): for obj in self.resolver.sorted_classes(): yield self.process_class(obj)
[docs] def prepare_imports(self) -> Dict[str, List[Package]]: """Get a list of sorted packages from the imports resolver apply the python code conventions, group them by the source package and return them.""" imports: DefaultDict[str, List[Package]] = defaultdict(list) for obj in self.resolver.sorted_imports(): imports[obj.source].append(self.process_import(obj)) return imports