diff mbox series

[PATCHv13,05/16] x86/uaccess: Provide untagged_addr() and remove tags before address check

Message ID 20221227030829.12508-6-kirill.shutemov@linux.intel.com (mailing list archive)
State New
Headers show
Series Linear Address Masking enabling | expand

Commit Message

Kirill A. Shutemov Dec. 27, 2022, 3:08 a.m. UTC
untagged_addr() is a helper used by the core-mm to strip tag bits and
get the address to the canonical shape. In only handles userspace
addresses. The untagging mask is stored in mmu_context and will be set
on enabling LAM for the process.

The tags must not be included into check whether it's okay to access the
userspace address.

Strip tags in access_ok().

get_user() and put_user() don't use access_ok(), but check access
against TASK_SIZE directly in assembly. Strip tags, before calling into
the assembly helper.

Signed-off-by: Kirill A. Shutemov <kirill.shutemov@linux.intel.com>
Acked-by: Peter Zijlstra (Intel) <peterz@infradead.org>
Acked-by: Andy Lutomirski <luto@kernel.org>
Tested-by: Alexander Potapenko <glider@google.com>
---
 arch/x86/include/asm/mmu.h         |  3 ++
 arch/x86/include/asm/mmu_context.h | 11 +++++++
 arch/x86/include/asm/uaccess.h     | 47 +++++++++++++++++++++++++++---
 arch/x86/kernel/process.c          |  3 ++
 4 files changed, 60 insertions(+), 4 deletions(-)

Comments

Linus Torvalds Dec. 27, 2022, 7:10 p.m. UTC | #1
On Mon, Dec 26, 2022 at 7:08 PM Kirill A. Shutemov
<kirill.shutemov@linux.intel.com> wrote:
>
> --- a/arch/x86/include/asm/uaccess.h
> +++ b/arch/x86/include/asm/uaccess.h
> @@ -21,6 +22,37 @@ static inline bool pagefault_disabled(void);
>  # define WARN_ON_IN_IRQ()
>  #endif
>
> +#ifdef CONFIG_X86_64

I think this should be CONFIG_ADDRESS_MASKING or something like that.

This is not a "64 vs 32-bit feature". This is something else.

Even if you then were to select it unconditionally for 64-bit kernels
(but why would you?) it reads better if the #ifdef's make sense.

> +#define __untagged_addr(mm, addr)      ({                              \
> +       u64 __addr = (__force u64)(addr);                               \
> +       s64 sign = (s64)__addr >> 63;                                   \
> +       __addr &= READ_ONCE((mm)->context.untag_mask) | sign;           \

Now the READ_ONCE() doesn't make much sense. There shouldn't be any
data races on that thing.

Plus:

> +#define untagged_addr(addr) __untagged_addr(current->mm, addr)

I think this should at least allow caching it in 'current' without the
mm indirection.

In fact, it might be even better off as a per-cpu variable.

Because it is now in somewhat crititcal code sections:

> -#define get_user(x,ptr) ({ might_fault(); do_get_user_call(get_user,x,ptr); })
> +#define get_user(x,ptr)                                                        \
> +({                                                                     \
> +       might_fault();                                                  \
> +       do_get_user_call(get_user,x,untagged_ptr(ptr)); \
> +})

This is disgusting and wrong.

The whole reason we do do_get_user_call() as a function call is
because we *don't* want to do this kind of stuff at the call sites. We
used to inline it all, but with all the clac/stac and access_ok
checks, it all just ended up ballooning so much that it was much
better to make it a special function call with particular calling
conventions.

That untagged_ptr() should be done in that asm function, not in every call site.

