diff mbox series

[v5,04/15] crypto: x86/sha256-ni - add support for finup_mb

Message ID 20240611034822.36603-5-ebiggers@kernel.org (mailing list archive)
State Superseded, archived
Delegated to: Mike Snitzer
Headers show
Series Optimize dm-verity and fsverity using multibuffer hashing | expand

Commit Message

Eric Biggers June 11, 2024, 3:48 a.m. UTC
From: Eric Biggers <ebiggers@google.com>

Add an implementation of finup_mb to sha256-ni, using an interleaving
factor of 2.  It interleaves a finup operation for two equal-length
messages that share a common prefix.  dm-verity and fs-verity will take
advantage of this for greatly improved performance on capable CPUs.

This increases the throughput of SHA-256 hashing 4096-byte messages by
the following amounts on the following CPUs:

    AMD Zen 1:                  84%
    AMD Zen 4:                  98%
    Intel Ice Lake:              4%
    Intel Sapphire Rapids:      20%

For now, this seems to benefit AMD much more than Intel.  This seems to
be because current AMD CPUs support concurrent execution of the SHA-NI
instructions, but unfortunately current Intel CPUs don't, except for the
sha256msg2 instruction.  Hopefully future Intel CPUs will support SHA-NI
on more execution ports.  Zen 1 supports 2 concurrent sha256rnds2, and
Zen 4 supports 4 concurrent sha256rnds2, which suggests that even better
performance may be achievable on Zen 4 by interleaving more than two
hashes; however, doing so poses a number of trade-offs.

It's been reported that the method that achieves the highest SHA-256
throughput on Intel CPUs is actually computing 16 hashes simultaneously
using AVX512.  That method would be quite different to the SHA-NI method
used in this patch.  However, such a high interleaving factor isn't
practical for the use cases being targeted in the kernel.

Signed-off-by: Eric Biggers <ebiggers@google.com>
---
 arch/x86/crypto/sha256_ni_asm.S     | 368 ++++++++++++++++++++++++++++
 arch/x86/crypto/sha256_ssse3_glue.c |  39 +++
 2 files changed, 407 insertions(+)

Comments

Herbert Xu June 12, 2024, 9:42 a.m. UTC | #1
On Mon, Jun 10, 2024 at 08:48:11PM -0700, Eric Biggers wrote:
> From: Eric Biggers <ebiggers@google.com>
> 
> Add an implementation of finup_mb to sha256-ni, using an interleaving
> factor of 2.  It interleaves a finup operation for two equal-length
> messages that share a common prefix.  dm-verity and fs-verity will take

I think the limitation on equal length is artificial.  There is
no reason why the code couldn't handle two messages with different
lengths.  Simply execute in dual mode up until the shorter message
runs out.  Then carry on as if you have a single message.

In fact, there is no reason why the two hashes have to start from
the same initial state either.  It has no bearing on the performance
of the actual hashing as far as I can see.

Cheers,
Eric Biggers June 12, 2024, 3:27 p.m. UTC | #2
On Wed, Jun 12, 2024 at 05:42:20PM +0800, Herbert Xu wrote:
> On Mon, Jun 10, 2024 at 08:48:11PM -0700, Eric Biggers wrote:
> > From: Eric Biggers <ebiggers@google.com>
> > 
> > Add an implementation of finup_mb to sha256-ni, using an interleaving
> > factor of 2.  It interleaves a finup operation for two equal-length
> > messages that share a common prefix.  dm-verity and fs-verity will take
> 
> I think the limitation on equal length is artificial.  There is
> no reason why the code couldn't handle two messages with different
> lengths.  Simply execute in dual mode up until the shorter message
> runs out.  Then carry on as if you have a single message.

Sure, as I mentioned the algorithm could fall back to single-buffer hashing once
the messages get out of sync.  This would actually have to be implemented and
tested, of course, which gets especially tricky with your proposal to support
arbitrary scatterlists.  And there's no actual use case for adding that
complexity yet.

> In fact, there is no reason why the two hashes have to start from
> the same initial state either.  It has no bearing on the performance
> of the actual hashing as far as I can see.

The SHA-256 inner loop would indeed be the same, but the single state has
several advantages:

- The caller only needs to allocate and prepare a single state.  This saves
  per-IO memory and reduces overhead.
- The glue code doesn't need to check that the number of internally buffered
  bytes are synced up.
- The assembly code only needs to load from the one state.

All of this simplifies the code slightly and boosts performance slightly.

These advantages aren't *too* large, of course, and if a use case for supporting
update arose, then support for multiple states would be added.  But it doesn't
make sense to add this functionality prematurely before it actually has a user.

- Eric
diff mbox series

Patch

diff --git a/arch/x86/crypto/sha256_ni_asm.S b/arch/x86/crypto/sha256_ni_asm.S
index d515a55a3bc1..5e97922a24e4 100644
--- a/arch/x86/crypto/sha256_ni_asm.S
+++ b/arch/x86/crypto/sha256_ni_asm.S
@@ -172,10 +172,378 @@  SYM_TYPED_FUNC_START(sha256_ni_transform)
 .Ldone_hash:
 
 	RET
 SYM_FUNC_END(sha256_ni_transform)
 
+#undef DIGEST_PTR
+#undef DATA_PTR
+#undef NUM_BLKS
+#undef SHA256CONSTANTS
+#undef MSG
+#undef STATE0
+#undef STATE1
+#undef MSG0
+#undef MSG1
+#undef MSG2
+#undef MSG3
+#undef TMP
+#undef SHUF_MASK
+#undef ABEF_SAVE
+#undef CDGH_SAVE
+
+// parameters for __sha256_ni_finup2x()
+#define SCTX		%rdi
+#define DATA1		%rsi
+#define DATA2		%rdx
+#define LEN		%ecx
+#define LEN8		%cl
+#define LEN64		%rcx
+#define OUT1		%r8
+#define OUT2		%r9
+
+// other scalar variables
+#define SHA256CONSTANTS	%rax
+#define COUNT		%r10
+#define COUNT32		%r10d
+#define FINAL_STEP	%r11d
+
+// rbx is used as a temporary.
+
+#define MSG		%xmm0	// sha256rnds2 implicit operand
+#define STATE0_A	%xmm1
+#define STATE1_A	%xmm2
+#define STATE0_B	%xmm3
+#define STATE1_B	%xmm4
+#define TMP_A		%xmm5
+#define TMP_B		%xmm6
+#define MSG0_A		%xmm7
+#define MSG1_A		%xmm8
+#define MSG2_A		%xmm9
+#define MSG3_A		%xmm10
+#define MSG0_B		%xmm11
+#define MSG1_B		%xmm12
+#define MSG2_B		%xmm13
+#define MSG3_B		%xmm14
+#define SHUF_MASK	%xmm15
+
+#define OFFSETOF_STATE	0	// offsetof(struct sha256_state, state)
+#define OFFSETOF_COUNT	32	// offsetof(struct sha256_state, count)
+#define OFFSETOF_BUF	40	// offsetof(struct sha256_state, buf)
+
+// Do 4 rounds of SHA-256 for each of two messages (interleaved).  m0_a and m0_b
+// contain the current 4 message schedule words for the first and second message
+// respectively.
+//
+// If not all the message schedule words have been computed yet, then this also
+// computes 4 more message schedule words for each message.  m1_a-m3_a contain
+// the next 3 groups of 4 message schedule words for the first message, and
+// likewise m1_b-m3_b for the second.  After consuming the current value of
+// m0_a, this macro computes the group after m3_a and writes it to m0_a, and
+// likewise for *_b.  This means that the next (m0_a, m1_a, m2_a, m3_a) is the
+// current (m1_a, m2_a, m3_a, m0_a), and likewise for *_b, so the caller must
+// cycle through the registers accordingly.
+.macro	do_4rounds_2x	i, m0_a, m1_a, m2_a, m3_a,  m0_b, m1_b, m2_b, m3_b
+	movdqa		(\i-32)*4(SHA256CONSTANTS), TMP_A
+	movdqa		TMP_A, TMP_B
+	paddd		\m0_a, TMP_A
+	paddd		\m0_b, TMP_B
+.if \i < 48
+	sha256msg1	\m1_a, \m0_a
+	sha256msg1	\m1_b, \m0_b
+.endif
+	movdqa		TMP_A, MSG
+	sha256rnds2	STATE0_A, STATE1_A
+	movdqa		TMP_B, MSG
+	sha256rnds2	STATE0_B, STATE1_B
+	pshufd 		$0x0E, TMP_A, MSG
+	sha256rnds2	STATE1_A, STATE0_A
+	pshufd 		$0x0E, TMP_B, MSG
+	sha256rnds2	STATE1_B, STATE0_B
+.if \i < 48
+	movdqa		\m3_a, TMP_A
+	movdqa		\m3_b, TMP_B
+	palignr		$4, \m2_a, TMP_A
+	palignr		$4, \m2_b, TMP_B
+	paddd		TMP_A, \m0_a
+	paddd		TMP_B, \m0_b
+	sha256msg2	\m3_a, \m0_a
+	sha256msg2	\m3_b, \m0_b
+.endif
+.endm
+
+//
+// void __sha256_ni_finup2x(const struct sha256_state *sctx,
+//			    const u8 *data1, const u8 *data2, int len,
+//			    u8 out1[SHA256_DIGEST_SIZE],
+//			    u8 out2[SHA256_DIGEST_SIZE]);
+//
+// This function computes the SHA-256 digests of two messages |data1| and
+// |data2| that are both |len| bytes long, starting from the initial state
+// |sctx|.  |len| must be at least SHA256_BLOCK_SIZE.
+//
+// The instructions for the two SHA-256 operations are interleaved.  On many
+// CPUs, this is almost twice as fast as hashing each message individually due
+// to taking better advantage of the CPU's SHA-256 and SIMD throughput.
+//
+SYM_FUNC_START(__sha256_ni_finup2x)
+	// Allocate 128 bytes of stack space, 16-byte aligned.
+	push		%rbx
+	push		%rbp
+	mov		%rsp, %rbp
+	sub		$128, %rsp
+	and		$~15, %rsp
+
+	// Load the shuffle mask for swapping the endianness of 32-bit words.
+	movdqa		PSHUFFLE_BYTE_FLIP_MASK(%rip), SHUF_MASK
+
+	// Set up pointer to the round constants.
+	lea		K256+32*4(%rip), SHA256CONSTANTS
+
+	// Initially we're not processing the final blocks.
+	xor		FINAL_STEP, FINAL_STEP
+
+	// Load the initial state from sctx->state.
+	movdqu		OFFSETOF_STATE+0*16(SCTX), STATE0_A	// DCBA
+	movdqu		OFFSETOF_STATE+1*16(SCTX), STATE1_A	// HGFE
+	movdqa		STATE0_A, TMP_A
+	punpcklqdq	STATE1_A, STATE0_A			// FEBA
+	punpckhqdq	TMP_A, STATE1_A				// DCHG
+	pshufd		$0x1B, STATE0_A, STATE0_A		// ABEF
+	pshufd		$0xB1, STATE1_A, STATE1_A		// CDGH
+
+	// Load sctx->count.  Take the mod 64 of it to get the number of bytes
+	// that are buffered in sctx->buf.  Also save it in a register with LEN
+	// added to it.
+	mov		LEN, LEN
+	mov		OFFSETOF_COUNT(SCTX), %rbx
+	lea		(%rbx, LEN64, 1), COUNT
+	and		$63, %ebx
+	jz		.Lfinup2x_enter_loop	// No bytes buffered?
+
+	// %ebx bytes (1 to 63) are currently buffered in sctx->buf.  Load them
+	// followed by the first 64 - %ebx bytes of data.  Since LEN >= 64, we
+	// just load 64 bytes from each of sctx->buf, DATA1, and DATA2
+	// unconditionally and rearrange the data as needed.
+
+	movdqu		OFFSETOF_BUF+0*16(SCTX), MSG0_A
+	movdqu		OFFSETOF_BUF+1*16(SCTX), MSG1_A
+	movdqu		OFFSETOF_BUF+2*16(SCTX), MSG2_A
+	movdqu		OFFSETOF_BUF+3*16(SCTX), MSG3_A
+	movdqa		MSG0_A, 0*16(%rsp)
+	movdqa		MSG1_A, 1*16(%rsp)
+	movdqa		MSG2_A, 2*16(%rsp)
+	movdqa		MSG3_A, 3*16(%rsp)
+
+	movdqu		0*16(DATA1), MSG0_A
+	movdqu		1*16(DATA1), MSG1_A
+	movdqu		2*16(DATA1), MSG2_A
+	movdqu		3*16(DATA1), MSG3_A
+	movdqu		MSG0_A, 0*16(%rsp,%rbx)
+	movdqu		MSG1_A, 1*16(%rsp,%rbx)
+	movdqu		MSG2_A, 2*16(%rsp,%rbx)
+	movdqu		MSG3_A, 3*16(%rsp,%rbx)
+	movdqa		0*16(%rsp), MSG0_A
+	movdqa		1*16(%rsp), MSG1_A
+	movdqa		2*16(%rsp), MSG2_A
+	movdqa		3*16(%rsp), MSG3_A
+
+	movdqu		0*16(DATA2), MSG0_B
+	movdqu		1*16(DATA2), MSG1_B
+	movdqu		2*16(DATA2), MSG2_B
+	movdqu		3*16(DATA2), MSG3_B
+	movdqu		MSG0_B, 0*16(%rsp,%rbx)
+	movdqu		MSG1_B, 1*16(%rsp,%rbx)
+	movdqu		MSG2_B, 2*16(%rsp,%rbx)
+	movdqu		MSG3_B, 3*16(%rsp,%rbx)
+	movdqa		0*16(%rsp), MSG0_B
+	movdqa		1*16(%rsp), MSG1_B
+	movdqa		2*16(%rsp), MSG2_B
+	movdqa		3*16(%rsp), MSG3_B
+
+	sub		$64, %rbx 	// rbx = buffered - 64
+	sub		%rbx, DATA1	// DATA1 += 64 - buffered
+	sub		%rbx, DATA2	// DATA2 += 64 - buffered
+	add		%ebx, LEN	// LEN += buffered - 64
+	movdqa		STATE0_A, STATE0_B
+	movdqa		STATE1_A, STATE1_B
+	jmp		.Lfinup2x_loop_have_data
+
+.Lfinup2x_enter_loop:
+	sub		$64, LEN
+	movdqa		STATE0_A, STATE0_B
+	movdqa		STATE1_A, STATE1_B
+.Lfinup2x_loop:
+	// Load the next two data blocks.
+	movdqu		0*16(DATA1), MSG0_A
+	movdqu		0*16(DATA2), MSG0_B
+	movdqu		1*16(DATA1), MSG1_A
+	movdqu		1*16(DATA2), MSG1_B
+	movdqu		2*16(DATA1), MSG2_A
+	movdqu		2*16(DATA2), MSG2_B
+	movdqu		3*16(DATA1), MSG3_A
+	movdqu		3*16(DATA2), MSG3_B
+	add		$64, DATA1
+	add		$64, DATA2
+.Lfinup2x_loop_have_data:
+	// Convert the words of the data blocks from big endian.
+	pshufb		SHUF_MASK, MSG0_A
+	pshufb		SHUF_MASK, MSG0_B
+	pshufb		SHUF_MASK, MSG1_A
+	pshufb		SHUF_MASK, MSG1_B
+	pshufb		SHUF_MASK, MSG2_A
+	pshufb		SHUF_MASK, MSG2_B
+	pshufb		SHUF_MASK, MSG3_A
+	pshufb		SHUF_MASK, MSG3_B
+.Lfinup2x_loop_have_bswapped_data:
+
+	// Save the original state for each block.
+	movdqa		STATE0_A, 0*16(%rsp)
+	movdqa		STATE0_B, 1*16(%rsp)
+	movdqa		STATE1_A, 2*16(%rsp)
+	movdqa		STATE1_B, 3*16(%rsp)
+
+	// Do the SHA-256 rounds on each block.
+.irp i, 0, 16, 32, 48
+	do_4rounds_2x	(\i + 0),  MSG0_A, MSG1_A, MSG2_A, MSG3_A, \
+				   MSG0_B, MSG1_B, MSG2_B, MSG3_B
+	do_4rounds_2x	(\i + 4),  MSG1_A, MSG2_A, MSG3_A, MSG0_A, \
+				   MSG1_B, MSG2_B, MSG3_B, MSG0_B
+	do_4rounds_2x	(\i + 8),  MSG2_A, MSG3_A, MSG0_A, MSG1_A, \
+				   MSG2_B, MSG3_B, MSG0_B, MSG1_B
+	do_4rounds_2x	(\i + 12), MSG3_A, MSG0_A, MSG1_A, MSG2_A, \
+				   MSG3_B, MSG0_B, MSG1_B, MSG2_B
+.endr
+
+	// Add the original state for each block.
+	paddd		0*16(%rsp), STATE0_A
+	paddd		1*16(%rsp), STATE0_B
+	paddd		2*16(%rsp), STATE1_A
+	paddd		3*16(%rsp), STATE1_B
+
+	// Update LEN and loop back if more blocks remain.
+	sub		$64, LEN
+	jge		.Lfinup2x_loop
+
+	// Check if any final blocks need to be handled.
+	// FINAL_STEP = 2: all done
+	// FINAL_STEP = 1: need to do count-only padding block
+	// FINAL_STEP = 0: need to do the block with 0x80 padding byte
+	cmp		$1, FINAL_STEP
+	jg		.Lfinup2x_done
+	je		.Lfinup2x_finalize_countonly
+	add		$64, LEN
+	jz		.Lfinup2x_finalize_blockaligned
+
+	// Not block-aligned; 1 <= LEN <= 63 data bytes remain.  Pad the block.
+	// To do this, write the padding starting with the 0x80 byte to
+	// &sp[64].  Then for each message, copy the last 64 data bytes to sp
+	// and load from &sp[64 - LEN] to get the needed padding block.  This
+	// code relies on the data buffers being >= 64 bytes in length.
+	mov		$64, %ebx
+	sub		LEN, %ebx		// ebx = 64 - LEN
+	sub		%rbx, DATA1		// DATA1 -= 64 - LEN
+	sub		%rbx, DATA2		// DATA2 -= 64 - LEN
+	mov		$0x80, FINAL_STEP   // using FINAL_STEP as a temporary
+	movd		FINAL_STEP, MSG0_A
+	pxor		MSG1_A, MSG1_A
+	movdqa		MSG0_A, 4*16(%rsp)
+	movdqa		MSG1_A, 5*16(%rsp)
+	movdqa		MSG1_A, 6*16(%rsp)
+	movdqa		MSG1_A, 7*16(%rsp)
+	cmp		$56, LEN
+	jge		1f	// will COUNT spill into its own block?
+	shl		$3, COUNT
+	bswap		COUNT
+	mov		COUNT, 56(%rsp,%rbx)
+	mov		$2, FINAL_STEP	// won't need count-only block
+	jmp		2f
+1:
+	mov		$1, FINAL_STEP	// will need count-only block
+2:
+	movdqu		0*16(DATA1), MSG0_A
+	movdqu		1*16(DATA1), MSG1_A
+	movdqu		2*16(DATA1), MSG2_A
+	movdqu		3*16(DATA1), MSG3_A
+	movdqa		MSG0_A, 0*16(%rsp)
+	movdqa		MSG1_A, 1*16(%rsp)
+	movdqa		MSG2_A, 2*16(%rsp)
+	movdqa		MSG3_A, 3*16(%rsp)
+	movdqu		0*16(%rsp,%rbx), MSG0_A
+	movdqu		1*16(%rsp,%rbx), MSG1_A
+	movdqu		2*16(%rsp,%rbx), MSG2_A
+	movdqu		3*16(%rsp,%rbx), MSG3_A
+
+	movdqu		0*16(DATA2), MSG0_B
+	movdqu		1*16(DATA2), MSG1_B
+	movdqu		2*16(DATA2), MSG2_B
+	movdqu		3*16(DATA2), MSG3_B
+	movdqa		MSG0_B, 0*16(%rsp)
+	movdqa		MSG1_B, 1*16(%rsp)
+	movdqa		MSG2_B, 2*16(%rsp)
+	movdqa		MSG3_B, 3*16(%rsp)
+	movdqu		0*16(%rsp,%rbx), MSG0_B
+	movdqu		1*16(%rsp,%rbx), MSG1_B
+	movdqu		2*16(%rsp,%rbx), MSG2_B
+	movdqu		3*16(%rsp,%rbx), MSG3_B
+	jmp		.Lfinup2x_loop_have_data
+
+	// Prepare a padding block, either:
+	//
+	//	{0x80, 0, 0, 0, ..., count (as __be64)}
+	//	This is for a block aligned message.
+	//
+	//	{   0, 0, 0, 0, ..., count (as __be64)}
+	//	This is for a message whose length mod 64 is >= 56.
+	//
+	// Pre-swap the endianness of the words.
+.Lfinup2x_finalize_countonly:
+	pxor		MSG0_A, MSG0_A
+	jmp		1f
+
+.Lfinup2x_finalize_blockaligned:
+	mov		$0x80000000, %ebx
+	movd		%ebx, MSG0_A
+1:
+	pxor		MSG1_A, MSG1_A
+	pxor		MSG2_A, MSG2_A
+	ror		$29, COUNT
+	movq		COUNT, MSG3_A
+	pslldq		$8, MSG3_A
+	movdqa		MSG0_A, MSG0_B
+	pxor		MSG1_B, MSG1_B
+	pxor		MSG2_B, MSG2_B
+	movdqa		MSG3_A, MSG3_B
+	mov		$2, FINAL_STEP
+	jmp		.Lfinup2x_loop_have_bswapped_data
+
+.Lfinup2x_done:
+	// Write the two digests with all bytes in the correct order.
+	movdqa		STATE0_A, TMP_A
+	movdqa		STATE0_B, TMP_B
+	punpcklqdq	STATE1_A, STATE0_A		// GHEF
+	punpcklqdq	STATE1_B, STATE0_B
+	punpckhqdq	TMP_A, STATE1_A			// ABCD
+	punpckhqdq	TMP_B, STATE1_B
+	pshufd		$0xB1, STATE0_A, STATE0_A	// HGFE
+	pshufd		$0xB1, STATE0_B, STATE0_B
+	pshufd		$0x1B, STATE1_A, STATE1_A	// DCBA
+	pshufd		$0x1B, STATE1_B, STATE1_B
+	pshufb		SHUF_MASK, STATE0_A
+	pshufb		SHUF_MASK, STATE0_B
+	pshufb		SHUF_MASK, STATE1_A
+	pshufb		SHUF_MASK, STATE1_B
+	movdqu		STATE0_A, 1*16(OUT1)
+	movdqu		STATE0_B, 1*16(OUT2)
+	movdqu		STATE1_A, 0*16(OUT1)
+	movdqu		STATE1_B, 0*16(OUT2)
+
+	mov		%rbp, %rsp
+	pop		%rbp
+	pop		%rbx
+	RET
+SYM_FUNC_END(__sha256_ni_finup2x)
+
 .section	.rodata.cst256.K256, "aM", @progbits, 256
 .align 64
 K256:
 	.long	0x428a2f98,0x71374491,0xb5c0fbcf,0xe9b5dba5
 	.long	0x3956c25b,0x59f111f1,0x923f82a4,0xab1c5ed5
diff --git a/arch/x86/crypto/sha256_ssse3_glue.c b/arch/x86/crypto/sha256_ssse3_glue.c
index e04a43d9f7d5..ff688bb1d560 100644
--- a/arch/x86/crypto/sha256_ssse3_glue.c
+++ b/arch/x86/crypto/sha256_ssse3_glue.c
@@ -331,10 +331,15 @@  static void unregister_sha256_avx2(void)
 
 #ifdef CONFIG_AS_SHA256_NI
 asmlinkage void sha256_ni_transform(struct sha256_state *digest,
 				    const u8 *data, int rounds);
 
+asmlinkage void __sha256_ni_finup2x(const struct sha256_state *sctx,
+				    const u8 *data1, const u8 *data2, int len,
+				    u8 out1[SHA256_DIGEST_SIZE],
+				    u8 out2[SHA256_DIGEST_SIZE]);
+
 static int sha256_ni_update(struct shash_desc *desc, const u8 *data,
 			 unsigned int len)
 {
 	return _sha256_update(desc, data, len, sha256_ni_transform);
 }
@@ -355,18 +360,52 @@  static int sha256_ni_digest(struct shash_desc *desc, const u8 *data,
 {
 	return sha256_base_init(desc) ?:
 	       sha256_ni_finup(desc, data, len, out);
 }
 
+static int sha256_ni_finup_mb(struct shash_desc *desc,
+			      const u8 * const data[], unsigned int len,
+			      u8 * const outs[], unsigned int num_msgs)
+{
+	struct sha256_state *sctx = shash_desc_ctx(desc);
+
+	/*
+	 * num_msgs != 2 should not happen here, since this algorithm sets
+	 * mb_max_msgs=2, and the crypto API handles num_msgs <= 1 before
+	 * calling into the algorithm's finup_mb method.
+	 */
+	if (WARN_ON_ONCE(num_msgs != 2))
+		return -EOPNOTSUPP;
+
+	if (unlikely(!crypto_simd_usable()))
+		return -EOPNOTSUPP;
+
+	/* __sha256_ni_finup2x() assumes SHA256_BLOCK_SIZE <= len <= INT_MAX. */
+	if (unlikely(len < SHA256_BLOCK_SIZE || len > INT_MAX))
+		return -EOPNOTSUPP;
+
+	/* __sha256_ni_finup2x() assumes the following offsets. */
+	BUILD_BUG_ON(offsetof(struct sha256_state, state) != 0);
+	BUILD_BUG_ON(offsetof(struct sha256_state, count) != 32);
+	BUILD_BUG_ON(offsetof(struct sha256_state, buf) != 40);
+
+	kernel_fpu_begin();
+	__sha256_ni_finup2x(sctx, data[0], data[1], len, outs[0], outs[1]);
+	kernel_fpu_end();
+	return 0;
+}
+
 static struct shash_alg sha256_ni_algs[] = { {
 	.digestsize	=	SHA256_DIGEST_SIZE,
 	.init		=	sha256_base_init,
 	.update		=	sha256_ni_update,
 	.final		=	sha256_ni_final,
 	.finup		=	sha256_ni_finup,
 	.digest		=	sha256_ni_digest,
+	.finup_mb	=	sha256_ni_finup_mb,
 	.descsize	=	sizeof(struct sha256_state),
+	.mb_max_msgs	=	2,
 	.base		=	{
 		.cra_name	=	"sha256",
 		.cra_driver_name =	"sha256-ni",
 		.cra_priority	=	250,
 		.cra_blocksize	=	SHA256_BLOCK_SIZE,