diff options
Diffstat (limited to 'util/sql.lua')
-rw-r--r-- | util/sql.lua | 221 |
1 files changed, 107 insertions, 114 deletions
diff --git a/util/sql.lua b/util/sql.lua index f360d6d0..2d5e1774 100644 --- a/util/sql.lua +++ b/util/sql.lua @@ -13,7 +13,7 @@ local DBI = require "DBI"; DBI.Drivers(); local build_url = require "socket.url".build; -module("sql") +local _ENV = nil; local column_mt = {}; local table_mt = {}; @@ -21,42 +21,17 @@ local query_mt = {}; --local op_mt = {}; local index_mt = {}; -function is_column(x) return getmetatable(x)==column_mt; end -function is_index(x) return getmetatable(x)==index_mt; end -function is_table(x) return getmetatable(x)==table_mt; end -function is_query(x) return getmetatable(x)==query_mt; end ---function is_op(x) return getmetatable(x)==op_mt; end ---function expr(...) return setmetatable({...}, op_mt); end -function Integer(n) return "Integer()" end -function String(n) return "String()" end +local function is_column(x) return getmetatable(x)==column_mt; end +local function is_index(x) return getmetatable(x)==index_mt; end +local function is_table(x) return getmetatable(x)==table_mt; end +local function is_query(x) return getmetatable(x)==query_mt; end +local function Integer(n) return "Integer()" end +local function String(n) return "String()" end ---[[local ops = { - __add = function(a, b) return "("..a.."+"..b..")" end; - __sub = function(a, b) return "("..a.."-"..b..")" end; - __mul = function(a, b) return "("..a.."*"..b..")" end; - __div = function(a, b) return "("..a.."/"..b..")" end; - __mod = function(a, b) return "("..a.."%"..b..")" end; - __pow = function(a, b) return "POW("..a..","..b..")" end; - __unm = function(a) return "NOT("..a..")" end; - __len = function(a) return "COUNT("..a..")" end; - __eq = function(a, b) return "("..a.."=="..b..")" end; - __lt = function(a, b) return "("..a.."<"..b..")" end; - __le = function(a, b) return "("..a.."<="..b..")" end; -}; - -local functions = { - -}; - -local cmap = { - [Integer] = Integer(); - [String] = String(); -};]] - -function Column(definition) +local function Column(definition) return setmetatable(definition, column_mt); end -function Table(definition) +local function Table(definition) local c = {} for i,col in ipairs(definition) do if is_column(col) then @@ -67,7 +42,7 @@ function Table(definition) end return setmetatable({ __table__ = definition, c = c, name = definition.name }, table_mt); end -function Index(definition) +local function Index(definition) return setmetatable(definition, index_mt); end @@ -94,7 +69,6 @@ function index_mt:__tostring() return s..' }'; -- return 'Index{ name="'..self.name..'", type="'..self.type..'" }' end --- local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end local function parse_url(url) @@ -121,32 +95,13 @@ local function parse_url(url) }; end ---[[local session = {}; - -function session.query(...) - local rets = {...}; - local query = setmetatable({ __rets = rets, __filters }, query_mt); - return query; -end --- - -local function db2uri(params) - return build_url{ - scheme = params.driver, - user = params.username, - password = params.password, - host = params.host, - port = params.port, - path = params.database, - }; -end]] - local engine = {}; function engine:connect() if self.conn then return true; end local params = self.params; assert(params.driver, "no driver") + log("debug", "Connecting to [%s] %s...", params.driver, params.database); local dbh, err = DBI.Connect( params.driver, params.database, params.username, params.password, @@ -156,8 +111,19 @@ function engine:connect() dbh:autocommit(false); -- don't commit automatically self.conn = dbh; self.prepared = {}; + local ok, err = self:set_encoding(); + if not ok then + return ok, err; + end + local ok, err = self:onconnect(); + if ok == false then + return ok, err; + end return true; end +function engine:onconnect() + -- Override from create_engine() +end function engine:execute(sql, ...) local success, err = self:connect(); if not success then return success, err; end @@ -177,8 +143,8 @@ function engine:execute(sql, ...) end local result_mt = { __index = { - affected = function(self) return self.__affected; end; - rowcount = function(self) return self.__rowcount; end; + affected = function(self) return self.__stmt:affected(); end; + rowcount = function(self) return self.__stmt:rowcount(); end; } }; function engine:execute_query(sql, ...) @@ -200,7 +166,7 @@ function engine:execute_update(sql, ...) prepared[sql] = stmt; end assert(stmt:execute(...)); - return setmetatable({ __affected = stmt:affected(), __rowcount = stmt:rowcount() }, result_mt); + return setmetatable({ __stmt = stmt }, result_mt); end engine.insert = engine.execute_update; engine.select = engine.execute_query; @@ -208,12 +174,13 @@ engine.delete = engine.execute_update; engine.update = engine.execute_update; function engine:_transaction(func, ...) if not self.conn then - local a,b = self:connect(); - if not a then return a,b; end + local ok, err = self:connect(); + if not ok then return ok, err; end end --assert(not self.__transaction, "Recursive transactions not allowed"); local args, n_args = {...}, select("#", ...); local function f() return func(unpack(args, 1, n_args)); end + log("debug", "SQL transaction begin [%s]", tostring(func)); self.__transaction = true; local success, a, b, c = xpcall(f, debug_traceback); self.__transaction = nil; @@ -229,15 +196,15 @@ function engine:_transaction(func, ...) end end function engine:transaction(...) - local a,b = self:_transaction(...); - if not a then + local ok, ret = self:_transaction(...); + if not ok then local conn = self.conn; if not conn or not conn:ping() then self.conn = nil; - a,b = self:_transaction(...); + ok, ret = self:_transaction(...); end end - return a,b; + return ok, ret; end function engine:_create_index(index) local sql = "CREATE INDEX `"..index.name.."` ON `"..index.table.."` ("; @@ -251,19 +218,39 @@ function engine:_create_index(index) elseif self.params.driver == "MySQL" then sql = sql:gsub("`([,)])", "`(20)%1"); end + if index.unique then + sql = sql:gsub("^CREATE", "CREATE UNIQUE"); + end --print(sql); return self:execute(sql); end function engine:_create_table(table) local sql = "CREATE TABLE `"..table.name.."` ("; for i,col in ipairs(table.c) do - sql = sql.."`"..col.name.."` "..col.type; + local col_type = col.type; + if col_type == "MEDIUMTEXT" and self.params.driver ~= "MySQL" then + col_type = "TEXT"; -- MEDIUMTEXT is MySQL-specific + end + if col.auto_increment == true and self.params.driver == "PostgreSQL" then + col_type = "BIGSERIAL"; + end + sql = sql.."`"..col.name.."` "..col_type; if col.nullable == false then sql = sql.." NOT NULL"; end + if col.primary_key == true then sql = sql.." PRIMARY KEY"; end + if col.auto_increment == true then + if self.params.driver == "MySQL" then + sql = sql.." AUTO_INCREMENT"; + elseif self.params.driver == "SQLite3" then + sql = sql.." AUTOINCREMENT"; + end + end if i ~= #table.c then sql = sql..", "; end end sql = sql.. ");" if self.params.driver == "PostgreSQL" then sql = sql:gsub("`", "\""); + elseif self.params.driver == "MySQL" then + sql = sql:gsub(";$", (" CHARACTER SET '%s' COLLATE '%s_bin';"):format(self.charset, self.charset)); end local success,err = self:execute(sql); if not success then return success,err; end @@ -274,6 +261,46 @@ function engine:_create_table(table) end return success; end +function engine:set_encoding() -- to UTF-8 + local driver = self.params.driver; + if driver == "SQLite3" then + return self:transaction(function() + if self:select"PRAGMA encoding;"()[1] == "UTF-8" then + self.charset = "utf8"; + end + end); + end + local set_names_query = "SET NAMES '%s';" + local charset = "utf8"; + if driver == "MySQL" then + local ok, charsets = self:transaction(function() + return self:select"SELECT `CHARACTER_SET_NAME` FROM `information_schema`.`CHARACTER_SETS` WHERE `CHARACTER_SET_NAME` LIKE 'utf8%' ORDER BY MAXLEN DESC LIMIT 1;"; + end); + local row = ok and charsets(); + charset = row and row[1] or charset; + set_names_query = set_names_query:gsub(";$", (" COLLATE '%s';"):format(charset.."_bin")); + end + self.charset = charset; + log("debug", "Using encoding '%s' for database connection", charset); + local ok, err = self:transaction(function() return self:execute(set_names_query:format(charset)); end); + if not ok then + return ok, err; + end + + if driver == "MySQL" then + local ok, actual_charset = self:transaction(function () + return self:select"SHOW SESSION VARIABLES LIKE 'character_set_client'"; + end); + for row in actual_charset do + if row[2] ~= charset then + log("error", "MySQL %s is actually %q (expected %q)", row[1], row[2], charset); + return false, "Failed to set connection encoding"; + end + end + end + + return true; +end local engine_mt = { __index = engine }; local function db2uri(params) @@ -286,55 +313,21 @@ local function db2uri(params) path = params.database, }; end -local engine_cache = {}; -- TODO make weak valued -function create_engine(self, params) - local url = db2uri(params); - if not engine_cache[url] then - local engine = setmetatable({ url = url, params = params }, engine_mt); - engine_cache[url] = engine; - end - return engine_cache[url]; -end - - ---[[Users = Table { - name="users"; - Column { name="user_id", type=String(), primary_key=true }; -}; -print(Users) -print(Users.c.user_id)]] ---local engine = create_engine('postgresql://scott:tiger@localhost:5432/mydatabase'); ---[[local engine = create_engine{ driver = "SQLite3", database = "./alchemy.sqlite" }; - -local i = 0; -for row in assert(engine:execute("select * from sqlite_master")):rows(true) do - i = i+1; - print(i); - for k,v in pairs(row) do - print("",k,v); - end +local function create_engine(self, params, onconnect) + return setmetatable({ url = db2uri(params), params = params, onconnect = onconnect }, engine_mt); end -print("---") -Prosody = Table { - name="prosody"; - Column { name="host", type="TEXT", nullable=false }; - Column { name="user", type="TEXT", nullable=false }; - Column { name="store", type="TEXT", nullable=false }; - Column { name="key", type="TEXT", nullable=false }; - Column { name="type", type="TEXT", nullable=false }; - Column { name="value", type="TEXT", nullable=false }; - Index { name="prosody_index", "host", "user", "store", "key" }; +return { + is_column = is_column; + is_index = is_index; + is_table = is_table; + is_query = is_query; + Integer = Integer; + String = String; + Column = Column; + Table = Table; + Index = Index; + create_engine = create_engine; + db2uri = db2uri; }; ---print(Prosody); -assert(engine:transaction(function() - assert(Prosody:create(engine)); -end)); - -for row in assert(engine:execute("select user from prosody")):rows(true) do - print("username:", row['username']) -end ---result.close();]] - -return _M; |