PyXR

c:\projects\bitpim\src \ database.py



0001 ### BITPIM
0002 ###
0003 ### Copyright (C) 2004 Roger Binns <rogerb@rogerbinns.com>
0004 ###
0005 ### This program is free software; you can redistribute it and/or modify
0006 ### it under the terms of the BitPim license as detailed in the LICENSE file.
0007 ###
0008 ### $Id: database.py 4370 2007-08-21 21:39:36Z djpham $
0009 
0010 """Interface to the database"""
0011 from __future__ import with_statement
0012 import os
0013 import copy
0014 import time
0015 import sha
0016 import random
0017 
0018 import apsw
0019 
0020 import common
0021 ###
0022 ### The first section of this file deals with typical objects used to
0023 ### represent data items and various methods for wrapping them.
0024 ###
0025 
0026 
0027 
0028 class basedataobject(dict):
0029     """A base object derived from dict that is used for various
0030     records.  Existing code can just continue to treat it as a dict.
0031     New code can treat it as dict, as well as access via attribute
0032     names (ie object["foo"] or object.foo).  attribute name access
0033     will always give a result includes None if the name is not in
0034     the dict.
0035 
0036     As a bonus this class includes checking of attribute names and
0037     types in non-production runs.  That will help catch typos etc.
0038     For production runs we may be receiving data that was written out
0039     by a newer version of BitPim so we don't check or error."""
0040     # which properties we know about
0041     _knownproperties=[]
0042     # which ones we know about that should be a list of dicts
0043     _knownlistproperties={'serials': ['sourcetype', '*']}
0044     # which ones we know about that should be a dict
0045     _knowndictproperties={}
0046 
0047     if __debug__:
0048         # in debug code we check key name and value types
0049 
0050         def _check_property(self,name,value=None):
0051             # check it
0052             assert isinstance(name, (str, unicode)), "keys must be a string type"
0053             assert name in self._knownproperties or name in self._knownlistproperties or name in self._knowndictproperties, "unknown property named '"+name+"'"
0054             if value is None: return
0055             if name in getattr(self, "_knownlistproperties"):
0056                 assert isinstance(value, list), "list properties ("+name+") must be given a list as value"
0057                 # each list member must be a dict
0058                 for v in value:
0059                     self._check_property_dictvalue(name,v)
0060                 return
0061             if name in getattr(self, "_knowndictproperties"):
0062                 assert isinstance(value, dict), "dict properties ("+name+") must be given a dict as value"
0063                 self._check_property_dictvalue(name,value)
0064                 return
0065             # the value must be a basetype supported by apsw/SQLite
0066             assert isinstance(value, (str, unicode, buffer, int, long, float)), "only serializable types supported for values"
0067 
0068         def _check_property_dictvalue(self, name, value):
0069             assert isinstance(value, dict), "item(s) in "+name+" (a list) must be dicts"
0070             assert name in self._knownlistproperties or name in self._knowndictproperties
0071             if name in self._knownlistproperties:
0072                 for key in value:
0073                     assert key in self._knownlistproperties[name] or '*' in self._knownlistproperties[name], "dict key "+key+" as member of item in list "+name+" is not known"
0074                     v=value[key]
0075                     assert isinstance(v, (str, unicode, buffer, int, long, float)), "only serializable types supported for values"
0076             elif name in self._knowndictproperties:
0077                 for key in value:
0078                     assert key in self._knowndictproperties[name] or '*' in self._knowndictproperties[name], "dict key "+key+" as member of dict in item "+name+" is not known"
0079                     v=value[key]
0080                     assert isinstance(v, (str, unicode, buffer, int, long, float)), "only serializable types supported for values"
0081                 
0082                 
0083         def update(self, items):
0084             assert isinstance(items, dict), "update only supports dicts" # Feel free to fix this code ...
0085             for k in items:
0086                 self._check_property(k, items[k])
0087             super(basedataobject, self).update(items)
0088 
0089         def __getitem__(self, name):
0090             # check when they are retrieved, not set.  I did try
0091             # catching the append method, but the layers of nested
0092             # namespaces got too confused
0093             self._check_property(name)
0094             v=super(basedataobject, self).__getitem__(name)
0095             self._check_property(name, v)
0096             return v
0097 
0098         def __setitem__(self, name, value):
0099             self._check_property(name, value)
0100             super(basedataobject,self).__setitem__(name, value)
0101 
0102         def __setattr__(self, name, value):
0103             # note that we map setattr to update the dict
0104             self._check_property(name, value)
0105             self.__setitem__(name, value)
0106 
0107         def __getattr__(self, name):
0108             if name not in self._knownproperties and name not in self._knownlistproperties and  name not in self._knowndictproperties:
0109                 raise AttributeError(name)
0110             self._check_property(name)
0111             if name in self.keys():
0112                 return self[name]
0113             return None
0114 
0115         def __delattr__(self, name):
0116             self._check_property(name)
0117             if name in self.keys():
0118                 del self[name]
0119 
0120     else:
0121         # non-debug mode - we don't do any attribute name/value type
0122         # checking as the data may (legitimately) be from a newer
0123         # version of the program.
0124         def __setattr__(self, name, value):
0125             # note that we map setattr to update the dict
0126             super(basedataobject,self).__setitem__(name, value)
0127 
0128         def __getattr__(self, name):
0129             # and getattr checks the dict
0130             if name not in self._knownproperties and name not in self._knownlistproperties and  name not in self._knowndictproperties:
0131                 raise AttributeError(name)
0132             if name in self.keys():
0133                 return self[name]
0134             return None
0135 
0136         def __delattr__(self, name):
0137             if name in self.keys():
0138                 del self[name]
0139 
0140     # various methods for manging serials
0141     def GetBitPimSerial(self):
0142         "Returns the BitPim serial for this item"
0143         if "serials" not in self:
0144             raise KeyError("no bitpim serial present")
0145         for v in self.serials:
0146             if v["sourcetype"]=="bitpim":
0147                 return v["id"]
0148         raise KeyError("no bitpim serial present")
0149 
0150     # rng seeded at startup
0151     _persistrandom=random.Random()
0152     _shathingy=None
0153     def _getnextrandomid(self, item):
0154         """Returns random ids used to give unique serial numbers to items
0155 
0156         @param item: any object - its memory location is used to help randomness
0157         @returns: a 20 character hexdigit string
0158         """
0159         if basedataobject._shathingy is None:
0160             basedataobject._shathingy=sha.new()
0161             basedataobject._shathingy.update(`basedataobject._persistrandom.random()`)
0162         basedataobject._shathingy.update(`id(self)`)
0163         basedataobject._shathingy.update(`basedataobject._persistrandom.random()`)
0164         basedataobject._shathingy.update(`id(item)`)
0165         return basedataobject._shathingy.hexdigest()
0166 
0167     
0168     def EnsureBitPimSerial(self):
0169         "Ensures this entry has a serial"
0170         if self.serials is None:
0171             self.serials=[]
0172         for v in self.serials:
0173             if v["sourcetype"]=="bitpim":
0174                 return
0175         self.serials.append({'sourcetype': "bitpim", "id": self._getnextrandomid(self.serials)})
0176 
0177 class dataobjectfactory:
0178     "Called by the code to read in objects when it needs a new object container"
0179     def __init__(self, dataobjectclass=basedataobject):
0180         self.dataobjectclass=dataobjectclass
0181 
0182     if __debug__:
0183         def newdataobject(self, values={}):
0184             v=self.dataobjectclass()
0185             if len(values):
0186                 v.update(values)
0187             return v
0188     else:
0189         def newdataobject(self, values={}):
0190             return self.dataobjectclass(values)
0191 
0192 
0193 def extractbitpimserials(dict):
0194     """Returns a new dict with keys being the bitpim serial for each
0195     row.  Each item must be derived from basedataobject"""
0196 
0197     res={}
0198     
0199     for record in dict.itervalues():
0200         res[record.GetBitPimSerial()]=record
0201 
0202     return res
0203 
0204 def ensurebitpimserials(dict):
0205     """Ensures that all records have a BitPim serial.  Each item must
0206     be derived from basedataobject"""
0207     for record in dict.itervalues():
0208         record.EnsureBitPimSerial()
0209 
0210 def findentrywithbitpimserial(dict, serial):
0211     """Returns the entry from dict whose bitpim serial matches serial"""
0212     for record in dict.itervalues():
0213         if record.GetBitPimSerial()==serial:
0214             return record
0215     raise KeyError("not item with serial "+serial+" found")
0216 
0217 def ensurerecordtype(dict, factory):
0218     for key,record in dict.iteritems():
0219         if not isinstance(record, basedataobject):
0220             dict[key]=factory.newdataobject(record)
0221         
0222 
0223 
0224 # a factory that uses dicts to allocate new data objects
0225 dictdataobjectfactory=dataobjectfactory(dict)
0226 
0227 
0228 
0229 
0230 ###
0231 ### Actual database interaction is from this point on
0232 ###
0233 
0234 
0235 # Change this to True to see what is going on under the hood.  It
0236 # will produce a lot of output!
0237 TRACE=False
0238 
0239 def ExclusiveWrapper(method):
0240     """Wrap a method so that it has an exclusive lock on the database
0241     (noone else can read or write) until it has finished"""
0242 
0243     # note that the existing threading safety checks in apsw will
0244     # catch any thread abuse issues.
0245     def _transactionwrapper(*args, **kwargs):
0246         # arg[0] should be a Database instance
0247         assert isinstance(args[0], Database)
0248         with args[0]:
0249             return method(*args, **kwargs)
0250 
0251     setattr(_transactionwrapper, "__doc__", getattr(method, "__doc__"))
0252     return _transactionwrapper
0253 
0254 def sqlquote(s):
0255     "returns an sqlite quoted string (the return value will begin and end with single quotes)"
0256     return "'"+s.replace("'", "''")+"'"
0257 
0258 def idquote(s):
0259     """returns an sqlite quoted identifier (eg for when a column name is also an SQL keyword
0260 
0261     The value returned is quoted in square brackets"""
0262     return '['+s+']'
0263 
0264 class IntegrityCheckFailed(Exception): pass
0265 
0266 class Database:
0267 
0268     # Make this class a context manager so it can be used with WITH blocks
0269     def __enter__(self):
0270         self.excounter+=1
0271         self.transactionwrite=False
0272         if self.excounter==1:
0273             if TRACE:
0274                 print "BEGIN EXCLUSIVE TRANSACTION"
0275             self.cursor.execute("BEGIN EXCLUSIVE TRANSACTION")
0276             self._schemacache={}
0277         return self
0278 
0279     def __exit__(self, ex_type, ex_value, tb):
0280         self.excounter-=1
0281         if self.excounter==0:
0282             w=self.transactionwrite
0283             if tb is None:
0284                 # no exception, so commit
0285                 cmd="COMMIT TRANSACTION" if w else "END TRANSACTION"
0286             else:
0287                 # an exception occurred, so rollback
0288                 cmd="ROLLBACK TRANSACTION" if w else "END TRANSACTION"
0289             if TRACE:
0290                 print cmd
0291             self.cursor.execute(cmd)
0292 
0293     def __del__(self):
0294         # connections have to be closed now
0295         self.connection.close(True)
0296 
0297     def __init__(self, filename, virtualtables=None):
0298         """
0299         @param filename: database filename
0300         @param virtualtables: a list of dict specifying the virtual tables
0301             Each dict is expected to have the following keys:
0302             'tablename': the name of the virtual table
0303             'modulename': the name of the module that implements this virtual
0304                 table
0305             'moduleclass': the ModuleBase subclass that implements this
0306                 virtual table
0307             'args': arguments passed to instantiaion of the module class
0308         """
0309         self.connection=apsw.Connection(filename)
0310         self.cursor=self.connection.cursor()
0311         # first tell sqlite to use the pre 3.4 format.  this will allow downgrades
0312         self.cursor.execute("PRAGMA legacy_file_format=1") # nb you don't get an error for unknown pragmas
0313         # we always do an integrity check second
0314         icheck=[]
0315         print "database integrity check"
0316         for row in self.cursor.execute("PRAGMA integrity_check"):
0317             icheck.extend(row)
0318         print "database integrity check complete"
0319         icheck="\n".join(icheck)
0320         if icheck!="ok":
0321             raise IntegrityCheckFailed(icheck)
0322         # exclusive lock counter
0323         self.excounter=0
0324         # this should be set to true by any code that writes - it is
0325         # used by the exclusivewrapper to tell if it should do a
0326         # commit/rollback or just a plain end
0327         self.transactionwrite=False
0328         # a cache of the table schemas
0329         self._schemacache={}
0330         self.sql=self.cursor.execute
0331         self.sqlmany=self.cursor.executemany
0332         if TRACE:
0333             self.cursor.setexectrace(self._sqltrace)
0334             self.cursor.setrowtrace(self._rowtrace)
0335         if virtualtables is not None:
0336             # virtual tables are specified
0337             for vtable in virtualtables:
0338                 # register the module
0339                 self.connection.createmodule(vtable['modulename'],
0340                                              vtable['moduleclass'](*vtable['args']))
0341                 if not self.doestableexist(vtable['tablename']):
0342                     # and declare the virtual table
0343                     self.sql('CREATE VIRTUAL TABLE %s USING %s;'%(idquote(vtable['tablename']),
0344                                                                   idquote(vtable['modulename'])))
0345 
0346     def _sqltrace(self, cmd, bindings):
0347         print "SQL:",cmd
0348         if bindings:
0349             print " bindings:",bindings
0350         return True
0351 
0352     def _rowtrace(self, *row):
0353         print "ROW:",row
0354         return row
0355 
0356     def sql(self, statement, params=()):
0357         "Executes statement and return a generator of the results"
0358         # this is replaced in init
0359         assert False
0360 
0361     def sqlmany(self, statement, params):
0362         "execute statements repeatedly with params"
0363         # this is replaced in init
0364         assert False
0365             
0366     def doestableexist(self, tablename):
0367         if tablename in self._schemacache:
0368             return True
0369         return bool(self.sql("select count(*) from sqlite_master where type='table' and name=%s" % (sqlquote(tablename),)).next()[0])
0370 
0371     def getcolumns(self, tablename, onlynames=False):
0372         res=self._schemacache.get(tablename,None)
0373         if res is None:
0374             res=[]
0375             for colnum,name,type, _, default, primarykey in self.sql("pragma table_info("+idquote(tablename)+")"):
0376                 if primarykey:
0377                     type+=" primary key"
0378                 res.append([colnum,name,type])
0379             self._schemacache[tablename]=res
0380         if onlynames:
0381             return [name for colnum,name,type in res]
0382         return res
0383 
0384     @ExclusiveWrapper
0385     def savemajordict(self, tablename, dict, timestamp=None):
0386         """This is the entrypoint for saving a first level dictionary
0387         such as the phonebook or calendar.
0388 
0389         @param tablename: name of the table to use
0390         @param dict: The dictionary of record.  The key must be the uniqueid for each record.
0391                    The @L{extractbitpimserials} function can do the conversion for you for
0392                    phonebook and similar formatted records.
0393         @param timestamp: the UTC time in seconds since the epoch.  This is 
0394         """
0395 
0396         if timestamp is None:
0397             timestamp=time.time()
0398 
0399         # work on a shallow copy of dict
0400         dict=dict.copy()
0401         
0402         # make sure the table exists first
0403         if not self.doestableexist(tablename):
0404             # create table and include meta-fields
0405             self.transactionwrite=True
0406             self.sql("create table %s (__rowid__ integer primary key, __timestamp__, __deleted__ integer, __uid__ varchar)" % (idquote(tablename),))
0407 
0408         # get the latest values for each guid ...
0409         current=self.getmajordictvalues(tablename)
0410         # compare what we have, and update/mark deleted as appropriate ...
0411         deleted=[k for k in current if k not in dict]
0412         new=[k for k in dict if k not in current]
0413         modified=[k for k in dict if k in current] # only potentially modified ...
0414 
0415         # deal with modified first
0416         dl=[]
0417         for i,k in enumerate(modified):
0418             if dict[k]==current[k]:
0419                 # unmodified!
0420                 del dict[k]
0421                 dl.append(i)
0422         dl.reverse()
0423         for i in dl:
0424             del modified[i]
0425 
0426         # add deleted entries back into dict
0427         for d in deleted:
0428             assert d not in dict
0429             dict[d]=current[d]
0430             dict[d]["__deleted__"]=1
0431             
0432         # now we only have new, changed and deleted entries left in dict
0433 
0434         # examine the keys in dict
0435         dk=[]
0436         for k in dict.keys():
0437             # make a copy since we modify values, but it doesn't matter about deleted since we own those
0438             if k not in deleted:
0439                 dict[k]=dict[k].copy()
0440             for kk in dict[k]:
0441                 if kk not in dk:
0442                     dk.append(kk)
0443         # verify that they don't start with __
0444         assert len([k for k in dk if k.startswith("__") and not k=="__deleted__"])==0
0445         # get database keys
0446         dbkeys=self.getcolumns(tablename, onlynames=True)
0447         # are any missing?
0448         missing=[k for k in dk if k not in dbkeys]
0449         if len(missing):
0450             creates=[]
0451             # for each missing key, we have to work out if the value
0452             # is a list or dict type (which we indirect to another table)
0453             for m in missing:
0454                 islist=None
0455                 isdict=None
0456                 isnotindirect=None
0457                 for r in dict.keys():
0458                     record=dict[r]
0459                     v=record.get(m,None)
0460                     if v is None:
0461                         continue
0462                     if isinstance(v, list):
0463                         islist=record
0464                     elif isinstance(v,type({})):
0465                         isdict=record
0466                     else:
0467                         isnotindirect=record
0468                     # in devel code, we check every single value
0469                     # in production, we just use the first we find
0470                     if not __debug__:
0471                         break
0472                 if islist is None and isdict is None and isnotindirect is None:
0473                     # they have the key but no record has any values, so we ignore it
0474                     del dk[dk.index(m)]
0475                     continue
0476                 # don't do this type abuse at home ...
0477                 if int(islist is not None)+int(isdict is not None)+int(isnotindirect is not None)!=int(True):
0478                     # can't have it more than one way
0479                     raise ValueError("key %s for table %s has values with inconsistent types. eg LIST: %s, DICT: %s, NOTINDIRECT: %s" % (m,tablename,`islist`,`isdict`,`isnotindirect`))
0480                 if islist is not None:
0481                     creates.append( (m, "indirectBLOB") )
0482                     continue
0483                 if isdict:
0484                     creates.append( (m, "indirectdictBLOB"))
0485                     continue
0486                 if isnotindirect is not None:
0487                     creates.append( (m, "valueBLOB") )
0488                     continue
0489                 assert False, "You can't possibly get here!"
0490             if len(creates):
0491                 self._altertable(tablename, creates, [], createindex=1)
0492 
0493         # write out indirect values
0494         dbtkeys=self.getcolumns(tablename)
0495         # for every indirect, we have to replace the value with a pointer
0496         for _,n,t in dbtkeys:
0497             if t in ("indirectBLOB", "indirectdictBLOB"):
0498                 indirects={}
0499                 for r in dict.keys():
0500                     record=dict[r]
0501                     v=record.get(n,None)
0502                     if v is not None:
0503                         if not len(v): # set zero length lists/dicts to None
0504                             record[n]=None
0505                         else:
0506                             if t=="indirectdictBLOB":
0507                                 indirects[r]=[v] # make it a one item dict list
0508                             else:
0509                                 indirects[r]=v
0510                 if len(indirects):
0511                     self.updateindirecttable(tablename+"__"+n, indirects)
0512                     for r in indirects.keys():
0513                         dict[r][n]=indirects[r]
0514 
0515         # and now the main table
0516         for k in dict.keys():
0517             record=dict[k]
0518             record["__uid__"]=k
0519             rk=[x for x,y in record.items() if y is not None]
0520             rk.sort()
0521             cmd=["insert into", idquote(tablename), "( [__timestamp__],"]
0522             cmd.append(",".join([idquote(r) for r in rk]))
0523             cmd.extend([")", "values", "(?,"])
0524             cmd.append(",".join(["?" for r in rk]))
0525             cmd.append(")")
0526             self.sql(" ".join(cmd), [timestamp]+[record[r] for r in rk])
0527             self.transactionwrite=True
0528         
0529     def updateindirecttable(self, tablename, indirects):
0530         # this is mostly similar to savemajordict, except we only deal
0531         # with lists of dicts, and we find existing records with the
0532         # same value if possible
0533 
0534         # does the table even exist?
0535         if not self.doestableexist(tablename):
0536             # create table and include meta-fields
0537             self.sql("create table %s (__rowid__ integer primary key)" % (idquote(tablename),))
0538             self.transactionwrite=True
0539         # get the list of keys from indirects
0540         datakeys=[]
0541         for i in indirects.keys():
0542             assert isinstance(indirects[i], list)
0543             for v in indirects[i]:
0544                 assert isinstance(v, dict)
0545                 for k in v.keys():
0546                     if k not in datakeys:
0547                         assert not k.startswith("__")
0548                         datakeys.append(k)
0549         # get the keys from the table
0550         dbkeys=self.getcolumns(tablename, onlynames=True)
0551         # are any missing?
0552         missing=[k for k in datakeys if k not in dbkeys]
0553         if len(missing):
0554             self._altertable(tablename, [(m,"valueBLOB") for m in missing], [], createindex=2)
0555         # for each row we now work out the indirect information
0556         for r in indirects:
0557             res=tablename+","
0558             for record in indirects[r]:
0559                 cmd=["select __rowid__ from", idquote(tablename), "where"]
0560                 params=[]
0561                 coals=[]
0562                 for d in datakeys:
0563                     v=record.get(d,None)
0564                     if v is None:
0565                         coals.append(idquote(d))
0566                     else:
0567                         if cmd[-1]!="where":
0568                             cmd.append("and")
0569                         cmd.extend([idquote(d), "= ?"])
0570                         params.append(v)
0571                 assert cmd[-1]!="where" # there must be at least one non-none column!
0572                 if len(coals)==1:
0573                     cmd.extend(["and",coals[0],"isnull"])
0574                 elif len(coals)>1:
0575                     cmd.extend(["and coalesce(",",".join(coals),") isnull"])
0576 
0577                 found=None
0578                 for found in self.sql(" ".join(cmd), params):
0579                     # get matching row
0580                     found=found[0]
0581                     break
0582                 if found is None:
0583                     # add it
0584                     cmd=["insert into", idquote(tablename), "("]
0585                     params=[]
0586                     for k in record:
0587                         if cmd[-1]!="(":
0588                             cmd.append(",")
0589                         cmd.append(k)
0590                         params.append(record[k])
0591                     cmd.extend([")", "values", "("])
0592                     cmd.append(",".join(["?" for p in params]))
0593                     cmd.append("); select last_insert_rowid()")
0594                     found=self.sql(" ".join(cmd), params).next()[0]
0595                     self.transactionwrite=True
0596                 res+=`found`+","
0597             indirects[r]=res
0598                         
0599     @ExclusiveWrapper
0600     def getmajordictvalues(self, tablename, factory=dictdataobjectfactory,
0601                            at_time=None):
0602 
0603         if not self.doestableexist(tablename):
0604             return {}
0605 
0606         res={}
0607         uids=[u[0] for u in self.sql("select distinct __uid__ from %s" % (idquote(tablename),))]
0608         schema=self.getcolumns(tablename)
0609         for colnum,name,type in schema:
0610             if name=='__deleted__':
0611                 deleted=colnum
0612             elif name=='__uid__':
0613                 uid=colnum
0614         # get all relevant rows
0615         if isinstance(at_time, (int, float)):
0616             sql_string="select * from %s where __uid__=? and __timestamp__<=%f order by __rowid__ desc limit 1" % (idquote(tablename), float(at_time))
0617         else:
0618             sql_string="select * from %s where __uid__=? order by __rowid__ desc limit 1" % (idquote(tablename),)
0619         indirects={}
0620         for row in self.sqlmany(sql_string, [(u,) for u in uids]):
0621             if row[deleted]:
0622                 continue
0623             record=factory.newdataobject()
0624             for colnum,name,type in schema:
0625                 if name.startswith("__") or type not in ("valueBLOB", "indirectBLOB", "indirectdictBLOB") or row[colnum] is None:
0626                     continue
0627                 if type=="valueBLOB":
0628                     record[name]=row[colnum]
0629                     continue
0630                 assert type=="indirectBLOB" or type=="indirectdictBLOB"
0631                 if name not in indirects:
0632                     indirects[name]=[]
0633                 indirects[name].append( (row[uid], row[colnum], type) )
0634             res[row[uid]]=record
0635         # now get the indirects
0636         for name,values in indirects.iteritems():
0637             for uid,v,type in values:
0638                 fieldvalue=self._getindirect(v)
0639                 if fieldvalue:
0640                     if type=="indirectBLOB":
0641                         res[uid][name]=fieldvalue
0642                     else:
0643                         res[uid][name]=fieldvalue[0]
0644         return res
0645 
0646     def _getindirect(self, what):
0647         """Gets a list of values (indirect) as described by what
0648         @param what: what to get - eg phonebook_serials,1,3,5,
0649                       (note there is always a trailing comma)
0650         """
0651 
0652         tablename,rows=what.split(',', 1)
0653         schema=self.getcolumns(tablename)
0654         
0655         res=[]
0656         for row in self.sqlmany("select * from %s where __rowid__=?" % (idquote(tablename),), [(int(r),) for r in rows.split(',') if len(r)]):
0657             record={}
0658             for colnum,name,type in schema:
0659                 if name.startswith("__") or type not in ("valueBLOB", "indirectBLOB", "indirectdictBLOB") or row[colnum] is None:
0660                     continue
0661                 if type=="valueBLOB":
0662                     record[name]=row[colnum]
0663                     continue
0664                 assert type=="indirectBLOB" or type=="indirectdictBLOB"
0665                 assert False, "indirect in indirect not handled"
0666             assert len(record),"Database._getindirect has zero len record"
0667             res.append(record)
0668         assert len(res), "Database._getindirect has zero len res"
0669         return res
0670         
0671     def _altertable(self, tablename, columnstoadd, columnstodel, createindex=0):
0672         """Alters the named table by deleting the specified columns, and
0673         adding the listed columns
0674 
0675         @param tablename: name of the table to alter
0676         @param columnstoadd: a list of (name,type) of the columns to add
0677         @param columnstodel: a list name of the columns to delete
0678         @param createindex: what sort of index to create.  0 means none, 1 means on just __uid__ and 2 is on all data columns
0679         """
0680         # indexes are automatically dropped when table is dropped so we don't need to
0681         dbtkeys=[x for x in self.getcolumns(tablename) \
0682                  if x[1] not in columnstodel]
0683         # clean out cache entry since we are about to invalidate it
0684         del self._schemacache[tablename]
0685         self.transactionwrite=True
0686         cmd=["create", "temporary", "table", idquote("backup_"+tablename),
0687              "(",
0688              ','.join(['%s %s'%(idquote(n), t) for _,n,t in dbtkeys]),
0689              ")"]
0690         self.sql(" ".join(cmd))
0691         # copy the values into the temporary table
0692         self.sql("insert into %s select %s from %s" % (idquote("backup_"+tablename),
0693                                                        ','.join([idquote(n) for _,n,_ in dbtkeys]),
0694                                                        idquote(tablename)))
0695         # drop the source table
0696         self.sql("drop table %s" % (idquote(tablename),))
0697         # recreate the source table with new columns
0698         del cmd[1] # remove temporary
0699         cmd[2]=idquote(tablename) # change tablename
0700         cmd[-2]=','.join(['%s %s'%(idquote(n), t) for _,n,t in dbtkeys]+\
0701                          ['%s %s'%(idquote(n), t) for n,t in columnstoadd]) # new list of columns
0702         self.sql(" ".join(cmd))
0703         # create index if needed
0704         if createindex:
0705             if createindex==1:
0706                 cmd=["create index", idquote("__index__"+tablename), "on", idquote(tablename), "(__uid__)"]
0707             elif createindex==2:
0708                 cmd=["create index", idquote("__index__"+tablename), "on", idquote(tablename), "("]
0709                 cols=[]
0710                 for _,n,t in dbtkeys:
0711                     if not n.startswith("__"):
0712                         cols.append(idquote(n))
0713                 for n,t in columnstoadd:
0714                     cols.append(idquote(n))
0715                 cmd.extend([",".join(cols), ")"])
0716             else:
0717                 raise ValueError("bad createindex "+`createindex`)
0718             self.sql(" ".join(cmd))
0719         # put values back in
0720         cmd=["insert into", idquote(tablename), '(',
0721              ','.join([idquote(n) for _,n,_ in dbtkeys]),
0722              ")", "select * from", idquote("backup_"+tablename)]
0723         self.sql(" ".join(cmd))
0724         self.sql("drop table "+idquote("backup_"+tablename))
0725 
0726     @ExclusiveWrapper
0727     def deleteold(self, tablename, uids=None, minvalues=3, maxvalues=5, keepoldest=93):
0728         """Deletes old entries from the database.  The deletion is based
0729         on either criterion of maximum values or age of values matching.
0730 
0731         @param uids: You can limit the items deleted to this list of uids,
0732              or None for all entries.
0733         @param minvalues: always keep at least this number of values
0734         @param maxvalues: maximum values to keep for any entry (you
0735              can supply None in which case no old entries will be removed
0736              based on how many there are).
0737         @param keepoldest: values older than this number of days before
0738              now are removed.  You can also supply None in which case no
0739              entries will be removed based on age.
0740         @returns: number of rows removed,number of rows remaining
0741              """
0742         if not self.doestableexist(tablename):
0743             return (0,0)
0744         
0745         timecutoff=0
0746         if keepoldest is not None:
0747             timecutoff=time.time()-(keepoldest*24*60*60)
0748         if maxvalues is None:
0749             maxvalues=sys.maxint-1
0750 
0751         if uids is None:
0752             uids=[u[0] for u in self.sql("select distinct __uid__ from %s" % (idquote(tablename),))]
0753 
0754         deleterows=[]
0755 
0756         for uid in uids:
0757             deleting=False
0758             for count, (rowid, deleted, timestamp) in enumerate(
0759                 self.sql("select __rowid__,__deleted__, __timestamp__ from %s where __uid__=? order by __rowid__ desc" % (idquote(tablename),), [uid])):
0760                 if count<minvalues:
0761                     continue
0762                 if deleting:
0763                     deleterows.append(rowid)
0764                     continue
0765                 if count>=maxvalues or timestamp<timecutoff:
0766                     deleting=True
0767                     if deleted:
0768                         # we are ok, this is an old value now deleted, so we can remove it
0769                         deleterows.append(rowid)
0770                         continue
0771                     # we don't want to delete current data (which may
0772                     # be very old and never updated)
0773                     if count>0:
0774                         deleterows.append(rowid)
0775                     continue
0776 
0777         self.sqlmany("delete from %s where __rowid__=?" % (idquote(tablename),), [(r,) for r in deleterows])
0778 
0779         return len(deleterows), self.sql("select count(*) from "+idquote(tablename)).next()[0]
0780 
0781     @ExclusiveWrapper
0782     def savelist(self, tablename, values):
0783         """Just save a list of items (eg categories).  There is no versioning or transaction history.
0784 
0785         Internally the table has two fields.  One is the actual value and the other indicates if
0786         the item is deleted.
0787         """
0788 
0789         # a tuple of the quoted table name
0790         tn=(idquote(tablename),)
0791         
0792         if not self.doestableexist(tablename):
0793             self.sql("create table %s (__rowid__ integer primary key, item, __deleted__ integer)" % tn)
0794 
0795         # some code to demonstrate my lack of experience with SQL ....
0796         delete=[]
0797         known=[]
0798         revive=[]
0799         for row, item, dead in self.sql("select __rowid__,item,__deleted__ from %s" % tn):
0800             known.append(item)
0801             if item in values:
0802                 # we need this row
0803                 if dead:
0804                     revive.append((row,))
0805                 continue
0806             if dead:
0807                 # don't need this entry and it is dead anyway
0808                 continue
0809             delete.append((row,))
0810         create=[(v,) for v in values if v not in known]
0811 
0812         # update table as appropriate
0813         self.sqlmany("update %s set __deleted__=0 where __rowid__=?" % tn, revive)
0814         self.sqlmany("update %s set __deleted__=1 where __rowid__=?" % tn, delete)
0815         self.sqlmany("insert into %s (item, __deleted__) values (?,0)" % tn, create)
0816         if __debug__:
0817             vdup=values[:]
0818             vdup.sort()
0819             vv=self.loadlist(tablename)
0820             vv.sort()
0821             assert vdup==vv
0822 
0823     @ExclusiveWrapper
0824     def loadlist(self, tablename):
0825         """Loads a list of items (eg categories)"""
0826         if not self.doestableexist(tablename):
0827             return []
0828         return [v[0] for v in self.sql("select item from %s where __deleted__=0" % (idquote(tablename),))]
0829         
0830     @ExclusiveWrapper
0831     def getchangescount(self, tablename):
0832         """Return the number of additions, deletions, and modifications
0833         made to this table over time.
0834         Expected fields containted in this table: __timestamp__,__deleted__,
0835         __uid__
0836         Assuming that both __rowid__ and __timestamp__ values are both ascending
0837         """
0838         if not self.doestableexist(tablename):
0839             return {}
0840         tn=idquote(tablename)
0841         # get the unique dates of changes
0842         sql_cmd='select distinct __timestamp__ from %s' % tn
0843         # setting up the return dict
0844         res={}
0845         for t in self.sql(sql_cmd):
0846             res[t[0]]={ 'add': 0, 'del': 0, 'mod': 0 }
0847         # go through the table and count the changes
0848         existing_uid={}
0849         sql_cmd='select __timestamp__,__uid__,__deleted__ from %s order by __timestamp__ asc' % tn
0850         for e in self.sql(sql_cmd):
0851             tt=e[0]
0852             uid=e[1]
0853             del_flg=e[2]
0854             if existing_uid.has_key(uid):
0855                 if del_flg:
0856                     res[tt]['del']+=1
0857                     del existing_uid[uid]
0858                 else:
0859                     res[tt]['mod']+=1
0860             else:
0861                 existing_uid[uid]=None
0862                 res[tt]['add']+=1
0863         return res
0864 
0865 class ModuleBase(object):
0866     """Base class to implement a specific Virtual Table module with apsw.
0867     For more info:
0868     http://www.sqlite.org/cvstrac/wiki/wiki?p=VirtualTables
0869     http://www.sqlite.org/cvstrac/wiki/wiki?p=VirtualTableMethods
0870     http://www.sqlite.org/cvstrac/wiki/wiki?p=VirtualTableBestIndexMethod
0871     """
0872     def __init__(self, field_names):
0873         self.connection=None
0874         self.table_name=None
0875         # the first field is ALWAYS __rowid__ to be consistent with Database
0876         self.field_names=('__rowid__',)+field_names
0877     def Create(self, connection, modulename, databasename, vtablename, *args):
0878         """Called when the virtual table is created.
0879         @param connection: an instance of apsw.Connection
0880         @param modulename: string name of the module being invoked
0881         @param databasename: string name of this database
0882         @param vtablename: string name of this new virtual table
0883         @param args: additional arguments sent from the CREATE VIRTUAL TABLE
0884             statement
0885         @returns: a tuple of 2 values: an sql string describing the table, and
0886             an object implementing it: Me!
0887         """
0888         self.table_name=vtablename
0889         fields=['__rowid__ integer primary key']
0890         for field in self.field_names[1:]:
0891             fields.append(idquote(field)+' valueBLOB')
0892         fields='(%s)'%','.join(fields)
0893         return ('create table %s %s;'%(idquote(vtablename), fields), self)
0894 
0895     def Connect(self, connection, modulename, databasename, vtablename, *args):
0896         """Connect to an existing virtual table, by default it is identical
0897         to Create
0898         """
0899         return self.Create(connection, modulename, databasename, vtablename,
0900                            *args)
0901 
0902     def Destroy(self):
0903         """Release a connection to a virtual table and destroy the underlying
0904         table implementation.  By default, we do nothing.
0905         """
0906         pass
0907     def Disconnect(self):
0908         """Release a connection to a virtual table.  By default, we do nothing.
0909         """
0910         pass
0911 
0912     def BestIndex(self, constraints, orderby):
0913         """Provide information on how to best access this table.
0914         Must be overriden by subclass.
0915         @param constraints: a tuple of (column #, op) defining a constraints
0916         @param orderby: a tuple of (column #, desc) defining the order by
0917         @returns a tuple of up to 5 values:
0918             0: aConstraingUsage: a tuple of the same size as constraints.
0919                 Each item is either None, argv index(int), or (argv index, omit(Bool)).
0920             1: idxNum(int)
0921             2: idxStr(string)
0922             3: orderByConsumed(Bool)
0923             4: estimatedCost(float)
0924         """
0925         raise NotImplementedError
0926 
0927     def Begin(self):
0928         pass
0929     def Sync(self):
0930         pass
0931     def Commit(self):
0932         pass
0933     def Rollback(self):
0934         pass
0935 
0936     def Open(self):
0937         """Create/prepare a cursor used for subsequent reading.
0938         @returns: the implementor object: Me!
0939         """
0940         return self
0941     def Close(self):
0942         """Close a cursor previously created by Open
0943         By default, do nothing
0944         """
0945         pass
0946     def Filter(self, idxNum, idxStr, argv):
0947         """Begin a search of a virtual table.
0948         @param idxNum: int value passed by BestIndex
0949         @param idxStr: string valued passed by BestIndex
0950         @param argv: constraint parameters requested by BestIndex
0951         @returns: None
0952         """
0953         raise NotImplementedError
0954     def Eof(self):
0955         """Determines if the current cursor points to a valid row.
0956         The Sqlite doc is wrong on this.
0957         @returns: True if NOT valid row, False otherwise
0958         """
0959         raise NotImplementedError
0960     def Column(self, N):
0961         """Find the value for the N-th column of the current row.
0962         @param N: the N-th column
0963         @returns: value of the N-th column
0964         """
0965         raise NotImplementedError
0966     def Next(self):
0967         """Move the cursor to the next row.
0968         @returns: None
0969         """
0970         raise NotImplementedError
0971     def Rowid(self):
0972         """Return the rowid of the current row.
0973         @returns: the rowid(int) of the current row.
0974         """
0975         raise NotImplementedError
0976     def UpdateDeleteRow(self, rowid):
0977         """Delete row rowid
0978         @param rowid:
0979         @returns: None
0980         """
0981         raise NotImplementedError
0982     def UpdateInsertRow(self, rowid, fields):
0983         """Insert a new row of data into the table
0984         @param rowid: if not None, use this rowid.  If None, create a new rowid
0985         @param fields: a tuple of the field values in the order declared in
0986             Create/Connet
0987         @returns: rowid of the new row.
0988         """
0989         raise NotImplementedError
0990     def UpdateChangeRow(self, rowid, newrowid, fields):
0991         """Change the row of the current rowid with the new rowid and new values
0992         @param rowid: rowid of the current row
0993         @param newrowid: new rowid
0994         @param fields: a tuple of the field values in the order declared in
0995             Create/Connect
0996         @returns: rowid of the new row
0997         """
0998         raise NotImplementedError
0999 
1000 if __name__=='__main__':
1001     import common
1002     import sys
1003     import time
1004     import os
1005 
1006     sys.excepthook=common.formatexceptioneh
1007 
1008 
1009     # our own hacked version for testing
1010     class phonebookdataobject(basedataobject):
1011         # no change to _knownproperties (all of ours are list properties)
1012         _knownlistproperties=basedataobject._knownlistproperties.copy()
1013         _knownlistproperties.update( {'names': ['title', 'first', 'middle', 'last', 'full', 'nickname'],
1014                                       'categories': ['category'],
1015                                       'emails': ['email', 'type'],
1016                                       'urls': ['url', 'type'],
1017                                       'ringtones': ['ringtone', 'use'],
1018                                       'addresses': ['type', 'company', 'street', 'street2', 'city', 'state', 'postalcode', 'country'],
1019                                       'wallpapers': ['wallpaper', 'use'],
1020                                       'flags': ['secret'],
1021                                       'memos': ['memo'],
1022                                       'numbers': ['number', 'type', 'speeddial'],
1023                                       # serials is in parent object
1024                                       })
1025         _knowndictproperties=basedataobject._knowndictproperties.copy()
1026         _knowndictproperties.update( {'repeat': ['daily', 'orange']} )
1027 
1028     phonebookobjectfactory=dataobjectfactory(phonebookdataobject)
1029     
1030     # use the phonebook out of the examples directory
1031     try:
1032         execfile(os.getenv("DBTESTFILE", "examples/phonebook-index.idx"))
1033     except UnicodeError:
1034         common.unicode_execfile(os.getenv("DBTESTFILE", "examples/phonebook-index.idx"))
1035 
1036     ensurerecordtype(phonebook, phonebookobjectfactory)
1037 
1038     phonebookmaster=phonebook
1039 
1040     def testfunc():
1041         global phonebook, TRACE, db
1042 
1043         # note that iterations increases the size of the
1044         # database/journal and will make each one take longer and
1045         # longer as the db/journal gets bigger
1046         if len(sys.argv)>=2:
1047             iterations=int(sys.argv[1])
1048         else:
1049             iterations=1
1050         if iterations >1:
1051             TRACE=False
1052 
1053         db=Database("testdb")
1054 
1055 
1056         b4=time.time()
1057         
1058 
1059         for i in xrange(iterations):
1060             phonebook=phonebookmaster.copy()
1061             
1062             # write it out
1063             db.savemajordict("phonebook", extractbitpimserials(phonebook))
1064 
1065             # check what we get back is identical
1066             v=db.getmajordictvalues("phonebook")
1067             assert v==extractbitpimserials(phonebook)
1068 
1069             # do a deletion
1070             del phonebook[17] # james bond @ microsoft
1071             db.savemajordict("phonebook", extractbitpimserials(phonebook))
1072             # and verify
1073             v=db.getmajordictvalues("phonebook")
1074             assert v==extractbitpimserials(phonebook)
1075 
1076             # modify a value
1077             phonebook[15]['addresses'][0]['city']="Bananarama"
1078             db.savemajordict("phonebook", extractbitpimserials(phonebook))
1079             # and verify
1080             v=db.getmajordictvalues("phonebook")
1081             assert v==extractbitpimserials(phonebook)
1082 
1083         after=time.time()
1084 
1085         print "time per iteration is",(after-b4)/iterations,"seconds"
1086         print "total time was",after-b4,"seconds for",iterations,"iterations"
1087 
1088         if iterations>1:
1089             print "testing repeated reads"
1090             b4=time.time()
1091             for i in xrange(iterations*10):
1092                 db.getmajordictvalues("phonebook")
1093             after=time.time()
1094             print "\ttime per iteration is",(after-b4)/(iterations*10),"seconds"
1095             print "\ttotal time was",after-b4,"seconds for",iterations*10,"iterations"
1096             print
1097             print "testing repeated writes"
1098             x=extractbitpimserials(phonebook)
1099             k=x.keys()
1100             b4=time.time()
1101             for i in xrange(iterations*10):
1102                 # we remove 1/3rd of the entries on each iteration
1103                 xcopy=x.copy()
1104                 for l in range(i,i+len(k)/3):
1105                     del xcopy[k[l%len(x)]]
1106                 db.savemajordict("phonebook",xcopy)
1107             after=time.time()
1108             print "\ttime per iteration is",(after-b4)/(iterations*10),"seconds"
1109             print "\ttotal time was",after-b4,"seconds for",iterations*10,"iterations"
1110 
1111 
1112 
1113     if len(sys.argv)==3:
1114         # also run under hotspot then
1115         def profile(filename, command):
1116             import hotshot, hotshot.stats, os
1117             file=os.path.abspath(filename)
1118             profile=hotshot.Profile(file)
1119             profile.run(command)
1120             profile.close()
1121             del profile
1122             howmany=100
1123             stats=hotshot.stats.load(file)
1124             stats.strip_dirs()
1125             stats.sort_stats('time', 'calls')
1126             stats.print_stats(100)
1127             stats.sort_stats('cum', 'calls')
1128             stats.print_stats(100)
1129             stats.sort_stats('calls', 'time')
1130             stats.print_stats(100)
1131             sys.exit(0)
1132 
1133         profile("dbprof", "testfunc()")
1134 
1135     else:
1136         testfunc()
1137         
1138 
1139 

Generated by PyXR 0.9.4