aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--util/sasl/scram.lua81
1 files changed, 40 insertions, 41 deletions
diff --git a/util/sasl/scram.lua b/util/sasl/scram.lua
index 65090719..a18f025e 100644
--- a/util/sasl/scram.lua
+++ b/util/sasl/scram.lua
@@ -102,22 +102,19 @@ end
local function scram_gen(hash_name, H_f, HMAC_f)
local function scram_hash(self, message)
- if not self.state then self["state"] = {} end
local support_channel_binding = false;
if self.profile.cb then support_channel_binding = true; end
if type(message) ~= "string" or #message == 0 then return "failure", "malformed-request" end
- if not self.state.name then
+ local state = self.state;
+ if not state then
-- we are processing client_first_message
local client_first_message = message;
-- TODO: fail if authzid is provided, since we don't support them yet
- self.state["client_first_message"] = client_first_message;
- self.state["gs2_header"], self.state["gs2_cbind_flag"], self.state["gs2_cbind_name"], self.state["authzid"], self.state["name"], self.state["clientnonce"]
+ local gs2_header, gs2_cbind_flag, gs2_cbind_name, authzid, name, clientnonce
= client_first_message:match("^(([ynp])=?([%a%-]*),(.*),)n=(.*),r=([^,]*).*");
- local gs2_cbind_flag = self.state.gs2_cbind_flag;
-
if not gs2_cbind_flag then
return "failure", "malformed-request";
end
@@ -135,29 +132,24 @@ local function scram_gen(hash_name, H_f, HMAC_f)
if support_channel_binding and gs2_cbind_flag == "p" then
-- check whether we support the proposed channel binding type
- if not self.profile.cb[self.state.gs2_cbind_name] then
+ if not self.profile.cb[gs2_cbind_name] then
return "failure", "malformed-request", "Proposed channel binding type isn't supported.";
end
else
-- no channel binding,
- self.state.gs2_cbind_name = nil;
- end
-
- if not self.state.name or not self.state.clientnonce then
- return "failure", "malformed-request", "Channel binding isn't support at this time.";
+ gs2_cbind_name = nil;
end
- self.state.name = validate_username(self.state.name, self.profile.nodeprep);
- if not self.state.name then
+ name = validate_username(name, self.profile.nodeprep);
+ if not name then
log("debug", "Username violates either SASLprep or contains forbidden character sequences.")
return "failure", "malformed-request", "Invalid username.";
end
- self.state["servernonce"] = generate_uuid();
-
-- retreive credentials
+ local stored_key, server_key, salt, iteration_count;
if self.profile.plain then
- local password, state = self.profile.plain(self, self.state.name, self.realm)
+ local password, state = self.profile.plain(self, name, self.realm)
if state == nil then return "failure", "not-authorized"
elseif state == false then return "failure", "account-disabled" end
@@ -167,64 +159,71 @@ local function scram_gen(hash_name, H_f, HMAC_f)
return "failure", "not-authorized", "Invalid password."
end
- self.state.salt = generate_uuid();
- self.state.iteration_count = default_i;
+ salt = generate_uuid();
+ iteration_count = default_i;
local succ = false;
- succ, self.state.stored_key, self.state.server_key = getAuthenticationDatabaseSHA1(password, self.state.salt, default_i, self.state.iteration_count);
+ succ, stored_key, server_key = getAuthenticationDatabaseSHA1(password, salt, iteration_count);
if not succ then
- log("error", "Generating authentication database failed. Reason: %s", self.state.stored_key);
+ log("error", "Generating authentication database failed. Reason: %s", stored_key);
return "failure", "temporary-auth-failure";
end
elseif self.profile["scram_"..hashprep(hash_name)] then
- local stored_key, server_key, iteration_count, salt, state = self.profile["scram_"..hashprep(hash_name)](self, self.state.name, self.realm);
+ local state;
+ stored_key, server_key, iteration_count, salt, state = self.profile["scram_"..hashprep(hash_name)](self, name, self.realm);
if state == nil then return "failure", "not-authorized"
elseif state == false then return "failure", "account-disabled" end
-
- self.state.stored_key = stored_key;
- self.state.server_key = server_key;
- self.state.iteration_count = iteration_count;
- self.state.salt = salt
end
- local server_first_message = "r="..self.state.clientnonce..self.state.servernonce..",s="..base64.encode(self.state.salt)..",i="..self.state.iteration_count;
- self.state["server_first_message"] = server_first_message;
+ local nonce = clientnonce .. generate_uuid();
+ local server_first_message = "r="..nonce..",s="..base64.encode(salt)..",i="..iteration_count;
+ self.state = {
+ gs2_header = gs2_header;
+ gs2_cbind_name = gs2_cbind_name;
+ name = name;
+ nonce = nonce;
+
+ server_key = server_key;
+ stored_key = stored_key;
+ client_first_message = client_first_message;
+ server_first_message = server_first_message;
+ }
return "challenge", server_first_message
else
-- we are processing client_final_message
local client_final_message = message;
- self.state["channelbinding"], self.state["nonce"], self.state["proof"] = client_final_message:match("^c=(.*),r=(.*),.*p=(.*)");
+ local channelbinding, nonce, proof = client_final_message:match("^c=(.*),r=(.*),.*p=(.*)");
- if not self.state.proof or not self.state.nonce or not self.state.channelbinding then
+ if not proof or not nonce or not channelbinding then
return "failure", "malformed-request", "Missing an attribute(p, r or c) in SASL message.";
end
- local client_gs2_header = base64.decode(self.state.channelbinding)
- local our_client_gs2_header = self.state["gs2_header"]
- if self.state.gs2_cbind_name then
+ local client_gs2_header = base64.decode(channelbinding)
+ local our_client_gs2_header = state["gs2_header"]
+ if state.gs2_cbind_name then
-- we support channelbinding, so check if the value is valid
- our_client_gs2_header = our_client_gs2_header .. self.profile.cb[self.state.gs2_cbind_name](self);
+ our_client_gs2_header = our_client_gs2_header .. self.profile.cb[state.gs2_cbind_name](self);
end
if client_gs2_header ~= our_client_gs2_header then
return "failure", "malformed-request", "Invalid channel binding value.";
end
- if self.state.nonce ~= self.state.clientnonce..self.state.servernonce then
+ if nonce ~= state.nonce then
return "failure", "malformed-request", "Wrong nonce in client-final-message.";
end
- local ServerKey = self.state.server_key;
- local StoredKey = self.state.stored_key;
+ local ServerKey = state.server_key;
+ local StoredKey = state.stored_key;
- local AuthMessage = "n=" .. s_match(self.state.client_first_message,"n=(.+)") .. "," .. self.state.server_first_message .. "," .. s_match(client_final_message, "(.+),p=.+")
+ local AuthMessage = "n=" .. s_match(state.client_first_message,"n=(.+)") .. "," .. state.server_first_message .. "," .. s_match(client_final_message, "(.+),p=.+")
local ClientSignature = HMAC_f(StoredKey, AuthMessage)
- local ClientKey = binaryXOR(ClientSignature, base64.decode(self.state.proof))
+ local ClientKey = binaryXOR(ClientSignature, base64.decode(proof))
local ServerSignature = HMAC_f(ServerKey, AuthMessage)
if StoredKey == H_f(ClientKey) then
local server_final_message = "v="..base64.encode(ServerSignature);
- self["username"] = self.state.name;
+ self["username"] = state.name;
return "success", server_final_message;
else
return "failure", "not-authorized", "The response provided by the client doesn't match the one we calculated.";