1#include <cstdint>
2#include <nlohmann/json.hpp>
3#include <utility>
4
5#include <openssl/aes.h>
6#include <openssl/sha.h>
7
8#include "mtxclient/crypto/client.hpp"
9#include "mtxclient/crypto/types.hpp"
10#include "mtxclient/crypto/utils.hpp"
11
12#include "mtx/log.hpp"
13
14using json = nlohmann::json;
15using namespace mtx::crypto;
16
17static constexpr auto pwhash_SALTBYTES = 16u;
18
19using namespace std::string_view_literals;
20
21static const std::array olmErrorStrings{
22 "SUCCESS"sv,
23 "NOT_ENOUGH_RANDOM"sv,
24 "OUTPUT_BUFFER_TOO_SMALL"sv,
25 "BAD_MESSAGE_VERSION"sv,
26 "BAD_MESSAGE_FORMAT"sv,
27 "BAD_MESSAGE_MAC"sv,
28 "BAD_MESSAGE_KEY_ID"sv,
29 "INVALID_BASE64"sv,
30 "BAD_ACCOUNT_KEY"sv,
31 "UNKNOWN_PICKLE_VERSION"sv,
32 "CORRUPTED_PICKLE"sv,
33 "BAD_SESSION_KEY"sv,
34 "UNKNOWN_MESSAGE_INDEX"sv,
35 "BAD_LEGACY_ACCOUNT_PICKLE"sv,
36 "BAD_SIGNATURE"sv,
37 "OLM_INPUT_BUFFER_TOO_SMALL"sv,
38 "OLM_SAS_THEIR_KEY_NOT_SET"sv,
39
40};
41
42mtx::crypto::OlmErrorCode
43olm_exception::ec_from_string(std::string_view error)
44{
45 for (size_t i = 0; i < olmErrorStrings.size(); i++) {
46 if (olmErrorStrings[i] == error)
47 return static_cast<mtx::crypto::OlmErrorCode>(i);
48 }
49
50 return mtx::crypto::OlmErrorCode::UNKNOWN_ERROR;
51}
52
53void
54OlmClient::create_new_account()
55{
56 account_ = create_olm_object<AccountObject>();
57
58 auto tmp_buf = create_buffer(nbytes: olm_create_account_random_length(account: account_.get()));
59 auto ret = olm_create_account(account: account_.get(), random: tmp_buf.data(), random_length: tmp_buf.size());
60
61 if (ret == olm_error())
62 throw olm_exception("create_new_account", account_.get());
63}
64
65void
66OlmClient::restore_account(const std::string &saved_data, const std::string &key)
67{
68 account_ = unpickle<AccountObject>(pickled: saved_data, key);
69}
70
71mtx::crypto::IdentityKeys
72OlmClient::identity_keys() const
73{
74 if (!account_)
75 throw olm_exception("identity_keys", account_.get());
76
77 auto tmp_buf = create_buffer(nbytes: olm_account_identity_keys_length(account: account_.get()));
78 auto ret = olm_account_identity_keys(account: account_.get(), identity_keys: (void *)tmp_buf.data(), identity_key_length: tmp_buf.size());
79
80 if (ret == olm_error())
81 throw olm_exception("identity_keys", account_.get());
82
83 return json::parse(i: std::string(tmp_buf.begin(), tmp_buf.end()))
84 .get<mtx::crypto::IdentityKeys>();
85}
86
87std::string
88OlmClient::sign_message(const std::string &msg) const
89{
90 auto signature_buf = create_buffer(nbytes: olm_account_signature_length(account: account_.get()));
91 olm_account_sign(
92 account: account_.get(), message: msg.data(), message_length: msg.size(), signature: signature_buf.data(), signature_length: signature_buf.size());
93
94 return std::string(signature_buf.begin(), signature_buf.end());
95}
96
97std::string
98OlmClient::sign_identity_keys()
99{
100 auto keys = identity_keys();
101
102 json body{{"algorithms", {"m.olm.v1.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"}},
103 {"user_id", user_id_},
104 {"device_id", device_id_},
105 {"keys",
106 {
107 {"curve25519:" + device_id_, keys.curve25519},
108 {"ed25519:" + device_id_, keys.ed25519},
109 }}};
110
111 return sign_message(msg: body.dump());
112}
113
114std::size_t
115OlmClient::generate_one_time_keys(std::size_t number_of_keys, bool generate_fallback)
116{
117 const std::size_t nbytes =
118 olm_account_generate_one_time_keys_random_length(account: account_.get(), number_of_keys);
119
120 auto buf = create_buffer(nbytes);
121
122 auto ret =
123 olm_account_generate_one_time_keys(account: account_.get(), number_of_keys, random: buf.data(), random_length: buf.size());
124
125 if (ret == olm_error())
126 throw olm_exception("generate_one_time_keys", account_.get());
127
128 if (generate_fallback) {
129 const std::size_t fnbytes = olm_account_generate_fallback_key_random_length(account: account_.get());
130 buf = create_buffer(nbytes: fnbytes);
131 auto temp = olm_account_generate_fallback_key(account: account_.get(), random: buf.data(), random_length: buf.size());
132 if (temp == olm_error())
133 throw olm_exception("generate_fallback_keys", account_.get());
134 }
135
136 return ret;
137}
138
139mtx::crypto::OneTimeKeys
140OlmClient::one_time_keys()
141{
142 auto buf = create_buffer(nbytes: olm_account_one_time_keys_length(account: account_.get()));
143
144 const auto ret = olm_account_one_time_keys(account: account_.get(), one_time_keys: buf.data(), one_time_keys_length: buf.size());
145
146 if (ret == olm_error())
147 throw olm_exception("one_time_keys", account_.get());
148
149 return json::parse(i: std::string(buf.begin(), buf.end())).get<mtx::crypto::OneTimeKeys>();
150}
151
152mtx::crypto::OneTimeKeys
153OlmClient::unpublished_fallback_keys()
154{
155 auto fbuf = create_buffer(nbytes: olm_account_unpublished_fallback_key_length(account: account_.get()));
156
157 const auto fret =
158 olm_account_unpublished_fallback_key(account: account_.get(), fallback_key: fbuf.data(), fallback_key_size: fbuf.size());
159 if (fret == olm_error())
160 throw olm_exception("unpublished_fallback_keys", account_.get());
161
162 return json::parse(i: std::string(fbuf.begin(), fbuf.end())).get<mtx::crypto::OneTimeKeys>();
163}
164
165std::string
166OlmClient::sign_one_time_key(const std::string &key, bool fallback)
167{
168 json j{{"key", key}};
169 if (fallback)
170 j["fallback"] = true;
171 return sign_message(msg: j.dump());
172}
173
174std::map<std::string, mtx::requests::SignedOneTimeKey>
175OlmClient::sign_one_time_keys(const OneTimeKeys &keys, bool fallback)
176{
177 // Sign & append the one time keys.
178 std::map<std::string, mtx::requests::SignedOneTimeKey> signed_one_time_keys;
179 for (const auto &elem : keys.curve25519) {
180 const auto key_id = elem.first;
181 const auto one_time_key = elem.second;
182
183 auto sig = sign_one_time_key(key: one_time_key, fallback);
184
185 signed_one_time_keys["signed_curve25519:" + key_id] =
186 signed_one_time_key(key: one_time_key, signature: sig, fallback);
187 }
188
189 return signed_one_time_keys;
190}
191
192mtx::requests::SignedOneTimeKey
193OlmClient::signed_one_time_key(const std::string &key, const std::string &signature, bool fallback)
194{
195 mtx::requests::SignedOneTimeKey sign{};
196 sign.key = key;
197 sign.fallback = fallback;
198 sign.signatures = {{user_id_, {{"ed25519:" + device_id_, signature}}}};
199 return sign;
200}
201
202mtx::requests::UploadKeys
203OlmClient::create_upload_keys_request()
204{
205 return create_upload_keys_request(keys: one_time_keys(), fallback_keys: unpublished_fallback_keys());
206}
207
208mtx::requests::UploadKeys
209OlmClient::create_upload_keys_request(const mtx::crypto::OneTimeKeys &one_time_keys,
210 const mtx::crypto::OneTimeKeys &fallback_keys)
211{
212 mtx::requests::UploadKeys req;
213 req.device_keys.user_id = user_id_;
214 req.device_keys.device_id = device_id_;
215
216 auto id_keys = identity_keys();
217
218 req.device_keys.keys["curve25519:" + device_id_] = id_keys.curve25519;
219 req.device_keys.keys["ed25519:" + device_id_] = id_keys.ed25519;
220
221 // Generate and add the signature to the request.
222 auto sig = sign_identity_keys();
223
224 req.device_keys.signatures[user_id_]["ed25519:" + device_id_] = sig;
225
226 // Sign & append the one time keys.
227 auto temp = sign_one_time_keys(keys: one_time_keys);
228 for (const auto &[key_id, key] : temp)
229 req.one_time_keys[key_id] = key;
230
231 temp = sign_one_time_keys(keys: fallback_keys, fallback: true);
232 for (const auto &[key_id, key] : temp) {
233 req.fallback_keys[key_id] = key;
234 }
235
236 return req;
237}
238
239std::optional<OlmClient::CrossSigningSetup>
240OlmClient::create_crosssigning_keys()
241{
242 auto master = PkSigning::new_key();
243 auto user_signing = PkSigning::new_key();
244 auto self_signing = PkSigning::new_key();
245
246 CrossSigningSetup setup{};
247 setup.private_master_key = master.seed();
248 setup.private_user_signing_key = user_signing.seed();
249 setup.private_self_signing_key = self_signing.seed();
250
251 // master key
252 setup.master_key.usage = {"master"};
253 setup.master_key.user_id = user_id_;
254 setup.master_key.keys["ed25519:" + master.public_key()] = master.public_key();
255
256 nlohmann::json master_j = setup.master_key;
257 master_j.erase(key: "unsigned");
258 master_j.erase(key: "signatures");
259 setup.master_key.signatures[user_id_]["ed25519:" + master.public_key()] =
260 master.sign(message: master_j.dump());
261 setup.master_key.signatures[user_id_]["ed25519:" + device_id_] = sign_message(msg: master_j.dump());
262
263 // user_signing_key
264 setup.user_signing_key.usage = {"user_signing"};
265 setup.user_signing_key.user_id = user_id_;
266 setup.user_signing_key.keys["ed25519:" + user_signing.public_key()] = user_signing.public_key();
267
268 nlohmann::json user_signing_j = setup.user_signing_key;
269 user_signing_j.erase(key: "unsigned");
270 user_signing_j.erase(key: "signatures");
271 setup.user_signing_key.signatures[user_id_]["ed25519:" + user_signing.public_key()] =
272 user_signing.sign(message: user_signing_j.dump());
273 setup.user_signing_key.signatures[user_id_]["ed25519:" + master.public_key()] =
274 master.sign(message: user_signing_j.dump());
275
276 // self_signing_key
277 setup.self_signing_key.usage = {"self_signing"};
278 setup.self_signing_key.user_id = user_id_;
279 setup.self_signing_key.keys["ed25519:" + self_signing.public_key()] = self_signing.public_key();
280
281 nlohmann::json self_signing_j = setup.self_signing_key;
282 self_signing_j.erase(key: "unsigned");
283 self_signing_j.erase(key: "signatures");
284 setup.self_signing_key.signatures[user_id_]["ed25519:" + self_signing.public_key()] =
285 self_signing.sign(message: self_signing_j.dump());
286 setup.self_signing_key.signatures[user_id_]["ed25519:" + master.public_key()] =
287 master.sign(message: self_signing_j.dump());
288
289 return setup;
290}
291
292std::optional<OlmClient::OnlineKeyBackupSetup>
293OlmClient::create_online_key_backup(const std::string &masterKey)
294{
295 OnlineKeyBackupSetup setup{};
296
297 auto key = create_buffer(nbytes: olm_pk_private_key_length());
298 setup.privateKey = key;
299
300 json auth_data;
301 auth_data["public_key"] = bin2base64_unpadded(bin: CURVE25519_public_key_from_private(privateKey: key));
302 auto master = PkSigning::from_seed(seed: masterKey);
303
304 auto sig = master.sign(message: auth_data.dump());
305 auth_data["signatures"][user_id_]["ed25519:" + master.public_key()] = sig;
306
307 setup.backupVersion.auth_data = auth_data.dump();
308 setup.backupVersion.algorithm = "m.megolm_backup.v1.curve25519-aes-sha2";
309
310 return setup;
311}
312
313std::optional<OlmClient::SSSSSetup>
314OlmClient::create_ssss_key(const std::string &password)
315{
316 OlmClient::SSSSSetup setup{};
317
318 if (password.empty()) {
319 setup.privateKey = create_buffer(nbytes: 32);
320 } else {
321 mtx::secret_storage::PBKDF2 pbkdf2{};
322 pbkdf2.algorithm = "m.pbkdf2";
323 // OWASP recommends 210'000 in 2023
324 // https://cheatsheetseries.owasp.org/cheatsheets/Password_Storage_Cheat_Sheet.html#pbkdf2
325 // We started out with 500'000 iterations, so we should still have a long time until we need
326 // to prompt users to upgrade and then we might want to go argon2 directly.
327 pbkdf2.iterations = 630'000;
328 pbkdf2.bits = 256; // 32 * 8
329 pbkdf2.salt = bin2base64(bin: to_string(buf: create_buffer(nbytes: 32)));
330
331 setup.privateKey = mtx::crypto::PBKDF2_HMAC_SHA_512(
332 pass: password, salt: to_binary_buf(str: pbkdf2.salt), iterations: pbkdf2.iterations, keylen: pbkdf2.bits / 8);
333 setup.keyDescription.passphrase = pbkdf2;
334 }
335
336 setup.keyDescription.algorithm = "m.secret_storage.v1.aes-hmac-sha2";
337 setup.keyDescription.name = bin2base58(bin: to_string(buf: create_buffer(nbytes: 16))); // create a random name
338 setup.keyDescription.iv = bin2base64(bin: to_string(buf: compatible_iv(incompatible_iv: create_buffer(nbytes: 32))));
339
340 auto testKeys = HKDF_SHA256(key: setup.privateKey, salt: BinaryBuf(32, 0), info: BinaryBuf{});
341
342 auto encrypted = AES_CTR_256_Encrypt(
343 plaintext: std::string(32, '\0'), aes256Key: testKeys.aes, iv: to_binary_buf(str: base642bin(b64: setup.keyDescription.iv)));
344
345 setup.keyDescription.mac = bin2base64(bin: to_string(buf: HMAC_SHA256(hmacKey: testKeys.mac, data: encrypted)));
346
347 return setup;
348}
349
350OutboundGroupSessionPtr
351OlmClient::init_outbound_group_session()
352{
353 auto session = create_olm_object<OutboundSessionObject>();
354 auto tmp_buf = create_buffer(nbytes: olm_init_outbound_group_session_random_length(session: session.get()));
355
356 const auto ret = olm_init_outbound_group_session(session: session.get(), random: tmp_buf.data(), random_length: tmp_buf.size());
357
358 if (ret == olm_error())
359 throw olm_exception("init_outbound_group_session", session.get());
360
361 return session;
362}
363
364InboundGroupSessionPtr
365OlmClient::init_inbound_group_session(const std::string &session_key)
366{
367 auto session = create_olm_object<InboundSessionObject>();
368
369 auto temp = session_key;
370 const auto ret = olm_init_inbound_group_session(
371 session: session.get(), session_key: reinterpret_cast<const uint8_t *>(temp.data()), session_key_length: temp.size());
372
373 if (ret == olm_error())
374 throw olm_exception("init_inbound_group_session", session.get());
375
376 return session;
377}
378
379InboundGroupSessionPtr
380OlmClient::import_inbound_group_session(const std::string &session_key)
381{
382 auto session = create_olm_object<InboundSessionObject>();
383
384 auto temp = session_key;
385 const auto ret = olm_import_inbound_group_session(
386 session: session.get(), session_key: reinterpret_cast<const uint8_t *>(temp.data()), session_key_length: temp.size());
387
388 if (ret == olm_error())
389 throw olm_exception("init_inbound_group_session", session.get());
390
391 return session;
392}
393
394GroupPlaintext
395OlmClient::decrypt_group_message(OlmInboundGroupSession *session,
396 const std::string &message,
397 uint32_t message_index)
398{
399 if (!session)
400 throw olm_exception("decrypt_group_message", session);
401
402 auto tmp_msg = create_buffer(nbytes: message.size());
403 std::copy(first: message.begin(), last: message.end(), result: tmp_msg.begin());
404
405 auto plaintext_len =
406 olm_group_decrypt_max_plaintext_length(session, message: tmp_msg.data(), message_length: tmp_msg.size());
407 if (plaintext_len == olm_error())
408 throw olm_exception("olm_group_decrypt_max_plaintext_length: invalid ciphertext", session);
409 auto plaintext = create_buffer(nbytes: plaintext_len);
410
411 tmp_msg = create_buffer(nbytes: message.size());
412 std::copy(first: message.begin(), last: message.end(), result: tmp_msg.begin());
413
414 const std::size_t nbytes = olm_group_decrypt(
415 session, message: tmp_msg.data(), message_length: tmp_msg.size(), plaintext: plaintext.data(), max_plaintext_length: plaintext.size(), message_index: &message_index);
416
417 if (nbytes == olm_error())
418 throw olm_exception("olm_group_decrypt", session);
419
420 auto output = create_buffer(nbytes);
421 std::memcpy(dest: output.data(), src: plaintext.data(), n: nbytes);
422
423 return GroupPlaintext{.data: std::move(output), .message_index: message_index};
424}
425
426BinaryBuf
427OlmClient::encrypt_group_message(OlmOutboundGroupSession *session, const std::string &plaintext)
428{
429 auto encrypted_len = olm_group_encrypt_message_length(session, plaintext_length: plaintext.size());
430 auto encrypted_message = create_buffer(nbytes: encrypted_len);
431
432 const std::size_t nbytes =
433 olm_group_encrypt(session,
434 plaintext: reinterpret_cast<const uint8_t *>(plaintext.data()),
435 plaintext_length: plaintext.size(),
436 message: encrypted_message.data(),
437 message_length: encrypted_message.size());
438
439 if (nbytes == olm_error())
440 throw olm_exception("olm_group_encrypt", session);
441
442 return encrypted_message;
443}
444
445BinaryBuf
446OlmClient::decrypt_message(OlmSession *session,
447 size_t msgtype,
448 const std::string &one_time_key_message)
449{
450 auto tmp = create_buffer(nbytes: one_time_key_message.size());
451 std::copy(first: one_time_key_message.begin(), last: one_time_key_message.end(), result: tmp.begin());
452
453 auto declen =
454 olm_decrypt_max_plaintext_length(session, message_type: msgtype, message: (void *)tmp.data(), message_length: tmp.size());
455
456 auto decrypted = create_buffer(nbytes: declen);
457 std::copy(first: one_time_key_message.begin(), last: one_time_key_message.end(), result: tmp.begin());
458
459 const std::size_t nbytes = olm_decrypt(
460 session, message_type: msgtype, message: (void *)tmp.data(), message_length: tmp.size(), plaintext: decrypted.data(), max_plaintext_length: decrypted.size());
461
462 if (nbytes == olm_error())
463 throw olm_exception("olm_decrypt", session);
464
465 // Removing the extra padding from the origial buffer.
466 auto output = create_buffer(nbytes);
467 std::memcpy(dest: output.data(), src: decrypted.data(), n: nbytes);
468
469 return output;
470}
471
472BinaryBuf
473OlmClient::encrypt_message(OlmSession *session, const std::string &msg)
474{
475 auto ciphertext = create_buffer(nbytes: olm_encrypt_message_length(session, plaintext_length: msg.size()));
476 auto random_buf = create_buffer(nbytes: olm_encrypt_random_length(session));
477
478 const auto ret = olm_encrypt(session,
479 plaintext: msg.data(),
480 plaintext_length: msg.size(),
481 random: random_buf.data(),
482 random_length: random_buf.size(),
483 message: ciphertext.data(),
484 message_length: ciphertext.size());
485 if (ret == olm_error())
486 throw olm_exception("olm_encrypt", session);
487
488 return ciphertext;
489}
490
491OlmSessionPtr
492OlmClient::create_inbound_session_from(const std::string &their_curve25519,
493 const std::string &one_time_key_message)
494{
495 BinaryBuf tmp(one_time_key_message.size());
496 memcpy(dest: tmp.data(), src: one_time_key_message.data(), n: one_time_key_message.size());
497
498 return create_inbound_session_from(their_curve25519, one_time_key_message: tmp);
499}
500
501OlmSessionPtr
502OlmClient::create_inbound_session_from(const std::string &their_curve25519,
503 const BinaryBuf &one_time_key_message)
504{
505 auto session = create_olm_object<SessionObject>();
506
507 auto tmp = create_buffer(nbytes: one_time_key_message.size());
508 std::copy(first: one_time_key_message.begin(), last: one_time_key_message.end(), result: tmp.begin());
509
510 std::size_t ret = olm_create_inbound_session_from(session: session.get(),
511 account: account(),
512 their_identity_key: their_curve25519.data(),
513 their_identity_key_length: their_curve25519.size(),
514 one_time_key_message: (void *)tmp.data(),
515 message_length: tmp.size());
516
517 if (ret == olm_error())
518 throw olm_exception("create_inbound_session_from", session.get());
519
520 ret = olm_remove_one_time_keys(account: account_.get(), session: session.get());
521
522 if (ret == olm_error())
523 throw olm_exception("inbound_session_from_remove_one_time_keys", account_.get());
524
525 return session;
526}
527
528OlmSessionPtr
529OlmClient::create_inbound_session(const std::string &one_time_key_message)
530{
531 BinaryBuf tmp(one_time_key_message.size());
532 memcpy(dest: tmp.data(), src: one_time_key_message.data(), n: one_time_key_message.size());
533
534 return create_inbound_session(one_time_key_message: tmp);
535}
536
537OlmSessionPtr
538OlmClient::create_inbound_session(const BinaryBuf &one_time_key_message)
539{
540 auto session = create_olm_object<SessionObject>();
541
542 auto tmp = create_buffer(nbytes: one_time_key_message.size());
543 std::copy(first: one_time_key_message.begin(), last: one_time_key_message.end(), result: tmp.begin());
544
545 std::size_t ret =
546 olm_create_inbound_session(session: session.get(), account: account(), one_time_key_message: (void *)tmp.data(), message_length: tmp.size());
547
548 if (ret == olm_error())
549 throw olm_exception("create_inbound_session", session.get());
550
551 ret = olm_remove_one_time_keys(account: account_.get(), session: session.get());
552
553 if (ret == olm_error())
554 throw olm_exception("inbound_session_remove_one_time_keys", account_.get());
555
556 return session;
557}
558
559OlmSessionPtr
560OlmClient::create_outbound_session(const std::string &identity_key, const std::string &one_time_key)
561{
562 auto session = create_olm_object<SessionObject>();
563 auto random_buf = create_buffer(nbytes: olm_create_outbound_session_random_length(session: session.get()));
564
565 const auto ret = olm_create_outbound_session(session: session.get(),
566 account: account(),
567 their_identity_key: identity_key.data(),
568 their_identity_key_length: identity_key.size(),
569 their_one_time_key: one_time_key.data(),
570 their_one_time_key_length: one_time_key.size(),
571 random: random_buf.data(),
572 random_length: random_buf.size());
573
574 if (ret == olm_error())
575 throw olm_exception("create_outbound_session", session.get());
576
577 return session;
578}
579
580std::unique_ptr<SAS>
581OlmClient::sas_init()
582{
583 return std::make_unique<SAS>();
584}
585
586//! constructor which create a new Curve25519 key pair which is stored in SASObject
587SAS::SAS()
588{
589 this->sas = create_olm_object<SASObject>();
590 auto random_buf = BinaryBuf(olm_create_sas_random_length(sas: sas.get()));
591
592 const auto ret = olm_create_sas(sas: this->sas.get(), random: random_buf.data(), random_length: random_buf.size());
593
594 if (ret == olm_error())
595 throw olm_exception("create_sas_instance", this->sas.get());
596}
597
598std::string
599SAS::public_key()
600{
601 auto pub_key_buffer = create_buffer(nbytes: olm_sas_pubkey_length(sas: this->sas.get()));
602
603 const auto ret =
604 olm_sas_get_pubkey(sas: this->sas.get(), pubkey: pub_key_buffer.data(), pubkey_length: pub_key_buffer.size());
605
606 if (ret == olm_error())
607 throw olm_exception("get_public_key", this->sas.get());
608
609 return to_string(buf: pub_key_buffer);
610}
611
612void
613SAS::set_their_key(const std::string &their_public_key)
614{
615 auto pub_key_buffer = to_binary_buf(str: their_public_key);
616
617 const auto ret =
618 olm_sas_set_their_key(sas: this->sas.get(), their_key: pub_key_buffer.data(), their_key_length: pub_key_buffer.size());
619
620 if (ret == olm_error())
621 throw olm_exception("get_public_key", this->sas.get());
622}
623
624std::vector<int>
625SAS::generate_bytes_decimal(const std::string &info)
626{
627 auto input_info_buffer = to_binary_buf(str: info);
628 auto output_buffer = BinaryBuf(5);
629
630 std::vector<int> output_list;
631 output_list.resize(new_size: 3);
632
633 const auto ret = olm_sas_generate_bytes(sas: this->sas.get(),
634 info: input_info_buffer.data(),
635 info_length: input_info_buffer.size(),
636 output: output_buffer.data(),
637 output_length: output_buffer.size());
638
639 if (ret == olm_error())
640 throw olm_exception("get_bytes_decimal", this->sas.get());
641
642 output_list[0] = (((output_buffer[0] << 5) | (output_buffer[1] >> 3)) + 1000);
643 output_list[1] =
644 ((((output_buffer[1] & 0x07) << 10) | (output_buffer[2] << 2) | (output_buffer[3] >> 6)) +
645 1000);
646 output_list[2] = (((((output_buffer[3] & 0x3F) << 7)) | ((output_buffer[4] >> 1))) + 1000);
647
648 return output_list;
649}
650
651//! generates and returns a vector of number(int) ranging from 0 to 63, to be used only after using
652//! `set_their_key`
653std::vector<int>
654SAS::generate_bytes_emoji(const std::string &info)
655{
656 auto input_info_buffer = to_binary_buf(str: info);
657 auto output_buffer = BinaryBuf(6);
658
659 std::vector<int> output_list;
660 output_list.resize(new_size: 7);
661
662 const auto ret = olm_sas_generate_bytes(sas: this->sas.get(),
663 info: input_info_buffer.data(),
664 info_length: input_info_buffer.size(),
665 output: output_buffer.data(),
666 output_length: output_buffer.size());
667
668 if (ret == olm_error())
669 throw olm_exception("get_bytes_emoji", this->sas.get());
670
671 output_list[0] = (output_buffer[0] >> 2);
672 output_list[1] = (((output_buffer[0] << 4) & 0x3f) | (output_buffer[1] >> 4));
673 output_list[2] = (((output_buffer[1] << 2) & 0x3f) | (output_buffer[2] >> 6));
674 output_list[3] = (output_buffer[2] & 0x3f);
675 output_list[4] = (output_buffer[3] >> 2);
676 output_list[5] = (((output_buffer[3] << 4) & 0x3f) | (output_buffer[4] >> 4));
677 output_list[6] = (((output_buffer[4] << 2) & 0x3f) | (output_buffer[5] >> 6));
678
679 return output_list;
680}
681
682//! calculates the mac based on the given input and info using the shared secret produced after
683//! `set_their_key`
684std::string
685SAS::calculate_mac(const std::string &input_data, const std::string &info)
686{
687 auto input_data_buffer = to_binary_buf(str: input_data);
688 auto info_buffer = to_binary_buf(str: info);
689 auto output_buffer = BinaryBuf(olm_sas_mac_length(sas: this->sas.get()));
690
691 const auto ret = olm_sas_calculate_mac(sas: this->sas.get(),
692 input: input_data_buffer.data(),
693 input_length: input_data_buffer.size(),
694 info: info_buffer.data(),
695 info_length: info_buffer.size(),
696 mac: output_buffer.data(),
697 mac_length: output_buffer.size());
698
699 if (ret == olm_error())
700 throw olm_exception("calculate_mac", this->sas.get());
701
702 return to_string(buf: output_buffer);
703}
704
705PkSigning
706PkSigning::new_key()
707{
708 auto priv_seed = bin2base64(bin: to_string(buf: create_buffer(nbytes: olm_pk_signing_seed_length())));
709 return from_seed(seed: priv_seed);
710}
711
712PkSigning
713PkSigning::from_seed(const std::string &seed)
714{
715 PkSigning s{};
716 s.seed_ = seed;
717 s.signing = create_olm_object<PkSigningObject>();
718
719 auto seed_ = base642bin(b64: seed);
720
721 auto pub_key_buffer = BinaryBuf(olm_pk_signing_public_key_length());
722 auto ret = olm_pk_signing_key_from_seed(
723 sign: s.signing.get(), pubkey: pub_key_buffer.data(), pubkey_length: pub_key_buffer.size(), seed: seed_.data(), seed_length: seed_.size());
724
725 if (ret == olm_error())
726 throw olm_exception("signing_from_seed", s.signing.get());
727
728 s.public_key_ = to_string(buf: pub_key_buffer);
729
730 return s;
731}
732
733std::string
734PkSigning::sign(const std::string &message)
735{
736 auto signature = BinaryBuf(olm_pk_signature_length());
737 auto message_ = to_binary_buf(str: message);
738
739 auto ret = olm_pk_sign(
740 sign: signing.get(), message: message_.data(), message_length: message_.size(), signature: signature.data(), signature_length: signature.size());
741
742 if (ret == olm_error())
743 throw olm_exception("olm_pk_sign", signing.get());
744
745 return to_string(buf: signature);
746}
747
748nlohmann::json
749OlmClient::create_olm_encrypted_content(OlmSession *session,
750 nlohmann::json event,
751 const UserId &recipient,
752 const std::string &recipient_ed25519_key,
753 const std::string &recipient_curve25519_key)
754{
755 event["keys"]["ed25519"] = identity_keys().ed25519;
756 event["sender"] = user_id_;
757 event["sender_device"] = device_id_;
758
759 event["recipient"] = recipient.get();
760 event["recipient_keys"]["ed25519"] = recipient_ed25519_key;
761
762 size_t msg_type = olm_encrypt_message_type(session);
763 auto encrypted = encrypt_message(session, msg: json(event).dump());
764 auto encrypted_str = std::string((char *)encrypted.data(), encrypted.size());
765
766 return json{
767 {"algorithm", "m.olm.v1.curve25519-aes-sha2"},
768 {"sender_key", identity_keys().curve25519},
769 {"ciphertext", {{recipient_curve25519_key, {{"body", encrypted_str}, {"type", msg_type}}}}}};
770}
771
772std::string
773OlmClient::save(const std::string &key)
774{
775 if (!account_)
776 return std::string();
777
778 return pickle<AccountObject>(object: account(), key);
779}
780
781std::string
782mtx::crypto::session_id(OlmSession *s)
783{
784 auto tmp = create_buffer(nbytes: olm_session_id_length(session: s));
785 olm_session_id(session: s, id: tmp.data(), id_length: tmp.size());
786
787 return std::string(tmp.begin(), tmp.end());
788}
789
790std::string
791mtx::crypto::session_id(OlmOutboundGroupSession *s)
792{
793 auto tmp = create_buffer(nbytes: olm_outbound_group_session_id_length(session: s));
794 olm_outbound_group_session_id(session: s, id: tmp.data(), id_length: tmp.size());
795
796 return std::string(tmp.begin(), tmp.end());
797}
798
799std::string
800mtx::crypto::session_key(OlmOutboundGroupSession *s)
801{
802 auto tmp = create_buffer(nbytes: olm_outbound_group_session_key_length(session: s));
803 olm_outbound_group_session_key(session: s, key: tmp.data(), key_length: tmp.size());
804
805 return std::string(tmp.begin(), tmp.end());
806}
807
808std::string
809mtx::crypto::export_session(OlmInboundGroupSession *s, uint32_t at_index)
810{
811 const size_t len = olm_export_inbound_group_session_length(session: s);
812 const uint32_t index =
813 at_index == uint32_t(-1) ? olm_inbound_group_session_first_known_index(session: s) : at_index;
814
815 auto session_key = create_buffer(nbytes: len);
816 const std::size_t ret =
817 olm_export_inbound_group_session(session: s, key: session_key.data(), key_length: session_key.size(), message_index: index);
818
819 if (ret == olm_error())
820 throw olm_exception("session_key", s);
821
822 return std::string(session_key.begin(), session_key.end());
823}
824
825InboundGroupSessionPtr
826mtx::crypto::import_session(const std::string &session_key)
827{
828 auto session = create_olm_object<InboundSessionObject>();
829
830 const std::size_t ret = olm_import_inbound_group_session(
831 session: session.get(), session_key: reinterpret_cast<const uint8_t *>(session_key.data()), session_key_length: session_key.size());
832
833 if (ret == olm_error())
834 throw olm_exception("import_session", session.get());
835
836 return session;
837}
838
839bool
840mtx::crypto::matches_inbound_session(OlmSession *session, const std::string &one_time_key_message)
841{
842 auto tmp = create_buffer(nbytes: one_time_key_message.size());
843 std::copy(first: one_time_key_message.begin(), last: one_time_key_message.end(), result: tmp.begin());
844
845 return olm_matches_inbound_session(session, one_time_key_message: (void *)tmp.data(), message_length: tmp.size());
846}
847
848bool
849mtx::crypto::matches_inbound_session_from(OlmSession *session,
850 const std::string &id_key,
851 const std::string &one_time_key_message)
852{
853 auto tmp = create_buffer(nbytes: one_time_key_message.size());
854 std::copy(first: one_time_key_message.begin(), last: one_time_key_message.end(), result: tmp.begin());
855
856 return olm_matches_inbound_session_from(
857 session, their_identity_key: id_key.data(), their_identity_key_length: id_key.size(), one_time_key_message: (void *)tmp.data(), message_length: tmp.size());
858}
859
860bool
861mtx::crypto::verify_identity_signature(const DeviceKeys &device_keys,
862 const DeviceId &device_id,
863 const UserId &user_id)
864{
865 try {
866 const auto sign_key_id = "ed25519:" + device_id.get();
867 const auto signing_key = device_keys.keys.at(k: sign_key_id);
868 const auto signature = device_keys.signatures.at(k: user_id.get()).at(k: sign_key_id);
869
870 if (signature.empty())
871 return false;
872
873 return ed25519_verify_signature(signing_key, obj: nlohmann::json(device_keys), signature);
874
875 } catch (const nlohmann::json::exception &e) {
876 mtx::utils::log::log()->error(fmt: "verify_identity_signature: {}", args: e.what());
877 }
878
879 return false;
880}
881
882//! checks if the signature is signed by the signing_key
883bool
884mtx::crypto::ed25519_verify_signature(std::string signing_key,
885 nlohmann::json obj,
886 std::string signature)
887{
888 try {
889 if (signature.empty())
890 return false;
891
892 obj.erase(key: "unsigned");
893 obj.erase(key: "signatures");
894
895 std::string canonical_json = obj.dump();
896
897 auto utility = create_olm_object<UtilityObject>();
898 auto ret = olm_ed25519_verify(utility: utility.get(),
899 key: signing_key.data(),
900 key_length: signing_key.size(),
901 message: canonical_json.data(),
902 message_length: canonical_json.size(),
903 signature: (void *)signature.data(),
904 signature_length: signature.size());
905
906 // the signature is wrong
907 if (ret != 0)
908 return false;
909
910 return true;
911 } catch (const nlohmann::json::exception &e) {
912 mtx::utils::log::log()->error(fmt: "verify_signature: {}", args: e.what());
913 }
914
915 return false;
916}
917
918std::string
919mtx::crypto::encrypt_exported_sessions(const mtx::crypto::ExportedSessionKeys &keys,
920 const std::string &pass)
921{
922 const auto plaintext = json(keys).dump();
923
924 auto nonce = create_buffer(AES_BLOCK_SIZE);
925 constexpr std::uint8_t mask = static_cast<std::uint8_t>(~(1U << (63 / 8)));
926 nonce[15 - 63 % 8] &= mask;
927
928 auto salt = create_buffer(nbytes: pwhash_SALTBYTES);
929
930 auto buf = create_buffer(nbytes: 64U);
931
932 uint32_t iterations = 100000;
933 buf = mtx::crypto::PBKDF2_HMAC_SHA_512(pass, salt, iterations);
934
935 BinaryBuf aes256 = BinaryBuf(buf.begin(), buf.begin() + 32);
936
937 BinaryBuf hmac256 = BinaryBuf(buf.begin() + 32, buf.begin() + (2UL * 32));
938
939 auto ciphertext = mtx::crypto::AES_CTR_256_Encrypt(plaintext, aes256Key: aes256, iv: nonce);
940
941 uint8_t iterationsArr[4];
942 mtx::crypto::uint32_to_uint8(b: iterationsArr, u32: iterations);
943
944 // Format of the output buffer: (0x01 + salt + IV + number of rounds + ciphertext +
945 // hmac-sha-256)
946 BinaryBuf output{
947 0x01,
948 };
949 output.reserve(n: 1 + salt.size() + nonce.size() + 4 + ciphertext.size());
950 output.insert(position: output.end(), first: salt.begin(), last: salt.end());
951 output.insert(position: output.end(), first: nonce.begin(), last: nonce.end());
952 output.insert(position: output.end(), first: &iterationsArr[0], last: &iterationsArr[4]);
953 output.insert(position: output.end(), first: ciphertext.begin(), last: ciphertext.end());
954
955 // Need to hmac-sha256 our string so far, and then use that to finish making the output.
956 auto hmacSha256 = mtx::crypto::HMAC_SHA256(hmacKey: hmac256, data: output);
957
958 output.insert(position: output.end(), first: hmacSha256.begin(), last: hmacSha256.end());
959 auto encrypted = std::string(output.begin(), output.end());
960
961 return encrypted;
962}
963
964mtx::crypto::ExportedSessionKeys
965mtx::crypto::decrypt_exported_sessions(const std::string &data, const std::string &pass)
966{
967 // Parse the data into a base64 string without the header and footer
968 std::string unpacked = mtx::crypto::unpack_key_file(data);
969
970 std::string binary_str = base642bin(b64: unpacked);
971
972 if (binary_str.size() <
973 1 + pwhash_SALTBYTES + AES_BLOCK_SIZE + sizeof(uint32_t) + SHA256_DIGEST_LENGTH + 2)
974 throw crypto_exception("decrypt_exported_sessions", "Invalid session file: too short");
975
976 const auto binary_start = binary_str.begin();
977 const auto binary_end = binary_str.end();
978
979 // Format version 0x01, 1 byte
980 const auto format_end = binary_start + 1;
981 auto format = BinaryBuf(binary_start, format_end);
982 if (format[0] != 0x01)
983 throw crypto_exception("decrypt_exported_sessions", "Unsupported backup file format.");
984
985 // Salt, 16 bytes
986 const auto salt_end = format_end + pwhash_SALTBYTES;
987 auto salt = BinaryBuf(format_end, salt_end);
988
989 // IV, 16 bytes
990 const auto iv_end = salt_end + AES_BLOCK_SIZE;
991 auto iv = BinaryBuf(salt_end, iv_end);
992
993 // Number of rounds, 4 bytes
994 const auto rounds_end = iv_end + sizeof(uint32_t);
995 auto rounds_buff = BinaryBuf(iv_end, rounds_end);
996 uint8_t rounds_arr[4];
997 std::copy(first: rounds_buff.begin(), last: rounds_buff.end(), result: rounds_arr);
998 uint32_t rounds;
999 mtx::crypto::uint8_to_uint32(b: rounds_arr, u32&: rounds);
1000
1001 // Variable-length JSON object...
1002 const auto json_end = binary_end - SHA256_DIGEST_LENGTH;
1003 auto json = BinaryBuf(rounds_end, json_end);
1004
1005 // HMAC of the above, 32 bytes
1006 auto hmac = BinaryBuf(json_end, binary_end);
1007
1008 // derive the keys
1009 auto buf = mtx::crypto::PBKDF2_HMAC_SHA_512(pass, salt, iterations: rounds);
1010
1011 BinaryBuf aes256 = BinaryBuf(buf.begin(), buf.begin() + 32);
1012
1013 BinaryBuf hmac256 = BinaryBuf(buf.begin() + 32, buf.begin() + (2 * 32));
1014
1015 // get hmac and verify they match
1016 auto hmacSha256 = mtx::crypto::HMAC_SHA256(hmacKey: hmac256, data: BinaryBuf(binary_start, json_end));
1017
1018 if (hmacSha256 != hmac) {
1019 throw crypto_exception{"decrypt_exported_sessions", "HMAC doesn't match"};
1020 }
1021
1022 const std::string ciphertext(json.begin(), json.end());
1023 auto decrypted = mtx::crypto::AES_CTR_256_Decrypt(ciphertext, aes256Key: aes256, iv);
1024
1025 std::string plaintext(decrypted.begin(), decrypted.end());
1026 return json::parse(i&: plaintext).get<mtx::crypto::ExportedSessionKeys>();
1027}
1028