1 module jwtd.jwt_openssl;
2 
3 version(UseOpenSSL) {
4 
5 	import deimos.openssl.ssl;
6 	import deimos.openssl.pem;
7 	import deimos.openssl.rsa;
8 	import deimos.openssl.hmac;
9 	import deimos.openssl.err;
10 
11 	import jwtd.jwt : JWTAlgorithm, SignException, VerifyException;
12 
13 	version(OpenSSL11) {
14 		extern(C) nothrow void HMAC_CTX_reset(HMAC_CTX * ctx);
15 	}
16 
17 	EC_KEY* getESKeypair(uint curve_type, string key) {
18 		EC_GROUP* curve;
19 		EVP_PKEY* pktmp;
20 		BIO* bpo;
21 		EC_POINT* pub;
22 
23 		if(null == (curve = EC_GROUP_new_by_curve_name(curve_type)))
24 			throw new Exception("Unsupported curve.");
25 		scope(exit) EC_GROUP_free(curve);
26 
27 		bpo = BIO_new_mem_buf(cast(char*)key.ptr, -1);
28 		if(bpo is null) {
29 			throw new Exception("Can't load the key.");
30 		}
31 		scope(exit) BIO_free(bpo);
32 
33 		pktmp = PEM_read_bio_PrivateKey(bpo, null, null, null);
34 		if(pktmp is null) {
35 			throw new Exception("Can't load the evp_pkey.");
36 		}
37 		scope(exit) EVP_PKEY_free(pktmp);
38 
39 		EC_KEY* eckey;
40 		eckey = EVP_PKEY_get1_EC_KEY(pktmp);
41 		if(eckey is null) {
42 			throw new Exception("Can't convert evp_pkey to EC_KEY.");
43 		}
44 		scope(failure) EC_KEY_free(eckey);
45 
46 		if(1 != EC_KEY_set_group(eckey, curve)) {
47 			throw new Exception("Can't associate group with the key.");
48 		}
49 
50 		const BIGNUM *prv = EC_KEY_get0_private_key(eckey);
51 		if(null == prv) {
52 			throw new Exception("Can't get private key.");
53 		}
54 
55 		pub = EC_POINT_new(curve);
56 		if(null == pub) {
57 			throw new Exception("Can't allocate EC point.");
58 		}
59 		scope(exit) EC_POINT_free(pub);
60 
61 		if (1 != EC_POINT_mul(curve, pub, prv, null, null, null)) {
62 			throw new Exception("Can't calculate public key.");
63 		}
64 
65 		if(1 != EC_KEY_set_public_key(eckey, pub)) {
66 			throw new Exception("Can't set public key.");
67 		}
68 
69 		return eckey;
70 	}
71 
72     unittest {
73         import jwtd.test;
74         import std.exception : assertThrown;
75 
76         assertThrown(getESKeypair(0, "key"));
77         assertThrown(getESKeypair(NID_secp256k1, "bogus_key"));
78         assertThrown(getESKeypair(NID_secp256k1, null));
79         assertThrown(getESKeypair(NID_secp256k1, private256));
80     }
81 
82 	EC_KEY* getESPrivateKey(uint curve_type, string key) {
83 		EC_GROUP* curve;
84 		EVP_PKEY* pktmp;
85 		BIO* bpo;
86 
87 		if(null == (curve = EC_GROUP_new_by_curve_name(curve_type)))
88 			throw new Exception("Unsupported curve.");
89 		scope(exit) EC_GROUP_free(curve);
90 
91 		bpo = BIO_new_mem_buf(cast(char*)key.ptr, -1);
92 		if(bpo is null) {
93 			throw new Exception("Can't load the key.");
94 		}
95 		scope(exit) BIO_free(bpo);
96 
97 		pktmp = PEM_read_bio_PrivateKey(bpo, null, null, null);
98 		if(pktmp is null) {
99 			throw new Exception("Can't load the evp_pkey.");
100 		}
101 		scope(exit) EVP_PKEY_free(pktmp);
102 
103 		EC_KEY * eckey;
104 	 	eckey = EVP_PKEY_get1_EC_KEY(pktmp);
105 		if(eckey is null) {
106 			throw new Exception("Can't convert evp_pkey to EC_KEY.");
107 		}
108 
109 		scope(failure) EC_KEY_free(eckey);
110 		if(1 != EC_KEY_set_group(eckey, curve)) {
111 			throw new Exception("Can't associate group with the key.");
112 		}
113 
114 		return eckey;
115 	}
116 
117     unittest {
118         import std.exception : assertThrown;
119         assertThrown(getESPrivateKey(0, "key"));
120         assertThrown(getESPrivateKey(NID_secp256k1, "bogus_key"));
121         assertThrown(getESPrivateKey(NID_secp256k1, null));
122     }
123 
124 	EC_KEY* getESPublicKey(uint curve_type, string key) {
125 		EC_GROUP* curve;
126 
127 		if(null == (curve = EC_GROUP_new_by_curve_name(curve_type)))
128 			throw new Exception("Unsupported curve.");
129 		scope(exit) EC_GROUP_free(curve);
130 
131 		EC_KEY* eckey;
132 
133 		BIO* bpo = BIO_new_mem_buf(cast(char*)key.ptr, -1);
134 		if(bpo is null) {
135 			throw new Exception("Can't load the key.");
136 		}
137 		scope(exit) BIO_free(bpo);
138 
139 		eckey = PEM_read_bio_EC_PUBKEY(bpo, null, null, null);
140 		scope(failure) EC_KEY_free(eckey);
141 
142 		if(1 != EC_KEY_set_group(eckey, curve)) {
143 			throw new Exception("Can't associate group with the key.");
144 		}
145 
146 		if(0 == EC_KEY_check_key(eckey))
147 			throw new Exception("Public key is not valid.");
148 
149 		return eckey;
150 	}
151 
152     unittest {
153         import jwtd.test;
154         import std.exception : assertThrown;
155 
156         assertThrown(getESPublicKey(0, "key"));
157 
158         auto eckey = getESPublicKey(NID_secp256k1, es256_public);
159         EC_KEY_free(eckey);
160         assertThrown(getESPublicKey(NID_secp256k1, null));
161     }
162 
163 	string sign(string msg, string key, JWTAlgorithm algo = JWTAlgorithm.HS256) {
164 		ubyte[] sign;
165 
166 		void sign_hs(const(EVP_MD)* evp, uint signLen) {
167 			sign = new ubyte[signLen];
168 
169 			HMAC_CTX ctx;
170 			version(OpenSSL11) {
171 				scope(exit) HMAC_CTX_reset(&ctx);
172 				HMAC_CTX_reset(&ctx);
173 			}
174 			else {
175 				scope(exit) HMAC_CTX_cleanup(&ctx);
176 				HMAC_CTX_init(&ctx);
177 			}
178 			if(0 == HMAC_Init_ex(&ctx, key.ptr, cast(int)key.length, evp, null)) {
179 				throw new Exception("Can't initialize HMAC context.");
180 			}
181 			if(0 == HMAC_Update(&ctx, cast(const(ubyte)*)msg.ptr, cast(ulong)msg.length)) {
182 				throw new Exception("Can't update HMAC.");
183 			}
184 			if(0 == HMAC_Final(&ctx, cast(ubyte*)sign.ptr, &signLen)) {
185 				throw new Exception("Can't finalize HMAC.");
186 			}
187 		}
188 
189 		void sign_rs(ubyte* hash, int type, uint len, uint signLen) {
190 			sign = new ubyte[len];
191 
192 			RSA* rsa_private = RSA_new();
193 			scope(exit) RSA_free(rsa_private);
194 
195 			BIO* bpo = BIO_new_mem_buf(cast(char*)key.ptr, -1);
196 			if(bpo is null)
197 				throw new Exception("Can't load the key.");
198 			scope(exit) BIO_free(bpo);
199 
200 			RSA* rsa = PEM_read_bio_RSAPrivateKey(bpo, &rsa_private, null, null);
201 			if(rsa is null) {
202 				throw new Exception("Can't create RSA key.");
203 			}
204 			if(0 == RSA_sign(type, hash, signLen, sign.ptr, &signLen, rsa_private)) {
205 				throw new Exception("Can't sign RSA message digest.");
206 			}
207 		}
208 
209 		void sign_es(uint curve_type, ubyte* hash, int hashLen) {
210 			EC_KEY* eckey = getESPrivateKey(curve_type, key);
211 			scope(exit) EC_KEY_free(eckey);
212 
213 			ECDSA_SIG* sig = ECDSA_do_sign(hash, hashLen, eckey);
214 			if(sig is null) {
215 				throw new Exception("Digest sign failed.");
216 			}
217 			scope(exit) ECDSA_SIG_free(sig);
218 
219 			sign = new ubyte[ECDSA_size(eckey)];
220 			ubyte* c = sign.ptr;
221 			if(!i2d_ECDSA_SIG(sig, &c)) {
222 				throw new Exception("Convert sign to DER format failed.");
223 			}
224 		}
225 
226 		switch(algo) {
227 			case JWTAlgorithm.NONE: {
228 				break;
229 			}
230 			case JWTAlgorithm.HS256: {
231 				sign_hs(EVP_sha256(), SHA256_DIGEST_LENGTH);
232 				break;
233 			}
234 			case JWTAlgorithm.HS384: {
235 				sign_hs(EVP_sha384(), SHA384_DIGEST_LENGTH);
236 				break;
237 			}
238 			case JWTAlgorithm.HS512: {
239 				sign_hs(EVP_sha512(), SHA512_DIGEST_LENGTH);
240 				break;
241 			}
242 			case JWTAlgorithm.RS256: {
243 				ubyte[] hash = new ubyte[SHA256_DIGEST_LENGTH];
244 				SHA256(cast(const(ubyte)*)msg.ptr, msg.length, hash.ptr);
245 				sign_rs(hash.ptr, NID_sha256, 256, SHA256_DIGEST_LENGTH);
246 				break;
247 			}
248 			case JWTAlgorithm.RS384: {
249 				ubyte[] hash = new ubyte[SHA384_DIGEST_LENGTH];
250 				SHA384(cast(const(ubyte)*)msg.ptr, msg.length, hash.ptr);
251 				sign_rs(hash.ptr, NID_sha384, 384, SHA384_DIGEST_LENGTH);
252 				break;
253 			}
254 			case JWTAlgorithm.RS512: {
255 				ubyte[] hash = new ubyte[SHA512_DIGEST_LENGTH];
256 				SHA512(cast(const(ubyte)*)msg.ptr, msg.length, hash.ptr);
257 				sign_rs(hash.ptr, NID_sha512, 512, SHA512_DIGEST_LENGTH);
258 				break;
259 			}
260 			case JWTAlgorithm.ES256: {
261 				ubyte[] hash = new ubyte[SHA256_DIGEST_LENGTH];
262 				SHA256(cast(const(ubyte)*)msg.ptr, msg.length, hash.ptr);
263 				sign_es(NID_secp256k1, hash.ptr, SHA256_DIGEST_LENGTH);
264 				break;
265 			}
266 			case JWTAlgorithm.ES384: {
267 				ubyte[] hash = new ubyte[SHA384_DIGEST_LENGTH];
268 				SHA384(cast(const(ubyte)*)msg.ptr, msg.length, hash.ptr);
269 				sign_es(NID_secp384r1, hash.ptr, SHA384_DIGEST_LENGTH);
270 				break;
271 			}
272 			case JWTAlgorithm.ES512: {
273 				ubyte[] hash = new ubyte[SHA512_DIGEST_LENGTH];
274 				SHA512(cast(const(ubyte)*)msg.ptr, msg.length, hash.ptr);
275 				sign_es(NID_secp521r1, hash.ptr, SHA512_DIGEST_LENGTH);
276 				break;
277 			}
278 
279 			default:
280 				throw new SignException("Wrong algorithm.");
281 		}
282 
283 		return cast(string)sign;
284 	}
285 
286 
287 	bool verifySignature(string signature, string signing_input, string key, JWTAlgorithm algo = JWTAlgorithm.HS256) {
288 
289 		bool verify_rs(ubyte* hash, int type, uint len, uint signLen) {
290 			RSA* rsa_public = RSA_new();
291 			scope(exit) RSA_free(rsa_public);
292 
293 			BIO* bpo = BIO_new_mem_buf(cast(char*)key.ptr, -1);
294 			if(bpo is null)
295 				throw new Exception("Can't load key to the BIO.");
296 			scope(exit) BIO_free(bpo);
297 
298 			RSA* rsa = PEM_read_bio_RSA_PUBKEY(bpo, &rsa_public, null, null);
299 			if(rsa is null) {
300 				throw new Exception("Can't create RSA key.");
301 			}
302 
303 			ubyte[] sign = cast(ubyte[])signature;
304 			int ret = RSA_verify(type, hash, signLen, sign.ptr, len, rsa_public);
305 			return ret == 1;
306 		}
307 
308 		bool verify_es(uint curve_type, ubyte* hash, int hashLen ) {
309 			EC_KEY* eckey = getESPublicKey(curve_type, key);
310 			scope(exit) EC_KEY_free(eckey);
311 
312 			ubyte* c = cast(ubyte*)signature.ptr;
313 			ECDSA_SIG* sig = null;
314 			sig = d2i_ECDSA_SIG(&sig, cast(const (ubyte)**)&c, cast(int) key.length);
315 			if (sig is null) {
316 				throw new Exception("Can't decode ECDSA signature.");
317 			}
318 			scope(exit) ECDSA_SIG_free(sig);
319 
320 			int ret =  ECDSA_do_verify(hash, hashLen, sig, eckey);
321 			return ret == 1;
322 		}
323 
324 		switch(algo) {
325 			case JWTAlgorithm.NONE: {
326 				return key.length == 0;
327 			}
328 			case JWTAlgorithm.HS256:
329 			case JWTAlgorithm.HS384:
330 			case JWTAlgorithm.HS512: {
331 				return signature == sign(signing_input, key, algo);
332 			}
333 			case JWTAlgorithm.RS256: {
334 				ubyte[] hash = new ubyte[SHA256_DIGEST_LENGTH];
335 				SHA256(cast(const(ubyte)*)signing_input.ptr, signing_input.length, hash.ptr);
336 				return verify_rs(hash.ptr, NID_sha256, 256, SHA256_DIGEST_LENGTH);
337 			}
338 			case JWTAlgorithm.RS384: {
339 				ubyte[] hash = new ubyte[SHA384_DIGEST_LENGTH];
340 				SHA384(cast(const(ubyte)*)signing_input.ptr, signing_input.length, hash.ptr);
341 				return verify_rs(hash.ptr, NID_sha384, 384, SHA384_DIGEST_LENGTH);
342 			}
343 			case JWTAlgorithm.RS512: {
344 				ubyte[] hash = new ubyte[SHA512_DIGEST_LENGTH];
345 				SHA512(cast(const(ubyte)*)signing_input.ptr, signing_input.length, hash.ptr);
346 				return verify_rs(hash.ptr, NID_sha512, 512, SHA512_DIGEST_LENGTH);
347 			}
348 
349 			case JWTAlgorithm.ES256:{
350 				ubyte[] hash = new ubyte[SHA256_DIGEST_LENGTH];
351 				SHA256(cast(const(ubyte)*)signing_input.ptr, signing_input.length, hash.ptr);
352 				return verify_es(NID_secp256k1, hash.ptr, SHA256_DIGEST_LENGTH );
353 			}
354 			case JWTAlgorithm.ES384:{
355 				ubyte[] hash = new ubyte[SHA384_DIGEST_LENGTH];
356 				SHA384(cast(const(ubyte)*)signing_input.ptr, signing_input.length, hash.ptr);
357 				return verify_es(NID_secp384r1, hash.ptr, SHA384_DIGEST_LENGTH );
358 			}
359 			case JWTAlgorithm.ES512: {
360 				ubyte[] hash = new ubyte[SHA512_DIGEST_LENGTH];
361 				SHA512(cast(const(ubyte)*)signing_input.ptr, signing_input.length, hash.ptr);
362 				return verify_es(NID_secp521r1, hash.ptr, SHA512_DIGEST_LENGTH );
363 			}
364 
365 			default:
366 				throw new VerifyException("Wrong algorithm.");
367 		}
368 	}
369 }
370 
371 unittest {
372     version (UseOpenSSL) {
373         import std.exception : assertThrown;
374         assertThrown!SignException(sign("message", "key", cast(JWTAlgorithm)"bogus_algo"));
375         assertThrown!VerifyException(verifySignature("signature", "signing_input", "key", cast(JWTAlgorithm)"bogus_algo"));
376     }
377 }