diff mbox series

[Part1,RFC,v4,15/36] x86/mm: Add support to validate memory when changing C-bit

Message ID 20210707181506.30489-16-brijesh.singh@amd.com (mailing list archive)
State New
Headers show
Series Add AMD Secure Nested Paging (SEV-SNP) Guest Support | expand

Commit Message

Brijesh Singh July 7, 2021, 6:14 p.m. UTC
The set_memory_{encrypt,decrypt}() are used for changing the pages
from decrypted (shared) to encrypted (private) and vice versa.
When SEV-SNP is active, the page state transition needs to go through
additional steps.

If the page is transitioned from shared to private, then perform the
following after the encryption attribute is set in the page table:

1. Issue the page state change VMGEXIT to add the memory region in
   the RMP table.
2. Validate the memory region after the RMP entry is added.

To maintain the security guarantees, if the page is transitioned from
private to shared, then perform the following before encryption attribute
is removed from the page table:

1. Invalidate the page.
2. Issue the page state change VMGEXIT to remove the page from RMP table.

To change the page state in the RMP table, use the Page State Change
VMGEXIT defined in the GHCB specification.

Signed-off-by: Brijesh Singh <brijesh.singh@amd.com>
---
 arch/x86/include/asm/sev-common.h |  24 +++++
 arch/x86/include/asm/sev.h        |   4 +
 arch/x86/include/uapi/asm/svm.h   |   2 +
 arch/x86/kernel/sev.c             | 160 ++++++++++++++++++++++++++++++
 arch/x86/mm/pat/set_memory.c      |  15 +++
 5 files changed, 205 insertions(+)

Comments

Borislav Petkov Aug. 17, 2021, 5:27 p.m. UTC | #1
On Wed, Jul 07, 2021 at 01:14:45PM -0500, Brijesh Singh wrote:
> +struct __packed psc_hdr {
> +	u16 cur_entry;
> +	u16 end_entry;
> +	u32 reserved;
> +};
> +
> +struct __packed psc_entry {
> +	u64	cur_page	: 12,
> +		gfn		: 40,
> +		operation	: 4,
> +		pagesize	: 1,
> +		reserved	: 7;
> +};
> +
> +struct __packed snp_psc_desc {
> +	struct psc_hdr hdr;
> +	struct psc_entry entries[VMGEXIT_PSC_MAX_ENTRY];
> +};

The majority of kernel code puts __packed after the struct definition,
let's put it there too pls, out of the way.

...

