diff mbox series

[v7,6/9] crypto: arm64/aes-xctr: Improve readability of XCTR and CTR modes

Message ID 20220509191107.3556468-7-nhuck@google.com (mailing list archive)
State Superseded
Headers show
Series crypto: HCTR2 support | expand

Commit Message

Nathan Huckleberry May 9, 2022, 7:11 p.m. UTC
Added some clarifying comments, changed the register allocations to make
the code clearer, and added register aliases.

Signed-off-by: Nathan Huckleberry <nhuck@google.com>
Reviewed-by: Eric Biggers <ebiggers@google.com>
---
 arch/arm64/crypto/aes-glue.c  |  18 +++
 arch/arm64/crypto/aes-modes.S | 226 ++++++++++++++++++++++++----------
 2 files changed, 176 insertions(+), 68 deletions(-)

Comments

Eric Biggers May 9, 2022, 9:56 p.m. UTC | #1
On Mon, May 09, 2022 at 07:11:04PM +0000, Nathan Huckleberry wrote:
> Added some clarifying comments, changed the register allocations to make
> the code clearer, and added register aliases.
> 
> Signed-off-by: Nathan Huckleberry <nhuck@google.com>
> Reviewed-by: Eric Biggers <ebiggers@google.com>

Did you mean to add Ard's Reviewed-by that he gave on v6 as well?

One comment about the v7 changes below:

>  	/*
>  	 * aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
>  	 *		   int bytes, u8 ctr[])
> +	 *
> +	 * The input and output buffers must always be at least 16 bytes even if
> +	 * encrypting/decrypting less than 16 bytes.  Otherwise out of bounds
> +	 * accesses will occur.
>  	 */

This comment, along with the other similar ones you added, doesn't properly
describe the behavior when bytes < 16, as it's not mentioned that the extra
space needs to be before the pointed-to regions rather than after.  That's the
most unusual part of these functions, so it really should be mentioned.

Separately, applying this patch and the previous one causes the following
whitespace errors to be reported:

Applying: crypto: arm64/aes-xctr: Add accelerated implementation of XCTR
.git/rebase-apply/patch:299: space before tab in indent.
        ld1             {v5.16b-v7.16b}, [x1], #48
warning: 1 line adds whitespace errors.
Applying: crypto: arm64/aes-xctr: Improve readability of XCTR and CTR modes
.git/rebase-apply/patch:216: space before tab in indent.
        ld1             {v5.16b-v7.16b}, [IN], #48
warning: 1 line adds whitespace errors.

- Eric
diff mbox series

Patch

diff --git a/arch/arm64/crypto/aes-glue.c b/arch/arm64/crypto/aes-glue.c
index b6883288234c..b85f69195785 100644
--- a/arch/arm64/crypto/aes-glue.c
+++ b/arch/arm64/crypto/aes-glue.c
@@ -464,6 +464,15 @@  static int __maybe_unused xctr_encrypt(struct skcipher_request *req)
 		u8 *dst = walk.dst.virt.addr;
 		u8 buf[AES_BLOCK_SIZE];
 
+		/*
+		 * XCTR mode can only handle a partial block at the very end.
+		 * Since chunksize=16 we know that we always get at least 16
+		 * bytes (except possibly at the very end).
+		 *
+		 * If given less than 16 bytes, we must copy the partial block
+		 * into a temporary buffer of 16 bytes to avoid out of bounds
+		 * reads and writes.
+		 */
 		if (unlikely(nbytes < AES_BLOCK_SIZE))
 			src = dst = memcpy(buf + sizeof(buf) - nbytes,
 					   src, nbytes);
@@ -501,6 +510,15 @@  static int __maybe_unused ctr_encrypt(struct skcipher_request *req)
 		u8 *dst = walk.dst.virt.addr;
 		u8 buf[AES_BLOCK_SIZE];
 
+		/*
+		 * CTR mode can only handle a partial block at the very end.
+		 * Since chunksize=16 we know that we always get at least 16
+		 * bytes (except possibly at the very end).
+		 *
+		 * If given less than 16 bytes, we must copy the partial block
+		 * into a temporary buffer of 16 bytes to avoid out of bounds
+		 * reads and writes.
+		 */
 		if (unlikely(nbytes < AES_BLOCK_SIZE))
 			src = dst = memcpy(buf + sizeof(buf) - nbytes,
 					   src, nbytes);
diff --git a/arch/arm64/crypto/aes-modes.S b/arch/arm64/crypto/aes-modes.S
index 9a027200bdba..8260721e2df4 100644
--- a/arch/arm64/crypto/aes-modes.S
+++ b/arch/arm64/crypto/aes-modes.S
@@ -322,32 +322,60 @@  AES_FUNC_END(aes_cbc_cts_decrypt)
 	 * This macro generates the code for CTR and XCTR mode.
 	 */
 .macro ctr_encrypt xctr
+	// Arguments
+	OUT		.req x0
+	IN		.req x1
+	KEY		.req x2
+	ROUNDS_W	.req w3
+	BYTES_W		.req w4
+	IV		.req x5
+	BYTE_CTR_W 	.req w6		// XCTR only
+	// Intermediate values
+	CTR_W		.req w11	// XCTR only
+	CTR		.req x11	// XCTR only
+	IV_PART		.req x12
+	BLOCKS		.req x13
+	BLOCKS_W	.req w13
+
 	stp		x29, x30, [sp, #-16]!
 	mov		x29, sp
 
-	enc_prepare	w3, x2, x12
-	ld1		{vctr.16b}, [x5]
+	enc_prepare	ROUNDS_W, KEY, IV_PART
+	ld1		{vctr.16b}, [IV]
 
+	/*
+	 * Keep 64 bits of the IV in a register.  For CTR mode this lets us
+	 * easily increment the IV.  For XCTR mode this lets us efficiently XOR
+	 * the 64-bit counter with the IV.
+	 */
 	.if \xctr
-		umov		x12, vctr.d[0]
-		lsr		w11, w6, #4
+		umov		IV_PART, vctr.d[0]
+		lsr		CTR_W, BYTE_CTR_W, #4
 	.else
-		umov		x12, vctr.d[1] /* keep swabbed ctr in reg */
-		rev		x12, x12
+		umov		IV_PART, vctr.d[1]
+		rev		IV_PART, IV_PART
 	.endif
 
 .LctrloopNx\xctr:
-	add		w7, w4, #15
-	sub		w4, w4, #MAX_STRIDE << 4
-	lsr		w7, w7, #4
+	add		BLOCKS_W, BYTES_W, #15
+	sub		BYTES_W, BYTES_W, #MAX_STRIDE << 4
+	lsr		BLOCKS_W, BLOCKS_W, #4
 	mov		w8, #MAX_STRIDE
-	cmp		w7, w8
-	csel		w7, w7, w8, lt
+	cmp		BLOCKS_W, w8
+	csel		BLOCKS_W, BLOCKS_W, w8, lt
 
+	/*
+	 * Set up the counter values in v0-v{MAX_STRIDE-1}.
+	 *
+	 * If we are encrypting less than MAX_STRIDE blocks, the tail block
+	 * handling code expects the last keystream block to be in
+	 * v{MAX_STRIDE-1}.  For example: if encrypting two blocks with
+	 * MAX_STRIDE=5, then v3 and v4 should have the next two counter blocks.
+	 */
 	.if \xctr
-		add		x11, x11, x7
+		add		CTR, CTR, BLOCKS
 	.else
-		adds		x12, x12, x7
+		adds		IV_PART, IV_PART, BLOCKS
 	.endif
 	mov		v0.16b, vctr.16b
 	mov		v1.16b, vctr.16b
@@ -355,16 +383,16 @@  AES_FUNC_END(aes_cbc_cts_decrypt)
 	mov		v3.16b, vctr.16b
 ST5(	mov		v4.16b, vctr.16b		)
 	.if \xctr
-		sub		x6, x11, #MAX_STRIDE - 1
-		sub		x7, x11, #MAX_STRIDE - 2
-		sub		x8, x11, #MAX_STRIDE - 3
-		sub		x9, x11, #MAX_STRIDE - 4
-ST5(		sub		x10, x11, #MAX_STRIDE - 5	)
-		eor		x6, x6, x12
-		eor		x7, x7, x12
-		eor		x8, x8, x12
-		eor		x9, x9, x12
-ST5(		eor		x10, x10, x12			)
+		sub		x6, CTR, #MAX_STRIDE - 1
+		sub		x7, CTR, #MAX_STRIDE - 2
+		sub		x8, CTR, #MAX_STRIDE - 3
+		sub		x9, CTR, #MAX_STRIDE - 4
+ST5(		sub		x10, CTR, #MAX_STRIDE - 5	)
+		eor		x6, x6, IV_PART
+		eor		x7, x7, IV_PART
+		eor		x8, x8, IV_PART
+		eor		x9, x9, IV_PART
+ST5(		eor		x10, x10, IV_PART		)
 		mov		v0.d[0], x6
 		mov		v1.d[0], x7
 		mov		v2.d[0], x8
@@ -373,17 +401,32 @@  ST5(		mov		v4.d[0], x10			)
 	.else
 		bcs		0f
 		.subsection	1
-		/* apply carry to outgoing counter */
+		/*
+		 * This subsection handles carries.
+		 *
+		 * Conditional branching here is allowed with respect to time
+		 * invariance since the branches are dependent on the IV instead
+		 * of the plaintext or key.  This code is rarely executed in
+		 * practice anyway.
+		 */
+
+		/* Apply carry to outgoing counter. */
 0:		umov		x8, vctr.d[0]
 		rev		x8, x8
 		add		x8, x8, #1
 		rev		x8, x8
 		ins		vctr.d[0], x8
 
-		/* apply carry to N counter blocks for N := x12 */
-		cbz		x12, 2f
+		/*
+		 * Apply carry to counter blocks if needed.
+		 *
+		 * Since the carry flag was set, we know 0 <= IV_PART <
+		 * MAX_STRIDE.  Using the value of IV_PART we can determine how
+		 * many counter blocks need to be updated.
+		 */
+		cbz		IV_PART, 2f
 		adr		x16, 1f
-		sub		x16, x16, x12, lsl #3
+		sub		x16, x16, IV_PART, lsl #3
 		br		x16
 		bti		c
 		mov		v0.d[0], vctr.d[0]
@@ -398,71 +441,88 @@  ST5(		mov		v4.d[0], vctr.d[0]		)
 1:		b		2f
 		.previous
 
-2:		rev		x7, x12
+2:		rev		x7, IV_PART
 		ins		vctr.d[1], x7
-		sub		x7, x12, #MAX_STRIDE - 1
-		sub		x8, x12, #MAX_STRIDE - 2
-		sub		x9, x12, #MAX_STRIDE - 3
+		sub		x7, IV_PART, #MAX_STRIDE - 1
+		sub		x8, IV_PART, #MAX_STRIDE - 2
+		sub		x9, IV_PART, #MAX_STRIDE - 3
 		rev		x7, x7
 		rev		x8, x8
 		mov		v1.d[1], x7
 		rev		x9, x9
-ST5(		sub		x10, x12, #MAX_STRIDE - 4	)
+ST5(		sub		x10, IV_PART, #MAX_STRIDE - 4	)
 		mov		v2.d[1], x8
 ST5(		rev		x10, x10			)
 		mov		v3.d[1], x9
 ST5(		mov		v4.d[1], x10			)
 	.endif
-	tbnz		w4, #31, .Lctrtail\xctr
-    	ld1		{v5.16b-v7.16b}, [x1], #48
+
+	/*
+	 * If there are at least MAX_STRIDE blocks left, XOR the data with
+	 * keystream and store.  Otherwise jump to tail handling.
+	 */
+	tbnz		BYTES_W, #31, .Lctrtail\xctr
+    	ld1		{v5.16b-v7.16b}, [IN], #48
 ST4(	bl		aes_encrypt_block4x		)
 ST5(	bl		aes_encrypt_block5x		)
 	eor		v0.16b, v5.16b, v0.16b
-ST4(	ld1		{v5.16b}, [x1], #16		)
+ST4(	ld1		{v5.16b}, [IN], #16		)
 	eor		v1.16b, v6.16b, v1.16b
-ST5(	ld1		{v5.16b-v6.16b}, [x1], #32	)
+ST5(	ld1		{v5.16b-v6.16b}, [IN], #32	)
 	eor		v2.16b, v7.16b, v2.16b
 	eor		v3.16b, v5.16b, v3.16b
 ST5(	eor		v4.16b, v6.16b, v4.16b		)
-	st1		{v0.16b-v3.16b}, [x0], #64
-ST5(	st1		{v4.16b}, [x0], #16		)
-	cbz		w4, .Lctrout\xctr
+	st1		{v0.16b-v3.16b}, [OUT], #64
+ST5(	st1		{v4.16b}, [OUT], #16		)
+	cbz		BYTES_W, .Lctrout\xctr
 	b		.LctrloopNx\xctr
 
 .Lctrout\xctr:
 	.if !\xctr
-		st1		{vctr.16b}, [x5] /* return next CTR value */
+		st1		{vctr.16b}, [IV] /* return next CTR value */
 	.endif
 	ldp		x29, x30, [sp], #16
 	ret
 
 .Lctrtail\xctr:
+	/*
+	 * Handle up to MAX_STRIDE * 16 - 1 bytes of plaintext
+	 *
+	 * This code expects the last keystream block to be in v{MAX_STRIDE-1}.
+	 * For example: if encrypting two blocks with MAX_STRIDE=5, then v3 and
+	 * v4 should have the next two counter blocks.
+	 *
+	 * This allows us to store the ciphertext by writing to overlapping
+	 * regions of memory.  Any invalid ciphertext blocks get overwritten by
+	 * correctly computed blocks.  This approach greatly simplifies the
+	 * logic for storing the ciphertext.
+	 */
 	mov		x16, #16
-	ands		x6, x4, #0xf
-	csel		x13, x6, x16, ne
+	ands		w7, BYTES_W, #0xf
+	csel		x13, x7, x16, ne
 
-ST5(	cmp		w4, #64 - (MAX_STRIDE << 4)	)
+ST5(	cmp		BYTES_W, #64 - (MAX_STRIDE << 4))
 ST5(	csel		x14, x16, xzr, gt		)
-	cmp		w4, #48 - (MAX_STRIDE << 4)
+	cmp		BYTES_W, #48 - (MAX_STRIDE << 4)
 	csel		x15, x16, xzr, gt
-	cmp		w4, #32 - (MAX_STRIDE << 4)
+	cmp		BYTES_W, #32 - (MAX_STRIDE << 4)
 	csel		x16, x16, xzr, gt
-	cmp		w4, #16 - (MAX_STRIDE << 4)
+	cmp		BYTES_W, #16 - (MAX_STRIDE << 4)
 
-	adr_l		x12, .Lcts_permute_table
-	add		x12, x12, x13
+	adr_l		x9, .Lcts_permute_table
+	add		x9, x9, x13
 	ble		.Lctrtail1x\xctr
 
-ST5(	ld1		{v5.16b}, [x1], x14		)
-	ld1		{v6.16b}, [x1], x15
-	ld1		{v7.16b}, [x1], x16
+ST5(	ld1		{v5.16b}, [IN], x14		)
+	ld1		{v6.16b}, [IN], x15
+	ld1		{v7.16b}, [IN], x16
 
 ST4(	bl		aes_encrypt_block4x		)
 ST5(	bl		aes_encrypt_block5x		)
 
-	ld1		{v8.16b}, [x1], x13
-	ld1		{v9.16b}, [x1]
-	ld1		{v10.16b}, [x12]
+	ld1		{v8.16b}, [IN], x13
+	ld1		{v9.16b}, [IN]
+	ld1		{v10.16b}, [x9]
 
 ST4(	eor		v6.16b, v6.16b, v0.16b		)
 ST4(	eor		v7.16b, v7.16b, v1.16b		)
@@ -477,35 +537,61 @@  ST5(	eor		v7.16b, v7.16b, v2.16b		)
 ST5(	eor		v8.16b, v8.16b, v3.16b		)
 ST5(	eor		v9.16b, v9.16b, v4.16b		)
 
-ST5(	st1		{v5.16b}, [x0], x14		)
-	st1		{v6.16b}, [x0], x15
-	st1		{v7.16b}, [x0], x16
-	add		x13, x13, x0
+ST5(	st1		{v5.16b}, [OUT], x14		)
+	st1		{v6.16b}, [OUT], x15
+	st1		{v7.16b}, [OUT], x16
+	add		x13, x13, OUT
 	st1		{v9.16b}, [x13]		// overlapping stores
-	st1		{v8.16b}, [x0]
+	st1		{v8.16b}, [OUT]
 	b		.Lctrout\xctr
 
 .Lctrtail1x\xctr:
-	sub		x7, x6, #16
-	csel		x6, x6, x7, eq
-	add		x1, x1, x6
-	add		x0, x0, x6
-	ld1		{v5.16b}, [x1]
-	ld1		{v6.16b}, [x0]
+	/*
+	 * Handle <= 16 bytes of plaintext
+	 *
+	 * This code always reads and writes 16 bytes. To avoid out of bounds
+	 * accesses, XCTR and CTR modes must use a temporary buffer when
+	 * encrypting/decrypting less than 16 bytes.
+	 */
+	sub		x8, x7, #16
+	csel		x7, x7, x8, eq
+	add		IN, IN, x7
+	add		OUT, OUT, x7
+	ld1		{v5.16b}, [IN]
+	ld1		{v6.16b}, [OUT]
 ST5(	mov		v3.16b, v4.16b			)
-	encrypt_block	v3, w3, x2, x8, w7
-	ld1		{v10.16b-v11.16b}, [x12]
+	encrypt_block	v3, ROUNDS_W, KEY, x8, w7
+	ld1		{v10.16b-v11.16b}, [x9]
 	tbl		v3.16b, {v3.16b}, v10.16b
 	sshr		v11.16b, v11.16b, #7
 	eor		v5.16b, v5.16b, v3.16b
 	bif		v5.16b, v6.16b, v11.16b
-	st1		{v5.16b}, [x0]
+	st1		{v5.16b}, [OUT]
 	b		.Lctrout\xctr
+
+	// Arguments
+	.unreq OUT
+	.unreq IN
+	.unreq KEY
+	.unreq ROUNDS_W
+	.unreq BYTES_W
+	.unreq IV
+	.unreq BYTE_CTR_W	// XCTR only
+	// Intermediate values
+	.unreq CTR_W		// XCTR only
+	.unreq CTR		// XCTR only
+	.unreq IV_PART
+	.unreq BLOCKS
+	.unreq BLOCKS_W
 .endm
 
 	/*
 	 * aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
 	 *		   int bytes, u8 ctr[])
+	 *
+	 * The input and output buffers must always be at least 16 bytes even if
+	 * encrypting/decrypting less than 16 bytes.  Otherwise out of bounds
+	 * accesses will occur.
 	 */
 
 AES_FUNC_START(aes_ctr_encrypt)
@@ -515,6 +601,10 @@  AES_FUNC_END(aes_ctr_encrypt)
 	/*
 	 * aes_xctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
 	 *		   int bytes, u8 const iv[], int byte_ctr)
+	 *
+	 * The input and output buffers must always be at least 16 bytes even if
+	 * encrypting/decrypting less than 16 bytes.  Otherwise out of bounds
+	 * accesses will occur.
 	 */
 
 AES_FUNC_START(aes_xctr_encrypt)