diff mbox series

[net-next,v3,1/3] net: Update bhash2 when socket's rcv saddr changes

Message ID 20220611021646.1578080-2-joannelkoong@gmail.com (mailing list archive)
State Changes Requested
Delegated to: Netdev Maintainers
Headers show
Series bhash2 binding table fixups | expand

Checks

Context Check Description
netdev/tree_selection success Clearly marked for net-next
netdev/fixes_present success Fixes tag not required for -next series
netdev/subject_prefix success Link
netdev/cover_letter success Series has a cover letter
netdev/patch_count success Link
netdev/header_inline success No static functions without inline keyword in header files
netdev/build_32bit success Errors and warnings before: 2055 this patch: 2055
netdev/cc_maintainers fail 1 blamed authors not CCed: kuniyu@amazon.co.jp; 4 maintainers not CCed: yoshfuji@linux-ipv6.org dccp@vger.kernel.org kuniyu@amazon.co.jp dsahern@kernel.org
netdev/build_clang success Errors and warnings before: 573 this patch: 573
netdev/module_param success Was 0 now: 0
netdev/verify_signedoff success Signed-off-by tag matches author and committer
netdev/check_selftest success No net selftest shell script
netdev/verify_fixes success Fixes tag looks correct
netdev/build_allmodconfig_warn success Errors and warnings before: 2182 this patch: 2182
netdev/checkpatch success total: 0 errors, 0 warnings, 0 checks, 275 lines checked
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/source_inline success Was 0 now: 0

Commit Message

Joanne Koong June 11, 2022, 2:16 a.m. UTC
Commit d5a42de8bdbe ("net: Add a second bind table hashed by port and
address") added a second bind table, bhash2, that hashes by a socket's port
and rcv address.

However, there are two cases where the socket's rcv saddr can change
after it has been binded:

1) The case where there is a bind() call on "::" (IPADDR_ANY) and then
a connect() call. The kernel will assign the socket an address when it
handles the connect()

2) In inet_sk_reselect_saddr(), which is called when rerouting fails
when rebuilding the sk header (invoked by inet_sk_rebuild_header)

In these two cases, we need to update the bhash2 table by removing the
entry for the old address, and adding a new entry reflecting the updated
address.

Reported-by: syzbot+015d756bbd1f8b5c8f09@syzkaller.appspotmail.com
Fixes: d5a42de8bdbe ("net: Add a second bind table hashed by port and address")
Signed-off-by: Joanne Koong <joannelkoong@gmail.com>
Reviewed-by: Eric Dumazet <edumazet@google.com>
---
 include/net/inet_hashtables.h |  6 ++-
 include/net/ipv6.h            |  2 +-
 net/dccp/ipv4.c               | 10 +++--
 net/dccp/ipv6.c               |  4 +-
 net/ipv4/af_inet.c            |  7 +++-
 net/ipv4/inet_hashtables.c    | 70 ++++++++++++++++++++++++++++++++---
 net/ipv4/tcp_ipv4.c           |  8 +++-
 net/ipv6/inet6_hashtables.c   |  4 +-
 net/ipv6/tcp_ipv6.c           |  4 +-
 9 files changed, 97 insertions(+), 18 deletions(-)

Comments

