diff mbox series

[v15,01/13] x86/sev: Carve out and export SNP guest messaging init routines

Message ID 20241203090045.942078-2-nikunj@amd.com (mailing list archive)
State New
Headers show
Series Add Secure TSC support for SNP guests | expand

Commit Message

Nikunj A. Dadhania Dec. 3, 2024, 9 a.m. UTC
Currently, the sev-guest driver is the only user of SNP guest messaging.
All routines for initializing SNP guest messaging are implemented within
the sev-guest driver and are not available during early boot. In
prepratation for adding Secure TSC guest support, carve out APIs to
allocate and initialize guest messaging descriptor context and make it part
of coco/sev/core.c. As there is no user of sev_guest_platform_data anymore,
remove the structure.

Signed-off-by: Nikunj A Dadhania <nikunj@amd.com>
Reviewed-by: Tom Lendacky <thomas.lendacky@amd.com>
---
 arch/x86/include/asm/sev.h              |  24 ++-
 arch/x86/coco/sev/core.c                | 183 +++++++++++++++++++++-
 drivers/virt/coco/sev-guest/sev-guest.c | 197 +++---------------------
 arch/x86/Kconfig                        |   1 +
 drivers/virt/coco/sev-guest/Kconfig     |   1 -
 5 files changed, 220 insertions(+), 186 deletions(-)

Comments

Borislav Petkov Dec. 3, 2024, 2:19 p.m. UTC | #1
On Tue, Dec 03, 2024 at 02:30:33PM +0530, Nikunj A Dadhania wrote:
> Currently, the sev-guest driver is the only user of SNP guest messaging.
> All routines for initializing SNP guest messaging are implemented within
> the sev-guest driver and are not available during early boot. In
> prepratation for adding Secure TSC guest support, carve out APIs to

Unknown word [prepratation] in commit message.
Suggestions: ['preparation', 'preparations', 'reparation', 'perpetration', 'reputation', 'perpetuation', 'peroration', 'presentation', 'repatriation', 'propagation', "preparation's"]

Please introduce a spellchecker into your patch creation workflow.

> allocate and initialize guest messaging descriptor context and make it part
> of coco/sev/core.c. As there is no user of sev_guest_platform_data anymore,
> remove the structure.
> 
> Signed-off-by: Nikunj A Dadhania <nikunj@amd.com>
> Reviewed-by: Tom Lendacky <thomas.lendacky@amd.com>
> ---
>  arch/x86/include/asm/sev.h              |  24 ++-
>  arch/x86/coco/sev/core.c                | 183 +++++++++++++++++++++-
>  drivers/virt/coco/sev-guest/sev-guest.c | 197 +++---------------------
>  arch/x86/Kconfig                        |   1 +
>  drivers/virt/coco/sev-guest/Kconfig     |   1 -
>  5 files changed, 220 insertions(+), 186 deletions(-)
> 
> diff --git a/arch/x86/include/asm/sev.h b/arch/x86/include/asm/sev.h
> index 91f08af31078..f78c94e29c74 100644
> --- a/arch/x86/include/asm/sev.h
> +++ b/arch/x86/include/asm/sev.h
> @@ -14,6 +14,7 @@
>  #include <asm/insn.h>
>  #include <asm/sev-common.h>
>  #include <asm/coco.h>
> +#include <asm/set_memory.h>
>  
>  #define GHCB_PROTOCOL_MIN	1ULL
>  #define GHCB_PROTOCOL_MAX	2ULL
> @@ -170,10 +171,6 @@ struct snp_guest_msg {
>  	u8 payload[PAGE_SIZE - sizeof(struct snp_guest_msg_hdr)];
>  } __packed;
>  
> -struct sev_guest_platform_data {
> -	u64 secrets_gpa;
> -};
> -
>  struct snp_guest_req {
>  	void *req_buf;
>  	size_t req_sz;
> @@ -253,6 +250,7 @@ struct snp_msg_desc {
>  
>  	u32 *os_area_msg_seqno;
>  	u8 *vmpck;
> +	int vmpck_id;
>  };
>  
>  /*
> @@ -458,6 +456,20 @@ void set_pte_enc_mask(pte_t *kpte, unsigned long pfn, pgprot_t new_prot);
>  void snp_kexec_finish(void);
>  void snp_kexec_begin(void);
>  
> +static inline bool snp_is_vmpck_empty(struct snp_msg_desc *mdesc)
> +{
> +	static const char zero_key[VMPCK_KEY_LEN] = {0};
> +
> +	if (mdesc->vmpck)
> +		return !memcmp(mdesc->vmpck, zero_key, VMPCK_KEY_LEN);
> +
> +	return true;
> +}

This function looks silly in a header with that array allocation.

I think you should simply do:

	if (memchr_inv(mdesc->vmpck, 0, VMPCK_KEY_LEN))

at the call sites and not have this helper at all.

But please do verify whether what I'm saying actually makes sense and if it
does, this can be a cleanup pre-patch.


> +
> +int snp_msg_init(struct snp_msg_desc *mdesc, int vmpck_id);
> +struct snp_msg_desc *snp_msg_alloc(void);
> +void snp_msg_free(struct snp_msg_desc *mdesc);
> +
>  #else	/* !CONFIG_AMD_MEM_ENCRYPT */
>  
>  #define snp_vmpl 0
> @@ -498,6 +510,10 @@ static inline int prepare_pte_enc(struct pte_enc_desc *d) { return 0; }
>  static inline void set_pte_enc_mask(pte_t *kpte, unsigned long pfn, pgprot_t new_prot) { }
>  static inline void snp_kexec_finish(void) { }
>  static inline void snp_kexec_begin(void) { }
> +static inline bool snp_is_vmpck_empty(struct snp_msg_desc *mdesc) { return false; }
> +static inline int snp_msg_init(struct snp_msg_desc *mdesc, int vmpck_id) { return -1; }
> +static inline struct snp_msg_desc *snp_msg_alloc(void) { return NULL; }
> +static inline void snp_msg_free(struct snp_msg_desc *mdesc) { }
>  
>  #endif	/* CONFIG_AMD_MEM_ENCRYPT */
>  
> diff --git a/arch/x86/coco/sev/core.c b/arch/x86/coco/sev/core.c
> index c5b0148b8c0a..3cc741eefd06 100644
> --- a/arch/x86/coco/sev/core.c
> +++ b/arch/x86/coco/sev/core.c
> @@ -25,6 +25,7 @@
>  #include <linux/psp-sev.h>
>  #include <linux/dmi.h>
>  #include <uapi/linux/sev-guest.h>
> +#include <crypto/gcm.h>
>  
>  #include <asm/init.h>
>  #include <asm/cpu_entry_area.h>
> @@ -2580,15 +2581,9 @@ static struct platform_device sev_guest_device = {
>  
>  static int __init snp_init_platform_device(void)
>  {
> -	struct sev_guest_platform_data data;
> -
>  	if (!cc_platform_has(CC_ATTR_GUEST_SEV_SNP))
>  		return -ENODEV;
>  
> -	data.secrets_gpa = secrets_pa;
> -	if (platform_device_add_data(&sev_guest_device, &data, sizeof(data)))
> -		return -ENODEV;
> -
>  	if (platform_device_register(&sev_guest_device))
>  		return -ENODEV;
>  
> @@ -2667,3 +2662,179 @@ static int __init sev_sysfs_init(void)
>  }
>  arch_initcall(sev_sysfs_init);
>  #endif // CONFIG_SYSFS
> +
> +static void free_shared_pages(void *buf, size_t sz)
> +{
> +	unsigned int npages = PAGE_ALIGN(sz) >> PAGE_SHIFT;
> +	int ret;
> +
> +	if (!buf)
> +		return;
> +
> +	ret = set_memory_encrypted((unsigned long)buf, npages);
> +	if (ret) {
> +		WARN_ONCE(ret, "failed to restore encryption mask (leak it)\n");

Looking at where this lands:

set_memory_encrypted
|-> __set_memory_enc_dec

and that doing now:

        if (cc_platform_has(CC_ATTR_MEM_ENCRYPT)) {
                if (!down_read_trylock(&mem_enc_lock))
                        return -EBUSY;


after

859e63b789d6 ("x86/tdx: Convert shared memory back to private on kexec")

we probably should pay attention to this here firing and maybe turning that
_trylock() into a normal down_read*

Anyway, just something to pay attention to in the future.

> +		return;
> +	}
> +
> +	__free_pages(virt_to_page(buf), get_order(sz));
> +}

...

> +struct snp_msg_desc *snp_msg_alloc(void)
> +{
> +	struct snp_msg_desc *mdesc;
> +	void __iomem *mem;
> +
> +	BUILD_BUG_ON(sizeof(struct snp_guest_msg) > PAGE_SIZE);
> +
> +	mdesc = kzalloc(sizeof(struct snp_msg_desc), GFP_KERNEL);

The above ones use GFP_KERNEL_ACCOUNT. What's the difference?

> +	if (!mdesc)
> +		return ERR_PTR(-ENOMEM);
> +
> +	mem = ioremap_encrypted(secrets_pa, PAGE_SIZE);
> +	if (!mem)
> +		goto e_free_mdesc;
> +
> +	mdesc->secrets = (__force struct snp_secrets_page *)mem;
> +
> +	/* Allocate the shared page used for the request and response message. */
> +	mdesc->request = alloc_shared_pages(sizeof(struct snp_guest_msg));
> +	if (!mdesc->request)
> +		goto e_unmap;
> +
> +	mdesc->response = alloc_shared_pages(sizeof(struct snp_guest_msg));
> +	if (!mdesc->response)
> +		goto e_free_request;
> +
> +	mdesc->certs_data = alloc_shared_pages(SEV_FW_BLOB_MAX_SIZE);
> +	if (!mdesc->certs_data)
> +		goto e_free_response;
> +
> +	/* initial the input address for guest request */
> +	mdesc->input.req_gpa = __pa(mdesc->request);
> +	mdesc->input.resp_gpa = __pa(mdesc->response);
> +	mdesc->input.data_gpa = __pa(mdesc->certs_data);
> +
> +	return mdesc;
> +
> +e_free_response:
> +	free_shared_pages(mdesc->response, sizeof(struct snp_guest_msg));
> +e_free_request:
> +	free_shared_pages(mdesc->request, sizeof(struct snp_guest_msg));
> +e_unmap:
> +	iounmap(mem);
> +e_free_mdesc:
> +	kfree(mdesc);
> +
> +	return ERR_PTR(-ENOMEM);
> +}
> +EXPORT_SYMBOL_GPL(snp_msg_alloc);
> +
> +void snp_msg_free(struct snp_msg_desc *mdesc)
> +{
> +	if (!mdesc)
> +		return;
> +
> +	mdesc->vmpck = NULL;
> +	mdesc->os_area_msg_seqno = NULL;

	memset(mdesc, ...);

at the end instead of those assignments.

> +	kfree(mdesc->ctx);
> +
> +	free_shared_pages(mdesc->response, sizeof(struct snp_guest_msg));
> +	free_shared_pages(mdesc->request, sizeof(struct snp_guest_msg));
> +	iounmap((__force void __iomem *)mdesc->secrets);


> +	kfree(mdesc);
> +}
> +EXPORT_SYMBOL_GPL(snp_msg_free);
> diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c
> index b699771be029..5268511bc9b8 100644
> --- a/drivers/virt/coco/sev-guest/sev-guest.c
> +++ b/drivers/virt/coco/sev-guest/sev-guest.c

...

> @@ -993,115 +898,57 @@ static int __init sev_guest_probe(struct platform_device *pdev)
>  	if (!cc_platform_has(CC_ATTR_GUEST_SEV_SNP))
>  		return -ENODEV;
>  
> -	if (!dev->platform_data)
> -		return -ENODEV;
> -
> -	data = (struct sev_guest_platform_data *)dev->platform_data;
> -	mapping = ioremap_encrypted(data->secrets_gpa, PAGE_SIZE);
> -	if (!mapping)
> -		return -ENODEV;
> -
> -	secrets = (__force void *)mapping;
> -
> -	ret = -ENOMEM;
>  	snp_dev = devm_kzalloc(&pdev->dev, sizeof(struct snp_guest_dev), GFP_KERNEL);
>  	if (!snp_dev)
> -		goto e_unmap;
> -
> -	mdesc = devm_kzalloc(&pdev->dev, sizeof(struct snp_msg_desc), GFP_KERNEL);
> -	if (!mdesc)
> -		goto e_unmap;
> -
> -	/* Adjust the default VMPCK key based on the executing VMPL level */
> -	if (vmpck_id == -1)
> -		vmpck_id = snp_vmpl;
> +		return -ENOMEM;
>  
> -	ret = -EINVAL;
> -	mdesc->vmpck = get_vmpck(vmpck_id, secrets, &mdesc->os_area_msg_seqno);
> -	if (!mdesc->vmpck) {
> -		dev_err(dev, "Invalid VMPCK%d communication key\n", vmpck_id);
> -		goto e_unmap;
> -	}
> +	mdesc = snp_msg_alloc();
> +	if (IS_ERR_OR_NULL(mdesc))
> +		return -ENOMEM;
>  
> -	/* Verify that VMPCK is not zero. */
> -	if (is_vmpck_empty(mdesc)) {
> -		dev_err(dev, "Empty VMPCK%d communication key\n", vmpck_id);
> -		goto e_unmap;
> -	}
> +	ret = snp_msg_init(mdesc, vmpck_id);
> +	if (ret)
> +		return -EIO;

You just leaked mdesc here.

Audit all your error paths.

Thx.
Nikunj A. Dadhania Dec. 3, 2024, 2:35 p.m. UTC | #2
On 12/3/2024 7:49 PM, Borislav Petkov wrote:
> On Tue, Dec 03, 2024 at 02:30:33PM +0530, Nikunj A Dadhania wrote:
>> Currently, the sev-guest driver is the only user of SNP guest messaging.
>> All routines for initializing SNP guest messaging are implemented within
>> the sev-guest driver and are not available during early boot. In
>> prepratation for adding Secure TSC guest support, carve out APIs to
> 
> Unknown word [prepratation] in commit message.
> Suggestions: ['preparation', 'preparations', 'reparation', 'perpetration', 'reputation', 'perpetuation', 'peroration', 'presentation', 'repatriation', 'propagation', "preparation's"]
> 
> Please introduce a spellchecker into your patch creation workflow.

This is what I use with checkpatch, that didnt catch the wrong spelling. Do you suggest using something else ?

./scripts/checkpatch.pl --codespell < sectsc_v15/v15-0001-x86-sev-Carve-out-and-export-SNP-guest-messaging.patch
total: 0 errors, 0 warnings, 569 lines checked

"[PATCH v15 01/13] x86/sev: Carve out and export SNP guest messaging" has no obvious style problems and is ready for submission.

Regards
Nikunj
Borislav Petkov Dec. 3, 2024, 2:50 p.m. UTC | #3
On Tue, Dec 03, 2024 at 08:05:32PM +0530, Nikunj A. Dadhania wrote:
> This is what I use with checkpatch, that didnt catch the wrong spelling.

Not surprised.

> Do you suggest using something else ?

You can enable spellchecking in your editor with which you write the commit
messages. For example:

https://www.linux.com/training-tutorials/using-spell-checking-vim/

Or, you can use my tool:

https://git.kernel.org/pub/scm/linux/kernel/git/bp/bp.git/log/?h=vp

You'd need to fish it out of the repo.

It doesn't completely replace checkpatch yet but I am extending it with
features as I go. But it does spellcheck:

$ ~/dev/vp/.tip/bin/vp.py ~/tmp/review/new
prepratation for adding Secure TSC guest support, carve out APIs to
Unknown word [prepratation] in commit message.
Suggestions: ['preparation', 'preparations', 'reparation', 'perpetration', 'reputation', 'perpetuation', 'peroration', 'presentation', 'repatriation', 'propagation', "preparation's"]

Class patch:
    original subject: [[PATCH v15 01/13] x86/sev: Carve out and export SNP guest messaging init routines]
             subject: [x86/sev: Carve out and export SNP guest messaging init routines]
              sender: [Nikunj A Dadhania <nikunj@amd.com>]
              author: [Nikunj A Dadhania <nikunj@amd.com>]
             version: [15]
              number: [1]
                name: [x86-sev-carve_out_and_export_snp_guest_messaging_init_routines]
                date: [Tue, 03 Dec 2024 14:30:33 +0530]
          message-id: [20241203090045.942078-2-nikunj@amd.com]

I'm sure there are gazillion other ways to automate it, ofc.

HTH.
Nikunj A. Dadhania Dec. 3, 2024, 2:52 p.m. UTC | #4
On 12/3/2024 8:20 PM, Borislav Petkov wrote:
> On Tue, Dec 03, 2024 at 08:05:32PM +0530, Nikunj A. Dadhania wrote:
>> This is what I use with checkpatch, that didnt catch the wrong spelling.
> 
> Not surprised.
> 
>> Do you suggest using something else ?
> 
> You can enable spellchecking in your editor with which you write the commit
> messages. For example:
> 
> https://www.linux.com/training-tutorials/using-spell-checking-vim/
> 
> Or, you can use my tool:
> 
> https://git.kernel.org/pub/scm/linux/kernel/git/bp/bp.git/log/?h=vp
> 
> You'd need to fish it out of the repo.

Sure will give it a try.

Regards
Nikunj
diff mbox series

Patch

diff --git a/arch/x86/include/asm/sev.h b/arch/x86/include/asm/sev.h
index 91f08af31078..f78c94e29c74 100644
--- a/arch/x86/include/asm/sev.h
+++ b/arch/x86/include/asm/sev.h
@@ -14,6 +14,7 @@ 
 #include <asm/insn.h>
 #include <asm/sev-common.h>
 #include <asm/coco.h>
+#include <asm/set_memory.h>
 
 #define GHCB_PROTOCOL_MIN	1ULL
 #define GHCB_PROTOCOL_MAX	2ULL
@@ -170,10 +171,6 @@  struct snp_guest_msg {
 	u8 payload[PAGE_SIZE - sizeof(struct snp_guest_msg_hdr)];
 } __packed;
 
-struct sev_guest_platform_data {
-	u64 secrets_gpa;
-};
-
 struct snp_guest_req {
 	void *req_buf;
 	size_t req_sz;
@@ -253,6 +250,7 @@  struct snp_msg_desc {
 
 	u32 *os_area_msg_seqno;
 	u8 *vmpck;
+	int vmpck_id;
 };
 
 /*
@@ -458,6 +456,20 @@  void set_pte_enc_mask(pte_t *kpte, unsigned long pfn, pgprot_t new_prot);
 void snp_kexec_finish(void);
 void snp_kexec_begin(void);
 
+static inline bool snp_is_vmpck_empty(struct snp_msg_desc *mdesc)
+{
+	static const char zero_key[VMPCK_KEY_LEN] = {0};
+
+	if (mdesc->vmpck)
+		return !memcmp(mdesc->vmpck, zero_key, VMPCK_KEY_LEN);
+
+	return true;
+}
+
+int snp_msg_init(struct snp_msg_desc *mdesc, int vmpck_id);
+struct snp_msg_desc *snp_msg_alloc(void);
+void snp_msg_free(struct snp_msg_desc *mdesc);
+
 #else	/* !CONFIG_AMD_MEM_ENCRYPT */
 
 #define snp_vmpl 0
@@ -498,6 +510,10 @@  static inline int prepare_pte_enc(struct pte_enc_desc *d) { return 0; }
 static inline void set_pte_enc_mask(pte_t *kpte, unsigned long pfn, pgprot_t new_prot) { }
 static inline void snp_kexec_finish(void) { }
 static inline void snp_kexec_begin(void) { }
+static inline bool snp_is_vmpck_empty(struct snp_msg_desc *mdesc) { return false; }
+static inline int snp_msg_init(struct snp_msg_desc *mdesc, int vmpck_id) { return -1; }
+static inline struct snp_msg_desc *snp_msg_alloc(void) { return NULL; }
+static inline void snp_msg_free(struct snp_msg_desc *mdesc) { }
 
 #endif	/* CONFIG_AMD_MEM_ENCRYPT */
 
diff --git a/arch/x86/coco/sev/core.c b/arch/x86/coco/sev/core.c
index c5b0148b8c0a..3cc741eefd06 100644
--- a/arch/x86/coco/sev/core.c
+++ b/arch/x86/coco/sev/core.c
@@ -25,6 +25,7 @@ 
 #include <linux/psp-sev.h>
 #include <linux/dmi.h>
 #include <uapi/linux/sev-guest.h>
+#include <crypto/gcm.h>
 
 #include <asm/init.h>
 #include <asm/cpu_entry_area.h>
@@ -2580,15 +2581,9 @@  static struct platform_device sev_guest_device = {
 
 static int __init snp_init_platform_device(void)
 {
-	struct sev_guest_platform_data data;
-
 	if (!cc_platform_has(CC_ATTR_GUEST_SEV_SNP))
 		return -ENODEV;
 
-	data.secrets_gpa = secrets_pa;
-	if (platform_device_add_data(&sev_guest_device, &data, sizeof(data)))
-		return -ENODEV;
-
 	if (platform_device_register(&sev_guest_device))
 		return -ENODEV;
 
@@ -2667,3 +2662,179 @@  static int __init sev_sysfs_init(void)
 }
 arch_initcall(sev_sysfs_init);
 #endif // CONFIG_SYSFS
+
+static void free_shared_pages(void *buf, size_t sz)
+{
+	unsigned int npages = PAGE_ALIGN(sz) >> PAGE_SHIFT;
+	int ret;
+
+	if (!buf)
+		return;
+
+	ret = set_memory_encrypted((unsigned long)buf, npages);
+	if (ret) {
+		WARN_ONCE(ret, "failed to restore encryption mask (leak it)\n");
+		return;
+	}
+
+	__free_pages(virt_to_page(buf), get_order(sz));
+}
+
+static void *alloc_shared_pages(size_t sz)
+{
+	unsigned int npages = PAGE_ALIGN(sz) >> PAGE_SHIFT;
+	struct page *page;
+	int ret;
+
+	page = alloc_pages(GFP_KERNEL_ACCOUNT, get_order(sz));
+	if (!page)
+		return NULL;
+
+	ret = set_memory_decrypted((unsigned long)page_address(page), npages);
+	if (ret) {
+		pr_err("failed to mark page shared, ret=%d\n", ret);
+		__free_pages(page, get_order(sz));
+		return NULL;
+	}
+
+	return page_address(page);
+}
+
+static u8 *get_vmpck(int id, struct snp_secrets_page *secrets, u32 **seqno)
+{
+	u8 *key = NULL;
+
+	switch (id) {
+	case 0:
+		*seqno = &secrets->os_area.msg_seqno_0;
+		key = secrets->vmpck0;
+		break;
+	case 1:
+		*seqno = &secrets->os_area.msg_seqno_1;
+		key = secrets->vmpck1;
+		break;
+	case 2:
+		*seqno = &secrets->os_area.msg_seqno_2;
+		key = secrets->vmpck2;
+		break;
+	case 3:
+		*seqno = &secrets->os_area.msg_seqno_3;
+		key = secrets->vmpck3;
+		break;
+	default:
+		break;
+	}
+
+	return key;
+}
+
+static struct aesgcm_ctx *snp_init_crypto(u8 *key, size_t keylen)
+{
+	struct aesgcm_ctx *ctx;
+
+	ctx = kzalloc(sizeof(*ctx), GFP_KERNEL_ACCOUNT);
+	if (!ctx)
+		return NULL;
+
+	if (aesgcm_expandkey(ctx, key, keylen, AUTHTAG_LEN)) {
+		pr_err("Crypto context initialization failed\n");
+		kfree(ctx);
+		return NULL;
+	}
+
+	return ctx;
+}
+
+int snp_msg_init(struct snp_msg_desc *mdesc, int vmpck_id)
+{
+	/* Adjust the default VMPCK key based on the executing VMPL level */
+	if (vmpck_id == -1)
+		vmpck_id = snp_vmpl;
+
+	mdesc->vmpck = get_vmpck(vmpck_id, mdesc->secrets, &mdesc->os_area_msg_seqno);
+	if (!mdesc->vmpck) {
+		pr_err("Invalid VMPCK%d communication key\n", vmpck_id);
+		return -EINVAL;
+	}
+
+	/* Verify that VMPCK is not zero. */
+	if (snp_is_vmpck_empty(mdesc)) {
+		pr_err("Empty VMPCK%d communication key\n", vmpck_id);
+		return -EINVAL;
+	}
+
+	mdesc->vmpck_id = vmpck_id;
+
+	mdesc->ctx = snp_init_crypto(mdesc->vmpck, VMPCK_KEY_LEN);
+	if (!mdesc->ctx)
+		return -ENOMEM;
+
+	return 0;
+}
+EXPORT_SYMBOL_GPL(snp_msg_init);
+
+struct snp_msg_desc *snp_msg_alloc(void)
+{
+	struct snp_msg_desc *mdesc;
+	void __iomem *mem;
+
+	BUILD_BUG_ON(sizeof(struct snp_guest_msg) > PAGE_SIZE);
+
+	mdesc = kzalloc(sizeof(struct snp_msg_desc), GFP_KERNEL);
+	if (!mdesc)
+		return ERR_PTR(-ENOMEM);
+
+	mem = ioremap_encrypted(secrets_pa, PAGE_SIZE);
+	if (!mem)
+		goto e_free_mdesc;
+
+	mdesc->secrets = (__force struct snp_secrets_page *)mem;
+
+	/* Allocate the shared page used for the request and response message. */
+	mdesc->request = alloc_shared_pages(sizeof(struct snp_guest_msg));
+	if (!mdesc->request)
+		goto e_unmap;
+
+	mdesc->response = alloc_shared_pages(sizeof(struct snp_guest_msg));
+	if (!mdesc->response)
+		goto e_free_request;
+
+	mdesc->certs_data = alloc_shared_pages(SEV_FW_BLOB_MAX_SIZE);
+	if (!mdesc->certs_data)
+		goto e_free_response;
+
+	/* initial the input address for guest request */
+	mdesc->input.req_gpa = __pa(mdesc->request);
+	mdesc->input.resp_gpa = __pa(mdesc->response);
+	mdesc->input.data_gpa = __pa(mdesc->certs_data);
+
+	return mdesc;
+
+e_free_response:
+	free_shared_pages(mdesc->response, sizeof(struct snp_guest_msg));
+e_free_request:
+	free_shared_pages(mdesc->request, sizeof(struct snp_guest_msg));
+e_unmap:
+	iounmap(mem);
+e_free_mdesc:
+	kfree(mdesc);
+
+	return ERR_PTR(-ENOMEM);
+}
+EXPORT_SYMBOL_GPL(snp_msg_alloc);
+
+void snp_msg_free(struct snp_msg_desc *mdesc)
+{
+	if (!mdesc)
+		return;
+
+	mdesc->vmpck = NULL;
+	mdesc->os_area_msg_seqno = NULL;
+	kfree(mdesc->ctx);
+
+	free_shared_pages(mdesc->response, sizeof(struct snp_guest_msg));
+	free_shared_pages(mdesc->request, sizeof(struct snp_guest_msg));
+	iounmap((__force void __iomem *)mdesc->secrets);
+	kfree(mdesc);
+}
+EXPORT_SYMBOL_GPL(snp_msg_free);
diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c
index b699771be029..5268511bc9b8 100644
--- a/drivers/virt/coco/sev-guest/sev-guest.c
+++ b/drivers/virt/coco/sev-guest/sev-guest.c
@@ -63,16 +63,6 @@  MODULE_PARM_DESC(vmpck_id, "The VMPCK ID to use when communicating with the PSP.
 /* Mutex to serialize the shared buffer access and command handling. */
 static DEFINE_MUTEX(snp_cmd_mutex);
 
-static bool is_vmpck_empty(struct snp_msg_desc *mdesc)
-{
-	char zero_key[VMPCK_KEY_LEN] = {0};
-
-	if (mdesc->vmpck)
-		return !memcmp(mdesc->vmpck, zero_key, VMPCK_KEY_LEN);
-
-	return true;
-}
-
 /*
  * If an error is received from the host or AMD Secure Processor (ASP) there
  * are two options. Either retry the exact same encrypted request or discontinue
@@ -93,7 +83,7 @@  static bool is_vmpck_empty(struct snp_msg_desc *mdesc)
 static void snp_disable_vmpck(struct snp_msg_desc *mdesc)
 {
 	pr_alert("Disabling VMPCK%d communication key to prevent IV reuse.\n",
-		  vmpck_id);
+		  mdesc->vmpck_id);
 	memzero_explicit(mdesc->vmpck, VMPCK_KEY_LEN);
 	mdesc->vmpck = NULL;
 }
@@ -147,23 +137,6 @@  static inline struct snp_guest_dev *to_snp_dev(struct file *file)
 	return container_of(dev, struct snp_guest_dev, misc);
 }
 
-static struct aesgcm_ctx *snp_init_crypto(u8 *key, size_t keylen)
-{
-	struct aesgcm_ctx *ctx;
-
-	ctx = kzalloc(sizeof(*ctx), GFP_KERNEL_ACCOUNT);
-	if (!ctx)
-		return NULL;
-
-	if (aesgcm_expandkey(ctx, key, keylen, AUTHTAG_LEN)) {
-		pr_err("Crypto context initialization failed\n");
-		kfree(ctx);
-		return NULL;
-	}
-
-	return ctx;
-}
-
 static int verify_and_dec_payload(struct snp_msg_desc *mdesc, struct snp_guest_req *req)
 {
 	struct snp_guest_msg *resp_msg = &mdesc->secret_response;
@@ -335,7 +308,7 @@  static int snp_send_guest_request(struct snp_msg_desc *mdesc, struct snp_guest_r
 	guard(mutex)(&snp_cmd_mutex);
 
 	/* Check if the VMPCK is not empty */
-	if (is_vmpck_empty(mdesc)) {
+	if (snp_is_vmpck_empty(mdesc)) {
 		pr_err_ratelimited("VMPCK is disabled\n");
 		return -ENOTTY;
 	}
@@ -414,7 +387,7 @@  static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_io
 
 	req.msg_version = arg->msg_version;
 	req.msg_type = SNP_MSG_REPORT_REQ;
-	req.vmpck_id = vmpck_id;
+	req.vmpck_id = mdesc->vmpck_id;
 	req.req_buf = report_req;
 	req.req_sz = sizeof(*report_req);
 	req.resp_buf = report_resp->data;
@@ -461,7 +434,7 @@  static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_reque
 
 	req.msg_version = arg->msg_version;
 	req.msg_type = SNP_MSG_KEY_REQ;
-	req.vmpck_id = vmpck_id;
+	req.vmpck_id = mdesc->vmpck_id;
 	req.req_buf = derived_key_req;
 	req.req_sz = sizeof(*derived_key_req);
 	req.resp_buf = buf;
@@ -539,7 +512,7 @@  static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
 
 	req.msg_version = arg->msg_version;
 	req.msg_type = SNP_MSG_REPORT_REQ;
-	req.vmpck_id = vmpck_id;
+	req.vmpck_id = mdesc->vmpck_id;
 	req.req_buf = &report_req->data;
 	req.req_sz = sizeof(report_req->data);
 	req.resp_buf = report_resp->data;
@@ -616,76 +589,11 @@  static long snp_guest_ioctl(struct file *file, unsigned int ioctl, unsigned long
 	return ret;
 }
 
-static void free_shared_pages(void *buf, size_t sz)
-{
-	unsigned int npages = PAGE_ALIGN(sz) >> PAGE_SHIFT;
-	int ret;
-
-	if (!buf)
-		return;
-
-	ret = set_memory_encrypted((unsigned long)buf, npages);
-	if (ret) {
-		WARN_ONCE(ret, "failed to restore encryption mask (leak it)\n");
-		return;
-	}
-
-	__free_pages(virt_to_page(buf), get_order(sz));
-}
-
-static void *alloc_shared_pages(struct device *dev, size_t sz)
-{
-	unsigned int npages = PAGE_ALIGN(sz) >> PAGE_SHIFT;
-	struct page *page;
-	int ret;
-
-	page = alloc_pages(GFP_KERNEL_ACCOUNT, get_order(sz));
-	if (!page)
-		return NULL;
-
-	ret = set_memory_decrypted((unsigned long)page_address(page), npages);
-	if (ret) {
-		dev_err(dev, "failed to mark page shared, ret=%d\n", ret);
-		__free_pages(page, get_order(sz));
-		return NULL;
-	}
-
-	return page_address(page);
-}
-
 static const struct file_operations snp_guest_fops = {
 	.owner	= THIS_MODULE,
 	.unlocked_ioctl = snp_guest_ioctl,
 };
 
-static u8 *get_vmpck(int id, struct snp_secrets_page *secrets, u32 **seqno)
-{
-	u8 *key = NULL;
-
-	switch (id) {
-	case 0:
-		*seqno = &secrets->os_area.msg_seqno_0;
-		key = secrets->vmpck0;
-		break;
-	case 1:
-		*seqno = &secrets->os_area.msg_seqno_1;
-		key = secrets->vmpck1;
-		break;
-	case 2:
-		*seqno = &secrets->os_area.msg_seqno_2;
-		key = secrets->vmpck2;
-		break;
-	case 3:
-		*seqno = &secrets->os_area.msg_seqno_3;
-		key = secrets->vmpck3;
-		break;
-	default:
-		break;
-	}
-
-	return key;
-}
-
 struct snp_msg_report_resp_hdr {
 	u32 status;
 	u32 report_size;
@@ -979,13 +887,10 @@  static void unregister_sev_tsm(void *data)
 
 static int __init sev_guest_probe(struct platform_device *pdev)
 {
-	struct sev_guest_platform_data *data;
-	struct snp_secrets_page *secrets;
 	struct device *dev = &pdev->dev;
 	struct snp_guest_dev *snp_dev;
 	struct snp_msg_desc *mdesc;
 	struct miscdevice *misc;
-	void __iomem *mapping;
 	int ret;
 
 	BUILD_BUG_ON(sizeof(struct snp_guest_msg) > PAGE_SIZE);
@@ -993,115 +898,57 @@  static int __init sev_guest_probe(struct platform_device *pdev)
 	if (!cc_platform_has(CC_ATTR_GUEST_SEV_SNP))
 		return -ENODEV;
 
-	if (!dev->platform_data)
-		return -ENODEV;
-
-	data = (struct sev_guest_platform_data *)dev->platform_data;
-	mapping = ioremap_encrypted(data->secrets_gpa, PAGE_SIZE);
-	if (!mapping)
-		return -ENODEV;
-
-	secrets = (__force void *)mapping;
-
-	ret = -ENOMEM;
 	snp_dev = devm_kzalloc(&pdev->dev, sizeof(struct snp_guest_dev), GFP_KERNEL);
 	if (!snp_dev)
-		goto e_unmap;
-
-	mdesc = devm_kzalloc(&pdev->dev, sizeof(struct snp_msg_desc), GFP_KERNEL);
-	if (!mdesc)
-		goto e_unmap;
-
-	/* Adjust the default VMPCK key based on the executing VMPL level */
-	if (vmpck_id == -1)
-		vmpck_id = snp_vmpl;
+		return -ENOMEM;
 
-	ret = -EINVAL;
-	mdesc->vmpck = get_vmpck(vmpck_id, secrets, &mdesc->os_area_msg_seqno);
-	if (!mdesc->vmpck) {
-		dev_err(dev, "Invalid VMPCK%d communication key\n", vmpck_id);
-		goto e_unmap;
-	}
+	mdesc = snp_msg_alloc();
+	if (IS_ERR_OR_NULL(mdesc))
+		return -ENOMEM;
 
-	/* Verify that VMPCK is not zero. */
-	if (is_vmpck_empty(mdesc)) {
-		dev_err(dev, "Empty VMPCK%d communication key\n", vmpck_id);
-		goto e_unmap;
-	}
+	ret = snp_msg_init(mdesc, vmpck_id);
+	if (ret)
+		return -EIO;
 
 	platform_set_drvdata(pdev, snp_dev);
 	snp_dev->dev = dev;
-	mdesc->secrets = secrets;
-
-	/* Allocate the shared page used for the request and response message. */
-	mdesc->request = alloc_shared_pages(dev, sizeof(struct snp_guest_msg));
-	if (!mdesc->request)
-		goto e_unmap;
-
-	mdesc->response = alloc_shared_pages(dev, sizeof(struct snp_guest_msg));
-	if (!mdesc->response)
-		goto e_free_request;
-
-	mdesc->certs_data = alloc_shared_pages(dev, SEV_FW_BLOB_MAX_SIZE);
-	if (!mdesc->certs_data)
-		goto e_free_response;
-
-	ret = -EIO;
-	mdesc->ctx = snp_init_crypto(mdesc->vmpck, VMPCK_KEY_LEN);
-	if (!mdesc->ctx)
-		goto e_free_cert_data;
 
 	misc = &snp_dev->misc;
 	misc->minor = MISC_DYNAMIC_MINOR;
 	misc->name = DEVICE_NAME;
 	misc->fops = &snp_guest_fops;
 
-	/* Initialize the input addresses for guest request */
-	mdesc->input.req_gpa = __pa(mdesc->request);
-	mdesc->input.resp_gpa = __pa(mdesc->response);
-	mdesc->input.data_gpa = __pa(mdesc->certs_data);
-
 	/* Set the privlevel_floor attribute based on the vmpck_id */
-	sev_tsm_ops.privlevel_floor = vmpck_id;
+	sev_tsm_ops.privlevel_floor = mdesc->vmpck_id;
 
 	ret = tsm_register(&sev_tsm_ops, snp_dev);
 	if (ret)
-		goto e_free_cert_data;
+		goto e_msg_init;
 
 	ret = devm_add_action_or_reset(&pdev->dev, unregister_sev_tsm, NULL);
 	if (ret)
-		goto e_free_cert_data;
+		goto e_msg_init;
 
 	ret =  misc_register(misc);
 	if (ret)
-		goto e_free_ctx;
+		goto e_msg_init;
 
 	snp_dev->msg_desc = mdesc;
-	dev_info(dev, "Initialized SEV guest driver (using VMPCK%d communication key)\n", vmpck_id);
+	dev_info(dev, "Initialized SEV guest driver (using VMPCK%d communication key)\n",
+		 mdesc->vmpck_id);
 	return 0;
 
-e_free_ctx:
-	kfree(mdesc->ctx);
-e_free_cert_data:
-	free_shared_pages(mdesc->certs_data, SEV_FW_BLOB_MAX_SIZE);
-e_free_response:
-	free_shared_pages(mdesc->response, sizeof(struct snp_guest_msg));
-e_free_request:
-	free_shared_pages(mdesc->request, sizeof(struct snp_guest_msg));
-e_unmap:
-	iounmap(mapping);
+e_msg_init:
+	snp_msg_free(mdesc);
+
 	return ret;
 }
 
 static void __exit sev_guest_remove(struct platform_device *pdev)
 {
 	struct snp_guest_dev *snp_dev = platform_get_drvdata(pdev);
-	struct snp_msg_desc *mdesc = snp_dev->msg_desc;
 
-	free_shared_pages(mdesc->certs_data, SEV_FW_BLOB_MAX_SIZE);
-	free_shared_pages(mdesc->response, sizeof(struct snp_guest_msg));
-	free_shared_pages(mdesc->request, sizeof(struct snp_guest_msg));
-	kfree(mdesc->ctx);
+	snp_msg_free(snp_dev->msg_desc);
 	misc_deregister(&snp_dev->misc);
 }
 
diff --git a/arch/x86/Kconfig b/arch/x86/Kconfig
index 9d7bd0ae48c4..0f7e3acf37e3 100644
--- a/arch/x86/Kconfig
+++ b/arch/x86/Kconfig
@@ -1559,6 +1559,7 @@  config AMD_MEM_ENCRYPT
 	select ARCH_HAS_CC_PLATFORM
 	select X86_MEM_ENCRYPT
 	select UNACCEPTED_MEMORY
+	select CRYPTO_LIB_AESGCM
 	help
 	  Say yes to enable support for the encryption of system memory.
 	  This requires an AMD processor that supports Secure Memory
diff --git a/drivers/virt/coco/sev-guest/Kconfig b/drivers/virt/coco/sev-guest/Kconfig
index 0b772bd921d8..a6405ab6c2c3 100644
--- a/drivers/virt/coco/sev-guest/Kconfig
+++ b/drivers/virt/coco/sev-guest/Kconfig
@@ -2,7 +2,6 @@  config SEV_GUEST
 	tristate "AMD SEV Guest driver"
 	default m
 	depends on AMD_MEM_ENCRYPT
-	select CRYPTO_LIB_AESGCM
 	select TSM_REPORTS
 	help
 	  SEV-SNP firmware provides the guest a mechanism to communicate with