diff options
-rw-r--r-- | .gitignore | 6 | ||||
-rw-r--r-- | README | 6 | ||||
-rw-r--r-- | src/Makefile | 34 | ||||
-rw-r--r-- | src/mysqlerl.app | 11 | ||||
-rw-r--r-- | src/mysqlerl.c | 190 | ||||
-rw-r--r-- | src/mysqlerl.erl | 196 | ||||
-rw-r--r-- | src/mysqlerl_app.erl | 19 | ||||
-rw-r--r-- | src/mysqlerl_connection.erl | 75 | ||||
-rw-r--r-- | src/mysqlerl_connection_sup.erl | 29 |
9 files changed, 566 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b36fc75 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +*~ +*.o +*.beam +mysqlerl +ebin/* +priv/*
\ No newline at end of file @@ -0,0 +1,6 @@ +Currently, all the MySQL drivers I've found use direct socket access +on knowledge of wire protocol in order to connect and manipulate MySQL +DBs. + +This is an attempt to fix that. This library will run as a port driver +and use libmysqlclient instead.
\ No newline at end of file diff --git a/src/Makefile b/src/Makefile new file mode 100644 index 0000000..3978bf0 --- /dev/null +++ b/src/Makefile @@ -0,0 +1,34 @@ +CFLAGS = -I/usr/local/mysql/include -O2 -g +LDFLAGS = -L/usr/local/mysql/lib +EFLAGS = -W +debug_info + +PRIVDIR = ../priv +BEAMDIR = ../ebin + +BINS = $(PRIVDIR)/mysqlerl $(BEAMDIR)/mysqlerl.app +MYSQLERLOBJS = mysqlerl.o +BEAMS = mysqlerl.beam mysqlerl_app.beam mysqlerl_connection.beam \ + mysqlerl_connection_sup.beam +LIBS = -lmysqlclient + +all: $(PRIVDIR) $(BEAMDIR) $(BINS) + +clean: + rm -rf *.o *.beam + rm -rf $(BINS) $(MYSQLERLOBJS) $(BEAMS) $(BEAMDIR)/mysqlerl.app + +%.beam: %.erl + erlc $(EFLAGS) $< + +$(PRIVDIR)/mysqlerl: $(PRIVDIR) $(MYSQLERLOBJS) + $(CC) -o $@ $(LDFLAGS) $(MYSQLERLOBJS) $(LIBS) + +$(BEAMDIR)/mysqlerl.app: $(BEAMDIR) $(BEAMS) + cp $(BEAMS) $(BEAMDIR) + cp mysqlerl.app $(BEAMDIR) + +$(PRIVDIR): + mkdir -p $(PRIVDIR) + +$(BEAMDIR): + mkdir -p $(BEAMDIR)
\ No newline at end of file diff --git a/src/mysqlerl.app b/src/mysqlerl.app new file mode 100644 index 0000000..3cc97b6 --- /dev/null +++ b/src/mysqlerl.app @@ -0,0 +1,11 @@ +%% -*- Erlang -*- + +{application, mysqlerl, + [{description, "mysqlerl"}, + {vsn, "0"}, + {modules, [mysqlerl, mysqlerl_app, mysqlerl_connection_sup, + mysqlerl_connection]}, + {registered, [mysqlerl, mysqlerl_app, mysqlerl_connection_sup]}, + {applications, [kernel, stdlib]}, + {env, []}, + {mod, {mysqlerl_app, []}}]}. diff --git a/src/mysqlerl.c b/src/mysqlerl.c new file mode 100644 index 0000000..d65dba1 --- /dev/null +++ b/src/mysqlerl.c @@ -0,0 +1,190 @@ +#include <mysql.h> + +#include <errno.h> +#include <stdio.h> +#include <stdarg.h> +#include <stdlib.h> +#include <string.h> +#include <sys/types.h> +#include <unistd.h> + +const char *LOGPATH = "/tmp/mysqlerl.log"; +const size_t BUFSIZE = 2048; +static FILE *logfile = NULL; + +typedef u_int32_t msglen_t; + +void +openlog() +{ + logfile = fopen(LOGPATH, "a"); +} + +void +closelog() +{ + fclose(logfile); +} + +void +logmsg(const char *format, ...) +{ + FILE *out = logfile; + va_list args; + + if (logfile == NULL) + logfile = stderr; + + va_start(args, format); + (void)vfprintf(logfile, format, args); + (void)fprintf(logfile, "\n"); + va_end(args); + + fflush(logfile); +} + +int +restartable_read(char *buf, size_t buflen) +{ + ssize_t rc, readb; + + rc = 0; + READLOOP: + while (rc < buflen) { + readb = read(STDIN_FILENO, buf + rc, buflen - rc); + if (readb == -1) { + if (errno == EAGAIN || errno == EINTR) + goto READLOOP; + + return -1; + } else if (readb == 0) { + logmsg("ERROR: EOF trying to read additional %d bytes from " + "standard input", buflen - rc); + return -1; + } + + rc += readb; + } + + return rc; +} + +int +restartable_write(const char *buf, size_t buflen) +{ + ssize_t rc, wroteb; + + rc = 0; + WRITELOOP: + while (rc < buflen) { + wroteb = write(STDOUT_FILENO, buf + rc, buflen - rc); + if (wroteb == -1) { + if (errno == EAGAIN || errno == EINTR) + goto WRITELOOP; + + return -1; + } + + rc += wroteb; + } + + return rc; +} + +char * +read_cmd() +{ + char *buf; + msglen_t len; + + logmsg("DEBUG: reading message length."); + if (restartable_read((char *)&len, sizeof(len)) == -1) { + logmsg("ERROR: couldn't read %d byte message prefix: %s.", + sizeof(len), strerror(errno)); + exit(2); + } + len = ntohl(len); + + buf = malloc(len); + if (buf == NULL) { + logmsg("ERROR: Couldn't malloc %d bytes: %s.", len, + strerror(errno)); + exit(2); + } + memset(buf, 0, BUFSIZE); + + logmsg("DEBUG: reading message body (len: %d).", len); + if (restartable_read(buf, len) == -1) { + logmsg("ERROR: couldn't read %d byte message: %s.", + len, strerror(errno)); + exit(2); + } + + return buf; +} + +int +write_cmd(const char *cmd, msglen_t len) +{ + msglen_t nlen; + + nlen = htonl(len + 3); + restartable_write((char *)&nlen, sizeof(nlen)); + restartable_write(" - ", 3); + restartable_write(cmd, len); +} + +void +dispatch_db_cmd(MYSQL *dbh, const char *cmd) +{ + msglen_t len, nlen; + + logmsg("DEBUG: dispatch_cmd(\"%s\")", cmd); + write_cmd(cmd, strlen(cmd)); +} + +void +usage() +{ + fprintf(stderr, "Usage: mysqlerl host port db_name user passwd\n"); + exit(1); +} + +int +main(int argc, char *argv[]) +{ + MYSQL dbh; + char *host, *port, *db_name, *user, *passwd, *cmd; + + openlog(); + logmsg("INFO: starting up."); + + if (argc < 6) + usage(); + + host = argv[1]; + port = argv[2]; + db_name = argv[3]; + user = argv[4]; + passwd = argv[5]; + + mysql_init(&dbh); + if (mysql_real_connect(&dbh, host, user, passwd, + db_name, atoi(port), NULL, 0) == NULL) { + logmsg("ERROR: Failed to connect to database %s: %s (%s:%s).", + db_name, mysql_error(&dbh), user, passwd); + exit(2); + } + + while ((cmd = read_cmd()) != NULL) { + dispatch_db_cmd(&dbh, cmd); + free(cmd); + } + + mysql_close(&dbh); + + logmsg("INFO: shutting down."); + closelog(); + + return 0; +} diff --git a/src/mysqlerl.erl b/src/mysqlerl.erl new file mode 100644 index 0000000..68185a2 --- /dev/null +++ b/src/mysqlerl.erl @@ -0,0 +1,196 @@ +%% Modeled from ODBC +%% http://www.erlang.org/doc/apps/odbc/ + +-module(mysqlerl). +-author('bjc@kublai.com'). + +-export([test_start/0, test_msg/0]). + +-export([start/0, start/1, stop/0, commit/2, commit/3, + connect/6, disconnect/1, describe_table/2, + describe_table/3, first/1, first/2, + last/1, last/2, next/1, next/2, prev/1, + prev/2, select_count/2, select_count/3, + select/3, select/4, param_query/3, param_query/4, + sql_query/2, sql_query/3]). + +-define(CONFIG, "/Users/bjc/tmp/test-server.cfg"). +-define(NOTIMPLEMENTED, {error, {not_implemented, + "This function has only been stubbed " + "out for reference."}}). + +test_start() -> + {ok, [{Host, Port, DB, User, Pass, Options}]} = file:consult(?CONFIG), + mysqlerl:connect(Host, Port, DB, User, Pass, Options). + +test_msg() -> + mysqlerl_connection:testmsg(mysqlerl_connection_sup:random_child()). + +start() -> + start(temporary). + +%% Arguments: +%% Type = permanent | transient | temporary +%% +%% Returns: +%% ok | {error, Reason} +start(Type) -> + application:start(sasl), + application:start(mysqlerl, Type). + +stop() -> + application:stop(mysqlerl). + +commit(Ref, CommitMode) -> + commit(Ref, CommitMode, infinity). + +%% Arguments: +%% Ref = connection_reference() +%% Timeout = time_out() +%% CommitMode = commit | rollback +%% Reason = not_an_explicit_commit_connection | +%% process_not_owner_of_odbc_connection | +%% common_reason() +%% ok | {error, Reason} +commit(Ref, commit, Timeout) -> + mysqlerl_connection:sql_query(Ref, "COMMIT", Timeout); +commit(Ref, rollback, Timeout) -> + mysqlerl_connection:sql_query(Ref, "ROLLBACK", Timeout). + +%% Arguments: +%% Host = string() +%% Port = integer() +%% Database = string() +%% User = string() +%% Password = string() +%% Options = list() +%% +%% Returns: +%% {ok, Ref} | {error, Reason} +%% Ref = connection_reference() +connect(Host, Port, Database, User, Password, Options) -> + mysqlerl_connection_sup:connect(Host, Port, Database, + User, Password, Options). + +%% Arguments: +%% Ref = connection_reference() +%% +%% Returns: +%% ok | {error, Reason} +disconnect(Ref) -> + mysqlerl_connection:stop(Ref). + +describe_table(Ref, Table) -> + describe_table(Ref, Table, infinity). + +%% Arguments: +%% Ref = connection_reference() +%% Table = string() +%% Timeout = time_out() +%% +%% Returns: +%% {ok, Description} | {error, Reason} +%% Description = [{col_name(), odbc_data_type()}] +describe_table(_Ref, _Table, _Timeout) -> + ?NOTIMPLEMENTED. + +first(Ref) -> + first(Ref, infinity). + +%% Arguments: +%% Ref = connection_reference() +%% Timeout = time_out() +%% Returns: +%% {selected, ColNames, Rows} | {error, Reason} +%% Rows = rows() +first(_Ref, _Timeout) -> + ?NOTIMPLEMENTED. + +last(Ref) -> + last(Ref, infinity). + +%% Arguments: +%% Ref = connection_reference() +%% Timeout = time_out() +%% Returns: +%% {selected, ColNames, Rows} | {error, Reason} +%% Rows = rows() +last(_Ref, _Timeout) -> + ?NOTIMPLEMENTED. + +next(Ref) -> + next(Ref, infinity). + +%% Arguments: +%% Ref = connection_reference() +%% Timeout = time_out() +%% Returns: +%% {selected, ColNames, Rows} | {error, Reason} +%% Rows = rows() +next(_Ref, _Timeout) -> + ?NOTIMPLEMENTED. + +prev(Ref) -> + prev(Ref, infinity). + +%% Arguments: +%% Ref = connection_reference() +%% Timeout = time_out() +%% Returns: +%% {selected, ColNames, Rows} | {error, Reason} +%% Rows = rows() +prev(_Ref, _Timeout) -> + ?NOTIMPLEMENTED. + +select_count(Ref, SQLQuery) -> + select_count(Ref, SQLQuery, infinity). + +%% Arguments: +%% Ref = connection_reference() +%% SQLQuery = string() +%% Timeout = time_out() +%% Returns: +%% {ok, NrRows} | {error, Reason} +%% NrRows = n_rows() +select_count(_Ref, _SQLQuery, _Timeout) -> + ?NOTIMPLEMENTED. + +select(Ref, Pos, N) -> + select(Ref, Pos, N, infinity). + +%% Arguments: +%% Ref = connection_reference() +%% Pos = integer() +%% Timeout = time_out() +%% Returns: +%% {selected, ColNames, Rows} | {error, Reason} +%% Rows = rows() +select(_Ref, _Pos, _N, _Timeout) -> + ?NOTIMPLEMENTED. + +param_query(Ref, SQLQuery, Params) -> + param_query(Ref, SQLQuery, Params, infinity). + +%% Arguments: +%% Ref = connection_reference() +%% SQLQuery = string() +%% Params = [{odbc_data_type(), [value()]}] +%% Timeout = time_out() +%% Returns: +%% {selected, ColNames, Rows} | {error, Reason} +%% Rows = rows() +param_query(_Ref, _SQLQuery, _Params, _Timeout) -> + ?NOTIMPLEMENTED. + +sql_query(Ref, SQLQuery) -> + sql_query(Ref, SQLQuery, infinity). + +%% Arguments: +%% Ref = connection_reference() +%% SQLQuery = string() +%% Timeout = time_out() +%% Returns: +%% {selected, ColNames, Rows} | {error, Reason} +%% Rows = rows() +sql_query(_Ref, _SQLQuery, _Timeout) -> + ?NOTIMPLEMENTED. diff --git a/src/mysqlerl_app.erl b/src/mysqlerl_app.erl new file mode 100644 index 0000000..6be4007 --- /dev/null +++ b/src/mysqlerl_app.erl @@ -0,0 +1,19 @@ +-module(mysqlerl_app). +-author('bjc@kublai.com'). + +-behavior(application). +-behavior(supervisor). + +-export([start/2, stop/1, init/1]). + +start(normal, []) -> + supervisor:start_link({local, ?MODULE}, ?MODULE, []). + +stop([]) -> + ok. + +init([]) -> + ConnectionSup = {mysqlerl_connection_sup, {mysqlerl_connection_sup, start_link, []}, + permanent, infinity, supervisor, [mysqlerl_connection_sup]}, + {ok, {{one_for_one, 10, 5}, + [ConnectionSup]}}. diff --git a/src/mysqlerl_connection.erl b/src/mysqlerl_connection.erl new file mode 100644 index 0000000..aa73d23 --- /dev/null +++ b/src/mysqlerl_connection.erl @@ -0,0 +1,75 @@ +-module(mysqlerl_connection). +-author('bjc@kublai.com'). + +-behavior(gen_server). + +-export([start_link/6, stop/1, sql_query/3, testmsg/1]). + +-export([init/1, terminate/2, code_change/3, + handle_call/3, handle_cast/2, handle_info/2]). + +-record(state, {ref}). +-record(port_closed, {reason}). +-record(sql_query, {q}). + +start_link(Host, Port, Database, User, Password, Options) -> + gen_server:start_link(?MODULE, [Host, Port, Database, + User, Password, Options], []). + +stop(Pid) -> + gen_server:cast(Pid, stop). + +sql_query(Pid, Query, Timeout) -> + gen_server:call(Pid, #sql_query{q = Query}, Timeout). + +testmsg(Pid) -> + gen_server:call(Pid, {test, "SELECT COUNT(*) FROM user;"}). + +init([Host, Port, Database, User, Password, Options]) -> + process_flag(trap_exit, true), + Cmd = lists:flatten(io_lib:format("~s ~s ~w ~s ~s ~s ~s", + [helper(), Host, Port, Database, + User, Password, Options])), + Ref = open_port({spawn, Cmd}, [{packet, 4}]), + {ok, #state{ref = Ref}}. + +terminate(#port_closed{reason = Reason}, #state{ref = Ref}) -> + io:format("DEBUG: mysqlerl connection ~p shutting down (~p).~n", + [Ref, Reason]), + ok; +terminate(Reason, State) -> + port_close(State#state.ref), + io:format("DEBUG: got terminate: ~p~n", [Reason]), + ok. + +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +handle_call({test, Str}, _From, #state{ref = Ref} = State) -> + io:format("DEBUG: got test message: ~p~n", [Str]), + port_command(Ref, Str), + receive + {Ref, {data, Res}} -> + {reply, {ok, Res}, State}; + Other -> + error_logger:warning_msg("Got unknown message: ~p~n", [Other]) + end; +handle_call(Request, From, State) -> + io:format("DEBUG: got unknown call from ~p: ~p~n", [From, Request]), + {noreply, State}. + +handle_cast(stop, State) -> + {stop, normal, State}. + +handle_info({'EXIT', _Ref, Reason}, State) -> + {stop, #port_closed{reason = Reason}, State}; +handle_info(Info, State) -> + io:format("DEBUG: got unknown info: ~p~n", [Info]), + {noreply, State}. + +helper() -> + case code:priv_dir(mysqlerl) of + PrivDir when is_list(PrivDir) -> ok; + {error, bad_name} -> PrivDir = filename:join(["..", "priv"]) + end, + filename:join([PrivDir, "mysqlerl"]). diff --git a/src/mysqlerl_connection_sup.erl b/src/mysqlerl_connection_sup.erl new file mode 100644 index 0000000..4aa3fc6 --- /dev/null +++ b/src/mysqlerl_connection_sup.erl @@ -0,0 +1,29 @@ +-module(mysqlerl_connection_sup). +-author('bjc@kublai.com'). + +-behavior(supervisor). + +-export([start_link/0, connect/6, random_child/0]). + +-export([init/1]). + +start_link() -> + supervisor:start_link({local, ?MODULE}, ?MODULE, []). + +connect(Host, Port, Database, User, Password, Options) -> + supervisor:start_child(?MODULE, [Host, Port, Database, User, Password, Options]). + +random_child() -> + case get_pids() of + [] -> {error, no_connections}; + Pids -> lists:nth(erlang:phash(now(), length(Pids)), Pids) + end. + +init([]) -> + Connection = {undefined, {mysqlerl_connection, start_link, []}, + transient, 5, worker, [mysqlerl_connection]}, + {ok, {{simple_one_for_one, 10, 5}, + [Connection]}}. + +get_pids() -> + [Pid || {_Id, Pid, _Type, _Modules} <- supervisor:which_children(?MODULE)]. |