Kuniyuki Iwashima June 12, 2022, 12:24 a.m. UTC | #1
From:   Joanne Koong <joannelkoong@gmail.com>
Date:   Fri, 10 Jun 2022 19:16:44 -0700
> Commit d5a42de8bdbe ("net: Add a second bind table hashed by port and
> address") added a second bind table, bhash2, that hashes by a socket's port
> and rcv address.
> 
> However, there are two cases where the socket's rcv saddr can change
> after it has been binded:
> 
> 1) The case where there is a bind() call on "::" (IPADDR_ANY) and then
> a connect() call. The kernel will assign the socket an address when it
> handles the connect()
> 
> 2) In inet_sk_reselect_saddr(), which is called when rerouting fails
> when rebuilding the sk header (invoked by inet_sk_rebuild_header)
> 
> In these two cases, we need to update the bhash2 table by removing the
> entry for the old address, and adding a new entry reflecting the updated
> address.
> 
> Reported-by: syzbot+015d756bbd1f8b5c8f09@syzkaller.appspotmail.com
> Fixes: d5a42de8bdbe ("net: Add a second bind table hashed by port and address")
> Signed-off-by: Joanne Koong <joannelkoong@gmail.com>
> Reviewed-by: Eric Dumazet <edumazet@google.com>
> ---
>  include/net/inet_hashtables.h |  6 ++-
>  include/net/ipv6.h            |  2 +-
>  net/dccp/ipv4.c               | 10 +++--
>  net/dccp/ipv6.c               |  4 +-
>  net/ipv4/af_inet.c            |  7 +++-
>  net/ipv4/inet_hashtables.c    | 70 ++++++++++++++++++++++++++++++++---
>  net/ipv4/tcp_ipv4.c           |  8 +++-
>  net/ipv6/inet6_hashtables.c   |  4 +-
>  net/ipv6/tcp_ipv6.c           |  4 +-
>  9 files changed, 97 insertions(+), 18 deletions(-)
> 
> diff --git a/include/net/inet_hashtables.h b/include/net/inet_hashtables.h
> index a0887b70967b..2c331ce6ca73 100644
> --- a/include/net/inet_hashtables.h
> +++ b/include/net/inet_hashtables.h
> @@ -448,11 +448,13 @@ static inline void sk_rcv_saddr_set(struct sock *sk, __be32 addr)
>  }
>  
>  int __inet_hash_connect(struct inet_timewait_death_row *death_row,
> -			struct sock *sk, u64 port_offset,
> +			struct sock *sk, u64 port_offset, bool prev_inaddr_any,
>  			int (*check_established)(struct inet_timewait_death_row *,
>  						 struct sock *, __u16,
>  						 struct inet_timewait_sock **));
>  
>  int inet_hash_connect(struct inet_timewait_death_row *death_row,
> -		      struct sock *sk);
> +		      struct sock *sk, bool prev_inaddr_any);
> +
> +int inet_bhash2_update_saddr(struct sock *sk);
>  #endif /* _INET_HASHTABLES_H */
> diff --git a/include/net/ipv6.h b/include/net/ipv6.h
> index de9dcc5652c4..735f7b4d55dc 100644
> --- a/include/net/ipv6.h
> +++ b/include/net/ipv6.h
> @@ -1187,7 +1187,7 @@ int inet6_compat_ioctl(struct socket *sock, unsigned int cmd,
>  		unsigned long arg);
>  
>  int inet6_hash_connect(struct inet_timewait_death_row *death_row,
> -			      struct sock *sk);
> +		       struct sock *sk, bool prev_inaddr_any);
>  int inet6_sendmsg(struct socket *sock, struct msghdr *msg, size_t size);
>  int inet6_recvmsg(struct socket *sock, struct msghdr *msg, size_t size,
>  		  int flags);
> diff --git a/net/dccp/ipv4.c b/net/dccp/ipv4.c
> index da6e3b20cd75..37a8bc3ee49e 100644
> --- a/net/dccp/ipv4.c
> +++ b/net/dccp/ipv4.c
> @@ -47,12 +47,13 @@ int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
>  	const struct sockaddr_in *usin = (struct sockaddr_in *)uaddr;
>  	struct inet_sock *inet = inet_sk(sk);
>  	struct dccp_sock *dp = dccp_sk(sk);
> +	struct ip_options_rcu *inet_opt;
>  	__be16 orig_sport, orig_dport;
> +	bool prev_inaddr_any = false;
>  	__be32 daddr, nexthop;
>  	struct flowi4 *fl4;
>  	struct rtable *rt;
>  	int err;
> -	struct ip_options_rcu *inet_opt;
>  
>  	dp->dccps_role = DCCP_ROLE_CLIENT;
>  
> @@ -89,8 +90,11 @@ int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
>  	if (inet_opt == NULL || !inet_opt->opt.srr)
>  		daddr = fl4->daddr;
>  
> -	if (inet->inet_saddr == 0)
> +	if (inet->inet_saddr == 0) {
>  		inet->inet_saddr = fl4->saddr;
> +		prev_inaddr_any = true;
> +	}
> +
>  	sk_rcv_saddr_set(sk, inet->inet_saddr);
>  	inet->inet_dport = usin->sin_port;
>  	sk_daddr_set(sk, daddr);
> @@ -105,7 +109,7 @@ int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
>  	 * complete initialization after this.
>  	 */
>  	dccp_set_state(sk, DCCP_REQUESTING);
> -	err = inet_hash_connect(&dccp_death_row, sk);
> +	err = inet_hash_connect(&dccp_death_row, sk, prev_inaddr_any);
>  	if (err != 0)
>  		goto failure;
>  
> diff --git a/net/dccp/ipv6.c b/net/dccp/ipv6.c
> index fd44638ec16b..03013522acab 100644
> --- a/net/dccp/ipv6.c
> +++ b/net/dccp/ipv6.c
> @@ -824,6 +824,7 @@ static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
>  	struct ipv6_pinfo *np = inet6_sk(sk);
>  	struct dccp_sock *dp = dccp_sk(sk);
>  	struct in6_addr *saddr = NULL, *final_p, final;
> +	bool prev_inaddr_any = false;
>  	struct ipv6_txoptions *opt;
>  	struct flowi6 fl6;
>  	struct dst_entry *dst;
> @@ -936,6 +937,7 @@ static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
>  	if (saddr == NULL) {
>  		saddr = &fl6.saddr;
>  		sk->sk_v6_rcv_saddr = *saddr;
> +		prev_inaddr_any = true;
>  	}
>  
>  	/* set the source address */
> @@ -951,7 +953,7 @@ static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
>  	inet->inet_dport = usin->sin6_port;
>  
>  	dccp_set_state(sk, DCCP_REQUESTING);
> -	err = inet6_hash_connect(&dccp_death_row, sk);
> +	err = inet6_hash_connect(&dccp_death_row, sk, prev_inaddr_any);
>  	if (err)
>  		goto late_failure;
>  
> diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c
> index 30e0e8992085..9785f8f428b0 100644
> --- a/net/ipv4/af_inet.c
> +++ b/net/ipv4/af_inet.c
> @@ -1221,10 +1221,11 @@ static int inet_sk_reselect_saddr(struct sock *sk)
>  	struct inet_sock *inet = inet_sk(sk);
>  	__be32 old_saddr = inet->inet_saddr;
>  	__be32 daddr = inet->inet_daddr;
> +	struct ip_options_rcu *inet_opt;
>  	struct flowi4 *fl4;
>  	struct rtable *rt;
>  	__be32 new_saddr;
> -	struct ip_options_rcu *inet_opt;
> +	int err;
>  
>  	inet_opt = rcu_dereference_protected(inet->inet_opt,
>  					     lockdep_sock_is_held(sk));
> @@ -1253,6 +1254,10 @@ static int inet_sk_reselect_saddr(struct sock *sk)
>  
>  	inet->inet_saddr = inet->inet_rcv_saddr = new_saddr;
>  
> +	err = inet_bhash2_update_saddr(sk);
> +	if (err)
> +		return err;
> +
>  	/*
>  	 * XXX The only one ugly spot where we need to
>  	 * XXX really change the sockets identity after
> diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c
> index 545f91b6cb5e..73f18134b2d5 100644
> --- a/net/ipv4/inet_hashtables.c
> +++ b/net/ipv4/inet_hashtables.c
> @@ -826,6 +826,55 @@ inet_bind2_bucket_find(struct inet_hashinfo *hinfo, struct net *net,
>  	return bhash2;
>  }
>  
> +/* the lock for the socket's corresponding bhash entry must be held */
> +static int __inet_bhash2_update_saddr(struct sock *sk,
> +				      struct inet_hashinfo *hinfo,
> +				      struct net *net, int port, int l3mdev)
> +{
> +	struct inet_bind2_hashbucket *head2;
> +	struct inet_bind2_bucket *tb2;
> +
> +	tb2 = inet_bind2_bucket_find(hinfo, net, port, l3mdev, sk,
> +				     &head2);
> +	if (!tb2) {
> +		tb2 = inet_bind2_bucket_create(hinfo->bind2_bucket_cachep,
> +					       net, head2, port, l3mdev, sk);
> +		if (!tb2)
> +			return -ENOMEM;
> +	}
> +
> +	/* Remove the socket's old entry from bhash2 */
> +	__sk_del_bind2_node(sk);
> +
> +	sk_add_bind2_node(sk, &tb2->owners);
> +	inet_csk(sk)->icsk_bind2_hash = tb2;
> +
> +	return 0;
> +}
> +
> +/* This should be called if/when a socket's rcv saddr changes after it has
> + * been binded.
> + */
> +int inet_bhash2_update_saddr(struct sock *sk)
> +{
> +	struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
> +	int l3mdev = inet_sk_bound_l3mdev(sk);
> +	struct inet_bind_hashbucket *head;
> +	int port = inet_sk(sk)->inet_num;
> +	struct net *net = sock_net(sk);
> +	int err;
> +
> +	head = &hinfo->bhash[inet_bhashfn(net, port, hinfo->bhash_size)];
> +
> +	spin_lock_bh(&head->lock);

I think this patch should be 2nd one in the series because we know this lock
does not protect bhash2 and it makes diff smaller.

Also, the series should be applied to net instead of net-next.


> +
> +	err = __inet_bhash2_update_saddr(sk, hinfo, net, port, l3mdev);
> +
> +	spin_unlock_bh(&head->lock);
> +
> +	return err;
> +}
> +
>  /* RFC 6056 3.3.4.  Algorithm 4: Double-Hash Port Selection Algorithm
>   * Note that we use 32bit integers (vs RFC 'short integers')
>   * because 2^16 is not a multiple of num_ephemeral and this
> @@ -840,7 +889,7 @@ inet_bind2_bucket_find(struct inet_hashinfo *hinfo, struct net *net,
>  static u32 *table_perturb;
>  
>  int __inet_hash_connect(struct inet_timewait_death_row *death_row,
> -		struct sock *sk, u64 port_offset,
> +		struct sock *sk, u64 port_offset, bool prev_inaddr_any,
>  		int (*check_established)(struct inet_timewait_death_row *,
>  			struct sock *, __u16, struct inet_timewait_sock **))
>  {
> @@ -858,11 +907,24 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
>  	int l3mdev;
>  	u32 index;
>  
> +	l3mdev = inet_sk_bound_l3mdev(sk);

This add/remove dance can be saved.


> +
>  	if (port) {
>  		head = &hinfo->bhash[inet_bhashfn(net, port,
>  						  hinfo->bhash_size)];
>  		tb = inet_csk(sk)->icsk_bind_hash;
> +
>  		spin_lock_bh(&head->lock);
> +
> +		if (prev_inaddr_any) {
> +			ret = __inet_bhash2_update_saddr(sk, hinfo, net, port,
> +							 l3mdev);
> +			if (ret) {
> +				spin_unlock_bh(&head->lock);
> +				return ret;
> +			}
> +		}
> +
>  		if (sk_head(&tb->owners) == sk && !sk->sk_bind_node.next) {
>  			inet_ehash_nolisten(sk, NULL, NULL);
>  			spin_unlock_bh(&head->lock);
> @@ -875,8 +937,6 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
>  		return ret;
>  	}
>  
> -	l3mdev = inet_sk_bound_l3mdev(sk);
> -
>  	inet_get_local_port_range(net, &low, &high);
>  	high++; /* [32768, 60999] -> [32768, 61000[ */
>  	remaining = high - low;
> @@ -987,13 +1047,13 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
>   * Bind a port for a connect operation and hash it.
>   */
>  int inet_hash_connect(struct inet_timewait_death_row *death_row,
> -		      struct sock *sk)
> +		      struct sock *sk, bool prev_inaddr_any)
>  {
>  	u64 port_offset = 0;
>  
>  	if (!inet_sk(sk)->inet_num)
>  		port_offset = inet_sk_port_offset(sk);
> -	return __inet_hash_connect(death_row, sk, port_offset,
> +	return __inet_hash_connect(death_row, sk, port_offset, prev_inaddr_any,
>  				   __inet_check_established);
>  }
>  EXPORT_SYMBOL_GPL(inet_hash_connect);
> diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
> index fe8f23b95d32..70c2182c780d 100644
> --- a/net/ipv4/tcp_ipv4.c
> +++ b/net/ipv4/tcp_ipv4.c
> @@ -203,6 +203,7 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
>  	struct inet_sock *inet = inet_sk(sk);
>  	struct tcp_sock *tp = tcp_sk(sk);
>  	__be16 orig_sport, orig_dport;
> +	bool prev_inaddr_any = false;
>  	__be32 daddr, nexthop;
>  	struct flowi4 *fl4;
>  	struct rtable *rt;
> @@ -246,8 +247,11 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
>  	if (!inet_opt || !inet_opt->opt.srr)
>  		daddr = fl4->daddr;
>  
> -	if (!inet->inet_saddr)
> +	if (!inet->inet_saddr) {
>  		inet->inet_saddr = fl4->saddr;
> +		prev_inaddr_any = true;
> +	}
> +
>  	sk_rcv_saddr_set(sk, inet->inet_saddr);
>  
>  	if (tp->rx_opt.ts_recent_stamp && inet->inet_daddr != daddr) {
> @@ -273,7 +277,7 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
>  	 * complete initialization after this.
>  	 */
>  	tcp_set_state(sk, TCP_SYN_SENT);
> -	err = inet_hash_connect(tcp_death_row, sk);
> +	err = inet_hash_connect(tcp_death_row, sk, prev_inaddr_any);
>  	if (err)
>  		goto failure;
>  
> diff --git a/net/ipv6/inet6_hashtables.c b/net/ipv6/inet6_hashtables.c
> index 7d53d62783b1..c87c5933f3be 100644
> --- a/net/ipv6/inet6_hashtables.c
> +++ b/net/ipv6/inet6_hashtables.c
> @@ -317,13 +317,13 @@ static u64 inet6_sk_port_offset(const struct sock *sk)
>  }
>  
>  int inet6_hash_connect(struct inet_timewait_death_row *death_row,
> -		       struct sock *sk)
> +		       struct sock *sk, bool prev_inaddr_any)
>  {
>  	u64 port_offset = 0;
>  
>  	if (!inet_sk(sk)->inet_num)
>  		port_offset = inet6_sk_port_offset(sk);
> -	return __inet_hash_connect(death_row, sk, port_offset,
> +	return __inet_hash_connect(death_row, sk, port_offset, prev_inaddr_any,
>  				   __inet6_check_established);
>  }
>  EXPORT_SYMBOL_GPL(inet6_hash_connect);
> diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c
> index f37dd4aa91c6..81e3312c2a97 100644
> --- a/net/ipv6/tcp_ipv6.c
> +++ b/net/ipv6/tcp_ipv6.c
> @@ -152,6 +152,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
>  	struct ipv6_pinfo *np = tcp_inet6_sk(sk);
>  	struct tcp_sock *tp = tcp_sk(sk);
>  	struct in6_addr *saddr = NULL, *final_p, final;
> +	bool prev_inaddr_any = false;
>  	struct ipv6_txoptions *opt;
>  	struct flowi6 fl6;
>  	struct dst_entry *dst;
> @@ -289,6 +290,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
>  	if (!saddr) {
>  		saddr = &fl6.saddr;
>  		sk->sk_v6_rcv_saddr = *saddr;
> +		prev_inaddr_any = true;
>  	}
>  
>  	/* set the source address */
> @@ -309,7 +311,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
>  
>  	tcp_set_state(sk, TCP_SYN_SENT);
>  	tcp_death_row = sock_net(sk)->ipv4.tcp_death_row;
> -	err = inet6_hash_connect(tcp_death_row, sk);
> +	err = inet6_hash_connect(tcp_death_row, sk, prev_inaddr_any);
>  	if (err)
>  		goto late_failure;
>  
> -- 
> 2.30.2
diff mbox series

Patch

diff --git a/include/net/inet_hashtables.h b/include/net/inet_hashtables.h
index a0887b70967b..2c331ce6ca73 100644
--- a/include/net/inet_hashtables.h
+++ b/include/net/inet_hashtables.h
@@ -448,11 +448,13 @@  static inline void sk_rcv_saddr_set(struct sock *sk, __be32 addr)
 }
 
 int __inet_hash_connect(struct inet_timewait_death_row *death_row,
-			struct sock *sk, u64 port_offset,
+			struct sock *sk, u64 port_offset, bool prev_inaddr_any,
 			int (*check_established)(struct inet_timewait_death_row *,
 						 struct sock *, __u16,
 						 struct inet_timewait_sock **));
 
 int inet_hash_connect(struct inet_timewait_death_row *death_row,
-		      struct sock *sk);
+		      struct sock *sk, bool prev_inaddr_any);
+
+int inet_bhash2_update_saddr(struct sock *sk);
 #endif /* _INET_HASHTABLES_H */
diff --git a/include/net/ipv6.h b/include/net/ipv6.h
index de9dcc5652c4..735f7b4d55dc 100644
--- a/include/net/ipv6.h
+++ b/include/net/ipv6.h
@@ -1187,7 +1187,7 @@  int inet6_compat_ioctl(struct socket *sock, unsigned int cmd,
 		unsigned long arg);
 
 int inet6_hash_connect(struct inet_timewait_death_row *death_row,
-			      struct sock *sk);
+		       struct sock *sk, bool prev_inaddr_any);
 int inet6_sendmsg(struct socket *sock, struct msghdr *msg, size_t size);
 int inet6_recvmsg(struct socket *sock, struct msghdr *msg, size_t size,
 		  int flags);
diff --git a/net/dccp/ipv4.c b/net/dccp/ipv4.c
index da6e3b20cd75..37a8bc3ee49e 100644
--- a/net/dccp/ipv4.c
+++ b/net/dccp/ipv4.c
@@ -47,12 +47,13 @@  int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
 	const struct sockaddr_in *usin = (struct sockaddr_in *)uaddr;
 	struct inet_sock *inet = inet_sk(sk);
 	struct dccp_sock *dp = dccp_sk(sk);
+	struct ip_options_rcu *inet_opt;
 	__be16 orig_sport, orig_dport;
+	bool prev_inaddr_any = false;
 	__be32 daddr, nexthop;
 	struct flowi4 *fl4;
 	struct rtable *rt;
 	int err;
-	struct ip_options_rcu *inet_opt;
 
 	dp->dccps_role = DCCP_ROLE_CLIENT;
 
@@ -89,8 +90,11 @@  int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
 	if (inet_opt == NULL || !inet_opt->opt.srr)
 		daddr = fl4->daddr;
 
-	if (inet->inet_saddr == 0)
+	if (inet->inet_saddr == 0) {
 		inet->inet_saddr = fl4->saddr;
+		prev_inaddr_any = true;
+	}
+
 	sk_rcv_saddr_set(sk, inet->inet_saddr);
 	inet->inet_dport = usin->sin_port;
 	sk_daddr_set(sk, daddr);
@@ -105,7 +109,7 @@  int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
 	 * complete initialization after this.
 	 */
 	dccp_set_state(sk, DCCP_REQUESTING);
-	err = inet_hash_connect(&dccp_death_row, sk);
+	err = inet_hash_connect(&dccp_death_row, sk, prev_inaddr_any);
 	if (err != 0)
 		goto failure;
 
diff --git a/net/dccp/ipv6.c b/net/dccp/ipv6.c
index fd44638ec16b..03013522acab 100644
--- a/net/dccp/ipv6.c
+++ b/net/dccp/ipv6.c
@@ -824,6 +824,7 @@  static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
 	struct ipv6_pinfo *np = inet6_sk(sk);
 	struct dccp_sock *dp = dccp_sk(sk);
 	struct in6_addr *saddr = NULL, *final_p, final;
+	bool prev_inaddr_any = false;
 	struct ipv6_txoptions *opt;
 	struct flowi6 fl6;
 	struct dst_entry *dst;
@@ -936,6 +937,7 @@  static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
 	if (saddr == NULL) {
 		saddr = &fl6.saddr;
 		sk->sk_v6_rcv_saddr = *saddr;
+		prev_inaddr_any = true;
 	}
 
 	/* set the source address */
@@ -951,7 +953,7 @@  static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
 	inet->inet_dport = usin->sin6_port;
 
 	dccp_set_state(sk, DCCP_REQUESTING);
-	err = inet6_hash_connect(&dccp_death_row, sk);
+	err = inet6_hash_connect(&dccp_death_row, sk, prev_inaddr_any);
 	if (err)
 		goto late_failure;
 
diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c
index 30e0e8992085..9785f8f428b0 100644
--- a/net/ipv4/af_inet.c
+++ b/net/ipv4/af_inet.c
@@ -1221,10 +1221,11 @@  static int inet_sk_reselect_saddr(struct sock *sk)
 	struct inet_sock *inet = inet_sk(sk);
 	__be32 old_saddr = inet->inet_saddr;
 	__be32 daddr = inet->inet_daddr;
+	struct ip_options_rcu *inet_opt;
 	struct flowi4 *fl4;
 	struct rtable *rt;
 	__be32 new_saddr;
