Commit c203be80 authored by Micaël Bergeron's avatar Micaël Bergeron

at last sqlalchemy plays nice with psycopg2

parent 26e1834b
import psycopg2
import os
import contextlib
import logging
from psycopg2.extras import LoggingConnection
from sqlalchemy import create_engine, MetaData
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.ext.declarative import declarative_base
......@@ -22,7 +25,9 @@ def engine_uri(**db_config):
SystemModel = declarative_base(metadata=MetaData(schema='meltano'))
Model = declarative_base()
Session = sessionmaker()
session_factory = sessionmaker()
Session = scoped_session(session_factory)
class DB:
......@@ -33,7 +38,7 @@ class DB:
'password': os.getenv('PG_PASSWORD'),
'database': os.getenv('PG_DATABASE'),
}
connection_class = psycopg2.extensions.connection
connection_class = LoggingConnection
_connection = None
_engine = None
......@@ -46,11 +51,20 @@ class DB:
@classmethod
def connect(self):
"""
Non thread-safe singleton database connection.
"""
if self._connection is not None:
return self._connection
return psycopg2.connect(**self.db_config,
conn = psycopg2.connect(**self.db_config,
connection_factory=self.connection_class)
conn.initialize(logging.getLogger(__name__))
return conn
@classmethod
def engine(self):
return self._engine
@classmethod
def session(self):
......@@ -72,8 +86,9 @@ class DB:
@contextlib.contextmanager
def db_open():
"""Provide a raw connection in a transaction"""
"""Provide a raw connection in a transaction."""
connection = DB.connect()
try:
yield connection
connection.commit()
......@@ -92,5 +107,3 @@ def session_open():
except:
session.rollback()
raise
finally:
session.close()
......@@ -4,12 +4,10 @@ import psycopg2
import psycopg2.sql as sql
import logging
from elt.db import DB
class NoCommitConnection(psycopg2.extensions.connection):
def commit(self):
print("db.commit() bypass for pytest")
from elt.db import DB, Session
from sqlalchemy import MetaData
logging.basicConfig(level=logging.INFO)
@pytest.fixture(scope='session')
def db_setup(request):
......@@ -20,15 +18,35 @@ def db_setup(request):
'user': os.getenv("PG_USERNAME"),
'password': os.getenv("PG_PASSWORD"),
}
DB.set_connection_class(NoCommitConnection)
DB.setup(**args)
truncate_tables(DB.engine(), schema='meltano')
@pytest.fixture()
@pytest.fixture(scope='function')
def db(request, db_setup):
connection = DB.connect()
def teardown():
connection.rollback()
truncate_tables(DB.engine(), schema='meltano')
request.addfinalizer(teardown)
return connection
@pytest.fixture(scope='function')
def session(request, db):
"""Creates a new database session for a test."""
return Session()
def truncate_tables(engine, schema):
# delete all table data (but keep tables)
# we do cleanup before test 'cause if previous test errored,
# DB can contain dust
con = engine.connect()
trans = con.begin()
con.execute("SET session_replication_role TO 'replica';")
meta = MetaData(bind=engine, reflect=True, schema=schema)
for table in meta.sorted_tables:
con.execute(table.delete())
con.execute("SET session_replication_role TO 'origin';")
trans.commit()
import elt.error
import logging
def do_raise(error_type, *args):
......@@ -14,7 +15,7 @@ def test_aggregate_default():
try:
aggregator.raise_aggregate()
except Exception as e:
print(str(e))
logging.info(str(e))
def test_aggregate_custom():
......@@ -30,7 +31,7 @@ def test_aggregate_custom():
try:
aggregator.raise_aggregate()
except CustomError as custom:
print(str(custom))
logging.info(str(custom))
except elt.error.Error as err:
raise "Catched by the elt.error.Error clause."
except Exception as e:
......
......@@ -10,26 +10,31 @@ def sample_job(payload={}):
payload=payload)
def test_save(db):
def test_save(session):
job = sample_job()
session.add(job)
session.commit()
assert(job.id > 0)
def test_save_raw(db):
job = sample_job()
Job.save(job)
assert(job)
assert(job.id > 0)
def test_load(db):
with DB.session() as session:
[session.add(sample_job({'key': i})) for i in range(0, 10)]
import pdb; pdb.set_trace()
session.flush()
def test_load(session):
[session.add(sample_job({'key': i})) for i in range(0, 10)]
session.commit()
import pdb; pdb.set_trace()
with DB.session() as s:
jobs = s.query(Job).filter_by(elt_uri='elt://bizops/sample-elt')
jobs = session.query(Job).filter_by(elt_uri='elt://bizops/sample-elt')
assert(len(jobs.all()) == 10)
def test_transit(db):
def test_transit(session):
j = sample_job()
transition = j.transit(State.RUNNING)
......@@ -40,4 +45,4 @@ def test_transit(db):
assert(transition == (State.RUNNING, State.SUCCESS))
j.ended_at = datetime.utcnow()
Job.save(j)
session.add(j)
from elt.db import DB
def test_connect():
db_conn = DB.connect()
assert(db_conn)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment