diff mbox series

[net-next,v2,1/2] net: Update bhash2 when socket's rcv saddr changes

Message ID 20220602165101.3188482-2-joannelkoong@gmail.com (mailing list archive)
State Changes Requested
Delegated to: Netdev Maintainers
Headers show
Series Update bhash2 when socket's rcv saddr changes | expand

Checks

Context Check Description
netdev/tree_selection success Clearly marked for net-next
netdev/fixes_present success Fixes tag not required for -next series
netdev/subject_prefix success Link
netdev/cover_letter success Series has a cover letter
netdev/patch_count success Link
netdev/header_inline success No static functions without inline keyword in header files
netdev/build_32bit success Errors and warnings before: 2082 this patch: 2082
netdev/cc_maintainers fail 1 blamed authors not CCed: kuniyu@amazon.co.jp; 4 maintainers not CCed: yoshfuji@linux-ipv6.org dccp@vger.kernel.org kuniyu@amazon.co.jp dsahern@kernel.org
netdev/build_clang success Errors and warnings before: 586 this patch: 586
netdev/module_param success Was 0 now: 0
netdev/verify_signedoff success Signed-off-by tag matches author and committer
netdev/check_selftest success No net selftest shell script
netdev/verify_fixes success Fixes tag looks correct
netdev/build_allmodconfig_warn success Errors and warnings before: 2207 this patch: 2207
netdev/checkpatch success total: 0 errors, 0 warnings, 0 checks, 275 lines checked
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/source_inline success Was 0 now: 0

Commit Message

Joanne Koong June 2, 2022, 4:51 p.m. UTC
Commit d5a42de8bdbe ("net: Add a second bind table hashed by port and
address") added a second bind table, bhash2, that hashes by a socket's port
and rcv address.

However, there are two cases where the socket's rcv saddr can change
after it has been binded:

1) The case where there is a bind() call on "::" (IPADDR_ANY) and then
a connect() call. The kernel will assign the socket an address when it
handles the connect()

2) In inet_sk_reselect_saddr(), which is called when rerouting fails
when rebuilding the sk header (invoked by inet_sk_rebuild_header)

In these two cases, we need to update the bhash2 table by removing the
entry for the old address, and adding a new entry reflecting the updated
address.

Reported-by: syzbot+015d756bbd1f8b5c8f09@syzkaller.appspotmail.com
Fixes: d5a42de8bdbe ("net: Add a second bind table hashed by port and address")
Signed-off-by: Joanne Koong <joannelkoong@gmail.com>
Reviewed-by: Eric Dumazet <edumazet@google.com>
---
 include/net/inet_hashtables.h |  6 ++-
 include/net/ipv6.h            |  2 +-
 net/dccp/ipv4.c               | 10 +++--
 net/dccp/ipv6.c               |  4 +-
 net/ipv4/af_inet.c            |  7 +++-
 net/ipv4/inet_hashtables.c    | 70 ++++++++++++++++++++++++++++++++---
 net/ipv4/tcp_ipv4.c           |  8 +++-
 net/ipv6/inet6_hashtables.c   |  4 +-
 net/ipv6/tcp_ipv6.c           |  4 +-
 9 files changed, 97 insertions(+), 18 deletions(-)

Comments

Paolo Abeni June 7, 2022, 8:33 a.m. UTC | #1
Hello,

