1 | /* Copyright 2015, 2016 OpenMarket Ltd |
2 | * |
3 | * Licensed under the Apache License, Version 2.0 (the "License"); |
4 | * you may not use this file except in compliance with the License. |
5 | * You may obtain a copy of the License at |
6 | * |
7 | * http://www.apache.org/licenses/LICENSE-2.0 |
8 | * |
9 | * Unless required by applicable law or agreed to in writing, software |
10 | * distributed under the License is distributed on an "AS IS" BASIS, |
11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | * See the License for the specific language governing permissions and |
13 | * limitations under the License. |
14 | */ |
15 | #include "olm/session.hh" |
16 | #include "olm/cipher.h" |
17 | #include "olm/crypto.h" |
18 | #include "olm/account.hh" |
19 | #include "olm/memory.hh" |
20 | #include "olm/message.hh" |
21 | #include "olm/pickle.hh" |
22 | |
23 | #include <cstring> |
24 | #include <stdio.h> |
25 | |
26 | namespace { |
27 | |
28 | static const std::uint8_t PROTOCOL_VERSION = 0x3; |
29 | |
30 | static const std::uint8_t ROOT_KDF_INFO[] = "OLM_ROOT" ; |
31 | static const std::uint8_t RATCHET_KDF_INFO[] = "OLM_RATCHET" ; |
32 | static const std::uint8_t CIPHER_KDF_INFO[] = "OLM_KEYS" ; |
33 | |
34 | static const olm::KdfInfo OLM_KDF_INFO = { |
35 | .root_info: ROOT_KDF_INFO, .root_info_length: sizeof(ROOT_KDF_INFO) - 1, |
36 | .ratchet_info: RATCHET_KDF_INFO, .ratchet_info_length: sizeof(RATCHET_KDF_INFO) - 1 |
37 | }; |
38 | |
39 | static const struct _olm_cipher_aes_sha_256 OLM_CIPHER = |
40 | OLM_CIPHER_INIT_AES_SHA_256(CIPHER_KDF_INFO); |
41 | |
42 | } // namespace |
43 | |
44 | olm::Session::Session( |
45 | ) : ratchet(OLM_KDF_INFO, OLM_CIPHER_BASE(&OLM_CIPHER)), |
46 | last_error(OlmErrorCode::OLM_SUCCESS), |
47 | received_message(false) { |
48 | |
49 | } |
50 | |
51 | |
52 | std::size_t olm::Session::new_outbound_session_random_length() const { |
53 | return CURVE25519_RANDOM_LENGTH * 2; |
54 | } |
55 | |
56 | |
57 | std::size_t olm::Session::new_outbound_session( |
58 | olm::Account const & local_account, |
59 | _olm_curve25519_public_key const & identity_key, |
60 | _olm_curve25519_public_key const & one_time_key, |
61 | std::uint8_t const * random, std::size_t random_length |
62 | ) { |
63 | if (random_length < new_outbound_session_random_length()) { |
64 | last_error = OlmErrorCode::OLM_NOT_ENOUGH_RANDOM; |
65 | return std::size_t(-1); |
66 | } |
67 | |
68 | _olm_curve25519_key_pair base_key; |
69 | _olm_crypto_curve25519_generate_key(random_32_bytes: random, output: &base_key); |
70 | |
71 | _olm_curve25519_key_pair ratchet_key; |
72 | _olm_crypto_curve25519_generate_key(random_32_bytes: random + CURVE25519_RANDOM_LENGTH, output: &ratchet_key); |
73 | |
74 | _olm_curve25519_key_pair const & alice_identity_key_pair = ( |
75 | local_account.identity_keys.curve25519_key |
76 | ); |
77 | |
78 | received_message = false; |
79 | alice_identity_key = alice_identity_key_pair.public_key; |
80 | alice_base_key = base_key.public_key; |
81 | bob_one_time_key = one_time_key; |
82 | |
83 | // Calculate the shared secret S via triple DH |
84 | std::uint8_t secret[3 * CURVE25519_SHARED_SECRET_LENGTH]; |
85 | std::uint8_t * pos = secret; |
86 | |
87 | _olm_crypto_curve25519_shared_secret(our_key: &alice_identity_key_pair, their_key: &one_time_key, output: pos); |
88 | pos += CURVE25519_SHARED_SECRET_LENGTH; |
89 | _olm_crypto_curve25519_shared_secret(our_key: &base_key, their_key: &identity_key, output: pos); |
90 | pos += CURVE25519_SHARED_SECRET_LENGTH; |
91 | _olm_crypto_curve25519_shared_secret(our_key: &base_key, their_key: &one_time_key, output: pos); |
92 | |
93 | ratchet.initialise_as_alice(shared_secret: secret, shared_secret_length: sizeof(secret), our_ratchet_key: ratchet_key); |
94 | |
95 | olm::unset(value&: base_key); |
96 | olm::unset(value&: ratchet_key); |
97 | olm::unset(value&: secret); |
98 | |
99 | return std::size_t(0); |
100 | } |
101 | |
102 | namespace { |
103 | |
104 | static bool check_message_fields( |
105 | olm::PreKeyMessageReader & reader, bool have_their_identity_key |
106 | ) { |
107 | bool ok = true; |
108 | ok = ok && (have_their_identity_key || reader.identity_key); |
109 | if (reader.identity_key) { |
110 | ok = ok && reader.identity_key_length == CURVE25519_KEY_LENGTH; |
111 | } |
112 | ok = ok && reader.message; |
113 | ok = ok && reader.base_key; |
114 | ok = ok && reader.base_key_length == CURVE25519_KEY_LENGTH; |
115 | ok = ok && reader.one_time_key; |
116 | ok = ok && reader.one_time_key_length == CURVE25519_KEY_LENGTH; |
117 | return ok; |
118 | } |
119 | |
120 | } // namespace |
121 | |
122 | |
123 | std::size_t olm::Session::new_inbound_session( |
124 | olm::Account & local_account, |
125 | _olm_curve25519_public_key const * their_identity_key, |
126 | std::uint8_t const * one_time_key_message, std::size_t message_length |
127 | ) { |
128 | olm::PreKeyMessageReader reader; |
129 | decode_one_time_key_message(reader, input: one_time_key_message, input_length: message_length); |
130 | |
131 | if (!check_message_fields(reader, have_their_identity_key: their_identity_key)) { |
132 | last_error = OlmErrorCode::OLM_BAD_MESSAGE_FORMAT; |
133 | return std::size_t(-1); |
134 | } |
135 | |
136 | if (reader.identity_key && their_identity_key) { |
137 | bool same = 0 == std::memcmp( |
138 | s1: their_identity_key->public_key, s2: reader.identity_key, CURVE25519_KEY_LENGTH |
139 | ); |
140 | if (!same) { |
141 | last_error = OlmErrorCode::OLM_BAD_MESSAGE_KEY_ID; |
142 | return std::size_t(-1); |
143 | } |
144 | } |
145 | |
146 | olm::load_array(destination&: alice_identity_key.public_key, source: reader.identity_key); |
147 | olm::load_array(destination&: alice_base_key.public_key, source: reader.base_key); |
148 | olm::load_array(destination&: bob_one_time_key.public_key, source: reader.one_time_key); |
149 | |
150 | olm::MessageReader message_reader; |
151 | decode_message( |
152 | reader&: message_reader, input: reader.message, input_length: reader.message_length, |
153 | mac_length: ratchet.ratchet_cipher->ops->mac_length(ratchet.ratchet_cipher) |
154 | ); |
155 | |
156 | if (!message_reader.ratchet_key |
157 | || message_reader.ratchet_key_length != CURVE25519_KEY_LENGTH) { |
158 | last_error = OlmErrorCode::OLM_BAD_MESSAGE_FORMAT; |
159 | return std::size_t(-1); |
160 | } |
161 | |
162 | _olm_curve25519_public_key ratchet_key; |
163 | olm::load_array(destination&: ratchet_key.public_key, source: message_reader.ratchet_key); |
164 | |
165 | olm::OneTimeKey const * our_one_time_key = local_account.lookup_key( |
166 | public_key: bob_one_time_key |
167 | ); |
168 | |
169 | if (!our_one_time_key) { |
170 | last_error = OlmErrorCode::OLM_BAD_MESSAGE_KEY_ID; |
171 | return std::size_t(-1); |
172 | } |
173 | |
174 | _olm_curve25519_key_pair const & bob_identity_key = ( |
175 | local_account.identity_keys.curve25519_key |
176 | ); |
177 | _olm_curve25519_key_pair const & bob_one_time_key = our_one_time_key->key; |
178 | |
179 | // Calculate the shared secret S via triple DH |
180 | std::uint8_t secret[CURVE25519_SHARED_SECRET_LENGTH * 3]; |
181 | std::uint8_t * pos = secret; |
182 | _olm_crypto_curve25519_shared_secret(our_key: &bob_one_time_key, their_key: &alice_identity_key, output: pos); |
183 | pos += CURVE25519_SHARED_SECRET_LENGTH; |
184 | _olm_crypto_curve25519_shared_secret(our_key: &bob_identity_key, their_key: &alice_base_key, output: pos); |
185 | pos += CURVE25519_SHARED_SECRET_LENGTH; |
186 | _olm_crypto_curve25519_shared_secret(our_key: &bob_one_time_key, their_key: &alice_base_key, output: pos); |
187 | |
188 | ratchet.initialise_as_bob(shared_secret: secret, shared_secret_length: sizeof(secret), their_ratchet_key: ratchet_key); |
189 | |
190 | olm::unset(value&: secret); |
191 | |
192 | return std::size_t(0); |
193 | } |
194 | |
195 | |
196 | std::size_t olm::Session::session_id_length() const { |
197 | return SHA256_OUTPUT_LENGTH; |
198 | } |
199 | |
200 | |
201 | std::size_t olm::Session::session_id( |
202 | std::uint8_t * id, std::size_t id_length |
203 | ) { |
204 | if (id_length < session_id_length()) { |
205 | last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL; |
206 | return std::size_t(-1); |
207 | } |
208 | std::uint8_t tmp[CURVE25519_KEY_LENGTH * 3]; |
209 | std::uint8_t * pos = tmp; |
210 | pos = olm::store_array(destination: pos, source: alice_identity_key.public_key); |
211 | pos = olm::store_array(destination: pos, source: alice_base_key.public_key); |
212 | pos = olm::store_array(destination: pos, source: bob_one_time_key.public_key); |
213 | _olm_crypto_sha256(input: tmp, input_length: sizeof(tmp), output: id); |
214 | return session_id_length(); |
215 | } |
216 | |
217 | |
218 | bool olm::Session::matches_inbound_session( |
219 | _olm_curve25519_public_key const * their_identity_key, |
220 | std::uint8_t const * one_time_key_message, std::size_t message_length |
221 | ) const { |
222 | olm::PreKeyMessageReader reader; |
223 | decode_one_time_key_message(reader, input: one_time_key_message, input_length: message_length); |
224 | |
225 | if (!check_message_fields(reader, have_their_identity_key: their_identity_key)) { |
226 | return false; |
227 | } |
228 | |
229 | bool same = true; |
230 | if (reader.identity_key) { |
231 | same = same && 0 == std::memcmp( |
232 | s1: reader.identity_key, s2: alice_identity_key.public_key, CURVE25519_KEY_LENGTH |
233 | ); |
234 | } |
235 | if (their_identity_key) { |
236 | same = same && 0 == std::memcmp( |
237 | s1: their_identity_key->public_key, s2: alice_identity_key.public_key, |
238 | CURVE25519_KEY_LENGTH |
239 | ); |
240 | } |
241 | same = same && 0 == std::memcmp( |
242 | s1: reader.base_key, s2: alice_base_key.public_key, CURVE25519_KEY_LENGTH |
243 | ); |
244 | same = same && 0 == std::memcmp( |
245 | s1: reader.one_time_key, s2: bob_one_time_key.public_key, CURVE25519_KEY_LENGTH |
246 | ); |
247 | return same; |
248 | } |
249 | |
250 | |
251 | olm::MessageType olm::Session::encrypt_message_type() const { |
252 | if (received_message) { |
253 | return olm::MessageType::MESSAGE; |
254 | } else { |
255 | return olm::MessageType::PRE_KEY; |
256 | } |
257 | } |
258 | |
259 | |
260 | std::size_t olm::Session::encrypt_message_length( |
261 | std::size_t plaintext_length |
262 | ) const { |
263 | std::size_t message_length = ratchet.encrypt_output_length( |
264 | plaintext_length |
265 | ); |
266 | |
267 | if (received_message) { |
268 | return message_length; |
269 | } |
270 | |
271 | return encode_one_time_key_message_length( |
272 | CURVE25519_KEY_LENGTH, |
273 | CURVE25519_KEY_LENGTH, |
274 | CURVE25519_KEY_LENGTH, |
275 | message_length |
276 | ); |
277 | } |
278 | |
279 | |
280 | std::size_t olm::Session::encrypt_random_length() const { |
281 | return ratchet.encrypt_random_length(); |
282 | } |
283 | |
284 | |
285 | std::size_t olm::Session::encrypt( |
286 | std::uint8_t const * plaintext, std::size_t plaintext_length, |
287 | std::uint8_t const * random, std::size_t random_length, |
288 | std::uint8_t * message, std::size_t message_length |
289 | ) { |
290 | if (message_length < encrypt_message_length(plaintext_length)) { |
291 | last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL; |
292 | return std::size_t(-1); |
293 | } |
294 | std::uint8_t * message_body; |
295 | std::size_t message_body_length = ratchet.encrypt_output_length( |
296 | plaintext_length |
297 | ); |
298 | |
299 | if (received_message) { |
300 | message_body = message; |
301 | } else { |
302 | olm::PreKeyMessageWriter writer; |
303 | encode_one_time_key_message( |
304 | writer, |
305 | version: PROTOCOL_VERSION, |
306 | CURVE25519_KEY_LENGTH, |
307 | CURVE25519_KEY_LENGTH, |
308 | CURVE25519_KEY_LENGTH, |
309 | message_length: message_body_length, |
310 | output: message |
311 | ); |
312 | olm::store_array(destination: writer.one_time_key, source: bob_one_time_key.public_key); |
313 | olm::store_array(destination: writer.identity_key, source: alice_identity_key.public_key); |
314 | olm::store_array(destination: writer.base_key, source: alice_base_key.public_key); |
315 | message_body = writer.message; |
316 | } |
317 | |
318 | std::size_t result = ratchet.encrypt( |
319 | plaintext, plaintext_length, |
320 | random, random_length, |
321 | output: message_body, max_output_length: message_body_length |
322 | ); |
323 | |
324 | if (result == std::size_t(-1)) { |
325 | last_error = ratchet.last_error; |
326 | ratchet.last_error = OlmErrorCode::OLM_SUCCESS; |
327 | return result; |
328 | } |
329 | |
330 | return result; |
331 | } |
332 | |
333 | |
334 | std::size_t olm::Session::decrypt_max_plaintext_length( |
335 | MessageType message_type, |
336 | std::uint8_t const * message, std::size_t message_length |
337 | ) { |
338 | std::uint8_t const * message_body; |
339 | std::size_t message_body_length; |
340 | if (message_type == olm::MessageType::MESSAGE) { |
341 | message_body = message; |
342 | message_body_length = message_length; |
343 | } else { |
344 | olm::PreKeyMessageReader reader; |
345 | decode_one_time_key_message(reader, input: message, input_length: message_length); |
346 | if (!reader.message) { |
347 | last_error = OlmErrorCode::OLM_BAD_MESSAGE_FORMAT; |
348 | return std::size_t(-1); |
349 | } |
350 | message_body = reader.message; |
351 | message_body_length = reader.message_length; |
352 | } |
353 | |
354 | std::size_t result = ratchet.decrypt_max_plaintext_length( |
355 | input: message_body, input_length: message_body_length |
356 | ); |
357 | |
358 | if (result == std::size_t(-1)) { |
359 | last_error = ratchet.last_error; |
360 | ratchet.last_error = OlmErrorCode::OLM_SUCCESS; |
361 | } |
362 | return result; |
363 | } |
364 | |
365 | |
366 | std::size_t olm::Session::decrypt( |
367 | olm::MessageType message_type, |
368 | std::uint8_t const * message, std::size_t message_length, |
369 | std::uint8_t * plaintext, std::size_t max_plaintext_length |
370 | ) { |
371 | std::uint8_t const * message_body; |
372 | std::size_t message_body_length; |
373 | if (message_type == olm::MessageType::MESSAGE) { |
374 | message_body = message; |
375 | message_body_length = message_length; |
376 | } else { |
377 | olm::PreKeyMessageReader reader; |
378 | decode_one_time_key_message(reader, input: message, input_length: message_length); |
379 | if (!reader.message) { |
380 | last_error = OlmErrorCode::OLM_BAD_MESSAGE_FORMAT; |
381 | return std::size_t(-1); |
382 | } |
383 | message_body = reader.message; |
384 | message_body_length = reader.message_length; |
385 | } |
386 | |
387 | std::size_t result = ratchet.decrypt( |
388 | input: message_body, input_length: message_body_length, plaintext, max_plaintext_length |
389 | ); |
390 | |
391 | if (result == std::size_t(-1)) { |
392 | last_error = ratchet.last_error; |
393 | ratchet.last_error = OlmErrorCode::OLM_SUCCESS; |
394 | return result; |
395 | } |
396 | |
397 | received_message = true; |
398 | return result; |
399 | } |
400 | |
401 | // make the description end with "..." instead of stopping abruptly with no |
402 | // warning |
403 | void elide_description(char *end) { |
404 | end[-3] = '.'; |
405 | end[-2] = '.'; |
406 | end[-1] = '.'; |
407 | end[0] = '\0'; |
408 | } |
409 | |
410 | void olm::Session::describe(char *describe_buffer, size_t buflen) { |
411 | // how much of the buffer is remaining (this is an int rather than a size_t |
412 | // because it will get compared to the return value from snprintf) |
413 | int remaining = buflen; |
414 | // do nothing if we have a zero-length buffer, or if buflen > INT_MAX, |
415 | // resulting in an overflow |
416 | if (remaining <= 0) return; |
417 | |
418 | describe_buffer[0] = '\0'; |
419 | // we need at least 23 characters to get any sort of meaningful |
420 | // information, so bail if we don't have that. (But more importantly, we |
421 | // need it to be at least 4 so that elide_description doesn't go out of |
422 | // bounds.) |
423 | if (remaining < 23) return; |
424 | |
425 | int size; |
426 | |
427 | // check that snprintf didn't return an error or reach the end of the buffer |
428 | #define CHECK_SIZE_AND_ADVANCE \ |
429 | if (size > remaining) { \ |
430 | return elide_description(describe_buffer + remaining - 1); \ |
431 | } else if (size > 0) { \ |
432 | describe_buffer += size; \ |
433 | remaining -= size; \ |
434 | } else { \ |
435 | return; \ |
436 | } |
437 | |
438 | size = snprintf( |
439 | s: describe_buffer, maxlen: remaining, |
440 | format: "sender chain index: %d " , ratchet.sender_chain[0].chain_key.index |
441 | ); |
442 | CHECK_SIZE_AND_ADVANCE; |
443 | |
444 | size = snprintf(s: describe_buffer, maxlen: remaining, format: "receiver chain indices:" ); |
445 | CHECK_SIZE_AND_ADVANCE; |
446 | |
447 | for (size_t i = 0; i < ratchet.receiver_chains.size(); ++i) { |
448 | size = snprintf( |
449 | s: describe_buffer, maxlen: remaining, |
450 | format: " %d" , ratchet.receiver_chains[i].chain_key.index |
451 | ); |
452 | CHECK_SIZE_AND_ADVANCE; |
453 | } |
454 | |
455 | size = snprintf(s: describe_buffer, maxlen: remaining, format: " skipped message keys:" ); |
456 | CHECK_SIZE_AND_ADVANCE; |
457 | |
458 | for (size_t i = 0; i < ratchet.skipped_message_keys.size(); ++i) { |
459 | size = snprintf( |
460 | s: describe_buffer, maxlen: remaining, |
461 | format: " %d" , ratchet.skipped_message_keys[i].message_key.index |
462 | ); |
463 | CHECK_SIZE_AND_ADVANCE; |
464 | } |
465 | #undef CHECK_SIZE_AND_ADVANCE |
466 | } |
467 | |
468 | namespace { |
469 | // the master branch writes pickle version 1; the logging_enabled branch writes |
470 | // 0x80000001. |
471 | static const std::uint32_t SESSION_PICKLE_VERSION = 1; |
472 | } |
473 | |
474 | std::size_t olm::pickle_length( |
475 | Session const & value |
476 | ) { |
477 | std::size_t length = 0; |
478 | length += olm::pickle_length(value: SESSION_PICKLE_VERSION); |
479 | length += olm::pickle_length(value: value.received_message); |
480 | length += olm::pickle_length(value: value.alice_identity_key); |
481 | length += olm::pickle_length(value: value.alice_base_key); |
482 | length += olm::pickle_length(value: value.bob_one_time_key); |
483 | length += olm::pickle_length(value: value.ratchet); |
484 | return length; |
485 | } |
486 | |
487 | |
488 | std::uint8_t * olm::pickle( |
489 | std::uint8_t * pos, |
490 | Session const & value |
491 | ) { |
492 | pos = olm::pickle(pos, value: SESSION_PICKLE_VERSION); |
493 | pos = olm::pickle(pos, value: value.received_message); |
494 | pos = olm::pickle(pos, value: value.alice_identity_key); |
495 | pos = olm::pickle(pos, value: value.alice_base_key); |
496 | pos = olm::pickle(pos, value: value.bob_one_time_key); |
497 | pos = olm::pickle(pos, value: value.ratchet); |
498 | return pos; |
499 | } |
500 | |
501 | |
502 | std::uint8_t const * olm::unpickle( |
503 | std::uint8_t const * pos, std::uint8_t const * end, |
504 | Session & value |
505 | ) { |
506 | uint32_t pickle_version; |
507 | pos = olm::unpickle(pos, end, value&: pickle_version); UNPICKLE_OK(pos); |
508 | |
509 | bool includes_chain_index; |
510 | switch (pickle_version) { |
511 | case 1: |
512 | includes_chain_index = false; |
513 | break; |
514 | |
515 | case 0x80000001UL: |
516 | includes_chain_index = true; |
517 | break; |
518 | |
519 | default: |
520 | value.last_error = OlmErrorCode::OLM_UNKNOWN_PICKLE_VERSION; |
521 | return nullptr; |
522 | } |
523 | |
524 | pos = olm::unpickle(pos, end, value&: value.received_message); UNPICKLE_OK(pos); |
525 | pos = olm::unpickle(pos, end, value&: value.alice_identity_key); UNPICKLE_OK(pos); |
526 | pos = olm::unpickle(pos, end, value&: value.alice_base_key); UNPICKLE_OK(pos); |
527 | pos = olm::unpickle(pos, end, value&: value.bob_one_time_key); UNPICKLE_OK(pos); |
528 | pos = olm::unpickle(pos, end, value&: value.ratchet, includes_chain_index); UNPICKLE_OK(pos); |
529 | |
530 | return pos; |
531 | } |
532 | |