aboutsummaryrefslogtreecommitdiffstats
path: root/util/sqlite3.lua
diff options
context:
space:
mode:
Diffstat (limited to 'util/sqlite3.lua')
-rw-r--r--util/sqlite3.lua413
1 files changed, 413 insertions, 0 deletions
diff --git a/util/sqlite3.lua b/util/sqlite3.lua
new file mode 100644
index 00000000..83c0d26d
--- /dev/null
+++ b/util/sqlite3.lua
@@ -0,0 +1,413 @@
+
+-- luacheck: ignore 113/unpack 211 212 411 213
+local setmetatable, getmetatable = setmetatable, getmetatable;
+local ipairs, unpack, select = ipairs, table.unpack or unpack, select;
+local tonumber, tostring = tonumber, tostring;
+local assert, xpcall, debug_traceback = assert, xpcall, debug.traceback;
+local error = error
+local type = type
+local t_concat = table.concat;
+local t_insert = table.insert;
+local s_char = string.char;
+local log = require "prosody.util.logger".init("sql");
+
+local lsqlite3 = require "lsqlite3";
+local build_url = require "socket.url".build;
+local ROW, DONE = lsqlite3.ROW, lsqlite3.DONE;
+
+-- from sqlite3.h, no copyright claimed
+local sqlite_errors = require"prosody.util.error".init("util.sqlite3", {
+ -- FIXME xmpp error conditions?
+ [1] = { code = 1; type = "modify"; condition = "ERROR"; text = "Generic error" };
+ [2] = { code = 2; type = "cancel"; condition = "INTERNAL"; text = "Internal logic error in SQLite" };
+ [3] = { code = 3; type = "auth"; condition = "PERM"; text = "Access permission denied" };
+ [4] = { code = 4; type = "cancel"; condition = "ABORT"; text = "Callback routine requested an abort" };
+ [5] = { code = 5; type = "wait"; condition = "BUSY"; text = "The database file is locked" };
+ [6] = { code = 6; type = "wait"; condition = "LOCKED"; text = "A table in the database is locked" };
+ [7] = { code = 7; type = "wait"; condition = "NOMEM"; text = "A malloc() failed" };
+ [8] = { code = 8; type = "cancel"; condition = "READONLY"; text = "Attempt to write a readonly database" };
+ [9] = { code = 9; type = "cancel"; condition = "INTERRUPT"; text = "Operation terminated by sqlite3_interrupt()" };
+ [10] = { code = 10; type = "wait"; condition = "IOERR"; text = "Some kind of disk I/O error occurred" };
+ [11] = { code = 11; type = "cancel"; condition = "CORRUPT"; text = "The database disk image is malformed" };
+ [12] = { code = 12; type = "modify"; condition = "NOTFOUND"; text = "Unknown opcode in sqlite3_file_control()" };
+ [13] = { code = 13; type = "wait"; condition = "FULL"; text = "Insertion failed because database is full" };
+ [14] = { code = 14; type = "auth"; condition = "CANTOPEN"; text = "Unable to open the database file" };
+ [15] = { code = 15; type = "cancel"; condition = "PROTOCOL"; text = "Database lock protocol error" };
+ [16] = { code = 16; type = "continue"; condition = "EMPTY"; text = "Internal use only" };
+ [17] = { code = 17; type = "modify"; condition = "SCHEMA"; text = "The database schema changed" };
+ [18] = { code = 18; type = "modify"; condition = "TOOBIG"; text = "String or BLOB exceeds size limit" };
+ [19] = { code = 19; type = "modify"; condition = "CONSTRAINT"; text = "Abort due to constraint violation" };
+ [20] = { code = 20; type = "modify"; condition = "MISMATCH"; text = "Data type mismatch" };
+ [21] = { code = 21; type = "modify"; condition = "MISUSE"; text = "Library used incorrectly" };
+ [22] = { code = 22; type = "cancel"; condition = "NOLFS"; text = "Uses OS features not supported on host" };
+ [23] = { code = 23; type = "auth"; condition = "AUTH"; text = "Authorization denied" };
+ [24] = { code = 24; type = "modify"; condition = "FORMAT"; text = "Not used" };
+ [25] = { code = 25; type = "modify"; condition = "RANGE"; text = "2nd parameter to sqlite3_bind out of range" };
+ [26] = { code = 26; type = "cancel"; condition = "NOTADB"; text = "File opened that is not a database file" };
+ [27] = { code = 27; type = "continue"; condition = "NOTICE"; text = "Notifications from sqlite3_log()" };
+ [28] = { code = 28; type = "continue"; condition = "WARNING"; text = "Warnings from sqlite3_log()" };
+ [100] = { code = 100; type = "continue"; condition = "ROW"; text = "sqlite3_step() has another row ready" };
+ [101] = { code = 101; type = "continue"; condition = "DONE"; text = "sqlite3_step() has finished executing" };
+});
+
+local assert = function(cond, errno, err)
+ return assert(sqlite_errors.coerce(cond, err or errno));
+end
+local _ENV = nil;
+-- luacheck: std none
+
+local column_mt = {};
+local table_mt = {};
+local query_mt = {};
+--local op_mt = {};
+local index_mt = {};
+
+local function is_column(x) return getmetatable(x)==column_mt; end
+local function is_index(x) return getmetatable(x)==index_mt; end
+local function is_table(x) return getmetatable(x)==table_mt; end
+local function is_query(x) return getmetatable(x)==query_mt; end
+local function Integer(n) return "Integer()" end
+local function String(n) return "String()" end
+
+local function Column(definition)
+ return setmetatable(definition, column_mt);
+end
+local function Table(definition)
+ local c = {}
+ for i,col in ipairs(definition) do
+ if is_column(col) then
+ c[i], c[col.name] = col, col;
+ elseif is_index(col) then
+ col.table = definition.name;
+ end
+ end
+ return setmetatable({ __table__ = definition, c = c, name = definition.name }, table_mt);
+end
+local function Index(definition)
+ return setmetatable(definition, index_mt);
+end
+
+function table_mt:__tostring()
+ local s = { 'name="'..self.__table__.name..'"' }
+ for i,col in ipairs(self.__table__) do
+ s[#s+1] = tostring(col);
+ end
+ return 'Table{ '..t_concat(s, ", ")..' }'
+end
+table_mt.__index = {};
+function table_mt.__index:create(engine)
+ return engine:_create_table(self);
+end
+function table_mt:__call(...)
+ -- TODO
+end
+function column_mt:__tostring()
+ return 'Column{ name="'..self.name..'", type="'..self.type..'" }'
+end
+function index_mt:__tostring()
+ local s = 'Index{ name="'..self.name..'"';
+ for i=1,#self do s = s..', "'..self[i]:gsub("[\\\"]", "\\%1")..'"'; end
+ return s..' }';
+-- return 'Index{ name="'..self.name..'", type="'..self.type..'" }'
+end
+
+local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end
+local function parse_url(url)
+ local scheme, secondpart, database = url:match("^([%w%+]+)://([^/]*)/?(.*)");
+ assert(scheme, "Invalid URL format");
+ local username, password, host, port;
+ local authpart, hostpart = secondpart:match("([^@]+)@([^@+])");
+ if not authpart then hostpart = secondpart; end
+ if authpart then
+ username, password = authpart:match("([^:]*):(.*)");
+ username = username or authpart;
+ password = password and urldecode(password);
+ end
+ if hostpart then
+ host, port = hostpart:match("([^:]*):(.*)");
+ host = host or hostpart;
+ port = port and assert(tonumber(port), "Invalid URL format");
+ end
+ return {
+ scheme = scheme:lower();
+ username = username; password = password;
+ host = host; port = port;
+ database = #database > 0 and database or nil;
+ };
+end
+
+local engine = {};
+function engine:connect()
+ if self.conn then return true; end
+
+ local params = self.params;
+ assert(params.driver == "SQLite3", "Only sqlite3 is supported");
+ local dbh, err = sqlite_errors.coerce(lsqlite3.open(params.database));
+ if not dbh then return nil, err; end
+ self.conn = dbh;
+ self.prepared = {};
+ local ok, err = self:set_encoding();
+ if not ok then
+ return ok, err;
+ end
+ local ok, err = self:onconnect();
+ if ok == false then
+ return ok, err;
+ end
+ return true;
+end
+function engine:onconnect()
+ -- Override from create_engine()
+end
+function engine:ondisconnect() -- luacheck: ignore 212/self
+ -- Override from create_engine()
+end
+function engine:execute(sql, ...)
+ local success, err = self:connect();
+ if not success then return success, err; end
+ local prepared = self.prepared;
+
+ if select('#', ...) == 0 then
+ local ret = self.conn:exec(sql);
+ if ret ~= lsqlite3.OK then
+ local err = sqlite_errors.new(err);
+ err.text = self.conn:errmsg();
+ return err;
+ end
+ return true;
+ end
+
+ local stmt = prepared[sql];
+ if not stmt then
+ local err;
+ stmt, err = self.conn:prepare(sql);
+ if not stmt then
+ err = sqlite_errors.new(err);
+ err.text = self.conn:errmsg();
+ return stmt, err;
+ end
+ prepared[sql] = stmt;
+ end
+
+ local ret = stmt:bind_values(...);
+ if ret ~= lsqlite3.OK then return nil, sqlite_errors.new(ret, { message = self.conn:errmsg() }); end
+ return stmt;
+end
+
+local result_mt = {
+ __index = {
+ affected = function(self) return self.__affected; end;
+ rowcount = function(self) return self.__rowcount; end;
+ },
+};
+
+local function iterator(table)
+ local i=0;
+ return function()
+ i=i+1;
+ local item=table[i];
+ if item ~= nil then
+ return item;
+ end
+ end
+end
+
+local function debugquery(where, sql, ...)
+ local i = 0; local a = {...}
+ sql = sql:gsub("\n?\t+", " ");
+ log("debug", "[%s] %s", where, (sql:gsub("%?", function ()
+ i = i + 1;
+ local v = a[i];
+ if type(v) == "string" then
+ v = ("'%s'"):format(v:gsub("'", "''"));
+ end
+ return tostring(v);
+ end)));
+end
+
+function engine:execute_query(sql, ...)
+ local prepared = self.prepared;
+ local stmt = prepared[sql];
+ if stmt and stmt:isopen() then
+ prepared[sql] = nil; -- Can't be used concurrently
+ else
+ stmt = assert(self.conn:prepare(sql));
+ end
+ local ret = stmt:bind_values(...);
+ if ret ~= lsqlite3.OK then error(self.conn:errmsg()); end
+ local data, ret = {}
+ while stmt:step() == ROW do
+ t_insert(data, stmt:get_values());
+ end
+ -- FIXME Error handling, BUSY, ERROR, MISUSE
+ if stmt:reset() == lsqlite3.OK then
+ prepared[sql] = stmt;
+ end
+ return setmetatable({ __data = data }, { __index = result_mt.__index, __call = iterator(data) });
+end
+function engine:execute_update(sql, ...)
+ local prepared = self.prepared;
+ local stmt = prepared[sql];
+ if not stmt or not stmt:isopen() then
+ stmt = assert(self.conn:prepare(sql));
+ else
+ prepared[sql] = nil;
+ end
+ local ret = stmt:bind_values(...);
+ if ret ~= lsqlite3.OK then error(self.conn:errmsg()); end
+ local rowcount = 0;
+ repeat
+ ret = stmt:step();
+ if ret == lsqlite3.ROW then
+ rowcount = rowcount + 1;
+ end
+ until ret ~= lsqlite3.ROW;
+ local affected = self.conn:changes();
+ if stmt:reset() == lsqlite3.OK then
+ prepared[sql] = stmt;
+ end
+ return setmetatable({ __affected = affected, __rowcount = rowcount }, result_mt);
+end
+engine.insert = engine.execute_update;
+engine.select = engine.execute_query;
+engine.delete = engine.execute_update;
+engine.update = engine.execute_update;
+local function debugwrap(name, f)
+ return function (self, sql, ...)
+ debugquery(name, sql, ...)
+ return f(self, sql, ...)
+ end
+end
+function engine:debug(enable)
+ self._debug = enable;
+ if enable then
+ engine.insert = debugwrap("insert", engine.execute_update);
+ engine.select = debugwrap("select", engine.execute_query);
+ engine.delete = debugwrap("delete", engine.execute_update);
+ engine.update = debugwrap("update", engine.execute_update);
+ else
+ engine.insert = engine.execute_update;
+ engine.select = engine.execute_query;
+ engine.delete = engine.execute_update;
+ engine.update = engine.execute_update;
+ end
+end
+function engine:_(word)
+ local ret = self.conn:exec(word);
+ if ret ~= lsqlite3.OK then return nil, self.conn:errmsg(); end
+ return true;
+end
+function engine:_transaction(func, ...)
+ if not self.conn then
+ local a,b = self:connect();
+ if not a then return a,b; end
+ end
+ --assert(not self.__transaction, "Recursive transactions not allowed");
+ local ok, err = self:_"BEGIN";
+ if not ok then return ok, err; end
+ self.__transaction = true;
+ local success, a, b, c = xpcall(func, debug_traceback, ...);
+ self.__transaction = nil;
+ if success then
+ log("debug", "SQL transaction success [%s]", tostring(func));
+ local ok, err = self:_"COMMIT";
+ if not ok then return ok, err; end -- commit failed
+ return success, a, b, c;
+ else
+ log("debug", "SQL transaction failure [%s]: %s", tostring(func), a);
+ if self.conn then self:_"ROLLBACK"; end
+ return success, a;
+ end
+end
+function engine:transaction(...)
+ local ok, ret = self:_transaction(...);
+ if not ok then
+ local conn = self.conn;
+ if not conn or not conn:isopen() then
+ self.conn = nil;
+ self:ondisconnect();
+ ok, ret = self:_transaction(...);
+ end
+ end
+ return ok, ret;
+end
+function engine:_create_index(index)
+ local sql = "CREATE INDEX IF NOT EXISTS \""..index.name.."\" ON \""..index.table.."\" (";
+ for i=1,#index do
+ sql = sql.."\""..index[i].."\"";
+ if i ~= #index then sql = sql..", "; end
+ end
+ sql = sql..");"
+ if index.unique then
+ sql = sql:gsub("^CREATE", "CREATE UNIQUE");
+ end
+ if self._debug then
+ debugquery("create", sql);
+ end
+ return self:execute(sql);
+end
+function engine:_create_table(table)
+ local sql = "CREATE TABLE IF NOT EXISTS \""..table.name.."\" (";
+ for i,col in ipairs(table.c) do
+ local col_type = col.type;
+ sql = sql.."\""..col.name.."\" "..col_type;
+ if col.nullable == false then sql = sql.." NOT NULL"; end
+ if col.primary_key == true then sql = sql.." PRIMARY KEY"; end
+ if col.auto_increment == true then
+ sql = sql.." AUTOINCREMENT";
+ end
+ if i ~= #table.c then sql = sql..", "; end
+ end
+ sql = sql.. ");"
+ if self._debug then
+ debugquery("create", sql);
+ end
+ local success,err = self:execute(sql);
+ if not success then return success,err; end
+ for i,v in ipairs(table.__table__) do
+ if is_index(v) then
+ self:_create_index(v);
+ end
+ end
+ return success;
+end
+function engine:set_encoding() -- to UTF-8
+ return self:transaction(function()
+ for encoding in self:select"PRAGMA encoding;" do
+ if encoding[1] == "UTF-8" then
+ self.charset = "utf8";
+ end
+ end
+ end);
+end
+local engine_mt = { __index = engine };
+
+local function db2uri(params)
+ return build_url{
+ scheme = params.driver,
+ user = params.username,
+ password = params.password,
+ host = params.host,
+ port = params.port,
+ path = params.database,
+ };
+end
+
+local function create_engine(_, params, onconnect, ondisconnect)
+ assert(params.driver == "SQLite3", "Only SQLite3 is supported without LuaDBI");
+ return setmetatable({ url = db2uri(params); params = params; onconnect = onconnect; ondisconnect = ondisconnect }, engine_mt);
+end
+
+return {
+ is_column = is_column;
+ is_index = is_index;
+ is_table = is_table;
+ is_query = is_query;
+ Integer = Integer;
+ String = String;
+ Column = Column;
+ Table = Table;
+ Index = Index;
+ create_engine = create_engine;
+ db2uri = db2uri;
+};