Commit 31d61758 authored by Nick Zaccardi's avatar Nick Zaccardi
Browse files

Improve Crud Setup

parents 51214dfe d63507d5
Loading
Loading
Loading
Loading
Loading
+17 −8
Original line number Diff line number Diff line
import logging
import falcon
import ujson
from marshmallow.schema import MarshalResult

import falcon_helpers.sqla.db as db

log = logging.getLogger(__name__)


class MarshmallowMiddleware:

    def _default_load(self, data, req, resource, params):
        schema = resource.schema()
        return schema.load(data, session=resource.session)

    def process_resource(self, req, resp, resource, params):
        # Try to get an instance from the `get_object` method on the resource so we can populate
        # already existing instances
        if hasattr(resource, 'get_object'):
            instance = resource.get_object(req=req, **params)
        else:
            instance = None

        return schema.load(data, session=db.session, instance=instance)

    def process_resource(self, req, resp, resource, params):
        should_parse = (
            # Veriy that it is safe to parse this resource
            req.method in ('POST', 'PUT'),
@@ -29,11 +41,8 @@ class MarshmallowMiddleware:
            req.context['_marshalled'] = False
            return

        stream = req.context['marshalled_stream'] = req.bounded_stream.read()
        data = ujson.loads(stream)

        if req.method == 'PUT':
            data['id'] = params['obj_id']
        req.context['marshalled_stream'] = req.stream.read()
        data = req._media = ujson.loads(req.context['marshalled_stream'])

        loaded = (self._default_load(data, req, resource, params)
                  if not hasattr(resource, 'schema_loader')
+28 −6
Original line number Diff line number Diff line
import logging
import ujson
import sqlalchemy as sa
import falcon
@@ -5,6 +6,9 @@ import falcon
from ..utils import flatten


log = logging.getLogger(__name__)


class ListBase:
    """A base class for returning list of objects.

@@ -184,10 +188,22 @@ class CrudBase:
    def get_object(self, req, **kwargs):
        try:
            obj_id = kwargs[self.default_param_name]
            return self.session.query(self.db_cls).get(obj_id)
        except KeyError:
            log.error(
                f'The resource {self.__class__.__name__} route is not using the correct parameter '
                f'name for the object identifier. Expecting `{self.default_param_name}` but it was '
                f'not in the matched route parameters. Add a `default_param_name` to the resource '
                f'which matches the route variable. Found these items: {",".join(kwargs.keys())}.'
            )
            raise falcon.HTTPInternalServerError("Misconfigured route")

        try:
            return self.session.query(self.db_cls).get(obj_id)
        except sa.exc.DataError as e:
            self.session.rollback()
            log.warning(f'Bad primary key given to  {self.__class__.__name__}')
            return None

    def on_get(self, req, resp, **kwargs):
        result = self.get_object(req, **kwargs)

@@ -198,7 +214,6 @@ class CrudBase:
        resp.body = schema.dump(result)
        resp.status = falcon.HTTP_200


    def on_put(self, req, resp, **kwargs):
        self.session.add(req.context['dto'].data)
        self.session.flush()
@@ -206,14 +221,21 @@ class CrudBase:
        resp.status = falcon.HTTP_200
        resp.body = self.schema().dump(req.context['dto'].data)


    def on_post(self, req, resp, **kwargs):
        self.session.add(req.context['dto'].data)

        try:
            self.session.flush()
        except sa.exc.IntegrityError as e:
            self.session.rollback()
            resp.status = falcon.HTTP_409
            resp.media = {
                'errors': ['An object with that identifier already exists.']
            }
            return

        resp.status = falcon.HTTP_201
        resp.body = self.schema().dump(req.context['dto'].data)

        resp.media = self.schema().dump(req.context['dto'].data).data

    def on_delete(self, req, resp, **kwargs):
        try:
+0 −0

Empty file added.

+189 −0
Original line number Diff line number Diff line
import falcon.testing
import pytest
import sqlalchemy as sa
import marshmallow as mm
import marshmallow_sqlalchemy as mms

import falcon_helpers.sqla.orm as orm
import falcon_helpers.sqla.db as db
from falcon_helpers.middlewares.marshmallow import MarshmallowMiddleware


@pytest.fixture()
def client():
    api = falcon.API(middleware=[
        MarshmallowMiddleware()
    ])

    return falcon.testing.TestClient(api)


class ObjEntity(orm.BaseColumns, orm.ModelBase):
    __tablename__ = 'obj_entity'

    name = sa.Column(sa.String)


class Obj(mms.ModelSchema):
    name = mm.fields.String()

    class Meta:
        model = ObjEntity


class WithoutSchemaResc(falcon.testing.SimpleTestResource):
    pass


class WithSchemaResc(falcon.testing.SimpleTestResource):
    schema = Obj


class WithCustomLoader(WithSchemaResc):
    def schema_loader(self, data, req, resource, params):
        result = self.schema().load(data, session=db.session)
        result.data.name = 'other'
        return result


def test_get_is_not_marshalled(client):
    resource = falcon.testing.SimpleTestResource()

    client.app.add_route('/test', resource)
    resp = client.simulate_get('/test')

    assert resp.status_code == 200
    assert not resource.captured_req.context['_marshalled']


def test_post_is_not_marshalled_without_body(client):
    resource = falcon.testing.SimpleTestResource()

    client.app.add_route('/test', resource)
    resp = client.simulate_post('/test')

    assert resp.status_code == 200
    assert not resource.captured_req.context['_marshalled']


def test_post_is_not_marshalled_with_json_but_has_schema(client):
    resource = WithoutSchemaResc()

    client.app.add_route('/test', resource)
    resp = client.simulate_post(
        '/test',
        json={}
    )

    assert resp.status_code == 200
    assert not resource.captured_req.context['_marshalled']


def test_turning_off_auto_marshalling(client):
    resource = WithSchemaResc()
    resource.auto_marshall = False

    client.app.add_route('/test', resource)
    resp = client.simulate_post(
        '/test',
        json={}
    )

    assert resp.status_code == 200
    assert not resource.captured_req.context['_marshalled']


def test_turning_verify_content_type_is_json(client):
    resource = WithSchemaResc()

    client.app.add_route('/test', resource)
    resp = client.simulate_post(
        '/test',
        body='["looks like json"]'
    )

    assert resp.status_code == 200
    assert not resource.captured_req.context['_marshalled']


def test_turning_verify_content_length(client):
    resource = WithSchemaResc()

    client.app.add_route('/test', resource)
    resp = client.simulate_post(
        '/test',
        headers={'Content-Type': 'application/json'},
        body=''
    )

    assert resp.status_code == 200
    assert not resource.captured_req.context['_marshalled']


def test_the_happy_path(client):
    resource = WithSchemaResc()
    client.app.add_route('/test', resource)
    resp = client.simulate_post(
        '/test',
        json={'name': 'john'}
    )

    assert resp.status_code == 200
    assert resource.captured_req.context['_marshalled']


def test_keeps_the_media_and_populates_the_raw_stream(client):
    resource = WithSchemaResc()
    client.app.add_route('/test', resource)

    resp = client.simulate_post(
        '/test',
        json={'name': 'john'}
    )

    assert resp.status_code == 200
    assert resource.captured_req.context['_marshalled']
    assert resource.captured_req.media == {'name': 'john'}
    assert resource.captured_req.context['marshalled_stream'] == b'{"name":"john"}'


def test_support_default_loader(client):
    resource = WithSchemaResc()
    client.app.add_route('/test', resource)

    resp = client.simulate_post(
        '/test',
        json={'name': 'john'}
    )

    assert resp.status_code == 200
    assert resource.captured_req.context['_marshalled']
    assert resource.captured_req.context['dto'].data.name == 'john'


def test_support_custom_loader(client):
    resource = WithCustomLoader()
    client.app.add_route('/test', resource)

    resp = client.simulate_post(
        '/test',
        json={'name': 'john'}
    )

    assert resp.status_code == 200
    assert resource.captured_req.context['_marshalled']
    assert resource.captured_req.context['dto'].data.name == 'other'


def test_errors_during_loading(client):
    resource = WithSchemaResc()
    client.app.add_route('/test', resource)

    resp = client.simulate_post(
        '/test',
        json={'name': 1}
    )

    assert resp.status_code == 400
    assert resp.json == {'errors': {'name': ['Not a valid string.']}}
    assert resource.captured_req is None
+21 −12
Original line number Diff line number Diff line
@@ -27,6 +27,7 @@ class ModelTest(Base, BaseColumns, BaseFunctions, Testable):
    __tablename__ = 'mtest'

    name = sa.Column(sa.Unicode, nullable=False)
    uni = sa.Column(sa.Unicode, unique=True)
    other = sa.orm.relationship("ModelOther")


@@ -55,6 +56,7 @@ class ListSub(ListBase):
        if kwargs['objid'] == 'zero':
            return []


@pytest.fixture
def app():
    Base.metadata.drop_all()
@@ -84,26 +86,36 @@ class TestCrudBase:
        resp = client.simulate_get('/crud/1')
        assert resp.status_code == 404


    def test_crud_base_get_500_with_misconfigured_route(self, client):
        resp = client.simulate_get('/bad')
        assert resp.status_code == 500


    def test_crud_base_get_200_with_object(self, client):
        m1 = ModelTest.testing_create()
        session.add(m1)
        session.commit()
        resp = client.simulate_get(f'/crud/{m1.id}')

        assert resp.status_code == 200
        assert resp.json == {
            'id': m1.id,
            'name': m1.name,
            'uni': m1.uni,
            'created_ts': m1.created_ts.replace(tzinfo=tz.utc).isoformat(),
            'updated_ts': m1.updated_ts.replace(tzinfo=tz.utc).isoformat(),
        }

    def test_crud_base_get_404_with_bad_primary_key(self, client):
        assert client.simulate_get(f'/crud/abs').status_code == 404

    def test_crud_base_post_duplicate_object(self, client):
        ModelTest.testing_create(uni='test')
        resp = client.simulate_post(
            f'/crud/new',
            json={
                'uni': 'test',
                'name': 'thing'
            })

        assert resp.status_code == 409

    def test_crud_base_get_404_with_wrong_id(self, client):
        m1 = ModelTest.testing_create()
@@ -113,7 +125,6 @@ class TestCrudBase:

        assert resp.status_code == 404


    def test_crud_base_post(self, client):
        resp = client.simulate_post(
            f'/crud/new',
@@ -125,7 +136,6 @@ class TestCrudBase:
        assert session.query(ModelTest).get(resp.json['id']).name == 'thing'
        assert resp.json['name'] == 'thing'


    def test_crud_base_delete(self, client):
        m1 = ModelTest.testing_create()
        session.add(m1)
@@ -134,7 +144,7 @@ class TestCrudBase:
        resp = client.simulate_delete(f'/crud/{m1.id}')

        assert resp.status_code == 204
        assert session.query(ModelTest).get(m1.id) == None
        assert session.query(ModelTest).get(m1.id) is None

    def test_crud_base_delete_with_relationship(self, client):
        m1 = ModelTest.testing_create()
@@ -152,7 +162,9 @@ class TestCrudBase:
        assert session.query(ModelOther).get(mo1.id) == mo1

        assert 'errors' in resp.json
        assert resp.json['errors'] == ['Unable to delete because the object is connected to other objects']
        assert resp.json['errors'] == [
            'Unable to delete because the object is connected to other objects'
        ]


class TestListBase:
@@ -164,7 +176,6 @@ class TestListBase:
        assert result.right.value == 'name'
        assert result.operator.__name__ == 'contains_op'


    def test_default_filter_for_column(self):
        lb = ListSub()
        lb.column_filters = {
@@ -218,15 +229,13 @@ class TestListBase:
        assert resp.json[0]['id'] == m1.id
        assert resp.json[0]['name'] == m1.name


    def test_listbase_get_sends_404_for_subobj_with_none_respose(self, client):
        resp = client.simulate_get(f'/list/missing/other')
        assert resp.status_code == 404
        assert 'error' in resp.json


    def test_listbase_get_sends_200_for_subobj_with_empty_respose(self, client):
        m1 = ModelTest.testing_create()
        ModelTest.testing_create()

        resp = client.simulate_get(f'/list/zero/other')
        assert resp.status_code == 200