@@ -42,17 +42,29 @@
/* tracepoints with more than 12 arguments will hit build error */
#define CAST_TO_U64(...) CONCATENATE(__CAST, COUNT_ARGS(__VA_ARGS__))(__VA_ARGS__)
-#define __BPF_DECLARE_TRACE(call, proto, args) \
+#define __BPF_DECLARE_TRACE(call, proto, args, tp_flags) \
static notrace void \
__bpf_trace_##call(void *__data, proto) \
{ \
struct bpf_prog *prog = __data; \
+ \
+ DEFINE_INACTIVE_GUARD(preempt_notrace, bpf_trace_guard); \
+ \
+ if ((tp_flags) & TRACEPOINT_MAY_FAULT) { \
+ might_fault(); \
+ activate_guard(preempt_notrace, bpf_trace_guard)(); \
+ } \
+ \
CONCATENATE(bpf_trace_run, COUNT_ARGS(args))(prog, CAST_TO_U64(args)); \
}
#undef DECLARE_EVENT_CLASS
#define DECLARE_EVENT_CLASS(call, proto, args, tstruct, assign, print) \
- __BPF_DECLARE_TRACE(call, PARAMS(proto), PARAMS(args))
+ __BPF_DECLARE_TRACE(call, PARAMS(proto), PARAMS(args), 0)
+
+#undef DECLARE_EVENT_CLASS_MAY_FAULT
+#define DECLARE_EVENT_CLASS_MAY_FAULT(call, proto, args, tstruct, assign, print) \
+ __BPF_DECLARE_TRACE(call, PARAMS(proto), PARAMS(args), TRACEPOINT_MAY_FAULT)
/*
* This part is compiled out, it is only here as a build time check
@@ -106,13 +118,13 @@ static inline void bpf_test_buffer_##call(void) \
#undef DECLARE_TRACE
#define DECLARE_TRACE(call, proto, args) \
- __BPF_DECLARE_TRACE(call, PARAMS(proto), PARAMS(args)) \
+ __BPF_DECLARE_TRACE(call, PARAMS(proto), PARAMS(args), 0) \
__DEFINE_EVENT(call, call, PARAMS(proto), PARAMS(args), 0)
#undef DECLARE_TRACE_WRITABLE
#define DECLARE_TRACE_WRITABLE(call, proto, args, size) \
__CHECK_WRITABLE_BUF_SIZE(call, PARAMS(proto), PARAMS(args), size) \
- __BPF_DECLARE_TRACE(call, PARAMS(proto), PARAMS(args)) \
+ __BPF_DECLARE_TRACE(call, PARAMS(proto), PARAMS(args), 0) \
__DEFINE_EVENT(call, call, PARAMS(proto), PARAMS(args), size)
#include TRACE_INCLUDE(TRACE_INCLUDE_FILE)
@@ -2443,9 +2443,15 @@ static int __bpf_probe_register(struct bpf_raw_event_map *btp, struct bpf_prog *
if (prog->aux->max_tp_access > btp->writable_size)
return -EINVAL;
- return tracepoint_probe_register_prio_flags(tp, (void *)btp->bpf_func,
- prog, TRACEPOINT_DEFAULT_PRIO,
- TRACEPOINT_MAY_EXIST);
+ if (tp->flags & TRACEPOINT_MAY_FAULT) {
+ return tracepoint_probe_register_prio_flags(tp, (void *)btp->bpf_func,
+ prog, TRACEPOINT_DEFAULT_PRIO,
+ TRACEPOINT_MAY_EXIST | TRACEPOINT_MAY_FAULT);
+ } else {
+ return tracepoint_probe_register_prio_flags(tp, (void *)btp->bpf_func,
+ prog, TRACEPOINT_DEFAULT_PRIO,
+ TRACEPOINT_MAY_EXIST);
+ }
}
int bpf_probe_register(struct bpf_raw_event_map *btp, struct bpf_prog *prog)