1 | /* Copyright 2015-2016 OpenMarket Ltd |
2 | * |
3 | * Licensed under the Apache License, Version 2.0 (the "License"); |
4 | * you may not use this file except in compliance with the License. |
5 | * You may obtain a copy of the License at |
6 | * |
7 | * http://www.apache.org/licenses/LICENSE-2.0 |
8 | * |
9 | * Unless required by applicable law or agreed to in writing, software |
10 | * distributed under the License is distributed on an "AS IS" BASIS, |
11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | * See the License for the specific language governing permissions and |
13 | * limitations under the License. |
14 | */ |
15 | #include "olm/message.hh" |
16 | |
17 | #include "olm/memory.hh" |
18 | |
19 | namespace { |
20 | |
21 | template<typename T> |
22 | static std::size_t varint_length( |
23 | T value |
24 | ) { |
25 | std::size_t result = 1; |
26 | while (value >= 128U) { |
27 | ++result; |
28 | value >>= 7; |
29 | } |
30 | return result; |
31 | } |
32 | |
33 | |
34 | template<typename T> |
35 | static std::uint8_t * varint_encode( |
36 | std::uint8_t * output, |
37 | T value |
38 | ) { |
39 | while (value >= 128U) { |
40 | *(output++) = (0x7F & value) | 0x80; |
41 | value >>= 7; |
42 | } |
43 | (*output++) = value; |
44 | return output; |
45 | } |
46 | |
47 | |
48 | template<typename T> |
49 | static T varint_decode( |
50 | std::uint8_t const * varint_start, |
51 | std::uint8_t const * varint_end |
52 | ) { |
53 | T value = 0; |
54 | if (varint_end == varint_start) { |
55 | return 0; |
56 | } |
57 | do { |
58 | value <<= 7; |
59 | value |= 0x7F & *(--varint_end); |
60 | } while (varint_end != varint_start); |
61 | return value; |
62 | } |
63 | |
64 | |
65 | static std::uint8_t const * varint_skip( |
66 | std::uint8_t const * input, |
67 | std::uint8_t const * input_end |
68 | ) { |
69 | while (input != input_end) { |
70 | std::uint8_t tmp = *(input++); |
71 | if ((tmp & 0x80) == 0) { |
72 | return input; |
73 | } |
74 | } |
75 | return input; |
76 | } |
77 | |
78 | |
79 | static std::size_t varstring_length( |
80 | std::size_t string_length |
81 | ) { |
82 | return varint_length(value: string_length) + string_length; |
83 | } |
84 | |
85 | static std::size_t const VERSION_LENGTH = 1; |
86 | static std::uint8_t const RATCHET_KEY_TAG = 012; |
87 | static std::uint8_t const COUNTER_TAG = 020; |
88 | static std::uint8_t const CIPHERTEXT_TAG = 042; |
89 | |
90 | static std::uint8_t * encode( |
91 | std::uint8_t * pos, |
92 | std::uint8_t tag, |
93 | std::uint32_t value |
94 | ) { |
95 | *(pos++) = tag; |
96 | return varint_encode(output: pos, value); |
97 | } |
98 | |
99 | static std::uint8_t * encode( |
100 | std::uint8_t * pos, |
101 | std::uint8_t tag, |
102 | std::uint8_t * & value, std::size_t value_length |
103 | ) { |
104 | *(pos++) = tag; |
105 | pos = varint_encode(output: pos, value: value_length); |
106 | value = pos; |
107 | return pos + value_length; |
108 | } |
109 | |
110 | static std::uint8_t const * decode( |
111 | std::uint8_t const * pos, std::uint8_t const * end, |
112 | std::uint8_t tag, |
113 | std::uint32_t & value, bool & has_value |
114 | ) { |
115 | if (pos != end && *pos == tag) { |
116 | ++pos; |
117 | std::uint8_t const * value_start = pos; |
118 | pos = varint_skip(input: pos, input_end: end); |
119 | value = varint_decode<std::uint32_t>(varint_start: value_start, varint_end: pos); |
120 | has_value = true; |
121 | } |
122 | return pos; |
123 | } |
124 | |
125 | |
126 | static std::uint8_t const * decode( |
127 | std::uint8_t const * pos, std::uint8_t const * end, |
128 | std::uint8_t tag, |
129 | std::uint8_t const * & value, std::size_t & value_length |
130 | ) { |
131 | if (pos != end && *pos == tag) { |
132 | ++pos; |
133 | std::uint8_t const * len_start = pos; |
134 | pos = varint_skip(input: pos, input_end: end); |
135 | std::size_t len = varint_decode<std::size_t>(varint_start: len_start, varint_end: pos); |
136 | if (len > std::size_t(end - pos)) return end; |
137 | value = pos; |
138 | value_length = len; |
139 | pos += len; |
140 | } |
141 | return pos; |
142 | } |
143 | |
144 | static std::uint8_t const * skip_unknown( |
145 | std::uint8_t const * pos, std::uint8_t const * end |
146 | ) { |
147 | if (pos != end) { |
148 | uint8_t tag = *pos; |
149 | if ((tag & 0x7) == 0) { |
150 | pos = varint_skip(input: pos, input_end: end); |
151 | pos = varint_skip(input: pos, input_end: end); |
152 | } else if ((tag & 0x7) == 2) { |
153 | pos = varint_skip(input: pos, input_end: end); |
154 | std::uint8_t const * len_start = pos; |
155 | pos = varint_skip(input: pos, input_end: end); |
156 | std::size_t len = varint_decode<std::size_t>(varint_start: len_start, varint_end: pos); |
157 | if (len > std::size_t(end - pos)) return end; |
158 | pos += len; |
159 | } else { |
160 | return end; |
161 | } |
162 | } |
163 | return pos; |
164 | } |
165 | |
166 | } // namespace |
167 | |
168 | |
169 | std::size_t olm::encode_message_length( |
170 | std::uint32_t counter, |
171 | std::size_t ratchet_key_length, |
172 | std::size_t ciphertext_length, |
173 | std::size_t mac_length |
174 | ) { |
175 | std::size_t length = VERSION_LENGTH; |
176 | length += 1 + varstring_length(string_length: ratchet_key_length); |
177 | length += 1 + varint_length(value: counter); |
178 | length += 1 + varstring_length(string_length: ciphertext_length); |
179 | length += mac_length; |
180 | return length; |
181 | } |
182 | |
183 | |
184 | void olm::encode_message( |
185 | olm::MessageWriter & writer, |
186 | std::uint8_t version, |
187 | std::uint32_t counter, |
188 | std::size_t ratchet_key_length, |
189 | std::size_t ciphertext_length, |
190 | std::uint8_t * output |
191 | ) { |
192 | std::uint8_t * pos = output; |
193 | *(pos++) = version; |
194 | pos = encode(pos, tag: RATCHET_KEY_TAG, value&: writer.ratchet_key, value_length: ratchet_key_length); |
195 | pos = encode(pos, tag: COUNTER_TAG, value: counter); |
196 | pos = encode(pos, tag: CIPHERTEXT_TAG, value&: writer.ciphertext, value_length: ciphertext_length); |
197 | } |
198 | |
199 | |
200 | void olm::decode_message( |
201 | olm::MessageReader & reader, |
202 | std::uint8_t const * input, std::size_t input_length, |
203 | std::size_t mac_length |
204 | ) { |
205 | std::uint8_t const * pos = input; |
206 | std::uint8_t const * end = input + input_length - mac_length; |
207 | std::uint8_t const * unknown = nullptr; |
208 | |
209 | reader.version = 0; |
210 | reader.has_counter = false; |
211 | reader.counter = 0; |
212 | reader.input = input; |
213 | reader.input_length = input_length; |
214 | reader.ratchet_key = nullptr; |
215 | reader.ratchet_key_length = 0; |
216 | reader.ciphertext = nullptr; |
217 | reader.ciphertext_length = 0; |
218 | |
219 | if (input_length < mac_length) return; |
220 | |
221 | if (pos == end) return; |
222 | reader.version = *(pos++); |
223 | |
224 | while (pos != end) { |
225 | unknown = pos; |
226 | pos = decode( |
227 | pos, end, tag: RATCHET_KEY_TAG, |
228 | value&: reader.ratchet_key, value_length&: reader.ratchet_key_length |
229 | ); |
230 | pos = decode( |
231 | pos, end, tag: COUNTER_TAG, |
232 | value&: reader.counter, has_value&: reader.has_counter |
233 | ); |
234 | pos = decode( |
235 | pos, end, tag: CIPHERTEXT_TAG, |
236 | value&: reader.ciphertext, value_length&: reader.ciphertext_length |
237 | ); |
238 | if (unknown == pos) { |
239 | pos = skip_unknown(pos, end); |
240 | } |
241 | } |
242 | } |
243 | |
244 | |
245 | namespace { |
246 | |
247 | static std::uint8_t const ONE_TIME_KEY_ID_TAG = 012; |
248 | static std::uint8_t const BASE_KEY_TAG = 022; |
249 | static std::uint8_t const IDENTITY_KEY_TAG = 032; |
250 | static std::uint8_t const MESSAGE_TAG = 042; |
251 | |
252 | } // namespace |
253 | |
254 | |
255 | std::size_t olm::encode_one_time_key_message_length( |
256 | std::size_t one_time_key_length, |
257 | std::size_t identity_key_length, |
258 | std::size_t base_key_length, |
259 | std::size_t message_length |
260 | ) { |
261 | std::size_t length = VERSION_LENGTH; |
262 | length += 1 + varstring_length(string_length: one_time_key_length); |
263 | length += 1 + varstring_length(string_length: identity_key_length); |
264 | length += 1 + varstring_length(string_length: base_key_length); |
265 | length += 1 + varstring_length(string_length: message_length); |
266 | return length; |
267 | } |
268 | |
269 | |
270 | void olm::encode_one_time_key_message( |
271 | olm::PreKeyMessageWriter & writer, |
272 | std::uint8_t version, |
273 | std::size_t identity_key_length, |
274 | std::size_t base_key_length, |
275 | std::size_t one_time_key_length, |
276 | std::size_t message_length, |
277 | std::uint8_t * output |
278 | ) { |
279 | std::uint8_t * pos = output; |
280 | *(pos++) = version; |
281 | pos = encode(pos, tag: ONE_TIME_KEY_ID_TAG, value&: writer.one_time_key, value_length: one_time_key_length); |
282 | pos = encode(pos, tag: BASE_KEY_TAG, value&: writer.base_key, value_length: base_key_length); |
283 | pos = encode(pos, tag: IDENTITY_KEY_TAG, value&: writer.identity_key, value_length: identity_key_length); |
284 | pos = encode(pos, tag: MESSAGE_TAG, value&: writer.message, value_length: message_length); |
285 | } |
286 | |
287 | |
288 | void olm::decode_one_time_key_message( |
289 | PreKeyMessageReader & reader, |
290 | std::uint8_t const * input, std::size_t input_length |
291 | ) { |
292 | std::uint8_t const * pos = input; |
293 | std::uint8_t const * end = input + input_length; |
294 | std::uint8_t const * unknown = nullptr; |
295 | |
296 | reader.version = 0; |
297 | reader.one_time_key = nullptr; |
298 | reader.one_time_key_length = 0; |
299 | reader.identity_key = nullptr; |
300 | reader.identity_key_length = 0; |
301 | reader.base_key = nullptr; |
302 | reader.base_key_length = 0; |
303 | reader.message = nullptr; |
304 | reader.message_length = 0; |
305 | |
306 | if (pos == end) return; |
307 | reader.version = *(pos++); |
308 | |
309 | while (pos != end) { |
310 | unknown = pos; |
311 | pos = decode( |
312 | pos, end, tag: ONE_TIME_KEY_ID_TAG, |
313 | value&: reader.one_time_key, value_length&: reader.one_time_key_length |
314 | ); |
315 | pos = decode( |
316 | pos, end, tag: BASE_KEY_TAG, |
317 | value&: reader.base_key, value_length&: reader.base_key_length |
318 | ); |
319 | pos = decode( |
320 | pos, end, tag: IDENTITY_KEY_TAG, |
321 | value&: reader.identity_key, value_length&: reader.identity_key_length |
322 | ); |
323 | pos = decode( |
324 | pos, end, tag: MESSAGE_TAG, |
325 | value&: reader.message, value_length&: reader.message_length |
326 | ); |
327 | if (unknown == pos) { |
328 | pos = skip_unknown(pos, end); |
329 | } |
330 | } |
331 | } |
332 | |
333 | |
334 | |
335 | static const std::uint8_t GROUP_MESSAGE_INDEX_TAG = 010; |
336 | static const std::uint8_t GROUP_CIPHERTEXT_TAG = 022; |
337 | |
338 | size_t _olm_encode_group_message_length( |
339 | uint32_t message_index, |
340 | size_t ciphertext_length, |
341 | size_t mac_length, |
342 | size_t signature_length |
343 | ) { |
344 | size_t length = VERSION_LENGTH; |
345 | length += 1 + varint_length(value: message_index); |
346 | length += 1 + varstring_length(string_length: ciphertext_length); |
347 | length += mac_length; |
348 | length += signature_length; |
349 | return length; |
350 | } |
351 | |
352 | |
353 | size_t _olm_encode_group_message( |
354 | uint8_t version, |
355 | uint32_t message_index, |
356 | size_t ciphertext_length, |
357 | uint8_t *output, |
358 | uint8_t **ciphertext_ptr |
359 | ) { |
360 | std::uint8_t * pos = output; |
361 | |
362 | *(pos++) = version; |
363 | pos = encode(pos, tag: GROUP_MESSAGE_INDEX_TAG, value: message_index); |
364 | pos = encode(pos, tag: GROUP_CIPHERTEXT_TAG, value&: *ciphertext_ptr, value_length: ciphertext_length); |
365 | return pos-output; |
366 | } |
367 | |
368 | void _olm_decode_group_message( |
369 | const uint8_t *input, size_t input_length, |
370 | size_t mac_length, size_t signature_length, |
371 | struct _OlmDecodeGroupMessageResults *results |
372 | ) { |
373 | std::uint8_t const * pos = input; |
374 | std::size_t trailer_length = mac_length + signature_length; |
375 | std::uint8_t const * end = input + input_length - trailer_length; |
376 | std::uint8_t const * unknown = nullptr; |
377 | |
378 | bool has_message_index = false; |
379 | results->version = 0; |
380 | results->message_index = 0; |
381 | results->has_message_index = (int)has_message_index; |
382 | results->ciphertext = nullptr; |
383 | results->ciphertext_length = 0; |
384 | |
385 | if (input_length < trailer_length) return; |
386 | |
387 | if (pos == end) return; |
388 | results->version = *(pos++); |
389 | |
390 | while (pos != end) { |
391 | unknown = pos; |
392 | pos = decode( |
393 | pos, end, tag: GROUP_MESSAGE_INDEX_TAG, |
394 | value&: results->message_index, has_value&: has_message_index |
395 | ); |
396 | pos = decode( |
397 | pos, end, tag: GROUP_CIPHERTEXT_TAG, |
398 | value&: results->ciphertext, value_length&: results->ciphertext_length |
399 | ); |
400 | if (unknown == pos) { |
401 | pos = skip_unknown(pos, end); |
402 | } |
403 | } |
404 | |
405 | results->has_message_index = (int)has_message_index; |
406 | } |
407 | |