diff options
author | Kim Alvefur <zash@zash.se> | 2022-08-01 15:23:33 +0200 |
---|---|---|
committer | Kim Alvefur <zash@zash.se> | 2022-08-01 15:23:33 +0200 |
commit | 2dada39e39863f29e64129395057fb07e4ee6046 (patch) | |
tree | 930b3985281392768e018640812db54feffbef8c /util | |
parent | 732d5ad8a8fb23c2a7f75fbff90c923090898957 (diff) | |
download | prosody-2dada39e39863f29e64129395057fb07e4ee6046.tar.gz prosody-2dada39e39863f29e64129395057fb07e4ee6046.zip |
util.sqlite3: SQLite3-only variant of util.sql using LuaSQLite3
http://lua.sqlite.org/
Diffstat (limited to 'util')
-rw-r--r-- | util/sqlite3.lua | 390 |
1 files changed, 390 insertions, 0 deletions
diff --git a/util/sqlite3.lua b/util/sqlite3.lua new file mode 100644 index 00000000..7e0ff94d --- /dev/null +++ b/util/sqlite3.lua @@ -0,0 +1,390 @@ + +-- luacheck: ignore 113/unpack 211 212 411 213 +local setmetatable, getmetatable = setmetatable, getmetatable; +local ipairs, unpack, select = ipairs, table.unpack or unpack, select; +local tonumber, tostring = tonumber, tostring; +local assert, xpcall, debug_traceback = assert, xpcall, debug.traceback; +local error = error +local type = type +local t_concat = table.concat; +local t_insert = table.insert; +local s_char = string.char; +local log = require "util.logger".init("sql"); + +local lsqlite3 = require "lsqlite3"; +local build_url = require "socket.url".build; +local ROW, DONE = lsqlite3.ROW, lsqlite3.DONE; +local err2str = { + [0] = "OK"; + "ERROR"; + "INTERNAL"; + "PERM"; + "ABORT"; + "BUSY"; + "LOCKED"; + "NOMEM"; + "READONLY"; + "INTERRUPT"; + "IOERR"; + "CORRUPT"; + "NOTFOUND"; + "FULL"; + "CANTOPEN"; + "PROTOCOL"; + "EMPTY"; + "SCHEMA"; + "TOOBIG"; + "CONSTRAINT"; + "MISMATCH"; + "MISUSE"; + "NOLFS"; + [24] = "FORMAT"; + [25] = "RANGE"; + [26] = "NOTADB"; + [100] = "ROW"; + [101] = "DONE"; +}; + +local assert = function(cond, errno, err) + return assert(cond, err or err2str[errno]); +end +local _ENV = nil; +-- luacheck: std none + +local column_mt = {}; +local table_mt = {}; +local query_mt = {}; +--local op_mt = {}; +local index_mt = {}; + +local function is_column(x) return getmetatable(x)==column_mt; end +local function is_index(x) return getmetatable(x)==index_mt; end +local function is_table(x) return getmetatable(x)==table_mt; end +local function is_query(x) return getmetatable(x)==query_mt; end +local function Integer(n) return "Integer()" end +local function String(n) return "String()" end + +local function Column(definition) + return setmetatable(definition, column_mt); +end +local function Table(definition) + local c = {} + for i,col in ipairs(definition) do + if is_column(col) then + c[i], c[col.name] = col, col; + elseif is_index(col) then + col.table = definition.name; + end + end + return setmetatable({ __table__ = definition, c = c, name = definition.name }, table_mt); +end +local function Index(definition) + return setmetatable(definition, index_mt); +end + +function table_mt:__tostring() + local s = { 'name="'..self.__table__.name..'"' } + for i,col in ipairs(self.__table__) do + s[#s+1] = tostring(col); + end + return 'Table{ '..t_concat(s, ", ")..' }' +end +table_mt.__index = {}; +function table_mt.__index:create(engine) + return engine:_create_table(self); +end +function table_mt:__call(...) + -- TODO +end +function column_mt:__tostring() + return 'Column{ name="'..self.name..'", type="'..self.type..'" }' +end +function index_mt:__tostring() + local s = 'Index{ name="'..self.name..'"'; + for i=1,#self do s = s..', "'..self[i]:gsub("[\\\"]", "\\%1")..'"'; end + return s..' }'; +-- return 'Index{ name="'..self.name..'", type="'..self.type..'" }' +end + +local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end +local function parse_url(url) + local scheme, secondpart, database = url:match("^([%w%+]+)://([^/]*)/?(.*)"); + assert(scheme, "Invalid URL format"); + local username, password, host, port; + local authpart, hostpart = secondpart:match("([^@]+)@([^@+])"); + if not authpart then hostpart = secondpart; end + if authpart then + username, password = authpart:match("([^:]*):(.*)"); + username = username or authpart; + password = password and urldecode(password); + end + if hostpart then + host, port = hostpart:match("([^:]*):(.*)"); + host = host or hostpart; + port = port and assert(tonumber(port), "Invalid URL format"); + end + return { + scheme = scheme:lower(); + username = username; password = password; + host = host; port = port; + database = #database > 0 and database or nil; + }; +end + +local engine = {}; +function engine:connect() + if self.conn then return true; end + + local params = self.params; + assert(params.driver == "SQLite3", "Only sqlite3 is supported"); + local dbh, err = lsqlite3.open(params.database); + if not dbh then return nil, err2str[err]; end + self.conn = dbh; + self.prepared = {}; + local ok, err = self:set_encoding(); + if not ok then + return ok, err; + end + local ok, err = self:onconnect(); + if ok == false then + return ok, err; + end + return true; +end +function engine:onconnect() + -- Override from create_engine() +end +function engine:execute(sql, ...) + local success, err = self:connect(); + if not success then return success, err; end + local prepared = self.prepared; + + local stmt = prepared[sql]; + if not stmt then + local err; + stmt, err = self.conn:prepare(sql); + if not stmt then return stmt, err; end + prepared[sql] = stmt; + end + + local ret = stmt:bind_values(...); + if ret ~= lsqlite3.OK then return nil, self.conn:errmsg(); end + return stmt; +end + +local result_mt = { + __index = { + affected = function(self) return self.__affected; end; + rowcount = function(self) return self.__rowcount; end; + }, +}; + +local function iterator(table) + local i=0; + return function() + i=i+1; + local item=table[i]; + if item ~= nil then + return item; + end + end +end + +local function debugquery(where, sql, ...) + local i = 0; local a = {...} + sql = sql:gsub("\n?\t+", " "); + log("debug", "[%s] %s", where, (sql:gsub("%?", function () + i = i + 1; + local v = a[i]; + if type(v) == "string" then + v = ("'%s'"):format(v:gsub("'", "''")); + end + return tostring(v); + end))); +end + +function engine:execute_query(sql, ...) + local prepared = self.prepared; + local stmt = prepared[sql]; + if stmt and stmt:isopen() then + prepared[sql] = nil; -- Can't be used concurrently + else + stmt = assert(self.conn:prepare(sql)); + end + local ret = stmt:bind_values(...); + if ret ~= lsqlite3.OK then error(self.conn:errmsg()); end + local data, ret = {} + while stmt:step() == ROW do + t_insert(data, stmt:get_values()); + end + -- FIXME Error handling, BUSY, ERROR, MISUSE + if stmt:reset() == lsqlite3.OK then + prepared[sql] = stmt; + end + return setmetatable({ __data = data }, { __index = result_mt.__index, __call = iterator(data) }); +end +function engine:execute_update(sql, ...) + local prepared = self.prepared; + local stmt = prepared[sql]; + if not stmt or not stmt:isopen() then + stmt = assert(self.conn:prepare(sql)); + else + prepared[sql] = nil; + end + local ret = stmt:bind_values(...); + if ret ~= lsqlite3.OK then error(self.conn:errmsg()); end + local rowcount = 0; + repeat + ret = stmt:step(); + if ret == lsqlite3.ROW then + rowcount = rowcount + 1; + end + until ret ~= lsqlite3.ROW; + local affected = self.conn:changes(); + if stmt:reset() == lsqlite3.OK then + prepared[sql] = stmt; + end + return setmetatable({ __affected = affected, __rowcount = rowcount }, result_mt); +end +engine.insert = engine.execute_update; +engine.select = engine.execute_query; +engine.delete = engine.execute_update; +engine.update = engine.execute_update; +local function debugwrap(name, f) + return function (self, sql, ...) + debugquery(name, sql, ...) + return f(self, sql, ...) + end +end +function engine:debug(enable) + self._debug = enable; + if enable then + engine.insert = debugwrap("insert", engine.execute_update); + engine.select = debugwrap("select", engine.execute_query); + engine.delete = debugwrap("delete", engine.execute_update); + engine.update = debugwrap("update", engine.execute_update); + else + engine.insert = engine.execute_update; + engine.select = engine.execute_query; + engine.delete = engine.execute_update; + engine.update = engine.execute_update; + end +end +function engine:_(word) + local ret = self.conn:exec(word); + if ret ~= lsqlite3.OK then return nil, self.conn:errmsg(); end + return true; +end +function engine:_transaction(func, ...) + if not self.conn then + local a,b = self:connect(); + if not a then return a,b; end + end + --assert(not self.__transaction, "Recursive transactions not allowed"); + local ok, err = self:_"BEGIN"; + if not ok then return ok, err; end + self.__transaction = true; + local success, a, b, c = xpcall(func, debug_traceback, ...); + self.__transaction = nil; + if success then + log("debug", "SQL transaction success [%s]", tostring(func)); + local ok, err = self:_"COMMIT"; + if not ok then return ok, err; end -- commit failed + return success, a, b, c; + else + log("debug", "SQL transaction failure [%s]: %s", tostring(func), a); + if self.conn then self:_"ROLLBACK"; end + return success, a; + end +end +function engine:transaction(...) + local ok, ret = self:_transaction(...); + if not ok then + local conn = self.conn; + if not conn or not conn:isopen() then + self.conn = nil; + ok, ret = self:_transaction(...); + end + end + return ok, ret; +end +function engine:_create_index(index) + local sql = "CREATE INDEX IF NOT EXISTS \""..index.name.."\" ON \""..index.table.."\" ("; + for i=1,#index do + sql = sql.."\""..index[i].."\""; + if i ~= #index then sql = sql..", "; end + end + sql = sql..");" + if index.unique then + sql = sql:gsub("^CREATE", "CREATE UNIQUE"); + end + if self._debug then + debugquery("create", sql); + end + return self:execute(sql); +end +function engine:_create_table(table) + local sql = "CREATE TABLE IF NOT EXISTS \""..table.name.."\" ("; + for i,col in ipairs(table.c) do + local col_type = col.type; + sql = sql.."\""..col.name.."\" "..col_type; + if col.nullable == false then sql = sql.." NOT NULL"; end + if col.primary_key == true then sql = sql.." PRIMARY KEY"; end + if col.auto_increment == true then + sql = sql.." AUTOINCREMENT"; + end + if i ~= #table.c then sql = sql..", "; end + end + sql = sql.. ");" + if self._debug then + debugquery("create", sql); + end + local success,err = self:execute(sql); + if not success then return success,err; end + for i,v in ipairs(table.__table__) do + if is_index(v) then + self:_create_index(v); + end + end + return success; +end +function engine:set_encoding() -- to UTF-8 + return self:transaction(function() + for encoding in self:select"PRAGMA encoding;" do + if encoding[1] == "UTF-8" then + self.charset = "utf8"; + end + end + end); +end +local engine_mt = { __index = engine }; + +local function db2uri(params) + return build_url{ + scheme = params.driver, + user = params.username, + password = params.password, + host = params.host, + port = params.port, + path = params.database, + }; +end + +local function create_engine(_, params, onconnect) + assert(params.driver == "SQLite3", "Only SQLite3 is supported without LuaDBI"); + return setmetatable({ url = db2uri(params), params = params, onconnect = onconnect }, engine_mt); +end + +return { + is_column = is_column; + is_index = is_index; + is_table = is_table; + is_query = is_query; + Integer = Integer; + String = String; + Column = Column; + Table = Table; + Index = Index; + create_engine = create_engine; + db2uri = db2uri; +}; |