diff mbox series

[v3,net-next,1/5] net: add TIME_WAIT logic to sk_to_full_sk()

Message ID 20241010174817.1543642-2-edumazet@google.com (mailing list archive)
State Accepted
Commit 78e2baf3d96edd21c6f26d8afc0e68d02ec2c51c
Delegated to: Netdev Maintainers
Headers show
Series tcp: add skb->sk to more control packets | expand

Checks

Context Check Description
netdev/series_format success Posting correctly formatted
netdev/tree_selection success Clearly marked for net-next, async
netdev/ynl success Generated files up to date; no warnings/errors; no diff in generated;
netdev/fixes_present success Fixes tag not required for -next series
netdev/header_inline success No static functions without inline keyword in header files
netdev/build_32bit success Errors and warnings before: 17 this patch: 17
netdev/build_tools success Errors and warnings before: 0 (+1) this patch: 0 (+1)
netdev/cc_maintainers warning 13 maintainers not CCed: song@kernel.org haoluo@google.com ast@kernel.org andrii@kernel.org john.fastabend@gmail.com sdf@fomichev.me martin.lau@linux.dev daniel@iogearbox.net bpf@vger.kernel.org kpsingh@kernel.org yonghong.song@linux.dev eddyz87@gmail.com jolsa@kernel.org
netdev/build_clang success Errors and warnings before: 39 this patch: 39
netdev/verify_signedoff success Signed-off-by tag matches author and committer
netdev/deprecated_api success None detected
netdev/check_selftest success No net selftest shell script
netdev/verify_fixes success No Fixes tag
netdev/build_allmodconfig_warn success Errors and warnings before: 2157 this patch: 2157
netdev/checkpatch success total: 0 errors, 0 warnings, 0 checks, 54 lines checked
netdev/build_clang_rust success No Rust files in patch. Skipping build
netdev/kdoc success Errors and warnings before: 8 this patch: 8
netdev/source_inline success Was 0 now: 0
netdev/contest success net-next-2024-10-12--12-00 (tests: 777)

Commit Message

Eric Dumazet Oct. 10, 2024, 5:48 p.m. UTC
TCP will soon attach TIME_WAIT sockets to some ACK and RST.

Make sure sk_to_full_sk() detects this and does not return
a non full socket.

v3: also changed sk_const_to_full_sk()

Signed-off-by: Eric Dumazet <edumazet@google.com>
---
 include/linux/bpf-cgroup.h | 2 +-
 include/net/inet_sock.h    | 8 ++++++--
 net/core/filter.c          | 6 +-----
 3 files changed, 8 insertions(+), 8 deletions(-)

Comments

Kuniyuki Iwashima Oct. 11, 2024, 11:20 p.m. UTC | #1
From: Eric Dumazet <edumazet@google.com>
Date: Thu, 10 Oct 2024 17:48:13 +0000
> TCP will soon attach TIME_WAIT sockets to some ACK and RST.
> 
> Make sure sk_to_full_sk() detects this and does not return
> a non full socket.
> 
> v3: also changed sk_const_to_full_sk()
> 
> Signed-off-by: Eric Dumazet <edumazet@google.com>

Reviewed-by: Kuniyuki Iwashima <kuniyu@amazon.com>
Martin KaFai Lau Oct. 12, 2024, 3:32 a.m. UTC | #2
On 10/10/24 10:48 AM, Eric Dumazet wrote:
> TCP will soon attach TIME_WAIT sockets to some ACK and RST.
> 
> Make sure sk_to_full_sk() detects this and does not return
> a non full socket.

Reviewed-by: Martin KaFai Lau <martin.lau@kernel.org>
Brian Vazquez Oct. 14, 2024, 2:01 p.m. UTC | #3
Thanks Eric for the patch series!  I left some comments inline


On Thu, Oct 10, 2024 at 1:48 PM Eric Dumazet <edumazet@google.com> wrote:
>
> TCP will soon attach TIME_WAIT sockets to some ACK and RST.
>
> Make sure sk_to_full_sk() detects this and does not return
> a non full socket.
>
> v3: also changed sk_const_to_full_sk()
>
> Signed-off-by: Eric Dumazet <edumazet@google.com>
> ---
>  include/linux/bpf-cgroup.h | 2 +-
>  include/net/inet_sock.h    | 8 ++++++--
>  net/core/filter.c          | 6 +-----
>  3 files changed, 8 insertions(+), 8 deletions(-)
>
> diff --git a/include/linux/bpf-cgroup.h b/include/linux/bpf-cgroup.h
> index ce91d9b2acb9f8991150ceead4475b130bead438..f0f219271daf4afea2666c4d09fd4d1a8091f844 100644
> --- a/include/linux/bpf-cgroup.h
> +++ b/include/linux/bpf-cgroup.h
> @@ -209,7 +209,7 @@ static inline bool cgroup_bpf_sock_enabled(struct sock *sk,
>         int __ret = 0;                                                         \
>         if (cgroup_bpf_enabled(CGROUP_INET_EGRESS) && sk) {                    \
>                 typeof(sk) __sk = sk_to_full_sk(sk);                           \
> -               if (sk_fullsock(__sk) && __sk == skb_to_full_sk(skb) &&        \
> +               if (__sk && __sk == skb_to_full_sk(skb) &&             \
>                     cgroup_bpf_sock_enabled(__sk, CGROUP_INET_EGRESS))         \
>                         __ret = __cgroup_bpf_run_filter_skb(__sk, skb,         \
>                                                       CGROUP_INET_EGRESS); \
> diff --git a/include/net/inet_sock.h b/include/net/inet_sock.h
> index f01dd273bea69d2eaf7a1d28274d7f980942b78a..56d8bc5593d3dfffd5f94cf7c6383948881917df 100644
> --- a/include/net/inet_sock.h
> +++ b/include/net/inet_sock.h
> @@ -321,8 +321,10 @@ static inline unsigned long inet_cmsg_flags(const struct inet_sock *inet)
>  static inline struct sock *sk_to_full_sk(struct sock *sk)
>  {
>  #ifdef CONFIG_INET
> -       if (sk && sk->sk_state == TCP_NEW_SYN_RECV)
> +       if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV)
>                 sk = inet_reqsk(sk)->rsk_listener;
> +       if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT)
> +               sk = NULL;
>  #endif
>         return sk;
>  }
> @@ -331,8 +333,10 @@ static inline struct sock *sk_to_full_sk(struct sock *sk)
>  static inline const struct sock *sk_const_to_full_sk(const struct sock *sk)
>  {
>  #ifdef CONFIG_INET
> -       if (sk && sk->sk_state == TCP_NEW_SYN_RECV)
> +       if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV)
>                 sk = ((const struct request_sock *)sk)->rsk_listener;
> +       if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT)
> +               sk = NULL;
>  #endif
>         return sk;
>  }
> diff --git a/net/core/filter.c b/net/core/filter.c
> index bd0d08bf76bb8de39ca2ca89cda99a97c9b0a034..202c1d386e19599e9fc6e0a0d4a95986ba6d0ea8 100644
> --- a/net/core/filter.c
> +++ b/net/core/filter.c
> @@ -6778,8 +6778,6 @@ __bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
>                 /* sk_to_full_sk() may return (sk)->rsk_listener, so make sure the original sk
>                  * sock refcnt is decremented to prevent a request_sock leak.
>                  */
> -               if (!sk_fullsock(sk2))
> -                       sk2 = NULL;

IIUC, we still want the condition above since sk_to_full_sk can return
the request socket in which case the helper should return NULL, so we
still need the refcnt decrement?

>                 if (sk2 != sk) {
>                         sock_gen_put(sk);
>                         /* Ensure there is no need to bump sk2 refcnt */
> @@ -6826,8 +6824,6 @@ bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
>                 /* sk_to_full_sk() may return (sk)->rsk_listener, so make sure the original sk
>                  * sock refcnt is decremented to prevent a request_sock leak.
>                  */
> -               if (!sk_fullsock(sk2))
> -                       sk2 = NULL;

Same as above.

>                 if (sk2 != sk) {
>                         sock_gen_put(sk);
>                         /* Ensure there is no need to bump sk2 refcnt */
> @@ -7276,7 +7272,7 @@ BPF_CALL_1(bpf_get_listener_sock, struct sock *, sk)
>  {
>         sk = sk_to_full_sk(sk);
>
> -       if (sk->sk_state == TCP_LISTEN && sock_flag(sk, SOCK_RCU_FREE))
> +       if (sk && sk->sk_state == TCP_LISTEN && sock_flag(sk, SOCK_RCU_FREE))
>                 return (unsigned long)sk;
>
>         return (unsigned long)NULL;
> --
> 2.47.0.rc1.288.g06298d1525-goog
>
Eric Dumazet Oct. 14, 2024, 2:27 p.m. UTC | #4
On Mon, Oct 14, 2024 at 4:01 PM Brian Vazquez <brianvv@google.com> wrote:
>
> Thanks Eric for the patch series!  I left some comments inline
>
>
> On Thu, Oct 10, 2024 at 1:48 PM Eric Dumazet <edumazet@google.com> wrote:
> >
> > TCP will soon attach TIME_WAIT sockets to some ACK and RST.
> >
> > Make sure sk_to_full_sk() detects this and does not return
> > a non full socket.
> >
> > v3: also changed sk_const_to_full_sk()
> >
> > Signed-off-by: Eric Dumazet <edumazet@google.com>
> > ---
> >  include/linux/bpf-cgroup.h | 2 +-
> >  include/net/inet_sock.h    | 8 ++++++--
> >  net/core/filter.c          | 6 +-----
> >  3 files changed, 8 insertions(+), 8 deletions(-)
> >
> > diff --git a/include/linux/bpf-cgroup.h b/include/linux/bpf-cgroup.h
> > index ce91d9b2acb9f8991150ceead4475b130bead438..f0f219271daf4afea2666c4d09fd4d1a8091f844 100644
> > --- a/include/linux/bpf-cgroup.h
> > +++ b/include/linux/bpf-cgroup.h
> > @@ -209,7 +209,7 @@ static inline bool cgroup_bpf_sock_enabled(struct sock *sk,
> >         int __ret = 0;                                                         \
> >         if (cgroup_bpf_enabled(CGROUP_INET_EGRESS) && sk) {                    \
> >                 typeof(sk) __sk = sk_to_full_sk(sk);                           \
> > -               if (sk_fullsock(__sk) && __sk == skb_to_full_sk(skb) &&        \
> > +               if (__sk && __sk == skb_to_full_sk(skb) &&             \
> >                     cgroup_bpf_sock_enabled(__sk, CGROUP_INET_EGRESS))         \
> >                         __ret = __cgroup_bpf_run_filter_skb(__sk, skb,         \
> >                                                       CGROUP_INET_EGRESS); \
> > diff --git a/include/net/inet_sock.h b/include/net/inet_sock.h
> > index f01dd273bea69d2eaf7a1d28274d7f980942b78a..56d8bc5593d3dfffd5f94cf7c6383948881917df 100644
> > --- a/include/net/inet_sock.h
> > +++ b/include/net/inet_sock.h
> > @@ -321,8 +321,10 @@ static inline unsigned long inet_cmsg_flags(const struct inet_sock *inet)
> >  static inline struct sock *sk_to_full_sk(struct sock *sk)
> >  {
> >  #ifdef CONFIG_INET
> > -       if (sk && sk->sk_state == TCP_NEW_SYN_RECV)
> > +       if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV)
> >                 sk = inet_reqsk(sk)->rsk_listener;
> > +       if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT)
> > +               sk = NULL;
> >  #endif
> >         return sk;
> >  }
> > @@ -331,8 +333,10 @@ static inline struct sock *sk_to_full_sk(struct sock *sk)
> >  static inline const struct sock *sk_const_to_full_sk(const struct sock *sk)
> >  {
> >  #ifdef CONFIG_INET
> > -       if (sk && sk->sk_state == TCP_NEW_SYN_RECV)
> > +       if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV)
> >                 sk = ((const struct request_sock *)sk)->rsk_listener;
> > +       if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT)
> > +               sk = NULL;
> >  #endif
> >         return sk;
> >  }
> > diff --git a/net/core/filter.c b/net/core/filter.c
> > index bd0d08bf76bb8de39ca2ca89cda99a97c9b0a034..202c1d386e19599e9fc6e0a0d4a95986ba6d0ea8 100644
> > --- a/net/core/filter.c
> > +++ b/net/core/filter.c
> > @@ -6778,8 +6778,6 @@ __bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
> >                 /* sk_to_full_sk() may return (sk)->rsk_listener, so make sure the original sk
> >                  * sock refcnt is decremented to prevent a request_sock leak.
> >                  */
> > -               if (!sk_fullsock(sk2))
> > -                       sk2 = NULL;
>
> IIUC, we still want the condition above since sk_to_full_sk can return
> the request socket in which case the helper should return NULL, so we
> still need the refcnt decrement?
>
> >                 if (sk2 != sk) {
> >                         sock_gen_put(sk);

Note that we call sock_gen_put(sk) here, not sock_gen_put(sk2);

sk is not NULL here, so if sk2 is NULL, we will take this branch.

> >                         /* Ensure there is no need to bump sk2 refcnt */
> > @@ -6826,8 +6824,6 @@ bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
> >                 /* sk_to_full_sk() may return (sk)->rsk_listener, so make sure the original sk
> >                  * sock refcnt is decremented to prevent a request_sock leak.
> >                  */
> > -               if (!sk_fullsock(sk2))
> > -                       sk2 = NULL;
>
> Same as above.

Should be fine I think.
Brian Vazquez Oct. 14, 2024, 3:08 p.m. UTC | #5
On Mon, Oct 14, 2024 at 10:28 AM Eric Dumazet <edumazet@google.com> wrote:
>
> On Mon, Oct 14, 2024 at 4:01 PM Brian Vazquez <brianvv@google.com> wrote:
> >
> > Thanks Eric for the patch series!  I left some comments inline
> >
> >
> > On Thu, Oct 10, 2024 at 1:48 PM Eric Dumazet <edumazet@google.com> wrote:
> > >
> > > TCP will soon attach TIME_WAIT sockets to some ACK and RST.
> > >
> > > Make sure sk_to_full_sk() detects this and does not return
> > > a non full socket.
> > >
> > > v3: also changed sk_const_to_full_sk()
> > >
> > > Signed-off-by: Eric Dumazet <edumazet@google.com>
> > > ---
> > >  include/linux/bpf-cgroup.h | 2 +-
> > >  include/net/inet_sock.h    | 8 ++++++--
> > >  net/core/filter.c          | 6 +-----
> > >  3 files changed, 8 insertions(+), 8 deletions(-)
> > >
> > > diff --git a/include/linux/bpf-cgroup.h b/include/linux/bpf-cgroup.h
> > > index ce91d9b2acb9f8991150ceead4475b130bead438..f0f219271daf4afea2666c4d09fd4d1a8091f844 100644
> > > --- a/include/linux/bpf-cgroup.h
> > > +++ b/include/linux/bpf-cgroup.h
> > > @@ -209,7 +209,7 @@ static inline bool cgroup_bpf_sock_enabled(struct sock *sk,
> > >         int __ret = 0;                                                         \
> > >         if (cgroup_bpf_enabled(CGROUP_INET_EGRESS) && sk) {                    \
> > >                 typeof(sk) __sk = sk_to_full_sk(sk);                           \
> > > -               if (sk_fullsock(__sk) && __sk == skb_to_full_sk(skb) &&        \
> > > +               if (__sk && __sk == skb_to_full_sk(skb) &&             \
> > >                     cgroup_bpf_sock_enabled(__sk, CGROUP_INET_EGRESS))         \
> > >                         __ret = __cgroup_bpf_run_filter_skb(__sk, skb,         \
> > >                                                       CGROUP_INET_EGRESS); \
> > > diff --git a/include/net/inet_sock.h b/include/net/inet_sock.h
> > > index f01dd273bea69d2eaf7a1d28274d7f980942b78a..56d8bc5593d3dfffd5f94cf7c6383948881917df 100644
> > > --- a/include/net/inet_sock.h
> > > +++ b/include/net/inet_sock.h
> > > @@ -321,8 +321,10 @@ static inline unsigned long inet_cmsg_flags(const struct inet_sock *inet)
> > >  static inline struct sock *sk_to_full_sk(struct sock *sk)
> > >  {
> > >  #ifdef CONFIG_INET
> > > -       if (sk && sk->sk_state == TCP_NEW_SYN_RECV)
> > > +       if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV)
> > >                 sk = inet_reqsk(sk)->rsk_listener;
> > > +       if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT)
> > > +               sk = NULL;
> > >  #endif
> > >         return sk;
> > >  }
> > > @@ -331,8 +333,10 @@ static inline struct sock *sk_to_full_sk(struct sock *sk)
> > >  static inline const struct sock *sk_const_to_full_sk(const struct sock *sk)
> > >  {
> > >  #ifdef CONFIG_INET
> > > -       if (sk && sk->sk_state == TCP_NEW_SYN_RECV)
> > > +       if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV)
> > >                 sk = ((const struct request_sock *)sk)->rsk_listener;
> > > +       if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT)
> > > +               sk = NULL;
> > >  #endif
> > >         return sk;
> > >  }
> > > diff --git a/net/core/filter.c b/net/core/filter.c
> > > index bd0d08bf76bb8de39ca2ca89cda99a97c9b0a034..202c1d386e19599e9fc6e0a0d4a95986ba6d0ea8 100644
> > > --- a/net/core/filter.c
> > > +++ b/net/core/filter.c
> > > @@ -6778,8 +6778,6 @@ __bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
> > >                 /* sk_to_full_sk() may return (sk)->rsk_listener, so make sure the original sk
> > >                  * sock refcnt is decremented to prevent a request_sock leak.
> > >                  */
> > > -               if (!sk_fullsock(sk2))
> > > -                       sk2 = NULL;
> >
> > IIUC, we still want the condition above since sk_to_full_sk can return
> > the request socket in which case the helper should return NULL, so we
> > still need the refcnt decrement?
> >
> > >                 if (sk2 != sk) {
> > >                         sock_gen_put(sk);
>
> Note that we call sock_gen_put(sk) here, not sock_gen_put(sk2);
>
> sk is not NULL here, so if sk2 is NULL, we will take this branch.

IIUC __bpf_sk_lookup calls __bpf_skc_lookup which can return a request
listener socket and takes a refcnt, but  __bpf_sk_lookup should only
return full_sk (no request nor time_wait).

I agree that after the change to sk_to_full_sk, for time_wait it will
return NULL, hence the condition is repetitive.

if (!sk_fullsock(sk2))
  sk2 = NULL;

but sk_to_full_sk can still retrieve the listener:   sk =
inet_reqsk(sk)->rsk_listener; in which case we would like to still use

if (!sk_fullsock(sk2))
  sk2 = NULL;

to invalidate the request socket, decrement the refcount and  sk = sk2
; // which makes sk == NULL?

I think removing that condition allows __bpf_sk_lookup to return the
req socket, which wasn't possible before?

>
>
> > >                         /* Ensure there is no need to bump sk2 refcnt */
> > > @@ -6826,8 +6824,6 @@ bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
> > >                 /* sk_to_full_sk() may return (sk)->rsk_listener, so make sure the original sk
> > >                  * sock refcnt is decremented to prevent a request_sock leak.
> > >                  */
> > > -               if (!sk_fullsock(sk2))
> > > -                       sk2 = NULL;
> >
> > Same as above.
>
> Should be fine I think.
Eric Dumazet Oct. 14, 2024, 3:23 p.m. UTC | #6
On Mon, Oct 14, 2024 at 5:03 PM Brian Vazquez <brianvv@google.com> wrote:
>
> On Mon, Oct 14, 2024 at 10:28 AM Eric Dumazet <edumazet@google.com> wrote:
>>
>> On Mon, Oct 14, 2024 at 4:01 PM Brian Vazquez <brianvv@google.com> wrote:
>> >
>> > Thanks Eric for the patch series!  I left some comments inline
>> >
>> >
>> > On Thu, Oct 10, 2024 at 1:48 PM Eric Dumazet <edumazet@google.com> wrote:
>> > >
>> > > TCP will soon attach TIME_WAIT sockets to some ACK and RST.
>> > >
>> > > Make sure sk_to_full_sk() detects this and does not return
>> > > a non full socket.
>> > >
>> > > v3: also changed sk_const_to_full_sk()
>> > >
>> > > Signed-off-by: Eric Dumazet <edumazet@google.com>
>> > > ---
>> > >  include/linux/bpf-cgroup.h | 2 +-
>> > >  include/net/inet_sock.h    | 8 ++++++--
>> > >  net/core/filter.c          | 6 +-----
>> > >  3 files changed, 8 insertions(+), 8 deletions(-)
>> > >
>> > > diff --git a/include/linux/bpf-cgroup.h b/include/linux/bpf-cgroup.h
>> > > index ce91d9b2acb9f8991150ceead4475b130bead438..f0f219271daf4afea2666c4d09fd4d1a8091f844 100644
>> > > --- a/include/linux/bpf-cgroup.h
>> > > +++ b/include/linux/bpf-cgroup.h
>> > > @@ -209,7 +209,7 @@ static inline bool cgroup_bpf_sock_enabled(struct sock *sk,
>> > >         int __ret = 0;                                                         \
>> > >         if (cgroup_bpf_enabled(CGROUP_INET_EGRESS) && sk) {                    \
>> > >                 typeof(sk) __sk = sk_to_full_sk(sk);                           \
>> > > -               if (sk_fullsock(__sk) && __sk == skb_to_full_sk(skb) &&        \
>> > > +               if (__sk && __sk == skb_to_full_sk(skb) &&             \
>> > >                     cgroup_bpf_sock_enabled(__sk, CGROUP_INET_EGRESS))         \
>> > >                         __ret = __cgroup_bpf_run_filter_skb(__sk, skb,         \
>> > >                                                       CGROUP_INET_EGRESS); \
>> > > diff --git a/include/net/inet_sock.h b/include/net/inet_sock.h
>> > > index f01dd273bea69d2eaf7a1d28274d7f980942b78a..56d8bc5593d3dfffd5f94cf7c6383948881917df 100644
>> > > --- a/include/net/inet_sock.h
>> > > +++ b/include/net/inet_sock.h
>> > > @@ -321,8 +321,10 @@ static inline unsigned long inet_cmsg_flags(const struct inet_sock *inet)
>> > >  static inline struct sock *sk_to_full_sk(struct sock *sk)
>> > >  {
>> > >  #ifdef CONFIG_INET
>> > > -       if (sk && sk->sk_state == TCP_NEW_SYN_RECV)
>> > > +       if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV)
>> > >                 sk = inet_reqsk(sk)->rsk_listener;
>> > > +       if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT)
>> > > +               sk = NULL;
>> > >  #endif
>> > >         return sk;
>> > >  }
>> > > @@ -331,8 +333,10 @@ static inline struct sock *sk_to_full_sk(struct sock *sk)
>> > >  static inline const struct sock *sk_const_to_full_sk(const struct sock *sk)
>> > >  {
>> > >  #ifdef CONFIG_INET
>> > > -       if (sk && sk->sk_state == TCP_NEW_SYN_RECV)
>> > > +       if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV)
>> > >                 sk = ((const struct request_sock *)sk)->rsk_listener;
>> > > +       if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT)
>> > > +               sk = NULL;
>> > >  #endif
>> > >         return sk;
>> > >  }
>> > > diff --git a/net/core/filter.c b/net/core/filter.c
>> > > index bd0d08bf76bb8de39ca2ca89cda99a97c9b0a034..202c1d386e19599e9fc6e0a0d4a95986ba6d0ea8 100644
>> > > --- a/net/core/filter.c
>> > > +++ b/net/core/filter.c
>> > > @@ -6778,8 +6778,6 @@ __bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
>> > >                 /* sk_to_full_sk() may return (sk)->rsk_listener, so make sure the original sk
>> > >                  * sock refcnt is decremented to prevent a request_sock leak.
>> > >                  */
>> > > -               if (!sk_fullsock(sk2))
>> > > -                       sk2 = NULL;
>> >
>> > IIUC, we still want the condition above since sk_to_full_sk can return
>> > the request socket in which case the helper should return NULL, so we
>> > still need the refcnt decrement?
>> >
>> > >                 if (sk2 != sk) {
>> > >                         sock_gen_put(sk);
>>
>> Note that we call sock_gen_put(sk) here, not sock_gen_put(sk2);
>>
>>
>> sk is not NULL here, so if sk2 is NULL, we will take this branch.
>
>
> IIUC __bpf_sk_lookup calls __bpf_skc_lookup which can return a request listener socket and takes a refcnt, but  __bpf_sk_lookup should only return full_sk (no request nor time_wait).
>
> That's why the function tries to detect whether req or time_wait was retrieved by __bpf_skc_lookup and if so, we invalidate the return:  sk = NULL, and decrement the refcnt. This is done by having sk2 and then comparing vs sk, and if sk2 is invalid because time_wait or listener, then we decrement sk (the original return from __bpf_skc_lookup, which took a refcnt)
>
> I agree that after the change to sk_to_full_sk, for time_wait it will return NULL, hence the condition is repetitive.
>
> if (!sk_fullsock(sk2))
>   sk2 = NULL;
>
> but sk_to_full_sk can still retrieve the listener:   sk = inet_reqsk(sk)->rsk_listener; in which case we would like to still use
> if (!sk_fullsock(sk2))
>   sk2 = NULL;
>
> to invalidate the request socket, decrement the refcount and  sk = sk2 ; // which makes sk == NULL?
>
> I think removing that condition allows __bpf_sk_lookup to return the req socket, which wasn't possible before?

It was not possible before, and not possible after :

static inline struct sock *sk_to_full_sk(struct sock *sk)
{
#ifdef CONFIG_INET
    if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV)
        sk = inet_reqsk(sk)->rsk_listener;
    if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT)    // NEW CODE
        sk = NULL;  // NEW CODE
#endif
    return sk;
}

if sk was a request socket, sk2 would be the listener.

sk2 being a listener means that sk_fullsock(sk2) is true.

if (!sk_fullsock(sk2))
   sk2 = NULL;

So really this check was only meant for TIME_WAIT, and it is now done
directly from sk_to_full_sk()

Therefore we can delete this dead code.
Brian Vazquez Oct. 14, 2024, 3:39 p.m. UTC | #7
On Mon, Oct 14, 2024 at 11:24 AM Eric Dumazet <edumazet@google.com> wrote:
>
> On Mon, Oct 14, 2024 at 5:03 PM Brian Vazquez <brianvv@google.com> wrote:
> >
> > On Mon, Oct 14, 2024 at 10:28 AM Eric Dumazet <edumazet@google.com> wrote:
> >>
> >> On Mon, Oct 14, 2024 at 4:01 PM Brian Vazquez <brianvv@google.com> wrote:
> >> >
> >> > Thanks Eric for the patch series!  I left some comments inline
> >> >
> >> >
> >> > On Thu, Oct 10, 2024 at 1:48 PM Eric Dumazet <edumazet@google.com> wrote:
> >> > >
> >> > > TCP will soon attach TIME_WAIT sockets to some ACK and RST.
> >> > >
> >> > > Make sure sk_to_full_sk() detects this and does not return
> >> > > a non full socket.
> >> > >
> >> > > v3: also changed sk_const_to_full_sk()
> >> > >
> >> > > Signed-off-by: Eric Dumazet <edumazet@google.com>
> >> > > ---
> >> > >  include/linux/bpf-cgroup.h | 2 +-
> >> > >  include/net/inet_sock.h    | 8 ++++++--
> >> > >  net/core/filter.c          | 6 +-----
> >> > >  3 files changed, 8 insertions(+), 8 deletions(-)
> >> > >
> >> > > diff --git a/include/linux/bpf-cgroup.h b/include/linux/bpf-cgroup.h
> >> > > index ce91d9b2acb9f8991150ceead4475b130bead438..f0f219271daf4afea2666c4d09fd4d1a8091f844 100644
> >> > > --- a/include/linux/bpf-cgroup.h
> >> > > +++ b/include/linux/bpf-cgroup.h
> >> > > @@ -209,7 +209,7 @@ static inline bool cgroup_bpf_sock_enabled(struct sock *sk,
> >> > >         int __ret = 0;                                                         \
> >> > >         if (cgroup_bpf_enabled(CGROUP_INET_EGRESS) && sk) {                    \
> >> > >                 typeof(sk) __sk = sk_to_full_sk(sk);                           \
> >> > > -               if (sk_fullsock(__sk) && __sk == skb_to_full_sk(skb) &&        \
> >> > > +               if (__sk && __sk == skb_to_full_sk(skb) &&             \
> >> > >                     cgroup_bpf_sock_enabled(__sk, CGROUP_INET_EGRESS))         \
> >> > >                         __ret = __cgroup_bpf_run_filter_skb(__sk, skb,         \
> >> > >                                                       CGROUP_INET_EGRESS); \
> >> > > diff --git a/include/net/inet_sock.h b/include/net/inet_sock.h
> >> > > index f01dd273bea69d2eaf7a1d28274d7f980942b78a..56d8bc5593d3dfffd5f94cf7c6383948881917df 100644
> >> > > --- a/include/net/inet_sock.h
> >> > > +++ b/include/net/inet_sock.h
> >> > > @@ -321,8 +321,10 @@ static inline unsigned long inet_cmsg_flags(const struct inet_sock *inet)
> >> > >  static inline struct sock *sk_to_full_sk(struct sock *sk)
> >> > >  {
> >> > >  #ifdef CONFIG_INET
> >> > > -       if (sk && sk->sk_state == TCP_NEW_SYN_RECV)
> >> > > +       if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV)
> >> > >                 sk = inet_reqsk(sk)->rsk_listener;
> >> > > +       if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT)
> >> > > +               sk = NULL;
> >> > >  #endif
> >> > >         return sk;
> >> > >  }
> >> > > @@ -331,8 +333,10 @@ static inline struct sock *sk_to_full_sk(struct sock *sk)
> >> > >  static inline const struct sock *sk_const_to_full_sk(const struct sock *sk)
> >> > >  {
> >> > >  #ifdef CONFIG_INET
> >> > > -       if (sk && sk->sk_state == TCP_NEW_SYN_RECV)
> >> > > +       if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV)
> >> > >                 sk = ((const struct request_sock *)sk)->rsk_listener;
> >> > > +       if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT)
> >> > > +               sk = NULL;
> >> > >  #endif
> >> > >         return sk;
> >> > >  }
> >> > > diff --git a/net/core/filter.c b/net/core/filter.c
> >> > > index bd0d08bf76bb8de39ca2ca89cda99a97c9b0a034..202c1d386e19599e9fc6e0a0d4a95986ba6d0ea8 100644
> >> > > --- a/net/core/filter.c
> >> > > +++ b/net/core/filter.c
> >> > > @@ -6778,8 +6778,6 @@ __bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
> >> > >                 /* sk_to_full_sk() may return (sk)->rsk_listener, so make sure the original sk
> >> > >                  * sock refcnt is decremented to prevent a request_sock leak.
> >> > >                  */
> >> > > -               if (!sk_fullsock(sk2))
> >> > > -                       sk2 = NULL;
> >> >
> >> > IIUC, we still want the condition above since sk_to_full_sk can return
> >> > the request socket in which case the helper should return NULL, so we
> >> > still need the refcnt decrement?
> >> >
> >> > >                 if (sk2 != sk) {
> >> > >                         sock_gen_put(sk);
> >>
> >> Note that we call sock_gen_put(sk) here, not sock_gen_put(sk2);
> >>
> >>
> >> sk is not NULL here, so if sk2 is NULL, we will take this branch.
> >
> >
> > IIUC __bpf_sk_lookup calls __bpf_skc_lookup which can return a request listener socket and takes a refcnt, but  __bpf_sk_lookup should only return full_sk (no request nor time_wait).
> >
> > That's why the function tries to detect whether req or time_wait was retrieved by __bpf_skc_lookup and if so, we invalidate the return:  sk = NULL, and decrement the refcnt. This is done by having sk2 and then comparing vs sk, and if sk2 is invalid because time_wait or listener, then we decrement sk (the original return from __bpf_skc_lookup, which took a refcnt)
> >
> > I agree that after the change to sk_to_full_sk, for time_wait it will return NULL, hence the condition is repetitive.
> >
> > if (!sk_fullsock(sk2))
> >   sk2 = NULL;
> >
> > but sk_to_full_sk can still retrieve the listener:   sk = inet_reqsk(sk)->rsk_listener; in which case we would like to still use
> > if (!sk_fullsock(sk2))
> >   sk2 = NULL;
> >
> > to invalidate the request socket, decrement the refcount and  sk = sk2 ; // which makes sk == NULL?
> >
> > I think removing that condition allows __bpf_sk_lookup to return the req socket, which wasn't possible before?
>
> It was not possible before, and not possible after :
>
> static inline struct sock *sk_to_full_sk(struct sock *sk)
> {
> #ifdef CONFIG_INET
>     if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV)
>         sk = inet_reqsk(sk)->rsk_listener;
>     if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT)    // NEW CODE
>         sk = NULL;  // NEW CODE
> #endif
>     return sk;
> }
>
> if sk was a request socket, sk2 would be the listener.

This is the part that I missed, I got misled by the comment above the dead code.

Thanks for clarifying!

>
> sk2 being a listener means that sk_fullsock(sk2) is true.
>
> if (!sk_fullsock(sk2))
>    sk2 = NULL;
>
> So really this check was only meant for TIME_WAIT, and it is now done
> directly from sk_to_full_sk()
>
> Therefore we can delete this dead code.

Reviewed-by: Brian Vazquez <brianvv@google.com>
diff mbox series

Patch

diff --git a/include/linux/bpf-cgroup.h b/include/linux/bpf-cgroup.h
index ce91d9b2acb9f8991150ceead4475b130bead438..f0f219271daf4afea2666c4d09fd4d1a8091f844 100644
--- a/include/linux/bpf-cgroup.h
+++ b/include/linux/bpf-cgroup.h
@@ -209,7 +209,7 @@  static inline bool cgroup_bpf_sock_enabled(struct sock *sk,
 	int __ret = 0;							       \
 	if (cgroup_bpf_enabled(CGROUP_INET_EGRESS) && sk) {		       \
 		typeof(sk) __sk = sk_to_full_sk(sk);			       \
-		if (sk_fullsock(__sk) && __sk == skb_to_full_sk(skb) &&	       \
+		if (__sk && __sk == skb_to_full_sk(skb) &&	       \
 		    cgroup_bpf_sock_enabled(__sk, CGROUP_INET_EGRESS))	       \
 			__ret = __cgroup_bpf_run_filter_skb(__sk, skb,	       \
 						      CGROUP_INET_EGRESS); \
diff --git a/include/net/inet_sock.h b/include/net/inet_sock.h
index f01dd273bea69d2eaf7a1d28274d7f980942b78a..56d8bc5593d3dfffd5f94cf7c6383948881917df 100644
--- a/include/net/inet_sock.h
+++ b/include/net/inet_sock.h
@@ -321,8 +321,10 @@  static inline unsigned long inet_cmsg_flags(const struct inet_sock *inet)
 static inline struct sock *sk_to_full_sk(struct sock *sk)
 {
 #ifdef CONFIG_INET
-	if (sk && sk->sk_state == TCP_NEW_SYN_RECV)
+	if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV)
 		sk = inet_reqsk(sk)->rsk_listener;
+	if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT)
+		sk = NULL;
 #endif
 	return sk;
 }
@@ -331,8 +333,10 @@  static inline struct sock *sk_to_full_sk(struct sock *sk)
 static inline const struct sock *sk_const_to_full_sk(const struct sock *sk)
 {
 #ifdef CONFIG_INET
-	if (sk && sk->sk_state == TCP_NEW_SYN_RECV)
+	if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV)
 		sk = ((const struct request_sock *)sk)->rsk_listener;
+	if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT)
+		sk = NULL;
 #endif
 	return sk;
 }
diff --git a/net/core/filter.c b/net/core/filter.c
index bd0d08bf76bb8de39ca2ca89cda99a97c9b0a034..202c1d386e19599e9fc6e0a0d4a95986ba6d0ea8 100644
--- a/net/core/filter.c
+++ b/net/core/filter.c
@@ -6778,8 +6778,6 @@  __bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
 		/* sk_to_full_sk() may return (sk)->rsk_listener, so make sure the original sk
 		 * sock refcnt is decremented to prevent a request_sock leak.
 		 */
-		if (!sk_fullsock(sk2))
-			sk2 = NULL;
 		if (sk2 != sk) {
 			sock_gen_put(sk);
 			/* Ensure there is no need to bump sk2 refcnt */
@@ -6826,8 +6824,6 @@  bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
 		/* sk_to_full_sk() may return (sk)->rsk_listener, so make sure the original sk
 		 * sock refcnt is decremented to prevent a request_sock leak.
 		 */
-		if (!sk_fullsock(sk2))
-			sk2 = NULL;
 		if (sk2 != sk) {
 			sock_gen_put(sk);
 			/* Ensure there is no need to bump sk2 refcnt */
@@ -7276,7 +7272,7 @@  BPF_CALL_1(bpf_get_listener_sock, struct sock *, sk)
 {
 	sk = sk_to_full_sk(sk);
 
-	if (sk->sk_state == TCP_LISTEN && sock_flag(sk, SOCK_RCU_FREE))
+	if (sk && sk->sk_state == TCP_LISTEN && sock_flag(sk, SOCK_RCU_FREE))
 		return (unsigned long)sk;
 
 	return (unsigned long)NULL;