From 467bf52500181cb6da64f18f7a5a827c24bb66d5 Mon Sep 17 00:00:00 2001
From: Waqas Hussain <waqas20@gmail.com>
Date: Tue, 22 Jan 2013 08:21:05 +0500
Subject: util.sasl.{plain,scram,digest-md5}: nodeprep username before passing
 to callbacks, so callbacks don't have to.

---
 util/sasl/digest-md5.lua | 11 ++++++++---
 util/sasl/plain.lua      |  9 +++++++++
 util/sasl/scram.lua      | 10 ++++++++--
 3 files changed, 25 insertions(+), 5 deletions(-)

diff --git a/util/sasl/digest-md5.lua b/util/sasl/digest-md5.lua
index de2538fc..591d8537 100644
--- a/util/sasl/digest-md5.lua
+++ b/util/sasl/digest-md5.lua
@@ -23,6 +23,7 @@ local to_byte, to_char = string.byte, string.char;
 local md5 = require "util.hashes".md5;
 local log = require "util.logger".init("sasl");
 local generate_uuid = require "util.uuid".generate;
+local nodeprep = require "util.encodings".stringprep.nodeprep;
 
 module "sasl.digest-md5"
 
@@ -139,10 +140,15 @@ local function digest(self, message)
 		end
 
 		-- check for username, it's REQUIRED by RFC 2831
-		if not response["username"] then
+		local username = response["username"];
+		local _nodeprep = self.profile.nodeprep;
+		if username and _nodeprep ~= false then
+			username = (_nodeprep or nodeprep)(username); -- FIXME charset
+		end
+		if not username or username == "" then
 			return "failure", "malformed-request";
 		end
-		self["username"] = response["username"];
+		self.username = username;
 
 		-- check for nonce, ...
 		if not response["nonce"] then
@@ -178,7 +184,6 @@ local function digest(self, message)
 		end
 
 		--TODO maybe realm support
-		self.username = response["username"];
 		local Y, state;
 		if self.profile.plain then
 			local password, state = self.profile.plain(self, response["username"], self.realm)
diff --git a/util/sasl/plain.lua b/util/sasl/plain.lua
index d108a40d..c9ec2911 100644
--- a/util/sasl/plain.lua
+++ b/util/sasl/plain.lua
@@ -13,6 +13,7 @@
 
 local s_match = string.match;
 local saslprep = require "util.encodings".stringprep.saslprep;
+local nodeprep = require "util.encodings".stringprep.nodeprep;
 local log = require "util.logger".init("sasl");
 
 module "sasl.plain"
@@ -54,6 +55,14 @@ local function plain(self, message)
 		return "failure", "malformed-request", "Invalid username or password.";
 	end
 
+	local _nodeprep = self.profile.nodeprep;
+	if _nodeprep ~= false then
+		authentication = (_nodeprep or nodeprep)(authentication);
+		if not authentication or authentication == "" then
+			return "failure", "malformed-request", "Invalid username or password."
+		end
+	end
+
 	local correct, state = false, false;
 	if self.profile.plain then
 		local correct_password;
diff --git a/util/sasl/scram.lua b/util/sasl/scram.lua
index 055ba16a..d0e8987c 100644
--- a/util/sasl/scram.lua
+++ b/util/sasl/scram.lua
@@ -19,6 +19,7 @@ local hmac_sha1 = require "util.hmac".sha1;
 local sha1 = require "util.hashes".sha1;
 local generate_uuid = require "util.uuid".generate;
 local saslprep = require "util.encodings".stringprep.saslprep;
+local nodeprep = require "util.encodings".stringprep.nodeprep;
 local log = require "util.logger".init("sasl");
 local t_concat = table.concat;
 local char = string.char;
@@ -76,7 +77,7 @@ function Hi(hmac, str, salt, i)
 	return res
 end
 
-local function validate_username(username)
+local function validate_username(username, _nodeprep)
 	-- check for forbidden char sequences
 	for eq in username:gmatch("=(.?.?)") do
 		if eq ~= "2C" and eq ~= "3D" then
@@ -90,6 +91,11 @@ local function validate_username(username)
 	
 	-- apply SASLprep
 	username = saslprep(username);
+
+	if username and _nodeprep ~= false then
+		username = (_nodeprep or nodeprep)(username);
+	end
+
 	return username and #username>0 and username;
 end
 
@@ -133,7 +139,7 @@ local function scram_gen(hash_name, H_f, HMAC_f)
 				return "failure", "malformed-request", "Channel binding isn't support at this time.";
 			end
 		
-			self.state.name = validate_username(self.state.name);
+			self.state.name = validate_username(self.state.name, self.profile.nodeprep);
 			if not self.state.name then
 				log("debug", "Username violates either SASLprep or contains forbidden character sequences.")
 				return "failure", "malformed-request", "Invalid username.";
-- 
cgit v1.2.3