diff mbox series

[v2] net/tcp: Merge TCP-MD5 inbound callbacks

Message ID 20220223121746.421327-1-dima@arista.com (mailing list archive)
State Superseded
Delegated to: Netdev Maintainers
Headers show
Series [v2] net/tcp: Merge TCP-MD5 inbound callbacks | expand

Checks

Context Check Description
netdev/fixes_present success Fixes tag not required for -next series
netdev/subject_prefix warning Target tree name not specified in the subject
netdev/cover_letter success Single patches do not need cover letters
netdev/patch_count success Link
netdev/header_inline success No static functions without inline keyword in header files
netdev/build_32bit success Errors and warnings before: 1107 this patch: 1107
netdev/cc_maintainers success CCed 6 of 6 maintainers
netdev/build_clang success Errors and warnings before: 159 this patch: 159
netdev/module_param success Was 0 now: 0
netdev/verify_signedoff success Signed-off-by tag matches author and committer
netdev/verify_fixes success No Fixes tag
netdev/build_allmodconfig_warn success Errors and warnings before: 1117 this patch: 1117
netdev/checkpatch warning CHECK: Alignment should match open parenthesis CHECK: Please use a blank line after function/struct/union/enum declarations WARNING: networking block comments don't use an empty /* line, use /* Comment...
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/source_inline success Was 0 now: 0
netdev/tree_selection success Guessing tree name failed - patch did not apply, async

Commit Message

Dmitry Safonov Feb. 23, 2022, 12:17 p.m. UTC
The functions do essentially the same work to verify TCP-MD5 sign.
Code can be merged into one family-independent function in order to
reduce copy'n'paste and generated code.
Later with TCP-AO option added, this will allow to create one function
that's responsible for segment verification, that will have all the
different checks for MD5/AO/non-signed packets, which in turn will help
to see checks for all corner-cases in one function, rather than spread
around different families and functions.

Cc: Eric Dumazet <edumazet@google.com>
Cc: "David S. Miller" <davem@davemloft.net>
Cc: Jakub Kicinski <kuba@kernel.org>
Cc: Hideaki YOSHIFUJI <yoshfuji@linux-ipv6.org>
Cc: David Ahern <dsahern@kernel.org>
Cc: netdev@vger.kernel.org
Signed-off-by: Dmitry Safonov <dima@arista.com>
---
v2: Rebased on net-next

 include/net/tcp.h   | 12 +++++++
 net/ipv4/tcp.c      | 70 ++++++++++++++++++++++++++++++++++++++++
 net/ipv4/tcp_ipv4.c | 78 +++------------------------------------------
 net/ipv6/tcp_ipv6.c | 62 +++--------------------------------
 4 files changed, 91 insertions(+), 131 deletions(-)


base-commit: 922ea87ff6f2b63f413c6afa2c25b287dce76639
prerequisite-patch-id: f4dc7ba51eadb9fccabb755c41854bedeeaf8954

Comments

Dmitry Safonov Feb. 23, 2022, 5:57 p.m. UTC | #1
On 2/23/22 12:17, Dmitry Safonov wrote:
> The functions do essentially the same work to verify TCP-MD5 sign.
> Code can be merged into one family-independent function in order to
> reduce copy'n'paste and generated code.
> Later with TCP-AO option added, this will allow to create one function
> that's responsible for segment verification, that will have all the
> different checks for MD5/AO/non-signed packets, which in turn will help
> to see checks for all corner-cases in one function, rather than spread
> around different families and functions.
> 
> Cc: Eric Dumazet <edumazet@google.com>
> Cc: "David S. Miller" <davem@davemloft.net>
> Cc: Jakub Kicinski <kuba@kernel.org>
> Cc: Hideaki YOSHIFUJI <yoshfuji@linux-ipv6.org>
> Cc: David Ahern <dsahern@kernel.org>
> Cc: netdev@vger.kernel.org
> Signed-off-by: Dmitry Safonov <dima@arista.com>
> ---
> v2: Rebased on net-next

On rebase I didn't check !CONFIG_TCP_MD5, and managed to forget to
change the function declaration for the stub. Duh!
I've sent version3.

Thanks, sorry for the noise,
          Dmitry
