aboutsummaryrefslogtreecommitdiffstats
path: root/util/sql.lua
diff options
context:
space:
mode:
Diffstat (limited to 'util/sql.lua')
-rw-r--r--util/sql.lua316
1 files changed, 180 insertions, 136 deletions
diff --git a/util/sql.lua b/util/sql.lua
index f360d6d0..15749911 100644
--- a/util/sql.lua
+++ b/util/sql.lua
@@ -1,8 +1,9 @@
local setmetatable, getmetatable = setmetatable, getmetatable;
-local ipairs, unpack, select = ipairs, unpack, select;
+local ipairs, unpack, select = ipairs, table.unpack or unpack, select; --luacheck: ignore 113
local tonumber, tostring = tonumber, tostring;
-local assert, xpcall, debug_traceback = assert, xpcall, debug.traceback;
+local type = type;
+local assert, pcall, xpcall, debug_traceback = assert, pcall, xpcall, debug.traceback;
local t_concat = table.concat;
local s_char = string.char;
local log = require "util.logger".init("sql");
@@ -13,7 +14,7 @@ local DBI = require "DBI";
DBI.Drivers();
local build_url = require "socket.url".build;
-module("sql")
+local _ENV = nil;
local column_mt = {};
local table_mt = {};
@@ -21,42 +22,17 @@ local query_mt = {};
--local op_mt = {};
local index_mt = {};
-function is_column(x) return getmetatable(x)==column_mt; end
-function is_index(x) return getmetatable(x)==index_mt; end
-function is_table(x) return getmetatable(x)==table_mt; end
-function is_query(x) return getmetatable(x)==query_mt; end
---function is_op(x) return getmetatable(x)==op_mt; end
---function expr(...) return setmetatable({...}, op_mt); end
-function Integer(n) return "Integer()" end
-function String(n) return "String()" end
+local function is_column(x) return getmetatable(x)==column_mt; end
+local function is_index(x) return getmetatable(x)==index_mt; end
+local function is_table(x) return getmetatable(x)==table_mt; end
+local function is_query(x) return getmetatable(x)==query_mt; end
+local function Integer() return "Integer()" end
+local function String() return "String()" end
---[[local ops = {
- __add = function(a, b) return "("..a.."+"..b..")" end;
- __sub = function(a, b) return "("..a.."-"..b..")" end;
- __mul = function(a, b) return "("..a.."*"..b..")" end;
- __div = function(a, b) return "("..a.."/"..b..")" end;
- __mod = function(a, b) return "("..a.."%"..b..")" end;
- __pow = function(a, b) return "POW("..a..","..b..")" end;
- __unm = function(a) return "NOT("..a..")" end;
- __len = function(a) return "COUNT("..a..")" end;
- __eq = function(a, b) return "("..a.."=="..b..")" end;
- __lt = function(a, b) return "("..a.."<"..b..")" end;
- __le = function(a, b) return "("..a.."<="..b..")" end;
-};
-
-local functions = {
-
-};
-
-local cmap = {
- [Integer] = Integer();
- [String] = String();
-};]]
-
-function Column(definition)
+local function Column(definition)
return setmetatable(definition, column_mt);
end
-function Table(definition)
+local function Table(definition)
local c = {}
for i,col in ipairs(definition) do
if is_column(col) then
@@ -67,13 +43,13 @@ function Table(definition)
end
return setmetatable({ __table__ = definition, c = c, name = definition.name }, table_mt);
end
-function Index(definition)
+local function Index(definition)
return setmetatable(definition, index_mt);
end
function table_mt:__tostring()
local s = { 'name="'..self.__table__.name..'"' }
- for i,col in ipairs(self.__table__) do
+ for _, col in ipairs(self.__table__) do
s[#s+1] = tostring(col);
end
return 'Table{ '..t_concat(s, ", ")..' }'
@@ -94,7 +70,6 @@ function index_mt:__tostring()
return s..' }';
-- return 'Index{ name="'..self.name..'", type="'..self.type..'" }'
end
---
local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end
local function parse_url(url)
@@ -121,48 +96,50 @@ local function parse_url(url)
};
end
---[[local session = {};
-
-function session.query(...)
- local rets = {...};
- local query = setmetatable({ __rets = rets, __filters }, query_mt);
- return query;
-end
---
-
-local function db2uri(params)
- return build_url{
- scheme = params.driver,
- user = params.username,
- password = params.password,
- host = params.host,
- port = params.port,
- path = params.database,
- };
-end]]
-
local engine = {};
function engine:connect()
if self.conn then return true; end
local params = self.params;
assert(params.driver, "no driver")
- local dbh, err = DBI.Connect(
+ log("debug", "Connecting to [%s] %s...", params.driver, params.database);
+ local ok, dbh, err = pcall(DBI.Connect,
params.driver, params.database,
params.username, params.password,
params.host, params.port
);
+ if not ok then return ok, dbh; end
if not dbh then return nil, err; end
dbh:autocommit(false); -- don't commit automatically
self.conn = dbh;
self.prepared = {};
+ local ok, err = self:set_encoding();
+ if not ok then
+ return ok, err;
+ end
+ local ok, err = self:onconnect();
+ if ok == false then
+ return ok, err;
+ end
return true;
end
+function engine:onconnect()
+ -- Override from create_engine()
+end
+
+function engine:prepquery(sql)
+ if self.params.driver == "MySQL" then
+ sql = sql:gsub("\"", "`");
+ end
+ return sql;
+end
+
function engine:execute(sql, ...)
local success, err = self:connect();
if not success then return success, err; end
local prepared = self.prepared;
+ sql = self:prepquery(sql);
local stmt = prepared[sql];
if not stmt then
local err;
@@ -177,22 +154,31 @@ function engine:execute(sql, ...)
end
local result_mt = { __index = {
- affected = function(self) return self.__affected; end;
- rowcount = function(self) return self.__rowcount; end;
+ affected = function(self) return self.__stmt:affected(); end;
+ rowcount = function(self) return self.__stmt:rowcount(); end;
} };
+local function debugquery(where, sql, ...)
+ local i = 0; local a = {...}
+ sql = sql:gsub("\n?\t+", " ");
+ log("debug", "[%s] %s", where, sql:gsub("%?", function ()
+ i = i + 1;
+ local v = a[i];
+ if type(v) == "string" then
+ v = ("'%s'"):format(v:gsub("'", "''"));
+ end
+ return tostring(v);
+ end));
+end
+
function engine:execute_query(sql, ...)
- if self.params.driver == "PostgreSQL" then
- sql = sql:gsub("`", "\"");
- end
+ sql = self:prepquery(sql);
local stmt = assert(self.conn:prepare(sql));
assert(stmt:execute(...));
return stmt:rows();
end
function engine:execute_update(sql, ...)
- if self.params.driver == "PostgreSQL" then
- sql = sql:gsub("`", "\"");
- end
+ sql = self:prepquery(sql);
local prepared = self.prepared;
local stmt = prepared[sql];
if not stmt then
@@ -200,22 +186,47 @@ function engine:execute_update(sql, ...)
prepared[sql] = stmt;
end
assert(stmt:execute(...));
- return setmetatable({ __affected = stmt:affected(), __rowcount = stmt:rowcount() }, result_mt);
+ return setmetatable({ __stmt = stmt }, result_mt);
end
engine.insert = engine.execute_update;
engine.select = engine.execute_query;
engine.delete = engine.execute_update;
engine.update = engine.execute_update;
+local function debugwrap(name, f)
+ return function (self, sql, ...)
+ debugquery(name, sql, ...)
+ return f(self, sql, ...)
+ end
+end
+function engine:debug(enable)
+ self._debug = enable;
+ if enable then
+ engine.insert = debugwrap("insert", engine.execute_update);
+ engine.select = debugwrap("select", engine.execute_query);
+ engine.delete = debugwrap("delete", engine.execute_update);
+ engine.update = debugwrap("update", engine.execute_update);
+ else
+ engine.insert = engine.execute_update;
+ engine.select = engine.execute_query;
+ engine.delete = engine.execute_update;
+ engine.update = engine.execute_update;
+ end
+end
+local function handleerr(err)
+ log("error", "Error in SQL transaction: %s", debug_traceback(err, 3));
+ return err;
+end
function engine:_transaction(func, ...)
if not self.conn then
- local a,b = self:connect();
- if not a then return a,b; end
+ local ok, err = self:connect();
+ if not ok then return ok, err; end
end
--assert(not self.__transaction, "Recursive transactions not allowed");
local args, n_args = {...}, select("#", ...);
local function f() return func(unpack(args, 1, n_args)); end
+ log("debug", "SQL transaction begin [%s]", tostring(func));
self.__transaction = true;
- local success, a, b, c = xpcall(f, debug_traceback);
+ local success, a, b, c = xpcall(f, handleerr);
self.__transaction = nil;
if success then
log("debug", "SQL transaction success [%s]", tostring(func));
@@ -229,51 +240,118 @@ function engine:_transaction(func, ...)
end
end
function engine:transaction(...)
- local a,b = self:_transaction(...);
- if not a then
+ local ok, ret = self:_transaction(...);
+ if not ok then
local conn = self.conn;
if not conn or not conn:ping() then
self.conn = nil;
- a,b = self:_transaction(...);
+ ok, ret = self:_transaction(...);
end
end
- return a,b;
+ return ok, ret;
end
function engine:_create_index(index)
- local sql = "CREATE INDEX `"..index.name.."` ON `"..index.table.."` (";
+ local sql = "CREATE INDEX \""..index.name.."\" ON \""..index.table.."\" (";
for i=1,#index do
- sql = sql.."`"..index[i].."`";
+ sql = sql.."\""..index[i].."\"";
if i ~= #index then sql = sql..", "; end
end
sql = sql..");"
- if self.params.driver == "PostgreSQL" then
- sql = sql:gsub("`", "\"");
- elseif self.params.driver == "MySQL" then
- sql = sql:gsub("`([,)])", "`(20)%1");
+ if self.params.driver == "MySQL" then
+ sql = sql:gsub("\"([,)])", "\"(20)%1");
+ end
+ if index.unique then
+ sql = sql:gsub("^CREATE", "CREATE UNIQUE");
+ end
+ if self._debug then
+ debugquery("create", sql);
end
- --print(sql);
return self:execute(sql);
end
function engine:_create_table(table)
- local sql = "CREATE TABLE `"..table.name.."` (";
+ local sql = "CREATE TABLE \""..table.name.."\" (";
for i,col in ipairs(table.c) do
- sql = sql.."`"..col.name.."` "..col.type;
+ local col_type = col.type;
+ if col_type == "MEDIUMTEXT" and self.params.driver ~= "MySQL" then
+ col_type = "TEXT"; -- MEDIUMTEXT is MySQL-specific
+ end
+ if col.auto_increment == true and self.params.driver == "PostgreSQL" then
+ col_type = "BIGSERIAL";
+ end
+ sql = sql.."\""..col.name.."\" "..col_type;
if col.nullable == false then sql = sql.." NOT NULL"; end
+ if col.primary_key == true then sql = sql.." PRIMARY KEY"; end
+ if col.auto_increment == true then
+ if self.params.driver == "MySQL" then
+ sql = sql.." AUTO_INCREMENT";
+ elseif self.params.driver == "SQLite3" then
+ sql = sql.." AUTOINCREMENT";
+ end
+ end
if i ~= #table.c then sql = sql..", "; end
end
sql = sql.. ");"
- if self.params.driver == "PostgreSQL" then
- sql = sql:gsub("`", "\"");
+ if self.params.driver == "MySQL" then
+ sql = sql:gsub(";$", (" CHARACTER SET '%s' COLLATE '%s_bin';"):format(self.charset, self.charset));
+ end
+ if self._debug then
+ debugquery("create", sql);
end
local success,err = self:execute(sql);
if not success then return success,err; end
- for i,v in ipairs(table.__table__) do
+ for _, v in ipairs(table.__table__) do
if is_index(v) then
self:_create_index(v);
end
end
return success;
end
+function engine:set_encoding() -- to UTF-8
+ local driver = self.params.driver;
+ if driver == "SQLite3" then
+ return self:transaction(function()
+ for encoding in self:select"PRAGMA encoding;" do
+ if encoding[1] == "UTF-8" then
+ self.charset = "utf8";
+ end
+ end
+ end);
+ end
+ local set_names_query = "SET NAMES '%s';"
+ local charset = "utf8";
+ if driver == "MySQL" then
+ self:transaction(function()
+ for row in self:select"SELECT \"CHARACTER_SET_NAME\" FROM \"information_schema\".\"CHARACTER_SETS\" WHERE \"CHARACTER_SET_NAME\" LIKE 'utf8%' ORDER BY MAXLEN DESC LIMIT 1;" do
+ charset = row and row[1] or charset;
+ end
+ end);
+ set_names_query = set_names_query:gsub(";$", (" COLLATE '%s';"):format(charset.."_bin"));
+ end
+ self.charset = charset;
+ log("debug", "Using encoding '%s' for database connection", charset);
+ local ok, err = self:transaction(function() return self:execute(set_names_query:format(charset)); end);
+ if not ok then
+ return ok, err;
+ end
+
+ if driver == "MySQL" then
+ local ok, actual_charset = self:transaction(function ()
+ return self:select"SHOW SESSION VARIABLES LIKE 'character_set_client'";
+ end);
+ local charset_ok = true;
+ for row in actual_charset do
+ if row[2] ~= charset then
+ log("error", "MySQL %s is actually %q (expected %q)", row[1], row[2], charset);
+ charset_ok = false;
+ end
+ end
+ if not charset_ok then
+ return false, "Failed to set connection encoding";
+ end
+ end
+
+ return true;
+end
local engine_mt = { __index = engine };
local function db2uri(params)
@@ -286,55 +364,21 @@ local function db2uri(params)
path = params.database,
};
end
-local engine_cache = {}; -- TODO make weak valued
-function create_engine(self, params)
- local url = db2uri(params);
- if not engine_cache[url] then
- local engine = setmetatable({ url = url, params = params }, engine_mt);
- engine_cache[url] = engine;
- end
- return engine_cache[url];
-end
-
---[[Users = Table {
- name="users";
- Column { name="user_id", type=String(), primary_key=true };
-};
-print(Users)
-print(Users.c.user_id)]]
-
---local engine = create_engine('postgresql://scott:tiger@localhost:5432/mydatabase');
---[[local engine = create_engine{ driver = "SQLite3", database = "./alchemy.sqlite" };
-
-local i = 0;
-for row in assert(engine:execute("select * from sqlite_master")):rows(true) do
- i = i+1;
- print(i);
- for k,v in pairs(row) do
- print("",k,v);
- end
+local function create_engine(self, params, onconnect)
+ return setmetatable({ url = db2uri(params), params = params, onconnect = onconnect }, engine_mt);
end
-print("---")
-Prosody = Table {
- name="prosody";
- Column { name="host", type="TEXT", nullable=false };
- Column { name="user", type="TEXT", nullable=false };
- Column { name="store", type="TEXT", nullable=false };
- Column { name="key", type="TEXT", nullable=false };
- Column { name="type", type="TEXT", nullable=false };
- Column { name="value", type="TEXT", nullable=false };
- Index { name="prosody_index", "host", "user", "store", "key" };
+return {
+ is_column = is_column;
+ is_index = is_index;
+ is_table = is_table;
+ is_query = is_query;
+ Integer = Integer;
+ String = String;
+ Column = Column;
+ Table = Table;
+ Index = Index;
+ create_engine = create_engine;
+ db2uri = db2uri;
};
---print(Prosody);
-assert(engine:transaction(function()
- assert(Prosody:create(engine));
-end));
-
-for row in assert(engine:execute("select user from prosody")):rows(true) do
- print("username:", row['username'])
-end
---result.close();]]
-
-return _M;