Commit b3a610cb authored by Micaël Bergeron's avatar Micaël Bergeron

load the manifest schema into the zuora schema

parent 00a25029
from .schema import *
...@@ -43,6 +43,10 @@ Column = namedtuple('Column', [ ...@@ -43,6 +43,10 @@ Column = namedtuple('Column', [
class Schema: class Schema:
def mapping_key_name(column: Column):
return "{}_{}_mapping_key".format(column.table_name,
column.column_name)
def table_key(column: Column): def table_key(column: Column):
return column.table_name return column.table_name
...@@ -92,6 +96,7 @@ def db_schema(db_conn, schema_name) -> Schema: ...@@ -92,6 +96,7 @@ def db_schema(db_conn, schema_name) -> Schema:
SELECT table_schema, table_name, column_name, udt_name::regtype as data_type, is_nullable = 'YES', NULL as is_mapping_key SELECT table_schema, table_name, column_name, udt_name::regtype as data_type, is_nullable = 'YES', NULL as is_mapping_key
FROM information_schema.columns FROM information_schema.columns
WHERE table_schema = %s WHERE table_schema = %s
AND column_name != '__row_id'
ORDER BY ordinal_position; ORDER BY ordinal_position;
""", (schema_name,)) """, (schema_name,))
...@@ -135,7 +140,10 @@ def schema_apply(db_conn, target_schema: Schema): ...@@ -135,7 +140,10 @@ def schema_apply(db_conn, target_schema: Schema):
for name, col in target_schema.columns.items(): for name, col in target_schema.columns.items():
results.call(schema_apply_column, results.call(schema_apply_column,
schema_cursor, schema, target_schema, col) schema_cursor,
schema,
target_schema,
col)
results.raise_aggregate() results.raise_aggregate()
...@@ -157,6 +165,7 @@ def schema_apply_column(db_cursor, ...@@ -157,6 +165,7 @@ def schema_apply_column(db_cursor,
:target_schema: Target schema (to apply) :target_schema: Target schema (to apply)
:column: the column to apply :column: the column to apply
""" """
diff = schema.column_diff(column) diff = schema.column_diff(column)
identifier = ( identifier = (
psycopg2.sql.Identifier(column.table_schema), psycopg2.sql.Identifier(column.table_schema),
...@@ -188,12 +197,10 @@ def schema_apply_column(db_cursor, ...@@ -188,12 +197,10 @@ def schema_apply_column(db_cursor,
db_cursor.execute(sql) db_cursor.execute(sql)
if column.is_mapping_key: if column.is_mapping_key:
constraint = "{table}_{column}_mapping_key".format(table=column.table_name,
column=column.column_name)
stmt = "ALTER TABLE {}.{} ADD CONSTRAINT {} UNIQUE ({})" stmt = "ALTER TABLE {}.{} ADD CONSTRAINT {} UNIQUE ({})"
sql = psycopg2.sql.SQL(stmt).format( sql = psycopg2.sql.SQL(stmt).format(
*identifier, *identifier,
psycopg2.sql.Identifier(constraint), psycopg2.sql.Identifier(Schema.mapping_key_name(column)),
psycopg2.sql.Identifier(column.column_name), psycopg2.sql.Identifier(column.column_name),
) )
db_cursor.execute(sql) db_cursor.execute(sql)
......
import yaml
import pathlib
import re
from typing import Generator
from functools import partial
from elt.schema import Schema, Column
def tables(schema) -> Generator[dict, None, None]:
col_in_table = lambda table, col: col.table_name == table
for table in schema.tables:
in_table = partial(col_in_table, table)
table_columns = list(filter(in_table, schema.columns.values()))
column_defs = {
col.column_name: col.data_type \
for col in table_columns
}
mapping_keys = {
Schema.mapping_key_name(col): col.column_name \
for col in table_columns \
if col.is_mapping_key
}
yield {table: {
**column_defs,
**mapping_keys,
}}
def dumps(schema: Schema) -> str:
schema_def = dict()
for table_def in tables(schema):
schema_def.update(table_def)
return yaml.dump(schema_def, default_flow_style=False)
def dump(writer, schema: Schema):
writer.write(dumps(schema))
def load(schema_name: str, reader) -> Schema:
return loads(schema_name, reader.read())
def loads(schema_name: str, yaml_str: str) -> Schema:
raw = yaml.load(yaml_str)
columns = []
for table, table_data in raw.items():
for column, data_type in table_data.items():
if column.endswith("_mapping_key"):
continue
# HACK: we should reformat this manifest file
mapping_key = "{}_{}_mapping_key".format(table, column)
is_mapping_key = mapping_key in table_data
column = Column(table_schema=schema_name,
table_name=table,
column_name=column,
data_type=data_type,
is_nullable=not is_mapping_key,
is_mapping_key=is_mapping_key)
columns.append(column)
return Schema(schema_name, columns)
import elt.schema.serializer as serializer
from functools import partial
from itertools import chain
from elt.schema import Schema, DBType, Column
def sample_schema(table_names=()):
table_schema = 'pytest'
# curry the Column object
def column(table_name, column_name, data_type, *,
is_nullable=True,
is_mapping_key=False):
return Column(table_schema=table_schema,
table_name=table_name,
column_name=column_name,
data_type=data_type.value,
is_nullable=is_nullable,
is_mapping_key=is_mapping_key)
def table(table_name):
table_column = partial(column, table_name)
return [
table_column("id", DBType.Integer, is_mapping_key=True),
table_column("string", DBType.String),
table_column("long", DBType.Long, is_nullable=False),
table_column("bool", DBType.Boolean),
table_column("date", DBType.Date),
table_column("ao_strings", DBType.ArrayOfString),
table_column("json", DBType.JSON),
table_column("ao_long", DBType.ArrayOfLong),
]
return Schema(table_schema, chain(*(table(name) for name in table_names)))
def test_dumps():
schema = sample_schema(table_names=('entity01', 'entity02'))
yaml = serializer.dumps(schema)
assert(yaml)
def test_loads():
yaml_schema = """
entity01:
id: integer
long: bigint
text: text
entity01_long_mapping_key: long
"""
schema = serializer.loads('yaml', yaml_schema)
assert(len(schema.columns.values()) == 3)
def test_idempotent():
schema = sample_schema(table_names=('entity01', 'entity02'))
schema2 = serializer.loads('pytest', serializer.dumps(schema))
assert(len(schema.tables) == len(schema2.tables))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment