diff mbox series

net/tcp: Merge TCP-MD5 inbound callbacks

Message ID 20220222185006.337620-1-dima@arista.com (mailing list archive)
State Superseded
Delegated to: Netdev Maintainers
Headers show
Series net/tcp: Merge TCP-MD5 inbound callbacks | expand

Checks

Context Check Description
netdev/tree_selection success Guessing tree name failed - patch did not apply, async

Commit Message

Dmitry Safonov Feb. 22, 2022, 6:50 p.m. UTC
The functions do essentially the same work to verify TCP-MD5 sign.
Code can be merged into one family-independent function in order to
reduce copy'n'paste and generated code.
Later with TCP-AO option added, this will allow to create one function
that's responsible for segment verification, that will have all the
different checks for MD5/AO/non-signed packets, which in turn will help
to see checks for all corner-cases in one function, rather than spread
around different families and functions.

Cc: Eric Dumazet <edumazet@google.com>
Cc: "David S. Miller" <davem@davemloft.net>
Cc: Jakub Kicinski <kuba@kernel.org>
Cc: Hideaki YOSHIFUJI <yoshfuji@linux-ipv6.org>
Cc: David Ahern <dsahern@kernel.org>
Cc: netdev@vger.kernel.org
Signed-off-by: Dmitry Safonov <dima@arista.com>
---
 include/net/tcp.h   | 11 +++++++
 net/ipv4/tcp.c      | 66 ++++++++++++++++++++++++++++++++++++++++
 net/ipv4/tcp_ipv4.c | 73 ++++-----------------------------------------
 net/ipv6/tcp_ipv6.c | 57 +++--------------------------------
 4 files changed, 86 insertions(+), 121 deletions(-)


base-commit: 038101e6b2cd5c55f888f85db42ea2ad3aecb4b6

Comments

Jakub Kicinski Feb. 23, 2022, 12:18 a.m. UTC | #1
On Tue, 22 Feb 2022 18:50:06 +0000 Dmitry Safonov wrote:
> The functions do essentially the same work to verify TCP-MD5 sign.
> Code can be merged into one family-independent function in order to
> reduce copy'n'paste and generated code.
> Later with TCP-AO option added, this will allow to create one function
> that's responsible for segment verification, that will have all the
> different checks for MD5/AO/non-signed packets, which in turn will help
> to see checks for all corner-cases in one function, rather than spread
> around different families and functions.

Please rebase on top of net-next
Dmitry Safonov Feb. 23, 2022, 12:19 p.m. UTC | #2
On 2/23/22 00:18, Jakub Kicinski wrote:
> On Tue, 22 Feb 2022 18:50:06 +0000 Dmitry Safonov wrote:
>> The functions do essentially the same work to verify TCP-MD5 sign.
>> Code can be merged into one family-independent function in order to
>> reduce copy'n'paste and generated code.
>> Later with TCP-AO option added, this will allow to create one function
>> that's responsible for segment verification, that will have all the
>> different checks for MD5/AO/non-signed packets, which in turn will help
>> to see checks for all corner-cases in one function, rather than spread
>> around different families and functions.
> 
> Please rebase on top of net-next

Thanks!
I've previously checked it on linux-next, was not aware that net-next is
not integrated or integrated with some delay?

Anyway, I've resent v2, based on net-next.

Thanks,
          Dmitry
diff mbox series

Patch

diff --git a/include/net/tcp.h b/include/net/tcp.h
index b9fc978fb2ca..598f89ef9546 100644
--- a/include/net/tcp.h
+++ b/include/net/tcp.h
@@ -1673,6 +1673,10 @@  tcp_md5_do_lookup(const struct sock *sk, int l3index,
 		return NULL;
 	return __tcp_md5_do_lookup(sk, l3index, addr, family);
 }
+bool tcp_inbound_md5_hash(const struct sock *sk, const struct sk_buff *skb,
+			  const void *saddr, const void *daddr,
+			  int family, int dif, int sdif);
+
 
 #define tcp_twsk_md5_key(twsk)	((twsk)->tw_md5_key)
 #else
