@@ -106,6 +106,7 @@ struct sk_psock {
struct mutex work_mutex;
struct sk_psock_work_state work_state;
struct delayed_work work;
+ struct sock *skpair;
struct rcu_work rwork;
};
@@ -75,6 +75,7 @@ struct unix_sock {
};
#define unix_sk(ptr) container_of_const(ptr, struct unix_sock, sk)
+#define unix_peer(sk) (unix_sk(sk)->peer)
#define peer_wait peer_wq.wait
@@ -826,6 +826,8 @@ static void sk_psock_destroy(struct work_struct *work)
if (psock->sk_redir)
sock_put(psock->sk_redir);
+ if (psock->skpair)
+ sock_put(psock->skpair);
sock_put(psock->sk);
kfree(psock);
}
@@ -213,8 +213,6 @@ static inline bool unix_secdata_eq(struct scm_cookie *scm, struct sk_buff *skb)
}
#endif /* CONFIG_SECURITY_NETWORK */
-#define unix_peer(sk) (unix_sk(sk)->peer)
-
static inline int unix_our_peer(struct sock *sk, struct sock *osk)
{
return unix_peer(osk) == sk;
@@ -159,12 +159,17 @@ int unix_dgram_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool re
int unix_stream_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
{
+ struct sock *skpair;
+
if (restore) {
sk->sk_write_space = psock->saved_write_space;
sock_replace_proto(sk, psock->sk_proto);
return 0;
}
+ skpair = unix_peer(sk);
+ sock_hold(skpair);
+ psock->skpair = skpair;
unix_stream_bpf_check_needs_rebuild(psock->sk_proto);
sock_replace_proto(sk, &unix_stream_bpf_prot);
return 0;