-	struct ip_options_rcu *inet_opt;
+	int err;
 
 	inet_opt = rcu_dereference_protected(inet->inet_opt,
 					     lockdep_sock_is_held(sk));
@@ -1253,6 +1254,10 @@  static int inet_sk_reselect_saddr(struct sock *sk)
 
 	inet->inet_saddr = inet->inet_rcv_saddr = new_saddr;
 
+	err = inet_bhash2_update_saddr(sk);
+	if (err)
+		return err;
+
 	/*
 	 * XXX The only one ugly spot where we need to
 	 * XXX really change the sockets identity after
diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c
index 545f91b6cb5e..73f18134b2d5 100644
--- a/net/ipv4/inet_hashtables.c
+++ b/net/ipv4/inet_hashtables.c
@@ -826,6 +826,55 @@  inet_bind2_bucket_find(struct inet_hashinfo *hinfo, struct net *net,
 	return bhash2;
 }
 
+/* the lock for the socket's corresponding bhash entry must be held */
+static int __inet_bhash2_update_saddr(struct sock *sk,
+				      struct inet_hashinfo *hinfo,
+				      struct net *net, int port, int l3mdev)
+{
+	struct inet_bind2_hashbucket *head2;
+	struct inet_bind2_bucket *tb2;
+
+	tb2 = inet_bind2_bucket_find(hinfo, net, port, l3mdev, sk,
+				     &head2);
+	if (!tb2) {
+		tb2 = inet_bind2_bucket_create(hinfo->bind2_bucket_cachep,
+					       net, head2, port, l3mdev, sk);
+		if (!tb2)
+			return -ENOMEM;
+	}
+
+	/* Remove the socket's old entry from bhash2 */
+	__sk_del_bind2_node(sk);
+
+	sk_add_bind2_node(sk, &tb2->owners);
+	inet_csk(sk)->icsk_bind2_hash = tb2;
+
+	return 0;
+}
+
+/* This should be called if/when a socket's rcv saddr changes after it has
+ * been binded.
+ */
+int inet_bhash2_update_saddr(struct sock *sk)
+{
+	struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
+	int l3mdev = inet_sk_bound_l3mdev(sk);
+	struct inet_bind_hashbucket *head;
+	int port = inet_sk(sk)->inet_num;
+	struct net *net = sock_net(sk);
+	int err;
+
+	head = &hinfo->bhash[inet_bhashfn(net, port, hinfo->bhash_size)];
+
+	spin_lock_bh(&head->lock);
+
+	err = __inet_bhash2_update_saddr(sk, hinfo, net, port, l3mdev);
+
+	spin_unlock_bh(&head->lock);
+
+	return err;
+}
+
 /* RFC 6056 3.3.4.  Algorithm 4: Double-Hash Port Selection Algorithm
  * Note that we use 32bit integers (vs RFC 'short integers')
  * because 2^16 is not a multiple of num_ephemeral and this
@@ -840,7 +889,7 @@  inet_bind2_bucket_find(struct inet_hashinfo *hinfo, struct net *net,
 static u32 *table_perturb;
 
 int __inet_hash_connect(struct inet_timewait_death_row *death_row,
-		struct sock *sk, u64 port_offset,
+		struct sock *sk, u64 port_offset, bool prev_inaddr_any,
 		int (*check_established)(struct inet_timewait_death_row *,
 			struct sock *, __u16, struct inet_timewait_sock **))
 {
@@ -858,11 +907,24 @@  int __inet_hash_connect(struct inet_timewait_death_row *death_row,
 	int l3mdev;
 	u32 index;
 
+	l3mdev = inet_sk_bound_l3mdev(sk);
+
 	if (port) {
 		head = &hinfo->bhash[inet_bhashfn(net, port,
 						  hinfo->bhash_size)];
 		tb = inet_csk(sk)->icsk_bind_hash;
+
 		spin_lock_bh(&head->lock);
+
+		if (prev_inaddr_any) {
+			ret = __inet_bhash2_update_saddr(sk, hinfo, net, port,
+							 l3mdev);
+			if (ret) {
+				spin_unlock_bh(&head->lock);
+				return ret;
+			}
+		}
+
 		if (sk_head(&tb->owners) == sk && !sk->sk_bind_node.next) {
 			inet_ehash_nolisten(sk, NULL, NULL);
 			spin_unlock_bh(&head->lock);
@@ -875,8 +937,6 @@  int __inet_hash_connect(struct inet_timewait_death_row *death_row,
 		return ret;
 	}
 
-	l3mdev = inet_sk_bound_l3mdev(sk);
-
 	inet_get_local_port_range(net, &low, &high);
 	high++; /* [32768, 60999] -> [32768, 61000[ */
 	remaining = high - low;