> +static int vmgexit_psc(struct snp_psc_desc *desc)
> +{
> +	int cur_entry, end_entry, ret;
> +	struct snp_psc_desc *data;
> +	struct ghcb_state state;
> +	struct ghcb *ghcb;
> +	struct psc_hdr *hdr;
> +	unsigned long flags;
> +
> +	local_irq_save(flags);
> +
> +	ghcb = __sev_get_ghcb(&state);
> +	if (unlikely(!ghcb))
> +		panic("SEV-SNP: Failed to get GHCB\n");
> +
> +	/* Copy the input desc into GHCB shared buffer */
> +	data = (struct snp_psc_desc *)ghcb->shared_buffer;
> +	memcpy(ghcb->shared_buffer, desc, sizeof(*desc));
> +
> +	hdr = &data->hdr;
> +	cur_entry = hdr->cur_entry;
> +	end_entry = hdr->end_entry;
> +
> +	/*
> +	 * As per the GHCB specification, the hypervisor can resume the guest
> +	 * before processing all the entries. Checks whether all the entries
> +	 * are processed. If not, then keep retrying.
> +	 *
> +	 * The stragtegy here is to wait for the hypervisor to change the page
> +	 * state in the RMP table before guest access the memory pages. If the
> +	 * page state was not successful, then later memory access will result
> +	 * in the crash.
> +	 */
> +	while (hdr->cur_entry <= hdr->end_entry) {
> +		ghcb_set_sw_scratch(ghcb, (u64)__pa(data));
> +
> +		ret = sev_es_ghcb_hv_call(ghcb, NULL, SVM_VMGEXIT_PSC, 0, 0);
> +
> +		/*
> +		 * Page State Change VMGEXIT can pass error code through
> +		 * exit_info_2.
> +		 */
> +		if (WARN(ret || ghcb->save.sw_exit_info_2,
> +			 "SEV-SNP: page state change failed ret=%d exit_info_2=%llx\n",
> +			 ret, ghcb->save.sw_exit_info_2))
> +			return 1;

Yikes, you return here and below with interrupts disabled.

All your returns need to be "goto out;" instead where you do

out:
        __sev_put_ghcb(&state);
        local_irq_restore(flags);

Yap, you very likely need to put the GHCB too.

> +		/*
> +		 * Lets do some sanity check that entry processing is not going
> +		 * backward. This will happen only if hypervisor is tricking us.
> +		 */
> +		if (WARN((hdr->end_entry > end_entry) || (cur_entry > hdr->cur_entry),
> +			"SEV-SNP: page state change processing going backward, end_entry "
> +			"(expected %d got %d) cur_entry (expected %d got %d)\n",
> +			end_entry, hdr->end_entry, cur_entry, hdr->cur_entry))
> +			return 1;

WARNING: quoted string split across lines
#293: FILE: arch/x86/kernel/sev.c:750:
+			"SEV-SNP: page state change processing going backward, end_entry "
+			"(expected %d got %d) cur_entry (expected %d got %d)\n",

If you're wondering what to do, yes, you can really stretch that string
and shorten it too:

                if (WARN((hdr->end_entry > end_entry) || (cur_entry > hdr->cur_entry),
"SEV-SNP: PSC processing going backwards, end_entry %d (got %d) cur_entry: %d (got %d)\n",
                         end_entry, hdr->end_entry, cur_entry, hdr->cur_entry))
                        return 1;

so that it fits on a single line and grepping can find it.

> +		/* Lets verify that reserved bit is not set in the header*/
> +		if (WARN(hdr->reserved, "Reserved bit is set in the PSC header\n"))

psc_entry has a ->reserved field too and since we're iterating over the
entries...

> +			return 1;
> +	}
> +
> +	__sev_put_ghcb(&state);
> +	local_irq_restore(flags);
> +
> +	return 0;
> +}
> +
> +static void __set_page_state(struct snp_psc_desc *data, unsigned long vaddr,
> +			     unsigned long vaddr_end, int op)
> +{
> +	struct psc_hdr *hdr;
> +	struct psc_entry *e;
> +	unsigned long pfn;
> +	int i;
> +
> +	hdr = &data->hdr;
> +	e = data->entries;
> +
> +	memset(data, 0, sizeof(*data));
> +	i = 0;
> +
> +	while (vaddr < vaddr_end) {
> +		if (is_vmalloc_addr((void *)vaddr))
> +			pfn = vmalloc_to_pfn((void *)vaddr);
> +		else
> +			pfn = __pa(vaddr) >> PAGE_SHIFT;
> +
> +		e->gfn = pfn;
> +		e->operation = op;
> +		hdr->end_entry = i;
> +
> +		/*
> +		 * The GHCB specification provides the flexibility to
> +		 * use either 4K or 2MB page size in the RMP table.
> +		 * The current SNP support does not keep track of the
> +		 * page size used in the RMP table. To avoid the
> +		 * overlap request, use the 4K page size in the RMP
> +		 * table.
> +		 */
> +		e->pagesize = RMP_PG_SIZE_4K;
> +
> +		vaddr = vaddr + PAGE_SIZE;
> +		e++;
> +		i++;
> +	}
> +
> +	/* Terminate the guest on page state change failure. */

That comment is kinda obvious :)

> +	if (vmgexit_psc(data))
> +		sev_es_terminate(1, GHCB_TERM_PSC);
> +}
> +
> +static void set_page_state(unsigned long vaddr, unsigned int npages, int op)
> +{
> +	unsigned long vaddr_end, next_vaddr;
> +	struct snp_psc_desc *desc;
> +
> +	vaddr = vaddr & PAGE_MASK;
> +	vaddr_end = vaddr + (npages << PAGE_SHIFT);
> +
> +	desc = kmalloc(sizeof(*desc), GFP_KERNEL_ACCOUNT);

kzalloc() so that you don't have to memset() later in
__set_page_state().

> +	if (!desc)
> +		panic("failed to allocate memory");

Make that error message more distinctive so that *if* it happens, one
can pinpoint the place in the code where the panic comes from.

> +	while (vaddr < vaddr_end) {
> +		/*
> +		 * Calculate the last vaddr that can be fit in one
> +		 * struct snp_psc_desc.
> +		 */
> +		next_vaddr = min_t(unsigned long, vaddr_end,
> +				(VMGEXIT_PSC_MAX_ENTRY * PAGE_SIZE) + vaddr);
> +
> +		__set_page_state(desc, vaddr, next_vaddr, op);
> +
> +		vaddr = next_vaddr;
> +	}
> +
> +	kfree(desc);
> +}
> +
Brijesh Singh Aug. 17, 2021, 6:07 p.m. UTC | #2
On 8/17/21 12:27 PM, Borislav Petkov wrote:
> 
> The majority of kernel code puts __packed after the struct definition,
> let's put it there too pls, out of the way.
> 
> ...

Noted.

>> +		if (WARN(ret || ghcb->save.sw_exit_info_2,
>> +			 "SEV-SNP: page state change failed ret=%d exit_info_2=%llx\n",
>> +			 ret, ghcb->save.sw_exit_info_2))
>> +			return 1;
> 
> Yikes, you return here and below with interrupts disabled.
> 
> All your returns need to be "goto out;" instead where you do
> 
> out:
>          __sev_put_ghcb(&state);
>          local_irq_restore(flags);
> 
> Yap, you very likely need to put the GHCB too.
> 

Sure, let me revisit this code to fix those path.

>> +		/*
>> +		 * Lets do some sanity check that entry processing is not going
>> +		 * backward. This will happen only if hypervisor is tricking us.
>> +		 */
>> +		if (WARN((hdr->end_entry > end_entry) || (cur_entry > hdr->cur_entry),
>> +			"SEV-SNP: page state change processing going backward, end_entry "
>> +			"(expected %d got %d) cur_entry (expected %d got %d)\n",
>> +			end_entry, hdr->end_entry, cur_entry, hdr->cur_entry))
>> +			return 1;
> 
> WARNING: quoted string split across lines
> #293: FILE: arch/x86/kernel/sev.c:750:
> +			"SEV-SNP: page state change processing going backward, end_entry "
> +			"(expected %d got %d) cur_entry (expected %d got %d)\n",
> 
> If you're wondering what to do, yes, you can really stretch that string
> and shorten it too:

Okay.

> 
>                  if (WARN((hdr->end_entry > end_entry) || (cur_entry > hdr->cur_entry),
> "SEV-SNP: PSC processing going backwards, end_entry %d (got %d) cur_entry: %d (got %d)\n",
>                           end_entry, hdr->end_entry, cur_entry, hdr->cur_entry))
>                          return 1;
> 
> so that it fits on a single line and grepping can find it.
> 
Noted.

>> +		/* Lets verify that reserved bit is not set in the header*/
>> +		if (WARN(hdr->reserved, "Reserved bit is set in the PSC header\n"))
> 
> psc_entry has a ->reserved field too and since we're iterating over the
> entries...
> 
Sure I can add that check.


>> +
>> +	desc = kmalloc(sizeof(*desc), GFP_KERNEL_ACCOUNT);
> 
> kzalloc() so that you don't have to memset() later in
> __set_page_state().

Depending on the size, the __set_page_state() can be call multiple times 
so it should clear the desc memory before filling it.

> 
>> +	if (!desc)
>> +		panic("failed to allocate memory");
> 
> Make that error message more distinctive so that *if* it happens, one
> can pinpoint the place in the code where the panic comes from.
> 

Now I am running checkpatch and notice that it complain about the 
message too. I can add a BUG() or WARN() to get the stack trace before 
the crashing.

>> +	while (vaddr < vaddr_end) {
>> +		/*
>> +		 * Calculate the last vaddr that can be fit in one
>> +		 * struct snp_psc_desc.
>> +		 */
>> +		next_vaddr = min_t(unsigned long, vaddr_end,
>> +				(VMGEXIT_PSC_MAX_ENTRY * PAGE_SIZE) + vaddr);
>> +
>> +		__set_page_state(desc, vaddr, next_vaddr, op);
>> +
>> +		vaddr = next_vaddr;
>> +	}
>> +
>> +	kfree(desc);
>> +}
>> +
>
Borislav Petkov Aug. 17, 2021, 6:17 p.m. UTC | #3
On Tue, Aug 17, 2021 at 01:07:40PM -0500, Brijesh Singh wrote:
> > > +	if (!desc)
> > > +		panic("failed to allocate memory");
> > 
> > Make that error message more distinctive so that *if* it happens, one
> > can pinpoint the place in the code where the panic comes from.
> > 
> 
> Now I am running checkpatch and notice that it complain about the message
> too. I can add a BUG() or WARN() to get the stack trace before the crashing.

checkpatch complains because there's a kmalloc before it and if it
fails, the mm core will issue a warning so there's no need for a warning
here.

But in this case, you want to panic and checkpatch doesn't see that so
you can ignore it here and leave the panic message but make it more
distinctive so one can find it by grepping. IOW, something like

	if (!desc)
		panic("SEV-SNP: Failed to allocame memory for PSC descriptor");

Thx.
Brijesh Singh Aug. 17, 2021, 6:18 p.m. UTC | #4
On 8/17/21 1:17 PM, Borislav Petkov wrote:
> On Tue, Aug 17, 2021 at 01:07:40PM -0500, Brijesh Singh wrote:
>>>> +	if (!desc)
>>>> +		panic("failed to allocate memory");
>>>
>>> Make that error message more distinctive so that *if* it happens, one
>>> can pinpoint the place in the code where the panic comes from.
>>>
>>
>> Now I am running checkpatch and notice that it complain about the message
>> too. I can add a BUG() or WARN() to get the stack trace before the crashing.
> 
> checkpatch complains because there's a kmalloc before it and if it
> fails, the mm core will issue a warning so there's no need for a warning
> here.
> 
> But in this case, you want to panic and checkpatch doesn't see that so
> you can ignore it here and leave the panic message but make it more
> distinctive so one can find it by grepping. IOW, something like
> 
> 	if (!desc)
> 		panic("SEV-SNP: Failed to allocame memory for PSC descriptor");
> 

Got it, I will update the message accordingly.

thanks
Brijesh Singh Aug. 17, 2021, 8:34 p.m. UTC | #5
Hi Boris,


On 8/17/21 12:27 PM, Borislav Petkov wrote:

> 
>> +		/* Lets verify that reserved bit is not set in the header*/
>> +		if (WARN(hdr->reserved, "Reserved bit is set in the PSC header\n"))
> 
> psc_entry has a ->reserved field too and since we're iterating over the
> entries...
> 


I am not seeing any strong reason to sanity check the reserved bit in 
the psc_entry. The fields in the psc_entry are input from guest to the 
hypervisor. The hypervisor cannot trick a guest by changing anything in 
the psc_entry because guest does not read the hypervisor filled value. I 
am okay with the psc_hdr because we need to read the current and end 
entry after the PSC completes to determine whether it was successful and 
sanity checking PSC header makes much more sense. Let me know if you are 
okay with it ?

thanks
Borislav Petkov Aug. 17, 2021, 8:44 p.m. UTC | #6
On Tue, Aug 17, 2021 at 03:34:41PM -0500, Brijesh Singh wrote:
> I am not seeing any strong reason to sanity check the reserved bit in the
> psc_entry. The fields in the psc_entry are input from guest to the
> hypervisor. The hypervisor cannot trick a guest by changing anything in the
> psc_entry because guest does not read the hypervisor filled value. I am okay
> with the psc_hdr because we need to read the current and end entry after the
> PSC completes to determine whether it was successful and sanity checking PSC
> header makes much more sense. Let me know if you are okay with it ?

Ok, fair enough.

Thx.
diff mbox series

Patch

diff --git a/arch/x86/include/asm/sev-common.h b/arch/x86/include/asm/sev-common.h
index b19d8d301f5d..2277c8085b13 100644
--- a/arch/x86/include/asm/sev-common.h
+++ b/arch/x86/include/asm/sev-common.h
@@ -60,6 +60,8 @@ 
 #define GHCB_MSR_PSC_REQ		0x014
 #define SNP_PAGE_STATE_PRIVATE		1
 #define SNP_PAGE_STATE_SHARED		2
+#define SNP_PAGE_STATE_PSMASH		3
+#define SNP_PAGE_STATE_UNSMASH		4
 #define GHCB_MSR_PSC_GFN_POS		12
 #define GHCB_MSR_PSC_GFN_MASK		GENMASK_ULL(39, 0)
 #define GHCB_MSR_PSC_OP_POS		52
@@ -84,6 +86,28 @@ 
 
 #define GHCB_HV_FT_SNP			BIT_ULL(0)
 
+/* SNP Page State Change NAE event */
+#define VMGEXIT_PSC_MAX_ENTRY		253
+
+struct __packed psc_hdr {
+	u16 cur_entry;
+	u16 end_entry;
+	u32 reserved;
+};
+
+struct __packed psc_entry {
+	u64	cur_page	: 12,
+		gfn		: 40,
+		operation	: 4,
+		pagesize	: 1,
+		reserved	: 7;
+};
+
+struct __packed snp_psc_desc {
+	struct psc_hdr hdr;
+	struct psc_entry entries[VMGEXIT_PSC_MAX_ENTRY];
+};
+
 #define GHCB_MSR_TERM_REQ		0x100
 #define GHCB_MSR_TERM_REASON_SET_POS	12
 #define GHCB_MSR_TERM_REASON_SET_MASK	0xf
diff --git a/arch/x86/include/asm/sev.h b/arch/x86/include/asm/sev.h
index 9a676fb0929d..2385651c810e 100644
--- a/arch/x86/include/asm/sev.h
+++ b/arch/x86/include/asm/sev.h
@@ -109,6 +109,8 @@  void __init early_snp_set_memory_private(unsigned long vaddr, unsigned long padd
 void __init early_snp_set_memory_shared(unsigned long vaddr, unsigned long paddr,
 					unsigned int npages);
 void __init snp_prep_memory(unsigned long paddr, unsigned int sz, int op);
+void snp_set_memory_shared(unsigned long vaddr, unsigned int npages);
+void snp_set_memory_private(unsigned long vaddr, unsigned int npages);
 #else
 static inline void sev_es_ist_enter(struct pt_regs *regs) { }
 static inline void sev_es_ist_exit(void) { }
@@ -121,6 +123,8 @@  early_snp_set_memory_private(unsigned long vaddr, unsigned long paddr, unsigned
 static inline void __init
 early_snp_set_memory_shared(unsigned long vaddr, unsigned long paddr, unsigned int npages) { }
 static inline void __init snp_prep_memory(unsigned long paddr, unsigned int sz, int op) { }
+static inline void snp_set_memory_shared(unsigned long vaddr, unsigned int npages) { }
+static inline void snp_set_memory_private(unsigned long vaddr, unsigned int npages) { }
 #endif
 
 #endif
diff --git a/arch/x86/include/uapi/asm/svm.h b/arch/x86/include/uapi/asm/svm.h
index 7fbc311e2de1..f7f65febff70 100644
--- a/arch/x86/include/uapi/asm/svm.h
+++ b/arch/x86/include/uapi/asm/svm.h
@@ -108,6 +108,7 @@ 
 #define SVM_VMGEXIT_AP_JUMP_TABLE		0x80000005
 #define SVM_VMGEXIT_SET_AP_JUMP_TABLE		0
 #define SVM_VMGEXIT_GET_AP_JUMP_TABLE		1
+#define SVM_VMGEXIT_PSC				0x80000010
 #define SVM_VMGEXIT_HYPERVISOR_FEATURES		0x8000fffd
 #define SVM_VMGEXIT_UNSUPPORTED_EVENT		0x8000ffff
 
@@ -216,6 +217,7 @@ 
 	{ SVM_VMGEXIT_NMI_COMPLETE,	"vmgexit_nmi_complete" }, \
 	{ SVM_VMGEXIT_AP_HLT_LOOP,	"vmgexit_ap_hlt_loop" }, \
 	{ SVM_VMGEXIT_AP_JUMP_TABLE,	"vmgexit_ap_jump_table" }, \
+	{ SVM_VMGEXIT_PSC,	"vmgexit_page_state_change" }, \
 	{ SVM_VMGEXIT_HYPERVISOR_FEATURES,	"vmgexit_hypervisor_feature" }, \
 	{ SVM_EXIT_ERR,         "invalid_guest_state" }
 
diff --git a/arch/x86/kernel/sev.c b/arch/x86/kernel/sev.c
index 62034879fb3f..5fef7fc46282 100644
--- a/arch/x86/kernel/sev.c
+++ b/arch/x86/kernel/sev.c
@@ -694,6 +694,166 @@  void __init snp_prep_memory(unsigned long paddr, unsigned int sz, int op)
 		WARN(1, "invalid memory op %d\n", op);
 }
 
+static int vmgexit_psc(struct snp_psc_desc *desc)
+{
+	int cur_entry, end_entry, ret;
+	struct snp_psc_desc *data;
+	struct ghcb_state state;
+	struct ghcb *ghcb;
+	struct psc_hdr *hdr;
+	unsigned long flags;
+
+	local_irq_save(flags);
+
+	ghcb = __sev_get_ghcb(&state);
+	if (unlikely(!ghcb))
+		panic("SEV-SNP: Failed to get GHCB\n");
+
+	/* Copy the input desc into GHCB shared buffer */
+	data = (struct snp_psc_desc *)ghcb->shared_buffer;
+	memcpy(ghcb->shared_buffer, desc, sizeof(*desc));
+
+	hdr = &data->hdr;
+	cur_entry = hdr->cur_entry;
+	end_entry = hdr->end_entry;
+
+	/*
+	 * As per the GHCB specification, the hypervisor can resume the guest
+	 * before processing all the entries. Checks whether all the entries
+	 * are processed. If not, then keep retrying.
+	 *
+	 * The stragtegy here is to wait for the hypervisor to change the page
+	 * state in the RMP table before guest access the memory pages. If the
+	 * page state was not successful, then later memory access will result
+	 * in the crash.
+	 */
+	while (hdr->cur_entry <= hdr->end_entry) {
+		ghcb_set_sw_scratch(ghcb, (u64)__pa(data));
+
+		ret = sev_es_ghcb_hv_call(ghcb, NULL, SVM_VMGEXIT_PSC, 0, 0);
+
+		/*
+		 * Page State Change VMGEXIT can pass error code through
+		 * exit_info_2.
+		 */
+		if (WARN(ret || ghcb->save.sw_exit_info_2,
+			 "SEV-SNP: page state change failed ret=%d exit_info_2=%llx\n",
+			 ret, ghcb->save.sw_exit_info_2))
+			return 1;
+
+		/*
+		 * Lets do some sanity check that entry processing is not going
+		 * backward. This will happen only if hypervisor is tricking us.
+		 */
+		if (WARN((hdr->end_entry > end_entry) || (cur_entry > hdr->cur_entry),
+			"SEV-SNP: page state change processing going backward, end_entry "
+			"(expected %d got %d) cur_entry (expected %d got %d)\n",
+			end_entry, hdr->end_entry, cur_entry, hdr->cur_entry))
+			return 1;
+
+		/* Lets verify that reserved bit is not set in the header*/
+		if (WARN(hdr->reserved, "Reserved bit is set in the PSC header\n"))
+			return 1;
+	}
+
+	__sev_put_ghcb(&state);
+	local_irq_restore(flags);
+
+	return 0;
+}
+
+static void __set_page_state(struct snp_psc_desc *data, unsigned long vaddr,
+			     unsigned long vaddr_end, int op)
+{
+	struct psc_hdr *hdr;
+	struct psc_entry *e;
+	unsigned long pfn;
+	int i;
+
+	hdr = &data->hdr;
+	e = data->entries;
+
+	memset(data, 0, sizeof(*data));
+	i = 0;
+
+	while (vaddr < vaddr_end) {
+		if (is_vmalloc_addr((void *)vaddr))
+			pfn = vmalloc_to_pfn((void *)vaddr);
+		else
+			pfn = __pa(vaddr) >> PAGE_SHIFT;
+
+		e->gfn = pfn;
+		e->operation = op;
+		hdr->end_entry = i;
+
+		/*
+		 * The GHCB specification provides the flexibility to
+		 * use either 4K or 2MB page size in the RMP table.
+		 * The current SNP support does not keep track of the
+		 * page size used in the RMP table. To avoid the
+		 * overlap request, use the 4K page size in the RMP
+		 * table.
+		 */
+		e->pagesize = RMP_PG_SIZE_4K;
+
+		vaddr = vaddr + PAGE_SIZE;
+		e++;
+		i++;
+	}
+
+	/* Terminate the guest on page state change failure. */
+	if (vmgexit_psc(data))
+		sev_es_terminate(1, GHCB_TERM_PSC);
+}
+
+static void set_page_state(unsigned long vaddr, unsigned int npages, int op)
+{
+	unsigned long vaddr_end, next_vaddr;
+	struct snp_psc_desc *desc;
+
+	vaddr = vaddr & PAGE_MASK;
+	vaddr_end = vaddr + (npages << PAGE_SHIFT);
+
+	desc = kmalloc(sizeof(*desc), GFP_KERNEL_ACCOUNT);
+	if (!desc)
+		panic("failed to allocate memory");
+
+	while (vaddr < vaddr_end) {
+		/*
+		 * Calculate the last vaddr that can be fit in one
+		 * struct snp_psc_desc.
+		 */
+		next_vaddr = min_t(unsigned long, vaddr_end,
+				(VMGEXIT_PSC_MAX_ENTRY * PAGE_SIZE) + vaddr);
+
+		__set_page_state(desc, vaddr, next_vaddr, op);
+
+		vaddr = next_vaddr;
+	}
+
+	kfree(desc);
+}
+
+void snp_set_memory_shared(unsigned long vaddr, unsigned int npages)
+{
+	if (!sev_feature_enabled(SEV_SNP))
+		return;
+
+	pvalidate_pages(vaddr, npages, 0);
+
+	set_page_state(vaddr, npages, SNP_PAGE_STATE_SHARED);
+}
+
+void snp_set_memory_private(unsigned long vaddr, unsigned int npages)
+{
+	if (!sev_feature_enabled(SEV_SNP))
+		return;
+
+	set_page_state(vaddr, npages, SNP_PAGE_STATE_PRIVATE);
+
+	pvalidate_pages(vaddr, npages, 1);
+}
+
 int sev_es_setup_ap_jump_table(struct real_mode_header *rmh)
 {
 	u16 startup_cs, startup_ip;
diff --git a/arch/x86/mm/pat/set_memory.c b/arch/x86/mm/pat/set_memory.c
index 156cd235659f..d09df2971d30 100644
--- a/arch/x86/mm/pat/set_memory.c
+++ b/arch/x86/mm/pat/set_memory.c
@@ -29,6 +29,7 @@ 
 #include <asm/proto.h>
 #include <asm/memtype.h>
 #include <asm/set_memory.h>
+#include <asm/sev.h>
 
 #include "../mm_internal.h"
 
@@ -2009,8 +2010,22 @@  static int __set_memory_enc_dec(unsigned long addr, int numpages, bool enc)
 	 */
 	cpa_flush(&cpa, !this_cpu_has(X86_FEATURE_SME_COHERENT));
 
+	/*
+	 * To maintain the security gurantees of SEV-SNP guest invalidate the memory
+	 * before clearing the encryption attribute.
+	 */
+	if (!enc)
+		snp_set_memory_shared(addr, numpages);
+
 	ret = __change_page_attr_set_clr(&cpa, 1);
 
+	/*
+	 * Now that memory is mapped encrypted in the page table, validate it
+	 * so that is consistent with the above page state.
+	 */
+	if (!ret && enc)
+		snp_set_memory_private(addr, numpages);
+
 	/*
 	 * After changing the encryption attribute, we need to flush TLBs again
 	 * in case any speculative TLB caching occurred (but no need to flush