Now, the sad part is that we got *rid* of all this kind of crap not
that long ago when Christoph cleaned up the old legacy set_fs() mess,
and we were able to make the task limit be a constant (ok, be _two_
constants, depending on LA57). So we'd have to re-introduce that nasty
"look up task size dynamically". See commit 47058bb54b57 ("x86: remove
address space overrides using set_fs()") for the removal that would
have to be re-instated.

But see above about "maybe it should be a per-cpu variable" - and
making that ALTERNATIVE th8ing even nastier.

Another alternative mght be to *only* test the sign bit in the
get_user/put_user functions, and just take the fault instead. Right
now we warn about non-canonical addresses because it implies somebody
might have missed an access_ok(), but we'd just mark those
get_user/put_user accesses special.

That would get this all entirely off the critical path. Most other
address masking is for relatively rare things (ie mmap/munmap), but
the user accesses are hot.

Hmm?

             Linus
Kirill A. Shutemov Dec. 31, 2022, 12:10 a.m. UTC | #2
On Tue, Dec 27, 2022 at 11:10:31AM -0800, Linus Torvalds wrote:
> On Mon, Dec 26, 2022 at 7:08 PM Kirill A. Shutemov
> <kirill.shutemov@linux.intel.com> wrote:
> >
> > --- a/arch/x86/include/asm/uaccess.h
> > +++ b/arch/x86/include/asm/uaccess.h
> > @@ -21,6 +22,37 @@ static inline bool pagefault_disabled(void);
> >  # define WARN_ON_IN_IRQ()
> >  #endif
> >
> > +#ifdef CONFIG_X86_64
> 
> I think this should be CONFIG_ADDRESS_MASKING or something like that.
> 
> This is not a "64 vs 32-bit feature". This is something else.
> 
> Even if you then were to select it unconditionally for 64-bit kernels
> (but why would you?) it reads better if the #ifdef's make sense.

I hoped to get away without a new option. It leads to more ifdeffery, but
well...

> > +#define __untagged_addr(mm, addr)      ({                              \
> > +       u64 __addr = (__force u64)(addr);                               \
> > +       s64 sign = (s64)__addr >> 63;                                   \
> > +       __addr &= READ_ONCE((mm)->context.untag_mask) | sign;           \
> 
> Now the READ_ONCE() doesn't make much sense. There shouldn't be any
> data races on that thing.

True. Removed.

> Plus:
> 
> > +#define untagged_addr(addr) __untagged_addr(current->mm, addr)
> 
> I think this should at least allow caching it in 'current' without the
> mm indirection.
> 
> In fact, it might be even better off as a per-cpu variable.
> 
> Because it is now in somewhat crititcal code sections:
> 
> > -#define get_user(x,ptr) ({ might_fault(); do_get_user_call(get_user,x,ptr); })
> > +#define get_user(x,ptr)                                                        \
> > +({                                                                     \
> > +       might_fault();                                                  \
> > +       do_get_user_call(get_user,x,untagged_ptr(ptr)); \
> > +})
> 
> This is disgusting and wrong.
> 
> The whole reason we do do_get_user_call() as a function call is
> because we *don't* want to do this kind of stuff at the call sites. We
> used to inline it all, but with all the clac/stac and access_ok
> checks, it all just ended up ballooning so much that it was much
> better to make it a special function call with particular calling
> conventions.
> 
> That untagged_ptr() should be done in that asm function, not in every call site.
> 
> Now, the sad part is that we got *rid* of all this kind of crap not
> that long ago when Christoph cleaned up the old legacy set_fs() mess,
> and we were able to make the task limit be a constant (ok, be _two_
> constants, depending on LA57). So we'd have to re-introduce that nasty
> "look up task size dynamically". See commit 47058bb54b57 ("x86: remove
> address space overrides using set_fs()") for the removal that would
> have to be re-instated.
> 
> But see above about "maybe it should be a per-cpu variable" - and
> making that ALTERNATIVE th8ing even nastier.

I made it a per-cpu variable (outside struct tlb_state to be visible in
modules). __get/put_user_X() now have a single instruction to untag the
address and it is gated by X86_FEATURE_LAM.

Seems reasonable to me.

BTW, am I blind or we have no infrastructure to hookup static branches
from assembly?

I would be a better fit than ALTERNATIVE here. It would allow to defer
overhead until the first user of the feature.

Is there any fundamental reason for this or just no demand?

> Another alternative mght be to *only* test the sign bit in the
> get_user/put_user functions, and just take the fault instead. Right
> now we warn about non-canonical addresses because it implies somebody
> might have missed an access_ok(), but we'd just mark those
> get_user/put_user accesses special.
> 
> That would get this all entirely off the critical path. Most other
> address masking is for relatively rare things (ie mmap/munmap), but
> the user accesses are hot.
> 
> Hmm?

Below is fixup that suppose to address your concerns. I will also extend
selftests to cover get/put_user().

diff --git a/arch/x86/Kconfig b/arch/x86/Kconfig
index 3604074a878b..211869aa618d 100644
--- a/arch/x86/Kconfig
+++ b/arch/x86/Kconfig
@@ -2290,6 +2290,17 @@ config RANDOMIZE_MEMORY_PHYSICAL_PADDING
 
 	  If unsure, leave at the default value.
 
+config ADDRESS_MASKING
+	bool "Linear Address Masking support"
+	depends on X86_64
+	help
+	  Linear Address Masking (LAM) modifies the checking that is applied
+	  to 64-bit linear addresses, allowing software to use of the
+	  untranslated address bits for metadata.
+
+	  The capability can be used for efficient address sanitizers (ASAN)
+	  implementation and for optimizations in JITs.
+
 config HOTPLUG_CPU
 	def_bool y
 	depends on SMP
diff --git a/arch/x86/include/asm/disabled-features.h b/arch/x86/include/asm/disabled-features.h
index c44b56f7ffba..66be8acabe92 100644
--- a/arch/x86/include/asm/disabled-features.h
+++ b/arch/x86/include/asm/disabled-features.h
@@ -99,6 +99,12 @@
 # define DISABLE_TDX_GUEST	(1 << (X86_FEATURE_TDX_GUEST & 31))
 #endif
 
+#ifdef CONFIG_ADDRESS_MASKING
+# define DISABLE_LAM	0
+#else
+# define DISABLE_LAM	(1 << (X86_FEATURE_LAM & 31))
+#endif
+
 /*
  * Make sure to add features to the correct mask
  */
@@ -115,7 +121,7 @@
 #define DISABLED_MASK10	0
 #define DISABLED_MASK11	(DISABLE_RETPOLINE|DISABLE_RETHUNK|DISABLE_UNRET| \
 			 DISABLE_CALL_DEPTH_TRACKING)
-#define DISABLED_MASK12	0
+#define DISABLED_MASK12	(DISABLE_LAM)
 #define DISABLED_MASK13	0
 #define DISABLED_MASK14	0
 #define DISABLED_MASK15	0
diff --git a/arch/x86/include/asm/mmu.h b/arch/x86/include/asm/mmu.h
index 90d20679e4d7..0da5c227f490 100644
--- a/arch/x86/include/asm/mmu.h
+++ b/arch/x86/include/asm/mmu.h
@@ -44,7 +44,9 @@ typedef struct {
 
 #ifdef CONFIG_X86_64
 	unsigned long flags;
+#endif
 
+#ifdef CONFIG_ADDRESS_MASKING
 	/* Active LAM mode:  X86_CR3_LAM_U48 or X86_CR3_LAM_U57 or 0 (disabled) */
 	unsigned long lam_cr3_mask;
 
diff --git a/arch/x86/include/asm/mmu_context.h b/arch/x86/include/asm/mmu_context.h
index 4bc95c35cbd3..6ffc42dfd59d 100644
--- a/arch/x86/include/asm/mmu_context.h
+++ b/arch/x86/include/asm/mmu_context.h
@@ -91,7 +91,7 @@ static inline void switch_ldt(struct mm_struct *prev, struct mm_struct *next)
 }
 #endif
 
-#ifdef CONFIG_X86_64
+#ifdef CONFIG_ADDRESS_MASKING
 static inline unsigned long mm_lam_cr3_mask(struct mm_struct *mm)
 {
 	return READ_ONCE(mm->context.lam_cr3_mask);
diff --git a/arch/x86/include/asm/tlbflush.h b/arch/x86/include/asm/tlbflush.h
index 662598dea937..75bfaa421030 100644
--- a/arch/x86/include/asm/tlbflush.h
+++ b/arch/x86/include/asm/tlbflush.h
@@ -2,7 +2,7 @@
 #ifndef _ASM_X86_TLBFLUSH_H
 #define _ASM_X86_TLBFLUSH_H
 
-#include <linux/mm.h>
+#include <linux/mm_types.h>
 #include <linux/sched.h>
 
 #include <asm/processor.h>
@@ -12,6 +12,7 @@
 #include <asm/invpcid.h>
 #include <asm/pti.h>
 #include <asm/processor-flags.h>
+#include <asm/pgtable.h>
 
 void __flush_tlb_all(void);
 
@@ -53,6 +54,15 @@ static inline void cr4_clear_bits(unsigned long mask)
 	local_irq_restore(flags);
 }
 
+#ifdef CONFIG_ADDRESS_MASKING
+DECLARE_PER_CPU(u64, tlbstate_untag_mask);
+
+static inline u64 current_untag_mask(void)
+{
+	return this_cpu_read(tlbstate_untag_mask);
+}
+#endif
+
 #ifndef MODULE
 /*
  * 6 because 6 should be plenty and struct tlb_state will fit in two cache
@@ -101,7 +111,7 @@ struct tlb_state {
 	 */
 	bool invalidate_other;
 
-#ifdef CONFIG_X86_64
+#ifdef CONFIG_ADDRESS_MASKING
 	/*
 	 * Active LAM mode.
 	 *
@@ -367,27 +377,29 @@ static inline bool huge_pmd_needs_flush(pmd_t oldpmd, pmd_t newpmd)
 }
 #define huge_pmd_needs_flush huge_pmd_needs_flush
 
-#ifdef CONFIG_X86_64
-static inline unsigned long tlbstate_lam_cr3_mask(void)
+#ifdef CONFIG_ADDRESS_MASKING
+static inline  u64 tlbstate_lam_cr3_mask(void)
 {
-	unsigned long lam = this_cpu_read(cpu_tlbstate.lam);
+	u64 lam = this_cpu_read(cpu_tlbstate.lam);
 
 	return lam << X86_CR3_LAM_U57_BIT;
 }
 
-static inline void set_tlbstate_cr3_lam_mask(unsigned long mask)
+static inline void set_tlbstate_lam_mode(struct mm_struct *mm)
 {
-	this_cpu_write(cpu_tlbstate.lam, mask >> X86_CR3_LAM_U57_BIT);
+	this_cpu_write(cpu_tlbstate.lam,
+		       mm->context.lam_cr3_mask >> X86_CR3_LAM_U57_BIT);
+	this_cpu_write(tlbstate_untag_mask, mm->context.untag_mask);
 }
 
 #else
 
-static inline unsigned long tlbstate_lam_cr3_mask(void)
+static inline u64 tlbstate_lam_cr3_mask(void)
 {
 	return 0;
 }
 
-static inline void set_tlbstate_cr3_lam_mask(u64 mask)
+static inline void set_tlbstate_lam_mode(struct mm_struct *mm)
 {
 }
 #endif
diff --git a/arch/x86/include/asm/uaccess.h b/arch/x86/include/asm/uaccess.h
index 1d931c7f6741..730649175191 100644
--- a/arch/x86/include/asm/uaccess.h
+++ b/arch/x86/include/asm/uaccess.h
@@ -13,6 +13,7 @@
 #include <asm/page.h>
 #include <asm/smap.h>
 #include <asm/extable.h>
+#include <asm/tlbflush.h>
 
 #ifdef CONFIG_DEBUG_ATOMIC_SLEEP
 static inline bool pagefault_disabled(void);
@@ -22,7 +23,7 @@ static inline bool pagefault_disabled(void);
 # define WARN_ON_IN_IRQ()
 #endif
 
-#ifdef CONFIG_X86_64
+#ifdef CONFIG_ADDRESS_MASKING
 DECLARE_STATIC_KEY_FALSE(tagged_addr_key);
 
 /*
@@ -31,31 +32,24 @@ DECLARE_STATIC_KEY_FALSE(tagged_addr_key);
  * Magic with the 'sign' allows to untag userspace pointer without any branches
  * while leaving kernel addresses intact.
  */
-#define __untagged_addr(mm, addr)	({				\
+#define __untagged_addr(untag_mask, addr)	({			\
 	u64 __addr = (__force u64)(addr);				\
 	if (static_branch_likely(&tagged_addr_key)) {			\
 		s64 sign = (s64)__addr >> 63;				\
-		u64 mask = READ_ONCE((mm)->context.untag_mask);		\
-		__addr &= mask | sign;					\
+		__addr &= untag_mask | sign;				\
 	}								\
 	(__force __typeof__(addr))__addr;				\
 })
 
-#define untagged_addr(addr) __untagged_addr(current->mm, addr)
+#define untagged_addr(addr) __untagged_addr(current_untag_mask(), addr)
 
 #define untagged_addr_remote(mm, addr)	({				\
 	mmap_assert_locked(mm);						\
-	__untagged_addr(mm, addr);					\
+	__untagged_addr((mm)->context.untag_mask, addr);		\
 })
 
-#define untagged_ptr(ptr)	({					\
-	u64 __ptrval = (__force u64)(ptr);				\
-	__ptrval = untagged_addr(__ptrval);				\
-	(__force __typeof__(ptr))__ptrval;				\
-})
 #else
-#define untagged_addr(addr)	(addr)
-#define untagged_ptr(ptr)	(ptr)
+#define untagged_addr(addr)    (addr)
 #endif
 
 /**
@@ -167,7 +161,7 @@ extern int __get_user_bad(void);
 #define get_user(x,ptr)							\
 ({									\
 	might_fault();							\
-	do_get_user_call(get_user,x,untagged_ptr(ptr));	\
+	do_get_user_call(get_user,x,ptr);				\
 })
 
 /**
@@ -270,7 +264,7 @@ extern void __put_user_nocheck_8(void);
  */
 #define put_user(x, ptr) ({						\
 	might_fault();							\
-	do_put_user_call(put_user,x,untagged_ptr(ptr));			\
+	do_put_user_call(put_user,x,ptr);				\
 })
 
 /**
diff --git a/arch/x86/kernel/process_64.c b/arch/x86/kernel/process_64.c
index add85615d5ae..1f61e3a13b4f 100644
--- a/arch/x86/kernel/process_64.c
+++ b/arch/x86/kernel/process_64.c
@@ -743,6 +743,7 @@ static long prctl_map_vdso(const struct vdso_image *image, unsigned long addr)
 }
 #endif
 
+#ifdef CONFIG_ADDRESS_MASKING
 DEFINE_STATIC_KEY_FALSE(tagged_addr_key);
 EXPORT_SYMBOL_GPL(tagged_addr_key);
 
@@ -775,7 +776,7 @@ static int prctl_enable_tagged_addr(struct mm_struct *mm, unsigned long nr_bits)
 	}
 
 	write_cr3(__read_cr3() | mm->context.lam_cr3_mask);
-	set_tlbstate_cr3_lam_mask(mm->context.lam_cr3_mask);
+	set_tlbstate_lam_mode(mm);
 	set_bit(MM_CONTEXT_LOCK_LAM, &mm->context.flags);
 
 	mmap_write_unlock(mm);
@@ -783,6 +784,7 @@ static int prctl_enable_tagged_addr(struct mm_struct *mm, unsigned long nr_bits)
 	static_branch_enable(&tagged_addr_key);
 	return 0;
 }
+#endif
 
 long do_arch_prctl_64(struct task_struct *task, int option, unsigned long arg2)
 {
@@ -871,6 +873,7 @@ long do_arch_prctl_64(struct task_struct *task, int option, unsigned long arg2)
 	case ARCH_MAP_VDSO_64:
 		return prctl_map_vdso(&vdso_image_64, arg2);
 #endif
+#ifdef CONFIG_ADDRESS_MASKING
 	case ARCH_GET_UNTAG_MASK:
 		return put_user(task->mm->context.untag_mask,
 				(unsigned long __user *)arg2);
@@ -884,6 +887,7 @@ long do_arch_prctl_64(struct task_struct *task, int option, unsigned long arg2)
 			return put_user(0, (unsigned long __user *)arg2);
 		else
 			return put_user(LAM_U57_BITS, (unsigned long __user *)arg2);
+#endif
 	default:
 		ret = -EINVAL;
 		break;
diff --git a/arch/x86/lib/getuser.S b/arch/x86/lib/getuser.S
index b70d98d79a9d..22e92236e8f6 100644
--- a/arch/x86/lib/getuser.S
+++ b/arch/x86/lib/getuser.S
@@ -35,6 +35,13 @@
 #include <asm/smap.h>
 #include <asm/export.h>
 
+#ifdef CONFIG_ADDRESS_MASKING
+#define UNTAG_ADDR \
+	ALTERNATIVE "", __stringify(and PER_CPU_VAR(tlbstate_untag_mask), %rax), X86_FEATURE_LAM
+#else
+#define UNTAG_ADDR
+#endif
+
 #define ASM_BARRIER_NOSPEC ALTERNATIVE "", "lfence", X86_FEATURE_LFENCE_RDTSC
 
 #ifdef CONFIG_X86_5LEVEL
@@ -48,6 +55,7 @@
 
 	.text
 SYM_FUNC_START(__get_user_1)
+	UNTAG_ADDR
 	LOAD_TASK_SIZE_MINUS_N(0)
 	cmp %_ASM_DX,%_ASM_AX
 	jae bad_get_user
@@ -62,6 +70,7 @@ SYM_FUNC_END(__get_user_1)
 EXPORT_SYMBOL(__get_user_1)
 
 SYM_FUNC_START(__get_user_2)
+	UNTAG_ADDR
 	LOAD_TASK_SIZE_MINUS_N(1)
 	cmp %_ASM_DX,%_ASM_AX
 	jae bad_get_user
@@ -76,6 +85,7 @@ SYM_FUNC_END(__get_user_2)
 EXPORT_SYMBOL(__get_user_2)
 
 SYM_FUNC_START(__get_user_4)
+	UNTAG_ADDR
 	LOAD_TASK_SIZE_MINUS_N(3)
 	cmp %_ASM_DX,%_ASM_AX
 	jae bad_get_user
@@ -91,6 +101,7 @@ EXPORT_SYMBOL(__get_user_4)
 
 SYM_FUNC_START(__get_user_8)
 #ifdef CONFIG_X86_64
+	UNTAG_ADDR
 	LOAD_TASK_SIZE_MINUS_N(7)
 	cmp %_ASM_DX,%_ASM_AX
 	jae bad_get_user
diff --git a/arch/x86/lib/putuser.S b/arch/x86/lib/putuser.S
index 32125224fcca..9e0276c553a8 100644
--- a/arch/x86/lib/putuser.S
+++ b/arch/x86/lib/putuser.S
@@ -33,6 +33,13 @@
  * as they get called from within inline assembly.
  */
 
+#ifdef CONFIG_ADDRESS_MASKING
+#define UNTAG_ADDR \
+	ALTERNATIVE "", __stringify(and PER_CPU_VAR(tlbstate_untag_mask), %rcx), X86_FEATURE_LAM
+#else
+#define UNTAG_ADDR
+#endif
+
 #ifdef CONFIG_X86_5LEVEL
 #define LOAD_TASK_SIZE_MINUS_N(n) \
 	ALTERNATIVE __stringify(mov $((1 << 47) - 4096 - (n)),%rbx), \
@@ -44,6 +51,7 @@
 
 .text
 SYM_FUNC_START(__put_user_1)
+	UNTAG_ADDR
 	LOAD_TASK_SIZE_MINUS_N(0)
 	cmp %_ASM_BX,%_ASM_CX
 	jae .Lbad_put_user
@@ -66,6 +74,7 @@ SYM_FUNC_END(__put_user_nocheck_1)
 EXPORT_SYMBOL(__put_user_nocheck_1)
 
 SYM_FUNC_START(__put_user_2)
+	UNTAG_ADDR
 	LOAD_TASK_SIZE_MINUS_N(1)
 	cmp %_ASM_BX,%_ASM_CX
 	jae .Lbad_put_user
@@ -88,6 +97,7 @@ SYM_FUNC_END(__put_user_nocheck_2)
 EXPORT_SYMBOL(__put_user_nocheck_2)
 
 SYM_FUNC_START(__put_user_4)
+	UNTAG_ADDR
 	LOAD_TASK_SIZE_MINUS_N(3)
 	cmp %_ASM_BX,%_ASM_CX
 	jae .Lbad_put_user
@@ -110,6 +120,7 @@ SYM_FUNC_END(__put_user_nocheck_4)
 EXPORT_SYMBOL(__put_user_nocheck_4)
 
 SYM_FUNC_START(__put_user_8)
+	UNTAG_ADDR
 	LOAD_TASK_SIZE_MINUS_N(7)
 	cmp %_ASM_BX,%_ASM_CX
 	jae .Lbad_put_user
diff --git a/arch/x86/mm/init.c b/arch/x86/mm/init.c
index d3987359d441..be5c7d1c0265 100644
--- a/arch/x86/mm/init.c
+++ b/arch/x86/mm/init.c
@@ -1044,6 +1044,11 @@ __visible DEFINE_PER_CPU_ALIGNED(struct tlb_state, cpu_tlbstate) = {
 	.cr4 = ~0UL,	/* fail hard if we screw up cr4 shadow initialization */
 };
 
+#ifdef CONFIG_ADDRESS_MASKING
+DEFINE_PER_CPU(u64, tlbstate_untag_mask);
+EXPORT_PER_CPU_SYMBOL(tlbstate_untag_mask);
+#endif
+
 void update_cache_mode_entry(unsigned entry, enum page_cache_mode cache)
 {
 	/* entry 0 MUST be WB (hardwired to speed up translations) */
diff --git a/arch/x86/mm/tlb.c b/arch/x86/mm/tlb.c
index 9d1e7a5f141c..8c330a6d0ece 100644
--- a/arch/x86/mm/tlb.c
+++ b/arch/x86/mm/tlb.c
@@ -635,7 +635,7 @@ void switch_mm_irqs_off(struct mm_struct *prev, struct mm_struct *next,
 		barrier();
 	}
 
-	set_tlbstate_cr3_lam_mask(new_lam);
+	set_tlbstate_lam_mode(next);
 	if (need_flush) {
 		this_cpu_write(cpu_tlbstate.ctxs[new_asid].ctx_id, next->context.ctx_id);
 		this_cpu_write(cpu_tlbstate.ctxs[new_asid].tlb_gen, next_tlb_gen);
@@ -726,7 +726,7 @@ void initialize_tlbstate_and_flush(void)
 	this_cpu_write(cpu_tlbstate.next_asid, 1);
 	this_cpu_write(cpu_tlbstate.ctxs[0].ctx_id, mm->context.ctx_id);
 	this_cpu_write(cpu_tlbstate.ctxs[0].tlb_gen, tlb_gen);
-	set_tlbstate_cr3_lam_mask(0);
+	set_tlbstate_lam_mode(mm);
 
 	for (i = 1; i < TLB_NR_DYN_ASIDS; i++)
 		this_cpu_write(cpu_tlbstate.ctxs[i].ctx_id, 0);
Linus Torvalds Dec. 31, 2022, 12:42 a.m. UTC | #3
On Fri, Dec 30, 2022 at 4:10 PM Kirill A. Shutemov <kirill@shutemov.name> wrote:
>
> I made it a per-cpu variable (outside struct tlb_state to be visible in
> modules). __get/put_user_X() now have a single instruction to untag the
> address and it is gated by X86_FEATURE_LAM.

Yeah, that looks more reasonable to me.

> BTW, am I blind or we have no infrastructure to hookup static branches
> from assembly?

I think you're right.

> I would be a better fit than ALTERNATIVE here. It would allow to defer
> overhead until the first user of the feature.

Well, it would make the overhead worse once people actually start
using it. So it's not obvious that a static branch is really the right
thing to do.

That said, while I think that UNTAG_ADDR is quite reasonable now, the
more I look at getuser.S and putuser.S, the more I'm thinking that
getting rid of the TASK_SIZE comparison entirely is the right thing to
do on x86-64.

It's really rather nasty, with not just that whole LA57 alternative,
but it's doing a large 64-bit constant too.

Now, on 32-bit, we do indeed have to compare against TASK_SIZE
explicitly, but on 32-bit we could just use an immediate for the cmp
instruction, so even there that whole "load constant" isn't really
optimal.

And on 64-bit, we really only need to check the high bit.

In fact, we don't even want to *check* it, because then we need to do
that disgusting array_index_mask_nospec thing to mask the bits for it,
so it would be even better to use purely arithmetic with no
conditionals anywhere.

And that's exactly what we could do on x86-64:

        movq %rdx,%rax
        shrq $63,%rax
        orq %rax,%rdx

would actually be noticeably better than what we do now for for
TASK_SIZE checking _and_ for the array index masking (for putuser.S,
we'd use %rbx instead of %rax in that sequence).

The above three simple instructions would replace all of the games we
now play with

        LOAD_TASK_SIZE_MINUS_N(0)
        cmp %_ASM_DX,%_ASM_AX
        jae bad_get_user
        sbb %_ASM_DX, %_ASM_DX          /* array_index_mask_nospec() */
        and %_ASM_DX, %_ASM_AX

entirely.

It would just turn all kernel addresses into all ones, which is then
guaranteed to fault. So no need for any conditional that never
triggers in real life anyway.

On 32-bit, we'd still have to do that old sequence, but we'd replace the

        LOAD_TASK_SIZE_MINUS_N(0)
        cmp %_ASM_DX,%_ASM_AX

with just the simpler

        cmp $TASK_SIZE_MAX-(n),%_ASM_AX

since the only reason we do that immediate load is because there si no
64-bit immediate compare instruction.

And once we don't test against TASK_SIZE, the need for UNTAG_ADDR just
goes away, so now LAM is better too.

In other words, we could actually improve on our current code _and_
simplify the LAM situation. Win-win.

Anyway, I do not hate the version of the patch you posted, but I do
think that the win-win of just making LAM not _have_ this issue in the
first place might be the preferable one.

The one thing that that "shift by 63 and bitwise or" trick does
require is that the _ASM_EXTABLE_UA() thing for getuser/putuser would
have to have an extra annotation to shut up the

        WARN_ONCE(trapnr == X86_TRAP_GP, "General protection fault in
user access. Non-canonical address?");

in ex_handler_uaccess() for the GP trap that users can now cause by
giving a non-canonical address with the high bit clear. So we'd
probably just want a new EX_TYPE_* for these cases, but that still
looks fairly straightforward.

Hmm?

              Linus
David Laight Jan. 2, 2023, 1:55 p.m. UTC | #4
From: Linus Torvalds
> Sent: 31 December 2022 00:42
> 
...
> And on 64-bit, we really only need to check the high bit.
> 
> In fact, we don't even want to *check* it, because then we need to do
> that disgusting array_index_mask_nospec thing to mask the bits for it,
> so it would be even better to use purely arithmetic with no
> conditionals anywhere.
> 
> And that's exactly what we could do on x86-64:
> 
>         movq %rdx,%rax
>         shrq $63,%rax
>         orq %rax,%rdx
> 
> would actually be noticeably better than what we do now for for
> TASK_SIZE checking _and_ for the array index masking (for putuser.S,
> we'd use %rbx instead of %rax in that sequence).
...
> It would just turn all kernel addresses into all ones, which is then
> guaranteed to fault. So no need for any conditional that never
> triggers in real life anyway.

Are byte loads guaranteed to fault?
I suspect the 'all ones' address can be assigned to io.
So get/put_user for a byte probably needs a 'js' test after the 'orq'.
(I don't think you need to worry about a apeculative load from an
uncached address.)

...
> And once we don't test against TASK_SIZE, the need for UNTAG_ADDR just
> goes away, so now LAM is better too.
> 
> In other words, we could actually improve on our current code _and_
> simplify the LAM situation. Win-win.

Presumably the fault handler already has the code to untag addresses.

It has to be said that I don't really see why tagging addresses is a
significant benefit unless the hardware checks than the PTE/TLB is
also set with the correct tag.
All it seems to me that it does it make more 'random addresses' valid.

Clearly interpreters can set and check the high address bits, but they
can also mask them after the checks (or use xor to flip the bits and
let the cpu fault on errors).

	David

-
Registered Address Lakeside, Bramley Road, Mount Farm, Milton Keynes, MK1 1PT, UK
Registration No: 1397386 (Wales)
Linus Torvalds Jan. 2, 2023, 7:05 p.m. UTC | #5
On Mon, Jan 2, 2023 at 5:55 AM David Laight <David.Laight@aculab.com> wrote:
>
> > It would just turn all kernel addresses into all ones, which is then
> > guaranteed to fault. So no need for any conditional that never
> > triggers in real life anyway.
>
> Are byte loads guaranteed to fault?

Yeah, we don't map the highest address on x86-64. And we don't want to
in the future either, because of how our error pointers work (ie if
somebody misses an "IS_ERR()" check and uses an error pointer as a
pointer, we want that to fault, rather than do random things).

It's not a hard requirement architecturally (either hardware or
software), and iirc on 32-bit we used to use the last virtual page for
something, so maybe I'm missing some odd use on 64-bit too, but
accessing the top-of-virtual address space on x86-64 should always
cause a clear fault afaik.

A byte access would always be a page fault, while a wrapping access
might trigger a GP fault first (probably not - on 32-bit it would be a
segment size violation, on 64-bit we've left those bad old days behind
and I don't think wrapping matters either)

> Presumably the fault handler already has the code to untag addresses.

More importantly, the fault handler isn't in any critical path. By the
time you've taken a page fault, the extra instructions to mask off any
bits are entirely irrelevant.

> It has to be said that I don't really see why tagging addresses is a
> significant benefit unless the hardware checks than the PTE/TLB is
> also set with the correct tag.

You can certainly pair it with hardware support for checking the tag,
but even without it, it can be a useful acceleration for doing
software pointer tag checking without having to always add the extra
code (and extra register pressure) to mask things off manually to
dereference it.

And traditionally, in various virtual machine environments, it's been
used for hiding information about what the pointer _points_ to,
without having to use up extra memory for some kind of type lookup.
Old old LISP being the traditional case (not because of some "top byte
ignore" flag, but simply because the address space was smaller than
the word size). People did the same on the original m68k - and for the
exact same reason.

Of course, on m68k it was a horrible mistake that people still
remember ("You're telling me 24 bits wasn't enough after all?") but
it's a new day, and 64 bits is a _lot)_ more than 32 bits.

The new world order of "56 bits is actually enough" is likely to
remain true in a lot of environments for the foreseeable future - the
people who already disagree tend to special, either because they want
to use the virtual address bits for *other* things (ie sparse address
spaces etc) or because they are doing globally addressable memory on
truly large machines.

A lot of "normal" use scenarios will be fundamentally limited by
physics to "56 bits is a *lot*", and using high pointer address bits
is unquestionably a good thing.

So enforcing some hardware tag check is not always what you want,
because while that is useful for *one* particular use case (ie the
ARM64 MTE extension, and hw-tagged KASAN), and that may be the only
particular use in some scenarios, other environments might use the top
bits for other pointer information.

            Linus
David Laight Jan. 3, 2023, 8:37 a.m. UTC | #6
From: Linus Torvalds
> Sent: 02 January 2023 19:05
...
> > Are byte loads guaranteed to fault?
> 
> Yeah, we don't map the highest address on x86-64. And we don't want to
> in the future either, because of how our error pointers work (ie if
> somebody misses an "IS_ERR()" check and uses an error pointer as a
> pointer, we want that to fault, rather than do random things).
> 
> It's not a hard requirement architecturally (either hardware or
> software), and iirc on 32-bit we used to use the last virtual page for
> something, so maybe I'm missing some odd use on 64-bit too, but
> accessing the top-of-virtual address space on x86-64 should always
> cause a clear fault afaik.
> 
> A byte access would always be a page fault, while a wrapping access
> might trigger a GP fault first (probably not - on 32-bit it would be a
> segment size violation, on 64-bit we've left those bad old days behind
> and I don't think wrapping matters either)

For some reason I was thinking you were relying on the wrapping access.

The other check is access_ok() for longer copies.
If you make the assumption that the copy is 'reasonably sequential'
then the length can be ignored provided that the bottom of 'kernel
space' is unmapped.
For x86-64 treating -ve values as kernel that is pretty true.
32-bit would need at least one unmapped page between user and kernel.
I suspect the kernel is loaded at 0xc0000000 making that difficult.

	David

-
Registered Address Lakeside, Bramley Road, Mount Farm, Milton Keynes, MK1 1PT, UK
Registration No: 1397386 (Wales)
Kirill A. Shutemov Jan. 7, 2023, 9:10 a.m. UTC | #7
On Fri, Dec 30, 2022 at 04:42:05PM -0800, Linus Torvalds wrote:
> The one thing that that "shift by 63 and bitwise or" trick does
> require is that the _ASM_EXTABLE_UA() thing for getuser/putuser would
> have to have an extra annotation to shut up the
> 
>         WARN_ONCE(trapnr == X86_TRAP_GP, "General protection fault in
> user access. Non-canonical address?");
> 
> in ex_handler_uaccess() for the GP trap that users can now cause by
> giving a non-canonical address with the high bit clear. So we'd
> probably just want a new EX_TYPE_* for these cases, but that still
> looks fairly straightforward.

Plain _ASM_EXTABLE() seems does the trick.

> Hmm?

Here's what I've come up with:

diff --git a/arch/x86/lib/getuser.S b/arch/x86/lib/getuser.S
index b70d98d79a9d..3e69e3727769 100644
--- a/arch/x86/lib/getuser.S
+++ b/arch/x86/lib/getuser.S
@@ -37,22 +37,22 @@
 
 #define ASM_BARRIER_NOSPEC ALTERNATIVE "", "lfence", X86_FEATURE_LFENCE_RDTSC
 
-#ifdef CONFIG_X86_5LEVEL
-#define LOAD_TASK_SIZE_MINUS_N(n) \
-	ALTERNATIVE __stringify(mov $((1 << 47) - 4096 - (n)),%rdx), \
-		    __stringify(mov $((1 << 56) - 4096 - (n)),%rdx), X86_FEATURE_LA57
-#else
-#define LOAD_TASK_SIZE_MINUS_N(n) \
-	mov $(TASK_SIZE_MAX - (n)),%_ASM_DX
-#endif
+.macro check_range size:req
+.if IS_ENABLED(CONFIG_X86_64)
+	mov %rax, %rdx
+	shr $63, %rdx
+	or %rdx, %rax
+.else
+	cmp $TASK_SIZE_MAX-\size+1, %eax
+	jae .Lbad_get_user
+	sbb %edx, %edx		/* array_index_mask_nospec() */
+	and %edx, %eax
+.endif
+.endm
 
 	.text
 SYM_FUNC_START(__get_user_1)
-	LOAD_TASK_SIZE_MINUS_N(0)
-	cmp %_ASM_DX,%_ASM_AX
-	jae bad_get_user
-	sbb %_ASM_DX, %_ASM_DX		/* array_index_mask_nospec() */
-	and %_ASM_DX, %_ASM_AX
+	check_range size=1
 	ASM_STAC
 1:	movzbl (%_ASM_AX),%edx
 	xor %eax,%eax
@@ -62,11 +62,7 @@ SYM_FUNC_END(__get_user_1)
 EXPORT_SYMBOL(__get_user_1)
 
 SYM_FUNC_START(__get_user_2)
-	LOAD_TASK_SIZE_MINUS_N(1)
-	cmp %_ASM_DX,%_ASM_AX
-	jae bad_get_user
-	sbb %_ASM_DX, %_ASM_DX		/* array_index_mask_nospec() */
-	and %_ASM_DX, %_ASM_AX
+	check_range size=2
 	ASM_STAC
 2:	movzwl (%_ASM_AX),%edx
 	xor %eax,%eax
@@ -76,11 +72,7 @@ SYM_FUNC_END(__get_user_2)
 EXPORT_SYMBOL(__get_user_2)
 
 SYM_FUNC_START(__get_user_4)
-	LOAD_TASK_SIZE_MINUS_N(3)
-	cmp %_ASM_DX,%_ASM_AX
-	jae bad_get_user
-	sbb %_ASM_DX, %_ASM_DX		/* array_index_mask_nospec() */
-	and %_ASM_DX, %_ASM_AX
+	check_range size=4
 	ASM_STAC
 3:	movl (%_ASM_AX),%edx
 	xor %eax,%eax
@@ -90,30 +82,17 @@ SYM_FUNC_END(__get_user_4)
 EXPORT_SYMBOL(__get_user_4)
 
 SYM_FUNC_START(__get_user_8)
-#ifdef CONFIG_X86_64
-	LOAD_TASK_SIZE_MINUS_N(7)
-	cmp %_ASM_DX,%_ASM_AX
-	jae bad_get_user
-	sbb %_ASM_DX, %_ASM_DX		/* array_index_mask_nospec() */
-	and %_ASM_DX, %_ASM_AX
+	check_range size=8
 	ASM_STAC
+#ifdef CONFIG_X86_64
 4:	movq (%_ASM_AX),%rdx
-	xor %eax,%eax
-	ASM_CLAC
-	RET
 #else
-	LOAD_TASK_SIZE_MINUS_N(7)
-	cmp %_ASM_DX,%_ASM_AX
-	jae bad_get_user_8
-	sbb %_ASM_DX, %_ASM_DX		/* array_index_mask_nospec() */
-	and %_ASM_DX, %_ASM_AX
-	ASM_STAC
 4:	movl (%_ASM_AX),%edx
 5:	movl 4(%_ASM_AX),%ecx
+#endif
 	xor %eax,%eax
 	ASM_CLAC
 	RET
-#endif
 SYM_FUNC_END(__get_user_8)
 EXPORT_SYMBOL(__get_user_8)
 
@@ -166,7 +145,7 @@ EXPORT_SYMBOL(__get_user_nocheck_8)
 
 SYM_CODE_START_LOCAL(.Lbad_get_user_clac)
 	ASM_CLAC
-bad_get_user:
+.Lbad_get_user:
 	xor %edx,%edx
 	mov $(-EFAULT),%_ASM_AX
 	RET
@@ -184,23 +163,23 @@ SYM_CODE_END(.Lbad_get_user_8_clac)
 #endif
 
 /* get_user */
-	_ASM_EXTABLE_UA(1b, .Lbad_get_user_clac)
-	_ASM_EXTABLE_UA(2b, .Lbad_get_user_clac)
-	_ASM_EXTABLE_UA(3b, .Lbad_get_user_clac)
+	_ASM_EXTABLE(1b, .Lbad_get_user_clac)
+	_ASM_EXTABLE(2b, .Lbad_get_user_clac)
+	_ASM_EXTABLE(3b, .Lbad_get_user_clac)
 #ifdef CONFIG_X86_64
-	_ASM_EXTABLE_UA(4b, .Lbad_get_user_clac)
+	_ASM_EXTABLE(4b, .Lbad_get_user_clac)
 #else
-	_ASM_EXTABLE_UA(4b, .Lbad_get_user_8_clac)
-	_ASM_EXTABLE_UA(5b, .Lbad_get_user_8_clac)
+	_ASM_EXTABLE(4b, .Lbad_get_user_8_clac)
+	_ASM_EXTABLE(5b, .Lbad_get_user_8_clac)
 #endif
 
 /* __get_user */
-	_ASM_EXTABLE_UA(6b, .Lbad_get_user_clac)
-	_ASM_EXTABLE_UA(7b, .Lbad_get_user_clac)
-	_ASM_EXTABLE_UA(8b, .Lbad_get_user_clac)
+	_ASM_EXTABLE(6b, .Lbad_get_user_clac)
+	_ASM_EXTABLE(7b, .Lbad_get_user_clac)
+	_ASM_EXTABLE(8b, .Lbad_get_user_clac)
 #ifdef CONFIG_X86_64
-	_ASM_EXTABLE_UA(9b, .Lbad_get_user_clac)
+	_ASM_EXTABLE(9b, .Lbad_get_user_clac)
 #else
-	_ASM_EXTABLE_UA(9b, .Lbad_get_user_8_clac)
-	_ASM_EXTABLE_UA(10b, .Lbad_get_user_8_clac)
+	_ASM_EXTABLE(9b, .Lbad_get_user_8_clac)
+	_ASM_EXTABLE(10b, .Lbad_get_user_8_clac)
 #endif
diff --git a/arch/x86/lib/putuser.S b/arch/x86/lib/putuser.S
index 32125224fcca..0ec57997a764 100644
--- a/arch/x86/lib/putuser.S
+++ b/arch/x86/lib/putuser.S
@@ -33,20 +33,20 @@
  * as they get called from within inline assembly.
  */
 
-#ifdef CONFIG_X86_5LEVEL
-#define LOAD_TASK_SIZE_MINUS_N(n) \
-	ALTERNATIVE __stringify(mov $((1 << 47) - 4096 - (n)),%rbx), \
-		    __stringify(mov $((1 << 56) - 4096 - (n)),%rbx), X86_FEATURE_LA57
-#else
-#define LOAD_TASK_SIZE_MINUS_N(n) \
-	mov $(TASK_SIZE_MAX - (n)),%_ASM_BX
-#endif
+.macro check_range size:req
+.if IS_ENABLED(CONFIG_X86_64)
+	movq %rcx, %rbx
+	shrq $63, %rbx
+	orq %rbx, %rcx
+.else
+	cmp $TASK_SIZE_MAX-\size+1, %ecx
+	jae .Lbad_put_user
+.endif
+.endm
 
 .text
 SYM_FUNC_START(__put_user_1)
-	LOAD_TASK_SIZE_MINUS_N(0)
-	cmp %_ASM_BX,%_ASM_CX
-	jae .Lbad_put_user
+	check_range size=1
 	ASM_STAC
 1:	movb %al,(%_ASM_CX)
 	xor %ecx,%ecx
@@ -66,9 +66,7 @@ SYM_FUNC_END(__put_user_nocheck_1)
 EXPORT_SYMBOL(__put_user_nocheck_1)
 
 SYM_FUNC_START(__put_user_2)
-	LOAD_TASK_SIZE_MINUS_N(1)
-	cmp %_ASM_BX,%_ASM_CX
-	jae .Lbad_put_user
+	check_range size=2
 	ASM_STAC
 3:	movw %ax,(%_ASM_CX)
 	xor %ecx,%ecx
@@ -88,9 +86,7 @@ SYM_FUNC_END(__put_user_nocheck_2)
 EXPORT_SYMBOL(__put_user_nocheck_2)
 
 SYM_FUNC_START(__put_user_4)
-	LOAD_TASK_SIZE_MINUS_N(3)
-	cmp %_ASM_BX,%_ASM_CX
-	jae .Lbad_put_user
+	check_range size=4
 	ASM_STAC
 5:	movl %eax,(%_ASM_CX)
 	xor %ecx,%ecx
@@ -110,9 +106,7 @@ SYM_FUNC_END(__put_user_nocheck_4)
 EXPORT_SYMBOL(__put_user_nocheck_4)
 
 SYM_FUNC_START(__put_user_8)
-	LOAD_TASK_SIZE_MINUS_N(7)
-	cmp %_ASM_BX,%_ASM_CX
-	jae .Lbad_put_user
+	check_range size=8
 	ASM_STAC
 7:	mov %_ASM_AX,(%_ASM_CX)
 #ifdef CONFIG_X86_32
@@ -144,15 +138,15 @@ SYM_CODE_START_LOCAL(.Lbad_put_user_clac)
 	RET
 SYM_CODE_END(.Lbad_put_user_clac)
 
-	_ASM_EXTABLE_UA(1b, .Lbad_put_user_clac)
-	_ASM_EXTABLE_UA(2b, .Lbad_put_user_clac)
-	_ASM_EXTABLE_UA(3b, .Lbad_put_user_clac)
-	_ASM_EXTABLE_UA(4b, .Lbad_put_user_clac)
-	_ASM_EXTABLE_UA(5b, .Lbad_put_user_clac)
-	_ASM_EXTABLE_UA(6b, .Lbad_put_user_clac)
-	_ASM_EXTABLE_UA(7b, .Lbad_put_user_clac)
-	_ASM_EXTABLE_UA(9b, .Lbad_put_user_clac)
+	_ASM_EXTABLE(1b, .Lbad_put_user_clac)
+	_ASM_EXTABLE(2b, .Lbad_put_user_clac)
+	_ASM_EXTABLE(3b, .Lbad_put_user_clac)
+	_ASM_EXTABLE(4b, .Lbad_put_user_clac)
+	_ASM_EXTABLE(5b, .Lbad_put_user_clac)
+	_ASM_EXTABLE(6b, .Lbad_put_user_clac)
+	_ASM_EXTABLE(7b, .Lbad_put_user_clac)
+	_ASM_EXTABLE(9b, .Lbad_put_user_clac)
 #ifdef CONFIG_X86_32
-	_ASM_EXTABLE_UA(8b, .Lbad_put_user_clac)
-	_ASM_EXTABLE_UA(10b, .Lbad_put_user_clac)
+	_ASM_EXTABLE(8b, .Lbad_put_user_clac)
+	_ASM_EXTABLE(10b, .Lbad_put_user_clac)
 #endif
Linus Torvalds Jan. 7, 2023, 5:28 p.m. UTC | #8
On Sat, Jan 7, 2023 at 1:10 AM Kirill A. Shutemov <kirill@shutemov.name> wrote:
>
> On Fri, Dec 30, 2022 at 04:42:05PM -0800, Linus Torvalds wrote:
> > in ex_handler_uaccess() for the GP trap that users can now cause by
> > giving a non-canonical address with the high bit clear. So we'd
> > probably just want a new EX_TYPE_* for these cases, but that still
> > looks fairly straightforward.
>
> Plain _ASM_EXTABLE() seems does the trick.

Ack, for some reason I stupidly thought we'd have to change the
_ASM_EXTABLE_UA logic.

Thanks for setting me straight.

> Here's what I've come up with:

This looks good to me. And I like how you've used assembler macros
instead of the C preprocessor, it makes things more readable.

I'm personally so unused to asm macros that I never use them (and the
same is obviously true of Christoph who did that previous task size
thing), but I can appreciate others doing a better job at it.

So ack on this from me (I assume you tested it - hopefully even with
LAM), but maybe the x86 maintainers disagree violently?

The one possible downside is that *if* somebody passes non-valid user
addresses to get/put_user() intentionally (expecting an EFAULT), we
will now handle that much more slowly with a fault. But it would have
to be some really crazy use-case, and the normal case should be
simpler and faster.

But honestly, to me the upside is mainly "no need to worry about LAM
masking in asm code".

               Linus
diff mbox series

Patch

diff --git a/arch/x86/include/asm/mmu.h b/arch/x86/include/asm/mmu.h
index 9a046aacad8d..ed72fcd2292d 100644
--- a/arch/x86/include/asm/mmu.h
+++ b/arch/x86/include/asm/mmu.h
@@ -43,6 +43,9 @@  typedef struct {
 
 	/* Active LAM mode:  X86_CR3_LAM_U48 or X86_CR3_LAM_U57 or 0 (disabled) */
 	unsigned long lam_cr3_mask;
+
+	/* Significant bits of the virtual address. Excludes tag bits. */
+	u64 untag_mask;
 #endif
 
 	struct mutex lock;
diff --git a/arch/x86/include/asm/mmu_context.h b/arch/x86/include/asm/mmu_context.h
index 464cca41d20a..71581cb4811b 100644
--- a/arch/x86/include/asm/mmu_context.h
+++ b/arch/x86/include/asm/mmu_context.h
@@ -100,6 +100,12 @@  static inline unsigned long mm_lam_cr3_mask(struct mm_struct *mm)
 static inline void dup_lam(struct mm_struct *oldmm, struct mm_struct *mm)
 {
 	mm->context.lam_cr3_mask = oldmm->context.lam_cr3_mask;
+	mm->context.untag_mask = oldmm->context.untag_mask;
+}
+
+static inline void mm_reset_untag_mask(struct mm_struct *mm)
+{
+	mm->context.untag_mask = -1UL;
 }
 
 #else
@@ -112,6 +118,10 @@  static inline unsigned long mm_lam_cr3_mask(struct mm_struct *mm)
 static inline void dup_lam(struct mm_struct *oldmm, struct mm_struct *mm)
 {
 }
+
+static inline void mm_reset_untag_mask(struct mm_struct *mm)
+{
+}
 #endif
 
 #define enter_lazy_tlb enter_lazy_tlb
@@ -138,6 +148,7 @@  static inline int init_new_context(struct task_struct *tsk,
 		mm->context.execute_only_pkey = -1;
 	}
 #endif
+	mm_reset_untag_mask(mm);
 	init_new_context_ldt(mm);
 	return 0;
 }
diff --git a/arch/x86/include/asm/uaccess.h b/arch/x86/include/asm/uaccess.h
index 1cc756eafa44..cbb463e9344f 100644
--- a/arch/x86/include/asm/uaccess.h
+++ b/arch/x86/include/asm/uaccess.h
@@ -7,6 +7,7 @@ 
 #include <linux/compiler.h>
 #include <linux/instrumented.h>
 #include <linux/kasan-checks.h>
+#include <linux/mm_types.h>
 #include <linux/string.h>
 #include <asm/asm.h>
 #include <asm/page.h>
@@ -21,6 +22,37 @@  static inline bool pagefault_disabled(void);
 # define WARN_ON_IN_IRQ()
 #endif
 
+#ifdef CONFIG_X86_64
+/*
+ * Mask out tag bits from the address.
+ *
+ * Magic with the 'sign' allows to untag userspace pointer without any branches
+ * while leaving kernel addresses intact.
+ */
+#define __untagged_addr(mm, addr)	({				\
+	u64 __addr = (__force u64)(addr);				\
+	s64 sign = (s64)__addr >> 63;					\
+	__addr &= READ_ONCE((mm)->context.untag_mask) | sign;		\
+	(__force __typeof__(addr))__addr;				\
+})
+
+#define untagged_addr(addr) __untagged_addr(current->mm, addr)
+
+#define untagged_addr_remote(mm, addr)	({				\
+	mmap_assert_locked(mm);						\
+	__untagged_addr(mm, addr);					\
+})
+
+#define untagged_ptr(ptr)	({					\
+	u64 __ptrval = (__force u64)(ptr);				\
+	__ptrval = untagged_addr(__ptrval);				\
+	(__force __typeof__(ptr))__ptrval;				\
+})
+#else
+#define untagged_addr(addr)	(addr)
+#define untagged_ptr(ptr)	(ptr)
+#endif
+
 /**
  * access_ok - Checks if a user space pointer is valid
  * @addr: User space pointer to start of block to check
@@ -38,10 +70,10 @@  static inline bool pagefault_disabled(void);
  * Return: true (nonzero) if the memory block may be valid, false (zero)
  * if it is definitely invalid.
  */
