Source code for xsdata.transformer
from dataclasses import dataclass
from dataclasses import field
from pathlib import Path
from typing import List
from typing import Optional
from typing import Union
from xsdata.analyzer import ClassAnalyzer
from xsdata.builder import ClassBuilder
from xsdata.logger import logger
from xsdata.models.codegen import Class
from xsdata.models.elements import Import
from xsdata.models.elements import Include
from xsdata.models.elements import Override
from xsdata.models.elements import Redefine
from xsdata.models.elements import Schema
from xsdata.parser import SchemaParser
from xsdata.writer import writer
[docs]@dataclass
class SchemaTransformer:
print: bool
output: str
processed: List[Path] = field(init=False, default_factory=list)
[docs] def process(self, schema_path: Path, package: str):
classes = self.process_schema(schema_path, package)
classes = self.analyze_classes(classes)
class_num, inner_num = self.count_classes(classes)
if not class_num:
return logger.warning("Analyzer returned zero classes!")
logger.info("Analyzer: %d main and %d inner classes", class_num, inner_num)
writer.designate(classes, self.output)
if self.print:
writer.print(classes, self.output)
else:
writer.write(classes, self.output)
[docs] def process_schema(
self, schema_path: Path, package: str, target_namespace: Optional[str] = None,
) -> List[Class]:
"""Recursively parse the given schema and all it's included schemas and
generate a list of classes."""
classes = []
if schema_path not in self.processed:
self.processed.append(schema_path)
logger.info("Parsing schema...")
schema = self.parse_schema(schema_path, target_namespace)
target_namespace = schema.target_namespace
for sub in schema.included():
included_classes = self.process_included(sub, package, target_namespace)
classes.extend(included_classes)
classes.extend(self.generate_classes(schema, package))
else:
logger.debug("Already processed skipping: %s", schema_path.name)
return classes
[docs] def process_included(
self,
included: Union[Import, Include, Redefine, Override],
package: str,
target_namespace: Optional[str],
) -> List[Class]:
"""Prepare the given included schema location and send it for
processing."""
classes = []
if not included.location:
logger.warning(
"%s: %s unresolved schema location..",
included.class_name,
included.schema_location,
)
elif included.location in self.processed:
logger.debug(
"%s: %s already included skipping..",
included.class_name,
included.schema_location,
)
else:
package = self.adjust_package(package, included.schema_location)
classes = self.process_schema(included.location, package, target_namespace)
return classes
[docs] def generate_classes(self, schema: Schema, package: str):
"""Convert the given schema tree to codegen classes and use the writer
factory to either generate or print the result code."""
logger.info("Compiling schema...")
classes = ClassBuilder(schema=schema, package=package).build()
class_num, inner_num = self.count_classes(classes)
if class_num > 0:
logger.info("Builder: %d main and %d inner classes", class_num, inner_num)
return classes
[docs] @staticmethod
def parse_schema(schema_path: Path, target_namespace: Optional[str]) -> Schema:
"""
Parse the given schema path and return the schema tree object.
Optionally add the target namespace if the schema is included
and is missing a target namespace.
"""
parser = SchemaParser(target_namespace=target_namespace)
return parser.from_xsd_path(schema_path)
[docs] @staticmethod
def analyze_classes(classes: List[Class]):
"""Analyzer the given class list and simplify attributes and
extensions."""
analyzer = ClassAnalyzer()
return analyzer.process(classes)
[docs] @staticmethod
def adjust_package(package: str, location: Optional[str]) -> str:
"""
Adjust if possible the package name relatively to the schema location
to make sense.
eg. foo.bar, ../common/schema.xsd -> foo.common
"""
if location and not location.startswith("http"):
pp = package.split(".")
for part in Path(location).parent.parts:
if part == "..":
pp.pop()
else:
pp.append(part)
if pp:
return ".".join(pp)
return package
[docs] def count_classes(self, classes: List[Class]):
"""Return a tuple of counters for the main and inner classes."""
main = len(classes)
inner = 0
for cls in classes:
inner += sum(self.count_classes(cls.inner))
return main, inner