diff mbox series

Patch

diff --git a/include/net/tcp.h b/include/net/tcp.h
index 04f4650e0ff0..2e60864d7865 100644
--- a/include/net/tcp.h
+++ b/include/net/tcp.h
@@ -1674,6 +1674,11 @@  tcp_md5_do_lookup(const struct sock *sk, int l3index,
 		return NULL;
 	return __tcp_md5_do_lookup(sk, l3index, addr, family);
 }
+bool tcp_inbound_md5_hash(const struct sock *sk, const struct sk_buff *skb,
+			  enum skb_drop_reason *reason,
+			  const void *saddr, const void *daddr,
+			  int family, int dif, int sdif);
+
 
 #define tcp_twsk_md5_key(twsk)	((twsk)->tw_md5_key)
 #else
@@ -1683,6 +1688,13 @@  tcp_md5_do_lookup(const struct sock *sk, int l3index,
 {
 	return NULL;
 }
+static inline bool tcp_inbound_md5_hash(const struct sock *sk,
+					const struct sk_buff *skb,
+					const void *saddr, const void *daddr,
+					int family, int dif, int sdif)
+{
+	return false;
+}
 #define tcp_twsk_md5_key(twsk)	NULL
 #endif
 
diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c
index 760e8221d321..68f1236b2858 100644
--- a/net/ipv4/tcp.c
+++ b/net/ipv4/tcp.c
@@ -4431,6 +4431,76 @@  int tcp_md5_hash_key(struct tcp_md5sig_pool *hp, const struct tcp_md5sig_key *ke
 }
 EXPORT_SYMBOL(tcp_md5_hash_key);
 
+/* Called with rcu_read_lock() */
+bool tcp_inbound_md5_hash(const struct sock *sk, const struct sk_buff *skb,
+			  enum skb_drop_reason *reason,
+			  const void *saddr, const void *daddr,
+			  int family, int dif, int sdif)
+{
+	/*
+	 * This gets called for each TCP segment that arrives
+	 * so we want to be efficient.
+	 * We have 3 drop cases:
+	 * o No MD5 hash and one expected.
+	 * o MD5 hash and we're not expecting one.
+	 * o MD5 hash and its wrong.
+	 */
+	const __u8 *hash_location = NULL;
+	struct tcp_md5sig_key *hash_expected;
+	const struct tcphdr *th = tcp_hdr(skb);
+	struct tcp_sock *tp = tcp_sk(sk);
+	int genhash, l3index;
+	u8 newhash[16];
+
+	/* sdif set, means packet ingressed via a device
+	 * in an L3 domain and dif is set to the l3mdev
+	 */
+	l3index = sdif ? dif : 0;
+
+	hash_expected = tcp_md5_do_lookup(sk, l3index, saddr, family);
+	hash_location = tcp_parse_md5sig_option(th);
+
+	/* We've parsed the options - do we have a hash? */
+	if (!hash_expected && !hash_location)
+		return false;
+
+	if (hash_expected && !hash_location) {
+		*reason = SKB_DROP_REASON_TCP_MD5NOTFOUND;
+		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5NOTFOUND);
+		return true;
+	}
+
+	if (!hash_expected && hash_location) {
+		*reason = SKB_DROP_REASON_TCP_MD5UNEXPECTED;
+		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5UNEXPECTED);
+		return true;
+	}
+
+	/* check the signature */
+	genhash = tp->af_specific->calc_md5_hash(newhash, hash_expected,
+						 NULL, skb);
+
+	if (genhash || memcmp(hash_location, newhash, 16) != 0) {
+		*reason = SKB_DROP_REASON_TCP_MD5FAILURE;
+		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5FAILURE);
+		if (family == AF_INET) {
+			net_info_ratelimited("MD5 Hash failed for (%pI4, %d)->(%pI4, %d)%s L3 index %d\n",
+					saddr, ntohs(th->source),
+					daddr, ntohs(th->dest),
+					genhash ? " tcp_v4_calc_md5_hash failed"
+					: "", l3index);
+		} else {
+			net_info_ratelimited("MD5 Hash %s for [%pI6c]:%u->[%pI6c]:%u L3 index %d\n",
+					genhash ? "failed" : "mismatch",
+					saddr, ntohs(th->source),
+					daddr, ntohs(th->dest), l3index);
+		}
+		return true;
+	}
+	return false;
+}
+EXPORT_SYMBOL(tcp_inbound_md5_hash);
+
 #endif
 
 void tcp_done(struct sock *sk)
diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
index d42824aedc36..411357ad9757 100644
--- a/net/ipv4/tcp_ipv4.c
+++ b/net/ipv4/tcp_ipv4.c
@@ -1409,76 +1409,6 @@  EXPORT_SYMBOL(tcp_v4_md5_hash_skb);
 
 #endif
 
-/* Called with rcu_read_lock() */
-static bool tcp_v4_inbound_md5_hash(const struct sock *sk,
-				    const struct sk_buff *skb,
-				    int dif, int sdif,
-				    enum skb_drop_reason *reason)
-{
-#ifdef CONFIG_TCP_MD5SIG
-	/*
-	 * This gets called for each TCP segment that arrives
-	 * so we want to be efficient.
-	 * We have 3 drop cases:
-	 * o No MD5 hash and one expected.
-	 * o MD5 hash and we're not expecting one.
-	 * o MD5 hash and its wrong.
-	 */
-	const __u8 *hash_location = NULL;
-	struct tcp_md5sig_key *hash_expected;
-	const struct iphdr *iph = ip_hdr(skb);
-	const struct tcphdr *th = tcp_hdr(skb);
-	const union tcp_md5_addr *addr;
-	unsigned char newhash[16];
-	int genhash, l3index;
-
-	/* sdif set, means packet ingressed via a device
-	 * in an L3 domain and dif is set to the l3mdev
-	 */
-	l3index = sdif ? dif : 0;
-
-	addr = (union tcp_md5_addr *)&iph->saddr;
-	hash_expected = tcp_md5_do_lookup(sk, l3index, addr, AF_INET);
-	hash_location = tcp_parse_md5sig_option(th);
-
-	/* We've parsed the options - do we have a hash? */
-	if (!hash_expected && !hash_location)
-		return false;
-
-	if (hash_expected && !hash_location) {
-		*reason = SKB_DROP_REASON_TCP_MD5NOTFOUND;
-		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5NOTFOUND);
-		return true;
-	}
-
-	if (!hash_expected && hash_location) {
-		*reason = SKB_DROP_REASON_TCP_MD5UNEXPECTED;
-		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5UNEXPECTED);
-		return true;
-	}
-
-	/* Okay, so this is hash_expected and hash_location -
-	 * so we need to calculate the checksum.
-	 */
-	genhash = tcp_v4_md5_hash_skb(newhash,
-				      hash_expected,
-				      NULL, skb);
-
-	if (genhash || memcmp(hash_location, newhash, 16) != 0) {
-		*reason = SKB_DROP_REASON_TCP_MD5FAILURE;
-		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5FAILURE);
-		net_info_ratelimited("MD5 Hash failed for (%pI4, %d)->(%pI4, %d)%s L3 index %d\n",
-				     &iph->saddr, ntohs(th->source),
-				     &iph->daddr, ntohs(th->dest),
-				     genhash ? " tcp_v4_calc_md5_hash failed"
-				     : "", l3index);
-		return true;
-	}
-	return false;
-#endif
-	return false;
-}
-
 static void tcp_v4_init_req(struct request_sock *req,
 			    const struct sock *sk_listener,
 			    struct sk_buff *skb)
@@ -2035,8 +1965,9 @@  int tcp_v4_rcv(struct sk_buff *skb)
 		struct sock *nsk;
 
 		sk = req->rsk_listener;
