1#include "mtx/pushrules.hpp"
2
3#include <charconv>
4
5#include <nlohmann/json.hpp>
6#include <re2/re2.h>
7
8#include "mtx/events/collections.hpp"
9#include "mtx/log.hpp"
10
11namespace {
12struct RelatedEvents
13{
14 std::vector<std::unordered_map<std::string, std::string>>
15 fallbacks; //!< fallback related events
16 std::vector<std::unordered_map<std::string, std::string>> events; //!< related events
17};
18}
19
20namespace mtx {
21namespace pushrules {
22
23void
24to_json(nlohmann::json &obj, const PushCondition &condition)
25{
26 obj["kind"] = condition.kind;
27 if (!condition.key.empty())
28 obj["key"] = condition.key;
29 if (!condition.pattern.empty())
30 obj["pattern"] = condition.pattern;
31 if (!condition.is.empty())
32 obj["is"] = condition.is;
33 if (condition.rel_type != mtx::common::RelationType::Unsupported)
34 obj["rel_type"] = condition.rel_type;
35}
36
37void
38from_json(const nlohmann::json &obj, PushCondition &condition)
39{
40 condition.kind = obj["kind"].get<std::string>();
41 condition.key = obj.value(key: "key", default_value: "");
42 condition.pattern = obj.value(key: "pattern", default_value: "");
43 condition.is = obj.value(key: "is", default_value: "");
44 condition.rel_type = obj.value(key: "rel_type", default_value: mtx::common::RelationType::Unsupported);
45 condition.include_fallback = obj.value(key: "include_fallback", default_value: false);
46}
47
48namespace actions {
49void
50to_json(nlohmann::json &obj, const Action &action)
51{
52 if (std::holds_alternative<notify>(v: action))
53 obj = "notify";
54 else if (std::holds_alternative<dont_notify>(v: action))
55 obj = "dont_notify";
56 else if (auto n = std::get_if<set_tweak_sound>(ptr: &action)) {
57 obj["set_tweak"] = "sound";
58 obj["value"] = n->value;
59 } else if (auto h = std::get_if<set_tweak_highlight>(ptr: &action)) {
60 obj["set_tweak"] = "highlight";
61 if (h->value == false)
62 obj["value"] = false;
63 }
64}
65
66void
67from_json(const nlohmann::json &obj, Action &action)
68{
69 if (obj.is_string()) {
70 if (obj == "notify")
71 action = notify{};
72 else if (obj == "dont_notify")
73 action = dont_notify{};
74 } else if (obj.contains(key: "set_tweak")) {
75 if (obj["set_tweak"] == "sound")
76 action = set_tweak_sound{.value: obj.value(key: "value", default_value: "default")};
77 else if (obj["set_tweak"] == "highlight")
78 action = set_tweak_highlight{.value: obj.value(key: "value", default_value: true)};
79 }
80}
81
82void
83to_json(nlohmann::json &obj, const Actions &action)
84{
85 obj["actions"] = action.actions;
86}
87
88void
89from_json(const nlohmann::json &obj, Actions &action)
90{
91 action.actions = obj["actions"].get<std::vector<Action>>();
92}
93}
94
95void
96to_json(nlohmann::json &obj, const PushRule &rule)
97{
98 if (rule.default_)
99 obj["default"] = rule.default_;
100
101 if (!rule.enabled)
102 obj["enabled"] = rule.enabled;
103
104 for (const auto &action : rule.actions)
105 obj["actions"].push_back(val: action);
106
107 if (!rule.rule_id.empty())
108 obj["rule_id"] = rule.rule_id;
109
110 if (!rule.pattern.empty())
111 obj["pattern"] = rule.pattern;
112
113 for (const auto &condition : rule.conditions)
114 obj["conditions"].push_back(val: condition);
115}
116
117void
118from_json(const nlohmann::json &obj, PushRule &rule)
119{
120 rule.rule_id = obj.value(key: "rule_id", default_value: "");
121 rule.default_ = obj.value(key: "default", default_value: false);
122 rule.enabled = obj.value(key: "enabled", default_value: true);
123
124 if (obj.contains(key: "actions"))
125 for (const auto &action : obj["actions"])
126 rule.actions.push_back(x: action.get<actions::Action>());
127
128 rule.pattern = obj.value(key: "pattern", default_value: "");
129
130 if (obj.contains(key: "conditions"))
131 for (const auto &condition : obj["conditions"])
132 rule.conditions.push_back(x: condition.get<PushCondition>());
133}
134
135void
136to_json(nlohmann::json &obj, const Ruleset &set)
137{
138 obj["override"] = set.override_;
139 obj["content"] = set.content;
140 obj["room"] = set.room;
141 obj["sender"] = set.sender;
142 obj["underride"] = set.underride;
143}
144
145void
146from_json(const nlohmann::json &obj, Ruleset &set)
147{
148 if (obj.contains(key: "override"))
149 for (const auto &e : obj["override"])
150 set.override_.push_back(x: e.get<PushRule>());
151 if (obj.contains(key: "content"))
152 for (const auto &e : obj["content"])
153 set.content.push_back(x: e.get<PushRule>());
154 if (obj.contains(key: "room"))
155 for (const auto &e : obj["room"])
156 set.room.push_back(x: e.get<PushRule>());
157 if (obj.contains(key: "sender"))
158 for (const auto &e : obj["sender"])
159 set.sender.push_back(x: e.get<PushRule>());
160 if (obj.contains(key: "underride"))
161 for (const auto &e : obj["underride"])
162 set.underride.push_back(x: e.get<PushRule>());
163}
164void
165to_json(nlohmann::json &obj, const GlobalRuleset &set)
166{
167 obj["global"] = set.global;
168}
169
170void
171from_json(const nlohmann::json &obj, GlobalRuleset &set)
172{
173 set.global = obj["global"].get<Ruleset>();
174}
175
176void
177to_json(nlohmann::json &obj, const Enabled &enabled)
178{
179 obj["enabled"] = enabled.enabled;
180}
181
182void
183from_json(const nlohmann::json &obj, Enabled &enabled)
184{
185 enabled.enabled = obj.value(key: "enabled", default_value: true);
186}
187
188struct PushRuleEvaluator::OptimizedRules
189{
190 //! The individual rule to apply
191 struct OptimizedRule
192 {
193 //! a pattern condition to match
194 struct PatternCondition
195 {
196 std::unique_ptr<re2::RE2> pattern; //!< the pattern
197 std::string field; //!< the field to match with pattern
198
199 bool matches(const std::unordered_map<std::string, std::string> &ev) const
200 {
201 if (auto it = ev.find(x: field); it != ev.end()) {
202 if (pattern) {
203 if (field == "content.body") {
204 if (!re2::RE2::PartialMatch(text: it->second, re: *pattern))
205 return false;
206 } else {
207 if (!re2::RE2::FullMatch(text: it->second, re: *pattern))
208 return false;
209 }
210 }
211 } else {
212 return false;
213 }
214
215 return true;
216 }
217 };
218 // TODO(Nico): Sort by field for faster matching?
219 std::vector<PatternCondition> patterns; //!< conditions that match on a field
220
221 //! a pattern condition to match on a related event
222 struct RelatedEventCondition
223 {
224 PatternCondition ev_match;
225 mtx::common::RelationType rel_type = mtx::common::RelationType::Unsupported;
226 bool include_fallbacks = false;
227 };
228 std::vector<RelatedEventCondition>
229 related_event_patterns; //!< conditions that match on fields of the related event.
230
231 //! a member count condition
232 struct MemberCountCondition
233 {
234 //! the count to compare against
235 std::size_t count = 0;
236 //! the comparison operation
237 enum Comp
238 {
239 Eq, //< ==
240 Lt, //< <
241 Le, //< <=
242 Ge, //< >=
243 Gt, //< >
244 };
245
246 Comp op = Comp::Eq;
247 };
248 std::vector<MemberCountCondition> membercounts; //< conditions that match on member count
249
250 std::vector<std::string> notification_levels;
251
252 //! evaluate contains_display_name condition
253 bool check_displayname = false;
254
255 std::vector<actions::Action> actions; //< the actions to apply on match
256
257 [[nodiscard]] bool matches(
258 const std::unordered_map<std::string, std::string> &ev,
259 const PushRuleEvaluator::RoomContext &ctx,
260 const std::map<mtx::common::RelationType, RelatedEvents> &relatedEventsFlat) const
261 {
262 for (const auto &cond : membercounts) {
263 if (![&cond, &ctx] {
264 switch (cond.op) {
265 case MemberCountCondition::Eq:
266 return ctx.member_count == cond.count;
267 case MemberCountCondition::Le:
268 return ctx.member_count <= cond.count;
269 case MemberCountCondition::Ge:
270 return ctx.member_count >= cond.count;
271 case MemberCountCondition::Lt:
272 return ctx.member_count < cond.count;
273 case MemberCountCondition::Gt:
274 return ctx.member_count > cond.count;
275 default:
276 return false;
277 }
278 }())
279 return false;
280 }
281
282 if (!notification_levels.empty()) {
283 auto sender_ = ev.find(x: "sender");
284 if (sender_ == ev.end())
285 return false;
286
287 auto sender_level = ctx.power_levels.user_level(user_id: sender_->second);
288
289 for (const auto &n : notification_levels) {
290 if (sender_level < ctx.power_levels.notification_level(notification_key: n))
291 return false;
292 }
293 }
294
295 for (const auto &cond : patterns) {
296 if (!cond.matches(ev))
297 return false;
298 }
299
300 for (const auto &cond : related_event_patterns) {
301 bool matched = false;
302 for (const auto &[rel_type, rel_ev] : relatedEventsFlat) {
303 if (cond.rel_type == rel_type) {
304 for (const auto &e : rel_ev.events) {
305 if (cond.ev_match.field.empty() || !cond.ev_match.pattern ||
306 cond.ev_match.matches(ev: e)) {
307 matched = true;
308 break;
309 }
310 }
311 if (cond.include_fallbacks) {
312 for (const auto &e : rel_ev.fallbacks) {
313 if (cond.ev_match.field.empty() || !cond.ev_match.pattern ||
314 cond.ev_match.matches(ev: e)) {
315 matched = true;
316 break;
317 }
318 }
319 }
320 }
321 }
322 if (!matched)
323 return false;
324 }
325
326 if (check_displayname) {
327 if (ctx.user_display_name.empty())
328 return false;
329
330 if (auto it = ev.find(x: "content.body"); it != ev.end()) {
331 re2::RE2::Options opts;
332 opts.set_case_sensitive(false);
333
334 if (!re2::RE2::PartialMatch(
335 text: it->second,
336 re: re2::RE2("(\\W|^)" + re2::RE2::QuoteMeta(unquoted: ctx.user_display_name) +
337 "(\\W|$)",
338 opts)))
339 return false;
340 } else {
341 return false;
342 }
343 }
344
345 return true;
346 }
347 };
348
349 std::vector<OptimizedRule> override_;
350 std::unordered_map<std::string, OptimizedRule> room;
351 std::unordered_map<std::string, OptimizedRule> sender;
352 std::vector<OptimizedRule> content;
353 std::vector<OptimizedRule> underride;
354};
355
356static std::unique_ptr<re2::RE2>
357construct_re_from_pattern(std::string pat, const std::string &field)
358{
359 pat = re2::RE2::QuoteMeta(unquoted: pat);
360
361 // Quote also espaces the globs, so we need to match them including the backslash
362 static re2::RE2 matchGlobStar("\\*");
363 re2::RE2::GlobalReplace(str: &pat, re: matchGlobStar, rewrite: ".*");
364
365 static re2::RE2 matchGlobQuest("\\?");
366 re2::RE2::GlobalReplace(str: &pat, re: matchGlobQuest, rewrite: ".");
367
368 re2::RE2::Options opts;
369 opts.set_case_sensitive(false);
370
371 if (field == "content.body")
372 return std::make_unique<re2::RE2>(args: "(\\W|^)" + pat + "(\\W|$)", args&: opts);
373 else
374 return std::make_unique<re2::RE2>(args&: pat, args&: opts);
375}
376
377PushRuleEvaluator::~PushRuleEvaluator() = default;
378PushRuleEvaluator::PushRuleEvaluator(const Ruleset &rules_)
379 : rules(std::make_unique<OptimizedRules>())
380{
381 auto add_conditions_to_rule = [](OptimizedRules::OptimizedRule &rule,
382 const std::vector<PushCondition> &conditions) {
383 for (const auto &cond : conditions) {
384 if (cond.kind == "event_match") {
385 OptimizedRules::OptimizedRule::PatternCondition c;
386 c.field = cond.key;
387 c.pattern = construct_re_from_pattern(pat: cond.pattern, field: cond.key);
388 if (c.pattern)
389 rule.patterns.push_back(x: std::move(c));
390 } else if (cond.kind == "im.nheko.msc3664.related_event_match") {
391 OptimizedRules::OptimizedRule::RelatedEventCondition c;
392
393 if (cond.rel_type != mtx::common::RelationType::Unsupported) {
394 c.rel_type = cond.rel_type;
395 c.include_fallbacks = cond.include_fallback;
396
397 if (!cond.key.empty() && !cond.pattern.empty()) {
398 c.ev_match.field = cond.key;
399 c.ev_match.pattern = construct_re_from_pattern(pat: cond.pattern, field: cond.key);
400 }
401 rule.related_event_patterns.push_back(x: std::move(c));
402 } else {
403 mtx::utils::log::log()->info(
404 msg: "Skipping rel_event_match rule with unknown rel_type.");
405 return false;
406 }
407 } else if (cond.kind == "contains_display_name") {
408 rule.check_displayname = true;
409 } else if (cond.kind == "room_member_count") {
410 OptimizedRules::OptimizedRule::MemberCountCondition c;
411 std::string_view is = cond.is;
412 if (is.starts_with(x: "==")) {
413 c.op = c.Comp::Eq;
414 is = is.substr(pos: 2);
415 } else if (is.starts_with(x: ">=")) {
416 c.op = c.Comp::Ge;
417 is = is.substr(pos: 2);
418 } else if (is.starts_with(x: "<=")) {
419 c.op = c.Comp::Le;
420 is = is.substr(pos: 2);
421 } else if (is.starts_with(x: '<')) {
422 c.op = c.Comp::Lt;
423 is = is.substr(pos: 1);
424 } else if (is.starts_with(x: '>')) {
425 c.op = c.Comp::Gt;
426 is = is.substr(pos: 1);
427 }
428
429 std::from_chars(first: is.data(), last: is.data() + is.size(), value&: c.count);
430 rule.membercounts.push_back(x: c);
431 } else if (cond.kind == "sender_notification_permission") {
432 rule.notification_levels.push_back(x: cond.key);
433 } else {
434 mtx::utils::log::log()->info(fmt: "Skipping rule with unknown condition type: {}",
435 args: cond.kind);
436 return false;
437 }
438 }
439
440 return true;
441 };
442
443 for (const auto &rule_ : rules_.override_) {
444 if (!rule_.enabled)
445 continue;
446
447 OptimizedRules::OptimizedRule rule;
448 rule.actions = rule_.actions;
449
450 if (!add_conditions_to_rule(rule, rule_.conditions))
451 continue;
452
453 rules->override_.push_back(x: std::move(rule));
454 }
455
456 for (const auto &rule_ : rules_.underride) {
457 if (!rule_.enabled)
458 continue;
459
460 OptimizedRules::OptimizedRule rule;
461 rule.actions = rule_.actions;
462
463 if (!add_conditions_to_rule(rule, rule_.conditions))
464 continue;
465
466 rules->underride.push_back(x: std::move(rule));
467 }
468
469 for (const auto &rule_ : rules_.room) {
470 if (!rule_.enabled)
471 continue;
472
473 if (!rule_.rule_id.starts_with(x: "!"))
474 continue;
475
476 OptimizedRules::OptimizedRule rule;
477 rule.actions = rule_.actions;
478 rules->room[rule_.rule_id] = std::move(rule);
479 }
480
481 for (const auto &rule_ : rules_.sender) {
482 if (!rule_.enabled)
483 continue;
484
485 if (!rule_.rule_id.starts_with(x: "@"))
486 continue;
487
488 OptimizedRules::OptimizedRule rule;
489 rule.actions = rule_.actions;
490 rules->sender[rule_.rule_id] = std::move(rule);
491 }
492
493 for (const auto &rule_ : rules_.content) {
494 if (!rule_.enabled)
495 continue;
496
497 // Work around construct sending invalid content rules.
498 // Also this seems like just a sane thing to do, an empty pattern will always match and that
499 // is usually not what you want.
500 if (rule_.pattern.empty())
501 continue;
502
503 OptimizedRules::OptimizedRule rule;
504 rule.actions = rule_.actions;
505
506 std::vector<PushCondition> conditions{
507 PushCondition{.kind = "event_match", .key = "content.body", .pattern = rule_.pattern},
508 };
509
510 if (!add_conditions_to_rule(rule, conditions))
511 continue;
512
513 rules->content.push_back(x: std::move(rule));
514 }
515}
516
517static void
518flatten_impl(const nlohmann::json &value,
519 std::unordered_map<std::string, std::string> &result,
520 const std::string &current_path,
521 int current_depth)
522{
523 if (current_depth > 100)
524 return;
525
526 switch (value.type()) {
527 case nlohmann::json::value_t::object: {
528 // iterate object and use keys as reference string
529 std::string prefix;
530 if (!current_path.empty())
531 prefix = current_path + ".";
532 for (const auto &element : value.items()) {
533 flatten_impl(value: element.value(), result, current_path: prefix + element.key(), current_depth: current_depth + 1);
534 }
535 break;
536 }
537
538 case nlohmann::json::value_t::string: {
539 // add primitive value with its reference string
540 result[current_path] = value.get<std::string>();
541 break;
542 }
543
544 // currently we only match strings
545 case nlohmann::json::value_t::array:
546 case nlohmann::json::value_t::null:
547 case nlohmann::json::value_t::boolean:
548 case nlohmann::json::value_t::number_integer:
549 case nlohmann::json::value_t::number_unsigned:
550 case nlohmann::json::value_t::number_float:
551 case nlohmann::json::value_t::binary:
552 case nlohmann::json::value_t::discarded:
553 default:
554 break;
555 }
556}
557
558static std::unordered_map<std::string, std::string>
559flatten_event(const nlohmann::json &j)
560{
561 std::unordered_map<std::string, std::string> flat;
562 flatten_impl(value: j, result&: flat, current_path: "", current_depth: 0);
563 return flat;
564}
565
566std::vector<actions::Action>
567PushRuleEvaluator::evaluate(
568 const mtx::events::collections::TimelineEvent &event,
569 const RoomContext &ctx,
570 const std::vector<std::pair<mtx::common::Relation, mtx::events::collections::TimelineEvent>>
571 &relatedEvents) const
572{
573 auto event_json = nlohmann::json(event);
574 auto flat_event = flatten_event(j: event_json);
575
576 std::map<mtx::common::RelationType, RelatedEvents> relatedEventsFlat;
577 for (const auto &[rel, ev] : relatedEvents) {
578 if (rel.rel_type != mtx::common::RelationType::Unsupported) {
579 if (rel.is_fallback)
580 relatedEventsFlat[rel.rel_type].fallbacks.push_back(
581 x: flatten_event(j: nlohmann::json(ev)));
582 else
583 relatedEventsFlat[rel.rel_type].events.push_back(x: flatten_event(j: nlohmann::json(ev)));
584 }
585 }
586
587 for (const auto &rule : rules->override_) {
588 if (rule.matches(ev: flat_event, ctx, relatedEventsFlat))
589 return rule.actions;
590 }
591
592 for (const auto &rule : rules->content) {
593 if (rule.matches(ev: flat_event, ctx, relatedEventsFlat))
594 return rule.actions;
595 }
596
597 // room rule always matches if present
598 if (auto room_rule = rules->room.find(x: event_json.value(key: "room_id", default_value: ""));
599 room_rule != rules->room.end()) {
600 return room_rule->second.actions;
601 }
602
603 // sender rule always matches if present
604 if (auto sender_rule = rules->sender.find(x: event_json.value(key: "sender", default_value: ""));
605 sender_rule != rules->sender.end()) {
606 return sender_rule->second.actions;
607 }
608
609 for (const auto &rule : rules->underride) {
610 if (rule.matches(ev: flat_event, ctx, relatedEventsFlat))
611 return rule.actions;
612 }
613 return {};
614}
615
616}
617}
618