Commit e7822fb7 authored by Micaël Bergeron's avatar Micaël Bergeron

make the serializers reusable

this way I can use it to load multiple files
parent 085d9ebc
......@@ -5,12 +5,12 @@ from enum import Enum
from elt.cli import ActionEnum, OptionEnum, parser_logging
from elt.utils import setup_logging
from elt.schema.serializers.singer import load
from elt.schema.serializers.meltano import dump
from elt.schema.serializers.singer import SingerSerializer
from elt.schema.serializers.meltano import MeltanoSerializer
def action_convert(args):
schema = load("singer", sys.stdin)
schema = load(, sys.stdin)
dump(sys.stdout, schema)
......
......@@ -3,7 +3,7 @@ import psycopg2
import psycopg2.sql
import psycopg2.extras
from typing import Sequence, Callable, Set
from typing import Sequence, Callable, Set, Union
from enum import Enum
from collections import OrderedDict, namedtuple
from elt.error import ExceptionAggregator, SchemaError, InapplicableChangeError
......@@ -44,6 +44,8 @@ Column = namedtuple('Column', [
class Schema:
Basis = Union[str, 'Schema']
@staticmethod
def mapping_key_name(column: Column):
return "{}_{}_mapping_key".format(column.table_name,
......@@ -57,6 +59,13 @@ class Schema:
def column_key(column: Column):
return (column.table_name, column.column_name)
@classmethod
def extend(cls, schema_or_name: Basis):
if isinstance(schema_or_name, Schema):
return schema_or_name
return Schema(schema_or_name)
def __init__(self, name, columns: Sequence[Column] = [],
primary_key_name='__row_id'):
self.name = name
......
import abc
from typing.io import TextIO
from typing import Union
from elt.schema import Schema
class Serializer:
def __init__(self, schema_or_name: Schema.Basis):
self._schema = Schema.extend(schema_or_name)
@property
def schema(self):
return self._schema
def loads(self, raw: str) -> 'Serializer':
raise NotImplementedError
def load(self, reader: TextIO) -> 'Serializer':
return loads(reader.read())
def dumps(self) -> str:
raise NotImplementedError
def dump(self, writer: TextIO):
writer.write(self.dumps(self.schema))
......@@ -2,6 +2,7 @@ import logging
from xml.etree import ElementTree
from elt.schema import Schema, Column, DBType
from .base import Serializer
data_type_map = {
......@@ -12,40 +13,40 @@ data_type_map = {
}
def loads(schema_name: str, raw: str) -> Schema:
tree = ElementTree.fromstring(raw)
schema = Schema(schema_name)
class KettleSerializer(Serializer):
def loads(self, raw: str) -> Schema:
tree = ElementTree.fromstring(raw)
sfdc_input_step = tree.find("step[type='SalesforceInput']")
table_name = sfdc_input_step.find("module").text
sfdc_input_step = tree.find("step[type='SalesforceInput']")
table_name = sfdc_input_step.find("module").text
for field in sfdc_input_step.iterfind("fields/field"):
schema.add_column(
field_column(schema_name, table_name, field)
)
for field in sfdc_input_step.iterfind("fields/field"):
self.schema.add_column(
self.field_column(table_name, field)
)
return schema
return self
def field_column(table_schema, table_name, element):
is_mapping_key = element.find("idlookup").text == "Y"
def field_column(self, table_name, element):
is_mapping_key = element.find("idlookup").text == "Y"
return Column(table_schema=table_schema,
table_name=table_name,
column_name=element.find("field").text,
data_type=field_data_type(element).value,
is_nullable=not is_mapping_key,
is_mapping_key=is_mapping_key)
return Column(table_schema=self.schema.name,
table_name=table_name,
column_name=element.find("field").text,
data_type=self.field_data_type(element).value,
is_nullable=not is_mapping_key,
is_mapping_key=is_mapping_key)
def field_data_type(element):
raw_type = element.find("type").text
raw_format = element.find("format").text
def field_data_type(self, element):
raw_type = element.find("type").text
raw_format = element.find("format").text
dt_type = data_type_map[raw_type]
dt_type = data_type_map[raw_type]
# date time can have a timezone or not, it depends on the format
if dt_type == DBType.Date:
dt_type = dt_type if raw_format == "yyyy-MM-dd" else DBType.Timestamp
# date time can have a timezone or not, it depends on the format
if dt_type == DBType.Date:
dt_type = dt_type if raw_format == "yyyy-MM-dd" else DBType.Timestamp
return dt_type
return dt_type
......@@ -5,64 +5,61 @@ import re
from typing import Generator
from functools import partial
from elt.schema import Schema, Column
from .base import Serializer
def tables(schema) -> Generator[dict, None, None]:
col_in_table = lambda table, col: col.table_name == table
class MeltanoSerializer(Serializer):
@staticmethod
def tables(schema) -> Generator[dict, None, None]:
col_in_table = lambda table, col: col.table_name == table
for table in schema.tables:
in_table = partial(col_in_table, table)
table_columns = list(filter(in_table, schema.columns.values()))
for table in schema.tables:
in_table = partial(col_in_table, table)
table_columns = list(filter(in_table, schema.columns.values()))
column_defs = {
col.column_name: col.data_type \
for col in table_columns
}
column_defs = {
col.column_name: col.data_type \
for col in table_columns
}
mapping_keys = {
Schema.mapping_key_name(col): col.column_name \
for col in table_columns \
if col.is_mapping_key
}
mapping_keys = {
Schema.mapping_key_name(col): col.column_name \
for col in table_columns \
if col.is_mapping_key
}
yield {table: {
**column_defs,
**mapping_keys,
}}
yield {table: {
**column_defs,
**mapping_keys,
}}
def dumps(self) -> str:
schema_def = dict()
for table_def in self.tables(self.schema):
schema_def.update(table_def)
def dumps(schema: Schema) -> str:
schema_def = dict()
for table_def in tables(schema):
schema_def.update(table_def)
return yaml.dump(schema_def, default_flow_style=False)
return yaml.dump(schema_def, default_flow_style=False)
def loads(self, yaml_str: str) -> Serializer:
raw = yaml.load(yaml_str)
def dump(writer, schema: Schema):
writer.write(dumps(schema))
columns = []
for table, table_data in raw.items():
for column, data_type in table_data.items():
if column.endswith("_mapping_key"):
continue
def load(schema_name: str, reader) -> Schema:
return loads(schema_name, reader.read())
# HACK: we should reformat this manifest file
mapping_key = "{}_{}_mapping_key".format(table, column)
is_mapping_key = mapping_key in table_data
def loads(schema_name: str, yaml_str: str) -> Schema:
raw = yaml.load(yaml_str)
column = Column(table_schema=self.schema.name,
table_name=table,
column_name=column,
data_type=data_type,
is_nullable=not is_mapping_key,
is_mapping_key=is_mapping_key)
columns = []
for table, table_data in raw.items():
for column, data_type in table_data.items():
if column.endswith("_mapping_key"):
continue
self.schema.add_column(column)
# HACK: we should reformat this manifest file
mapping_key = "{}_{}_mapping_key".format(table, column)
is_mapping_key = mapping_key in table_data
column = Column(table_schema=schema_name,
table_name=table,
column_name=column,
data_type=data_type,
is_nullable=not is_mapping_key,
is_mapping_key=is_mapping_key)
columns.append(column)
return Schema(schema_name, columns)
return self
import elt.schema.serializers.kettle as kettle
from elt.schema.serializers.kettle import KettleSerializer
def test_config():
schema = kettle.loads("sfdc", SAMPLE_KETTLE_CONFIG)
schema = KettleSerializer("sfdc") \
.loads(SAMPLE_KETTLE_CONFIG) \
.loads(SAMPLE_KETTLE_CONFIG_EXTRA) \
.schema
import pdb; pdb.set_trace()
......@@ -95,3 +98,26 @@ SAMPLE_KETTLE_CONFIG = """<?xml version="1.0" encoding="UTF-8"?>
</step>
</transformation>
"""
SAMPLE_KETTLE_CONFIG_EXTRA = """<?xml version="1.0" encoding="UTF-8"?>
<transformation>
<step>
<name>Insert / Update</name>
<type>InsertUpdate</type>
</step>
<step>
<name>Salesforce Input</name>
<type>SalesforceInput</type>
<module>User</module>
<fields>
<field>
<name>Extra field</name>
<field>extra_field</field>
<idlookup>N</idlookup>
<type>String</type>
<format />
</field>
</fields>
</step>
</transformation>
"""
import elt.schema.serializers.meltano as serializer
from functools import partial
from itertools import chain
from elt.schema import Schema, DBType, Column
from elt.schema.serializers.meltano import MeltanoSerializer
def sample_schema(table_names=()):
......@@ -37,7 +36,7 @@ def sample_schema(table_names=()):
def test_dumps():
schema = sample_schema(table_names=('entity01', 'entity02'))
yaml = serializer.dumps(schema)
yaml = MeltanoSerializer(schema).dumps()
assert(yaml)
......@@ -48,13 +47,14 @@ entity01:
long: bigint
text: text
entity01_long_mapping_key: long
"""
schema = serializer.loads('yaml', yaml_schema)
""
schema = MeltanoSerializer('yaml').loads(yaml_schema).schema
assert(len(schema.columns.values()) == 3)
def test_idempotent():
schema = sample_schema(table_names=('entity01', 'entity02'))
schema2 = serializer.loads('pytest', serializer.dumps(schema))
serializer = MeltanoSerializer(schema)
schema2 = MeltanoSerializer(schema.name).loads(serializer.dumps()).schema
assert(len(schema.tables) == len(schema2.tables))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment