diff mbox series

[PATCHv3,4/8] x86/mm: Handle LAM on context switch

Message ID 20220610143527.22974-5-kirill.shutemov@linux.intel.com (mailing list archive)
State New
Headers show
Series Linear Address Masking enabling | expand

Commit Message

Kirill A. Shutemov June 10, 2022, 2:35 p.m. UTC
Linear Address Masking mode for userspace pointers encoded in CR3 bits.
The mode is selected per-thread. Add new thread features indicate that the
thread has Linear Address Masking enabled.

switch_mm_irqs_off() now respects these flags and constructs CR3
accordingly.

The active LAM mode gets recorded in the tlb_state.

Signed-off-by: Kirill A. Shutemov <kirill.shutemov@linux.intel.com>
---
 arch/x86/include/asm/mmu.h         |  1 +
 arch/x86/include/asm/mmu_context.h | 24 ++++++++++++
 arch/x86/include/asm/tlbflush.h    |  3 ++
 arch/x86/mm/tlb.c                  | 62 ++++++++++++++++++++++--------
 4 files changed, 75 insertions(+), 15 deletions(-)

Comments

Edgecombe, Rick P June 10, 2022, 11:55 p.m. UTC | #1
On Fri, 2022-06-10 at 17:35 +0300, Kirill A. Shutemov wrote:
> @@ -687,6 +716,7 @@ void initialize_tlbstate_and_flush(void)
>         struct mm_struct *mm = this_cpu_read(cpu_tlbstate.loaded_mm);
>         u64 tlb_gen = atomic64_read(&init_mm.context.tlb_gen);
>         unsigned long cr3 = __read_cr3();
> +       u64 lam = cr3 & (X86_CR3_LAM_U48 | X86_CR3_LAM_U57);
>  
>         /* Assert that CR3 already references the right mm. */
>         WARN_ON((cr3 & CR3_ADDR_MASK) != __pa(mm->pgd));
> @@ -700,7 +730,7 @@ void initialize_tlbstate_and_flush(void)
>                 !(cr4_read_shadow() & X86_CR4_PCIDE));
>  
>         /* Force ASID 0 and force a TLB flush. */
> -       write_cr3(build_cr3(mm->pgd, 0));
> +       write_cr3(build_cr3(mm->pgd, 0, lam));
>  

Can you explain why to keep the lam bits that were in CR3 here? It
seems to be worried some CR3 bits got changed and need to be set to a
known state. Why not take them from the MM?

Also, it warns if the cr3 pfn doesn't match the mm pgd, should it warn
if cr3 lam bits don't match the MM's copy?
Kirill A. Shutemov June 15, 2022, 3:54 p.m. UTC | #2
On Fri, Jun 10, 2022 at 11:55:02PM +0000, Edgecombe, Rick P wrote:
> On Fri, 2022-06-10 at 17:35 +0300, Kirill A. Shutemov wrote:
> > @@ -687,6 +716,7 @@ void initialize_tlbstate_and_flush(void)
> >         struct mm_struct *mm = this_cpu_read(cpu_tlbstate.loaded_mm);
> >         u64 tlb_gen = atomic64_read(&init_mm.context.tlb_gen);
> >         unsigned long cr3 = __read_cr3();
> > +       u64 lam = cr3 & (X86_CR3_LAM_U48 | X86_CR3_LAM_U57);
> >  
> >         /* Assert that CR3 already references the right mm. */
> >         WARN_ON((cr3 & CR3_ADDR_MASK) != __pa(mm->pgd));
> > @@ -700,7 +730,7 @@ void initialize_tlbstate_and_flush(void)
> >                 !(cr4_read_shadow() & X86_CR4_PCIDE));
> >  
> >         /* Force ASID 0 and force a TLB flush. */
> > -       write_cr3(build_cr3(mm->pgd, 0));
> > +       write_cr3(build_cr3(mm->pgd, 0, lam));
> >  
> 
> Can you explain why to keep the lam bits that were in CR3 here? It
> seems to be worried some CR3 bits got changed and need to be set to a
> known state. Why not take them from the MM?
> 
> Also, it warns if the cr3 pfn doesn't match the mm pgd, should it warn
> if cr3 lam bits don't match the MM's copy?

You are right, taking LAM mode from init_mm is more correct. And we need
to update tlbstate with the new LAM mode. 

I think both CR3 and init_mm should LAM disabled here as we are bringing
CPU up. I'll add WARN_ON().
Peter Zijlstra June 16, 2022, 9:08 a.m. UTC | #3
On Fri, Jun 10, 2022 at 05:35:23PM +0300, Kirill A. Shutemov wrote:

> diff --git a/arch/x86/include/asm/tlbflush.h b/arch/x86/include/asm/tlbflush.h
> index 4af5579c7ef7..5b93dad93ff4 100644
> --- a/arch/x86/include/asm/tlbflush.h
> +++ b/arch/x86/include/asm/tlbflush.h
> @@ -86,6 +86,9 @@ struct tlb_state {
>  		unsigned long		last_user_mm_spec;
>  	};
>  
> +#ifdef CONFIG_X86_64
> +	u64 lam_cr3_mask;
> +#endif
>  	u16 loaded_mm_asid;
>  	u16 next_asid;
>  

Urgh.. so there's a comment there that states:

/*
 * 6 because 6 should be plenty and struct tlb_state will fit in two cache
 * lines.
 */
#define TLB_NR_DYN_ASIDS        6

And then look at tlb_state:

struct tlb_state {
	struct mm_struct *         loaded_mm;            /*     0     8 */
	union {
		struct mm_struct * last_user_mm;         /*     8     8 */
		long unsigned int  last_user_mm_spec;    /*     8     8 */
	};                                               /*     8     8 */
	u16                        loaded_mm_asid;       /*    16     2 */
	u16                        next_asid;            /*    18     2 */
	bool                       invalidate_other;     /*    20     1 */

	/* XXX 1 byte hole, try to pack */

	short unsigned int         user_pcid_flush_mask; /*    22     2 */
	long unsigned int          cr4;                  /*    24     8 */
	struct tlb_context         ctxs[6];              /*    32    96 */

	/* size: 128, cachelines: 2, members: 8 */
	/* sum members: 127, holes: 1, sum holes: 1 */
};

If you add that u64 as you do, you'll wreck all that.

Either use that one spare byte, or find room elsewhere I suppose.
Kirill A. Shutemov June 16, 2022, 4:40 p.m. UTC | #4
On Thu, Jun 16, 2022 at 11:08:07AM +0200, Peter Zijlstra wrote:
> Either use that one spare byte, or find room elsewhere I suppose.

Okay, I will put into the byte after invalidate_other and modify
tlbstate_lam_cr3_mask() and set_tlbstate_lam_cr3_mask() to shift it
accordingly.

It looks like this:

struct tlb_state {
	struct mm_struct *         loaded_mm;            /*     0     8 */
	union {
		struct mm_struct * last_user_mm;         /*     8     8 */
		unsigned long      last_user_mm_spec;    /*     8     8 */
	};                                               /*     8     8 */
	union {
		struct mm_struct *         last_user_mm;         /*     0     8 */
		unsigned long              last_user_mm_spec;    /*     0     8 */
	};

	u16                        loaded_mm_asid;       /*    16     2 */
	u16                        next_asid;            /*    18     2 */
	bool                       invalidate_other;     /*    20     1 */
	u8                         lam;                  /*    21     1 */
	unsigned short             user_pcid_flush_mask; /*    22     2 */
	unsigned long              cr4;                  /*    24     8 */
	struct tlb_context         ctxs[6];              /*    32    96 */

	/* size: 128, cachelines: 2, members: 9 */
};
Alexander Potapenko June 17, 2022, 3:35 p.m. UTC | #5
On Fri, Jun 10, 2022 at 4:35 PM Kirill A. Shutemov
<kirill.shutemov@linux.intel.com> wrote:
>
> Linear Address Masking mode for userspace pointers encoded in CR3 bits.
> The mode is selected per-thread. Add new thread features indicate that the
> thread has Linear Address Masking enabled.
>
> switch_mm_irqs_off() now respects these flags and constructs CR3
> accordingly.
>
> The active LAM mode gets recorded in the tlb_state.
>
> Signed-off-by: Kirill A. Shutemov <kirill.shutemov@linux.intel.com>


> +#ifdef CONFIG_X86_64
> +static inline u64 mm_cr3_lam_mask(struct mm_struct *mm)
> +{
> +       return mm->context.lam_cr3_mask;
> +}

Nit: can we have either "cr3_lam_mask" or "lam_cr3_mask" in both places?
Kirill A. Shutemov June 17, 2022, 10:39 p.m. UTC | #6
On Fri, Jun 17, 2022 at 05:35:05PM +0200, Alexander Potapenko wrote:
> On Fri, Jun 10, 2022 at 4:35 PM Kirill A. Shutemov
> <kirill.shutemov@linux.intel.com> wrote:
> >
> > Linear Address Masking mode for userspace pointers encoded in CR3 bits.
> > The mode is selected per-thread. Add new thread features indicate that the
> > thread has Linear Address Masking enabled.
> >
> > switch_mm_irqs_off() now respects these flags and constructs CR3
> > accordingly.
> >
> > The active LAM mode gets recorded in the tlb_state.
> >
> > Signed-off-by: Kirill A. Shutemov <kirill.shutemov@linux.intel.com>
> 
> 
> > +#ifdef CONFIG_X86_64
> > +static inline u64 mm_cr3_lam_mask(struct mm_struct *mm)
> > +{
> > +       return mm->context.lam_cr3_mask;
> > +}
> 
> Nit: can we have either "cr3_lam_mask" or "lam_cr3_mask" in both places?

With changes sugessted by Peter, the field in the mmu_context will be
called 'lam' as it is not CR3 mask anymore.
Andy Lutomirski June 28, 2022, 11:33 p.m. UTC | #7
On 6/10/22 07:35, Kirill A. Shutemov wrote:
> Linear Address Masking mode for userspace pointers encoded in CR3 bits.
> The mode is selected per-thread. Add new thread features indicate that the
> thread has Linear Address Masking enabled.
> 
> switch_mm_irqs_off() now respects these flags and constructs CR3
> accordingly.
> 
> The active LAM mode gets recorded in the tlb_state.
> 
> Signed-off-by: Kirill A. Shutemov <kirill.shutemov@linux.intel.com>
> ---
>   arch/x86/include/asm/mmu.h         |  1 +
>   arch/x86/include/asm/mmu_context.h | 24 ++++++++++++
>   arch/x86/include/asm/tlbflush.h    |  3 ++
>   arch/x86/mm/tlb.c                  | 62 ++++++++++++++++++++++--------
>   4 files changed, 75 insertions(+), 15 deletions(-)
> 
> diff --git a/arch/x86/include/asm/mmu.h b/arch/x86/include/asm/mmu.h
> index 5d7494631ea9..d150e92163b6 100644
> --- a/arch/x86/include/asm/mmu.h
> +++ b/arch/x86/include/asm/mmu.h
> @@ -40,6 +40,7 @@ typedef struct {
>   
>   #ifdef CONFIG_X86_64
>   	unsigned short flags;
> +	u64 lam_cr3_mask;
>   #endif
>   
>   	struct mutex lock;
> diff --git a/arch/x86/include/asm/mmu_context.h b/arch/x86/include/asm/mmu_context.h
> index b8d40ddeab00..e6eac047c728 100644
> --- a/arch/x86/include/asm/mmu_context.h
> +++ b/arch/x86/include/asm/mmu_context.h
> @@ -91,6 +91,29 @@ static inline void switch_ldt(struct mm_struct *prev, struct mm_struct *next)
>   }
>   #endif
>   
> +#ifdef CONFIG_X86_64
> +static inline u64 mm_cr3_lam_mask(struct mm_struct *mm)
> +{
> +	return mm->context.lam_cr3_mask;
> +}
> +
> +static inline void dup_lam(struct mm_struct *oldmm, struct mm_struct *mm)
> +{
> +	mm->context.lam_cr3_mask = oldmm->context.lam_cr3_mask;
> +}
> +
> +#else
> +
> +static inline u64 mm_cr3_lam_mask(struct mm_struct *mm)
> +{
> +	return 0;
> +}
> +
> +static inline void dup_lam(struct mm_struct *oldmm, struct mm_struct *mm)
> +{
> +}
> +#endif

Do we really need the ifdeffery here?  I see no real harm in having the 
field exist on 32-bit -- we don't care much about performance for 32-bit 
kernels.

> -	if (real_prev == next) {
> +	if (real_prev == next && prev_lam == new_lam) {
>   		VM_WARN_ON(this_cpu_read(cpu_tlbstate.ctxs[prev_asid].ctx_id) !=
>   			   next->context.ctx_id);

This looks wrong to me.  If we change threads within the same mm but lam 
changes (which is certainly possible by a race if nothing else) then 
this will go down the "we really are changing mms" path, not the "we're 
not changing but we might need to flush something" path.
Kirill A. Shutemov June 29, 2022, 12:34 a.m. UTC | #8
On Tue, Jun 28, 2022 at 04:33:21PM -0700, Andy Lutomirski wrote:
> On 6/10/22 07:35, Kirill A. Shutemov wrote:
> > Linear Address Masking mode for userspace pointers encoded in CR3 bits.
> > The mode is selected per-thread. Add new thread features indicate that the
> > thread has Linear Address Masking enabled.
> > 
> > switch_mm_irqs_off() now respects these flags and constructs CR3
> > accordingly.
> > 
> > The active LAM mode gets recorded in the tlb_state.
> > 
> > Signed-off-by: Kirill A. Shutemov <kirill.shutemov@linux.intel.com>
> > ---
> >   arch/x86/include/asm/mmu.h         |  1 +
> >   arch/x86/include/asm/mmu_context.h | 24 ++++++++++++
> >   arch/x86/include/asm/tlbflush.h    |  3 ++
> >   arch/x86/mm/tlb.c                  | 62 ++++++++++++++++++++++--------
> >   4 files changed, 75 insertions(+), 15 deletions(-)
> > 
> > diff --git a/arch/x86/include/asm/mmu.h b/arch/x86/include/asm/mmu.h
> > index 5d7494631ea9..d150e92163b6 100644
> > --- a/arch/x86/include/asm/mmu.h
> > +++ b/arch/x86/include/asm/mmu.h
> > @@ -40,6 +40,7 @@ typedef struct {
> >   #ifdef CONFIG_X86_64
> >   	unsigned short flags;
> > +	u64 lam_cr3_mask;
> >   #endif
> >   	struct mutex lock;
> > diff --git a/arch/x86/include/asm/mmu_context.h b/arch/x86/include/asm/mmu_context.h
> > index b8d40ddeab00..e6eac047c728 100644
> > --- a/arch/x86/include/asm/mmu_context.h
> > +++ b/arch/x86/include/asm/mmu_context.h
> > @@ -91,6 +91,29 @@ static inline void switch_ldt(struct mm_struct *prev, struct mm_struct *next)
> >   }
> >   #endif
> > +#ifdef CONFIG_X86_64
> > +static inline u64 mm_cr3_lam_mask(struct mm_struct *mm)
> > +{
> > +	return mm->context.lam_cr3_mask;
> > +}
> > +
> > +static inline void dup_lam(struct mm_struct *oldmm, struct mm_struct *mm)
> > +{
> > +	mm->context.lam_cr3_mask = oldmm->context.lam_cr3_mask;
> > +}
> > +
> > +#else
> > +
> > +static inline u64 mm_cr3_lam_mask(struct mm_struct *mm)
> > +{
> > +	return 0;
> > +}
> > +
> > +static inline void dup_lam(struct mm_struct *oldmm, struct mm_struct *mm)
> > +{
> > +}
> > +#endif
> 
> Do we really need the ifdeffery here?  I see no real harm in having the
> field exist on 32-bit -- we don't care much about performance for 32-bit
> kernels.

The waste doesn't feel right to me. I would rather keep it.

But sure I can do this if needed.

> > -	if (real_prev == next) {
> > +	if (real_prev == next && prev_lam == new_lam) {
> >   		VM_WARN_ON(this_cpu_read(cpu_tlbstate.ctxs[prev_asid].ctx_id) !=
> >   			   next->context.ctx_id);
> 
> This looks wrong to me.  If we change threads within the same mm but lam
> changes (which is certainly possible by a race if nothing else) then this
> will go down the "we really are changing mms" path, not the "we're not
> changing but we might need to flush something" path.

If LAM gets enabled we must write CR3 with the new LAM mode. Without the
change real_prev == next case will not do this for !was_lazy case.

Note that currently enabling LAM is done by setting LAM mode in the mmu
context and doing switch_mm(current->mm, current->mm, current), so it is
very important case.
Andy Lutomirski June 30, 2022, 1:51 a.m. UTC | #9
On Tue, Jun 28, 2022, at 5:34 PM, Kirill A. Shutemov wrote:
> On Tue, Jun 28, 2022 at 04:33:21PM -0700, Andy Lutomirski wrote:
>> On 6/10/22 07:35, Kirill A. Shutemov wrote:
>> > Linear Address Masking mode for userspace pointers encoded in CR3 bits.
>> > The mode is selected per-thread. Add new thread features indicate that the
>> > thread has Linear Address Masking enabled.
>> > 
>> > switch_mm_irqs_off() now respects these flags and constructs CR3
>> > accordingly.
>> > 
>> > The active LAM mode gets recorded in the tlb_state.
>> > 
>> > Signed-off-by: Kirill A. Shutemov <kirill.shutemov@linux.intel.com>
>> > ---
>> >   arch/x86/include/asm/mmu.h         |  1 +
>> >   arch/x86/include/asm/mmu_context.h | 24 ++++++++++++
>> >   arch/x86/include/asm/tlbflush.h    |  3 ++
>> >   arch/x86/mm/tlb.c                  | 62 ++++++++++++++++++++++--------
>> >   4 files changed, 75 insertions(+), 15 deletions(-)
>> > 
>> > diff --git a/arch/x86/include/asm/mmu.h b/arch/x86/include/asm/mmu.h
>> > index 5d7494631ea9..d150e92163b6 100644
>> > --- a/arch/x86/include/asm/mmu.h
>> > +++ b/arch/x86/include/asm/mmu.h
>> > @@ -40,6 +40,7 @@ typedef struct {
>> >   #ifdef CONFIG_X86_64
>> >   	unsigned short flags;
>> > +	u64 lam_cr3_mask;
>> >   #endif
>> >   	struct mutex lock;
>> > diff --git a/arch/x86/include/asm/mmu_context.h b/arch/x86/include/asm/mmu_context.h
>> > index b8d40ddeab00..e6eac047c728 100644
>> > --- a/arch/x86/include/asm/mmu_context.h
>> > +++ b/arch/x86/include/asm/mmu_context.h
>> > @@ -91,6 +91,29 @@ static inline void switch_ldt(struct mm_struct *prev, struct mm_struct *next)
>> >   }
>> >   #endif
>> > +#ifdef CONFIG_X86_64
>> > +static inline u64 mm_cr3_lam_mask(struct mm_struct *mm)
>> > +{
>> > +	return mm->context.lam_cr3_mask;
>> > +}
>> > +
>> > +static inline void dup_lam(struct mm_struct *oldmm, struct mm_struct *mm)
>> > +{
>> > +	mm->context.lam_cr3_mask = oldmm->context.lam_cr3_mask;
>> > +}
>> > +
>> > +#else
>> > +
>> > +static inline u64 mm_cr3_lam_mask(struct mm_struct *mm)
>> > +{
>> > +	return 0;
>> > +}
>> > +
>> > +static inline void dup_lam(struct mm_struct *oldmm, struct mm_struct *mm)
>> > +{
>> > +}
>> > +#endif
>> 
>> Do we really need the ifdeffery here?  I see no real harm in having the
>> field exist on 32-bit -- we don't care much about performance for 32-bit
>> kernels.
>
> The waste doesn't feel right to me. I would rather keep it.
>
> But sure I can do this if needed.

I could go either way here.

>
>> > -	if (real_prev == next) {
>> > +	if (real_prev == next && prev_lam == new_lam) {
>> >   		VM_WARN_ON(this_cpu_read(cpu_tlbstate.ctxs[prev_asid].ctx_id) !=
>> >   			   next->context.ctx_id);
>> 
>> This looks wrong to me.  If we change threads within the same mm but lam
>> changes (which is certainly possible by a race if nothing else) then this
>> will go down the "we really are changing mms" path, not the "we're not
>> changing but we might need to flush something" path.
>
> If LAM gets enabled we must write CR3 with the new LAM mode. Without the
> change real_prev == next case will not do this for !was_lazy case.

You could fix that.  Or you could determine that this isn’t actually needed, just like updating the LDT in that path isn’t needed, if you change the way LAM is updated.

>
> Note that currently enabling LAM is done by setting LAM mode in the mmu
> context and doing switch_mm(current->mm, current->mm, current), so it is
> very important case.
>

Well, I did separately ask why this is the case.

> -- 
>  Kirill A. Shutemov
diff mbox series

Patch

diff --git a/arch/x86/include/asm/mmu.h b/arch/x86/include/asm/mmu.h
index 5d7494631ea9..d150e92163b6 100644
--- a/arch/x86/include/asm/mmu.h
+++ b/arch/x86/include/asm/mmu.h
@@ -40,6 +40,7 @@  typedef struct {
 
 #ifdef CONFIG_X86_64
 	unsigned short flags;
+	u64 lam_cr3_mask;
 #endif
 
 	struct mutex lock;
diff --git a/arch/x86/include/asm/mmu_context.h b/arch/x86/include/asm/mmu_context.h
index b8d40ddeab00..e6eac047c728 100644
--- a/arch/x86/include/asm/mmu_context.h
+++ b/arch/x86/include/asm/mmu_context.h
@@ -91,6 +91,29 @@  static inline void switch_ldt(struct mm_struct *prev, struct mm_struct *next)
 }
 #endif
 
+#ifdef CONFIG_X86_64
+static inline u64 mm_cr3_lam_mask(struct mm_struct *mm)
+{
+	return mm->context.lam_cr3_mask;
+}
+
+static inline void dup_lam(struct mm_struct *oldmm, struct mm_struct *mm)
+{
+	mm->context.lam_cr3_mask = oldmm->context.lam_cr3_mask;
+}
+
+#else
+
+static inline u64 mm_cr3_lam_mask(struct mm_struct *mm)
+{
+	return 0;
+}
+
+static inline void dup_lam(struct mm_struct *oldmm, struct mm_struct *mm)
+{
+}
+#endif
+
 #define enter_lazy_tlb enter_lazy_tlb
 extern void enter_lazy_tlb(struct mm_struct *mm, struct task_struct *tsk);
 
@@ -168,6 +191,7 @@  static inline int arch_dup_mmap(struct mm_struct *oldmm, struct mm_struct *mm)
 {
 	arch_dup_pkeys(oldmm, mm);
 	paravirt_arch_dup_mmap(oldmm, mm);
+	dup_lam(oldmm, mm);
 	return ldt_dup_context(oldmm, mm);
 }
 
diff --git a/arch/x86/include/asm/tlbflush.h b/arch/x86/include/asm/tlbflush.h
index 4af5579c7ef7..5b93dad93ff4 100644
--- a/arch/x86/include/asm/tlbflush.h
+++ b/arch/x86/include/asm/tlbflush.h
@@ -86,6 +86,9 @@  struct tlb_state {
 		unsigned long		last_user_mm_spec;
 	};
 
+#ifdef CONFIG_X86_64
+	u64 lam_cr3_mask;
+#endif
 	u16 loaded_mm_asid;
 	u16 next_asid;
 
diff --git a/arch/x86/mm/tlb.c b/arch/x86/mm/tlb.c
index d400b6d9d246..458867a8f4bd 100644
--- a/arch/x86/mm/tlb.c
+++ b/arch/x86/mm/tlb.c
@@ -154,17 +154,17 @@  static inline u16 user_pcid(u16 asid)
 	return ret;
 }
 
-static inline unsigned long build_cr3(pgd_t *pgd, u16 asid)
+static inline unsigned long build_cr3(pgd_t *pgd, u16 asid, u64 lam)
 {
 	if (static_cpu_has(X86_FEATURE_PCID)) {
-		return __sme_pa(pgd) | kern_pcid(asid);
+		return __sme_pa(pgd) | kern_pcid(asid) | lam;
 	} else {
 		VM_WARN_ON_ONCE(asid != 0);
-		return __sme_pa(pgd);
+		return __sme_pa(pgd) | lam;
 	}
 }
 
-static inline unsigned long build_cr3_noflush(pgd_t *pgd, u16 asid)
+static inline unsigned long build_cr3_noflush(pgd_t *pgd, u16 asid, u64 lam)
 {
 	VM_WARN_ON_ONCE(asid > MAX_ASID_AVAILABLE);
 	/*
@@ -173,7 +173,7 @@  static inline unsigned long build_cr3_noflush(pgd_t *pgd, u16 asid)
 	 * boot because all CPU's the have same capabilities:
 	 */
 	VM_WARN_ON_ONCE(!boot_cpu_has(X86_FEATURE_PCID));
-	return __sme_pa(pgd) | kern_pcid(asid) | CR3_NOFLUSH;
+	return __sme_pa(pgd) | kern_pcid(asid) | lam | CR3_NOFLUSH;
 }
 
 /*
@@ -274,15 +274,15 @@  static inline void invalidate_user_asid(u16 asid)
 		  (unsigned long *)this_cpu_ptr(&cpu_tlbstate.user_pcid_flush_mask));
 }
 
-static void load_new_mm_cr3(pgd_t *pgdir, u16 new_asid, bool need_flush)
+static void load_new_mm_cr3(pgd_t *pgdir, u16 new_asid, u64 lam, bool need_flush)
 {
 	unsigned long new_mm_cr3;
 
 	if (need_flush) {
 		invalidate_user_asid(new_asid);
-		new_mm_cr3 = build_cr3(pgdir, new_asid);
+		new_mm_cr3 = build_cr3(pgdir, new_asid, lam);
 	} else {
-		new_mm_cr3 = build_cr3_noflush(pgdir, new_asid);
+		new_mm_cr3 = build_cr3_noflush(pgdir, new_asid, lam);
 	}
 
 	/*
@@ -486,11 +486,36 @@  void cr4_update_pce(void *ignored)
 static inline void cr4_update_pce_mm(struct mm_struct *mm) { }
 #endif
 
+#ifdef CONFIG_X86_64
+static inline u64 tlbstate_lam_cr3_mask(void)
+{
+	return this_cpu_read(cpu_tlbstate.lam_cr3_mask);
+}
+
+static inline void set_tlbstate_lam_cr3_mask(u64 mask)
+{
+	this_cpu_write(cpu_tlbstate.lam_cr3_mask, mask);
+}
+
+#else
+
+static inline u64 tlbstate_lam_cr3_mask(void)
+{
+	return 0;
+}
+
+static inline void set_tlbstate_lam_cr3_mask(u64 mask)
+{
+}
+#endif
+
 void switch_mm_irqs_off(struct mm_struct *prev, struct mm_struct *next,
 			struct task_struct *tsk)
 {
 	struct mm_struct *real_prev = this_cpu_read(cpu_tlbstate.loaded_mm);
 	u16 prev_asid = this_cpu_read(cpu_tlbstate.loaded_mm_asid);
+	u64 prev_lam = tlbstate_lam_cr3_mask();
+	u64 new_lam = mm_cr3_lam_mask(next);
 	bool was_lazy = this_cpu_read(cpu_tlbstate_shared.is_lazy);
 	unsigned cpu = smp_processor_id();
 	u64 next_tlb_gen;
@@ -504,6 +529,9 @@  void switch_mm_irqs_off(struct mm_struct *prev, struct mm_struct *next,
 	 * cpu_tlbstate.loaded_mm) matches next.
 	 *
 	 * NB: leave_mm() calls us with prev == NULL and tsk == NULL.
+	 *
+	 * NB: Initial LAM enabling calls us with prev == next. We must update
+	 * CR3 if prev_lam doesn't match the new one.
 	 */
 
 	/* We don't want flush_tlb_func() to run concurrently with us. */
@@ -520,7 +548,7 @@  void switch_mm_irqs_off(struct mm_struct *prev, struct mm_struct *next,
 	 * isn't free.
 	 */
 #ifdef CONFIG_DEBUG_VM
-	if (WARN_ON_ONCE(__read_cr3() != build_cr3(real_prev->pgd, prev_asid))) {
+	if (WARN_ON_ONCE(__read_cr3() != build_cr3(real_prev->pgd, prev_asid, prev_lam))) {
 		/*
 		 * If we were to BUG here, we'd be very likely to kill
 		 * the system so hard that we don't see the call trace.
@@ -551,7 +579,7 @@  void switch_mm_irqs_off(struct mm_struct *prev, struct mm_struct *next,
 	 * provides that full memory barrier and core serializing
 	 * instruction.
 	 */
-	if (real_prev == next) {
+	if (real_prev == next && prev_lam == new_lam) {
 		VM_WARN_ON(this_cpu_read(cpu_tlbstate.ctxs[prev_asid].ctx_id) !=
 			   next->context.ctx_id);
 
@@ -622,15 +650,16 @@  void switch_mm_irqs_off(struct mm_struct *prev, struct mm_struct *next,
 		barrier();
 	}
 
+	set_tlbstate_lam_cr3_mask(new_lam);
 	if (need_flush) {
 		this_cpu_write(cpu_tlbstate.ctxs[new_asid].ctx_id, next->context.ctx_id);
 		this_cpu_write(cpu_tlbstate.ctxs[new_asid].tlb_gen, next_tlb_gen);
-		load_new_mm_cr3(next->pgd, new_asid, true);
+		load_new_mm_cr3(next->pgd, new_asid, new_lam, true);
 
 		trace_tlb_flush(TLB_FLUSH_ON_TASK_SWITCH, TLB_FLUSH_ALL);
 	} else {
 		/* The new ASID is already up to date. */
-		load_new_mm_cr3(next->pgd, new_asid, false);
+		load_new_mm_cr3(next->pgd, new_asid, new_lam, false);
 
 		trace_tlb_flush(TLB_FLUSH_ON_TASK_SWITCH, 0);
 	}
@@ -687,6 +716,7 @@  void initialize_tlbstate_and_flush(void)
 	struct mm_struct *mm = this_cpu_read(cpu_tlbstate.loaded_mm);
 	u64 tlb_gen = atomic64_read(&init_mm.context.tlb_gen);
 	unsigned long cr3 = __read_cr3();
+	u64 lam = cr3 & (X86_CR3_LAM_U48 | X86_CR3_LAM_U57);
 
 	/* Assert that CR3 already references the right mm. */
 	WARN_ON((cr3 & CR3_ADDR_MASK) != __pa(mm->pgd));
@@ -700,7 +730,7 @@  void initialize_tlbstate_and_flush(void)
 		!(cr4_read_shadow() & X86_CR4_PCIDE));
 
 	/* Force ASID 0 and force a TLB flush. */
-	write_cr3(build_cr3(mm->pgd, 0));
+	write_cr3(build_cr3(mm->pgd, 0, lam));
 
 	/* Reinitialize tlbstate. */
 	this_cpu_write(cpu_tlbstate.last_user_mm_spec, LAST_USER_MM_INIT);
@@ -1047,8 +1077,10 @@  void flush_tlb_kernel_range(unsigned long start, unsigned long end)
  */
 unsigned long __get_current_cr3_fast(void)
 {
-	unsigned long cr3 = build_cr3(this_cpu_read(cpu_tlbstate.loaded_mm)->pgd,
-		this_cpu_read(cpu_tlbstate.loaded_mm_asid));
+	unsigned long cr3 =
+		build_cr3(this_cpu_read(cpu_tlbstate.loaded_mm)->pgd,
+		this_cpu_read(cpu_tlbstate.loaded_mm_asid),
+		tlbstate_lam_cr3_mask());
 
 	/* For now, be very restrictive about when this can be called. */
 	VM_WARN_ON(in_nmi() || preemptible());