-#define access_ok(addr, size)					\
+#define access_ok(addr, size)						\
 ({									\
 	WARN_ON_IN_IRQ();						\
-	likely(__access_ok(addr, size));				\
+	likely(__access_ok(untagged_addr(addr), size));			\
 })
 
 #include <asm-generic/access_ok.h>
@@ -127,7 +159,11 @@  extern int __get_user_bad(void);
  * Return: zero on success, or -EFAULT on error.
  * On error, the variable @x is set to zero.
  */
-#define get_user(x,ptr) ({ might_fault(); do_get_user_call(get_user,x,ptr); })
+#define get_user(x,ptr)							\
+({									\
+	might_fault();							\
+	do_get_user_call(get_user,x,untagged_ptr(ptr));	\
+})
 
 /**
  * __get_user - Get a simple variable from user space, with less checking.
@@ -227,7 +263,10 @@  extern void __put_user_nocheck_8(void);
  *
  * Return: zero on success, or -EFAULT on error.
  */
-#define put_user(x, ptr) ({ might_fault(); do_put_user_call(put_user,x,ptr); })
+#define put_user(x, ptr) ({						\
+	might_fault();							\
+	do_put_user_call(put_user,x,untagged_ptr(ptr));			\
+})
 
 /**
  * __put_user - Write a simple value into user space, with less checking.
diff --git a/arch/x86/kernel/process.c b/arch/x86/kernel/process.c
index 40d156a31676..ef6bde1d40d8 100644
--- a/arch/x86/kernel/process.c
+++ b/arch/x86/kernel/process.c
@@ -47,6 +47,7 @@ 
 #include <asm/frame.h>
 #include <asm/unwind.h>
 #include <asm/tdx.h>
+#include <asm/mmu_context.h>
 
 #include "process.h"
 
@@ -367,6 +368,8 @@  void arch_setup_new_exec(void)
 		task_clear_spec_ssb_noexec(current);
 		speculation_ctrl_update(read_thread_flags());
 	}
+
+	mm_reset_untag_mask(current->mm);
 }
 
 #ifdef CONFIG_X86_IOPL_IOPERM