1 module jwtd.jwt;
2 
3 import std.json;
4 import std.base64;
5 import std.algorithm;
6 import std.array : split;
7 
8 private alias Base64URLNoPadding = Base64Impl!('-', '_', Base64.NoPadding);
9 
10 version(UseOpenSSL) {
11 	public import jwtd.jwt_openssl;
12 }
13 version(UseBotan) {
14 	public import jwtd.jwt_botan;
15 }
16 version(UsePhobos) {
17 	public import jwtd.jwt_phobos;
18 }
19 
20 enum JWTAlgorithm : string {
21 	NONE  = "none",
22 	HS256 = "HS256",
23 	HS384 = "HS384",
24 	HS512 = "HS512",
25 	RS256 = "RS256",
26 	RS384 = "RS384",
27 	RS512 = "RS512",
28 	ES256 = "ES256",
29 	ES384 = "ES384",
30 	ES512 = "ES512"
31 }
32 
33 class SignException : Exception {
34 	this(string s) { super(s); }
35 }
36 
37 class VerifyException : Exception {
38 	this(string s) { super(s); }
39 }
40 
41 /**
42   simple version that accepts only strings as values for payload and header fields
43 */
44 string encode(string[string] payload, string key, JWTAlgorithm algo = JWTAlgorithm.HS256, string[string] header_fields = null) {
45 	JSONValue jsonHeader = header_fields;
46 	JSONValue jsonPayload = payload;
47 
48 	return encode(jsonPayload, key, algo, jsonHeader);
49 }
50 
51 /**
52   full version that accepts JSONValue tree as payload and header fields
53 */
54 string encode(ref JSONValue payload, string key, JWTAlgorithm algo = JWTAlgorithm.HS256, JSONValue header_fields = null) {
55 	return encode(cast(ubyte[])payload.toString(), key, algo, header_fields);
56 }
57 
58 /**
59   full version that accepts ubyte[] as payload and JSONValue tree as header fields
60 */
61 string encode(in ubyte[] payload, string key, JWTAlgorithm algo = JWTAlgorithm.HS256, JSONValue header_fields = null) {
62 	import std.functional : memoize;
63 
64 	auto getEncodedHeader(JWTAlgorithm algo, JSONValue fields) {
65 		if(fields.type == JSONType.null_)
66 			fields = (JSONValue[string]).init;
67 		fields.object["alg"] = cast(string)algo;
68 		fields.object["typ"] = "JWT";
69 
70 		return Base64URLNoPadding.encode(cast(ubyte[])fields.toString()).idup;
71 	}
72 
73 	string encodedHeader = memoize!(getEncodedHeader, 64)(algo, header_fields);
74 	string encodedPayload = Base64URLNoPadding.encode(payload);
75 
76 	string signingInput = encodedHeader ~ "." ~ encodedPayload;
77 	string signature = Base64URLNoPadding.encode(cast(ubyte[])sign(signingInput, key, algo));
78 
79 	return signingInput ~ "." ~ signature;
80 }
81 
82 unittest {
83     import jwtd.test;
84 
85 	// Code coverage for when header_fields is NULL type
86 	auto header_fields = JSONValue();
87 	assert(header_fields.type == JSONType.null_);
88     auto payload = JSONValue([ "a" : "b" ]);
89 	encode(payload, public256, JWTAlgorithm.HS256, header_fields);
90 }
91 
92 /**
93   simple version that knows which key was used to encode the token
94 */
95 JSONValue decode(string token, string key) {
96 	return decode(token, (ref _) => key);
97 }
98 
99 /**
100   full version where the key is provided after decoding the JOSE header
101 */
102 JSONValue decode(string token, string delegate(ref JSONValue jose) lazyKey) {
103 	import std.algorithm : count;
104 	import std.conv : to;
105 	import std.uni : toUpper;
106 
107 	if(count(token, ".") != 2)
108 		throw new VerifyException("Token is incorrect.");
109 
110 	string[] tokenParts = split(token, ".");
111 
112 	JSONValue header;
113 	try {
114 		header = parseJSON(urlsafeB64Decode(tokenParts[0]));
115 	} catch(Exception e) {
116 		throw new VerifyException("Header is incorrect.");
117 	}
118 
119 	JWTAlgorithm alg;
120 	try {
121 		// toUpper for none
122 		alg = to!(JWTAlgorithm)(toUpper(header["alg"].str()));
123 	} catch(Exception e) {
124 		throw new VerifyException("Algorithm is incorrect.");
125 	}
126 
127 	if (auto typ = ("typ" in header)) {
128 		string typ_str = typ.str();
129 		if(typ_str && typ_str != "JWT")
130 			throw new VerifyException("Type is incorrect.");
131 	}
132 
133 	const key = lazyKey(header);
134 	if(!verifySignature(urlsafeB64Decode(tokenParts[2]), tokenParts[0]~"."~tokenParts[1], key, alg))
135 		throw new VerifyException("Signature is incorrect.");
136 
137 	JSONValue payload;
138 
139 	try {
140 		payload = parseJSON(urlsafeB64Decode(tokenParts[1]));
141 	} catch(JSONException e) {
142 		// Code coverage has to miss this line because the signature test above throws before this does
143 		throw new VerifyException("Payload JSON is incorrect.");
144 	}
145 
146 	return payload;
147 }
148 
149 unittest {
150     import jwtd.test;
151     import std.traits : EnumMembers;
152 
153     struct Keys {
154         string priv;
155         string pub;
156 
157         this (string priv, string pub = null) {
158             this.priv = priv;
159             this.pub = (pub ? pub : priv);
160         }
161     }
162 
163     auto commonAlgos = [
164         JWTAlgorithm.NONE  : Keys(),
165         JWTAlgorithm.HS256 : Keys("my key"),
166         JWTAlgorithm.HS384 : Keys("his key"),
167         JWTAlgorithm.HS512 : Keys("her key"),
168     ];
169 
170     version (UseOpenSSL) {
171         Keys[JWTAlgorithm] specialAlgos = [
172             JWTAlgorithm.RS256 : Keys(private256, public256),
173             // TODO: Find key pairs for RS384 and RS512
174             // JWTAlgorithm.RS384 : Keys(private384, public384),
175             // JWTAlgorithm.RS512 : Keys(private512, public512),
176             JWTAlgorithm.ES256 : Keys(es256_private, es256_public),
177             JWTAlgorithm.ES384 : Keys(es384_private, es384_public),
178             JWTAlgorithm.ES512 : Keys(es512_private, es512_public),
179         ];
180     }
181 
182     version (UseBotan) {
183         Keys[JWTAlgorithm] specialAlgos = [
184             JWTAlgorithm.RS256 : Keys(private256, public256),
185             // TODO: Find key pairs for the following
186             // JWTAlgorithm.RS384 : Keys(private384, public384),
187             // JWTAlgorithm.RS512 : Keys(private512, public512),
188             // JWTAlgorithm.ES256 : Keys(es256_private, es256_public),
189             // JWTAlgorithm.ES384 : Keys(es384_private, es384_public),
190             // JWTAlgorithm.ES512 : Keys(es512_private, es512_public),
191         ];
192     }
193 
194     else {
195     }
196 
197     version (UsePhobos) {
198         Keys[JWTAlgorithm] specialAlgos;
199     }
200 
201     void testWith(Keys[JWTAlgorithm] keys) {
202         foreach (algo, k; keys) {
203             auto payload = JSONValue([ "claim" : "value" ]);
204             const encoded = encode(payload, k.priv, algo);
205             const decoded = decode(encoded, k.pub);
206             assert(decoded == payload);
207         }
208     }
209 
210     testWith(commonAlgos);
211     testWith(specialAlgos);
212 }
213 
214 version (unittest) {
215 	string corruptEncodedString(size_t part, string field, string badValue) {
216 		import std.conv : text;
217 
218 		string encoded = encode([ "my" : "payload" ], "key");
219 		string[] tokenParts = split(encoded, ".");
220 		auto jsonValue = parseJSON(urlsafeB64Decode(tokenParts[part]));
221 		jsonValue[field] = badValue;
222 		tokenParts[part] = urlsafeB64Encode(jsonValue.toString());
223 		return text(tokenParts.joiner("."));
224 	}
225 }
226 
227 unittest {
228 	import std.exception : assertThrown;
229 
230     // decode() must not accept invalid tokens
231 
232     // Must have 2 dots
233 	assertThrown!VerifyException(decode("nodot", "key"));
234 	assertThrown!VerifyException(decode("one.dot", "key"));
235 	assertThrown!VerifyException(decode("thr.e.e.dots", "key"));
236 
237     // Must have valid header
238  	assertThrown!VerifyException(decode("corrupt.encoding.blah", "key"));
239 
240     // Must be a known algorithm
241 	assertThrown!VerifyException(decode(corruptEncodedString(0, "alg", "bogus_alg"), "key"));
242 
243     // Must be JWT type
244 	assertThrown!VerifyException(decode(corruptEncodedString(0, "typ", "JWX"), "key"));
245 
246     // Must have valid signature
247 	string encoded = encode([ "my" : "payload" ], "key");
248 	assertThrown!VerifyException(decode(encoded[0..$-1], "key"));
249 }
250 
251 bool verify(string token, string key) {
252 	import std.algorithm : count;
253 	import std.conv : to;
254 	import std.uni : toUpper;
255 
256 	if(count(token, ".") != 2)
257 		throw new VerifyException("Token is incorrect.");
258 
259 	string[] tokenParts = split(token, ".");
260 
261 	string decHeader = urlsafeB64Decode(tokenParts[0]);
262 	JSONValue header = parseJSON(decHeader);
263 
264 	JWTAlgorithm alg;
265 	try {
266 		// toUpper for none
267 		alg = to!(JWTAlgorithm)(toUpper(header["alg"].str()));
268 	} catch(Exception e) {
269 		throw new VerifyException("Algorithm is incorrect.");
270 	}
271 
272 	if (auto typ = ("typ" in header)) {
273 		string typ_str = typ.str();
274 		if(typ_str && typ_str != "JWT")
275 			throw new VerifyException("Type is incorrect.");
276 	}
277 
278 	return verifySignature(urlsafeB64Decode(tokenParts[2]), tokenParts[0]~"."~tokenParts[1], key, alg);
279 }
280 
281 unittest {
282     // verify() must not accept invalid tokens
283 
284 	import std.exception : assertThrown;
285 
286     // Must have 2 dots
287 	assertThrown!VerifyException(verify("nodot", "key"));
288 	assertThrown!VerifyException(verify("one.dot", "key"));
289 	assertThrown!VerifyException(verify("thr.e.e.dots", "key"));
290 
291     // Must have valid algorithm and type
292 	assertThrown!VerifyException(verify(corruptEncodedString(0, "alg", "bogus_alg"), "key"));
293 	assertThrown!VerifyException(verify(corruptEncodedString(0, "typ", "JWX"), "key"));
294 }
295 
296 /**
297  * Encode a string with URL-safe Base64.
298  */
299 string urlsafeB64Encode(string inp) pure nothrow {
300 	return Base64URLNoPadding.encode(cast(ubyte[])inp);
301 }
302 
303 /**
304  * Decode a string with URL-safe Base64.
305  */
306 string urlsafeB64Decode(string inp) pure {
307 	return cast(string)Base64URLNoPadding.decode(inp);
308 }
309 
310 unittest {
311     import jwtd.test;
312 
313 	string hs_secret = "secret";
314 
315 	// none
316 
317 	string noneToken = encode(["language": "D"], "", JWTAlgorithm.NONE);
318 	assert(noneToken == "eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.eyJsYW5ndWFnZSI6IkQifQ.");
319 	assert(verify(noneToken, ""));
320 	assert(!verify(noneToken, "somesecret"));
321 
322 	// hs256
323 
324 	string hs256Token = encode(["language": "D"], hs_secret, JWTAlgorithm.HS256);
325 	assert(hs256Token == "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJsYW5ndWFnZSI6IkQifQ.utQLevAUK97y-e6B3-EnSofvTNAfSXNuSbu4moAh-hY");
326 	assert(verify(hs256Token, hs_secret));
327 
328 	// hs512
329 
330 	string hs512Token = encode(["language": "D"], hs_secret, JWTAlgorithm.HS512);
331 	assert(hs512Token == "eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXVCJ9.eyJsYW5ndWFnZSI6IkQifQ.tDRXngYs15t6Q-9AortMxXNfvTgVjaQGD9VTlwL3JD6Xxab8ass2ekCoom8uOiRdpZ772ajLQD42RXMuALct1Q");
332 	assert(verify(hs512Token, hs_secret));
333 
334 	version(UsePhobos) {
335 		//Not supported
336 	} else {
337         // rs256
338 
339         string rs256Token = encode(["language": "D"], private256, JWTAlgorithm.RS256);
340         assert(rs256Token == "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJsYW5ndWFnZSI6IkQifQ.BYpRNUNsho1Yquq7Uolp31K2Ng90h0hRlMV6J6d9WSSIYf7s2MBX2xgDlBuHtB-Yb9dkbkfdxqjYCQdWejiMc_II6dn72ZSBwBCyWdPPRNbTRA2DNlsoKFBS5WMp7iYordfD9KE0LowK61n_Z7AHNAiOop5Ka1xTKH8cqEo8s3ItgoxZt8mzAfhIYNogGown6sYytqg1I72UHsEX9KAuP7sCxCbxZ9cSVg2f4afEuwwo08AdG3hW_LXhT7VD-EweDmvF2JLAyf1_rW66PMgiZZCLQ6kf2hQRsa56xRDmo5qC98wDseBHx9f3PsTsracTKojwQUdezDmbHv90vCt-Iw");
341         assert(verify(rs256Token, public256));
342 
343         // es256
344 
345         string es256Token = encode(["language": "D"], es256_private, JWTAlgorithm.ES256);
346         assert(verify(es256Token, es256_public));
347 	}
348 }