Message ID | 20241010174817.1543642-2-edumazet@google.com (mailing list archive) |
---|---|
State | Accepted |
Commit | 78e2baf3d96edd21c6f26d8afc0e68d02ec2c51c |
Delegated to: | Netdev Maintainers |
Headers | show |
Series | tcp: add skb->sk to more control packets | expand |
From: Eric Dumazet <edumazet@google.com> Date: Thu, 10 Oct 2024 17:48:13 +0000 > TCP will soon attach TIME_WAIT sockets to some ACK and RST. > > Make sure sk_to_full_sk() detects this and does not return > a non full socket. > > v3: also changed sk_const_to_full_sk() > > Signed-off-by: Eric Dumazet <edumazet@google.com> Reviewed-by: Kuniyuki Iwashima <kuniyu@amazon.com>
On 10/10/24 10:48 AM, Eric Dumazet wrote: > TCP will soon attach TIME_WAIT sockets to some ACK and RST. > > Make sure sk_to_full_sk() detects this and does not return > a non full socket. Reviewed-by: Martin KaFai Lau <martin.lau@kernel.org>
Thanks Eric for the patch series! I left some comments inline On Thu, Oct 10, 2024 at 1:48 PM Eric Dumazet <edumazet@google.com> wrote: > > TCP will soon attach TIME_WAIT sockets to some ACK and RST. > > Make sure sk_to_full_sk() detects this and does not return > a non full socket. > > v3: also changed sk_const_to_full_sk() > > Signed-off-by: Eric Dumazet <edumazet@google.com> > --- > include/linux/bpf-cgroup.h | 2 +- > include/net/inet_sock.h | 8 ++++++-- > net/core/filter.c | 6 +----- > 3 files changed, 8 insertions(+), 8 deletions(-) > > diff --git a/include/linux/bpf-cgroup.h b/include/linux/bpf-cgroup.h > index ce91d9b2acb9f8991150ceead4475b130bead438..f0f219271daf4afea2666c4d09fd4d1a8091f844 100644 > --- a/include/linux/bpf-cgroup.h > +++ b/include/linux/bpf-cgroup.h > @@ -209,7 +209,7 @@ static inline bool cgroup_bpf_sock_enabled(struct sock *sk, > int __ret = 0; \ > if (cgroup_bpf_enabled(CGROUP_INET_EGRESS) && sk) { \ > typeof(sk) __sk = sk_to_full_sk(sk); \ > - if (sk_fullsock(__sk) && __sk == skb_to_full_sk(skb) && \ > + if (__sk && __sk == skb_to_full_sk(skb) && \ > cgroup_bpf_sock_enabled(__sk, CGROUP_INET_EGRESS)) \ > __ret = __cgroup_bpf_run_filter_skb(__sk, skb, \ > CGROUP_INET_EGRESS); \ > diff --git a/include/net/inet_sock.h b/include/net/inet_sock.h > index f01dd273bea69d2eaf7a1d28274d7f980942b78a..56d8bc5593d3dfffd5f94cf7c6383948881917df 100644 > --- a/include/net/inet_sock.h > +++ b/include/net/inet_sock.h > @@ -321,8 +321,10 @@ static inline unsigned long inet_cmsg_flags(const struct inet_sock *inet) > static inline struct sock *sk_to_full_sk(struct sock *sk) > { > #ifdef CONFIG_INET > - if (sk && sk->sk_state == TCP_NEW_SYN_RECV) > + if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV) > sk = inet_reqsk(sk)->rsk_listener; > + if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT) > + sk = NULL; > #endif > return sk; > } > @@ -331,8 +333,10 @@ static inline struct sock *sk_to_full_sk(struct sock *sk) > static inline const struct sock *sk_const_to_full_sk(const struct sock *sk) > { > #ifdef CONFIG_INET > - if (sk && sk->sk_state == TCP_NEW_SYN_RECV) > + if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV) > sk = ((const struct request_sock *)sk)->rsk_listener; > + if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT) > + sk = NULL; > #endif > return sk; > } > diff --git a/net/core/filter.c b/net/core/filter.c > index bd0d08bf76bb8de39ca2ca89cda99a97c9b0a034..202c1d386e19599e9fc6e0a0d4a95986ba6d0ea8 100644 > --- a/net/core/filter.c > +++ b/net/core/filter.c > @@ -6778,8 +6778,6 @@ __bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len, > /* sk_to_full_sk() may return (sk)->rsk_listener, so make sure the original sk > * sock refcnt is decremented to prevent a request_sock leak. > */ > - if (!sk_fullsock(sk2)) > - sk2 = NULL; IIUC, we still want the condition above since sk_to_full_sk can return the request socket in which case the helper should return NULL, so we still need the refcnt decrement? > if (sk2 != sk) { > sock_gen_put(sk); > /* Ensure there is no need to bump sk2 refcnt */ > @@ -6826,8 +6824,6 @@ bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len, > /* sk_to_full_sk() may return (sk)->rsk_listener, so make sure the original sk > * sock refcnt is decremented to prevent a request_sock leak. > */ > - if (!sk_fullsock(sk2)) > - sk2 = NULL; Same as above. > if (sk2 != sk) { > sock_gen_put(sk); > /* Ensure there is no need to bump sk2 refcnt */ > @@ -7276,7 +7272,7 @@ BPF_CALL_1(bpf_get_listener_sock, struct sock *, sk) > { > sk = sk_to_full_sk(sk); > > - if (sk->sk_state == TCP_LISTEN && sock_flag(sk, SOCK_RCU_FREE)) > + if (sk && sk->sk_state == TCP_LISTEN && sock_flag(sk, SOCK_RCU_FREE)) > return (unsigned long)sk; > > return (unsigned long)NULL; > -- > 2.47.0.rc1.288.g06298d1525-goog >
On Mon, Oct 14, 2024 at 4:01 PM Brian Vazquez <brianvv@google.com> wrote: > > Thanks Eric for the patch series! I left some comments inline > > > On Thu, Oct 10, 2024 at 1:48 PM Eric Dumazet <edumazet@google.com> wrote: > > > > TCP will soon attach TIME_WAIT sockets to some ACK and RST. > > > > Make sure sk_to_full_sk() detects this and does not return > > a non full socket. > > > > v3: also changed sk_const_to_full_sk() > > > > Signed-off-by: Eric Dumazet <edumazet@google.com> > > --- > > include/linux/bpf-cgroup.h | 2 +- > > include/net/inet_sock.h | 8 ++++++-- > > net/core/filter.c | 6 +----- > > 3 files changed, 8 insertions(+), 8 deletions(-) > > > > diff --git a/include/linux/bpf-cgroup.h b/include/linux/bpf-cgroup.h > > index ce91d9b2acb9f8991150ceead4475b130bead438..f0f219271daf4afea2666c4d09fd4d1a8091f844 100644 > > --- a/include/linux/bpf-cgroup.h > > +++ b/include/linux/bpf-cgroup.h > > @@ -209,7 +209,7 @@ static inline bool cgroup_bpf_sock_enabled(struct sock *sk, > > int __ret = 0; \ > > if (cgroup_bpf_enabled(CGROUP_INET_EGRESS) && sk) { \ > > typeof(sk) __sk = sk_to_full_sk(sk); \ > > - if (sk_fullsock(__sk) && __sk == skb_to_full_sk(skb) && \ > > + if (__sk && __sk == skb_to_full_sk(skb) && \ > > cgroup_bpf_sock_enabled(__sk, CGROUP_INET_EGRESS)) \ > > __ret = __cgroup_bpf_run_filter_skb(__sk, skb, \ > > CGROUP_INET_EGRESS); \ > > diff --git a/include/net/inet_sock.h b/include/net/inet_sock.h > > index f01dd273bea69d2eaf7a1d28274d7f980942b78a..56d8bc5593d3dfffd5f94cf7c6383948881917df 100644 > > --- a/include/net/inet_sock.h > > +++ b/include/net/inet_sock.h > > @@ -321,8 +321,10 @@ static inline unsigned long inet_cmsg_flags(const struct inet_sock *inet) > > static inline struct sock *sk_to_full_sk(struct sock *sk) > > { > > #ifdef CONFIG_INET > > - if (sk && sk->sk_state == TCP_NEW_SYN_RECV) > > + if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV) > > sk = inet_reqsk(sk)->rsk_listener; > > + if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT) > > + sk = NULL; > > #endif > > return sk; > > } > > @@ -331,8 +333,10 @@ static inline struct sock *sk_to_full_sk(struct sock *sk) > > static inline const struct sock *sk_const_to_full_sk(const struct sock *sk) > > { > > #ifdef CONFIG_INET > > - if (sk && sk->sk_state == TCP_NEW_SYN_RECV) > > + if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV) > > sk = ((const struct request_sock *)sk)->rsk_listener; > > + if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT) > > + sk = NULL; > > #endif > > return sk; > > } > > diff --git a/net/core/filter.c b/net/core/filter.c > > index bd0d08bf76bb8de39ca2ca89cda99a97c9b0a034..202c1d386e19599e9fc6e0a0d4a95986ba6d0ea8 100644 > > --- a/net/core/filter.c > > +++ b/net/core/filter.c > > @@ -6778,8 +6778,6 @@ __bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len, > > /* sk_to_full_sk() may return (sk)->rsk_listener, so make sure the original sk > > * sock refcnt is decremented to prevent a request_sock leak. > > */ > > - if (!sk_fullsock(sk2)) > > - sk2 = NULL; > > IIUC, we still want the condition above since sk_to_full_sk can return > the request socket in which case the helper should return NULL, so we > still need the refcnt decrement? > > > if (sk2 != sk) { > > sock_gen_put(sk); Note that we call sock_gen_put(sk) here, not sock_gen_put(sk2); sk is not NULL here, so if sk2 is NULL, we will take this branch. > > /* Ensure there is no need to bump sk2 refcnt */ > > @@ -6826,8 +6824,6 @@ bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len, > > /* sk_to_full_sk() may return (sk)->rsk_listener, so make sure the original sk > > * sock refcnt is decremented to prevent a request_sock leak. > > */ > > - if (!sk_fullsock(sk2)) > > - sk2 = NULL; > > Same as above. Should be fine I think.
On Mon, Oct 14, 2024 at 10:28 AM Eric Dumazet <edumazet@google.com> wrote: > > On Mon, Oct 14, 2024 at 4:01 PM Brian Vazquez <brianvv@google.com> wrote: > > > > Thanks Eric for the patch series! I left some comments inline > > > > > > On Thu, Oct 10, 2024 at 1:48 PM Eric Dumazet <edumazet@google.com> wrote: > > > > > > TCP will soon attach TIME_WAIT sockets to some ACK and RST. > > > > > > Make sure sk_to_full_sk() detects this and does not return > > > a non full socket. > > > > > > v3: also changed sk_const_to_full_sk() > > > > > > Signed-off-by: Eric Dumazet <edumazet@google.com> > > > --- > > > include/linux/bpf-cgroup.h | 2 +- > > > include/net/inet_sock.h | 8 ++++++-- > > > net/core/filter.c | 6 +----- > > > 3 files changed, 8 insertions(+), 8 deletions(-) > > > > > > diff --git a/include/linux/bpf-cgroup.h b/include/linux/bpf-cgroup.h > > > index ce91d9b2acb9f8991150ceead4475b130bead438..f0f219271daf4afea2666c4d09fd4d1a8091f844 100644 > > > --- a/include/linux/bpf-cgroup.h > > > +++ b/include/linux/bpf-cgroup.h > > > @@ -209,7 +209,7 @@ static inline bool cgroup_bpf_sock_enabled(struct sock *sk, > > > int __ret = 0; \ > > > if (cgroup_bpf_enabled(CGROUP_INET_EGRESS) && sk) { \ > > > typeof(sk) __sk = sk_to_full_sk(sk); \ > > > - if (sk_fullsock(__sk) && __sk == skb_to_full_sk(skb) && \ > > > + if (__sk && __sk == skb_to_full_sk(skb) && \ > > > cgroup_bpf_sock_enabled(__sk, CGROUP_INET_EGRESS)) \ > > > __ret = __cgroup_bpf_run_filter_skb(__sk, skb, \ > > > CGROUP_INET_EGRESS); \ > > > diff --git a/include/net/inet_sock.h b/include/net/inet_sock.h > > > index f01dd273bea69d2eaf7a1d28274d7f980942b78a..56d8bc5593d3dfffd5f94cf7c6383948881917df 100644 > > > --- a/include/net/inet_sock.h > > > +++ b/include/net/inet_sock.h > > > @@ -321,8 +321,10 @@ static inline unsigned long inet_cmsg_flags(const struct inet_sock *inet) > > > static inline struct sock *sk_to_full_sk(struct sock *sk) > > > { > > > #ifdef CONFIG_INET > > > - if (sk && sk->sk_state == TCP_NEW_SYN_RECV) > > > + if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV) > > > sk = inet_reqsk(sk)->rsk_listener; > > > + if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT) > > > + sk = NULL; > > > #endif > > > return sk; > > > } > > > @@ -331,8 +333,10 @@ static inline struct sock *sk_to_full_sk(struct sock *sk) > > > static inline const struct sock *sk_const_to_full_sk(const struct sock *sk) > > > { > > > #ifdef CONFIG_INET > > > - if (sk && sk->sk_state == TCP_NEW_SYN_RECV) > > > + if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV) > > > sk = ((const struct request_sock *)sk)->rsk_listener; > > > + if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT) > > > + sk = NULL; > > > #endif > > > return sk; > > > } > > > diff --git a/net/core/filter.c b/net/core/filter.c > > > index bd0d08bf76bb8de39ca2ca89cda99a97c9b0a034..202c1d386e19599e9fc6e0a0d4a95986ba6d0ea8 100644 > > > --- a/net/core/filter.c > > > +++ b/net/core/filter.c > > > @@ -6778,8 +6778,6 @@ __bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len, > > > /* sk_to_full_sk() may return (sk)->rsk_listener, so make sure the original sk > > > * sock refcnt is decremented to prevent a request_sock leak. > > > */ > > > - if (!sk_fullsock(sk2)) > > > - sk2 = NULL; > > > > IIUC, we still want the condition above since sk_to_full_sk can return > > the request socket in which case the helper should return NULL, so we > > still need the refcnt decrement? > > > > > if (sk2 != sk) { > > > sock_gen_put(sk); > > Note that we call sock_gen_put(sk) here, not sock_gen_put(sk2); > > sk is not NULL here, so if sk2 is NULL, we will take this branch. IIUC __bpf_sk_lookup calls __bpf_skc_lookup which can return a request listener socket and takes a refcnt, but __bpf_sk_lookup should only return full_sk (no request nor time_wait). I agree that after the change to sk_to_full_sk, for time_wait it will return NULL, hence the condition is repetitive. if (!sk_fullsock(sk2)) sk2 = NULL; but sk_to_full_sk can still retrieve the listener: sk = inet_reqsk(sk)->rsk_listener; in which case we would like to still use if (!sk_fullsock(sk2)) sk2 = NULL; to invalidate the request socket, decrement the refcount and sk = sk2 ; // which makes sk == NULL? I think removing that condition allows __bpf_sk_lookup to return the req socket, which wasn't possible before? > > > > > /* Ensure there is no need to bump sk2 refcnt */ > > > @@ -6826,8 +6824,6 @@ bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len, > > > /* sk_to_full_sk() may return (sk)->rsk_listener, so make sure the original sk > > > * sock refcnt is decremented to prevent a request_sock leak. > > > */ > > > - if (!sk_fullsock(sk2)) > > > - sk2 = NULL; > > > > Same as above. > > Should be fine I think.
On Mon, Oct 14, 2024 at 5:03 PM Brian Vazquez <brianvv@google.com> wrote: > > On Mon, Oct 14, 2024 at 10:28 AM Eric Dumazet <edumazet@google.com> wrote: >> >> On Mon, Oct 14, 2024 at 4:01 PM Brian Vazquez <brianvv@google.com> wrote: >> > >> > Thanks Eric for the patch series! I left some comments inline >> > >> > >> > On Thu, Oct 10, 2024 at 1:48 PM Eric Dumazet <edumazet@google.com> wrote: >> > > >> > > TCP will soon attach TIME_WAIT sockets to some ACK and RST. >> > > >> > > Make sure sk_to_full_sk() detects this and does not return >> > > a non full socket. >> > > >> > > v3: also changed sk_const_to_full_sk() >> > > >> > > Signed-off-by: Eric Dumazet <edumazet@google.com> >> > > --- >> > > include/linux/bpf-cgroup.h | 2 +- >> > > include/net/inet_sock.h | 8 ++++++-- >> > > net/core/filter.c | 6 +----- >> > > 3 files changed, 8 insertions(+), 8 deletions(-) >> > > >> > > diff --git a/include/linux/bpf-cgroup.h b/include/linux/bpf-cgroup.h >> > > index ce91d9b2acb9f8991150ceead4475b130bead438..f0f219271daf4afea2666c4d09fd4d1a8091f844 100644 >> > > --- a/include/linux/bpf-cgroup.h >> > > +++ b/include/linux/bpf-cgroup.h >> > > @@ -209,7 +209,7 @@ static inline bool cgroup_bpf_sock_enabled(struct sock *sk, >> > > int __ret = 0; \ >> > > if (cgroup_bpf_enabled(CGROUP_INET_EGRESS) && sk) { \ >> > > typeof(sk) __sk = sk_to_full_sk(sk); \ >> > > - if (sk_fullsock(__sk) && __sk == skb_to_full_sk(skb) && \ >> > > + if (__sk && __sk == skb_to_full_sk(skb) && \ >> > > cgroup_bpf_sock_enabled(__sk, CGROUP_INET_EGRESS)) \ >> > > __ret = __cgroup_bpf_run_filter_skb(__sk, skb, \ >> > > CGROUP_INET_EGRESS); \ >> > > diff --git a/include/net/inet_sock.h b/include/net/inet_sock.h >> > > index f01dd273bea69d2eaf7a1d28274d7f980942b78a..56d8bc5593d3dfffd5f94cf7c6383948881917df 100644 >> > > --- a/include/net/inet_sock.h >> > > +++ b/include/net/inet_sock.h >> > > @@ -321,8 +321,10 @@ static inline unsigned long inet_cmsg_flags(const struct inet_sock *inet) >> > > static inline struct sock *sk_to_full_sk(struct sock *sk) >> > > { >> > > #ifdef CONFIG_INET >> > > - if (sk && sk->sk_state == TCP_NEW_SYN_RECV) >> > > + if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV) >> > > sk = inet_reqsk(sk)->rsk_listener; >> > > + if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT) >> > > + sk = NULL; >> > > #endif >> > > return sk; >> > > } >> > > @@ -331,8 +333,10 @@ static inline struct sock *sk_to_full_sk(struct sock *sk) >> > > static inline const struct sock *sk_const_to_full_sk(const struct sock *sk) >> > > { >> > > #ifdef CONFIG_INET >> > > - if (sk && sk->sk_state == TCP_NEW_SYN_RECV) >> > > + if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV) >> > > sk = ((const struct request_sock *)sk)->rsk_listener; >> > > + if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT) >> > > + sk = NULL; >> > > #endif >> > > return sk; >> > > } >> > > diff --git a/net/core/filter.c b/net/core/filter.c >> > > index bd0d08bf76bb8de39ca2ca89cda99a97c9b0a034..202c1d386e19599e9fc6e0a0d4a95986ba6d0ea8 100644 >> > > --- a/net/core/filter.c >> > > +++ b/net/core/filter.c >> > > @@ -6778,8 +6778,6 @@ __bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len, >> > > /* sk_to_full_sk() may return (sk)->rsk_listener, so make sure the original sk >> > > * sock refcnt is decremented to prevent a request_sock leak. >> > > */ >> > > - if (!sk_fullsock(sk2)) >> > > - sk2 = NULL; >> > >> > IIUC, we still want the condition above since sk_to_full_sk can return >> > the request socket in which case the helper should return NULL, so we >> > still need the refcnt decrement? >> > >> > > if (sk2 != sk) { >> > > sock_gen_put(sk); >> >> Note that we call sock_gen_put(sk) here, not sock_gen_put(sk2); >> >> >> sk is not NULL here, so if sk2 is NULL, we will take this branch. > > > IIUC __bpf_sk_lookup calls __bpf_skc_lookup which can return a request listener socket and takes a refcnt, but __bpf_sk_lookup should only return full_sk (no request nor time_wait). > > That's why the function tries to detect whether req or time_wait was retrieved by __bpf_skc_lookup and if so, we invalidate the return: sk = NULL, and decrement the refcnt. This is done by having sk2 and then comparing vs sk, and if sk2 is invalid because time_wait or listener, then we decrement sk (the original return from __bpf_skc_lookup, which took a refcnt) > > I agree that after the change to sk_to_full_sk, for time_wait it will return NULL, hence the condition is repetitive. > > if (!sk_fullsock(sk2)) > sk2 = NULL; > > but sk_to_full_sk can still retrieve the listener: sk = inet_reqsk(sk)->rsk_listener; in which case we would like to still use > if (!sk_fullsock(sk2)) > sk2 = NULL; > > to invalidate the request socket, decrement the refcount and sk = sk2 ; // which makes sk == NULL? > > I think removing that condition allows __bpf_sk_lookup to return the req socket, which wasn't possible before? It was not possible before, and not possible after : static inline struct sock *sk_to_full_sk(struct sock *sk) { #ifdef CONFIG_INET if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV) sk = inet_reqsk(sk)->rsk_listener; if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT) // NEW CODE sk = NULL; // NEW CODE #endif return sk; } if sk was a request socket, sk2 would be the listener. sk2 being a listener means that sk_fullsock(sk2) is true. if (!sk_fullsock(sk2)) sk2 = NULL; So really this check was only meant for TIME_WAIT, and it is now done directly from sk_to_full_sk() Therefore we can delete this dead code.
On Mon, Oct 14, 2024 at 11:24 AM Eric Dumazet <edumazet@google.com> wrote: > > On Mon, Oct 14, 2024 at 5:03 PM Brian Vazquez <brianvv@google.com> wrote: > > > > On Mon, Oct 14, 2024 at 10:28 AM Eric Dumazet <edumazet@google.com> wrote: > >> > >> On Mon, Oct 14, 2024 at 4:01 PM Brian Vazquez <brianvv@google.com> wrote: > >> > > >> > Thanks Eric for the patch series! I left some comments inline > >> > > >> > > >> > On Thu, Oct 10, 2024 at 1:48 PM Eric Dumazet <edumazet@google.com> wrote: > >> > > > >> > > TCP will soon attach TIME_WAIT sockets to some ACK and RST. > >> > > > >> > > Make sure sk_to_full_sk() detects this and does not return > >> > > a non full socket. > >> > > > >> > > v3: also changed sk_const_to_full_sk() > >> > > > >> > > Signed-off-by: Eric Dumazet <edumazet@google.com> > >> > > --- > >> > > include/linux/bpf-cgroup.h | 2 +- > >> > > include/net/inet_sock.h | 8 ++++++-- > >> > > net/core/filter.c | 6 +----- > >> > > 3 files changed, 8 insertions(+), 8 deletions(-) > >> > > > >> > > diff --git a/include/linux/bpf-cgroup.h b/include/linux/bpf-cgroup.h > >> > > index ce91d9b2acb9f8991150ceead4475b130bead438..f0f219271daf4afea2666c4d09fd4d1a8091f844 100644 > >> > > --- a/include/linux/bpf-cgroup.h > >> > > +++ b/include/linux/bpf-cgroup.h > >> > > @@ -209,7 +209,7 @@ static inline bool cgroup_bpf_sock_enabled(struct sock *sk, > >> > > int __ret = 0; \ > >> > > if (cgroup_bpf_enabled(CGROUP_INET_EGRESS) && sk) { \ > >> > > typeof(sk) __sk = sk_to_full_sk(sk); \ > >> > > - if (sk_fullsock(__sk) && __sk == skb_to_full_sk(skb) && \ > >> > > + if (__sk && __sk == skb_to_full_sk(skb) && \ > >> > > cgroup_bpf_sock_enabled(__sk, CGROUP_INET_EGRESS)) \ > >> > > __ret = __cgroup_bpf_run_filter_skb(__sk, skb, \ > >> > > CGROUP_INET_EGRESS); \ > >> > > diff --git a/include/net/inet_sock.h b/include/net/inet_sock.h > >> > > index f01dd273bea69d2eaf7a1d28274d7f980942b78a..56d8bc5593d3dfffd5f94cf7c6383948881917df 100644 > >> > > --- a/include/net/inet_sock.h > >> > > +++ b/include/net/inet_sock.h > >> > > @@ -321,8 +321,10 @@ static inline unsigned long inet_cmsg_flags(const struct inet_sock *inet) > >> > > static inline struct sock *sk_to_full_sk(struct sock *sk) > >> > > { > >> > > #ifdef CONFIG_INET > >> > > - if (sk && sk->sk_state == TCP_NEW_SYN_RECV) > >> > > + if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV) > >> > > sk = inet_reqsk(sk)->rsk_listener; > >> > > + if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT) > >> > > + sk = NULL; > >> > > #endif > >> > > return sk; > >> > > } > >> > > @@ -331,8 +333,10 @@ static inline struct sock *sk_to_full_sk(struct sock *sk) > >> > > static inline const struct sock *sk_const_to_full_sk(const struct sock *sk) > >> > > { > >> > > #ifdef CONFIG_INET > >> > > - if (sk && sk->sk_state == TCP_NEW_SYN_RECV) > >> > > + if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV) > >> > > sk = ((const struct request_sock *)sk)->rsk_listener; > >> > > + if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT) > >> > > + sk = NULL; > >> > > #endif > >> > > return sk; > >> > > } > >> > > diff --git a/net/core/filter.c b/net/core/filter.c > >> > > index bd0d08bf76bb8de39ca2ca89cda99a97c9b0a034..202c1d386e19599e9fc6e0a0d4a95986ba6d0ea8 100644 > >> > > --- a/net/core/filter.c > >> > > +++ b/net/core/filter.c > >> > > @@ -6778,8 +6778,6 @@ __bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len, > >> > > /* sk_to_full_sk() may return (sk)->rsk_listener, so make sure the original sk > >> > > * sock refcnt is decremented to prevent a request_sock leak. > >> > > */ > >> > > - if (!sk_fullsock(sk2)) > >> > > - sk2 = NULL; > >> > > >> > IIUC, we still want the condition above since sk_to_full_sk can return > >> > the request socket in which case the helper should return NULL, so we > >> > still need the refcnt decrement? > >> > > >> > > if (sk2 != sk) { > >> > > sock_gen_put(sk); > >> > >> Note that we call sock_gen_put(sk) here, not sock_gen_put(sk2); > >> > >> > >> sk is not NULL here, so if sk2 is NULL, we will take this branch. > > > > > > IIUC __bpf_sk_lookup calls __bpf_skc_lookup which can return a request listener socket and takes a refcnt, but __bpf_sk_lookup should only return full_sk (no request nor time_wait). > > > > That's why the function tries to detect whether req or time_wait was retrieved by __bpf_skc_lookup and if so, we invalidate the return: sk = NULL, and decrement the refcnt. This is done by having sk2 and then comparing vs sk, and if sk2 is invalid because time_wait or listener, then we decrement sk (the original return from __bpf_skc_lookup, which took a refcnt) > > > > I agree that after the change to sk_to_full_sk, for time_wait it will return NULL, hence the condition is repetitive. > > > > if (!sk_fullsock(sk2)) > > sk2 = NULL; > > > > but sk_to_full_sk can still retrieve the listener: sk = inet_reqsk(sk)->rsk_listener; in which case we would like to still use > > if (!sk_fullsock(sk2)) > > sk2 = NULL; > > > > to invalidate the request socket, decrement the refcount and sk = sk2 ; // which makes sk == NULL? > > > > I think removing that condition allows __bpf_sk_lookup to return the req socket, which wasn't possible before? > > It was not possible before, and not possible after : > > static inline struct sock *sk_to_full_sk(struct sock *sk) > { > #ifdef CONFIG_INET > if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV) > sk = inet_reqsk(sk)->rsk_listener; > if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT) // NEW CODE > sk = NULL; // NEW CODE > #endif > return sk; > } > > if sk was a request socket, sk2 would be the listener. This is the part that I missed, I got misled by the comment above the dead code. Thanks for clarifying! > > sk2 being a listener means that sk_fullsock(sk2) is true. > > if (!sk_fullsock(sk2)) > sk2 = NULL; > > So really this check was only meant for TIME_WAIT, and it is now done > directly from sk_to_full_sk() > > Therefore we can delete this dead code. Reviewed-by: Brian Vazquez <brianvv@google.com>
diff --git a/include/linux/bpf-cgroup.h b/include/linux/bpf-cgroup.h index ce91d9b2acb9f8991150ceead4475b130bead438..f0f219271daf4afea2666c4d09fd4d1a8091f844 100644 --- a/include/linux/bpf-cgroup.h +++ b/include/linux/bpf-cgroup.h @@ -209,7 +209,7 @@ static inline bool cgroup_bpf_sock_enabled(struct sock *sk, int __ret = 0; \ if (cgroup_bpf_enabled(CGROUP_INET_EGRESS) && sk) { \ typeof(sk) __sk = sk_to_full_sk(sk); \ - if (sk_fullsock(__sk) && __sk == skb_to_full_sk(skb) && \ + if (__sk && __sk == skb_to_full_sk(skb) && \ cgroup_bpf_sock_enabled(__sk, CGROUP_INET_EGRESS)) \ __ret = __cgroup_bpf_run_filter_skb(__sk, skb, \ CGROUP_INET_EGRESS); \ diff --git a/include/net/inet_sock.h b/include/net/inet_sock.h index f01dd273bea69d2eaf7a1d28274d7f980942b78a..56d8bc5593d3dfffd5f94cf7c6383948881917df 100644 --- a/include/net/inet_sock.h +++ b/include/net/inet_sock.h @@ -321,8 +321,10 @@ static inline unsigned long inet_cmsg_flags(const struct inet_sock *inet) static inline struct sock *sk_to_full_sk(struct sock *sk) { #ifdef CONFIG_INET - if (sk && sk->sk_state == TCP_NEW_SYN_RECV) + if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV) sk = inet_reqsk(sk)->rsk_listener; + if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT) + sk = NULL; #endif return sk; } @@ -331,8 +333,10 @@ static inline struct sock *sk_to_full_sk(struct sock *sk) static inline const struct sock *sk_const_to_full_sk(const struct sock *sk) { #ifdef CONFIG_INET - if (sk && sk->sk_state == TCP_NEW_SYN_RECV) + if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV) sk = ((const struct request_sock *)sk)->rsk_listener; + if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT) + sk = NULL; #endif return sk; } diff --git a/net/core/filter.c b/net/core/filter.c index bd0d08bf76bb8de39ca2ca89cda99a97c9b0a034..202c1d386e19599e9fc6e0a0d4a95986ba6d0ea8 100644 --- a/net/core/filter.c +++ b/net/core/filter.c @@ -6778,8 +6778,6 @@ __bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len, /* sk_to_full_sk() may return (sk)->rsk_listener, so make sure the original sk * sock refcnt is decremented to prevent a request_sock leak. */ - if (!sk_fullsock(sk2)) - sk2 = NULL; if (sk2 != sk) { sock_gen_put(sk); /* Ensure there is no need to bump sk2 refcnt */ @@ -6826,8 +6824,6 @@ bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len, /* sk_to_full_sk() may return (sk)->rsk_listener, so make sure the original sk * sock refcnt is decremented to prevent a request_sock leak. */ - if (!sk_fullsock(sk2)) - sk2 = NULL; if (sk2 != sk) { sock_gen_put(sk); /* Ensure there is no need to bump sk2 refcnt */ @@ -7276,7 +7272,7 @@ BPF_CALL_1(bpf_get_listener_sock, struct sock *, sk) { sk = sk_to_full_sk(sk); - if (sk->sk_state == TCP_LISTEN && sock_flag(sk, SOCK_RCU_FREE)) + if (sk && sk->sk_state == TCP_LISTEN && sock_flag(sk, SOCK_RCU_FREE)) return (unsigned long)sk; return (unsigned long)NULL;
TCP will soon attach TIME_WAIT sockets to some ACK and RST. Make sure sk_to_full_sk() detects this and does not return a non full socket. v3: also changed sk_const_to_full_sk() Signed-off-by: Eric Dumazet <edumazet@google.com> --- include/linux/bpf-cgroup.h | 2 +- include/net/inet_sock.h | 8 ++++++-- net/core/filter.c | 6 +----- 3 files changed, 8 insertions(+), 8 deletions(-)