@@ -540,6 +540,8 @@ struct sock {
int (*sk_backlog_rcv)(struct sock *sk,
struct sk_buff *skb);
void (*sk_destruct)(struct sock *sk);
+ void (*sk_lock_sock)(struct sock *sk);
+ void (*sk_release_sock)(struct sock *sk);
struct sock_reuseport __rcu *sk_reuseport_cb;
#ifdef CONFIG_BPF_SYSCALL
struct bpf_local_storage __rcu *sk_bpf_storage;
@@ -1843,10 +1843,10 @@ int __cgroup_bpf_run_filter_setsockopt(struct sock *sk, int *level,
goto out;
}
- lock_sock(sk);
+ sk->sk_lock_sock ? sk->sk_lock_sock(sk) : lock_sock(sk);
ret = bpf_prog_run_array_cg(&cgrp->bpf, CGROUP_SETSOCKOPT,
&ctx, bpf_prog_run, 0, NULL);
- release_sock(sk);
+ sk->sk_release_sock ? sk->sk_release_sock(sk) : release_sock(sk);
if (ret)
goto out;
@@ -1952,10 +1952,10 @@ int __cgroup_bpf_run_filter_getsockopt(struct sock *sk, int level,
}
}
- lock_sock(sk);
+ sk->sk_lock_sock ? sk->sk_lock_sock(sk) : lock_sock(sk);
ret = bpf_prog_run_array_cg(&cgrp->bpf, CGROUP_GETSOCKOPT,
&ctx, bpf_prog_run, retval, NULL);
- release_sock(sk);
+ sk->sk_release_sock ? sk->sk_release_sock(sk) : release_sock(sk);
if (ret < 0)
goto out;
@@ -2712,6 +2712,18 @@ static void mptcp_worker(struct work_struct *work)
sock_put(sk);
}
+static void mptcp_sk_lock_sock(struct sock *sk)
+{
+ lock_sock(sk);
+ mptcp_set_bpf_iter_task(mptcp_sk(sk));
+}
+
+static void mptcp_sk_release_sock(struct sock *sk)
+{
+ mptcp_clear_bpf_iter_task(mptcp_sk(sk));
+ release_sock(sk);
+}
+
static void __mptcp_init_sock(struct sock *sk)
{
struct mptcp_sock *msk = mptcp_sk(sk);
@@ -2741,6 +2753,9 @@ static void __mptcp_init_sock(struct sock *sk)
/* re-use the csk retrans timer for MPTCP-level retrans */
timer_setup(&msk->sk.icsk_retransmit_timer, mptcp_retransmit_timer, 0);
timer_setup(&sk->sk_timer, mptcp_tout_timer, 0);
+
+ sk->sk_lock_sock = mptcp_sk_lock_sock;
+ sk->sk_release_sock = mptcp_sk_release_sock;
}
static void mptcp_ca_reset(struct sock *sk)