Skip to content

Commit

Permalink
Auto gen basic schema: str, int, float. #2
Browse files Browse the repository at this point in the history
  • Loading branch information
fernandojunior committed Mar 28, 2016
1 parent dab7347 commit 5338836
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 43 deletions.
16 changes: 4 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,12 @@ A Python object relational mapper for SQLite.
from orm import Model

class Post(Model):
def __init__(self, text, id=None):
self.id = id
self.text = text

@classmethod
def schema(cls):
return '''
drop table if exists post;
text = str

def __init__(self, text):
self.text = text

create table post (
id integer primary key autoincrement,
text text not null
);
'''
```

* Import `Database` to create a data access object (DAO).
Expand Down
46 changes: 28 additions & 18 deletions orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@ def cut_attrs(obj, keys):
return dict(i for i in obj.__dict__.items() if i[0] not in keys)


def render_schema(model):
schema = 'create table {table} (id integer primary key autoincrement, {columns});' # noqa
datatypes = {str: 'text', int: 'integer', float: 'real'}
iscol = lambda key, value: key[0] is not '_' and value in datatypes.keys()
colrender = lambda key, value: '%s %s' % (key, datatypes[value])
cols = [colrender(*i) for i in model.__dict__.items() if iscol(*i)]
values = {'table': model.__name__, 'columns': ', '.join(cols)}
return schema.format(**values)


class Database(object):

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -50,53 +60,58 @@ class Manager(object):
def __init__(self, db, model):
self.db = db
self.model = model
self.table_name = model.__name__.lower()
self.tablename = model.__name__
if not self._hastable():
self.db.executescript(self.model.schema())
self.db.executescript(render_schema(self.model))

def all(self):
cursor = self.db.execute('select * from %s' % self.table_name)
return (self.model(**row) for row in cursor.fetchall())
cursor = self.db.execute('select * from %s' % self.tablename)
return (self.create(**row) for row in cursor.fetchall())

def create(self, **kwargs):
obj = object.__new__(self.model)
obj.__dict__ = kwargs
return obj

def delete(self, obj):
sql = 'DELETE from %s WHERE id = ?'
self.db.execute(sql % self.table_name, obj.id)
self.db.execute(sql % self.tablename, obj.id)

def get(self, id):
sql = 'select * from %s where id = ?' % self.table_name
sql = 'select * from %s where id = ?' % self.tablename
cursor = self.db.execute(sql, id)
row = cursor.fetchone()
if not row:
msg = 'Object%s with id does not exist: %s' % (self.model, id)
raise ValueError(msg)
return self.model(**row)
return self.create(**row)

def has(self, id):
sql = 'select id from %s where id = ?' % self.table_name
sql = 'select id from %s where id = ?' % self.tablename
cursor = self.db.execute(sql, id)
return True if cursor.fetchall() else False

def save(self, obj):
if obj.id and self.has(obj.id):
if hasattr(obj, 'id') and self.has(obj.id):
msg = 'Object%s id already registred: %s' % (self.model, obj.id)
raise ValueError(msg)
copy_ = cut_attrs(obj, 'id')
keys = '(%s)' % ', '.join(copy_.keys()) # (key1, ...)
refs = '(%s)' % ', '.join('?' for i in range(len(copy_))) # (?, ...)
sql = 'insert into %s %s values %s' % (self.table_name, keys, refs)
sql = 'insert into %s %s values %s' % (self.tablename, keys, refs)
cursor = self.db.execute(sql, *copy_.values())
obj.id = cursor.lastrowid
return obj

def update(self, obj):
copy_ = cut_attrs(obj, 'id')
keys = '= ?, '.join(copy_.keys()) + '= ?' # key1 = ?, ...
sql = 'UPDATE %s SET %s WHERE id = ?' % (self.table_name, keys)
sql = 'UPDATE %s SET %s WHERE id = ?' % (self.tablename, keys)
self.db.execute(sql, *(copy_.values() + [obj.id]))

def _hastable(self):
sql = 'select name len FROM sqlite_master where type = ? AND name = ?'
cursor = self.db.execute(sql, 'table', self.table_name)
cursor = self.db.execute(sql, 'table', self.tablename)
return True if cursor.fetchall() else False


Expand All @@ -122,9 +137,4 @@ def __repr__(self):

@classmethod
def manager(cls, db=None):
db = db if db else cls.db
return Manager(cls.db, cls)

@classmethod
def schema(cls):
raise NotImplementedError
return Manager(db if db else cls.db, cls)
21 changes: 8 additions & 13 deletions tests.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,23 @@
import os
from random import random
from orm import Database

db = Database('db.sqlite.test')


class Post(db.Model):

def __init__(self, text, id=None):
self.id = id
self.text = text
random = float
text = str

@classmethod
def schema(cls):
return '''
drop table if exists post;
create table post (
id integer primary key autoincrement,
text text not null
);
'''
def __init__(self, text):
self.text = text
self.random = random()

try:
post = Post('Hello World').save()
assert(post.id == 1)
assert(isinstance(post.random, float))
post.text = 'Hello Mundo'
post.update()
db.commit()
Expand All @@ -31,7 +26,7 @@ def schema(cls):
db.commit()
objects = Post.manager()
objects.save(Post('Hello World'))
assert(objects.get(2).public == {'text': 'Hello World', 'id': 2})
assert(objects.get(2).public.keys() == ['text', 'random', 'id'])
db.close()
assert(list(objects.all()) == [])
finally:
Expand Down

0 comments on commit 5338836

Please sign in to comment.