1/* Copyright 2015, 2016 OpenMarket Ltd
2 *
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15#include "olm/ratchet.hh"
16#include "olm/message.hh"
17#include "olm/memory.hh"
18#include "olm/cipher.h"
19#include "olm/pickle.hh"
20
21#include <cstring>
22
23namespace {
24
25static const std::uint8_t PROTOCOL_VERSION = 3;
26static const std::uint8_t MESSAGE_KEY_SEED[1] = {0x01};
27static const std::uint8_t CHAIN_KEY_SEED[1] = {0x02};
28static const std::size_t MAX_MESSAGE_GAP = 2000;
29
30
31/**
32 * Advance the root key, creating a new message chain.
33 *
34 * @param root_key previous root key R(n-1)
35 * @param our_key our new ratchet key T(n)
36 * @param their_key their most recent ratchet key T(n-1)
37 * @param info table of constants for the ratchet function
38 * @param new_root_key[out] returns the new root key R(n)
39 * @param new_chain_key[out] returns the first chain key in the new chain
40 * C(n,0)
41 */
42static void create_chain_key(
43 olm::SharedKey const & root_key,
44 _olm_curve25519_key_pair const & our_key,
45 _olm_curve25519_public_key const & their_key,
46 olm::KdfInfo const & info,
47 olm::SharedKey & new_root_key,
48 olm::ChainKey & new_chain_key
49) {
50 olm::SharedKey secret;
51 _olm_crypto_curve25519_shared_secret(our_key: &our_key, their_key: &their_key, output: secret);
52 std::uint8_t derived_secrets[2 * olm::OLM_SHARED_KEY_LENGTH];
53 _olm_crypto_hkdf_sha256(
54 input: secret, input_length: sizeof(secret),
55 info: root_key, info_length: sizeof(root_key),
56 salt: info.ratchet_info, salt_length: info.ratchet_info_length,
57 output: derived_secrets, output_length: sizeof(derived_secrets)
58 );
59 std::uint8_t const * pos = derived_secrets;
60 pos = olm::load_array(destination&: new_root_key, source: pos);
61 pos = olm::load_array(destination&: new_chain_key.key, source: pos);
62 new_chain_key.index = 0;
63 olm::unset(value&: derived_secrets);
64 olm::unset(value&: secret);
65}
66
67
68static void advance_chain_key(
69 olm::ChainKey const & chain_key,
70 olm::ChainKey & new_chain_key
71) {
72 _olm_crypto_hmac_sha256(
73 key: chain_key.key, key_length: sizeof(chain_key.key),
74 input: CHAIN_KEY_SEED, input_length: sizeof(CHAIN_KEY_SEED),
75 output: new_chain_key.key
76 );
77 new_chain_key.index = chain_key.index + 1;
78}
79
80
81static void create_message_keys(
82 olm::ChainKey const & chain_key,
83 olm::KdfInfo const & info,
84 olm::MessageKey & message_key) {
85 _olm_crypto_hmac_sha256(
86 key: chain_key.key, key_length: sizeof(chain_key.key),
87 input: MESSAGE_KEY_SEED, input_length: sizeof(MESSAGE_KEY_SEED),
88 output: message_key.key
89 );
90 message_key.index = chain_key.index;
91}
92
93
94static std::size_t verify_mac_and_decrypt(
95 _olm_cipher const *cipher,
96 olm::MessageKey const & message_key,
97 olm::MessageReader const & reader,
98 std::uint8_t * plaintext, std::size_t max_plaintext_length
99) {
100 return cipher->ops->decrypt(
101 cipher,
102 message_key.key, sizeof(message_key.key),
103 reader.input, reader.input_length,
104 reader.ciphertext, reader.ciphertext_length,
105 plaintext, max_plaintext_length
106 );
107}
108
109
110static std::size_t verify_mac_and_decrypt_for_existing_chain(
111 olm::Ratchet const & session,
112 olm::ChainKey const & chain,
113 olm::MessageReader const & reader,
114 std::uint8_t * plaintext, std::size_t max_plaintext_length
115) {
116 if (reader.counter < chain.index) {
117 return std::size_t(-1);
118 }
119
120 /* Limit the number of hashes we're prepared to compute */
121 if (reader.counter - chain.index > MAX_MESSAGE_GAP) {
122 return std::size_t(-1);
123 }
124
125 olm::ChainKey new_chain = chain;
126
127 while (new_chain.index < reader.counter) {
128 advance_chain_key(chain_key: new_chain, new_chain_key&: new_chain);
129 }
130
131 olm::MessageKey message_key;
132 create_message_keys(chain_key: new_chain, info: session.kdf_info, message_key);
133
134 std::size_t result = verify_mac_and_decrypt(
135 cipher: session.ratchet_cipher, message_key, reader,
136 plaintext, max_plaintext_length
137 );
138
139 olm::unset(value&: new_chain);
140 return result;
141}
142
143
144static std::size_t verify_mac_and_decrypt_for_new_chain(
145 olm::Ratchet const & session,
146 olm::MessageReader const & reader,
147 std::uint8_t * plaintext, std::size_t max_plaintext_length
148) {
149 olm::SharedKey new_root_key;
150 olm::ReceiverChain new_chain;
151
152 /* They shouldn't move to a new chain until we've sent them a message
153 * acknowledging the last one */
154 if (session.sender_chain.empty()) {
155 return std::size_t(-1);
156 }
157
158 /* Limit the number of hashes we're prepared to compute */
159 if (reader.counter > MAX_MESSAGE_GAP) {
160 return std::size_t(-1);
161 }
162 olm::load_array(destination&: new_chain.ratchet_key.public_key, source: reader.ratchet_key);
163
164 create_chain_key(
165 root_key: session.root_key, our_key: session.sender_chain[0].ratchet_key,
166 their_key: new_chain.ratchet_key, info: session.kdf_info,
167 new_root_key, new_chain_key&: new_chain.chain_key
168 );
169 std::size_t result = verify_mac_and_decrypt_for_existing_chain(
170 session, chain: new_chain.chain_key, reader,
171 plaintext, max_plaintext_length
172 );
173 olm::unset(value&: new_root_key);
174 olm::unset(value&: new_chain);
175 return result;
176}
177
178} // namespace
179
180
181olm::Ratchet::Ratchet(
182 olm::KdfInfo const & kdf_info,
183 _olm_cipher const * ratchet_cipher
184) : kdf_info(kdf_info),
185 ratchet_cipher(ratchet_cipher),
186 last_error(OlmErrorCode::OLM_SUCCESS) {
187}
188
189
190void olm::Ratchet::initialise_as_bob(
191 std::uint8_t const * shared_secret, std::size_t shared_secret_length,
192 _olm_curve25519_public_key const & their_ratchet_key
193) {
194 std::uint8_t derived_secrets[2 * olm::OLM_SHARED_KEY_LENGTH];
195 _olm_crypto_hkdf_sha256(
196 input: shared_secret, input_length: shared_secret_length,
197 info: nullptr, info_length: 0,
198 salt: kdf_info.root_info, salt_length: kdf_info.root_info_length,
199 output: derived_secrets, output_length: sizeof(derived_secrets)
200 );
201 receiver_chains.insert();
202 receiver_chains[0].chain_key.index = 0;
203 std::uint8_t const * pos = derived_secrets;
204 pos = olm::load_array(destination&: root_key, source: pos);
205 pos = olm::load_array(destination&: receiver_chains[0].chain_key.key, source: pos);
206 receiver_chains[0].ratchet_key = their_ratchet_key;
207 olm::unset(value&: derived_secrets);
208}
209
210
211void olm::Ratchet::initialise_as_alice(
212 std::uint8_t const * shared_secret, std::size_t shared_secret_length,
213 _olm_curve25519_key_pair const & our_ratchet_key
214) {
215 std::uint8_t derived_secrets[2 * olm::OLM_SHARED_KEY_LENGTH];
216 _olm_crypto_hkdf_sha256(
217 input: shared_secret, input_length: shared_secret_length,
218 info: nullptr, info_length: 0,
219 salt: kdf_info.root_info, salt_length: kdf_info.root_info_length,
220 output: derived_secrets, output_length: sizeof(derived_secrets)
221 );
222 sender_chain.insert();
223 sender_chain[0].chain_key.index = 0;
224 std::uint8_t const * pos = derived_secrets;
225 pos = olm::load_array(destination&: root_key, source: pos);
226 pos = olm::load_array(destination&: sender_chain[0].chain_key.key, source: pos);
227 sender_chain[0].ratchet_key = our_ratchet_key;
228 olm::unset(value&: derived_secrets);
229}
230
231namespace olm {
232
233
234static std::size_t pickle_length(
235 const olm::SharedKey & value
236) {
237 return olm::OLM_SHARED_KEY_LENGTH;
238}
239
240
241static std::uint8_t * pickle(
242 std::uint8_t * pos,
243 const olm::SharedKey & value
244) {
245 return olm::pickle_bytes(pos, bytes: value, bytes_length: olm::OLM_SHARED_KEY_LENGTH);
246}
247
248
249static std::uint8_t const * unpickle(
250 std::uint8_t const * pos, std::uint8_t const * end,
251 olm::SharedKey & value
252) {
253 return olm::unpickle_bytes(pos, end, bytes: value, bytes_length: olm::OLM_SHARED_KEY_LENGTH);
254}
255
256
257static std::size_t pickle_length(
258 const olm::SenderChain & value
259) {
260 std::size_t length = 0;
261 length += olm::pickle_length(value: value.ratchet_key);
262 length += olm::pickle_length(value: value.chain_key.key);
263 length += olm::pickle_length(value: value.chain_key.index);
264 return length;
265}
266
267
268static std::uint8_t * pickle(
269 std::uint8_t * pos,
270 const olm::SenderChain & value
271) {
272 pos = olm::pickle(pos, value: value.ratchet_key);
273 pos = olm::pickle(pos, value: value.chain_key.key);
274 pos = olm::pickle(pos, value: value.chain_key.index);
275 return pos;
276}
277
278
279static std::uint8_t const * unpickle(
280 std::uint8_t const * pos, std::uint8_t const * end,
281 olm::SenderChain & value
282) {
283 pos = olm::unpickle(pos, end, value&: value.ratchet_key); UNPICKLE_OK(pos);
284 pos = olm::unpickle(pos, end, value&: value.chain_key.key); UNPICKLE_OK(pos);
285 pos = olm::unpickle(pos, end, value&: value.chain_key.index); UNPICKLE_OK(pos);
286 return pos;
287}
288
289static std::size_t pickle_length(
290 const olm::ReceiverChain & value
291) {
292 std::size_t length = 0;
293 length += olm::pickle_length(value: value.ratchet_key);
294 length += olm::pickle_length(value: value.chain_key.key);
295 length += olm::pickle_length(value: value.chain_key.index);
296 return length;
297}
298
299
300static std::uint8_t * pickle(
301 std::uint8_t * pos,
302 const olm::ReceiverChain & value
303) {
304 pos = olm::pickle(pos, value: value.ratchet_key);
305 pos = olm::pickle(pos, value: value.chain_key.key);
306 pos = olm::pickle(pos, value: value.chain_key.index);
307 return pos;
308}
309
310
311static std::uint8_t const * unpickle(
312 std::uint8_t const * pos, std::uint8_t const * end,
313 olm::ReceiverChain & value
314) {
315 pos = olm::unpickle(pos, end, value&: value.ratchet_key); UNPICKLE_OK(pos);
316 pos = olm::unpickle(pos, end, value&: value.chain_key.key); UNPICKLE_OK(pos);
317 pos = olm::unpickle(pos, end, value&: value.chain_key.index); UNPICKLE_OK(pos);
318 return pos;
319}
320
321
322static std::size_t pickle_length(
323 const olm::SkippedMessageKey & value
324) {
325 std::size_t length = 0;
326 length += olm::pickle_length(value: value.ratchet_key);
327 length += olm::pickle_length(value: value.message_key.key);
328 length += olm::pickle_length(value: value.message_key.index);
329 return length;
330}
331
332
333static std::uint8_t * pickle(
334 std::uint8_t * pos,
335 const olm::SkippedMessageKey & value
336) {
337 pos = olm::pickle(pos, value: value.ratchet_key);
338 pos = olm::pickle(pos, value: value.message_key.key);
339 pos = olm::pickle(pos, value: value.message_key.index);
340 return pos;
341}
342
343
344static std::uint8_t const * unpickle(
345 std::uint8_t const * pos, std::uint8_t const * end,
346 olm::SkippedMessageKey & value
347) {
348 pos = olm::unpickle(pos, end, value&: value.ratchet_key); UNPICKLE_OK(pos);
349 pos = olm::unpickle(pos, end, value&: value.message_key.key); UNPICKLE_OK(pos);
350 pos = olm::unpickle(pos, end, value&: value.message_key.index); UNPICKLE_OK(pos);
351 return pos;
352}
353
354
355} // namespace olm
356
357
358std::size_t olm::pickle_length(
359 olm::Ratchet const & value
360) {
361 std::size_t length = 0;
362 length += olm::OLM_SHARED_KEY_LENGTH;
363 length += olm::pickle_length(list: value.sender_chain);
364 length += olm::pickle_length(list: value.receiver_chains);
365 length += olm::pickle_length(list: value.skipped_message_keys);
366 return length;
367}
368
369std::uint8_t * olm::pickle(
370 std::uint8_t * pos,
371 olm::Ratchet const & value
372) {
373 pos = pickle(pos, value: value.root_key);
374 pos = pickle(pos, list: value.sender_chain);
375 pos = pickle(pos, list: value.receiver_chains);
376 pos = pickle(pos, list: value.skipped_message_keys);
377 return pos;
378}
379
380
381std::uint8_t const * olm::unpickle(
382 std::uint8_t const * pos, std::uint8_t const * end,
383 olm::Ratchet & value,
384 bool includes_chain_index
385) {
386 pos = unpickle(pos, end, value&: value.root_key); UNPICKLE_OK(pos);
387 pos = unpickle(pos, end, list&: value.sender_chain); UNPICKLE_OK(pos);
388 pos = unpickle(pos, end, list&: value.receiver_chains); UNPICKLE_OK(pos);
389 pos = unpickle(pos, end, list&: value.skipped_message_keys); UNPICKLE_OK(pos);
390
391 // pickle v 0x80000001 includes a chain index; pickle v1 does not.
392 if (includes_chain_index) {
393 std::uint32_t dummy;
394 pos = unpickle(pos, end, value&: dummy); UNPICKLE_OK(pos);
395 }
396 return pos;
397}
398
399
400std::size_t olm::Ratchet::encrypt_output_length(
401 std::size_t plaintext_length
402) const {
403 std::size_t counter = 0;
404 if (!sender_chain.empty()) {
405 counter = sender_chain[0].chain_key.index;
406 }
407 std::size_t padded = ratchet_cipher->ops->encrypt_ciphertext_length(
408 ratchet_cipher,
409 plaintext_length
410 );
411 return olm::encode_message_length(
412 counter, CURVE25519_KEY_LENGTH, ciphertext_length: padded, mac_length: ratchet_cipher->ops->mac_length(ratchet_cipher)
413 );
414}
415
416
417std::size_t olm::Ratchet::encrypt_random_length() const {
418 return sender_chain.empty() ? CURVE25519_RANDOM_LENGTH : 0;
419}
420
421
422std::size_t olm::Ratchet::encrypt(
423 std::uint8_t const * plaintext, std::size_t plaintext_length,
424 std::uint8_t const * random, std::size_t random_length,
425 std::uint8_t * output, std::size_t max_output_length
426) {
427 std::size_t output_length = encrypt_output_length(plaintext_length);
428
429 if (random_length < encrypt_random_length()) {
430 last_error = OlmErrorCode::OLM_NOT_ENOUGH_RANDOM;
431 return std::size_t(-1);
432 }
433 if (max_output_length < output_length) {
434 last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
435 return std::size_t(-1);
436 }
437
438 if (sender_chain.empty()) {
439 sender_chain.insert();
440 _olm_crypto_curve25519_generate_key(random_32_bytes: random, output: &sender_chain[0].ratchet_key);
441 create_chain_key(
442 root_key,
443 our_key: sender_chain[0].ratchet_key,
444 their_key: receiver_chains[0].ratchet_key,
445 info: kdf_info,
446 new_root_key&: root_key, new_chain_key&: sender_chain[0].chain_key
447 );
448 }
449
450 MessageKey keys;
451 create_message_keys(chain_key: sender_chain[0].chain_key, info: kdf_info, message_key&: keys);
452 advance_chain_key(chain_key: sender_chain[0].chain_key, new_chain_key&: sender_chain[0].chain_key);
453
454 std::size_t ciphertext_length = ratchet_cipher->ops->encrypt_ciphertext_length(
455 ratchet_cipher,
456 plaintext_length
457 );
458 std::uint32_t counter = keys.index;
459 _olm_curve25519_public_key const & ratchet_key =
460 sender_chain[0].ratchet_key.public_key;
461
462 olm::MessageWriter writer;
463
464 olm::encode_message(
465 writer, version: PROTOCOL_VERSION, counter, CURVE25519_KEY_LENGTH,
466 ciphertext_length,
467 output
468 );
469
470 olm::store_array(destination: writer.ratchet_key, source: ratchet_key.public_key);
471
472 ratchet_cipher->ops->encrypt(
473 ratchet_cipher,
474 keys.key, sizeof(keys.key),
475 plaintext, plaintext_length,
476 writer.ciphertext, ciphertext_length,
477 output, output_length
478 );
479
480 olm::unset(value&: keys);
481 return output_length;
482}
483
484
485std::size_t olm::Ratchet::decrypt_max_plaintext_length(
486 std::uint8_t const * input, std::size_t input_length
487) {
488 olm::MessageReader reader;
489 olm::decode_message(
490 reader, input, input_length,
491 mac_length: ratchet_cipher->ops->mac_length(ratchet_cipher)
492 );
493
494 if (!reader.ciphertext) {
495 last_error = OlmErrorCode::OLM_BAD_MESSAGE_FORMAT;
496 return std::size_t(-1);
497 }
498
499 return ratchet_cipher->ops->decrypt_max_plaintext_length(
500 ratchet_cipher, reader.ciphertext_length);
501}
502
503
504std::size_t olm::Ratchet::decrypt(
505 std::uint8_t const * input, std::size_t input_length,
506 std::uint8_t * plaintext, std::size_t max_plaintext_length
507) {
508 olm::MessageReader reader;
509 olm::decode_message(
510 reader, input, input_length,
511 mac_length: ratchet_cipher->ops->mac_length(ratchet_cipher)
512 );
513
514 if (reader.version != PROTOCOL_VERSION) {
515 last_error = OlmErrorCode::OLM_BAD_MESSAGE_VERSION;
516 return std::size_t(-1);
517 }
518
519 if (!reader.has_counter || !reader.ratchet_key || !reader.ciphertext) {
520 last_error = OlmErrorCode::OLM_BAD_MESSAGE_FORMAT;
521 return std::size_t(-1);
522 }
523
524 std::size_t max_length = ratchet_cipher->ops->decrypt_max_plaintext_length(
525 ratchet_cipher,
526 reader.ciphertext_length
527 );
528
529 if (max_plaintext_length < max_length) {
530 last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
531 return std::size_t(-1);
532 }
533
534 if (reader.ratchet_key_length != CURVE25519_KEY_LENGTH) {
535 last_error = OlmErrorCode::OLM_BAD_MESSAGE_FORMAT;
536 return std::size_t(-1);
537 }
538
539 ReceiverChain * chain = nullptr;
540
541 for (olm::ReceiverChain & receiver_chain : receiver_chains) {
542 if (0 == std::memcmp(
543 s1: receiver_chain.ratchet_key.public_key, s2: reader.ratchet_key,
544 CURVE25519_KEY_LENGTH
545 )) {
546 chain = &receiver_chain;
547 break;
548 }
549 }
550
551 std::size_t result = std::size_t(-1);
552
553 if (!chain) {
554 result = verify_mac_and_decrypt_for_new_chain(
555 session: *this, reader, plaintext, max_plaintext_length
556 );
557 } else if (chain->chain_key.index > reader.counter) {
558 /* Chain already advanced beyond the key for this message
559 * Check if the message keys are in the skipped key list. */
560 for (olm::SkippedMessageKey & skipped : skipped_message_keys) {
561 if (reader.counter == skipped.message_key.index
562 && 0 == std::memcmp(
563 s1: skipped.ratchet_key.public_key, s2: reader.ratchet_key,
564 CURVE25519_KEY_LENGTH
565 )
566 ) {
567 /* Found the key for this message. Check the MAC. */
568
569 result = verify_mac_and_decrypt(
570 cipher: ratchet_cipher, message_key: skipped.message_key, reader,
571 plaintext, max_plaintext_length
572 );
573
574 if (result != std::size_t(-1)) {
575 /* Remove the key from the skipped keys now that we've
576 * decoded the message it corresponds to. */
577 olm::unset(value&: skipped);
578 skipped_message_keys.erase(pos: &skipped);
579 return result;
580 }
581 }
582 }
583 } else {
584 result = verify_mac_and_decrypt_for_existing_chain(
585 session: *this, chain: chain->chain_key,
586 reader, plaintext, max_plaintext_length
587 );
588 }
589
590 if (result == std::size_t(-1)) {
591 last_error = OlmErrorCode::OLM_BAD_MESSAGE_MAC;
592 return std::size_t(-1);
593 }
594
595 if (!chain) {
596 /* They have started using a new ephemeral ratchet key.
597 * We need to derive a new set of chain keys.
598 * We can discard our previous ephemeral ratchet key.
599 * We will generate a new key when we send the next message. */
600
601 chain = receiver_chains.insert();
602 olm::load_array(destination&: chain->ratchet_key.public_key, source: reader.ratchet_key);
603
604 // TODO: we've already done this once, in
605 // verify_mac_and_decrypt_for_new_chain(). we could reuse the result.
606 create_chain_key(
607 root_key, our_key: sender_chain[0].ratchet_key, their_key: chain->ratchet_key,
608 info: kdf_info, new_root_key&: root_key, new_chain_key&: chain->chain_key
609 );
610
611 olm::unset(value&: sender_chain[0]);
612 sender_chain.erase(pos: sender_chain.begin());
613 }
614
615 while (chain->chain_key.index < reader.counter) {
616 olm::SkippedMessageKey & key = *skipped_message_keys.insert();
617 create_message_keys(chain_key: chain->chain_key, info: kdf_info, message_key&: key.message_key);
618 key.ratchet_key = chain->ratchet_key;
619 advance_chain_key(chain_key: chain->chain_key, new_chain_key&: chain->chain_key);
620 }
621
622 advance_chain_key(chain_key: chain->chain_key, new_chain_key&: chain->chain_key);
623
624 return result;
625}
626