@@ -43,6 +43,9 @@ typedef struct {
/* Active LAM mode: X86_CR3_LAM_U48 or X86_CR3_LAM_U57 or 0 (disabled) */
unsigned long lam_cr3_mask;
+
+ /* Significant bits of the virtual address. Excludes tag bits. */
+ u64 untag_mask;
#endif
struct mutex lock;
@@ -100,6 +100,12 @@ static inline unsigned long mm_lam_cr3_mask(struct mm_struct *mm)
static inline void dup_lam(struct mm_struct *oldmm, struct mm_struct *mm)
{
mm->context.lam_cr3_mask = oldmm->context.lam_cr3_mask;
+ mm->context.untag_mask = oldmm->context.untag_mask;
+}
+
+static inline void mm_reset_untag_mask(struct mm_struct *mm)
+{
+ mm->context.untag_mask = -1UL;
}
#else
@@ -112,6 +118,10 @@ static inline unsigned long mm_lam_cr3_mask(struct mm_struct *mm)
static inline void dup_lam(struct mm_struct *oldmm, struct mm_struct *mm)
{
}
+
+static inline void mm_reset_untag_mask(struct mm_struct *mm)
+{
+}
#endif
#define enter_lazy_tlb enter_lazy_tlb
@@ -138,6 +148,7 @@ static inline int init_new_context(struct task_struct *tsk,
mm->context.execute_only_pkey = -1;
}
#endif
+ mm_reset_untag_mask(mm);
init_new_context_ldt(mm);
return 0;
}
@@ -7,6 +7,7 @@
#include <linux/compiler.h>
#include <linux/instrumented.h>
#include <linux/kasan-checks.h>
+#include <linux/mm_types.h>
#include <linux/string.h>
#include <asm/asm.h>
#include <asm/page.h>
@@ -21,6 +22,30 @@ static inline bool pagefault_disabled(void);
# define WARN_ON_IN_IRQ()
#endif
+#ifdef CONFIG_X86_64
+/*
+ * Mask out tag bits from the address.
+ *
+ * Magic with the 'sign' allows to untag userspace pointer without any branches
+ * while leaving kernel addresses intact.
+ */
+#define untagged_addr(mm, addr) ({ \
+ u64 __addr = (__force u64)(addr); \
+ s64 sign = (s64)__addr >> 63; \
+ __addr &= (mm)->context.untag_mask | sign; \
+ (__force __typeof__(addr))__addr; \
+})
+
+#define untagged_ptr(mm, ptr) ({ \
+ u64 __ptrval = (__force u64)(ptr); \
+ __ptrval = untagged_addr(mm, __ptrval); \
+ (__force __typeof__(*(ptr)) *)__ptrval; \
+})
+#else
+#define untagged_addr(mm, addr) (addr)
+#define untagged_ptr(mm, ptr) (ptr)
+#endif
+
/**
* access_ok - Checks if a user space pointer is valid
* @addr: User space pointer to start of block to check
@@ -41,7 +66,7 @@ static inline bool pagefault_disabled(void);
#define access_ok(addr, size) \
({ \
WARN_ON_IN_IRQ(); \
- likely(__access_ok(addr, size)); \
+ likely(__access_ok(untagged_addr(current->mm, addr), size)); \
})
#include <asm-generic/access_ok.h>
@@ -127,7 +152,13 @@ extern int __get_user_bad(void);
* Return: zero on success, or -EFAULT on error.
* On error, the variable @x is set to zero.
*/
-#define get_user(x,ptr) ({ might_fault(); do_get_user_call(get_user,x,ptr); })
+#define get_user(x,ptr) \
+({ \
+ __typeof__(*(ptr)) __user *__ptr_clean; \
+ __ptr_clean = untagged_ptr(current->mm, ptr); \
+ might_fault(); \
+ do_get_user_call(get_user,x,__ptr_clean); \
+})
/**
* __get_user - Get a simple variable from user space, with less checking.
@@ -227,7 +258,12 @@ extern void __put_user_nocheck_8(void);
*
* Return: zero on success, or -EFAULT on error.
*/
-#define put_user(x, ptr) ({ might_fault(); do_put_user_call(put_user,x,ptr); })
+#define put_user(x, ptr) ({ \
+ __typeof__(*(ptr)) __user *__ptr_clean; \
+ __ptr_clean = untagged_ptr(current->mm, ptr); \
+ might_fault(); \
+ do_put_user_call(put_user,x,__ptr_clean); \
+})
/**
* __put_user - Write a simple value into user space, with less checking.
@@ -47,6 +47,7 @@
#include <asm/frame.h>
#include <asm/unwind.h>
#include <asm/tdx.h>
+#include <asm/mmu_context.h>
#include "process.h"
@@ -367,6 +368,8 @@ void arch_setup_new_exec(void)
task_clear_spec_ssb_noexec(current);
speculation_ctrl_update(read_thread_flags());
}
+
+ mm_reset_untag_mask(current->mm);
}
#ifdef CONFIG_X86_IOPL_IOPERM