0001 ### BITPIM 0002 ### 0003 ### Copyright (C) 2004 Roger Binns <rogerb@rogerbinns.com> 0004 ### 0005 ### This program is free software; you can redistribute it and/or modify 0006 ### it under the terms of the BitPim license as detailed in the LICENSE file. 0007 ### 0008 ### $Id: database.py 4370 2007-08-21 21:39:36Z djpham $ 0009 0010 """Interface to the database""" 0011 from __future__ import with_statement 0012 import os 0013 import copy 0014 import time 0015 import sha 0016 import random 0017 0018 import apsw 0019 0020 import common 0021 ### 0022 ### The first section of this file deals with typical objects used to 0023 ### represent data items and various methods for wrapping them. 0024 ### 0025 0026 0027 0028 class basedataobject(dict): 0029 """A base object derived from dict that is used for various 0030 records. Existing code can just continue to treat it as a dict. 0031 New code can treat it as dict, as well as access via attribute 0032 names (ie object["foo"] or object.foo). attribute name access 0033 will always give a result includes None if the name is not in 0034 the dict. 0035 0036 As a bonus this class includes checking of attribute names and 0037 types in non-production runs. That will help catch typos etc. 0038 For production runs we may be receiving data that was written out 0039 by a newer version of BitPim so we don't check or error.""" 0040 # which properties we know about 0041 _knownproperties=[] 0042 # which ones we know about that should be a list of dicts 0043 _knownlistproperties={'serials': ['sourcetype', '*']} 0044 # which ones we know about that should be a dict 0045 _knowndictproperties={} 0046 0047 if __debug__: 0048 # in debug code we check key name and value types 0049 0050 def _check_property(self,name,value=None): 0051 # check it 0052 assert isinstance(name, (str, unicode)), "keys must be a string type" 0053 assert name in self._knownproperties or name in self._knownlistproperties or name in self._knowndictproperties, "unknown property named '"+name+"'" 0054 if value is None: return 0055 if name in getattr(self, "_knownlistproperties"): 0056 assert isinstance(value, list), "list properties ("+name+") must be given a list as value" 0057 # each list member must be a dict 0058 for v in value: 0059 self._check_property_dictvalue(name,v) 0060 return 0061 if name in getattr(self, "_knowndictproperties"): 0062 assert isinstance(value, dict), "dict properties ("+name+") must be given a dict as value" 0063 self._check_property_dictvalue(name,value) 0064 return 0065 # the value must be a basetype supported by apsw/SQLite 0066 assert isinstance(value, (str, unicode, buffer, int, long, float)), "only serializable types supported for values" 0067 0068 def _check_property_dictvalue(self, name, value): 0069 assert isinstance(value, dict), "item(s) in "+name+" (a list) must be dicts" 0070 assert name in self._knownlistproperties or name in self._knowndictproperties 0071 if name in self._knownlistproperties: 0072 for key in value: 0073 assert key in self._knownlistproperties[name] or '*' in self._knownlistproperties[name], "dict key "+key+" as member of item in list "+name+" is not known" 0074 v=value[key] 0075 assert isinstance(v, (str, unicode, buffer, int, long, float)), "only serializable types supported for values" 0076 elif name in self._knowndictproperties: 0077 for key in value: 0078 assert key in self._knowndictproperties[name] or '*' in self._knowndictproperties[name], "dict key "+key+" as member of dict in item "+name+" is not known" 0079 v=value[key] 0080 assert isinstance(v, (str, unicode, buffer, int, long, float)), "only serializable types supported for values" 0081 0082 0083 def update(self, items): 0084 assert isinstance(items, dict), "update only supports dicts" # Feel free to fix this code ... 0085 for k in items: 0086 self._check_property(k, items[k]) 0087 super(basedataobject, self).update(items) 0088 0089 def __getitem__(self, name): 0090 # check when they are retrieved, not set. I did try 0091 # catching the append method, but the layers of nested 0092 # namespaces got too confused 0093 self._check_property(name) 0094 v=super(basedataobject, self).__getitem__(name) 0095 self._check_property(name, v) 0096 return v 0097 0098 def __setitem__(self, name, value): 0099 self._check_property(name, value) 0100 super(basedataobject,self).__setitem__(name, value) 0101 0102 def __setattr__(self, name, value): 0103 # note that we map setattr to update the dict 0104 self._check_property(name, value) 0105 self.__setitem__(name, value) 0106 0107 def __getattr__(self, name): 0108 if name not in self._knownproperties and name not in self._knownlistproperties and name not in self._knowndictproperties: 0109 raise AttributeError(name) 0110 self._check_property(name) 0111 if name in self.keys(): 0112 return self[name] 0113 return None 0114 0115 def __delattr__(self, name): 0116 self._check_property(name) 0117 if name in self.keys(): 0118 del self[name] 0119 0120 else: 0121 # non-debug mode - we don't do any attribute name/value type 0122 # checking as the data may (legitimately) be from a newer 0123 # version of the program. 0124 def __setattr__(self, name, value): 0125 # note that we map setattr to update the dict 0126 super(basedataobject,self).__setitem__(name, value) 0127 0128 def __getattr__(self, name): 0129 # and getattr checks the dict 0130 if name not in self._knownproperties and name not in self._knownlistproperties and name not in self._knowndictproperties: 0131 raise AttributeError(name) 0132 if name in self.keys(): 0133 return self[name] 0134 return None 0135 0136 def __delattr__(self, name): 0137 if name in self.keys(): 0138 del self[name] 0139 0140 # various methods for manging serials 0141 def GetBitPimSerial(self): 0142 "Returns the BitPim serial for this item" 0143 if "serials" not in self: 0144 raise KeyError("no bitpim serial present") 0145 for v in self.serials: 0146 if v["sourcetype"]=="bitpim": 0147 return v["id"] 0148 raise KeyError("no bitpim serial present") 0149 0150 # rng seeded at startup 0151 _persistrandom=random.Random() 0152 _shathingy=None 0153 def _getnextrandomid(self, item): 0154 """Returns random ids used to give unique serial numbers to items 0155 0156 @param item: any object - its memory location is used to help randomness 0157 @returns: a 20 character hexdigit string 0158 """ 0159 if basedataobject._shathingy is None: 0160 basedataobject._shathingy=sha.new() 0161 basedataobject._shathingy.update(`basedataobject._persistrandom.random()`) 0162 basedataobject._shathingy.update(`id(self)`) 0163 basedataobject._shathingy.update(`basedataobject._persistrandom.random()`) 0164 basedataobject._shathingy.update(`id(item)`) 0165 return basedataobject._shathingy.hexdigest() 0166 0167 0168 def EnsureBitPimSerial(self): 0169 "Ensures this entry has a serial" 0170 if self.serials is None: 0171 self.serials=[] 0172 for v in self.serials: 0173 if v["sourcetype"]=="bitpim": 0174 return 0175 self.serials.append({'sourcetype': "bitpim", "id": self._getnextrandomid(self.serials)}) 0176 0177 class dataobjectfactory: 0178 "Called by the code to read in objects when it needs a new object container" 0179 def __init__(self, dataobjectclass=basedataobject): 0180 self.dataobjectclass=dataobjectclass 0181 0182 if __debug__: 0183 def newdataobject(self, values={}): 0184 v=self.dataobjectclass() 0185 if len(values): 0186 v.update(values) 0187 return v 0188 else: 0189 def newdataobject(self, values={}): 0190 return self.dataobjectclass(values) 0191 0192 0193 def extractbitpimserials(dict): 0194 """Returns a new dict with keys being the bitpim serial for each 0195 row. Each item must be derived from basedataobject""" 0196 0197 res={} 0198 0199 for record in dict.itervalues(): 0200 res[record.GetBitPimSerial()]=record 0201 0202 return res 0203 0204 def ensurebitpimserials(dict): 0205 """Ensures that all records have a BitPim serial. Each item must 0206 be derived from basedataobject""" 0207 for record in dict.itervalues(): 0208 record.EnsureBitPimSerial() 0209 0210 def findentrywithbitpimserial(dict, serial): 0211 """Returns the entry from dict whose bitpim serial matches serial""" 0212 for record in dict.itervalues(): 0213 if record.GetBitPimSerial()==serial: 0214 return record 0215 raise KeyError("not item with serial "+serial+" found") 0216 0217 def ensurerecordtype(dict, factory): 0218 for key,record in dict.iteritems(): 0219 if not isinstance(record, basedataobject): 0220 dict[key]=factory.newdataobject(record) 0221 0222 0223 0224 # a factory that uses dicts to allocate new data objects 0225 dictdataobjectfactory=dataobjectfactory(dict) 0226 0227 0228 0229 0230 ### 0231 ### Actual database interaction is from this point on 0232 ### 0233 0234 0235 # Change this to True to see what is going on under the hood. It 0236 # will produce a lot of output! 0237 TRACE=False 0238 0239 def ExclusiveWrapper(method): 0240 """Wrap a method so that it has an exclusive lock on the database 0241 (noone else can read or write) until it has finished""" 0242 0243 # note that the existing threading safety checks in apsw will 0244 # catch any thread abuse issues. 0245 def _transactionwrapper(*args, **kwargs): 0246 # arg[0] should be a Database instance 0247 assert isinstance(args[0], Database) 0248 with args[0]: 0249 return method(*args, **kwargs) 0250 0251 setattr(_transactionwrapper, "__doc__", getattr(method, "__doc__")) 0252 return _transactionwrapper 0253 0254 def sqlquote(s): 0255 "returns an sqlite quoted string (the return value will begin and end with single quotes)" 0256 return "'"+s.replace("'", "''")+"'" 0257 0258 def idquote(s): 0259 """returns an sqlite quoted identifier (eg for when a column name is also an SQL keyword 0260 0261 The value returned is quoted in square brackets""" 0262 return '['+s+']' 0263 0264 class IntegrityCheckFailed(Exception): pass 0265 0266 class Database: 0267 0268 # Make this class a context manager so it can be used with WITH blocks 0269 def __enter__(self): 0270 self.excounter+=1 0271 self.transactionwrite=False 0272 if self.excounter==1: 0273 if TRACE: 0274 print "BEGIN EXCLUSIVE TRANSACTION" 0275 self.cursor.execute("BEGIN EXCLUSIVE TRANSACTION") 0276 self._schemacache={} 0277 return self 0278 0279 def __exit__(self, ex_type, ex_value, tb): 0280 self.excounter-=1 0281 if self.excounter==0: 0282 w=self.transactionwrite 0283 if tb is None: 0284 # no exception, so commit 0285 cmd="COMMIT TRANSACTION" if w else "END TRANSACTION" 0286 else: 0287 # an exception occurred, so rollback 0288 cmd="ROLLBACK TRANSACTION" if w else "END TRANSACTION" 0289 if TRACE: 0290 print cmd 0291 self.cursor.execute(cmd) 0292 0293 def __del__(self): 0294 # connections have to be closed now 0295 self.connection.close(True) 0296 0297 def __init__(self, filename, virtualtables=None): 0298 """ 0299 @param filename: database filename 0300 @param virtualtables: a list of dict specifying the virtual tables 0301 Each dict is expected to have the following keys: 0302 'tablename': the name of the virtual table 0303 'modulename': the name of the module that implements this virtual 0304 table 0305 'moduleclass': the ModuleBase subclass that implements this 0306 virtual table 0307 'args': arguments passed to instantiaion of the module class 0308 """ 0309 self.connection=apsw.Connection(filename) 0310 self.cursor=self.connection.cursor() 0311 # first tell sqlite to use the pre 3.4 format. this will allow downgrades 0312 self.cursor.execute("PRAGMA legacy_file_format=1") # nb you don't get an error for unknown pragmas 0313 # we always do an integrity check second 0314 icheck=[] 0315 print "database integrity check" 0316 for row in self.cursor.execute("PRAGMA integrity_check"): 0317 icheck.extend(row) 0318 print "database integrity check complete" 0319 icheck="\n".join(icheck) 0320 if icheck!="ok": 0321 raise IntegrityCheckFailed(icheck) 0322 # exclusive lock counter 0323 self.excounter=0 0324 # this should be set to true by any code that writes - it is 0325 # used by the exclusivewrapper to tell if it should do a 0326 # commit/rollback or just a plain end 0327 self.transactionwrite=False 0328 # a cache of the table schemas 0329 self._schemacache={} 0330 self.sql=self.cursor.execute 0331 self.sqlmany=self.cursor.executemany 0332 if TRACE: 0333 self.cursor.setexectrace(self._sqltrace) 0334 self.cursor.setrowtrace(self._rowtrace) 0335 if virtualtables is not None: 0336 # virtual tables are specified 0337 for vtable in virtualtables: 0338 # register the module 0339 self.connection.createmodule(vtable['modulename'], 0340 vtable['moduleclass'](*vtable['args'])) 0341 if not self.doestableexist(vtable['tablename']): 0342 # and declare the virtual table 0343 self.sql('CREATE VIRTUAL TABLE %s USING %s;'%(idquote(vtable['tablename']), 0344 idquote(vtable['modulename']))) 0345 0346 def _sqltrace(self, cmd, bindings): 0347 print "SQL:",cmd 0348 if bindings: 0349 print " bindings:",bindings 0350 return True 0351 0352 def _rowtrace(self, *row): 0353 print "ROW:",row 0354 return row 0355 0356 def sql(self, statement, params=()): 0357 "Executes statement and return a generator of the results" 0358 # this is replaced in init 0359 assert False 0360 0361 def sqlmany(self, statement, params): 0362 "execute statements repeatedly with params" 0363 # this is replaced in init 0364 assert False 0365 0366 def doestableexist(self, tablename): 0367 if tablename in self._schemacache: 0368 return True 0369 return bool(self.sql("select count(*) from sqlite_master where type='table' and name=%s" % (sqlquote(tablename),)).next()[0]) 0370 0371 def getcolumns(self, tablename, onlynames=False): 0372 res=self._schemacache.get(tablename,None) 0373 if res is None: 0374 res=[] 0375 for colnum,name,type, _, default, primarykey in self.sql("pragma table_info("+idquote(tablename)+")"): 0376 if primarykey: 0377 type+=" primary key" 0378 res.append([colnum,name,type]) 0379 self._schemacache[tablename]=res 0380 if onlynames: 0381 return [name for colnum,name,type in res] 0382 return res 0383 0384 @ExclusiveWrapper 0385 def savemajordict(self, tablename, dict, timestamp=None): 0386 """This is the entrypoint for saving a first level dictionary 0387 such as the phonebook or calendar. 0388 0389 @param tablename: name of the table to use 0390 @param dict: The dictionary of record. The key must be the uniqueid for each record. 0391 The @L{extractbitpimserials} function can do the conversion for you for 0392 phonebook and similar formatted records. 0393 @param timestamp: the UTC time in seconds since the epoch. This is 0394 """ 0395 0396 if timestamp is None: 0397 timestamp=time.time() 0398 0399 # work on a shallow copy of dict 0400 dict=dict.copy() 0401 0402 # make sure the table exists first 0403 if not self.doestableexist(tablename): 0404 # create table and include meta-fields 0405 self.transactionwrite=True 0406 self.sql("create table %s (__rowid__ integer primary key, __timestamp__, __deleted__ integer, __uid__ varchar)" % (idquote(tablename),)) 0407 0408 # get the latest values for each guid ... 0409 current=self.getmajordictvalues(tablename) 0410 # compare what we have, and update/mark deleted as appropriate ... 0411 deleted=[k for k in current if k not in dict] 0412 new=[k for k in dict if k not in current] 0413 modified=[k for k in dict if k in current] # only potentially modified ... 0414 0415 # deal with modified first 0416 dl=[] 0417 for i,k in enumerate(modified): 0418 if dict[k]==current[k]: 0419 # unmodified! 0420 del dict[k] 0421 dl.append(i) 0422 dl.reverse() 0423 for i in dl: 0424 del modified[i] 0425 0426 # add deleted entries back into dict 0427 for d in deleted: 0428 assert d not in dict 0429 dict[d]=current[d] 0430 dict[d]["__deleted__"]=1 0431 0432 # now we only have new, changed and deleted entries left in dict 0433 0434 # examine the keys in dict 0435 dk=[] 0436 for k in dict.keys(): 0437 # make a copy since we modify values, but it doesn't matter about deleted since we own those 0438 if k not in deleted: 0439 dict[k]=dict[k].copy() 0440 for kk in dict[k]: 0441 if kk not in dk: 0442 dk.append(kk) 0443 # verify that they don't start with __ 0444 assert len([k for k in dk if k.startswith("__") and not k=="__deleted__"])==0 0445 # get database keys 0446 dbkeys=self.getcolumns(tablename, onlynames=True) 0447 # are any missing? 0448 missing=[k for k in dk if k not in dbkeys] 0449 if len(missing): 0450 creates=[] 0451 # for each missing key, we have to work out if the value 0452 # is a list or dict type (which we indirect to another table) 0453 for m in missing: 0454 islist=None 0455 isdict=None 0456 isnotindirect=None 0457 for r in dict.keys(): 0458 record=dict[r] 0459 v=record.get(m,None) 0460 if v is None: 0461 continue 0462 if isinstance(v, list): 0463 islist=record 0464 elif isinstance(v,type({})): 0465 isdict=record 0466 else: 0467 isnotindirect=record 0468 # in devel code, we check every single value 0469 # in production, we just use the first we find 0470 if not __debug__: 0471 break 0472 if islist is None and isdict is None and isnotindirect is None: 0473 # they have the key but no record has any values, so we ignore it 0474 del dk[dk.index(m)] 0475 continue 0476 # don't do this type abuse at home ... 0477 if int(islist is not None)+int(isdict is not None)+int(isnotindirect is not None)!=int(True): 0478 # can't have it more than one way 0479 raise ValueError("key %s for table %s has values with inconsistent types. eg LIST: %s, DICT: %s, NOTINDIRECT: %s" % (m,tablename,`islist`,`isdict`,`isnotindirect`)) 0480 if islist is not None: 0481 creates.append( (m, "indirectBLOB") ) 0482 continue 0483 if isdict: 0484 creates.append( (m, "indirectdictBLOB")) 0485 continue 0486 if isnotindirect is not None: 0487 creates.append( (m, "valueBLOB") ) 0488 continue 0489 assert False, "You can't possibly get here!" 0490 if len(creates): 0491 self._altertable(tablename, creates, [], createindex=1) 0492 0493 # write out indirect values 0494 dbtkeys=self.getcolumns(tablename) 0495 # for every indirect, we have to replace the value with a pointer 0496 for _,n,t in dbtkeys: 0497 if t in ("indirectBLOB", "indirectdictBLOB"): 0498 indirects={} 0499 for r in dict.keys(): 0500 record=dict[r] 0501 v=record.get(n,None) 0502 if v is not None: 0503 if not len(v): # set zero length lists/dicts to None 0504 record[n]=None 0505 else: 0506 if t=="indirectdictBLOB": 0507 indirects[r]=[v] # make it a one item dict list 0508 else: 0509 indirects[r]=v 0510 if len(indirects): 0511 self.updateindirecttable(tablename+"__"+n, indirects) 0512 for r in indirects.keys(): 0513 dict[r][n]=indirects[r] 0514 0515 # and now the main table 0516 for k in dict.keys(): 0517 record=dict[k] 0518 record["__uid__"]=k 0519 rk=[x for x,y in record.items() if y is not None] 0520 rk.sort() 0521 cmd=["insert into", idquote(tablename), "( [__timestamp__],"] 0522 cmd.append(",".join([idquote(r) for r in rk])) 0523 cmd.extend([")", "values", "(?,"]) 0524 cmd.append(",".join(["?" for r in rk])) 0525 cmd.append(")") 0526 self.sql(" ".join(cmd), [timestamp]+[record[r] for r in rk]) 0527 self.transactionwrite=True 0528 0529 def updateindirecttable(self, tablename, indirects): 0530 # this is mostly similar to savemajordict, except we only deal 0531 # with lists of dicts, and we find existing records with the 0532 # same value if possible 0533 0534 # does the table even exist? 0535 if not self.doestableexist(tablename): 0536 # create table and include meta-fields 0537 self.sql("create table %s (__rowid__ integer primary key)" % (idquote(tablename),)) 0538 self.transactionwrite=True 0539 # get the list of keys from indirects 0540 datakeys=[] 0541 for i in indirects.keys(): 0542 assert isinstance(indirects[i], list) 0543 for v in indirects[i]: 0544 assert isinstance(v, dict) 0545 for k in v.keys(): 0546 if k not in datakeys: 0547 assert not k.startswith("__") 0548 datakeys.append(k) 0549 # get the keys from the table 0550 dbkeys=self.getcolumns(tablename, onlynames=True) 0551 # are any missing? 0552 missing=[k for k in datakeys if k not in dbkeys] 0553 if len(missing): 0554 self._altertable(tablename, [(m,"valueBLOB") for m in missing], [], createindex=2) 0555 # for each row we now work out the indirect information 0556 for r in indirects: 0557 res=tablename+"," 0558 for record in indirects[r]: 0559 cmd=["select __rowid__ from", idquote(tablename), "where"] 0560 params=[] 0561 coals=[] 0562 for d in datakeys: 0563 v=record.get(d,None) 0564 if v is None: 0565 coals.append(idquote(d)) 0566 else: 0567 if cmd[-1]!="where": 0568 cmd.append("and") 0569 cmd.extend([idquote(d), "= ?"]) 0570 params.append(v) 0571 assert cmd[-1]!="where" # there must be at least one non-none column! 0572 if len(coals)==1: 0573 cmd.extend(["and",coals[0],"isnull"]) 0574 elif len(coals)>1: 0575 cmd.extend(["and coalesce(",",".join(coals),") isnull"]) 0576 0577 found=None 0578 for found in self.sql(" ".join(cmd), params): 0579 # get matching row 0580 found=found[0] 0581 break 0582 if found is None: 0583 # add it 0584 cmd=["insert into", idquote(tablename), "("] 0585 params=[] 0586 for k in record: 0587 if cmd[-1]!="(": 0588 cmd.append(",") 0589 cmd.append(k) 0590 params.append(record[k]) 0591 cmd.extend([")", "values", "("]) 0592 cmd.append(",".join(["?" for p in params])) 0593 cmd.append("); select last_insert_rowid()") 0594 found=self.sql(" ".join(cmd), params).next()[0] 0595 self.transactionwrite=True 0596 res+=`found`+"," 0597 indirects[r]=res 0598 0599 @ExclusiveWrapper 0600 def getmajordictvalues(self, tablename, factory=dictdataobjectfactory, 0601 at_time=None): 0602 0603 if not self.doestableexist(tablename): 0604 return {} 0605 0606 res={} 0607 uids=[u[0] for u in self.sql("select distinct __uid__ from %s" % (idquote(tablename),))] 0608 schema=self.getcolumns(tablename) 0609 for colnum,name,type in schema: 0610 if name=='__deleted__': 0611 deleted=colnum 0612 elif name=='__uid__': 0613 uid=colnum 0614 # get all relevant rows 0615 if isinstance(at_time, (int, float)): 0616 sql_string="select * from %s where __uid__=? and __timestamp__<=%f order by __rowid__ desc limit 1" % (idquote(tablename), float(at_time)) 0617 else: 0618 sql_string="select * from %s where __uid__=? order by __rowid__ desc limit 1" % (idquote(tablename),) 0619 indirects={} 0620 for row in self.sqlmany(sql_string, [(u,) for u in uids]): 0621 if row[deleted]: 0622 continue 0623 record=factory.newdataobject() 0624 for colnum,name,type in schema: 0625 if name.startswith("__") or type not in ("valueBLOB", "indirectBLOB", "indirectdictBLOB") or row[colnum] is None: 0626 continue 0627 if type=="valueBLOB": 0628 record[name]=row[colnum] 0629 continue 0630 assert type=="indirectBLOB" or type=="indirectdictBLOB" 0631 if name not in indirects: 0632 indirects[name]=[] 0633 indirects[name].append( (row[uid], row[colnum], type) ) 0634 res[row[uid]]=record 0635 # now get the indirects 0636 for name,values in indirects.iteritems(): 0637 for uid,v,type in values: 0638 fieldvalue=self._getindirect(v) 0639 if fieldvalue: 0640 if type=="indirectBLOB": 0641 res[uid][name]=fieldvalue 0642 else: 0643 res[uid][name]=fieldvalue[0] 0644 return res 0645 0646 def _getindirect(self, what): 0647 """Gets a list of values (indirect) as described by what 0648 @param what: what to get - eg phonebook_serials,1,3,5, 0649 (note there is always a trailing comma) 0650 """ 0651 0652 tablename,rows=what.split(',', 1) 0653 schema=self.getcolumns(tablename) 0654 0655 res=[] 0656 for row in self.sqlmany("select * from %s where __rowid__=?" % (idquote(tablename),), [(int(r),) for r in rows.split(',') if len(r)]): 0657 record={} 0658 for colnum,name,type in schema: 0659 if name.startswith("__") or type not in ("valueBLOB", "indirectBLOB", "indirectdictBLOB") or row[colnum] is None: 0660 continue 0661 if type=="valueBLOB": 0662 record[name]=row[colnum] 0663 continue 0664 assert type=="indirectBLOB" or type=="indirectdictBLOB" 0665 assert False, "indirect in indirect not handled" 0666 assert len(record),"Database._getindirect has zero len record" 0667 res.append(record) 0668 assert len(res), "Database._getindirect has zero len res" 0669 return res 0670 0671 def _altertable(self, tablename, columnstoadd, columnstodel, createindex=0): 0672 """Alters the named table by deleting the specified columns, and 0673 adding the listed columns 0674 0675 @param tablename: name of the table to alter 0676 @param columnstoadd: a list of (name,type) of the columns to add 0677 @param columnstodel: a list name of the columns to delete 0678 @param createindex: what sort of index to create. 0 means none, 1 means on just __uid__ and 2 is on all data columns 0679 """ 0680 # indexes are automatically dropped when table is dropped so we don't need to 0681 dbtkeys=[x for x in self.getcolumns(tablename) \ 0682 if x[1] not in columnstodel] 0683 # clean out cache entry since we are about to invalidate it 0684 del self._schemacache[tablename] 0685 self.transactionwrite=True 0686 cmd=["create", "temporary", "table", idquote("backup_"+tablename), 0687 "(", 0688 ','.join(['%s %s'%(idquote(n), t) for _,n,t in dbtkeys]), 0689 ")"] 0690 self.sql(" ".join(cmd)) 0691 # copy the values into the temporary table 0692 self.sql("insert into %s select %s from %s" % (idquote("backup_"+tablename), 0693 ','.join([idquote(n) for _,n,_ in dbtkeys]), 0694 idquote(tablename))) 0695 # drop the source table 0696 self.sql("drop table %s" % (idquote(tablename),)) 0697 # recreate the source table with new columns 0698 del cmd[1] # remove temporary 0699 cmd[2]=idquote(tablename) # change tablename 0700 cmd[-2]=','.join(['%s %s'%(idquote(n), t) for _,n,t in dbtkeys]+\ 0701 ['%s %s'%(idquote(n), t) for n,t in columnstoadd]) # new list of columns 0702 self.sql(" ".join(cmd)) 0703 # create index if needed 0704 if createindex: 0705 if createindex==1: 0706 cmd=["create index", idquote("__index__"+tablename), "on", idquote(tablename), "(__uid__)"] 0707 elif createindex==2: 0708 cmd=["create index", idquote("__index__"+tablename), "on", idquote(tablename), "("] 0709 cols=[] 0710 for _,n,t in dbtkeys: 0711 if not n.startswith("__"): 0712 cols.append(idquote(n)) 0713 for n,t in columnstoadd: 0714 cols.append(idquote(n)) 0715 cmd.extend([",".join(cols), ")"]) 0716 else: 0717 raise ValueError("bad createindex "+`createindex`) 0718 self.sql(" ".join(cmd)) 0719 # put values back in 0720 cmd=["insert into", idquote(tablename), '(', 0721 ','.join([idquote(n) for _,n,_ in dbtkeys]), 0722 ")", "select * from", idquote("backup_"+tablename)] 0723 self.sql(" ".join(cmd)) 0724 self.sql("drop table "+idquote("backup_"+tablename)) 0725 0726 @ExclusiveWrapper 0727 def deleteold(self, tablename, uids=None, minvalues=3, maxvalues=5, keepoldest=93): 0728 """Deletes old entries from the database. The deletion is based 0729 on either criterion of maximum values or age of values matching. 0730 0731 @param uids: You can limit the items deleted to this list of uids, 0732 or None for all entries. 0733 @param minvalues: always keep at least this number of values 0734 @param maxvalues: maximum values to keep for any entry (you 0735 can supply None in which case no old entries will be removed 0736 based on how many there are). 0737 @param keepoldest: values older than this number of days before 0738 now are removed. You can also supply None in which case no 0739 entries will be removed based on age. 0740 @returns: number of rows removed,number of rows remaining 0741 """ 0742 if not self.doestableexist(tablename): 0743 return (0,0) 0744 0745 timecutoff=0 0746 if keepoldest is not None: 0747 timecutoff=time.time()-(keepoldest*24*60*60) 0748 if maxvalues is None: 0749 maxvalues=sys.maxint-1 0750 0751 if uids is None: 0752 uids=[u[0] for u in self.sql("select distinct __uid__ from %s" % (idquote(tablename),))] 0753 0754 deleterows=[] 0755 0756 for uid in uids: 0757 deleting=False 0758 for count, (rowid, deleted, timestamp) in enumerate( 0759 self.sql("select __rowid__,__deleted__, __timestamp__ from %s where __uid__=? order by __rowid__ desc" % (idquote(tablename),), [uid])): 0760 if count<minvalues: 0761 continue 0762 if deleting: 0763 deleterows.append(rowid) 0764 continue 0765 if count>=maxvalues or timestamp<timecutoff: 0766 deleting=True 0767 if deleted: 0768 # we are ok, this is an old value now deleted, so we can remove it 0769 deleterows.append(rowid) 0770 continue 0771 # we don't want to delete current data (which may 0772 # be very old and never updated) 0773 if count>0: 0774 deleterows.append(rowid) 0775 continue 0776 0777 self.sqlmany("delete from %s where __rowid__=?" % (idquote(tablename),), [(r,) for r in deleterows]) 0778 0779 return len(deleterows), self.sql("select count(*) from "+idquote(tablename)).next()[0] 0780 0781 @ExclusiveWrapper 0782 def savelist(self, tablename, values): 0783 """Just save a list of items (eg categories). There is no versioning or transaction history. 0784 0785 Internally the table has two fields. One is the actual value and the other indicates if 0786 the item is deleted. 0787 """ 0788 0789 # a tuple of the quoted table name 0790 tn=(idquote(tablename),) 0791 0792 if not self.doestableexist(tablename): 0793 self.sql("create table %s (__rowid__ integer primary key, item, __deleted__ integer)" % tn) 0794 0795 # some code to demonstrate my lack of experience with SQL .... 0796 delete=[] 0797 known=[] 0798 revive=[] 0799 for row, item, dead in self.sql("select __rowid__,item,__deleted__ from %s" % tn): 0800 known.append(item) 0801 if item in values: 0802 # we need this row 0803 if dead: 0804 revive.append((row,)) 0805 continue 0806 if dead: 0807 # don't need this entry and it is dead anyway 0808 continue 0809 delete.append((row,)) 0810 create=[(v,) for v in values if v not in known] 0811 0812 # update table as appropriate 0813 self.sqlmany("update %s set __deleted__=0 where __rowid__=?" % tn, revive) 0814 self.sqlmany("update %s set __deleted__=1 where __rowid__=?" % tn, delete) 0815 self.sqlmany("insert into %s (item, __deleted__) values (?,0)" % tn, create) 0816 if __debug__: 0817 vdup=values[:] 0818 vdup.sort() 0819 vv=self.loadlist(tablename) 0820 vv.sort() 0821 assert vdup==vv 0822 0823 @ExclusiveWrapper 0824 def loadlist(self, tablename): 0825 """Loads a list of items (eg categories)""" 0826 if not self.doestableexist(tablename): 0827 return [] 0828 return [v[0] for v in self.sql("select item from %s where __deleted__=0" % (idquote(tablename),))] 0829 0830 @ExclusiveWrapper 0831 def getchangescount(self, tablename): 0832 """Return the number of additions, deletions, and modifications 0833 made to this table over time. 0834 Expected fields containted in this table: __timestamp__,__deleted__, 0835 __uid__ 0836 Assuming that both __rowid__ and __timestamp__ values are both ascending 0837 """ 0838 if not self.doestableexist(tablename): 0839 return {} 0840 tn=idquote(tablename) 0841 # get the unique dates of changes 0842 sql_cmd='select distinct __timestamp__ from %s' % tn 0843 # setting up the return dict 0844 res={} 0845 for t in self.sql(sql_cmd): 0846 res[t[0]]={ 'add': 0, 'del': 0, 'mod': 0 } 0847 # go through the table and count the changes 0848 existing_uid={} 0849 sql_cmd='select __timestamp__,__uid__,__deleted__ from %s order by __timestamp__ asc' % tn 0850 for e in self.sql(sql_cmd): 0851 tt=e[0] 0852 uid=e[1] 0853 del_flg=e[2] 0854 if existing_uid.has_key(uid): 0855 if del_flg: 0856 res[tt]['del']+=1 0857 del existing_uid[uid] 0858 else: 0859 res[tt]['mod']+=1 0860 else: 0861 existing_uid[uid]=None 0862 res[tt]['add']+=1 0863 return res 0864 0865 class ModuleBase(object): 0866 """Base class to implement a specific Virtual Table module with apsw. 0867 For more info: 0868 http://www.sqlite.org/cvstrac/wiki/wiki?p=VirtualTables 0869 http://www.sqlite.org/cvstrac/wiki/wiki?p=VirtualTableMethods 0870 http://www.sqlite.org/cvstrac/wiki/wiki?p=VirtualTableBestIndexMethod 0871 """ 0872 def __init__(self, field_names): 0873 self.connection=None 0874 self.table_name=None 0875 # the first field is ALWAYS __rowid__ to be consistent with Database 0876 self.field_names=('__rowid__',)+field_names 0877 def Create(self, connection, modulename, databasename, vtablename, *args): 0878 """Called when the virtual table is created. 0879 @param connection: an instance of apsw.Connection 0880 @param modulename: string name of the module being invoked 0881 @param databasename: string name of this database 0882 @param vtablename: string name of this new virtual table 0883 @param args: additional arguments sent from the CREATE VIRTUAL TABLE 0884 statement 0885 @returns: a tuple of 2 values: an sql string describing the table, and 0886 an object implementing it: Me! 0887 """ 0888 self.table_name=vtablename 0889 fields=['__rowid__ integer primary key'] 0890 for field in self.field_names[1:]: 0891 fields.append(idquote(field)+' valueBLOB') 0892 fields='(%s)'%','.join(fields) 0893 return ('create table %s %s;'%(idquote(vtablename), fields), self) 0894 0895 def Connect(self, connection, modulename, databasename, vtablename, *args): 0896 """Connect to an existing virtual table, by default it is identical 0897 to Create 0898 """ 0899 return self.Create(connection, modulename, databasename, vtablename, 0900 *args) 0901 0902 def Destroy(self): 0903 """Release a connection to a virtual table and destroy the underlying 0904 table implementation. By default, we do nothing. 0905 """ 0906 pass 0907 def Disconnect(self): 0908 """Release a connection to a virtual table. By default, we do nothing. 0909 """ 0910 pass 0911 0912 def BestIndex(self, constraints, orderby): 0913 """Provide information on how to best access this table. 0914 Must be overriden by subclass. 0915 @param constraints: a tuple of (column #, op) defining a constraints 0916 @param orderby: a tuple of (column #, desc) defining the order by 0917 @returns a tuple of up to 5 values: 0918 0: aConstraingUsage: a tuple of the same size as constraints. 0919 Each item is either None, argv index(int), or (argv index, omit(Bool)). 0920 1: idxNum(int) 0921 2: idxStr(string) 0922 3: orderByConsumed(Bool) 0923 4: estimatedCost(float) 0924 """ 0925 raise NotImplementedError 0926 0927 def Begin(self): 0928 pass 0929 def Sync(self): 0930 pass 0931 def Commit(self): 0932 pass 0933 def Rollback(self): 0934 pass 0935 0936 def Open(self): 0937 """Create/prepare a cursor used for subsequent reading. 0938 @returns: the implementor object: Me! 0939 """ 0940 return self 0941 def Close(self): 0942 """Close a cursor previously created by Open 0943 By default, do nothing 0944 """ 0945 pass 0946 def Filter(self, idxNum, idxStr, argv): 0947 """Begin a search of a virtual table. 0948 @param idxNum: int value passed by BestIndex 0949 @param idxStr: string valued passed by BestIndex 0950 @param argv: constraint parameters requested by BestIndex 0951 @returns: None 0952 """ 0953 raise NotImplementedError 0954 def Eof(self): 0955 """Determines if the current cursor points to a valid row. 0956 The Sqlite doc is wrong on this. 0957 @returns: True if NOT valid row, False otherwise 0958 """ 0959 raise NotImplementedError 0960 def Column(self, N): 0961 """Find the value for the N-th column of the current row. 0962 @param N: the N-th column 0963 @returns: value of the N-th column 0964 """ 0965 raise NotImplementedError 0966 def Next(self): 0967 """Move the cursor to the next row. 0968 @returns: None 0969 """ 0970 raise NotImplementedError 0971 def Rowid(self): 0972 """Return the rowid of the current row. 0973 @returns: the rowid(int) of the current row. 0974 """ 0975 raise NotImplementedError 0976 def UpdateDeleteRow(self, rowid): 0977 """Delete row rowid 0978 @param rowid: 0979 @returns: None 0980 """ 0981 raise NotImplementedError 0982 def UpdateInsertRow(self, rowid, fields): 0983 """Insert a new row of data into the table 0984 @param rowid: if not None, use this rowid. If None, create a new rowid 0985 @param fields: a tuple of the field values in the order declared in 0986 Create/Connet 0987 @returns: rowid of the new row. 0988 """ 0989 raise NotImplementedError 0990 def UpdateChangeRow(self, rowid, newrowid, fields): 0991 """Change the row of the current rowid with the new rowid and new values 0992 @param rowid: rowid of the current row 0993 @param newrowid: new rowid 0994 @param fields: a tuple of the field values in the order declared in 0995 Create/Connect 0996 @returns: rowid of the new row 0997 """ 0998 raise NotImplementedError 0999 1000 if __name__=='__main__': 1001 import common 1002 import sys 1003 import time 1004 import os 1005 1006 sys.excepthook=common.formatexceptioneh 1007 1008 1009 # our own hacked version for testing 1010 class phonebookdataobject(basedataobject): 1011 # no change to _knownproperties (all of ours are list properties) 1012 _knownlistproperties=basedataobject._knownlistproperties.copy() 1013 _knownlistproperties.update( {'names': ['title', 'first', 'middle', 'last', 'full', 'nickname'], 1014 'categories': ['category'], 1015 'emails': ['email', 'type'], 1016 'urls': ['url', 'type'], 1017 'ringtones': ['ringtone', 'use'], 1018 'addresses': ['type', 'company', 'street', 'street2', 'city', 'state', 'postalcode', 'country'], 1019 'wallpapers': ['wallpaper', 'use'], 1020 'flags': ['secret'], 1021 'memos': ['memo'], 1022 'numbers': ['number', 'type', 'speeddial'], 1023 # serials is in parent object 1024 }) 1025 _knowndictproperties=basedataobject._knowndictproperties.copy() 1026 _knowndictproperties.update( {'repeat': ['daily', 'orange']} ) 1027 1028 phonebookobjectfactory=dataobjectfactory(phonebookdataobject) 1029 1030 # use the phonebook out of the examples directory 1031 try: 1032 execfile(os.getenv("DBTESTFILE", "examples/phonebook-index.idx")) 1033 except UnicodeError: 1034 common.unicode_execfile(os.getenv("DBTESTFILE", "examples/phonebook-index.idx")) 1035 1036 ensurerecordtype(phonebook, phonebookobjectfactory) 1037 1038 phonebookmaster=phonebook 1039 1040 def testfunc(): 1041 global phonebook, TRACE, db 1042 1043 # note that iterations increases the size of the 1044 # database/journal and will make each one take longer and 1045 # longer as the db/journal gets bigger 1046 if len(sys.argv)>=2: 1047 iterations=int(sys.argv[1]) 1048 else: 1049 iterations=1 1050 if iterations >1: 1051 TRACE=False 1052 1053 db=Database("testdb") 1054 1055 1056 b4=time.time() 1057 1058 1059 for i in xrange(iterations): 1060 phonebook=phonebookmaster.copy() 1061 1062 # write it out 1063 db.savemajordict("phonebook", extractbitpimserials(phonebook)) 1064 1065 # check what we get back is identical 1066 v=db.getmajordictvalues("phonebook") 1067 assert v==extractbitpimserials(phonebook) 1068 1069 # do a deletion 1070 del phonebook[17] # james bond @ microsoft 1071 db.savemajordict("phonebook", extractbitpimserials(phonebook)) 1072 # and verify 1073 v=db.getmajordictvalues("phonebook") 1074 assert v==extractbitpimserials(phonebook) 1075 1076 # modify a value 1077 phonebook[15]['addresses'][0]['city']="Bananarama" 1078 db.savemajordict("phonebook", extractbitpimserials(phonebook)) 1079 # and verify 1080 v=db.getmajordictvalues("phonebook") 1081 assert v==extractbitpimserials(phonebook) 1082 1083 after=time.time() 1084 1085 print "time per iteration is",(after-b4)/iterations,"seconds" 1086 print "total time was",after-b4,"seconds for",iterations,"iterations" 1087 1088 if iterations>1: 1089 print "testing repeated reads" 1090 b4=time.time() 1091 for i in xrange(iterations*10): 1092 db.getmajordictvalues("phonebook") 1093 after=time.time() 1094 print "\ttime per iteration is",(after-b4)/(iterations*10),"seconds" 1095 print "\ttotal time was",after-b4,"seconds for",iterations*10,"iterations" 1096 print 1097 print "testing repeated writes" 1098 x=extractbitpimserials(phonebook) 1099 k=x.keys() 1100 b4=time.time() 1101 for i in xrange(iterations*10): 1102 # we remove 1/3rd of the entries on each iteration 1103 xcopy=x.copy() 1104 for l in range(i,i+len(k)/3): 1105 del xcopy[k[l%len(x)]] 1106 db.savemajordict("phonebook",xcopy) 1107 after=time.time() 1108 print "\ttime per iteration is",(after-b4)/(iterations*10),"seconds" 1109 print "\ttotal time was",after-b4,"seconds for",iterations*10,"iterations" 1110 1111 1112 1113 if len(sys.argv)==3: 1114 # also run under hotspot then 1115 def profile(filename, command): 1116 import hotshot, hotshot.stats, os 1117 file=os.path.abspath(filename) 1118 profile=hotshot.Profile(file) 1119 profile.run(command) 1120 profile.close() 1121 del profile 1122 howmany=100 1123 stats=hotshot.stats.load(file) 1124 stats.strip_dirs() 1125 stats.sort_stats('time', 'calls') 1126 stats.print_stats(100) 1127 stats.sort_stats('cum', 'calls') 1128 stats.print_stats(100) 1129 stats.sort_stats('calls', 'time') 1130 stats.print_stats(100) 1131 sys.exit(0) 1132 1133 profile("dbprof", "testfunc()") 1134 1135 else: 1136 testfunc() 1137 1138 1139
Generated by PyXR 0.9.4