diff mbox series

[net-next,v2,2/4] net/tls: Use cipher sizes structs

Message ID 20220920130150.3546-3-gal@nvidia.com (mailing list archive)
State Accepted
Commit ea7a9d88ba21dd9d395d7137b0ca1743c5f5d9c2
Delegated to: Netdev Maintainers
Headers show
Series Support 256 bit TLS keys with device offload | expand

Checks

Context Check Description
netdev/tree_selection success Clearly marked for net-next
netdev/fixes_present success Fixes tag not required for -next series
netdev/subject_prefix success Link
netdev/cover_letter success Series has a cover letter
netdev/patch_count success Link
netdev/header_inline success No static functions without inline keyword in header files
netdev/build_32bit success Errors and warnings before: 0 this patch: 0
netdev/cc_maintainers warning 2 maintainers not CCed: edumazet@google.com pabeni@redhat.com
netdev/build_clang success Errors and warnings before: 0 this patch: 0
netdev/module_param success Was 0 now: 0
netdev/verify_signedoff success Signed-off-by tag matches author and committer
netdev/check_selftest success No net selftest shell script
netdev/verify_fixes success No Fixes tag
netdev/build_allmodconfig_warn success Errors and warnings before: 0 this patch: 0
netdev/checkpatch warning WARNING: line length of 81 exceeds 80 columns
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/source_inline success Was 0 now: 0

Commit Message

Gal Pressman Sept. 20, 2022, 1:01 p.m. UTC
Use the newly introduced cipher sizes structs instead of the repeated
switch cases churn.

Reviewed-by: Tariq Toukan <tariqt@nvidia.com>
Signed-off-by: Gal Pressman <gal@nvidia.com>
---
 net/tls/tls_device.c          | 55 +++++++++++++-------------
 net/tls/tls_device_fallback.c | 72 +++++++++++++++++++++++------------
 2 files changed, 76 insertions(+), 51 deletions(-)
diff mbox series

Patch

diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c
index 0f983e5f7dde..3f8121b8125c 100644
--- a/net/tls/tls_device.c
+++ b/net/tls/tls_device.c
@@ -902,17 +902,27 @@  static void tls_device_core_ctrl_rx_resync(struct tls_context *tls_ctx,
 }
 
 static int
-tls_device_reencrypt(struct sock *sk, struct tls_sw_context_rx *sw_ctx)
+tls_device_reencrypt(struct sock *sk, struct tls_context *tls_ctx)
 {
+	struct tls_sw_context_rx *sw_ctx = tls_sw_ctx_rx(tls_ctx);
+	const struct tls_cipher_size_desc *cipher_sz;
 	int err, offset, copy, data_len, pos;
 	struct sk_buff *skb, *skb_iter;
 	struct scatterlist sg[1];
 	struct strp_msg *rxm;
 	char *orig_buf, *buf;
 
+	switch (tls_ctx->crypto_recv.info.cipher_type) {
+	case TLS_CIPHER_AES_GCM_128:
+		break;
+	default:
+		return -EINVAL;
+	}
+	cipher_sz = &tls_cipher_size_desc[tls_ctx->crypto_recv.info.cipher_type];
+
 	rxm = strp_msg(tls_strp_msg(sw_ctx));
-	orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE +
-			   TLS_CIPHER_AES_GCM_128_IV_SIZE, sk->sk_allocation);
+	orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE + cipher_sz->iv,
+			   sk->sk_allocation);
 	if (!orig_buf)
 		return -ENOMEM;
 	buf = orig_buf;
@@ -927,10 +937,8 @@  tls_device_reencrypt(struct sock *sk, struct tls_sw_context_rx *sw_ctx)
 
 	sg_init_table(sg, 1);
 	sg_set_buf(&sg[0], buf,
-		   rxm->full_len + TLS_HEADER_SIZE +
-		   TLS_CIPHER_AES_GCM_128_IV_SIZE);
-	err = skb_copy_bits(skb, offset, buf,
-			    TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE);
+		   rxm->full_len + TLS_HEADER_SIZE + cipher_sz->iv);
+	err = skb_copy_bits(skb, offset, buf, TLS_HEADER_SIZE + cipher_sz->iv);
 	if (err)
 		goto free_buf;
 
@@ -941,7 +949,7 @@  tls_device_reencrypt(struct sock *sk, struct tls_sw_context_rx *sw_ctx)
 	else
 		err = 0;
 
-	data_len = rxm->full_len - TLS_CIPHER_AES_GCM_128_TAG_SIZE;
+	data_len = rxm->full_len - cipher_sz->tag;
 
 	if (skb_pagelen(skb) > offset) {
 		copy = min_t(int, skb_pagelen(skb) - offset, data_len);
@@ -1024,7 +1032,7 @@  int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx)
 		 * likely have initial fragments decrypted, and final ones not
 		 * decrypted. We need to reencrypt that single SKB.
 		 */
-		return tls_device_reencrypt(sk, sw_ctx);
+		return tls_device_reencrypt(sk, tls_ctx);
 	}
 
 	/* Return immediately if the record is either entirely plaintext or
@@ -1041,7 +1049,7 @@  int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx)
 	}
 
 	ctx->resync_nh_reset = 1;
-	return tls_device_reencrypt(sk, sw_ctx);
+	return tls_device_reencrypt(sk, tls_ctx);
 }
 
 static void tls_device_attach(struct tls_context *ctx, struct sock *sk,
@@ -1062,9 +1070,9 @@  static void tls_device_attach(struct tls_context *ctx, struct sock *sk,
 
 int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
 {
-	u16 nonce_size, tag_size, iv_size, rec_seq_size, salt_size;
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_prot_info *prot = &tls_ctx->prot_info;
+	const struct tls_cipher_size_desc *cipher_sz;
 	struct tls_record_info *start_marker_record;
 	struct tls_offload_context_tx *offload_ctx;
 	struct tls_crypto_info *crypto_info;
@@ -1099,12 +1107,7 @@  int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
 
 	switch (crypto_info->cipher_type) {
 	case TLS_CIPHER_AES_GCM_128:
-		nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
-		tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
-		iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
 		iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
-		rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
-		salt_size = TLS_CIPHER_AES_GCM_128_SALT_SIZE;
 		rec_seq =
 		 ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
 		break;
@@ -1112,31 +1115,31 @@  int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
 		rc = -EINVAL;
 		goto release_netdev;
 	}
+	cipher_sz = &tls_cipher_size_desc[crypto_info->cipher_type];
 
 	/* Sanity-check the rec_seq_size for stack allocations */
-	if (rec_seq_size > TLS_MAX_REC_SEQ_SIZE) {
+	if (cipher_sz->rec_seq > TLS_MAX_REC_SEQ_SIZE) {
 		rc = -EINVAL;
 		goto release_netdev;
 	}
 
 	prot->version = crypto_info->version;
 	prot->cipher_type = crypto_info->cipher_type;
-	prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
-	prot->tag_size = tag_size;
+	prot->prepend_size = TLS_HEADER_SIZE + cipher_sz->iv;
+	prot->tag_size = cipher_sz->tag;
 	prot->overhead_size = prot->prepend_size + prot->tag_size;
-	prot->iv_size = iv_size;
-	prot->salt_size = salt_size;
-	ctx->tx.iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
-			     GFP_KERNEL);
+	prot->iv_size = cipher_sz->iv;
+	prot->salt_size = cipher_sz->salt;
+	ctx->tx.iv = kmalloc(cipher_sz->iv + cipher_sz->salt, GFP_KERNEL);
 	if (!ctx->tx.iv) {
 		rc = -ENOMEM;
 		goto release_netdev;
 	}
 
-	memcpy(ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
+	memcpy(ctx->tx.iv + cipher_sz->salt, iv, cipher_sz->iv);
 
-	prot->rec_seq_size = rec_seq_size;
-	ctx->tx.rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
+	prot->rec_seq_size = cipher_sz->rec_seq;
+	ctx->tx.rec_seq = kmemdup(rec_seq, cipher_sz->rec_seq, GFP_KERNEL);
 	if (!ctx->tx.rec_seq) {
 		rc = -ENOMEM;
 		goto free_iv;
diff --git a/net/tls/tls_device_fallback.c b/net/tls/tls_device_fallback.c
index 7dfc8023e0f1..0d2b6518b877 100644
--- a/net/tls/tls_device_fallback.c
+++ b/net/tls/tls_device_fallback.c
@@ -54,13 +54,24 @@  static int tls_enc_record(struct aead_request *aead_req,
 			  struct scatter_walk *out, int *in_len,
 			  struct tls_prot_info *prot)
 {
-	unsigned char buf[TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE];
+	unsigned char buf[TLS_HEADER_SIZE + MAX_IV_SIZE];
+	const struct tls_cipher_size_desc *cipher_sz;
 	struct scatterlist sg_in[3];
 	struct scatterlist sg_out[3];
+	unsigned int buf_size;
 	u16 len;
 	int rc;
 
-	len = min_t(int, *in_len, ARRAY_SIZE(buf));
+	switch (prot->cipher_type) {
+	case TLS_CIPHER_AES_GCM_128:
+		break;
+	default:
+		return -EINVAL;
+	}
+	cipher_sz = &tls_cipher_size_desc[prot->cipher_type];
+
+	buf_size = TLS_HEADER_SIZE + cipher_sz->iv;
+	len = min_t(int, *in_len, buf_size);
 
 	scatterwalk_copychunks(buf, in, len, 0);
 	scatterwalk_copychunks(buf, out, len, 1);
@@ -73,13 +84,11 @@  static int tls_enc_record(struct aead_request *aead_req,
 	scatterwalk_pagedone(out, 1, 1);
 
 	len = buf[4] | (buf[3] << 8);
-	len -= TLS_CIPHER_AES_GCM_128_IV_SIZE;
+	len -= cipher_sz->iv;
 
-	tls_make_aad(aad, len - TLS_CIPHER_AES_GCM_128_TAG_SIZE,
-		(char *)&rcd_sn, buf[0], prot);
+	tls_make_aad(aad, len - cipher_sz->tag, (char *)&rcd_sn, buf[0], prot);
 
-	memcpy(iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, buf + TLS_HEADER_SIZE,
-	       TLS_CIPHER_AES_GCM_128_IV_SIZE);
+	memcpy(iv + cipher_sz->salt, buf + TLS_HEADER_SIZE, cipher_sz->iv);
 
 	sg_init_table(sg_in, ARRAY_SIZE(sg_in));
 	sg_init_table(sg_out, ARRAY_SIZE(sg_out));
@@ -90,7 +99,7 @@  static int tls_enc_record(struct aead_request *aead_req,
 
 	*in_len -= len;
 	if (*in_len < 0) {
-		*in_len += TLS_CIPHER_AES_GCM_128_TAG_SIZE;
+		*in_len += cipher_sz->tag;
 		/* the input buffer doesn't contain the entire record.
 		 * trim len accordingly. The resulting authentication tag
 		 * will contain garbage, but we don't care, so we won't
@@ -111,7 +120,7 @@  static int tls_enc_record(struct aead_request *aead_req,
 		scatterwalk_pagedone(out, 1, 1);
 	}
 
-	len -= TLS_CIPHER_AES_GCM_128_TAG_SIZE;
+	len -= cipher_sz->tag;
 	aead_request_set_crypt(aead_req, sg_in, sg_out, len, iv);
 
 	rc = crypto_aead_encrypt(aead_req);
@@ -299,11 +308,14 @@  static void fill_sg_out(struct scatterlist sg_out[3], void *buf,
 			int sync_size,
 			void *dummy_buf)
 {
+	const struct tls_cipher_size_desc *cipher_sz =
+		&tls_cipher_size_desc[tls_ctx->crypto_send.info.cipher_type];
+
 	sg_set_buf(&sg_out[0], dummy_buf, sync_size);
 	sg_set_buf(&sg_out[1], nskb->data + tcp_payload_offset, payload_len);
 	/* Add room for authentication tag produced by crypto */
 	dummy_buf += sync_size;
-	sg_set_buf(&sg_out[2], dummy_buf, TLS_CIPHER_AES_GCM_128_TAG_SIZE);
+	sg_set_buf(&sg_out[2], dummy_buf, cipher_sz->tag);
 }
 
 static struct sk_buff *tls_enc_skb(struct tls_context *tls_ctx,
@@ -315,7 +327,8 @@  static struct sk_buff *tls_enc_skb(struct tls_context *tls_ctx,
 	struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
 	int tcp_payload_offset = skb_tcp_all_headers(skb);
 	int payload_len = skb->len - tcp_payload_offset;
-	void *buf, *iv, *aad, *dummy_buf;
+	const struct tls_cipher_size_desc *cipher_sz;
+	void *buf, *iv, *aad, *dummy_buf, *salt;
 	struct aead_request *aead_req;
 	struct sk_buff *nskb = NULL;
 	int buf_len;
@@ -324,20 +337,23 @@  static struct sk_buff *tls_enc_skb(struct tls_context *tls_ctx,
 	if (!aead_req)
 		return NULL;
 
-	buf_len = TLS_CIPHER_AES_GCM_128_SALT_SIZE +
-		  TLS_CIPHER_AES_GCM_128_IV_SIZE +
-		  TLS_AAD_SPACE_SIZE +
-		  sync_size +
-		  TLS_CIPHER_AES_GCM_128_TAG_SIZE;
+	switch (tls_ctx->crypto_send.info.cipher_type) {
+	case TLS_CIPHER_AES_GCM_128:
+		salt = tls_ctx->crypto_send.aes_gcm_128.salt;
+		break;
+	default:
+		return NULL;
+	}
+	cipher_sz = &tls_cipher_size_desc[tls_ctx->crypto_send.info.cipher_type];
+	buf_len = cipher_sz->salt + cipher_sz->iv + TLS_AAD_SPACE_SIZE +
+		  sync_size + cipher_sz->tag;
 	buf = kmalloc(buf_len, GFP_ATOMIC);
 	if (!buf)
 		goto free_req;
 
 	iv = buf;
-	memcpy(iv, tls_ctx->crypto_send.aes_gcm_128.salt,
-	       TLS_CIPHER_AES_GCM_128_SALT_SIZE);
-	aad = buf + TLS_CIPHER_AES_GCM_128_SALT_SIZE +
-	      TLS_CIPHER_AES_GCM_128_IV_SIZE;
+	memcpy(iv, salt, cipher_sz->salt);
+	aad = buf + cipher_sz->salt + cipher_sz->iv;
 	dummy_buf = aad + TLS_AAD_SPACE_SIZE;
 
 	nskb = alloc_skb(skb_headroom(skb) + skb->len, GFP_ATOMIC);
@@ -451,6 +467,7 @@  int tls_sw_fallback_init(struct sock *sk,
 			 struct tls_offload_context_tx *offload_ctx,
 			 struct tls_crypto_info *crypto_info)
 {
+	const struct tls_cipher_size_desc *cipher_sz;
 	const u8 *key;
 	int rc;
 
@@ -463,15 +480,20 @@  int tls_sw_fallback_init(struct sock *sk,
 		goto err_out;
 	}
 
-	key = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->key;
+	switch (crypto_info->cipher_type) {
+	case TLS_CIPHER_AES_GCM_128:
+		key = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->key;
+		break;
+	default:
+		return -EINVAL;
+	}
+	cipher_sz = &tls_cipher_size_desc[crypto_info->cipher_type];
 
-	rc = crypto_aead_setkey(offload_ctx->aead_send, key,
-				TLS_CIPHER_AES_GCM_128_KEY_SIZE);
+	rc = crypto_aead_setkey(offload_ctx->aead_send, key, cipher_sz->key);
 	if (rc)
 		goto free_aead;
 
-	rc = crypto_aead_setauthsize(offload_ctx->aead_send,
-				     TLS_CIPHER_AES_GCM_128_TAG_SIZE);
+	rc = crypto_aead_setauthsize(offload_ctx->aead_send, cipher_sz->tag);
 	if (rc)
 		goto free_aead;