@@ -1682,6 +1686,13 @@  tcp_md5_do_lookup(const struct sock *sk, int l3index,
 {
 	return NULL;
 }
+static inline bool tcp_inbound_md5_hash(const struct sock *sk,
+					const struct sk_buff *skb,
+					const void *saddr, const void *daddr,
+					int family, int dif, int sdif)
+{
+	return false;
+}
 #define tcp_twsk_md5_key(twsk)	NULL
 #endif
 
diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c
index 02cb275e5487..546647534381 100644
--- a/net/ipv4/tcp.c
+++ b/net/ipv4/tcp.c
@@ -4432,6 +4432,72 @@  int tcp_md5_hash_key(struct tcp_md5sig_pool *hp, const struct tcp_md5sig_key *ke
 }
 EXPORT_SYMBOL(tcp_md5_hash_key);
 
+/* Called with rcu_read_lock() */
+bool tcp_inbound_md5_hash(const struct sock *sk, const struct sk_buff *skb,
+			  const void *saddr, const void *daddr,
+			  int family, int dif, int sdif)
+{
+	/*
+	 * This gets called for each TCP segment that arrives
+	 * so we want to be efficient.
+	 * We have 3 drop cases:
+	 * o No MD5 hash and one expected.
+	 * o MD5 hash and we're not expecting one.
+	 * o MD5 hash and its wrong.
+	 */
+	const __u8 *hash_location = NULL;
+	struct tcp_md5sig_key *hash_expected;
+	const struct tcphdr *th = tcp_hdr(skb);
+	struct tcp_sock *tp = tcp_sk(sk);
+	int genhash, l3index;
+	u8 newhash[16];
+
+	/* sdif set, means packet ingressed via a device
+	 * in an L3 domain and dif is set to the l3mdev
+	 */
+	l3index = sdif ? dif : 0;
+
+	hash_expected = tcp_md5_do_lookup(sk, l3index, saddr, family);
+	hash_location = tcp_parse_md5sig_option(th);
+
+	/* We've parsed the options - do we have a hash? */
+	if (!hash_expected && !hash_location)
+		return false;
+
+	if (hash_expected && !hash_location) {
+		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5NOTFOUND);
+		return true;
+	}
+
+	if (!hash_expected && hash_location) {
+		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5UNEXPECTED);
+		return true;
+	}
+
+	/* check the signature */
+	genhash = tp->af_specific->calc_md5_hash(newhash, hash_expected,
+						 NULL, skb);
+
+	if (genhash || memcmp(hash_location, newhash, 16) != 0) {
+		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5FAILURE);
+		if (family == AF_INET) {
+			net_info_ratelimited("MD5 Hash failed for (%pI4, %d)->(%pI4, %d)%s L3 index %d\n",
+					saddr, ntohs(th->source),
+					daddr, ntohs(th->dest),
+					genhash ? " tcp_v4_calc_md5_hash failed"
+					: "", l3index);
+		} else {
+			net_info_ratelimited("MD5 Hash %s for [%pI6c]:%u->[%pI6c]:%u L3 index %d\n",
+					genhash ? "failed" : "mismatch",
+					saddr, ntohs(th->source),
+					daddr, ntohs(th->dest), l3index);
+		}
+		return true;
+	}
+	return false;
+}
+EXPORT_SYMBOL(tcp_inbound_md5_hash);
+
 #endif
 
 void tcp_done(struct sock *sk)
diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
index fec656f5a39e..bf2f6aff146d 100644
--- a/net/ipv4/tcp_ipv4.c
+++ b/net/ipv4/tcp_ipv4.c
@@ -1403,72 +1403,6 @@  EXPORT_SYMBOL(tcp_v4_md5_hash_skb);
 
 #endif
 
-/* Called with rcu_read_lock() */
-static bool tcp_v4_inbound_md5_hash(const struct sock *sk,
-				    const struct sk_buff *skb,
-				    int dif, int sdif)
-{
-#ifdef CONFIG_TCP_MD5SIG
-	/*
-	 * This gets called for each TCP segment that arrives
-	 * so we want to be efficient.
-	 * We have 3 drop cases:
-	 * o No MD5 hash and one expected.
-	 * o MD5 hash and we're not expecting one.
-	 * o MD5 hash and its wrong.
-	 */
-	const __u8 *hash_location = NULL;
-	struct tcp_md5sig_key *hash_expected;
-	const struct iphdr *iph = ip_hdr(skb);
-	const struct tcphdr *th = tcp_hdr(skb);
-	const union tcp_md5_addr *addr;
-	unsigned char newhash[16];
-	int genhash, l3index;
-
-	/* sdif set, means packet ingressed via a device
-	 * in an L3 domain and dif is set to the l3mdev
-	 */
-	l3index = sdif ? dif : 0;
-
-	addr = (union tcp_md5_addr *)&iph->saddr;
-	hash_expected = tcp_md5_do_lookup(sk, l3index, addr, AF_INET);
-	hash_location = tcp_parse_md5sig_option(th);
-
-	/* We've parsed the options - do we have a hash? */
-	if (!hash_expected && !hash_location)
-		return false;
-
-	if (hash_expected && !hash_location) {
-		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5NOTFOUND);
-		return true;
-	}
-
-	if (!hash_expected && hash_location) {
-		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5UNEXPECTED);
-		return true;
-	}
-
-	/* Okay, so this is hash_expected and hash_location -
-	 * so we need to calculate the checksum.
-	 */
-	genhash = tcp_v4_md5_hash_skb(newhash,
-				      hash_expected,
-				      NULL, skb);
-
-	if (genhash || memcmp(hash_location, newhash, 16) != 0) {
-		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5FAILURE);
-		net_info_ratelimited("MD5 Hash failed for (%pI4, %d)->(%pI4, %d)%s L3 index %d\n",
-				     &iph->saddr, ntohs(th->source),
-				     &iph->daddr, ntohs(th->dest),
-				     genhash ? " tcp_v4_calc_md5_hash failed"
-				     : "", l3index);
-		return true;
-	}
-	return false;
-#endif
-	return false;
-}
-
 static void tcp_v4_init_req(struct request_sock *req,
 			    const struct sock *sk_listener,
 			    struct sk_buff *skb)
@@ -2019,7 +1953,9 @@  int tcp_v4_rcv(struct sk_buff *skb)
 		struct sock *nsk;
 
 		sk = req->rsk_listener;
-		if (unlikely(tcp_v4_inbound_md5_hash(sk, skb, dif, sdif))) {
+		if (unlikely(tcp_inbound_md5_hash(sk, skb,
+						  &iph->saddr, &iph->daddr,
+						  AF_INET, dif, sdif))) {
 			sk_drops_add(sk, skb);
 			reqsk_put(req);
 			goto discard_it;
@@ -2089,7 +2025,8 @@  int tcp_v4_rcv(struct sk_buff *skb)
 	if (!xfrm4_policy_check(sk, XFRM_POLICY_IN, skb))
 		goto discard_and_relse;
 
-	if (tcp_v4_inbound_md5_hash(sk, skb, dif, sdif))
+	if (tcp_inbound_md5_hash(sk, skb, &iph->saddr, &iph->daddr,
+				AF_INET, dif, sdif))
 		goto discard_and_relse;
 
 	nf_reset_ct(skb);
diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c
index 075ee8a2df3b..30cd0a074c2c 100644
--- a/net/ipv6/tcp_ipv6.c
+++ b/net/ipv6/tcp_ipv6.c
@@ -772,57 +772,6 @@  static int tcp_v6_md5_hash_skb(char *md5_hash,
 
 #endif
 
-static bool tcp_v6_inbound_md5_hash(const struct sock *sk,
-				    const struct sk_buff *skb,
-				    int dif, int sdif)
-{
-#ifdef CONFIG_TCP_MD5SIG
-	const __u8 *hash_location = NULL;
-	struct tcp_md5sig_key *hash_expected;
-	const struct ipv6hdr *ip6h = ipv6_hdr(skb);
-	const struct tcphdr *th = tcp_hdr(skb);
-	int genhash, l3index;
-	u8 newhash[16];
-
-	/* sdif set, means packet ingressed via a device
-	 * in an L3 domain and dif is set to the l3mdev
-	 */
-	l3index = sdif ? dif : 0;
-
-	hash_expected = tcp_v6_md5_do_lookup(sk, &ip6h->saddr, l3index);
-	hash_location = tcp_parse_md5sig_option(th);
-
-	/* We've parsed the options - do we have a hash? */
-	if (!hash_expected && !hash_location)
-		return false;
-
-	if (hash_expected && !hash_location) {
-		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5NOTFOUND);
-		return true;
-	}
-
-	if (!hash_expected && hash_location) {
-		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5UNEXPECTED);
-		return true;
-	}
-
-	/* check the signature */
-	genhash = tcp_v6_md5_hash_skb(newhash,
-				      hash_expected,
-				      NULL, skb);
-
-	if (genhash || memcmp(hash_location, newhash, 16) != 0) {
-		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5FAILURE);
-		net_info_ratelimited("MD5 Hash %s for [%pI6c]:%u->[%pI6c]:%u L3 index %d\n",
-				     genhash ? "failed" : "mismatch",
-				     &ip6h->saddr, ntohs(th->source),
-				     &ip6h->daddr, ntohs(th->dest), l3index);
-		return true;
-	}
-#endif
-	return false;
-}
-
 static void tcp_v6_init_req(struct request_sock *req,
 			    const struct sock *sk_listener,
 			    struct sk_buff *skb)
@@ -1676,7 +1625,8 @@  INDIRECT_CALLABLE_SCOPE int tcp_v6_rcv(struct sk_buff *skb)
 		struct sock *nsk;
 
 		sk = req->rsk_listener;
-		if (tcp_v6_inbound_md5_hash(sk, skb, dif, sdif)) {
+		if (tcp_inbound_md5_hash(sk, skb, &hdr->saddr, &hdr->daddr,
+					 AF_INET6, dif, sdif)) {
 			sk_drops_add(sk, skb);
 			reqsk_put(req);
 			goto discard_it;
@@ -1743,7 +1693,8 @@  INDIRECT_CALLABLE_SCOPE int tcp_v6_rcv(struct sk_buff *skb)
 	if (!xfrm6_policy_check(sk, XFRM_POLICY_IN, skb))
 		goto discard_and_relse;
 
-	if (tcp_v6_inbound_md5_hash(sk, skb, dif, sdif))
+	if (tcp_inbound_md5_hash(sk, skb, &hdr->saddr, &hdr->daddr,
+				 AF_INET6, dif, sdif))
 		goto discard_and_relse;
 
 	if (tcp_filter(sk, skb))