Commit 26e1834b authored by Micaël Bergeron's avatar Micaël Bergeron

rework job to use sqlalchemy

parent c8375b0b
import psycopg2
import os
import contextlib
from sqlalchemy import create_engine, MetaData
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
db_config_keys = [
"host",
......@@ -10,6 +16,15 @@ db_config_keys = [
]
def engine_uri(**db_config):
return "postgresql://{user}:{password}@{host}:{port}/{database}".format(**db_config)
SystemModel = declarative_base(metadata=MetaData(schema='meltano'))
Model = declarative_base()
Session = sessionmaker()
class DB:
db_config = {
'host': os.getenv('PG_ADDRESS', 'localhost'),
......@@ -20,12 +35,14 @@ class DB:
}
connection_class = psycopg2.extensions.connection
_connection = None
_engine = None
@classmethod
def setup(self, open_persistent=False, **kwargs):
self.db_config.update({k: kwargs[k] for k in db_config_keys if k in kwargs})
# self._engine = create_engine(self.engine_uri())
self._connection = self.connect()
self._engine = create_engine(engine_uri(**self.db_config), creator=self.connect)
Session.configure(bind=self._engine)
@classmethod
def connect(self):
......@@ -36,12 +53,12 @@ class DB:
connection_factory=self.connection_class)
@classmethod
def open(self):
return db_open()
def session(self):
return session_open()
@classmethod
def engine_uri():
return "postgresql://{username}:{password}@{host}:{port}/{database}".format(**self.db_config)
def open(self):
return db_open()
@classmethod
def set_connection_class(self, cls):
......@@ -49,16 +66,31 @@ class DB:
@classmethod
def close(self):
if self.connection is not None:
self.connection.close()
class db_open:
def __enter__(self):
self.connection = DB.connect()
return self.connection
def __exit__(self, ex_type, ex_value, traceback):
if ex_value is None:
self.connection.commit()
else:
self.connection.rollback()
if self.engine is not None:
self._engine.dispose()
@contextlib.contextmanager
def db_open():
"""Provide a raw connection in a transaction"""
connection = DB.connect()
try:
yield connection
connection.commit()
except:
connection.rollback()
raise
@contextlib.contextmanager
def session_open():
"""Provide a transactional scope around a series of operations."""
session = Session()
try:
yield session
session.commit()
except:
session.rollback()
raise
finally:
session.close()
import psycopg2
import json
import sqlalchemy.types as types
from psycopg2.sql import Identifier, SQL, Placeholder
from enum import Enum
from elt.db import DB
from elt.schema import Schema, Column, DBType
from elt.error import Error
from functools import partial
from elt.db import DB, SystemModel
from elt.schema import Schema, Column as SchemaColumn, DBType
from elt.error import Error
from sqlalchemy import Column
PG_SCHEMA = 'meltano'
PG_TABLE = 'job_runs'
PG_TABLE = 'job'
PRIMARY_KEY = 'id'
......@@ -26,24 +28,6 @@ class ImpossibleTransitionError(Error):
"""
def describe_schema() -> Schema:
def job_column(name, data_type, is_nullable=False):
return Column(table_name=PG_TABLE,
table_schema=PG_SCHEMA,
column_name=name,
data_type=data_type.value,
is_nullable=is_nullable,
is_mapping_key=False)
return Schema(PG_SCHEMA, [
job_column('elt_uri', DBType.String),
job_column('state', DBType.String),
job_column('started_at', DBType.Timestamp, is_nullable=True),
job_column('ended_at', DBType.Timestamp, is_nullable=True),
job_column('payload', DBType.JSON, is_nullable=True),
], primary_key_name='id')
class State(Enum):
SUCCESS = (2, ())
FAIL = (3, ())
......@@ -57,53 +41,15 @@ class State(Enum):
return self.name
class Job:
"""
Represents a Job at a certain state (State).
"""
schema_name = PG_SCHEMA
table_name = PG_TABLE
@classmethod
def identifier(self):
return map(Identifier, (self.schema_name, self.table_name))
def __init__(self, elt_uri,
id=None,
state=State.IDLE,
started_at=None,
ended_at=None,
payload={}):
self.id = id
self.elt_uri = elt_uri
self._state = state
self._started_at = started_at
self._ended_at = ended_at
self.payload = payload
@property
def state(self):
return self._state
@property
def started_at(self):
return self._started_at
@started_at.setter
def started_at(self, value):
if self.state != State.RUNNING:
raise InconsistentStateError(self.state)
self._started_at = value
@property
def ended_at(self):
return self._ended_at
@ended_at.setter
def ended_at(self, value):
if self.state not in (State.SUCCESS, State.FAIL):
raise InconsistentStateError(self.state)
self._ended_at = value
class Job(SystemModel):
__tablename__ = 'job'
id = Column(types.Integer, primary_key=True)
elt_uri = Column(types.String)
state = Column(types.Enum(State))
started_at = Column(types.DateTime)
ended_at = Column(types.DateTime)
payload = Column(types.JSON)
def transit(self, state: State) -> (State, State):
transition = (self.state, state)
......@@ -111,52 +57,30 @@ class Job:
if state.name not in self.state.transitions():
raise ImpossibleTransitionError(transition)
self._state = state
self.state = state
return transition
def __dict__(self):
return {
'state': str(self.state),
'elt_uri': str(self.elt_uri),
'started_at': self.started_at,
'ended_at': self.ended_at,
'payload': json.dumps(self.payload),
}
@classmethod
def save(self, job):
job_serial = job.__dict__()
columns, values = (job_serial.keys(), job_serial.values())
insert = SQL(("INSERT INTO {}.{} ({}) "
"VALUES ({}) "))
with DB.open() as db, db.cursor() as cursor:
cursor.execute(
insert.format(
*self.identifier(),
SQL(",").join(map(Identifier, columns)),
SQL(",").join(Placeholder() * len(values)),
),
list(values),
)
return job
@classmethod
def for_elt(self, elt_uri, limit=100):
fetch = SQL(("SELECT elt_uri, state, started_at, ended_at, payload FROM {}.{} "
"WHERE elt_uri = %s "
"ORDER BY started_at DESC "
"LIMIT %s ")).format(*self.identifier())
def as_job(row):
return Job(row[0],
state=State[row[1]],
started_at=row[2],
ended_at=row[3],
payload=row[4])
with DB.open() as db, db.cursor() as cursor:
cursor.execute(fetch, (elt_uri, limit))
return list(map(as_job, cursor.fetchall()))
def __repr__(self):
return "<Job(id='%s', elt_uri='%s', state='%s')>" % (
self.id, self.elt_uri, self.state)
def describe_schema() -> Schema:
def job_column(name, data_type, is_nullable=False):
return SchemaColumn(table_name=PG_TABLE,
table_schema=PG_SCHEMA,
column_name=name,
data_type=data_type.value,
is_nullable=is_nullable,
is_mapping_key=False)
return Schema(PG_SCHEMA, [
job_column('elt_uri', DBType.String),
job_column('state', DBType.String),
job_column('started_at', DBType.Timestamp, is_nullable=True),
job_column('ended_at', DBType.Timestamp, is_nullable=True),
job_column('payload', DBType.JSON, is_nullable=True),
], primary_key_name='id')
def save(job):
with DB.session() as session:
session.add(job)
......@@ -49,11 +49,12 @@ class Schema:
def column_key(column: Column):
return (column.table_name, column.column_name)
def __init__(self, name, columns: Sequence[Column] = [], primary_key_name='__row_id'):
def __init__(self, name, columns: Sequence[Column] = [],
primary_key_name='__row_id'):
self.name = name
self.tables = set()
self.primary_key_name = primary_key_name
self.columns = OrderedDict()
self.primary_key_name = primary_key_name
for column in columns:
self.tables.add(Schema.table_key(column))
......@@ -80,7 +81,7 @@ class Schema:
return {SchemaDiff.COLUMN_OK}
def db_schema(db_conn, schema_name, primary_key_name='__row_id') -> Schema:
def db_schema(db_conn, schema_name) -> Schema:
"""
:db_conn: psycopg2 db_connection
:schema: database schema
......@@ -95,7 +96,7 @@ def db_schema(db_conn, schema_name, primary_key_name='__row_id') -> Schema:
""", (schema_name,))
columns = map(Column._make, cursor.fetchall())
return Schema(schema_name, columns, primary_key_name=primary_key_name)
return Schema(schema_name, columns)
def ensure_schema_exists(db_conn, schema_name):
......@@ -127,14 +128,14 @@ def schema_apply(db_conn, target_schema: Schema):
"""
ensure_schema_exists(db_conn, target_schema.name)
schema = db_schema(db_conn, target_schema.name,
primary_key_name=target_schema.primary_key_name)
schema = db_schema(db_conn, target_schema.name)
results = ExceptionAggregator(InapplicableChangeError)
schema_cursor = db_conn.cursor()
for name, col in target_schema.columns.items():
results.call(schema_apply_column, schema_cursor, schema, col)
results.call(schema_apply_column,
schema_cursor, schema, target_schema, col)
results.raise_aggregate()
......@@ -142,13 +143,18 @@ def schema_apply(db_conn, target_schema: Schema):
db_conn.commit()
def schema_apply_column(db_cursor, schema: Schema, column: Column) -> Set[SchemaDiff]:
def schema_apply_column(db_cursor,
schema: Schema,
target_schema: Schema,
column: Column) -> Set[SchemaDiff]:
"""
Apply the schema to the current database connection
adapting tables as it goes. Currently only supports
adding new columns.
:cursor: A database connection
:schema: Source schema (database)
:target_schema: Target schema (to apply)
:column: the column to apply
"""
diff = schema.column_diff(column)
......@@ -157,22 +163,19 @@ def schema_apply_column(db_cursor, schema: Schema, column: Column) -> Set[Schema
psycopg2.sql.Identifier(column.table_name),
)
if SchemaDiff.COLUMN_OK in diff:
logging.debug("[{}]: {}".format(column.column_name, diff))
if SchemaDiff.COLUMN_CHANGED in diff:
raise InapplicableChangeError(diff)
if SchemaDiff.TABLE_MISSING in diff:
stmt = "CREATE TABLE {}.{} ({} SERIAL PRIMARY KEY)"
sql = psycopg2.sql.SQL(stmt).format(
*identifier,
psycopg2.sql.Identifier(schema.primary_key_name)
)
logging.debug("Creating table {}.{}".format(*identifier))
sql = psycopg2.sql.SQL(stmt).format(*identifier,
psycopg2.sql.Identifier(target_schema.primary_key_name))
db_cursor.execute(sql)
schema.add_table(column)
if SchemaDiff.COLUMN_OK not in diff:
logging.debug("[{}]: {}".format(column.column_name, diff))
if SchemaDiff.COLUMN_MISSING in diff:
stmt = "ALTER TABLE {}.{} ADD COLUMN {} %s"
if not column.is_nullable:
......
import pytest
from elt.db import DB
from elt.job import Job, State
from datetime import datetime
def sample_job(payload={}):
return Job('elt://bizops/sample-elt',
return Job(elt_uri='elt://bizops/sample-elt',
state=State.IDLE,
payload=payload)
def test_save(db):
assert(Job.save(sample_job()))
job = sample_job()
Job.save(job)
assert(job)
def test_load(db):
for i in range(0, 10):
Job.save(sample_job({'key': i}))
with DB.session() as session:
[session.add(sample_job({'key': i})) for i in range(0, 10)]
import pdb; pdb.set_trace()
session.flush()
jobs = Job.for_elt('elt://bizops/sample-elt')
[print(x.__dict__()) for x in jobs]
assert(len(jobs) == 10)
import pdb; pdb.set_trace()
with DB.session() as s:
jobs = s.query(Job).filter_by(elt_uri='elt://bizops/sample-elt')
assert(len(jobs.all()) == 10)
def test_transit(db):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment