diff mbox series

[v10,18/50] crypto: ccp: Handle the legacy SEV command when SNP is enabled

Message ID 20231016132819.1002933-19-michael.roth@amd.com (mailing list archive)
State Not Applicable
Delegated to: Herbert Xu
Headers show
Series Add AMD Secure Nested Paging (SEV-SNP) Hypervisor Support | expand

Commit Message

Michael Roth Oct. 16, 2023, 1:27 p.m. UTC
From: Brijesh Singh <brijesh.singh@amd.com>

The behavior of the SEV-legacy commands is altered when the SNP firmware
is in the INIT state. When SNP is in INIT state, all the SEV-legacy
commands that cause the firmware to write to memory must be in the
firmware state before issuing the command..

A command buffer may contains a system physical address that the firmware
may write to. There are two cases that need to be handled:

1) system physical address points to a guest memory
2) system physical address points to a host memory

To handle the case #1, change the page state to the firmware in the RMP
table before issuing the command and restore the state to shared after the
command completes.

For the case #2, use a bounce buffer to complete the request.

Signed-off-by: Brijesh Singh <brijesh.singh@amd.com>
Signed-off-by: Ashish Kalra <ashish.kalra@amd.com>
Signed-off-by: Michael Roth <michael.roth@amd.com>
---
 drivers/crypto/ccp/sev-dev.c | 346 ++++++++++++++++++++++++++++++++++-
 drivers/crypto/ccp/sev-dev.h |  12 ++
 2 files changed, 348 insertions(+), 10 deletions(-)

Comments

Borislav Petkov Dec. 9, 2023, 3:36 p.m. UTC | #1
On Mon, Oct 16, 2023 at 08:27:47AM -0500, Michael Roth wrote:
> From: Brijesh Singh <brijesh.singh@amd.com>
> 
> The behavior of the SEV-legacy commands is altered when the SNP firmware
> is in the INIT state. When SNP is in INIT state, all the SEV-legacy
> commands that cause the firmware to write to memory must be in the
> firmware state before issuing the command..

I think this is trying to say that the *memory* must be in firmware
state before the command. Needs massaging.

> A command buffer may contains a system physical address that the firmware

"contain"

> may write to. There are two cases that need to be handled:
> 
> 1) system physical address points to a guest memory
> 2) system physical address points to a host memory

s/a //g
> 
> To handle the case #1, change the page state to the firmware in the RMP
> table before issuing the command and restore the state to shared after the
> command completes.
> 
> For the case #2, use a bounce buffer to complete the request.
> 
> Signed-off-by: Brijesh Singh <brijesh.singh@amd.com>
> Signed-off-by: Ashish Kalra <ashish.kalra@amd.com>
> Signed-off-by: Michael Roth <michael.roth@amd.com>
> ---
>  drivers/crypto/ccp/sev-dev.c | 346 ++++++++++++++++++++++++++++++++++-
>  drivers/crypto/ccp/sev-dev.h |  12 ++
>  2 files changed, 348 insertions(+), 10 deletions(-)
> 
> diff --git a/drivers/crypto/ccp/sev-dev.c b/drivers/crypto/ccp/sev-dev.c
> index ea21307a2b34..b574b0ef2b1f 100644
> --- a/drivers/crypto/ccp/sev-dev.c
> +++ b/drivers/crypto/ccp/sev-dev.c
> @@ -462,12 +462,295 @@ static int sev_write_init_ex_file_if_required(int cmd_id)
>  	return sev_write_init_ex_file();
>  }
>  
> +static int alloc_snp_host_map(struct sev_device *sev)

If this is allocating intermediary bounce buffers, then call the
function that it does exactly that. Or what "host_map" is the name
referring to?

> +{
> +	struct page *page;
> +	int i;
> +
> +	for (i = 0; i < MAX_SNP_HOST_MAP_BUFS; i++) {
> +		struct snp_host_map *map = &sev->snp_host_map[i];
> +
> +		memset(map, 0, sizeof(*map));
> +
> +		page = alloc_pages(GFP_KERNEL_ACCOUNT, get_order(SEV_FW_BLOB_MAX_SIZE));
> +		if (!page)
> +			return -ENOMEM;

If the second allocation fails, you just leaked the first one.

> +		map->host = page_address(page);
> +	}
> +
> +	return 0;
> +}
> +
> +static void free_snp_host_map(struct sev_device *sev)
> +{
> +	int i;
> +
> +	for (i = 0; i < MAX_SNP_HOST_MAP_BUFS; i++) {
> +		struct snp_host_map *map = &sev->snp_host_map[i];
> +
> +		if (map->host) {
> +			__free_pages(virt_to_page(map->host), get_order(SEV_FW_BLOB_MAX_SIZE));
> +			memset(map, 0, sizeof(*map));
> +		}
> +	}
> +}
> +
> +static int map_firmware_writeable(u64 *paddr, u32 len, bool guest, struct snp_host_map *map)

Why is paddr a pointer? You simply pass a "unsigned long paddr" like the
rest of the gazillion functions dealing with addresses.

And then you do the ERR_PTR, PTR_ERR thing for the return value of this
function, see include/linux/err.h.

> +{
> +	unsigned int npages = PAGE_ALIGN(len) >> PAGE_SHIFT;
> +
> +	map->active = false;

This toggling of active on function entry and exit is silly.

The usual way to do those things is to mark it as active as the last
step of the map function, when everything has succeeded and to mark it
as inactive (active == false) as the first step in the unmap function.

> +
> +	if (!paddr || !len)
> +		return 0;
> +
> +	map->paddr = *paddr;
> +	map->len = len;
> +
> +	/* If paddr points to a guest memory then change the page state to firmwware. */
> +	if (guest) {
> +		if (rmp_mark_pages_firmware(*paddr, npages, true))
> +			return -EFAULT;
> +
> +		goto done;
> +	}

This is where it tells you that this function wants splitting:

map_guest_firmware_pages
map_host_firmware_pages

or so.

And then you lose the @guest argument too and you call the different
functions depending on the SEV cmd.

> +
> +	if (!map->host)

What in the hell is ->host?! SPA is host memory?

Comments please.

> +		return -ENOMEM;
> +
> +	/* Check if the pre-allocated buffer can be used to fullfil the request. */

"fulfill"

> +	if (len > SEV_FW_BLOB_MAX_SIZE)
> +		return -EINVAL;
> +
> +	/* Transition the pre-allocated buffer to the firmware state. */
> +	if (rmp_mark_pages_firmware(__pa(map->host), npages, true))
> +		return -EFAULT;
> +
> +	/* Set the paddr to use pre-allocated firmware buffer */
> +	*paddr = __psp_pa(map->host);
> +
> +done:
> +	map->active = true;
> +	return 0;
> +}


> +
> +static int unmap_firmware_writeable(u64 *paddr, u32 len, bool guest, struct snp_host_map *map)
> +{
> +	unsigned int npages = PAGE_ALIGN(len) >> PAGE_SHIFT;
> +
> +	if (!map->active)

Same comments as above for that one.

> +		return 0;
> +
> +	/* If paddr points to a guest memory then restore the page state to hypervisor. */
> +	if (guest) {
> +		if (snp_reclaim_pages(*paddr, npages, true))
> +			return -EFAULT;
> +
> +		goto done;
> +	}
> +
> +	/*
> +	 * Transition the pre-allocated buffer to hypervisor state before the access.
> +	 *
> +	 * This is because while changing the page state to firmware, the kernel unmaps
> +	 * the pages from the direct map, and to restore the direct map the pages must
> +	 * be transitioned back to the shared state.
> +	 */
> +	if (snp_reclaim_pages(__pa(map->host), npages, true))
> +		return -EFAULT;
> +
> +	/* Copy the response data firmware buffer to the callers buffer. */
> +	memcpy(__va(__sme_clr(map->paddr)), map->host, min_t(size_t, len, map->len));

This is not testing whether map->host is NULL as the above counterpart.

> +	*paddr = map->paddr;
> +
> +done:
> +	map->active = false;
> +	return 0;
> +}
> +
> +static bool sev_legacy_cmd_buf_writable(int cmd)
> +{
> +	switch (cmd) {
> +	case SEV_CMD_PLATFORM_STATUS:
> +	case SEV_CMD_GUEST_STATUS:
> +	case SEV_CMD_LAUNCH_START:
> +	case SEV_CMD_RECEIVE_START:
> +	case SEV_CMD_LAUNCH_MEASURE:
> +	case SEV_CMD_SEND_START:
> +	case SEV_CMD_SEND_UPDATE_DATA:
> +	case SEV_CMD_SEND_UPDATE_VMSA:
> +	case SEV_CMD_PEK_CSR:
> +	case SEV_CMD_PDH_CERT_EXPORT:
> +	case SEV_CMD_GET_ID:
> +	case SEV_CMD_ATTESTATION_REPORT:
> +		return true;
> +	default:
> +		return false;
> +	}
> +}
> +
> +#define prep_buffer(name, addr, len, guest, map) \
> +	func(&((typeof(name *))cmd_buf)->addr, ((typeof(name *))cmd_buf)->len, guest, map)
> +
> +static int __snp_cmd_buf_copy(int cmd, void *cmd_buf, bool to_fw, int fw_err)
> +{
> +	int (*func)(u64 *paddr, u32 len, bool guest, struct snp_host_map *map);
> +	struct sev_device *sev = psp_master->sev_data;
> +	bool from_fw = !to_fw;
> +
> +	/*
> +	 * After the command is completed, change the command buffer memory to
> +	 * hypervisor state.
> +	 *
> +	 * The immutable bit is automatically cleared by the firmware, so
> +	 * no not need to reclaim the page.
> +	 */
> +	if (from_fw && sev_legacy_cmd_buf_writable(cmd)) {
> +		if (snp_reclaim_pages(__pa(cmd_buf), 1, true))
> +			return -EFAULT;
> +
> +		/* No need to go further if firmware failed to execute command. */
> +		if (fw_err)
> +			return 0;
> +	}
> +
> +	if (to_fw)
> +		func = map_firmware_writeable;
> +	else
> +		func = unmap_firmware_writeable;

Eww, ugly and with the macro above even worse. And completely
unnecessary.

Define prep_buffer() as a normal function which selects which @func to
call and then does it. Not like this.

...

> +static inline bool need_firmware_copy(int cmd)
> +{
> +	struct sev_device *sev = psp_master->sev_data;
> +
> +	/* After SNP is INIT'ed, the behavior of legacy SEV command is changed. */

"initialized"

> +	return ((cmd < SEV_CMD_SNP_INIT) && sev->snp_initialized) ? true : false;

redundant ternary conditional:

	return cmd < SEV_CMD_SNP_INIT && sev->snp_initialized;

> +}
> +
> +static int snp_aware_copy_to_firmware(int cmd, void *data)

What does "SNP aware" even mean?

> +{
> +	return __snp_cmd_buf_copy(cmd, data, true, 0);
> +}
> +
> +static int snp_aware_copy_from_firmware(int cmd, void *data, int fw_err)
> +{
> +	return __snp_cmd_buf_copy(cmd, data, false, fw_err);
> +}
> +
>  static int __sev_do_cmd_locked(int cmd, void *data, int *psp_ret)
>  {
>  	struct psp_device *psp = psp_master;
>  	struct sev_device *sev;
>  	unsigned int phys_lsb, phys_msb;
>  	unsigned int reg, ret = 0;
> +	void *cmd_buf;
>  	int buf_len;
>  
>  	if (!psp || !psp->sev_data)
> @@ -487,12 +770,28 @@ static int __sev_do_cmd_locked(int cmd, void *data, int *psp_ret)
>  	 * work for some memory, e.g. vmalloc'd addresses, and @data may not be
>  	 * physically contiguous.
>  	 */
> -	if (data)
> -		memcpy(sev->cmd_buf, data, buf_len);
> +	if (data) {
> +		if (sev->cmd_buf_active > 2)

What is that silly counter supposed to mean?

Nested SNP commands?

> +			return -EBUSY;
> +
> +		cmd_buf = sev->cmd_buf_active ? sev->cmd_buf_backup : sev->cmd_buf;
> +
> +		memcpy(cmd_buf, data, buf_len);
> +		sev->cmd_buf_active++;
> +
> +		/*
> +		 * The behavior of the SEV-legacy commands is altered when the
> +		 * SNP firmware is in the INIT state.
> +		 */
> +		if (need_firmware_copy(cmd) && snp_aware_copy_to_firmware(cmd, cmd_buf))

Move that need_firmware_copy() check inside snp_aware_copy_to_firmware()
and the other one.

> +			return -EFAULT;
> +	} else {
> +		cmd_buf = sev->cmd_buf;
> +	}
>  
>  	/* Get the physical address of the command buffer */
> -	phys_lsb = data ? lower_32_bits(__psp_pa(sev->cmd_buf)) : 0;
> -	phys_msb = data ? upper_32_bits(__psp_pa(sev->cmd_buf)) : 0;
> +	phys_lsb = data ? lower_32_bits(__psp_pa(cmd_buf)) : 0;
> +	phys_msb = data ? upper_32_bits(__psp_pa(cmd_buf)) : 0;
>  
>  	dev_dbg(sev->dev, "sev command id %#x buffer 0x%08x%08x timeout %us\n",
>  		cmd, phys_msb, phys_lsb, psp_timeout);

...

> @@ -639,6 +947,14 @@ static int ___sev_platform_init_locked(int *error, bool probe)
>  	if (probe && !psp_init_on_probe)
>  		return 0;
>  
> +	/*
> +	 * Allocate the intermediate buffers used for the legacy command handling.
> +	 */
> +	if (rc != -ENODEV && alloc_snp_host_map(sev)) {

Why isn't this

	if (!rc && ...)

> +		dev_notice(sev->dev, "Failed to alloc host map (disabling legacy SEV)\n");
> +		goto skip_legacy;

No need for that skip_legacy silly label. Just "return 0" here.

...

Thx.
Michael Roth Dec. 29, 2023, 9:38 p.m. UTC | #2
On Sat, Dec 09, 2023 at 04:36:56PM +0100, Borislav Petkov wrote:
> > +static int __snp_cmd_buf_copy(int cmd, void *cmd_buf, bool to_fw, int fw_err)
> > +{
> > +	int (*func)(u64 *paddr, u32 len, bool guest, struct snp_host_map *map);
> > +	struct sev_device *sev = psp_master->sev_data;
> > +	bool from_fw = !to_fw;
> > +
> > +	/*
> > +	 * After the command is completed, change the command buffer memory to
> > +	 * hypervisor state.
> > +	 *
> > +	 * The immutable bit is automatically cleared by the firmware, so
> > +	 * no not need to reclaim the page.
> > +	 */
> > +	if (from_fw && sev_legacy_cmd_buf_writable(cmd)) {
> > +		if (snp_reclaim_pages(__pa(cmd_buf), 1, true))
> > +			return -EFAULT;
> > +
> > +		/* No need to go further if firmware failed to execute command. */
> > +		if (fw_err)
> > +			return 0;
> > +	}
> > +
> > +	if (to_fw)
> > +		func = map_firmware_writeable;
> > +	else
> > +		func = unmap_firmware_writeable;
> 
> Eww, ugly and with the macro above even worse. And completely
> unnecessary.
> 
> Define prep_buffer() as a normal function which selects which @func to
> call and then does it. Not like this.

I've rewritten this using a descriptor array to handle buffers for
various command parameters, and switched to allocating bounce buffers
on-demand to avoid some of the init/cleanup coordination. I dont think
any of these are really performance critical and its only for legacy
support, but would be straightforward to add a cache of pre-allocated
buffers later if needed.

I've tried to document/name the helpers so the flow is a bit clearer.

-Mike

> 
> ...
> 
> > +static inline bool need_firmware_copy(int cmd)
> > +{
> > +	struct sev_device *sev = psp_master->sev_data;
> > +
> > +	/* After SNP is INIT'ed, the behavior of legacy SEV command is changed. */
> 
> "initialized"
> 
> > +	return ((cmd < SEV_CMD_SNP_INIT) && sev->snp_initialized) ? true : false;
> 
> redundant ternary conditional:
> 
> 	return cmd < SEV_CMD_SNP_INIT && sev->snp_initialized;
> 
> > +}
> > +
> > +static int snp_aware_copy_to_firmware(int cmd, void *data)
> 
> What does "SNP aware" even mean?
> 
> > +{
> > +	return __snp_cmd_buf_copy(cmd, data, true, 0);
> > +}
> > +
> > +static int snp_aware_copy_from_firmware(int cmd, void *data, int fw_err)
> > +{
> > +	return __snp_cmd_buf_copy(cmd, data, false, fw_err);
> > +}
> > +
> >  static int __sev_do_cmd_locked(int cmd, void *data, int *psp_ret)
> >  {
> >  	struct psp_device *psp = psp_master;
> >  	struct sev_device *sev;
> >  	unsigned int phys_lsb, phys_msb;
> >  	unsigned int reg, ret = 0;
> > +	void *cmd_buf;
> >  	int buf_len;
> >  
> >  	if (!psp || !psp->sev_data)
> > @@ -487,12 +770,28 @@ static int __sev_do_cmd_locked(int cmd, void *data, int *psp_ret)
> >  	 * work for some memory, e.g. vmalloc'd addresses, and @data may not be
> >  	 * physically contiguous.
> >  	 */
> > -	if (data)
> > -		memcpy(sev->cmd_buf, data, buf_len);
> > +	if (data) {
> > +		if (sev->cmd_buf_active > 2)
> 
> What is that silly counter supposed to mean?
> 
> Nested SNP commands?
> 
> > +			return -EBUSY;
> > +
> > +		cmd_buf = sev->cmd_buf_active ? sev->cmd_buf_backup : sev->cmd_buf;
> > +
> > +		memcpy(cmd_buf, data, buf_len);
> > +		sev->cmd_buf_active++;
> > +
> > +		/*
> > +		 * The behavior of the SEV-legacy commands is altered when the
> > +		 * SNP firmware is in the INIT state.
> > +		 */
> > +		if (need_firmware_copy(cmd) && snp_aware_copy_to_firmware(cmd, cmd_buf))
> 
> Move that need_firmware_copy() check inside snp_aware_copy_to_firmware()
> and the other one.
> 
> > +			return -EFAULT;
> > +	} else {
> > +		cmd_buf = sev->cmd_buf;
> > +	}
> >  
> >  	/* Get the physical address of the command buffer */
> > -	phys_lsb = data ? lower_32_bits(__psp_pa(sev->cmd_buf)) : 0;
> > -	phys_msb = data ? upper_32_bits(__psp_pa(sev->cmd_buf)) : 0;
> > +	phys_lsb = data ? lower_32_bits(__psp_pa(cmd_buf)) : 0;
> > +	phys_msb = data ? upper_32_bits(__psp_pa(cmd_buf)) : 0;
> >  
> >  	dev_dbg(sev->dev, "sev command id %#x buffer 0x%08x%08x timeout %us\n",
> >  		cmd, phys_msb, phys_lsb, psp_timeout);
> 
> ...
> 
> > @@ -639,6 +947,14 @@ static int ___sev_platform_init_locked(int *error, bool probe)
> >  	if (probe && !psp_init_on_probe)
> >  		return 0;
> >  
> > +	/*
> > +	 * Allocate the intermediate buffers used for the legacy command handling.
> > +	 */
> > +	if (rc != -ENODEV && alloc_snp_host_map(sev)) {
> 
> Why isn't this
> 
> 	if (!rc && ...)
> 
> > +		dev_notice(sev->dev, "Failed to alloc host map (disabling legacy SEV)\n");
> > +		goto skip_legacy;
> 
> No need for that skip_legacy silly label. Just "return 0" here.
> 
> ...
> 
> Thx.
> 
> -- 
> Regards/Gruss,
>     Boris.
> 
> https://people.kernel.org/tglx/notes-about-netiquette
>
diff mbox series

Patch

diff --git a/drivers/crypto/ccp/sev-dev.c b/drivers/crypto/ccp/sev-dev.c
index ea21307a2b34..b574b0ef2b1f 100644
--- a/drivers/crypto/ccp/sev-dev.c
+++ b/drivers/crypto/ccp/sev-dev.c
@@ -462,12 +462,295 @@  static int sev_write_init_ex_file_if_required(int cmd_id)
 	return sev_write_init_ex_file();
 }
 
+static int alloc_snp_host_map(struct sev_device *sev)
+{
+	struct page *page;
+	int i;
+
+	for (i = 0; i < MAX_SNP_HOST_MAP_BUFS; i++) {
+		struct snp_host_map *map = &sev->snp_host_map[i];
+
+		memset(map, 0, sizeof(*map));
+
+		page = alloc_pages(GFP_KERNEL_ACCOUNT, get_order(SEV_FW_BLOB_MAX_SIZE));
+		if (!page)
+			return -ENOMEM;
+
+		map->host = page_address(page);
+	}
+
+	return 0;
+}
+
+static void free_snp_host_map(struct sev_device *sev)
+{
+	int i;
+
+	for (i = 0; i < MAX_SNP_HOST_MAP_BUFS; i++) {
+		struct snp_host_map *map = &sev->snp_host_map[i];
+
+		if (map->host) {
+			__free_pages(virt_to_page(map->host), get_order(SEV_FW_BLOB_MAX_SIZE));
+			memset(map, 0, sizeof(*map));
+		}
+	}
+}
+
+static int map_firmware_writeable(u64 *paddr, u32 len, bool guest, struct snp_host_map *map)
+{
+	unsigned int npages = PAGE_ALIGN(len) >> PAGE_SHIFT;
+
+	map->active = false;
+
+	if (!paddr || !len)
+		return 0;
+
+	map->paddr = *paddr;
+	map->len = len;
+
+	/* If paddr points to a guest memory then change the page state to firmwware. */
+	if (guest) {
+		if (rmp_mark_pages_firmware(*paddr, npages, true))
+			return -EFAULT;
+
+		goto done;
+	}
+
+	if (!map->host)
+		return -ENOMEM;
+
+	/* Check if the pre-allocated buffer can be used to fullfil the request. */
+	if (len > SEV_FW_BLOB_MAX_SIZE)
+		return -EINVAL;
+
+	/* Transition the pre-allocated buffer to the firmware state. */
+	if (rmp_mark_pages_firmware(__pa(map->host), npages, true))
+		return -EFAULT;
+
+	/* Set the paddr to use pre-allocated firmware buffer */
+	*paddr = __psp_pa(map->host);
+
+done:
+	map->active = true;
+	return 0;
+}
+
+static int unmap_firmware_writeable(u64 *paddr, u32 len, bool guest, struct snp_host_map *map)
+{
+	unsigned int npages = PAGE_ALIGN(len) >> PAGE_SHIFT;
+
+	if (!map->active)
+		return 0;
+
+	/* If paddr points to a guest memory then restore the page state to hypervisor. */
+	if (guest) {
+		if (snp_reclaim_pages(*paddr, npages, true))
+			return -EFAULT;
+
+		goto done;
+	}
+
+	/*
+	 * Transition the pre-allocated buffer to hypervisor state before the access.
+	 *
+	 * This is because while changing the page state to firmware, the kernel unmaps
+	 * the pages from the direct map, and to restore the direct map the pages must
+	 * be transitioned back to the shared state.
+	 */
+	if (snp_reclaim_pages(__pa(map->host), npages, true))
+		return -EFAULT;
+
+	/* Copy the response data firmware buffer to the callers buffer. */
+	memcpy(__va(__sme_clr(map->paddr)), map->host, min_t(size_t, len, map->len));
+	*paddr = map->paddr;
+
+done:
+	map->active = false;
+	return 0;
+}
+
+static bool sev_legacy_cmd_buf_writable(int cmd)
+{
+	switch (cmd) {
+	case SEV_CMD_PLATFORM_STATUS:
+	case SEV_CMD_GUEST_STATUS:
+	case SEV_CMD_LAUNCH_START:
+	case SEV_CMD_RECEIVE_START:
+	case SEV_CMD_LAUNCH_MEASURE:
+	case SEV_CMD_SEND_START:
+	case SEV_CMD_SEND_UPDATE_DATA:
+	case SEV_CMD_SEND_UPDATE_VMSA:
+	case SEV_CMD_PEK_CSR:
+	case SEV_CMD_PDH_CERT_EXPORT:
+	case SEV_CMD_GET_ID:
+	case SEV_CMD_ATTESTATION_REPORT:
+		return true;
+	default:
+		return false;
+	}
+}
+
+#define prep_buffer(name, addr, len, guest, map) \
+	func(&((typeof(name *))cmd_buf)->addr, ((typeof(name *))cmd_buf)->len, guest, map)
+
+static int __snp_cmd_buf_copy(int cmd, void *cmd_buf, bool to_fw, int fw_err)
+{
+	int (*func)(u64 *paddr, u32 len, bool guest, struct snp_host_map *map);
+	struct sev_device *sev = psp_master->sev_data;
+	bool from_fw = !to_fw;
+
+	/*
+	 * After the command is completed, change the command buffer memory to
+	 * hypervisor state.
+	 *
+	 * The immutable bit is automatically cleared by the firmware, so
+	 * no not need to reclaim the page.
+	 */
+	if (from_fw && sev_legacy_cmd_buf_writable(cmd)) {
+		if (snp_reclaim_pages(__pa(cmd_buf), 1, true))
+			return -EFAULT;
+
+		/* No need to go further if firmware failed to execute command. */
+		if (fw_err)
+			return 0;
+	}
+
+	if (to_fw)
+		func = map_firmware_writeable;
+	else
+		func = unmap_firmware_writeable;
+
+	/*
+	 * A command buffer may contains a system physical address. If the address
+	 * points to a host memory then use an intermediate firmware page otherwise
+	 * change the page state in the RMP table.
+	 */
+	switch (cmd) {
+	case SEV_CMD_PDH_CERT_EXPORT:
+		if (prep_buffer(struct sev_data_pdh_cert_export, pdh_cert_address,
+				pdh_cert_len, false, &sev->snp_host_map[0]))
+			goto err;
+		if (prep_buffer(struct sev_data_pdh_cert_export, cert_chain_address,
+				cert_chain_len, false, &sev->snp_host_map[1]))
+			goto err;
+		break;
+	case SEV_CMD_GET_ID:
+		if (prep_buffer(struct sev_data_get_id, address, len,
+				false, &sev->snp_host_map[0]))
+			goto err;
+		break;
+	case SEV_CMD_PEK_CSR:
+		if (prep_buffer(struct sev_data_pek_csr, address, len,
+				false, &sev->snp_host_map[0]))
+			goto err;
+		break;
+	case SEV_CMD_LAUNCH_UPDATE_DATA:
+		if (prep_buffer(struct sev_data_launch_update_data, address, len,
+				true, &sev->snp_host_map[0]))
+			goto err;
+		break;
+	case SEV_CMD_LAUNCH_UPDATE_VMSA:
+		if (prep_buffer(struct sev_data_launch_update_vmsa, address, len,
+				true, &sev->snp_host_map[0]))
+			goto err;
+		break;
+	case SEV_CMD_LAUNCH_MEASURE:
+		if (prep_buffer(struct sev_data_launch_measure, address, len,
+				false, &sev->snp_host_map[0]))
+			goto err;
+		break;
+	case SEV_CMD_LAUNCH_UPDATE_SECRET:
+		if (prep_buffer(struct sev_data_launch_secret, guest_address, guest_len,
+				true, &sev->snp_host_map[0]))
+			goto err;
+		break;
+	case SEV_CMD_DBG_DECRYPT:
+		if (prep_buffer(struct sev_data_dbg, dst_addr, len, false,
+				&sev->snp_host_map[0]))
+			goto err;
+		break;
+	case SEV_CMD_DBG_ENCRYPT:
+		if (prep_buffer(struct sev_data_dbg, dst_addr, len, true,
+				&sev->snp_host_map[0]))
+			goto err;
+		break;
+	case SEV_CMD_ATTESTATION_REPORT:
+		if (prep_buffer(struct sev_data_attestation_report, address, len,
+				false, &sev->snp_host_map[0]))
+			goto err;
+		break;
+	case SEV_CMD_SEND_START:
+		if (prep_buffer(struct sev_data_send_start, session_address,
+				session_len, false, &sev->snp_host_map[0]))
+			goto err;
+		break;
+	case SEV_CMD_SEND_UPDATE_DATA:
+		if (prep_buffer(struct sev_data_send_update_data, hdr_address, hdr_len,
+				false, &sev->snp_host_map[0]))
+			goto err;
+		if (prep_buffer(struct sev_data_send_update_data, trans_address,
+				trans_len, false, &sev->snp_host_map[1]))
+			goto err;
+		break;
+	case SEV_CMD_SEND_UPDATE_VMSA:
+		if (prep_buffer(struct sev_data_send_update_vmsa, hdr_address, hdr_len,
+				false, &sev->snp_host_map[0]))
+			goto err;
+		if (prep_buffer(struct sev_data_send_update_vmsa, trans_address,
+				trans_len, false, &sev->snp_host_map[1]))
+			goto err;
+		break;
+	case SEV_CMD_RECEIVE_UPDATE_DATA:
+		if (prep_buffer(struct sev_data_receive_update_data, guest_address,
+				guest_len, true, &sev->snp_host_map[0]))
+			goto err;
+		break;
+	case SEV_CMD_RECEIVE_UPDATE_VMSA:
+		if (prep_buffer(struct sev_data_receive_update_vmsa, guest_address,
+				guest_len, true, &sev->snp_host_map[0]))
+			goto err;
+		break;
+	default:
+		break;
+	}
+
+	/* The command buffer need to be in the firmware state. */
+	if (to_fw && sev_legacy_cmd_buf_writable(cmd)) {
+		if (rmp_mark_pages_firmware(__pa(cmd_buf), 1, true))
+			return -EFAULT;
+	}
+
+	return 0;
+
+err:
+	return -EINVAL;
+}
+
+static inline bool need_firmware_copy(int cmd)
+{
+	struct sev_device *sev = psp_master->sev_data;
+
+	/* After SNP is INIT'ed, the behavior of legacy SEV command is changed. */
+	return ((cmd < SEV_CMD_SNP_INIT) && sev->snp_initialized) ? true : false;
+}
+
+static int snp_aware_copy_to_firmware(int cmd, void *data)
+{
+	return __snp_cmd_buf_copy(cmd, data, true, 0);
+}
+
+static int snp_aware_copy_from_firmware(int cmd, void *data, int fw_err)
+{
+	return __snp_cmd_buf_copy(cmd, data, false, fw_err);
+}
+
 static int __sev_do_cmd_locked(int cmd, void *data, int *psp_ret)
 {
 	struct psp_device *psp = psp_master;
 	struct sev_device *sev;
 	unsigned int phys_lsb, phys_msb;
 	unsigned int reg, ret = 0;
+	void *cmd_buf;
 	int buf_len;
 
 	if (!psp || !psp->sev_data)
@@ -487,12 +770,28 @@  static int __sev_do_cmd_locked(int cmd, void *data, int *psp_ret)
 	 * work for some memory, e.g. vmalloc'd addresses, and @data may not be
 	 * physically contiguous.
 	 */
-	if (data)
-		memcpy(sev->cmd_buf, data, buf_len);
+	if (data) {
+		if (sev->cmd_buf_active > 2)
+			return -EBUSY;
+
+		cmd_buf = sev->cmd_buf_active ? sev->cmd_buf_backup : sev->cmd_buf;
+
+		memcpy(cmd_buf, data, buf_len);
+		sev->cmd_buf_active++;
+
+		/*
+		 * The behavior of the SEV-legacy commands is altered when the
+		 * SNP firmware is in the INIT state.
+		 */
+		if (need_firmware_copy(cmd) && snp_aware_copy_to_firmware(cmd, cmd_buf))
+			return -EFAULT;
+	} else {
+		cmd_buf = sev->cmd_buf;
+	}
 
 	/* Get the physical address of the command buffer */
-	phys_lsb = data ? lower_32_bits(__psp_pa(sev->cmd_buf)) : 0;
-	phys_msb = data ? upper_32_bits(__psp_pa(sev->cmd_buf)) : 0;
+	phys_lsb = data ? lower_32_bits(__psp_pa(cmd_buf)) : 0;
+	phys_msb = data ? upper_32_bits(__psp_pa(cmd_buf)) : 0;
 
 	dev_dbg(sev->dev, "sev command id %#x buffer 0x%08x%08x timeout %us\n",
 		cmd, phys_msb, phys_lsb, psp_timeout);
@@ -533,15 +832,24 @@  static int __sev_do_cmd_locked(int cmd, void *data, int *psp_ret)
 		ret = sev_write_init_ex_file_if_required(cmd);
 	}
 
-	print_hex_dump_debug("(out): ", DUMP_PREFIX_OFFSET, 16, 2, data,
-			     buf_len, false);
-
 	/*
 	 * Copy potential output from the PSP back to data.  Do this even on
 	 * failure in case the caller wants to glean something from the error.
 	 */
-	if (data)
-		memcpy(data, sev->cmd_buf, buf_len);
+	if (data) {
+		/*
+		 * Restore the page state after the command completes.
+		 */
+		if (need_firmware_copy(cmd) &&
+		    snp_aware_copy_from_firmware(cmd, cmd_buf, ret))
+			return -EFAULT;
+
+		memcpy(data, cmd_buf, buf_len);
+		sev->cmd_buf_active--;
+	}
+
+	print_hex_dump_debug("(out): ", DUMP_PREFIX_OFFSET, 16, 2, data,
+			     buf_len, false);
 
 	return ret;
 }
@@ -639,6 +947,14 @@  static int ___sev_platform_init_locked(int *error, bool probe)
 	if (probe && !psp_init_on_probe)
 		return 0;
 
+	/*
+	 * Allocate the intermediate buffers used for the legacy command handling.
+	 */
+	if (rc != -ENODEV && alloc_snp_host_map(sev)) {
+		dev_notice(sev->dev, "Failed to alloc host map (disabling legacy SEV)\n");
+		goto skip_legacy;
+	}
+
 	if (!sev_es_tmr) {
 		/* Obtain the TMR memory area for SEV-ES use */
 		sev_es_tmr = sev_fw_alloc(sev_es_tmr_size);
@@ -691,6 +1007,7 @@  static int ___sev_platform_init_locked(int *error, bool probe)
 	dev_info(sev->dev, "SEV API:%d.%d build:%d\n", sev->api_major,
 		 sev->api_minor, sev->build);
 
+skip_legacy:
 	return 0;
 }
 
@@ -1616,10 +1933,12 @@  int sev_dev_init(struct psp_device *psp)
 	if (!sev)
 		goto e_err;
 
-	sev->cmd_buf = (void *)devm_get_free_pages(dev, GFP_KERNEL, 0);
+	sev->cmd_buf = (void *)devm_get_free_pages(dev, GFP_KERNEL, 1);
 	if (!sev->cmd_buf)
 		goto e_sev;
 
+	sev->cmd_buf_backup = (uint8_t *)sev->cmd_buf + PAGE_SIZE;
+
 	psp->sev_data = sev;
 
 	sev->dev = dev;
@@ -1685,6 +2004,12 @@  static void sev_firmware_shutdown(struct sev_device *sev)
 		snp_range_list = NULL;
 	}
 
+	/*
+	 * The host map need to clear the immutable bit so it must be free'd before the
+	 * SNP firmware shutdown.
+	 */
+	free_snp_host_map(sev);
+
 	sev_snp_shutdown(&error);
 }
 
@@ -1753,6 +2078,7 @@  void sev_pci_init(void)
 	return;
 
 err:
+	free_snp_host_map(sev);
 	psp_master->sev_data = NULL;
 }
 
diff --git a/drivers/crypto/ccp/sev-dev.h b/drivers/crypto/ccp/sev-dev.h
index 85506325051a..2c2fe42189a5 100644
--- a/drivers/crypto/ccp/sev-dev.h
+++ b/drivers/crypto/ccp/sev-dev.h
@@ -29,11 +29,20 @@ 
 #define SEV_CMD_COMPLETE		BIT(1)
 #define SEV_CMDRESP_IOC			BIT(0)
 
+#define MAX_SNP_HOST_MAP_BUFS		2
+
 struct sev_misc_dev {
 	struct kref refcount;
 	struct miscdevice misc;
 };
 
+struct snp_host_map {
+	u64 paddr;
+	u32 len;
+	void *host;
+	bool active;
+};
+
 struct sev_device {
 	struct device *dev;
 	struct psp_device *psp;
@@ -52,8 +61,11 @@  struct sev_device {
 	u8 build;
 
 	void *cmd_buf;
+	void *cmd_buf_backup;
+	int cmd_buf_active;
 
 	bool snp_initialized;
+	struct snp_host_map snp_host_map[MAX_SNP_HOST_MAP_BUFS];
 };
 
 int sev_dev_init(struct psp_device *psp);