@@ -987,13 +1047,13 @@  int __inet_hash_connect(struct inet_timewait_death_row *death_row,
  * Bind a port for a connect operation and hash it.
  */
 int inet_hash_connect(struct inet_timewait_death_row *death_row,
-		      struct sock *sk)
+		      struct sock *sk, bool prev_inaddr_any)
 {
 	u64 port_offset = 0;
 
 	if (!inet_sk(sk)->inet_num)
 		port_offset = inet_sk_port_offset(sk);
-	return __inet_hash_connect(death_row, sk, port_offset,
+	return __inet_hash_connect(death_row, sk, port_offset, prev_inaddr_any,
 				   __inet_check_established);
 }
 EXPORT_SYMBOL_GPL(inet_hash_connect);
diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
index fe8f23b95d32..70c2182c780d 100644
--- a/net/ipv4/tcp_ipv4.c
+++ b/net/ipv4/tcp_ipv4.c
@@ -203,6 +203,7 @@  int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
 	struct inet_sock *inet = inet_sk(sk);
 	struct tcp_sock *tp = tcp_sk(sk);
 	__be16 orig_sport, orig_dport;
+	bool prev_inaddr_any = false;
 	__be32 daddr, nexthop;
 	struct flowi4 *fl4;
 	struct rtable *rt;
@@ -246,8 +247,11 @@  int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
 	if (!inet_opt || !inet_opt->opt.srr)
 		daddr = fl4->daddr;
 
-	if (!inet->inet_saddr)
+	if (!inet->inet_saddr) {
 		inet->inet_saddr = fl4->saddr;
+		prev_inaddr_any = true;
+	}
+
 	sk_rcv_saddr_set(sk, inet->inet_saddr);
 
 	if (tp->rx_opt.ts_recent_stamp && inet->inet_daddr != daddr) {
@@ -273,7 +277,7 @@  int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
 	 * complete initialization after this.
 	 */
 	tcp_set_state(sk, TCP_SYN_SENT);
-	err = inet_hash_connect(tcp_death_row, sk);
+	err = inet_hash_connect(tcp_death_row, sk, prev_inaddr_any);
 	if (err)
 		goto failure;
 
diff --git a/net/ipv6/inet6_hashtables.c b/net/ipv6/inet6_hashtables.c
index 7d53d62783b1..c87c5933f3be 100644
--- a/net/ipv6/inet6_hashtables.c
+++ b/net/ipv6/inet6_hashtables.c
@@ -317,13 +317,13 @@  static u64 inet6_sk_port_offset(const struct sock *sk)
 }
 
 int inet6_hash_connect(struct inet_timewait_death_row *death_row,
-		       struct sock *sk)
+		       struct sock *sk, bool prev_inaddr_any)
 {
 	u64 port_offset = 0;
 
 	if (!inet_sk(sk)->inet_num)
 		port_offset = inet6_sk_port_offset(sk);
-	return __inet_hash_connect(death_row, sk, port_offset,
+	return __inet_hash_connect(death_row, sk, port_offset, prev_inaddr_any,
 				   __inet6_check_established);
 }
 EXPORT_SYMBOL_GPL(inet6_hash_connect);
diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c
index f37dd4aa91c6..81e3312c2a97 100644
--- a/net/ipv6/tcp_ipv6.c
+++ b/net/ipv6/tcp_ipv6.c
@@ -152,6 +152,7 @@  static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
 	struct ipv6_pinfo *np = tcp_inet6_sk(sk);
 	struct tcp_sock *tp = tcp_sk(sk);
 	struct in6_addr *saddr = NULL, *final_p, final;
+	bool prev_inaddr_any = false;
 	struct ipv6_txoptions *opt;
 	struct flowi6 fl6;
 	struct dst_entry *dst;
@@ -289,6 +290,7 @@  static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
 	if (!saddr) {
 		saddr = &fl6.saddr;
 		sk->sk_v6_rcv_saddr = *saddr;
+		prev_inaddr_any = true;
 	}
 
 	/* set the source address */
@@ -309,7 +311,7 @@  static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
 
 	tcp_set_state(sk, TCP_SYN_SENT);
 	tcp_death_row = sock_net(sk)->ipv4.tcp_death_row;
-	err = inet6_hash_connect(tcp_death_row, sk);
+	err = inet6_hash_connect(tcp_death_row, sk, prev_inaddr_any);
 	if (err)
 		goto late_failure;