On Thu, 2022-06-02 at 09:51 -0700, Joanne Koong wrote:
> Commit d5a42de8bdbe ("net: Add a second bind table hashed by port and
> address") added a second bind table, bhash2, that hashes by a socket's port
> and rcv address.
> 
> However, there are two cases where the socket's rcv saddr can change
> after it has been binded:
> 
> 1) The case where there is a bind() call on "::" (IPADDR_ANY) and then
> a connect() call. The kernel will assign the socket an address when it
> handles the connect()
> 
> 2) In inet_sk_reselect_saddr(), which is called when rerouting fails
> when rebuilding the sk header (invoked by inet_sk_rebuild_header)
> 
> In these two cases, we need to update the bhash2 table by removing the
> entry for the old address, and adding a new entry reflecting the updated
> address.
> 
> Reported-by: syzbot+015d756bbd1f8b5c8f09@syzkaller.appspotmail.com
> Fixes: d5a42de8bdbe ("net: Add a second bind table hashed by port and address")
> Signed-off-by: Joanne Koong <joannelkoong@gmail.com>
> Reviewed-by: Eric Dumazet <edumazet@google.com>
> ---
>  include/net/inet_hashtables.h |  6 ++-
>  include/net/ipv6.h            |  2 +-
>  net/dccp/ipv4.c               | 10 +++--
>  net/dccp/ipv6.c               |  4 +-
>  net/ipv4/af_inet.c            |  7 +++-
>  net/ipv4/inet_hashtables.c    | 70 ++++++++++++++++++++++++++++++++---
>  net/ipv4/tcp_ipv4.c           |  8 +++-
>  net/ipv6/inet6_hashtables.c   |  4 +-
>  net/ipv6/tcp_ipv6.c           |  4 +-
>  9 files changed, 97 insertions(+), 18 deletions(-)
> 
> diff --git a/include/net/inet_hashtables.h b/include/net/inet_hashtables.h
> index a0887b70967b..2c331ce6ca73 100644
> --- a/include/net/inet_hashtables.h
> +++ b/include/net/inet_hashtables.h
> @@ -448,11 +448,13 @@ static inline void sk_rcv_saddr_set(struct sock *sk, __be32 addr)
>  }
>  
>  int __inet_hash_connect(struct inet_timewait_death_row *death_row,
> -			struct sock *sk, u64 port_offset,
> +			struct sock *sk, u64 port_offset, bool prev_inaddr_any,
>  			int (*check_established)(struct inet_timewait_death_row *,
>  						 struct sock *, __u16,
>  						 struct inet_timewait_sock **));
>  
>  int inet_hash_connect(struct inet_timewait_death_row *death_row,
> -		      struct sock *sk);
> +		      struct sock *sk, bool prev_inaddr_any);
> +
> +int inet_bhash2_update_saddr(struct sock *sk);
>  #endif /* _INET_HASHTABLES_H */
> diff --git a/include/net/ipv6.h b/include/net/ipv6.h
> index 5b38bf1a586b..6a50aca56d50 100644
> --- a/include/net/ipv6.h
> +++ b/include/net/ipv6.h
> @@ -1187,7 +1187,7 @@ int inet6_compat_ioctl(struct socket *sock, unsigned int cmd,
>  		unsigned long arg);
>  
>  int inet6_hash_connect(struct inet_timewait_death_row *death_row,
> -			      struct sock *sk);
> +		       struct sock *sk, bool prev_inaddr_any);
>  int inet6_sendmsg(struct socket *sock, struct msghdr *msg, size_t size);
>  int inet6_recvmsg(struct socket *sock, struct msghdr *msg, size_t size,
>  		  int flags);
> diff --git a/net/dccp/ipv4.c b/net/dccp/ipv4.c
> index da6e3b20cd75..37a8bc3ee49e 100644
> --- a/net/dccp/ipv4.c
> +++ b/net/dccp/ipv4.c
> @@ -47,12 +47,13 @@ int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
>  	const struct sockaddr_in *usin = (struct sockaddr_in *)uaddr;
>  	struct inet_sock *inet = inet_sk(sk);
>  	struct dccp_sock *dp = dccp_sk(sk);
> +	struct ip_options_rcu *inet_opt;
>  	__be16 orig_sport, orig_dport;
> +	bool prev_inaddr_any = false;
>  	__be32 daddr, nexthop;
>  	struct flowi4 *fl4;
>  	struct rtable *rt;
>  	int err;
> -	struct ip_options_rcu *inet_opt;
>  
>  	dp->dccps_role = DCCP_ROLE_CLIENT;
>  
> @@ -89,8 +90,11 @@ int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
>  	if (inet_opt == NULL || !inet_opt->opt.srr)
>  		daddr = fl4->daddr;
>  
> -	if (inet->inet_saddr == 0)
> +	if (inet->inet_saddr == 0) {
>  		inet->inet_saddr = fl4->saddr;
> +		prev_inaddr_any = true;
> +	}
> +
>  	sk_rcv_saddr_set(sk, inet->inet_saddr);
>  	inet->inet_dport = usin->sin_port;
>  	sk_daddr_set(sk, daddr);
> @@ -105,7 +109,7 @@ int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
>  	 * complete initialization after this.
>  	 */
>  	dccp_set_state(sk, DCCP_REQUESTING);
> -	err = inet_hash_connect(&dccp_death_row, sk);
> +	err = inet_hash_connect(&dccp_death_row, sk, prev_inaddr_any);
>  	if (err != 0)
>  		goto failure;
>  
> diff --git a/net/dccp/ipv6.c b/net/dccp/ipv6.c
> index fd44638ec16b..03013522acab 100644
> --- a/net/dccp/ipv6.c
> +++ b/net/dccp/ipv6.c
> @@ -824,6 +824,7 @@ static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
>  	struct ipv6_pinfo *np = inet6_sk(sk);
>  	struct dccp_sock *dp = dccp_sk(sk);
>  	struct in6_addr *saddr = NULL, *final_p, final;
> +	bool prev_inaddr_any = false;
>  	struct ipv6_txoptions *opt;
>  	struct flowi6 fl6;
>  	struct dst_entry *dst;
> @@ -936,6 +937,7 @@ static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
>  	if (saddr == NULL) {
>  		saddr = &fl6.saddr;
>  		sk->sk_v6_rcv_saddr = *saddr;
> +		prev_inaddr_any = true;
>  	}
>  
>  	/* set the source address */
> @@ -951,7 +953,7 @@ static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
>  	inet->inet_dport = usin->sin6_port;
>  
>  	dccp_set_state(sk, DCCP_REQUESTING);
> -	err = inet6_hash_connect(&dccp_death_row, sk);
> +	err = inet6_hash_connect(&dccp_death_row, sk, prev_inaddr_any);
>  	if (err)
>  		goto late_failure;
>  
> diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c
> index 93da9f783bec..ad627a99ff9d 100644
> --- a/net/ipv4/af_inet.c
> +++ b/net/ipv4/af_inet.c
> @@ -1221,10 +1221,11 @@ static int inet_sk_reselect_saddr(struct sock *sk)
>  	struct inet_sock *inet = inet_sk(sk);
>  	__be32 old_saddr = inet->inet_saddr;
>  	__be32 daddr = inet->inet_daddr;
> +	struct ip_options_rcu *inet_opt;
>  	struct flowi4 *fl4;
>  	struct rtable *rt;
>  	__be32 new_saddr;
> -	struct ip_options_rcu *inet_opt;
> +	int err;
>  
>  	inet_opt = rcu_dereference_protected(inet->inet_opt,
>  					     lockdep_sock_is_held(sk));
> @@ -1253,6 +1254,10 @@ static int inet_sk_reselect_saddr(struct sock *sk)
>  
>  	inet->inet_saddr = inet->inet_rcv_saddr = new_saddr;
>  
> +	err = inet_bhash2_update_saddr(sk);
> +	if (err)
> +		return err;
> +
>  	/*
>  	 * XXX The only one ugly spot where we need to
>  	 * XXX really change the sockets identity after
> diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c
> index e8de5e699b3f..592b70663a3b 100644
> --- a/net/ipv4/inet_hashtables.c
> +++ b/net/ipv4/inet_hashtables.c
> @@ -826,6 +826,55 @@ inet_bind2_bucket_find(struct inet_hashinfo *hinfo, struct net *net,
>  	return bhash2;
>  }
>  
> +/* the lock for the socket's corresponding bhash entry must be held */
> +static int __inet_bhash2_update_saddr(struct sock *sk,
> +				      struct inet_hashinfo *hinfo,
> +				      struct net *net, int port, int l3mdev)
> +{
> +	struct inet_bind2_hashbucket *head2;
> +	struct inet_bind2_bucket *tb2;
> +
> +	tb2 = inet_bind2_bucket_find(hinfo, net, port, l3mdev, sk,
> +				     &head2);
> +	if (!tb2) {
> +		tb2 = inet_bind2_bucket_create(hinfo->bind2_bucket_cachep,
> +					       net, head2, port, l3mdev, sk);
> +		if (!tb2)
> +			return -ENOMEM;
> +	}
> +
> +	/* Remove the socket's old entry from bhash2 */
> +	__sk_del_bind2_node(sk);
> +
> +	sk_add_bind2_node(sk, &tb2->owners);
> +	inet_csk(sk)->icsk_bind2_hash = tb2;
> +
> +	return 0;
> +}
> +
> +/* This should be called if/when a socket's rcv saddr changes after it has
> + * been binded.
> + */
> +int inet_bhash2_update_saddr(struct sock *sk)
> +{
> +	struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
> +	int l3mdev = inet_sk_bound_l3mdev(sk);
> +	struct inet_bind_hashbucket *head;
> +	int port = inet_sk(sk)->inet_num;
> +	struct net *net = sock_net(sk);
> +	int err;
> +
> +	head = &hinfo->bhash[inet_bhashfn(net, port, hinfo->bhash_size)];
> +
> +	spin_lock_bh(&head->lock);
> +
> +	err = __inet_bhash2_update_saddr(sk, hinfo, net, port, l3mdev);
> +
> +	spin_unlock_bh(&head->lock);
> +
> +	return err;
> +}
> +
>  /* RFC 6056 3.3.4.  Algorithm 4: Double-Hash Port Selection Algorithm
>   * Note that we use 32bit integers (vs RFC 'short integers')
>   * because 2^16 is not a multiple of num_ephemeral and this
> @@ -840,7 +889,7 @@ inet_bind2_bucket_find(struct inet_hashinfo *hinfo, struct net *net,
>  static u32 *table_perturb;
>  
>  int __inet_hash_connect(struct inet_timewait_death_row *death_row,
> -		struct sock *sk, u64 port_offset,
> +		struct sock *sk, u64 port_offset, bool prev_inaddr_any,
>  		int (*check_established)(struct inet_timewait_death_row *,
>  			struct sock *, __u16, struct inet_timewait_sock **))
>  {
> @@ -858,11 +907,24 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
>  	int l3mdev;
>  	u32 index;
>  
> +	l3mdev = inet_sk_bound_l3mdev(sk);
> +
>  	if (port) {
>  		head = &hinfo->bhash[inet_bhashfn(net, port,
>  						  hinfo->bhash_size)];
>  		tb = inet_csk(sk)->icsk_bind_hash;
> +
>  		spin_lock_bh(&head->lock);
> +
> +		if (prev_inaddr_any) {
> +			ret = __inet_bhash2_update_saddr(sk, hinfo, net, port,
> +							 l3mdev);
> +			if (ret) {
> +				spin_unlock_bh(&head->lock);
> +				return ret;
> +			}
> +		}
> +
>  		if (sk_head(&tb->owners) == sk && !sk->sk_bind_node.next) {
>  			inet_ehash_nolisten(sk, NULL, NULL);
>  			spin_unlock_bh(&head->lock);
> @@ -875,8 +937,6 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
>  		return ret;
>  	}
>  
> -	l3mdev = inet_sk_bound_l3mdev(sk);
> -
>  	inet_get_local_port_range(net, &low, &high);
>  	high++; /* [32768, 60999] -> [32768, 61000[ */
>  	remaining = high - low;
> @@ -987,13 +1047,13 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
>   * Bind a port for a connect operation and hash it.
>   */
>  int inet_hash_connect(struct inet_timewait_death_row *death_row,
> -		      struct sock *sk)
> +		      struct sock *sk, bool prev_inaddr_any)
>  {
>  	u64 port_offset = 0;
>  
>  	if (!inet_sk(sk)->inet_num)
>  		port_offset = inet_sk_port_offset(sk);
> -	return __inet_hash_connect(death_row, sk, port_offset,
> +	return __inet_hash_connect(death_row, sk, port_offset, prev_inaddr_any,
>  				   __inet_check_established);
>  }
>  EXPORT_SYMBOL_GPL(inet_hash_connect);
> diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
> index dac2650f3863..adf8d750933d 100644
> --- a/net/ipv4/tcp_ipv4.c
> +++ b/net/ipv4/tcp_ipv4.c
> @@ -203,6 +203,7 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
>  	struct inet_sock *inet = inet_sk(sk);
>  	struct tcp_sock *tp = tcp_sk(sk);
>  	__be16 orig_sport, orig_dport;
> +	bool prev_inaddr_any = false;
>  	__be32 daddr, nexthop;
>  	struct flowi4 *fl4;
>  	struct rtable *rt;
> @@ -246,8 +247,11 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
>  	if (!inet_opt || !inet_opt->opt.srr)
>  		daddr = fl4->daddr;
>  
> -	if (!inet->inet_saddr)
> +	if (!inet->inet_saddr) {
>  		inet->inet_saddr = fl4->saddr;
> +		prev_inaddr_any = true;
> +	}
> +
>  	sk_rcv_saddr_set(sk, inet->inet_saddr);
>  
>  	if (tp->rx_opt.ts_recent_stamp && inet->inet_daddr != daddr) {
> @@ -273,7 +277,7 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
>  	 * complete initialization after this.
>  	 */
>  	tcp_set_state(sk, TCP_SYN_SENT);
> -	err = inet_hash_connect(tcp_death_row, sk);
> +	err = inet_hash_connect(tcp_death_row, sk, prev_inaddr_any);
>  	if (err)
>  		goto failure;
>  
> diff --git a/net/ipv6/inet6_hashtables.c b/net/ipv6/inet6_hashtables.c
> index 7d53d62783b1..c87c5933f3be 100644
> --- a/net/ipv6/inet6_hashtables.c
> +++ b/net/ipv6/inet6_hashtables.c
> @@ -317,13 +317,13 @@ static u64 inet6_sk_port_offset(const struct sock *sk)
>  }
>  
>  int inet6_hash_connect(struct inet_timewait_death_row *death_row,
> -		       struct sock *sk)
> +		       struct sock *sk, bool prev_inaddr_any)
>  {
>  	u64 port_offset = 0;
>  
>  	if (!inet_sk(sk)->inet_num)
>  		port_offset = inet6_sk_port_offset(sk);
> -	return __inet_hash_connect(death_row, sk, port_offset,
> +	return __inet_hash_connect(death_row, sk, port_offset, prev_inaddr_any,
>  				   __inet6_check_established);
>  }
>  EXPORT_SYMBOL_GPL(inet6_hash_connect);
> diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c
> index f37dd4aa91c6..81e3312c2a97 100644
> --- a/net/ipv6/tcp_ipv6.c
> +++ b/net/ipv6/tcp_ipv6.c
> @@ -152,6 +152,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
>  	struct ipv6_pinfo *np = tcp_inet6_sk(sk);
>  	struct tcp_sock *tp = tcp_sk(sk);
>  	struct in6_addr *saddr = NULL, *final_p, final;
> +	bool prev_inaddr_any = false;
>  	struct ipv6_txoptions *opt;
>  	struct flowi6 fl6;
>  	struct dst_entry *dst;
> @@ -289,6 +290,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
>  	if (!saddr) {
>  		saddr = &fl6.saddr;
>  		sk->sk_v6_rcv_saddr = *saddr;
> +		prev_inaddr_any = true;
>  	}
>  
>  	/* set the source address */
> @@ -309,7 +311,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
>  
>  	tcp_set_state(sk, TCP_SYN_SENT);
>  	tcp_death_row = sock_net(sk)->ipv4.tcp_death_row;
> -	err = inet6_hash_connect(tcp_death_row, sk);
> +	err = inet6_hash_connect(tcp_death_row, sk, prev_inaddr_any);
>  	if (err)
>  		goto late_failure;
>  

I'm sorry for the late notice, but it looks like that the mptcp
syzkaller instance is still hitting the Warning in icsk_get_port on top
of the v1 of this series:

https://github.com/multipath-tcp/mptcp_net-next/issues/279

and the change in v2 should not address that. @Mat could you please
confirm the above?

Dumb question: I don't understand how the locking in bhash2 works.
Could you explain that? 

What happens when 2 different processes bind different sockets on
different ports (with different bhash buckets) using different
addresses so that they hit the same bhash2 bucket? AFAICS each process
will use a different lock and access/modification to bhash2 could
happen simultaneusly?

Thanks!

Paolo
Mat Martineau June 7, 2022, 5:10 p.m. UTC | #2
On Tue, 7 Jun 2022, Paolo Abeni wrote:

> Hello,
>
> On Thu, 2022-06-02 at 09:51 -0700, Joanne Koong wrote:
>> Commit d5a42de8bdbe ("net: Add a second bind table hashed by port and
>> address") added a second bind table, bhash2, that hashes by a socket's port
>> and rcv address.
>>
>> However, there are two cases where the socket's rcv saddr can change
>> after it has been binded:
>>
>> 1) The case where there is a bind() call on "::" (IPADDR_ANY) and then
>> a connect() call. The kernel will assign the socket an address when it
>> handles the connect()
>>
>> 2) In inet_sk_reselect_saddr(), which is called when rerouting fails
>> when rebuilding the sk header (invoked by inet_sk_rebuild_header)
>>
>> In these two cases, we need to update the bhash2 table by removing the
>> entry for the old address, and adding a new entry reflecting the updated
>> address.
>>
>> Reported-by: syzbot+015d756bbd1f8b5c8f09@syzkaller.appspotmail.com
>> Fixes: d5a42de8bdbe ("net: Add a second bind table hashed by port and address")
>> Signed-off-by: Joanne Koong <joannelkoong@gmail.com>
>> Reviewed-by: Eric Dumazet <edumazet@google.com>
>> ---
>>  include/net/inet_hashtables.h |  6 ++-
>>  include/net/ipv6.h            |  2 +-
>>  net/dccp/ipv4.c               | 10 +++--
>>  net/dccp/ipv6.c               |  4 +-
>>  net/ipv4/af_inet.c            |  7 +++-
>>  net/ipv4/inet_hashtables.c    | 70 ++++++++++++++++++++++++++++++++---
>>  net/ipv4/tcp_ipv4.c           |  8 +++-
>>  net/ipv6/inet6_hashtables.c   |  4 +-
>>  net/ipv6/tcp_ipv6.c           |  4 +-
>>  9 files changed, 97 insertions(+), 18 deletions(-)
>>
>> diff --git a/include/net/inet_hashtables.h b/include/net/inet_hashtables.h
>> index a0887b70967b..2c331ce6ca73 100644
>> --- a/include/net/inet_hashtables.h
>> +++ b/include/net/inet_hashtables.h
>> @@ -448,11 +448,13 @@ static inline void sk_rcv_saddr_set(struct sock *sk, __be32 addr)
>>  }
>>
>>  int __inet_hash_connect(struct inet_timewait_death_row *death_row,
>> -			struct sock *sk, u64 port_offset,
>> +			struct sock *sk, u64 port_offset, bool prev_inaddr_any,
>>  			int (*check_established)(struct inet_timewait_death_row *,
>>  						 struct sock *, __u16,
>>  						 struct inet_timewait_sock **));
>>
>>  int inet_hash_connect(struct inet_timewait_death_row *death_row,
>> -		      struct sock *sk);
>> +		      struct sock *sk, bool prev_inaddr_any);
>> +
>> +int inet_bhash2_update_saddr(struct sock *sk);
>>  #endif /* _INET_HASHTABLES_H */
>> diff --git a/include/net/ipv6.h b/include/net/ipv6.h
>> index 5b38bf1a586b..6a50aca56d50 100644
>> --- a/include/net/ipv6.h
>> +++ b/include/net/ipv6.h
>> @@ -1187,7 +1187,7 @@ int inet6_compat_ioctl(struct socket *sock, unsigned int cmd,
>>  		unsigned long arg);
>>
>>  int inet6_hash_connect(struct inet_timewait_death_row *death_row,
>> -			      struct sock *sk);
>> +		       struct sock *sk, bool prev_inaddr_any);
>>  int inet6_sendmsg(struct socket *sock, struct msghdr *msg, size_t size);
>>  int inet6_recvmsg(struct socket *sock, struct msghdr *msg, size_t size,
>>  		  int flags);
>> diff --git a/net/dccp/ipv4.c b/net/dccp/ipv4.c
>> index da6e3b20cd75..37a8bc3ee49e 100644
>> --- a/net/dccp/ipv4.c
>> +++ b/net/dccp/ipv4.c
>> @@ -47,12 +47,13 @@ int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
>>  	const struct sockaddr_in *usin = (struct sockaddr_in *)uaddr;
>>  	struct inet_sock *inet = inet_sk(sk);
>>  	struct dccp_sock *dp = dccp_sk(sk);
>> +	struct ip_options_rcu *inet_opt;
>>  	__be16 orig_sport, orig_dport;
>> +	bool prev_inaddr_any = false;
>>  	__be32 daddr, nexthop;
>>  	struct flowi4 *fl4;
>>  	struct rtable *rt;
>>  	int err;
>> -	struct ip_options_rcu *inet_opt;
>>
>>  	dp->dccps_role = DCCP_ROLE_CLIENT;
>>
>> @@ -89,8 +90,11 @@ int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
>>  	if (inet_opt == NULL || !inet_opt->opt.srr)
>>  		daddr = fl4->daddr;
>>
>> -	if (inet->inet_saddr == 0)
>> +	if (inet->inet_saddr == 0) {
>>  		inet->inet_saddr = fl4->saddr;
>> +		prev_inaddr_any = true;
>> +	}
>> +
>>  	sk_rcv_saddr_set(sk, inet->inet_saddr);
>>  	inet->inet_dport = usin->sin_port;
>>  	sk_daddr_set(sk, daddr);
>> @@ -105,7 +109,7 @@ int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
>>  	 * complete initialization after this.
>>  	 */
>>  	dccp_set_state(sk, DCCP_REQUESTING);
>> -	err = inet_hash_connect(&dccp_death_row, sk);
>> +	err = inet_hash_connect(&dccp_death_row, sk, prev_inaddr_any);
>>  	if (err != 0)
>>  		goto failure;
>>
>> diff --git a/net/dccp/ipv6.c b/net/dccp/ipv6.c
>> index fd44638ec16b..03013522acab 100644
>> --- a/net/dccp/ipv6.c
>> +++ b/net/dccp/ipv6.c
>> @@ -824,6 +824,7 @@ static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
>>  	struct ipv6_pinfo *np = inet6_sk(sk);
>>  	struct dccp_sock *dp = dccp_sk(sk);
>>  	struct in6_addr *saddr = NULL, *final_p, final;
>> +	bool prev_inaddr_any = false;
>>  	struct ipv6_txoptions *opt;
>>  	struct flowi6 fl6;
>>  	struct dst_entry *dst;
>> @@ -936,6 +937,7 @@ static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
>>  	if (saddr == NULL) {
>>  		saddr = &fl6.saddr;
>>  		sk->sk_v6_rcv_saddr = *saddr;
>> +		prev_inaddr_any = true;
>>  	}
>>
>>  	/* set the source address */
>> @@ -951,7 +953,7 @@ static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
>>  	inet->inet_dport = usin->sin6_port;
>>
>>  	dccp_set_state(sk, DCCP_REQUESTING);
>> -	err = inet6_hash_connect(&dccp_death_row, sk);
>> +	err = inet6_hash_connect(&dccp_death_row, sk, prev_inaddr_any);
>>  	if (err)
>>  		goto late_failure;
>>
>> diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c
>> index 93da9f783bec..ad627a99ff9d 100644
>> --- a/net/ipv4/af_inet.c
>> +++ b/net/ipv4/af_inet.c
>> @@ -1221,10 +1221,11 @@ static int inet_sk_reselect_saddr(struct sock *sk)
>>  	struct inet_sock *inet = inet_sk(sk);
>>  	__be32 old_saddr = inet->inet_saddr;
>>  	__be32 daddr = inet->inet_daddr;
>> +	struct ip_options_rcu *inet_opt;
>>  	struct flowi4 *fl4;
>>  	struct rtable *rt;
>>  	__be32 new_saddr;
>> -	struct ip_options_rcu *inet_opt;
>> +	int err;
>>
>>  	inet_opt = rcu_dereference_protected(inet->inet_opt,
>>  					     lockdep_sock_is_held(sk));
>> @@ -1253,6 +1254,10 @@ static int inet_sk_reselect_saddr(struct sock *sk)
>>
>>  	inet->inet_saddr = inet->inet_rcv_saddr = new_saddr;
>>
>> +	err = inet_bhash2_update_saddr(sk);
>> +	if (err)
>> +		return err;
>> +
>>  	/*
>>  	 * XXX The only one ugly spot where we need to
>>  	 * XXX really change the sockets identity after
>> diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c
>> index e8de5e699b3f..592b70663a3b 100644
>> --- a/net/ipv4/inet_hashtables.c
>> +++ b/net/ipv4/inet_hashtables.c
>> @@ -826,6 +826,55 @@ inet_bind2_bucket_find(struct inet_hashinfo *hinfo, struct net *net,
>>  	return bhash2;
>>  }
>>
>> +/* the lock for the socket's corresponding bhash entry must be held */
>> +static int __inet_bhash2_update_saddr(struct sock *sk,
>> +				      struct inet_hashinfo *hinfo,
>> +				      struct net *net, int port, int l3mdev)
>> +{
>> +	struct inet_bind2_hashbucket *head2;
>> +	struct inet_bind2_bucket *tb2;
>> +
>> +	tb2 = inet_bind2_bucket_find(hinfo, net, port, l3mdev, sk,
>> +				     &head2);
>> +	if (!tb2) {
>> +		tb2 = inet_bind2_bucket_create(hinfo->bind2_bucket_cachep,
>> +					       net, head2, port, l3mdev, sk);
>> +		if (!tb2)
>> +			return -ENOMEM;
>> +	}
>> +
>> +	/* Remove the socket's old entry from bhash2 */
>> +	__sk_del_bind2_node(sk);
>> +
>> +	sk_add_bind2_node(sk, &tb2->owners);
>> +	inet_csk(sk)->icsk_bind2_hash = tb2;
>> +
>> +	return 0;
>> +}
>> +
>> +/* This should be called if/when a socket's rcv saddr changes after it has
>> + * been binded.
>> + */
>> +int inet_bhash2_update_saddr(struct sock *sk)
>> +{
>> +	struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
>> +	int l3mdev = inet_sk_bound_l3mdev(sk);
>> +	struct inet_bind_hashbucket *head;
>> +	int port = inet_sk(sk)->inet_num;
>> +	struct net *net = sock_net(sk);
>> +	int err;
>> +
>> +	head = &hinfo->bhash[inet_bhashfn(net, port, hinfo->bhash_size)];
>> +
>> +	spin_lock_bh(&head->lock);
>> +
>> +	err = __inet_bhash2_update_saddr(sk, hinfo, net, port, l3mdev);
>> +
>> +	spin_unlock_bh(&head->lock);
>> +
>> +	return err;
>> +}
>> +
>>  /* RFC 6056 3.3.4.  Algorithm 4: Double-Hash Port Selection Algorithm
>>   * Note that we use 32bit integers (vs RFC 'short integers')
>>   * because 2^16 is not a multiple of num_ephemeral and this
>> @@ -840,7 +889,7 @@ inet_bind2_bucket_find(struct inet_hashinfo *hinfo, struct net *net,
>>  static u32 *table_perturb;
>>
>>  int __inet_hash_connect(struct inet_timewait_death_row *death_row,
>> -		struct sock *sk, u64 port_offset,
>> +		struct sock *sk, u64 port_offset, bool prev_inaddr_any,
>>  		int (*check_established)(struct inet_timewait_death_row *,
>>  			struct sock *, __u16, struct inet_timewait_sock **))
>>  {
>> @@ -858,11 +907,24 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
>>  	int l3mdev;
>>  	u32 index;
>>
>> +	l3mdev = inet_sk_bound_l3mdev(sk);
>> +
>>  	if (port) {
>>  		head = &hinfo->bhash[inet_bhashfn(net, port,
>>  						  hinfo->bhash_size)];
>>  		tb = inet_csk(sk)->icsk_bind_hash;
>> +
>>  		spin_lock_bh(&head->lock);
>> +
>> +		if (prev_inaddr_any) {
>> +			ret = __inet_bhash2_update_saddr(sk, hinfo, net, port,
>> +							 l3mdev);
>> +			if (ret) {
>> +				spin_unlock_bh(&head->lock);
>> +				return ret;
>> +			}
>> +		}
>> +
>>  		if (sk_head(&tb->owners) == sk && !sk->sk_bind_node.next) {
>>  			inet_ehash_nolisten(sk, NULL, NULL);
>>  			spin_unlock_bh(&head->lock);
>> @@ -875,8 +937,6 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
>>  		return ret;
>>  	}
>>
>> -	l3mdev = inet_sk_bound_l3mdev(sk);
>> -
>>  	inet_get_local_port_range(net, &low, &high);
>>  	high++; /* [32768, 60999] -> [32768, 61000[ */
>>  	remaining = high - low;
>> @@ -987,13 +1047,13 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
>>   * Bind a port for a connect operation and hash it.
>>   */
>>  int inet_hash_connect(struct inet_timewait_death_row *death_row,
>> -		      struct sock *sk)
>> +		      struct sock *sk, bool prev_inaddr_any)
>>  {
>>  	u64 port_offset = 0;
>>
>>  	if (!inet_sk(sk)->inet_num)
>>  		port_offset = inet_sk_port_offset(sk);
>> -	return __inet_hash_connect(death_row, sk, port_offset,
>> +	return __inet_hash_connect(death_row, sk, port_offset, prev_inaddr_any,
>>  				   __inet_check_established);
>>  }
>>  EXPORT_SYMBOL_GPL(inet_hash_connect);
>> diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
>> index dac2650f3863..adf8d750933d 100644
>> --- a/net/ipv4/tcp_ipv4.c
>> +++ b/net/ipv4/tcp_ipv4.c
>> @@ -203,6 +203,7 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
>>  	struct inet_sock *inet = inet_sk(sk);
>>  	struct tcp_sock *tp = tcp_sk(sk);
>>  	__be16 orig_sport, orig_dport;
>> +	bool prev_inaddr_any = false;
>>  	__be32 daddr, nexthop;
>>  	struct flowi4 *fl4;
>>  	struct rtable *rt;
>> @@ -246,8 +247,11 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
>>  	if (!inet_opt || !inet_opt->opt.srr)
>>  		daddr = fl4->daddr;
>>
>> -	if (!inet->inet_saddr)
>> +	if (!inet->inet_saddr) {
>>  		inet->inet_saddr = fl4->saddr;
>> +		prev_inaddr_any = true;
>> +	}
>> +
>>  	sk_rcv_saddr_set(sk, inet->inet_saddr);
>>
>>  	if (tp->rx_opt.ts_recent_stamp && inet->inet_daddr != daddr) {
>> @@ -273,7 +277,7 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
>>  	 * complete initialization after this.
>>  	 */
>>  	tcp_set_state(sk, TCP_SYN_SENT);
>> -	err = inet_hash_connect(tcp_death_row, sk);
>> +	err = inet_hash_connect(tcp_death_row, sk, prev_inaddr_any);
>>  	if (err)
>>  		goto failure;
>>
>> diff --git a/net/ipv6/inet6_hashtables.c b/net/ipv6/inet6_hashtables.c
>> index 7d53d62783b1..c87c5933f3be 100644
>> --- a/net/ipv6/inet6_hashtables.c
>> +++ b/net/ipv6/inet6_hashtables.c
>> @@ -317,13 +317,13 @@ static u64 inet6_sk_port_offset(const struct sock *sk)
>>  }
>>
>>  int inet6_hash_connect(struct inet_timewait_death_row *death_row,
>> -		       struct sock *sk)
>> +		       struct sock *sk, bool prev_inaddr_any)
>>  {
>>  	u64 port_offset = 0;
>>
>>  	if (!inet_sk(sk)->inet_num)
>>  		port_offset = inet6_sk_port_offset(sk);
>> -	return __inet_hash_connect(death_row, sk, port_offset,
>> +	return __inet_hash_connect(death_row, sk, port_offset, prev_inaddr_any,
>>  				   __inet6_check_established);
>>  }
>>  EXPORT_SYMBOL_GPL(inet6_hash_connect);
>> diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c
>> index f37dd4aa91c6..81e3312c2a97 100644
>> --- a/net/ipv6/tcp_ipv6.c
>> +++ b/net/ipv6/tcp_ipv6.c
>> @@ -152,6 +152,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
>>  	struct ipv6_pinfo *np = tcp_inet6_sk(sk);
>>  	struct tcp_sock *tp = tcp_sk(sk);
>>  	struct in6_addr *saddr = NULL, *final_p, final;
>> +	bool prev_inaddr_any = false;
>>  	struct ipv6_txoptions *opt;
>>  	struct flowi6 fl6;
>>  	struct dst_entry *dst;
>> @@ -289,6 +290,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
>>  	if (!saddr) {
>>  		saddr = &fl6.saddr;
>>  		sk->sk_v6_rcv_saddr = *saddr;
>> +		prev_inaddr_any = true;
>>  	}
>>
>>  	/* set the source address */
>> @@ -309,7 +311,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
>>
>>  	tcp_set_state(sk, TCP_SYN_SENT);
>>  	tcp_death_row = sock_net(sk)->ipv4.tcp_death_row;
>> -	err = inet6_hash_connect(tcp_death_row, sk);
>> +	err = inet6_hash_connect(tcp_death_row, sk, prev_inaddr_any);
>>  	if (err)
>>  		goto late_failure;
>>
>
> I'm sorry for the late notice, but it looks like that the mptcp
> syzkaller instance is still hitting the Warning in icsk_get_port on top
> of the v1 of this series:
>
> https://github.com/multipath-tcp/mptcp_net-next/issues/279
>
> and the change in v2 should not address that. @Mat could you please
> confirm the above?

Yes, I did see the icsk_get_port warning one time between June 1 and today 
with the v1 patch applied. I'll restart syzkaller with latest net-next and 
v2 just to be sure.

>
> Dumb question: I don't understand how the locking in bhash2 works.
> Could you explain that?
>
> What happens when 2 different processes bind different sockets on
> different ports (with different bhash buckets) using different
> addresses so that they hit the same bhash2 bucket? AFAICS each process
> will use a different lock and access/modification to bhash2 could
> happen simultaneusly?
>
> Thanks!
>
> Paolo
>
>
>

--
Mat Martineau
Intel
Joanne Koong June 7, 2022, 8:24 p.m. UTC | #3
On Tue, Jun 7, 2022 at 1:33 AM Paolo Abeni <pabeni@redhat.com> wrote:
>
> Hello,
>
> On Thu, 2022-06-02 at 09:51 -0700, Joanne Koong wrote:
> > Commit d5a42de8bdbe ("net: Add a second bind table hashed by port and
> > address") added a second bind table, bhash2, that hashes by a socket's port
> > and rcv address.
> >
> > However, there are two cases where the socket's rcv saddr can change
> > after it has been binded:
> >
> > 1) The case where there is a bind() call on "::" (IPADDR_ANY) and then
> > a connect() call. The kernel will assign the socket an address when it
> > handles the connect()
> >
> > 2) In inet_sk_reselect_saddr(), which is called when rerouting fails
> > when rebuilding the sk header (invoked by inet_sk_rebuild_header)
> >
> > In these two cases, we need to update the bhash2 table by removing the
> > entry for the old address, and adding a new entry reflecting the updated
> > address.
> >
> > Reported-by: syzbot+015d756bbd1f8b5c8f09@syzkaller.appspotmail.com
> > Fixes: d5a42de8bdbe ("net: Add a second bind table hashed by port and address")
> > Signed-off-by: Joanne Koong <joannelkoong@gmail.com>
> > Reviewed-by: Eric Dumazet <edumazet@google.com>
> > ---
> >  include/net/inet_hashtables.h |  6 ++-
> >  include/net/ipv6.h            |  2 +-
> >  net/dccp/ipv4.c               | 10 +++--
> >  net/dccp/ipv6.c               |  4 +-
> >  net/ipv4/af_inet.c            |  7 +++-
> >  net/ipv4/inet_hashtables.c    | 70 ++++++++++++++++++++++++++++++++---
> >  net/ipv4/tcp_ipv4.c           |  8 +++-
> >  net/ipv6/inet6_hashtables.c   |  4 +-
> >  net/ipv6/tcp_ipv6.c           |  4 +-
> >  9 files changed, 97 insertions(+), 18 deletions(-)
> >
> > diff --git a/include/net/inet_hashtables.h b/include/net/inet_hashtables.h
> > index a0887b70967b..2c331ce6ca73 100644
> > --- a/include/net/inet_hashtables.h
> > +++ b/include/net/inet_hashtables.h
> > @@ -448,11 +448,13 @@ static inline void sk_rcv_saddr_set(struct sock *sk, __be32 addr)
> >  }
> >
> >  int __inet_hash_connect(struct inet_timewait_death_row *death_row,
> > -                     struct sock *sk, u64 port_offset,
> > +                     struct sock *sk, u64 port_offset, bool prev_inaddr_any,
> >                       int (*check_established)(struct inet_timewait_death_row *,
> >                                                struct sock *, __u16,
> >                                                struct inet_timewait_sock **));
> >
> >  int inet_hash_connect(struct inet_timewait_death_row *death_row,
> > -                   struct sock *sk);
> > +                   struct sock *sk, bool prev_inaddr_any);
> > +
> > +int inet_bhash2_update_saddr(struct sock *sk);
> >  #endif /* _INET_HASHTABLES_H */
> > diff --git a/include/net/ipv6.h b/include/net/ipv6.h
> > index 5b38bf1a586b..6a50aca56d50 100644
> > --- a/include/net/ipv6.h
> > +++ b/include/net/ipv6.h
> > @@ -1187,7 +1187,7 @@ int inet6_compat_ioctl(struct socket *sock, unsigned int cmd,
> >               unsigned long arg);
> >
> >  int inet6_hash_connect(struct inet_timewait_death_row *death_row,
> > -                           struct sock *sk);
> > +                    struct sock *sk, bool prev_inaddr_any);
> >  int inet6_sendmsg(struct socket *sock, struct msghdr *msg, size_t size);
> >  int inet6_recvmsg(struct socket *sock, struct msghdr *msg, size_t size,
> >                 int flags);
> > diff --git a/net/dccp/ipv4.c b/net/dccp/ipv4.c
> > index da6e3b20cd75..37a8bc3ee49e 100644
> > --- a/net/dccp/ipv4.c
> > +++ b/net/dccp/ipv4.c
> > @@ -47,12 +47,13 @@ int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
> >       const struct sockaddr_in *usin = (struct sockaddr_in *)uaddr;
> >       struct inet_sock *inet = inet_sk(sk);
> >       struct dccp_sock *dp = dccp_sk(sk);
> > +     struct ip_options_rcu *inet_opt;
> >       __be16 orig_sport, orig_dport;
> > +     bool prev_inaddr_any = false;
> >       __be32 daddr, nexthop;
> >       struct flowi4 *fl4;
> >       struct rtable *rt;
> >       int err;
> > -     struct ip_options_rcu *inet_opt;
> >
> >       dp->dccps_role = DCCP_ROLE_CLIENT;
> >
> > @@ -89,8 +90,11 @@ int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
> >       if (inet_opt == NULL || !inet_opt->opt.srr)
> >               daddr = fl4->daddr;
> >
> > -     if (inet->inet_saddr == 0)
> > +     if (inet->inet_saddr == 0) {
> >               inet->inet_saddr = fl4->saddr;
> > +             prev_inaddr_any = true;
> > +     }
> > +
> >       sk_rcv_saddr_set(sk, inet->inet_saddr);
> >       inet->inet_dport = usin->sin_port;
> >       sk_daddr_set(sk, daddr);
> > @@ -105,7 +109,7 @@ int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
> >        * complete initialization after this.
> >        */
> >       dccp_set_state(sk, DCCP_REQUESTING);
> > -     err = inet_hash_connect(&dccp_death_row, sk);
> > +     err = inet_hash_connect(&dccp_death_row, sk, prev_inaddr_any);
> >       if (err != 0)
> >               goto failure;
> >
> > diff --git a/net/dccp/ipv6.c b/net/dccp/ipv6.c
> > index fd44638ec16b..03013522acab 100644
> > --- a/net/dccp/ipv6.c
> > +++ b/net/dccp/ipv6.c
> > @@ -824,6 +824,7 @@ static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
> >       struct ipv6_pinfo *np = inet6_sk(sk);
> >       struct dccp_sock *dp = dccp_sk(sk);
> >       struct in6_addr *saddr = NULL, *final_p, final;
> > +     bool prev_inaddr_any = false;
> >       struct ipv6_txoptions *opt;
> >       struct flowi6 fl6;
> >       struct dst_entry *dst;
> > @@ -936,6 +937,7 @@ static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
> >       if (saddr == NULL) {
> >               saddr = &fl6.saddr;
> >               sk->sk_v6_rcv_saddr = *saddr;
> > +             prev_inaddr_any = true;
> >       }
> >
> >       /* set the source address */
> > @@ -951,7 +953,7 @@ static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
> >       inet->inet_dport = usin->sin6_port;
> >
> >       dccp_set_state(sk, DCCP_REQUESTING);
> > -     err = inet6_hash_connect(&dccp_death_row, sk);
> > +     err = inet6_hash_connect(&dccp_death_row, sk, prev_inaddr_any);
> >       if (err)
> >               goto late_failure;
> >
> > diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c
> > index 93da9f783bec..ad627a99ff9d 100644
> > --- a/net/ipv4/af_inet.c
> > +++ b/net/ipv4/af_inet.c
> > @@ -1221,10 +1221,11 @@ static int inet_sk_reselect_saddr(struct sock *sk)
> >       struct inet_sock *inet = inet_sk(sk);
> >       __be32 old_saddr = inet->inet_saddr;
> >       __be32 daddr = inet->inet_daddr;
> > +     struct ip_options_rcu *inet_opt;
> >       struct flowi4 *fl4;
> >       struct rtable *rt;
> >       __be32 new_saddr;
> > -     struct ip_options_rcu *inet_opt;
> > +     int err;
> >
> >       inet_opt = rcu_dereference_protected(inet->inet_opt,
> >                                            lockdep_sock_is_held(sk));
> > @@ -1253,6 +1254,10 @@ static int inet_sk_reselect_saddr(struct sock *sk)
> >
> >       inet->inet_saddr = inet->inet_rcv_saddr = new_saddr;
> >
> > +     err = inet_bhash2_update_saddr(sk);
> > +     if (err)
> > +             return err;
> > +
> >       /*
> >        * XXX The only one ugly spot where we need to
> >        * XXX really change the sockets identity after
> > diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c
> > index e8de5e699b3f..592b70663a3b 100644
> > --- a/net/ipv4/inet_hashtables.c
> > +++ b/net/ipv4/inet_hashtables.c
> > @@ -826,6 +826,55 @@ inet_bind2_bucket_find(struct inet_hashinfo *hinfo, struct net *net,
> >       return bhash2;
> >  }
> >
> > +/* the lock for the socket's corresponding bhash entry must be held */
> > +static int __inet_bhash2_update_saddr(struct sock *sk,
> > +                                   struct inet_hashinfo *hinfo,
> > +                                   struct net *net, int port, int l3mdev)
> > +{
> > +     struct inet_bind2_hashbucket *head2;
> > +     struct inet_bind2_bucket *tb2;
> > +
> > +     tb2 = inet_bind2_bucket_find(hinfo, net, port, l3mdev, sk,
> > +                                  &head2);
> > +     if (!tb2) {
> > +             tb2 = inet_bind2_bucket_create(hinfo->bind2_bucket_cachep,
> > +                                            net, head2, port, l3mdev, sk);
> > +             if (!tb2)
> > +                     return -ENOMEM;
> > +     }
> > +
> > +     /* Remove the socket's old entry from bhash2 */
> > +     __sk_del_bind2_node(sk);
> > +
> > +     sk_add_bind2_node(sk, &tb2->owners);
> > +     inet_csk(sk)->icsk_bind2_hash = tb2;
> > +
> > +     return 0;
> > +}
> > +
> > +/* This should be called if/when a socket's rcv saddr changes after it has
> > + * been binded.
> > + */
> > +int inet_bhash2_update_saddr(struct sock *sk)
> > +{
> > +     struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
> > +     int l3mdev = inet_sk_bound_l3mdev(sk);
> > +     struct inet_bind_hashbucket *head;
> > +     int port = inet_sk(sk)->inet_num;
> > +     struct net *net = sock_net(sk);
> > +     int err;
> > +
> > +     head = &hinfo->bhash[inet_bhashfn(net, port, hinfo->bhash_size)];
> > +
> > +     spin_lock_bh(&head->lock);
> > +
> > +     err = __inet_bhash2_update_saddr(sk, hinfo, net, port, l3mdev);
> > +
> > +     spin_unlock_bh(&head->lock);
> > +
> > +     return err;
> > +}
> > +
> >  /* RFC 6056 3.3.4.  Algorithm 4: Double-Hash Port Selection Algorithm
> >   * Note that we use 32bit integers (vs RFC 'short integers')
> >   * because 2^16 is not a multiple of num_ephemeral and this
> > @@ -840,7 +889,7 @@ inet_bind2_bucket_find(struct inet_hashinfo *hinfo, struct net *net,
> >  static u32 *table_perturb;
> >
> >  int __inet_hash_connect(struct inet_timewait_death_row *death_row,
> > -             struct sock *sk, u64 port_offset,
> > +             struct sock *sk, u64 port_offset, bool prev_inaddr_any,
> >               int (*check_established)(struct inet_timewait_death_row *,
> >                       struct sock *, __u16, struct inet_timewait_sock **))
> >  {
> > @@ -858,11 +907,24 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
> >       int l3mdev;
> >       u32 index;
> >
> > +     l3mdev = inet_sk_bound_l3mdev(sk);
> > +
> >       if (port) {
> >               head = &hinfo->bhash[inet_bhashfn(net, port,
> >                                                 hinfo->bhash_size)];
> >               tb = inet_csk(sk)->icsk_bind_hash;
> > +
> >               spin_lock_bh(&head->lock);
> > +
> > +             if (prev_inaddr_any) {
> > +                     ret = __inet_bhash2_update_saddr(sk, hinfo, net, port,
> > +                                                      l3mdev);
> > +                     if (ret) {
> > +                             spin_unlock_bh(&head->lock);
> > +                             return ret;
> > +                     }
> > +             }
> > +
> >               if (sk_head(&tb->owners) == sk && !sk->sk_bind_node.next) {
> >                       inet_ehash_nolisten(sk, NULL, NULL);
> >                       spin_unlock_bh(&head->lock);
> > @@ -875,8 +937,6 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
> >               return ret;
> >       }
> >
> > -     l3mdev = inet_sk_bound_l3mdev(sk);
> > -
> >       inet_get_local_port_range(net, &low, &high);
> >       high++; /* [32768, 60999] -> [32768, 61000[ */
> >       remaining = high - low;
> > @@ -987,13 +1047,13 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
> >   * Bind a port for a connect operation and hash it.
> >   */
> >  int inet_hash_connect(struct inet_timewait_death_row *death_row,
> > -                   struct sock *sk)
> > +                   struct sock *sk, bool prev_inaddr_any)
> >  {
> >       u64 port_offset = 0;
> >
> >       if (!inet_sk(sk)->inet_num)
> >               port_offset = inet_sk_port_offset(sk);
> > -     return __inet_hash_connect(death_row, sk, port_offset,
> > +     return __inet_hash_connect(death_row, sk, port_offset, prev_inaddr_any,
> >                                  __inet_check_established);
> >  }
> >  EXPORT_SYMBOL_GPL(inet_hash_connect);
> > diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
> > index dac2650f3863..adf8d750933d 100644
> > --- a/net/ipv4/tcp_ipv4.c
> > +++ b/net/ipv4/tcp_ipv4.c
> > @@ -203,6 +203,7 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
> >       struct inet_sock *inet = inet_sk(sk);
> >       struct tcp_sock *tp = tcp_sk(sk);
> >       __be16 orig_sport, orig_dport;
> > +     bool prev_inaddr_any = false;
> >       __be32 daddr, nexthop;
> >       struct flowi4 *fl4;
> >       struct rtable *rt;
> > @@ -246,8 +247,11 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
> >       if (!inet_opt || !inet_opt->opt.srr)
> >               daddr = fl4->daddr;
> >
> > -     if (!inet->inet_saddr)
> > +     if (!inet->inet_saddr) {
> >               inet->inet_saddr = fl4->saddr;
> > +             prev_inaddr_any = true;
> > +     }
> > +
> >       sk_rcv_saddr_set(sk, inet->inet_saddr);
> >
> >       if (tp->rx_opt.ts_recent_stamp && inet->inet_daddr != daddr) {
> > @@ -273,7 +277,7 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
> >        * complete initialization after this.
> >        */
> >       tcp_set_state(sk, TCP_SYN_SENT);
> > -     err = inet_hash_connect(tcp_death_row, sk);
> > +     err = inet_hash_connect(tcp_death_row, sk, prev_inaddr_any);
> >       if (err)
> >               goto failure;
> >
> > diff --git a/net/ipv6/inet6_hashtables.c b/net/ipv6/inet6_hashtables.c
> > index 7d53d62783b1..c87c5933f3be 100644
> > --- a/net/ipv6/inet6_hashtables.c
> > +++ b/net/ipv6/inet6_hashtables.c
> > @@ -317,13 +317,13 @@ static u64 inet6_sk_port_offset(const struct sock *sk)
> >  }
> >
> >  int inet6_hash_connect(struct inet_timewait_death_row *death_row,
> > -                    struct sock *sk)
> > +                    struct sock *sk, bool prev_inaddr_any)
> >  {
> >       u64 port_offset = 0;
> >
> >       if (!inet_sk(sk)->inet_num)
> >               port_offset = inet6_sk_port_offset(sk);
> > -     return __inet_hash_connect(death_row, sk, port_offset,
> > +     return __inet_hash_connect(death_row, sk, port_offset, prev_inaddr_any,
> >                                  __inet6_check_established);
> >  }
> >  EXPORT_SYMBOL_GPL(inet6_hash_connect);
> > diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c
> > index f37dd4aa91c6..81e3312c2a97 100644
> > --- a/net/ipv6/tcp_ipv6.c
> > +++ b/net/ipv6/tcp_ipv6.c
> > @@ -152,6 +152,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
> >       struct ipv6_pinfo *np = tcp_inet6_sk(sk);
> >       struct tcp_sock *tp = tcp_sk(sk);
> >       struct in6_addr *saddr = NULL, *final_p, final;
> > +     bool prev_inaddr_any = false;
> >       struct ipv6_txoptions *opt;
> >       struct flowi6 fl6;
> >       struct dst_entry *dst;
> > @@ -289,6 +290,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
> >       if (!saddr) {
> >               saddr = &fl6.saddr;
> >               sk->sk_v6_rcv_saddr = *saddr;
> > +             prev_inaddr_any = true;
> >       }
> >
> >       /* set the source address */
> > @@ -309,7 +311,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
> >
> >       tcp_set_state(sk, TCP_SYN_SENT);
> >       tcp_death_row = sock_net(sk)->ipv4.tcp_death_row;
> > -     err = inet6_hash_connect(tcp_death_row, sk);
> > +     err = inet6_hash_connect(tcp_death_row, sk, prev_inaddr_any);
> >       if (err)
> >               goto late_failure;
> >
>
> I'm sorry for the late notice, but it looks like that the mptcp
> syzkaller instance is still hitting the Warning in icsk_get_port on top
> of the v1 of this series:
>
> https://github.com/multipath-tcp/mptcp_net-next/issues/279
>
> and the change in v2 should not address that. @Mat could you please
> confirm the above?
>
> Dumb question: I don't understand how the locking in bhash2 works.
> Could you explain that?
>
> What happens when 2 different processes bind different sockets on
> different ports (with different bhash buckets) using different
> addresses so that they hit the same bhash2 bucket? AFAICS each process
> will use a different lock and access/modification to bhash2 could
> happen simultaneusly?
Hi Paolo. Yes, I think you are correct here that there could be a
scenario where this happens. Unfortunately, I think this means the
bhash2 table will need its own lock. I will submit a follow-up for
this.
>
> Thanks!
>
> Paolo
>
>
Paolo Abeni June 8, 2022, 7:35 a.m. UTC | #4
On Tue, 2022-06-07 at 13:24 -0700, Joanne Koong wrote:
> On Tue, Jun 7, 2022 at 1:33 AM Paolo Abeni <pabeni@redhat.com> wrote:
> > 
> > Hello,
> > 
> > On Thu, 2022-06-02 at 09:51 -0700, Joanne Koong wrote:
> > > Commit d5a42de8bdbe ("net: Add a second bind table hashed by port and
> > > address") added a second bind table, bhash2, that hashes by a socket's port
> > > and rcv address.
> > > 
> > > However, there are two cases where the socket's rcv saddr can change
> > > after it has been binded:
> > > 
> > > 1) The case where there is a bind() call on "::" (IPADDR_ANY) and then
> > > a connect() call. The kernel will assign the socket an address when it
> > > handles the connect()
> > > 
> > > 2) In inet_sk_reselect_saddr(), which is called when rerouting fails
> > > when rebuilding the sk header (invoked by inet_sk_rebuild_header)
> > > 
> > > In these two cases, we need to update the bhash2 table by removing the
> > > entry for the old address, and adding a new entry reflecting the updated
> > > address.
> > > 
> > > Reported-by: syzbot+015d756bbd1f8b5c8f09@syzkaller.appspotmail.com
> > > Fixes: d5a42de8bdbe ("net: Add a second bind table hashed by port and address")
> > > Signed-off-by: Joanne Koong <joannelkoong@gmail.com>
> > > Reviewed-by: Eric Dumazet <edumazet@google.com>
> > > ---
> > >  include/net/inet_hashtables.h |  6 ++-
> > >  include/net/ipv6.h            |  2 +-
> > >  net/dccp/ipv4.c               | 10 +++--
> > >  net/dccp/ipv6.c               |  4 +-
> > >  net/ipv4/af_inet.c            |  7 +++-
> > >  net/ipv4/inet_hashtables.c    | 70 ++++++++++++++++++++++++++++++++---
> > >  net/ipv4/tcp_ipv4.c           |  8 +++-
> > >  net/ipv6/inet6_hashtables.c   |  4 +-
> > >  net/ipv6/tcp_ipv6.c           |  4 +-
> > >  9 files changed, 97 insertions(+), 18 deletions(-)
> > > 
> > > diff --git a/include/net/inet_hashtables.h b/include/net/inet_hashtables.h
> > > index a0887b70967b..2c331ce6ca73 100644
> > > --- a/include/net/inet_hashtables.h
> > > +++ b/include/net/inet_hashtables.h
> > > @@ -448,11 +448,13 @@ static inline void sk_rcv_saddr_set(struct sock *sk, __be32 addr)
> > >  }
> > > 
> > >  int __inet_hash_connect(struct inet_timewait_death_row *death_row,
> > > -                     struct sock *sk, u64 port_offset,
> > > +                     struct sock *sk, u64 port_offset, bool prev_inaddr_any,
> > >                       int (*check_established)(struct inet_timewait_death_row *,
> > >                                                struct sock *, __u16,
> > >                                                struct inet_timewait_sock **));
> > > 
> > >  int inet_hash_connect(struct inet_timewait_death_row *death_row,
> > > -                   struct sock *sk);
> > > +                   struct sock *sk, bool prev_inaddr_any);
> > > +
> > > +int inet_bhash2_update_saddr(struct sock *sk);
> > >  #endif /* _INET_HASHTABLES_H */
> > > diff --git a/include/net/ipv6.h b/include/net/ipv6.h
> > > index 5b38bf1a586b..6a50aca56d50 100644
> > > --- a/include/net/ipv6.h
> > > +++ b/include/net/ipv6.h
> > > @@ -1187,7 +1187,7 @@ int inet6_compat_ioctl(struct socket *sock, unsigned int cmd,
> > >               unsigned long arg);
> > > 
> > >  int inet6_hash_connect(struct inet_timewait_death_row *death_row,
> > > -                           struct sock *sk);
> > > +                    struct sock *sk, bool prev_inaddr_any);
> > >  int inet6_sendmsg(struct socket *sock, struct msghdr *msg, size_t size);
> > >  int inet6_recvmsg(struct socket *sock, struct msghdr *msg, size_t size,
> > >                 int flags);
> > > diff --git a/net/dccp/ipv4.c b/net/dccp/ipv4.c
> > > index da6e3b20cd75..37a8bc3ee49e 100644
> > > --- a/net/dccp/ipv4.c
> > > +++ b/net/dccp/ipv4.c
> > > @@ -47,12 +47,13 @@ int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
> > >       const struct sockaddr_in *usin = (struct sockaddr_in *)uaddr;
> > >       struct inet_sock *inet = inet_sk(sk);
> > >       struct dccp_sock *dp = dccp_sk(sk);
> > > +     struct ip_options_rcu *inet_opt;
> > >       __be16 orig_sport, orig_dport;
> > > +     bool prev_inaddr_any = false;
> > >       __be32 daddr, nexthop;
> > >       struct flowi4 *fl4;
> > >       struct rtable *rt;
> > >       int err;
> > > -     struct ip_options_rcu *inet_opt;
> > > 
> > >       dp->dccps_role = DCCP_ROLE_CLIENT;
> > > 
> > > @@ -89,8 +90,11 @@ int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
> > >       if (inet_opt == NULL || !inet_opt->opt.srr)
> > >               daddr = fl4->daddr;
> > > 
> > > -     if (inet->inet_saddr == 0)
> > > +     if (inet->inet_saddr == 0) {
> > >               inet->inet_saddr = fl4->saddr;
> > > +             prev_inaddr_any = true;
> > > +     }
> > > +
> > >       sk_rcv_saddr_set(sk, inet->inet_saddr);
> > >       inet->inet_dport = usin->sin_port;
> > >       sk_daddr_set(sk, daddr);
> > > @@ -105,7 +109,7 @@ int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
> > >        * complete initialization after this.
> > >        */
> > >       dccp_set_state(sk, DCCP_REQUESTING);
> > > -     err = inet_hash_connect(&dccp_death_row, sk);
> > > +     err = inet_hash_connect(&dccp_death_row, sk, prev_inaddr_any);
> > >       if (err != 0)
> > >               goto failure;
> > > 
> > > diff --git a/net/dccp/ipv6.c b/net/dccp/ipv6.c
> > > index fd44638ec16b..03013522acab 100644
> > > --- a/net/dccp/ipv6.c
> > > +++ b/net/dccp/ipv6.c
> > > @@ -824,6 +824,7 @@ static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
> > >       struct ipv6_pinfo *np = inet6_sk(sk);
> > >       struct dccp_sock *dp = dccp_sk(sk);
> > >       struct in6_addr *saddr = NULL, *final_p, final;
> > > +     bool prev_inaddr_any = false;
> > >       struct ipv6_txoptions *opt;
> > >       struct flowi6 fl6;
> > >       struct dst_entry *dst;
> > > @@ -936,6 +937,7 @@ static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
> > >       if (saddr == NULL) {
> > >               saddr = &fl6.saddr;
> > >               sk->sk_v6_rcv_saddr = *saddr;
> > > +             prev_inaddr_any = true;
> > >       }
> > > 
> > >       /* set the source address */
> > > @@ -951,7 +953,7 @@ static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
> > >       inet->inet_dport = usin->sin6_port;
> > > 
> > >       dccp_set_state(sk, DCCP_REQUESTING);
> > > -     err = inet6_hash_connect(&dccp_death_row, sk);
> > > +     err = inet6_hash_connect(&dccp_death_row, sk, prev_inaddr_any);
> > >       if (err)
> > >               goto late_failure;
> > > 
> > > diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c
> > > index 93da9f783bec..ad627a99ff9d 100644
> > > --- a/net/ipv4/af_inet.c
> > > +++ b/net/ipv4/af_inet.c
> > > @@ -1221,10 +1221,11 @@ static int inet_sk_reselect_saddr(struct sock *sk)
> > >       struct inet_sock *inet = inet_sk(sk);
> > >       __be32 old_saddr = inet->inet_saddr;
> > >       __be32 daddr = inet->inet_daddr;
> > > +     struct ip_options_rcu *inet_opt;
> > >       struct flowi4 *fl4;
> > >       struct rtable *rt;
> > >       __be32 new_saddr;
> > > -     struct ip_options_rcu *inet_opt;
> > > +     int err;
> > > 
> > >       inet_opt = rcu_dereference_protected(inet->inet_opt,
> > >                                            lockdep_sock_is_held(sk));
> > > @@ -1253,6 +1254,10 @@ static int inet_sk_reselect_saddr(struct sock *sk)
> > > 
> > >       inet->inet_saddr = inet->inet_rcv_saddr = new_saddr;
> > > 
> > > +     err = inet_bhash2_update_saddr(sk);
> > > +     if (err)
> > > +             return err;
> > > +
> > >       /*
> > >        * XXX The only one ugly spot where we need to
> > >        * XXX really change the sockets identity after
> > > diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c
> > > index e8de5e699b3f..592b70663a3b 100644
> > > --- a/net/ipv4/inet_hashtables.c
> > > +++ b/net/ipv4/inet_hashtables.c
> > > @@ -826,6 +826,55 @@ inet_bind2_bucket_find(struct inet_hashinfo *hinfo, struct net *net,
> > >       return bhash2;
> > >  }
> > > 
> > > +/* the lock for the socket's corresponding bhash entry must be held */
> > > +static int __inet_bhash2_update_saddr(struct sock *sk,
> > > +                                   struct inet_hashinfo *hinfo,
> > > +                                   struct net *net, int port, int l3mdev)
> > > +{
> > > +     struct inet_bind2_hashbucket *head2;
> > > +     struct inet_bind2_bucket *tb2;
> > > +
> > > +     tb2 = inet_bind2_bucket_find(hinfo, net, port, l3mdev, sk,
> > > +                                  &head2);
> > > +     if (!tb2) {
> > > +             tb2 = inet_bind2_bucket_create(hinfo->bind2_bucket_cachep,
> > > +                                            net, head2, port, l3mdev, sk);
> > > +             if (!tb2)
> > > +                     return -ENOMEM;
> > > +     }
> > > +
> > > +     /* Remove the socket's old entry from bhash2 */
> > > +     __sk_del_bind2_node(sk);
> > > +
> > > +     sk_add_bind2_node(sk, &tb2->owners);
> > > +     inet_csk(sk)->icsk_bind2_hash = tb2;
> > > +
> > > +     return 0;
> > > +}
> > > +
> > > +/* This should be called if/when a socket's rcv saddr changes after it has
> > > + * been binded.
> > > + */
> > > +int inet_bhash2_update_saddr(struct sock *sk)
> > > +{
> > > +     struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
> > > +     int l3mdev = inet_sk_bound_l3mdev(sk);
> > > +     struct inet_bind_hashbucket *head;
> > > +     int port = inet_sk(sk)->inet_num;
> > > +     struct net *net = sock_net(sk);
> > > +     int err;
> > > +
> > > +     head = &hinfo->bhash[inet_bhashfn(net, port, hinfo->bhash_size)];
> > > +
> > > +     spin_lock_bh(&head->lock);
> > > +
> > > +     err = __inet_bhash2_update_saddr(sk, hinfo, net, port, l3mdev);
> > > +
> > > +     spin_unlock_bh(&head->lock);
> > > +
> > > +     return err;
> > > +}
> > > +
> > >  /* RFC 6056 3.3.4.  Algorithm 4: Double-Hash Port Selection Algorithm
> > >   * Note that we use 32bit integers (vs RFC 'short integers')
> > >   * because 2^16 is not a multiple of num_ephemeral and this
> > > @@ -840,7 +889,7 @@ inet_bind2_bucket_find(struct inet_hashinfo *hinfo, struct net *net,
> > >  static u32 *table_perturb;
> > > 
> > >  int __inet_hash_connect(struct inet_timewait_death_row *death_row,
> > > -             struct sock *sk, u64 port_offset,
> > > +             struct sock *sk, u64 port_offset, bool prev_inaddr_any,
> > >               int (*check_established)(struct inet_timewait_death_row *,
> > >                       struct sock *, __u16, struct inet_timewait_sock **))
> > >  {
> > > @@ -858,11 +907,24 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
> > >       int l3mdev;
> > >       u32 index;
> > > 
> > > +     l3mdev = inet_sk_bound_l3mdev(sk);
> > > +
> > >       if (port) {
> > >               head = &hinfo->bhash[inet_bhashfn(net, port,
> > >                                                 hinfo->bhash_size)];
> > >               tb = inet_csk(sk)->icsk_bind_hash;
> > > +
> > >               spin_lock_bh(&head->lock);
> > > +
> > > +             if (prev_inaddr_any) {
> > > +                     ret = __inet_bhash2_update_saddr(sk, hinfo, net, port,
> > > +                                                      l3mdev);
> > > +                     if (ret) {
> > > +                             spin_unlock_bh(&head->lock);
> > > +                             return ret;
> > > +                     }
> > > +             }
> > > +
> > >               if (sk_head(&tb->owners) == sk && !sk->sk_bind_node.next) {
> > >                       inet_ehash_nolisten(sk, NULL, NULL);
> > >                       spin_unlock_bh(&head->lock);
> > > @@ -875,8 +937,6 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
> > >               return ret;
> > >       }
> > > 
> > > -     l3mdev = inet_sk_bound_l3mdev(sk);
> > > -
> > >       inet_get_local_port_range(net, &low, &high);
> > >       high++; /* [32768, 60999] -> [32768, 61000[ */
> > >       remaining = high - low;
> > > @@ -987,13 +1047,13 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
> > >   * Bind a port for a connect operation and hash it.
> > >   */
> > >  int inet_hash_connect(struct inet_timewait_death_row *death_row,
> > > -                   struct sock *sk)
> > > +                   struct sock *sk, bool prev_inaddr_any)
> > >  {
> > >       u64 port_offset = 0;
> > > 
> > >       if (!inet_sk(sk)->inet_num)
> > >               port_offset = inet_sk_port_offset(sk);
> > > -     return __inet_hash_connect(death_row, sk, port_offset,
> > > +     return __inet_hash_connect(death_row, sk, port_offset, prev_inaddr_any,
> > >                                  __inet_check_established);
> > >  }
> > >  EXPORT_SYMBOL_GPL(inet_hash_connect);
> > > diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
> > > index dac2650f3863..adf8d750933d 100644
> > > --- a/net/ipv4/tcp_ipv4.c
> > > +++ b/net/ipv4/tcp_ipv4.c
> > > @@ -203,6 +203,7 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
> > >       struct inet_sock *inet = inet_sk(sk);
> > >       struct tcp_sock *tp = tcp_sk(sk);
> > >       __be16 orig_sport, orig_dport;
> > > +     bool prev_inaddr_any = false;
> > >       __be32 daddr, nexthop;
> > >       struct flowi4 *fl4;
> > >       struct rtable *rt;
> > > @@ -246,8 +247,11 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
> > >       if (!inet_opt || !inet_opt->opt.srr)
> > >               daddr = fl4->daddr;
> > > 
> > > -     if (!inet->inet_saddr)
> > > +     if (!inet->inet_saddr) {
> > >               inet->inet_saddr = fl4->saddr;
> > > +             prev_inaddr_any = true;
> > > +     }
> > > +
> > >       sk_rcv_saddr_set(sk, inet->inet_saddr);
> > > 
> > >       if (tp->rx_opt.ts_recent_stamp && inet->inet_daddr != daddr) {
> > > @@ -273,7 +277,7 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
> > >        * complete initialization after this.
> > >        */
> > >       tcp_set_state(sk, TCP_SYN_SENT);
> > > -     err = inet_hash_connect(tcp_death_row, sk);
> > > +     err = inet_hash_connect(tcp_death_row, sk, prev_inaddr_any);
> > >       if (err)
> > >               goto failure;
> > > 
> > > diff --git a/net/ipv6/inet6_hashtables.c b/net/ipv6/inet6_hashtables.c
> > > index 7d53d62783b1..c87c5933f3be 100644
> > > --- a/net/ipv6/inet6_hashtables.c
> > > +++ b/net/ipv6/inet6_hashtables.c
> > > @@ -317,13 +317,13 @@ static u64 inet6_sk_port_offset(const struct sock *sk)
> > >  }
> > > 
> > >  int inet6_hash_connect(struct inet_timewait_death_row *death_row,
> > > -                    struct sock *sk)
> > > +                    struct sock *sk, bool prev_inaddr_any)
> > >  {
> > >       u64 port_offset = 0;
> > > 
> > >       if (!inet_sk(sk)->inet_num)
> > >               port_offset = inet6_sk_port_offset(sk);
> > > -     return __inet_hash_connect(death_row, sk, port_offset,
> > > +     return __inet_hash_connect(death_row, sk, port_offset, prev_inaddr_any,
> > >                                  __inet6_check_established);
> > >  }
> > >  EXPORT_SYMBOL_GPL(inet6_hash_connect);
> > > diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c
> > > index f37dd4aa91c6..81e3312c2a97 100644
> > > --- a/net/ipv6/tcp_ipv6.c
> > > +++ b/net/ipv6/tcp_ipv6.c
> > > @@ -152,6 +152,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
> > >       struct ipv6_pinfo *np = tcp_inet6_sk(sk);
> > >       struct tcp_sock *tp = tcp_sk(sk);
> > >       struct in6_addr *saddr = NULL, *final_p, final;
> > > +     bool prev_inaddr_any = false;
> > >       struct ipv6_txoptions *opt;
> > >       struct flowi6 fl6;
> > >       struct dst_entry *dst;
> > > @@ -289,6 +290,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
> > >       if (!saddr) {
> > >               saddr = &fl6.saddr;
> > >               sk->sk_v6_rcv_saddr = *saddr;
> > > +             prev_inaddr_any = true;
> > >       }
> > > 
> > >       /* set the source address */
> > > @@ -309,7 +311,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
> > > 
> > >       tcp_set_state(sk, TCP_SYN_SENT);
> > >       tcp_death_row = sock_net(sk)->ipv4.tcp_death_row;
> > > -     err = inet6_hash_connect(tcp_death_row, sk);
> > > +     err = inet6_hash_connect(tcp_death_row, sk, prev_inaddr_any);
> > >       if (err)
> > >               goto late_failure;
> > > 
> > 
> > I'm sorry for the late notice, but it looks like that the mptcp
> > syzkaller instance is still hitting the Warning in icsk_get_port on top
> > of the v1 of this series:
> > 
> > https://github.com/multipath-tcp/mptcp_net-next/issues/279
> > 
> > and the change in v2 should not address that. @Mat could you please
> > confirm the above?
> > 
> > Dumb question: I don't understand how the locking in bhash2 works.
> > Could you explain that?
> > 
> > What happens when 2 different processes bind different sockets on
> > different ports (with different bhash buckets) using different
> > addresses so that they hit the same bhash2 bucket? AFAICS each process
> > will use a different lock and access/modification to bhash2 could
> > happen simultaneusly?
> Hi Paolo. Yes, I think you are correct here that there could be a
> scenario where this happens. Unfortunately, I think this means the
> bhash2 table will need its own lock. I will submit a follow-up for
> this.
> 

I'm wondering if we could (and more importantly, if it would make any
sense) resort to use bhash2 usage only?

e.g. add a spinlock per row to bhash2, use it to protect row
manipulation, always hash into bhash2 and then drop bhash, similar to
commit cae3873c5b3a4fcd9706fb461ff4e91bdf1f0120?

Cheers,

Paolo
Joanne Koong June 8, 2022, 5:47 p.m. UTC | #5
On Wed, Jun 8, 2022 at 12:35 AM Paolo Abeni <pabeni@redhat.com> wrote:
>
> On Tue, 2022-06-07 at 13:24 -0700, Joanne Koong wrote:
> > On Tue, Jun 7, 2022 at 1:33 AM Paolo Abeni <pabeni@redhat.com> wrote:
> > >
> > > Hello,
> > >
> > > On Thu, 2022-06-02 at 09:51 -0700, Joanne Koong wrote:
> > > > Commit d5a42de8bdbe ("net: Add a second bind table hashed by port and
> > > > address") added a second bind table, bhash2, that hashes by a socket's port
> > > > and rcv address.
> > > >
> > > > However, there are two cases where the socket's rcv saddr can change
> > > > after it has been binded:
> > > >
> > > > 1) The case where there is a bind() call on "::" (IPADDR_ANY) and then
> > > > a connect() call. The kernel will assign the socket an address when it
> > > > handles the connect()
> > > >
> > > > 2) In inet_sk_reselect_saddr(), which is called when rerouting fails
> > > > when rebuilding the sk header (invoked by inet_sk_rebuild_header)
> > > >
> > > > In these two cases, we need to update the bhash2 table by removing the
> > > > entry for the old address, and adding a new entry reflecting the updated
> > > > address.
> > > >
> > > > Reported-by: syzbot+015d756bbd1f8b5c8f09@syzkaller.appspotmail.com
> > > > Fixes: d5a42de8bdbe ("net: Add a second bind table hashed by port and address")
> > > > Signed-off-by: Joanne Koong <joannelkoong@gmail.com>
> > > > Reviewed-by: Eric Dumazet <edumazet@google.com>
> > > > ---
> > > >  include/net/inet_hashtables.h |  6 ++-
> > > >  include/net/ipv6.h            |  2 +-
> > > >  net/dccp/ipv4.c               | 10 +++--
> > > >  net/dccp/ipv6.c               |  4 +-
> > > >  net/ipv4/af_inet.c            |  7 +++-
> > > >  net/ipv4/inet_hashtables.c    | 70 ++++++++++++++++++++++++++++++++---
> > > >  net/ipv4/tcp_ipv4.c           |  8 +++-
> > > >  net/ipv6/inet6_hashtables.c   |  4 +-
> > > >  net/ipv6/tcp_ipv6.c           |  4 +-
> > > >  9 files changed, 97 insertions(+), 18 deletions(-)
> > > >
> > > > diff --git a/include/net/inet_hashtables.h b/include/net/inet_hashtables.h
> > > > index a0887b70967b..2c331ce6ca73 100644
> > > > --- a/include/net/inet_hashtables.h
> > > > +++ b/include/net/inet_hashtables.h
> > > > @@ -448,11 +448,13 @@ static inline void sk_rcv_saddr_set(struct sock *sk, __be32 addr)
> > > >  }
> > > >
> > > >  int __inet_hash_connect(struct inet_timewait_death_row *death_row,
> > > > -                     struct sock *sk, u64 port_offset,
> > > > +                     struct sock *sk, u64 port_offset, bool prev_inaddr_any,
> > > >                       int (*check_established)(struct inet_timewait_death_row *,
> > > >                                                struct sock *, __u16,
> > > >                                                struct inet_timewait_sock **));
> > > >
> > > >  int inet_hash_connect(struct inet_timewait_death_row *death_row,
> > > > -                   struct sock *sk);
> > > > +                   struct sock *sk, bool prev_inaddr_any);
> > > > +
> > > > +int inet_bhash2_update_saddr(struct sock *sk);
> > > >  #endif /* _INET_HASHTABLES_H */
> > > > diff --git a/include/net/ipv6.h b/include/net/ipv6.h
> > > > index 5b38bf1a586b..6a50aca56d50 100644
> > > > --- a/include/net/ipv6.h
> > > > +++ b/include/net/ipv6.h
> > > > @@ -1187,7 +1187,7 @@ int inet6_compat_ioctl(struct socket *sock, unsigned int cmd,
> > > >               unsigned long arg);
> > > >
> > > >  int inet6_hash_connect(struct inet_timewait_death_row *death_row,
> > > > -                           struct sock *sk);
> > > > +                    struct sock *sk, bool prev_inaddr_any);
> > > >  int inet6_sendmsg(struct socket *sock, struct msghdr *msg, size_t size);
> > > >  int inet6_recvmsg(struct socket *sock, struct msghdr *msg, size_t size,
> > > >                 int flags);
> > > > diff --git a/net/dccp/ipv4.c b/net/dccp/ipv4.c
> > > > index da6e3b20cd75..37a8bc3ee49e 100644
> > > > --- a/net/dccp/ipv4.c
> > > > +++ b/net/dccp/ipv4.c
> > > > @@ -47,12 +47,13 @@ int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
> > > >       const struct sockaddr_in *usin = (struct sockaddr_in *)uaddr;
> > > >       struct inet_sock *inet = inet_sk(sk);
> > > >       struct dccp_sock *dp = dccp_sk(sk);
> > > > +     struct ip_options_rcu *inet_opt;
> > > >       __be16 orig_sport, orig_dport;
> > > > +     bool prev_inaddr_any = false;
> > > >       __be32 daddr, nexthop;
> > > >       struct flowi4 *fl4;
> > > >       struct rtable *rt;
> > > >       int err;
> > > > -     struct ip_options_rcu *inet_opt;
> > > >
> > > >       dp->dccps_role = DCCP_ROLE_CLIENT;
> > > >
> > > > @@ -89,8 +90,11 @@ int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
> > > >       if (inet_opt == NULL || !inet_opt->opt.srr)
> > > >               daddr = fl4->daddr;
> > > >
> > > > -     if (inet->inet_saddr == 0)
> > > > +     if (inet->inet_saddr == 0) {
> > > >               inet->inet_saddr = fl4->saddr;
> > > > +             prev_inaddr_any = true;
> > > > +     }
> > > > +
> > > >       sk_rcv_saddr_set(sk, inet->inet_saddr);
> > > >       inet->inet_dport = usin->sin_port;
> > > >       sk_daddr_set(sk, daddr);
> > > > @@ -105,7 +109,7 @@ int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
> > > >        * complete initialization after this.
> > > >        */
> > > >       dccp_set_state(sk, DCCP_REQUESTING);
> > > > -     err = inet_hash_connect(&dccp_death_row, sk);
> > > > +     err = inet_hash_connect(&dccp_death_row, sk, prev_inaddr_any);
> > > >       if (err != 0)
> > > >               goto failure;
> > > >
> > > > diff --git a/net/dccp/ipv6.c b/net/dccp/ipv6.c
> > > > index fd44638ec16b..03013522acab 100644
> > > > --- a/net/dccp/ipv6.c
> > > > +++ b/net/dccp/ipv6.c
> > > > @@ -824,6 +824,7 @@ static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
> > > >       struct ipv6_pinfo *np = inet6_sk(sk);
> > > >       struct dccp_sock *dp = dccp_sk(sk);
> > > >       struct in6_addr *saddr = NULL, *final_p, final;
> > > > +     bool prev_inaddr_any = false;
> > > >       struct ipv6_txoptions *opt;
> > > >       struct flowi6 fl6;
> > > >       struct dst_entry *dst;
> > > > @@ -936,6 +937,7 @@ static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
> > > >       if (saddr == NULL) {
> > > >               saddr = &fl6.saddr;
> > > >               sk->sk_v6_rcv_saddr = *saddr;
> > > > +             prev_inaddr_any = true;
> > > >       }
> > > >
> > > >       /* set the source address */
> > > > @@ -951,7 +953,7 @@ static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
> > > >       inet->inet_dport = usin->sin6_port;
> > > >
> > > >       dccp_set_state(sk, DCCP_REQUESTING);
> > > > -     err = inet6_hash_connect(&dccp_death_row, sk);
> > > > +     err = inet6_hash_connect(&dccp_death_row, sk, prev_inaddr_any);
> > > >       if (err)
> > > >               goto late_failure;
> > > >
> > > > diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c
> > > > index 93da9f783bec..ad627a99ff9d 100644
> > > > --- a/net/ipv4/af_inet.c
> > > > +++ b/net/ipv4/af_inet.c
> > > > @@ -1221,10 +1221,11 @@ static int inet_sk_reselect_saddr(struct sock *sk)
> > > >       struct inet_sock *inet = inet_sk(sk);
> > > >       __be32 old_saddr = inet->inet_saddr;
> > > >       __be32 daddr = inet->inet_daddr;
> > > > +     struct ip_options_rcu *inet_opt;
> > > >       struct flowi4 *fl4;
> > > >       struct rtable *rt;
> > > >       __be32 new_saddr;
> > > > -     struct ip_options_rcu *inet_opt;
> > > > +     int err;
> > > >
> > > >       inet_opt = rcu_dereference_protected(inet->inet_opt,
> > > >                                            lockdep_sock_is_held(sk));
> > > > @@ -1253,6 +1254,10 @@ static int inet_sk_reselect_saddr(struct sock *sk)
> > > >
> > > >       inet->inet_saddr = inet->inet_rcv_saddr = new_saddr;
> > > >
> > > > +     err = inet_bhash2_update_saddr(sk);
> > > > +     if (err)
> > > > +             return err;
> > > > +
> > > >       /*
> > > >        * XXX The only one ugly spot where we need to
> > > >        * XXX really change the sockets identity after
> > > > diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c
> > > > index e8de5e699b3f..592b70663a3b 100644
> > > > --- a/net/ipv4/inet_hashtables.c
> > > > +++ b/net/ipv4/inet_hashtables.c
> > > > @@ -826,6 +826,55 @@ inet_bind2_bucket_find(struct inet_hashinfo *hinfo, struct net *net,
> > > >       return bhash2;
> > > >  }
> > > >
> > > > +/* the lock for the socket's corresponding bhash entry must be held */
> > > > +static int __inet_bhash2_update_saddr(struct sock *sk,
> > > > +                                   struct inet_hashinfo *hinfo,
> > > > +                                   struct net *net, int port, int l3mdev)
> > > > +{
> > > > +     struct inet_bind2_hashbucket *head2;
> > > > +     struct inet_bind2_bucket *tb2;
> > > > +
> > > > +     tb2 = inet_bind2_bucket_find(hinfo, net, port, l3mdev, sk,
> > > > +                                  &head2);
> > > > +     if (!tb2) {
> > > > +             tb2 = inet_bind2_bucket_create(hinfo->bind2_bucket_cachep,
> > > > +                                            net, head2, port, l3mdev, sk);
> > > > +             if (!tb2)
> > > > +                     return -ENOMEM;
> > > > +     }
> > > > +
> > > > +     /* Remove the socket's old entry from bhash2 */
> > > > +     __sk_del_bind2_node(sk);
> > > > +
> > > > +     sk_add_bind2_node(sk, &tb2->owners);
> > > > +     inet_csk(sk)->icsk_bind2_hash = tb2;
> > > > +
> > > > +     return 0;
> > > > +}
> > > > +
> > > > +/* This should be called if/when a socket's rcv saddr changes after it has
> > > > + * been binded.
> > > > + */
> > > > +int inet_bhash2_update_saddr(struct sock *sk)
> > > > +{
> > > > +     struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
> > > > +     int l3mdev = inet_sk_bound_l3mdev(sk);
> > > > +     struct inet_bind_hashbucket *head;
> > > > +     int port = inet_sk(sk)->inet_num;
> > > > +     struct net *net = sock_net(sk);
> > > > +     int err;
> > > > +
> > > > +     head = &hinfo->bhash[inet_bhashfn(net, port, hinfo->bhash_size)];
> > > > +
> > > > +     spin_lock_bh(&head->lock);
> > > > +
> > > > +     err = __inet_bhash2_update_saddr(sk, hinfo, net, port, l3mdev);
> > > > +
> > > > +     spin_unlock_bh(&head->lock);
> > > > +
> > > > +     return err;
> > > > +}
> > > > +
> > > >  /* RFC 6056 3.3.4.  Algorithm 4: Double-Hash Port Selection Algorithm
> > > >   * Note that we use 32bit integers (vs RFC 'short integers')
> > > >   * because 2^16 is not a multiple of num_ephemeral and this
> > > > @@ -840,7 +889,7 @@ inet_bind2_bucket_find(struct inet_hashinfo *hinfo, struct net *net,
> > > >  static u32 *table_perturb;
> > > >
> > > >  int __inet_hash_connect(struct inet_timewait_death_row *death_row,
> > > > -             struct sock *sk, u64 port_offset,
> > > > +             struct sock *sk, u64 port_offset, bool prev_inaddr_any,
> > > >               int (*check_established)(struct inet_timewait_death_row *,
> > > >                       struct sock *, __u16, struct inet_timewait_sock **))
> > > >  {
> > > > @@ -858,11 +907,24 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
> > > >       int l3mdev;
> > > >       u32 index;
> > > >
> > > > +     l3mdev = inet_sk_bound_l3mdev(sk);
> > > > +
> > > >       if (port) {
> > > >               head = &hinfo->bhash[inet_bhashfn(net, port,
> > > >                                                 hinfo->bhash_size)];
> > > >               tb = inet_csk(sk)->icsk_bind_hash;
> > > > +
> > > >               spin_lock_bh(&head->lock);
> > > > +
> > > > +             if (prev_inaddr_any) {
> > > > +                     ret = __inet_bhash2_update_saddr(sk, hinfo, net, port,
> > > > +                                                      l3mdev);
> > > > +                     if (ret) {
> > > > +                             spin_unlock_bh(&head->lock);
> > > > +                             return ret;
> > > > +                     }
> > > > +             }
> > > > +
> > > >               if (sk_head(&tb->owners) == sk && !sk->sk_bind_node.next) {
> > > >                       inet_ehash_nolisten(sk, NULL, NULL);
> > > >                       spin_unlock_bh(&head->lock);
> > > > @@ -875,8 +937,6 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
> > > >               return ret;
> > > >       }
> > > >
> > > > -     l3mdev = inet_sk_bound_l3mdev(sk);
> > > > -
> > > >       inet_get_local_port_range(net, &low, &high);
> > > >       high++; /* [32768, 60999] -> [32768, 61000[ */
> > > >       remaining = high - low;
> > > > @@ -987,13 +1047,13 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
> > > >   * Bind a port for a connect operation and hash it.
> > > >   */
> > > >  int inet_hash_connect(struct inet_timewait_death_row *death_row,
> > > > -                   struct sock *sk)
> > > > +                   struct sock *sk, bool prev_inaddr_any)
> > > >  {
> > > >       u64 port_offset = 0;
> > > >
> > > >       if (!inet_sk(sk)->inet_num)
> > > >               port_offset = inet_sk_port_offset(sk);
> > > > -     return __inet_hash_connect(death_row, sk, port_offset,
> > > > +     return __inet_hash_connect(death_row, sk, port_offset, prev_inaddr_any,
> > > >                                  __inet_check_established);
> > > >  }
> > > >  EXPORT_SYMBOL_GPL(inet_hash_connect);
> > > > diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
> > > > index dac2650f3863..adf8d750933d 100644
> > > > --- a/net/ipv4/tcp_ipv4.c
> > > > +++ b/net/ipv4/tcp_ipv4.c
> > > > @@ -203,6 +203,7 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
> > > >       struct inet_sock *inet = inet_sk(sk);
> > > >       struct tcp_sock *tp = tcp_sk(sk);
> > > >       __be16 orig_sport, orig_dport;
> > > > +     bool prev_inaddr_any = false;
> > > >       __be32 daddr, nexthop;
> > > >       struct flowi4 *fl4;
> > > >       struct rtable *rt;
> > > > @@ -246,8 +247,11 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
> > > >       if (!inet_opt || !inet_opt->opt.srr)
> > > >               daddr = fl4->daddr;
> > > >
> > > > -     if (!inet->inet_saddr)
> > > > +     if (!inet->inet_saddr) {
> > > >               inet->inet_saddr = fl4->saddr;
> > > > +             prev_inaddr_any = true;
> > > > +     }
> > > > +
> > > >       sk_rcv_saddr_set(sk, inet->inet_saddr);
> > > >
> > > >       if (tp->rx_opt.ts_recent_stamp && inet->inet_daddr != daddr) {
> > > > @@ -273,7 +277,7 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
> > > >        * complete initialization after this.
> > > >        */
> > > >       tcp_set_state(sk, TCP_SYN_SENT);
> > > > -     err = inet_hash_connect(tcp_death_row, sk);
> > > > +     err = inet_hash_connect(tcp_death_row, sk, prev_inaddr_any);
> > > >       if (err)
> > > >               goto failure;
> > > >
> > > > diff --git a/net/ipv6/inet6_hashtables.c b/net/ipv6/inet6_hashtables.c
> > > > index 7d53d62783b1..c87c5933f3be 100644
> > > > --- a/net/ipv6/inet6_hashtables.c
> > > > +++ b/net/ipv6/inet6_hashtables.c
> > > > @@ -317,13 +317,13 @@ static u64 inet6_sk_port_offset(const struct sock *sk)
> > > >  }
> > > >
> > > >  int inet6_hash_connect(struct inet_timewait_death_row *death_row,
> > > > -                    struct sock *sk)
> > > > +                    struct sock *sk, bool prev_inaddr_any)
> > > >  {
> > > >       u64 port_offset = 0;
> > > >
> > > >       if (!inet_sk(sk)->inet_num)
> > > >               port_offset = inet6_sk_port_offset(sk);
> > > > -     return __inet_hash_connect(death_row, sk, port_offset,
> > > > +     return __inet_hash_connect(death_row, sk, port_offset, prev_inaddr_any,
> > > >                                  __inet6_check_established);
> > > >  }
> > > >  EXPORT_SYMBOL_GPL(inet6_hash_connect);
> > > > diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c
> > > > index f37dd4aa91c6..81e3312c2a97 100644
> > > > --- a/net/ipv6/tcp_ipv6.c
> > > > +++ b/net/ipv6/tcp_ipv6.c
> > > > @@ -152,6 +152,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
> > > >       struct ipv6_pinfo *np = tcp_inet6_sk(sk);
> > > >       struct tcp_sock *tp = tcp_sk(sk);
> > > >       struct in6_addr *saddr = NULL, *final_p, final;
> > > > +     bool prev_inaddr_any = false;
> > > >       struct ipv6_txoptions *opt;
> > > >       struct flowi6 fl6;
> > > >       struct dst_entry *dst;
> > > > @@ -289,6 +290,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
> > > >       if (!saddr) {
> > > >               saddr = &fl6.saddr;
> > > >               sk->sk_v6_rcv_saddr = *saddr;
> > > > +             prev_inaddr_any = true;
> > > >       }
> > > >
> > > >       /* set the source address */
> > > > @@ -309,7 +311,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
> > > >
> > > >       tcp_set_state(sk, TCP_SYN_SENT);
> > > >       tcp_death_row = sock_net(sk)->ipv4.tcp_death_row;
> > > > -     err = inet6_hash_connect(tcp_death_row, sk);
> > > > +     err = inet6_hash_connect(tcp_death_row, sk, prev_inaddr_any);
> > > >       if (err)
> > > >               goto late_failure;
> > > >
> > >
> > > I'm sorry for the late notice, but it looks like that the mptcp
> > > syzkaller instance is still hitting the Warning in icsk_get_port on top
> > > of the v1 of this series:
> > >
> > > https://github.com/multipath-tcp/mptcp_net-next/issues/279
> > >
> > > and the change in v2 should not address that. @Mat could you please
> > > confirm the above?
> > >
> > > Dumb question: I don't understand how the locking in bhash2 works.
> > > Could you explain that?
> > >
> > > What happens when 2 different processes bind different sockets on
> > > different ports (with different bhash buckets) using different
> > > addresses so that they hit the same bhash2 bucket? AFAICS each process
> > > will use a different lock and access/modification to bhash2 could
> > > happen simultaneusly?
> > Hi Paolo. Yes, I think you are correct here that there could be a
> > scenario where this happens. Unfortunately, I think this means the
> > bhash2 table will need its own lock. I will submit a follow-up for
> > this.
> >
>
> I'm wondering if we could (and more importantly, if it would make any
> sense) resort to use bhash2 usage only?
>
> e.g. add a spinlock per row to bhash2, use it to protect row
> manipulation, always hash into bhash2 and then drop bhash, similar to
> commit cae3873c5b3a4fcd9706fb461ff4e91bdf1f0120?

This would be the ideal solution but I think we need the bhash table
to handle the case where the bind request is for address INADDR_ANY.
If the bind request is for INADDR_ANY on a port, then we need to check
every socket already bound to that port to determine if there is a
conflict. Without the bhash table (which hashes only by port number),
we don't have a way of getting all the sockets that are bound to a
specified port.

>
> Cheers,
>
> Paolo
>
>
>
>
diff mbox series

Patch

diff --git a/include/net/inet_hashtables.h b/include/net/inet_hashtables.h
index a0887b70967b..2c331ce6ca73 100644
--- a/include/net/inet_hashtables.h
+++ b/include/net/inet_hashtables.h
@@ -448,11 +448,13 @@  static inline void sk_rcv_saddr_set(struct sock *sk, __be32 addr)
 }
 
 int __inet_hash_connect(struct inet_timewait_death_row *death_row,
-			struct sock *sk, u64 port_offset,
+			struct sock *sk, u64 port_offset, bool prev_inaddr_any,
 			int (*check_established)(struct inet_timewait_death_row *,
 						 struct sock *, __u16,
 						 struct inet_timewait_sock **));
 
 int inet_hash_connect(struct inet_timewait_death_row *death_row,
-		      struct sock *sk);
+		      struct sock *sk, bool prev_inaddr_any);
+
+int inet_bhash2_update_saddr(struct sock *sk);
 #endif /* _INET_HASHTABLES_H */
diff --git a/include/net/ipv6.h b/include/net/ipv6.h
index 5b38bf1a586b..6a50aca56d50 100644
--- a/include/net/ipv6.h
+++ b/include/net/ipv6.h
@@ -1187,7 +1187,7 @@  int inet6_compat_ioctl(struct socket *sock, unsigned int cmd,
 		unsigned long arg);
 
 int inet6_hash_connect(struct inet_timewait_death_row *death_row,
-			      struct sock *sk);
+		       struct sock *sk, bool prev_inaddr_any);
 int inet6_sendmsg(struct socket *sock, struct msghdr *msg, size_t size);
 int inet6_recvmsg(struct socket *sock, struct msghdr *msg, size_t size,
 		  int flags);
diff --git a/net/dccp/ipv4.c b/net/dccp/ipv4.c
index da6e3b20cd75..37a8bc3ee49e 100644
--- a/net/dccp/ipv4.c
+++ b/net/dccp/ipv4.c
@@ -47,12 +47,13 @@  int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
 	const struct sockaddr_in *usin = (struct sockaddr_in *)uaddr;
 	struct inet_sock *inet = inet_sk(sk);
 	struct dccp_sock *dp = dccp_sk(sk);
+	struct ip_options_rcu *inet_opt;
 	__be16 orig_sport, orig_dport;
+	bool prev_inaddr_any = false;
 	__be32 daddr, nexthop;
 	struct flowi4 *fl4;
 	struct rtable *rt;
 	int err;
-	struct ip_options_rcu *inet_opt;
 
 	dp->dccps_role = DCCP_ROLE_CLIENT;
 
@@ -89,8 +90,11 @@  int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
 	if (inet_opt == NULL || !inet_opt->opt.srr)
 		daddr = fl4->daddr;
 
-	if (inet->inet_saddr == 0)
+	if (inet->inet_saddr == 0) {
 		inet->inet_saddr = fl4->saddr;
+		prev_inaddr_any = true;
+	}
+
 	sk_rcv_saddr_set(sk, inet->inet_saddr);
 	inet->inet_dport = usin->sin_port;
 	sk_daddr_set(sk, daddr);
@@ -105,7 +109,7 @@  int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
 	 * complete initialization after this.
 	 */
 	dccp_set_state(sk, DCCP_REQUESTING);
-	err = inet_hash_connect(&dccp_death_row, sk);
+	err = inet_hash_connect(&dccp_death_row, sk, prev_inaddr_any);
 	if (err != 0)
 		goto failure;
 
diff --git a/net/dccp/ipv6.c b/net/dccp/ipv6.c
index fd44638ec16b..03013522acab 100644
--- a/net/dccp/ipv6.c
+++ b/net/dccp/ipv6.c
@@ -824,6 +824,7 @@  static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
 	struct ipv6_pinfo *np = inet6_sk(sk);
 	struct dccp_sock *dp = dccp_sk(sk);
 	struct in6_addr *saddr = NULL, *final_p, final;
+	bool prev_inaddr_any = false;
 	struct ipv6_txoptions *opt;
 	struct flowi6 fl6;
 	struct dst_entry *dst;
@@ -936,6 +937,7 @@  static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
 	if (saddr == NULL) {
 		saddr = &fl6.saddr;
 		sk->sk_v6_rcv_saddr = *saddr;
+		prev_inaddr_any = true;
 	}
 
 	/* set the source address */
@@ -951,7 +953,7 @@  static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
 	inet->inet_dport = usin->sin6_port;
 
 	dccp_set_state(sk, DCCP_REQUESTING);
-	err = inet6_hash_connect(&dccp_death_row, sk);
+	err = inet6_hash_connect(&dccp_death_row, sk, prev_inaddr_any);
 	if (err)
 		goto late_failure;
 
diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c
index 93da9f783bec..ad627a99ff9d 100644
--- a/net/ipv4/af_inet.c
+++ b/net/ipv4/af_inet.c
@@ -1221,10 +1221,11 @@  static int inet_sk_reselect_saddr(struct sock *sk)
 	struct inet_sock *inet = inet_sk(sk);
 	__be32 old_saddr = inet->inet_saddr;
 	__be32 daddr = inet->inet_daddr;
+	struct ip_options_rcu *inet_opt;
 	struct flowi4 *fl4;
 	struct rtable *rt;
 	__be32 new_saddr;
