diff --git a/lib/rest.py b/lib/rest.py index 0a75600..ab9a8f9 100644 --- a/lib/rest.py +++ b/lib/rest.py @@ -22,7 +22,7 @@ from sqlalchemy.orm.exc import NoResultFound from pyramid.httpexceptions import HTTPNotFound, HTTPBadRequest, HTTPServerError from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker, Session, scoped_session +from sqlalchemy.orm import sessionmaker, Session, scoped_session, class_mapper from pyramid.config import Configurator from pyramid.request import Request @@ -595,7 +595,7 @@ def description(self, request): """ return self.model.description(self.dictionary) - def m_to_n_handling(self, key, value): + def m_to_n_handling(self, key, value, session): """ Small helper method. It checks if the passed key is defined as an m_to_n relation ship to other tables. If it is: The method gets the corresponding objects from database by the passed (comma seperated) id's and return @@ -605,6 +605,8 @@ def m_to_n_handling(self, key, value): :type key: str :param value: The effective value which is intended to be set to the column :type value: + :param session: The session instance which should be used by this method + :type session: Session :return: The found results of corresponding datasets or False if the checked column was not a m_to_n one :rtype: bool or list """ @@ -612,12 +614,12 @@ def m_to_n_handling(self, key, value): # found the column which is m:n if column.get('is_m_to_n') and key == column.get('column_name'): value_list = str(value).split(',') - bound_model = getattr(self.model, key).argument + bound_model = class_mapper(self.model).get_property(key).argument pk_name = bound_model.description(self.dictionary).get('pk_name') pk_column = getattr(bound_model, pk_name) relation_list = [] for identifier in value_list: - result = self.session.query(bound_model).filter(pk_column == identifier).one() + result = session.query(bound_model).filter(pk_column == identifier).one() relation_list.append(result) return relation_list return False @@ -645,7 +647,7 @@ def create(self, request): m_to_n = self.m_to_n_handling(key, value) if m_to_n: value = m_to_n - setattr(new_record, key, value) + setattr(new_record, key, value, session) session.add(new_record) session.flush() request.response.status_int = 201 @@ -683,7 +685,7 @@ def update(self, request): for key, value in data.iteritems(): if key == 'geom': value = WKBSpatialElement(buffer(wkt.loads(value).wkb), srid=2056) - m_to_n = self.m_to_n_handling(key, value) + m_to_n = self.m_to_n_handling(key, value, session) if m_to_n: value = m_to_n setattr(element, key, value)