aboutsummaryrefslogtreecommitdiffstats
path: root/util/sasl.lua
blob: de39a6d77b3bd6b6eaf56645dc21a94c34479754 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173

local md5 = require "md5"
local log = require "util.logger".init("sasl");
local tostring = tostring;
local st = require "util.stanza";
local generate_uuid = require "util.uuid".generate;
local s_match = string.match;
local gmatch = string.gmatch
local string = string
local math = require "math"
local type = type
local error = error
local print = print

module "sasl"

local function new_plain(realm, password_handler)
	local object = { mechanism = "PLAIN", realm = realm, password_handler = password_handler}
	object.feed = 	function(self, message)
						--print(message:gsub("%W", function (c) return string.format("\\%d", string.byte(c)) end));

						if message == "" or message == nil then return "failure", "malformed-request" end
						local response = message
						local authorization = s_match(response, "([^&%z]+)")
						local authentication = s_match(response, "%z([^&%z]+)%z")
						local password = s_match(response, "%z[^&%z]+%z([^&%z]+)")
						
						local password_encoding, correct_password = self.password_handler(authentication, self.realm, "PLAIN")
						
						local claimed_password = ""
						if password_encoding == nil then claimed_password = password
						else claimed_password = password_encoding(password) end
						
						self.username = authentication
						if claimed_password == correct_password then
							log("debug", "success")
							return "success"
						else
							log("debug", "failure")
							return "failure", "not-authorized"
						end
					end
	return object
end

local function new_digest_md5(realm, password_handler)
	--TODO maybe support for authzid

	local function serialize(message)
		local data = ""
		
		if type(message) ~= "table" then error("serialize needs an argument of type table.") end
		
		-- testing all possible values
		if message["nonce"] then data = data..[[nonce="]]..message.nonce..[[",]] end
		if message["qop"] then data = data..[[qop="]]..message.qop..[[",]] end
		if message["charset"] then data = data..[[charset=]]..message.charset.."," end
		if message["algorithm"] then data = data..[[algorithm=]]..message.algorithm.."," end
		if message["realm"] then data = data..[[realm="]]..message.realm..[[",]] end
		if message["rspauth"] then data = data..[[rspauth=]]..message.rspauth.."," end
		data = data:gsub(",$", "")
		return data
	end
	
	local function parse(data)
		message = {}
		log("debug", "parse-message: "..data)
		for k, v in gmatch(data, [[([%w%-]+)="?([%w%-%/%.%+=]+)"?,?]]) do
			message[k] = v
		log("debug", "               "..k.." = "..v)
		end
		return message
	end

	local object = { mechanism = "DIGEST-MD5", realm = realm, password_handler = password_handler}
	
	--TODO: something better than math.random would be nice, maybe OpenSSL's random number generator
	object.nonce = generate_uuid()
	object.step = 0
	object.nonce_count = {}
												
	function object.feed(self, message)
		log("debug", "SASL step: "..self.step)
		self.step = self.step + 1
		if (self.step == 1) then
			local challenge = serialize({	nonce = object.nonce, 
											qop = "auth",
											charset = "utf-8",
											algorithm = "md5-sess",
											realm = self.realm});
			log("debug", "challenge: "..challenge)
			return "challenge", challenge
		elseif (self.step == 2) then
			local response = parse(message)
			-- check for replay attack
			if response["nc"] then
				if self.nonce_count[response["nc"]] then return "failure", "not-authorized" end
			end
			
			-- check for username, it's REQUIRED by RFC 2831
			if not response["username"] then
				return "failure", "malformed-request"
			end
			self["username"] = response["username"] 
			
			-- check for nonce, ...
			if not response["nonce"] then
				return "failure", "malformed-request"
			else
				-- check if it's the right nonce
				if response["nonce"] ~= tostring(self.nonce) then return "failure", "malformed-request" end
			end
			
			if not response["cnonce"] then return "failure", "malformed-request" end
			if not response["qop"] then response["qop"] = "auth" end
			
			if response["realm"] == nil then response["realm"] = "" end
			
			local domain = ""
			local protocol = ""
			if response["digest-uri"] then
				protocol, domain = response["digest-uri"]:match("(%w+)/(.*)$")
			else
				return "failure", "malformed-request", "Missing entry for digest-uri in SASL message."
			end
			
			--TODO maybe realm support
			self.username = response["username"]
			local password_encoding, Y = self.password_handler(response["username"], response["realm"], "DIGEST-MD5")
			local A1 = Y..":"..response["nonce"]..":"..response["cnonce"]--:authzid
			local A2 = "AUTHENTICATE:"..protocol.."/"..domain
			
			local HA1 = md5.sumhexa(A1)
			local HA2 = md5.sumhexa(A2)
			
			local KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2
			local response_value = md5.sumhexa(KD)
			
			log("debug", "response_value: "..response_value);
			log("debug", "response:       "..response["response"]);
			if response_value == response["response"] then
				-- calculate rspauth
				A2 = ":"..protocol.."/"..domain
				
				HA1 = md5.sumhexa(A1)
				HA2 = md5.sumhexa(A2)
        
				KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2
				local rspauth = md5.sumhexa(KD)
				
				return "challenge", serialize({rspauth = rspauth})
			else
				return "failure", "not-authorized", "The response provided by the client doesn't match the one we calculated."
			end							
		elseif self.step == 3 then
			return "success"
		end
	end
	return object
end

function new(mechanism, realm, password_handler)
	local object
	if mechanism == "PLAIN" then object = new_plain(realm, password_handler)
	elseif mechanism == "DIGEST-MD5" then object = new_digest_md5(realm, password_handler)
	else
		log("debug", "Unsupported SASL mechanism: "..tostring(mechanism));
		return nil
	end
	return object
end

return _M;