-		if (unlikely(tcp_v4_inbound_md5_hash(sk, skb, dif, sdif,
-						     &drop_reason))) {
+		if (unlikely(tcp_inbound_md5_hash(sk, skb, &drop_reason,
+						  &iph->saddr, &iph->daddr,
+						  AF_INET, dif, sdif))) {
 			sk_drops_add(sk, skb);
 			reqsk_put(req);
 			goto discard_it;
@@ -2110,7 +2041,8 @@  int tcp_v4_rcv(struct sk_buff *skb)
 		goto discard_and_relse;
 	}
 
-	if (tcp_v4_inbound_md5_hash(sk, skb, dif, sdif, &drop_reason))
+	if (tcp_inbound_md5_hash(sk, skb, &drop_reason, &iph->saddr,
+				 &iph->daddr, AF_INET, dif, sdif))
 		goto discard_and_relse;
 
 	nf_reset_ct(skb);
diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c
index 749de8529c83..e98af869ff3a 100644
--- a/net/ipv6/tcp_ipv6.c
+++ b/net/ipv6/tcp_ipv6.c
@@ -773,61 +773,6 @@  static int tcp_v6_md5_hash_skb(char *md5_hash,
 
 #endif
 
-static bool tcp_v6_inbound_md5_hash(const struct sock *sk,
-				    const struct sk_buff *skb,
-				    int dif, int sdif,
-				    enum skb_drop_reason *reason)
-{
-#ifdef CONFIG_TCP_MD5SIG
-	const __u8 *hash_location = NULL;
-	struct tcp_md5sig_key *hash_expected;
-	const struct ipv6hdr *ip6h = ipv6_hdr(skb);
-	const struct tcphdr *th = tcp_hdr(skb);
-	int genhash, l3index;
-	u8 newhash[16];
-
-	/* sdif set, means packet ingressed via a device
-	 * in an L3 domain and dif is set to the l3mdev
-	 */
-	l3index = sdif ? dif : 0;
-
-	hash_expected = tcp_v6_md5_do_lookup(sk, &ip6h->saddr, l3index);
-	hash_location = tcp_parse_md5sig_option(th);
-
-	/* We've parsed the options - do we have a hash? */
-	if (!hash_expected && !hash_location)
-		return false;
-
-	if (hash_expected && !hash_location) {
-		*reason = SKB_DROP_REASON_TCP_MD5NOTFOUND;
-		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5NOTFOUND);
-		return true;
-	}
-
-	if (!hash_expected && hash_location) {
-		*reason = SKB_DROP_REASON_TCP_MD5UNEXPECTED;
-		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5UNEXPECTED);
-		return true;
-	}
-
-	/* check the signature */
-	genhash = tcp_v6_md5_hash_skb(newhash,
-				      hash_expected,
-				      NULL, skb);
-
-	if (genhash || memcmp(hash_location, newhash, 16) != 0) {
-		*reason = SKB_DROP_REASON_TCP_MD5FAILURE;
-		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5FAILURE);
-		net_info_ratelimited("MD5 Hash %s for [%pI6c]:%u->[%pI6c]:%u L3 index %d\n",
-				     genhash ? "failed" : "mismatch",
-				     &ip6h->saddr, ntohs(th->source),
-				     &ip6h->daddr, ntohs(th->dest), l3index);
-		return true;
-	}
-#endif
-	return false;
-}
-
 static void tcp_v6_init_req(struct request_sock *req,
 			    const struct sock *sk_listener,
 			    struct sk_buff *skb)
@@ -1687,8 +1632,8 @@  INDIRECT_CALLABLE_SCOPE int tcp_v6_rcv(struct sk_buff *skb)
 		struct sock *nsk;
 
 		sk = req->rsk_listener;
-		if (tcp_v6_inbound_md5_hash(sk, skb, dif, sdif,
-					    &drop_reason)) {
+		if (tcp_inbound_md5_hash(sk, skb, &drop_reason, &hdr->saddr,
+					 &hdr->daddr, AF_INET6, dif, sdif)) {
 			sk_drops_add(sk, skb);
 			reqsk_put(req);
 			goto discard_it;
@@ -1759,7 +1704,8 @@  INDIRECT_CALLABLE_SCOPE int tcp_v6_rcv(struct sk_buff *skb)
 		goto discard_and_relse;
 	}
 
-	if (tcp_v6_inbound_md5_hash(sk, skb, dif, sdif, &drop_reason))
+	if (tcp_inbound_md5_hash(sk, skb, &drop_reason, &hdr->saddr,
+				 &hdr->daddr, AF_INET6, dif, sdif))
 		goto discard_and_relse;
 
 	if (tcp_filter(sk, skb)) {