-	struct ip_options_rcu *inet_opt;
+	int err;
 
 	inet_opt = rcu_dereference_protected(inet->inet_opt,
 					     lockdep_sock_is_held(sk));
@@ -1253,6 +1254,10 @@  static int inet_sk_reselect_saddr(struct sock *sk)
 
 	inet->inet_saddr = inet->inet_rcv_saddr = new_saddr;
 
+	err = inet_bhash2_update_saddr(sk);
+	if (err)
+		return err;
+
 	/*
 	 * XXX The only one ugly spot where we need to
 	 * XXX really change the sockets identity after
diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c
index e8de5e699b3f..592b70663a3b 100644
--- a/net/ipv4/inet_hashtables.c
+++ b/net/ipv4/inet_hashtables.c
@@ -826,6 +826,55 @@  inet_bind2_bucket_find(struct inet_hashinfo *hinfo, struct net *net,
 	return bhash2;
 }
 
+/* the lock for the socket's corresponding bhash entry must be held */
+static int __inet_bhash2_update_saddr(struct sock *sk,
+				      struct inet_hashinfo *hinfo,
+				      struct net *net, int port, int l3mdev)
+{
+	struct inet_bind2_hashbucket *head2;
+	struct inet_bind2_bucket *tb2;
+
+	tb2 = inet_bind2_bucket_find(hinfo, net, port, l3mdev, sk,
+				     &head2);
+	if (!tb2) {
+		tb2 = inet_bind2_bucket_create(hinfo->bind2_bucket_cachep,
+					       net, head2, port, l3mdev, sk);
+		if (!tb2)
+			return -ENOMEM;
+	}
+
+	/* Remove the socket's old entry from bhash2 */
+	__sk_del_bind2_node(sk);
+
+	sk_add_bind2_node(sk, &tb2->owners);
+	inet_csk(sk)->icsk_bind2_hash = tb2;
+
+	return 0;
+}
+
+/* This should be called if/when a socket's rcv saddr changes after it has
+ * been binded.
+ */
+int inet_bhash2_update_saddr(struct sock *sk)
+{
+	struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
+	int l3mdev = inet_sk_bound_l3mdev(sk);
+	struct inet_bind_hashbucket *head;
+	int port = inet_sk(sk)->inet_num;
+	struct net *net = sock_net(sk);
+	int err;
+
+	head = &hinfo->bhash[inet_bhashfn(net, port, hinfo->bhash_size)];
+
+	spin_lock_bh(&head->lock);
+
+	err = __inet_bhash2_update_saddr(sk, hinfo, net, port, l3mdev);
+
+	spin_unlock_bh(&head->lock);
+
+	return err;
+}
+
 /* RFC 6056 3.3.4.  Algorithm 4: Double-Hash Port Selection Algorithm
  * Note that we use 32bit integers (vs RFC 'short integers')
  * because 2^16 is not a multiple of num_ephemeral and this
@@ -840,7 +889,7 @@  inet_bind2_bucket_find(struct inet_hashinfo *hinfo, struct net *net,
 static u32 *table_perturb;
 
 int __inet_hash_connect(struct inet_timewait_death_row *death_row,
-		struct sock *sk, u64 port_offset,
+		struct sock *sk, u64 port_offset, bool prev_inaddr_any,
 		int (*check_established)(struct inet_timewait_death_row *,
 			struct sock *, __u16, struct inet_timewait_sock **))
 {
@@ -858,11 +907,24 @@  int __inet_hash_connect(struct inet_timewait_death_row *death_row,
 	int l3mdev;
 	u32 index;
 
+	l3mdev = inet_sk_bound_l3mdev(sk);
+
 	if (port) {
 		head = &hinfo->bhash[inet_bhashfn(net, port,
 						  hinfo->bhash_size)];
 		tb = inet_csk(sk)->icsk_bind_hash;
+
 		spin_lock_bh(&head->lock);
+
+		if (prev_inaddr_any) {
+			ret = __inet_bhash2_update_saddr(sk, hinfo, net, port,
+							 l3mdev);
+			if (ret) {
+				spin_unlock_bh(&head->lock);
+				return ret;
+			}
+		}
+
 		if (sk_head(&tb->owners) == sk && !sk->sk_bind_node.next) {
 			inet_ehash_nolisten(sk, NULL, NULL);
 			spin_unlock_bh(&head->lock);
@@ -875,8 +937,6 @@  int __inet_hash_connect(struct inet_timewait_death_row *death_row,
 		return ret;
 	}
 
-	l3mdev = inet_sk_bound_l3mdev(sk);
-
 	inet_get_local_port_range(net, &low, &high);
 	high++; /* [32768, 60999] -> [32768, 61000[ */
 	remaining = high - low;
@@ -987,13 +1047,13 @@  int __inet_hash_connect(struct inet_timewait_death_row *death_row,
  * Bind a port for a connect operation and hash it.
  */
 int inet_hash_connect(struct inet_timewait_death_row *death_row,
-		      struct sock *sk)
+		      struct sock *sk, bool prev_inaddr_any)
 {
 	u64 port_offset = 0;
 
 	if (!inet_sk(sk)->inet_num)
 		port_offset = inet_sk_port_offset(sk);
-	return __inet_hash_connect(death_row, sk, port_offset,
+	return __inet_hash_connect(death_row, sk, port_offset, prev_inaddr_any,
 				   __inet_check_established);
 }
 EXPORT_SYMBOL_GPL(inet_hash_connect);
diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
index dac2650f3863..adf8d750933d 100644
--- a/net/ipv4/tcp_ipv4.c
+++ b/net/ipv4/tcp_ipv4.c
@@ -203,6 +203,7 @@  int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
 	struct inet_sock *inet = inet_sk(sk);
 	struct tcp_sock *tp = tcp_sk(sk);
 	__be16 orig_sport, orig_dport;
+	bool prev_inaddr_any = false;
 	__be32 daddr, nexthop;
 	struct flowi4 *fl4;
 	struct rtable *rt;
@@ -246,8 +247,11 @@  int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
 	if (!inet_opt || !inet_opt->opt.srr)
 		daddr = fl4->daddr;
 
-	if (!inet->inet_saddr)
+	if (!inet->inet_saddr) {
 		inet->inet_saddr = fl4->saddr;
+		prev_inaddr_any = true;
+	}
+
 	sk_rcv_saddr_set(sk, inet->inet_saddr);
 
 	if (tp->rx_opt.ts_recent_stamp && inet->inet_daddr != daddr) {
@@ -273,7 +277,7 @@  int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
 	 * complete initialization after this.
 	 */
 	tcp_set_state(sk, TCP_SYN_SENT);
-	err = inet_hash_connect(tcp_death_row, sk);
+	err = inet_hash_connect(tcp_death_row, sk, prev_inaddr_any);
 	if (err)
 		goto failure;
 
diff --git a/net/ipv6/inet6_hashtables.c b/net/ipv6/inet6_hashtables.c
index 7d53d62783b1..c87c5933f3be 100644
--- a/net/ipv6/inet6_hashtables.c
+++ b/net/ipv6/inet6_hashtables.c
@@ -317,13 +317,13 @@  static u64 inet6_sk_port_offset(const struct sock *sk)
 }
 
 int inet6_hash_connect(struct inet_timewait_death_row *death_row,
-		       struct sock *sk)
+		       struct sock *sk, bool prev_inaddr_any)
 {
 	u64 port_offset = 0;
 
 	if (!inet_sk(sk)->inet_num)
 		port_offset = inet6_sk_port_offset(sk);
-	return __inet_hash_connect(death_row, sk, port_offset,
+	return __inet_hash_connect(death_row, sk, port_offset, prev_inaddr_any,
 				   __inet6_check_established);
 }
 EXPORT_SYMBOL_GPL(inet6_hash_connect);
diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c
index f37dd4aa91c6..81e3312c2a97 100644
--- a/net/ipv6/tcp_ipv6.c
+++ b/net/ipv6/tcp_ipv6.c
@@ -152,6 +152,7 @@  static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
 	struct ipv6_pinfo *np = tcp_inet6_sk(sk);
 	struct tcp_sock *tp = tcp_sk(sk);
 	struct in6_addr *saddr = NULL, *final_p, final;
+	bool prev_inaddr_any = false;
 	struct ipv6_txoptions *opt;
 	struct flowi6 fl6;
 	struct dst_entry *dst;
@@ -289,6 +290,7 @@  static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
 	if (!saddr) {
 		saddr = &fl6.saddr;
 		sk->sk_v6_rcv_saddr = *saddr;
+		prev_inaddr_any = true;
 	}
 
 	/* set the source address */
@@ -309,7 +311,7 @@  static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
 
 	tcp_set_state(sk, TCP_SYN_SENT);
 	tcp_death_row = sock_net(sk)->ipv4.tcp_death_row;
-	err = inet6_hash_connect(tcp_death_row, sk);
+	err = inet6_hash_connect(tcp_death_row, sk, prev_inaddr_any);
 	if (err)
 		goto late_failure;