diff mbox series

[PATCHv3,5/8] x86/uaccess: Provide untagged_addr() and remove tags before address check

Message ID 20220610143527.22974-6-kirill.shutemov@linux.intel.com (mailing list archive)
State New
Headers show
Series Linear Address Masking enabling | expand

Commit Message

Kirill A. Shutemov June 10, 2022, 2:35 p.m. UTC
untagged_addr() is a helper used by the core-mm to strip tag bits and
get the address to the canonical shape. In only handles userspace
addresses. The untagging mask is stored in mmu_context and will be set
on enabling LAM for the process.

The tags must not be included into check whether it's okay to access the
userspace address.

Strip tags in access_ok().

get_user() and put_user() don't use access_ok(), but check access
against TASK_SIZE directly in assembly. Strip tags, before calling into
the assembly helper.

Signed-off-by: Kirill A. Shutemov <kirill.shutemov@linux.intel.com>
---
 arch/x86/include/asm/mmu.h         |  1 +
 arch/x86/include/asm/mmu_context.h | 11 ++++++++
 arch/x86/include/asm/uaccess.h     | 44 ++++++++++++++++++++++++++++--
 arch/x86/kernel/process.c          |  3 ++
 4 files changed, 56 insertions(+), 3 deletions(-)

Comments

Edgecombe, Rick P June 13, 2022, 5:36 p.m. UTC | #1
On Fri, 2022-06-10 at 17:35 +0300, Kirill A. Shutemov wrote:
> +#ifdef CONFIG_X86_64
> +/*
> + * Mask out tag bits from the address.
> + *
> + * Magic with the 'sign' allows to untag userspace pointer without
> any branches
> + * while leaving kernel addresses intact.

Trying to understand the magic part here. I guess how it works is, when
the high bit is set, it does the opposite of untagging the addresses by
setting the tag bits instead of clearing them. So:
 - For proper canonical kernel addresses (with U57) it leaves them 
   intact since the tag bits were already set.
 - For non-canonical kernel-half addresses, it fixes them up. 
   (0xeffffff000000840->0xfffffff000000840)
 - For U48 and 5 level paging, it corrupts some normal kernel 
   addresses. (0xff90ffffffffffff->0xffffffffffffffff)

I just ported this to userspace and threw some addresses at it to see
what happened, so hopefully I got that right.

Is this special kernel address handling only needed because
copy_to_kernel_nofault(), etc call the user helpers?

> + */
> +#define untagged_addr(mm,
> addr)        ({                                      \
> +       u64 __addr = (__force
> u64)(addr);                               \
> +       s64 sign = (s64)__addr >>
> 63;                                   \
> +       __addr ^=
> sign;                                                 \
> +       __addr &= (mm)-
> >context.untag_mask;                             \
> +       __addr ^=
> sign;                                                 \
> +       (__force
> __typeof__(addr))__addr;                               \
> +})
Kirill A. Shutemov June 15, 2022, 4:58 p.m. UTC | #2
On Mon, Jun 13, 2022 at 05:36:43PM +0000, Edgecombe, Rick P wrote:
> On Fri, 2022-06-10 at 17:35 +0300, Kirill A. Shutemov wrote:
> > +#ifdef CONFIG_X86_64
> > +/*
> > + * Mask out tag bits from the address.
> > + *
> > + * Magic with the 'sign' allows to untag userspace pointer without
> > any branches
> > + * while leaving kernel addresses intact.
> 
> Trying to understand the magic part here. I guess how it works is, when
> the high bit is set, it does the opposite of untagging the addresses by
> setting the tag bits instead of clearing them. So:
>  - For proper canonical kernel addresses (with U57) it leaves them 
>    intact since the tag bits were already set.
>  - For non-canonical kernel-half addresses, it fixes them up. 
>    (0xeffffff000000840->0xfffffff000000840)
>  - For U48 and 5 level paging, it corrupts some normal kernel 
>    addresses. (0xff90ffffffffffff->0xffffffffffffffff)
> 
> I just ported this to userspace and threw some addresses at it to see
> what happened, so hopefully I got that right.

Ouch. Thanks for noticing this. I should have catched this myself. Yes,
this implementation is broken for LAM_U48 on 5-level machine.

What about this:

	#define untagged_addr(mm, addr)	({					\
		u64 __addr = (__force u64)(addr);				\
		s64 sign = (s64)__addr >> 63;					\
		__addr &= (mm)->context.untag_mask | sign;			\
		(__force __typeof__(addr))__addr;				\
	})

It makes mask effectively. all-ones for supervisor addresses. And it is
less magic to my eyes.

The generated code also look sane to me:

    11d0:	48 89 f8             	mov    %rdi,%rax
    11d3:	48 c1 f8 3f          	sar    $0x3f,%rax
    11d7:	48 0b 05 52 2e 00 00 	or     0x2e52(%rip),%rax        # 4030 <untag_mask>
    11de:	48 21 f8             	and    %rdi,%rax

Any comments?

> Is this special kernel address handling only needed because
> copy_to_kernel_nofault(), etc call the user helpers?

I did not have any particular use-case in mind. But just if some kernel
address gets there and bits get cleared we will have very hard to debug
bug.
Edgecombe, Rick P June 15, 2022, 7:06 p.m. UTC | #3
On Wed, 2022-06-15 at 19:58 +0300, Kirill A. Shutemov wrote:
> On Mon, Jun 13, 2022 at 05:36:43PM +0000, Edgecombe, Rick P wrote:
> > On Fri, 2022-06-10 at 17:35 +0300, Kirill A. Shutemov wrote:
> > > +#ifdef CONFIG_X86_64
> > > +/*
> > > + * Mask out tag bits from the address.
> > > + *
> > > + * Magic with the 'sign' allows to untag userspace pointer
> > > without
> > > any branches
> > > + * while leaving kernel addresses intact.
> > 
> > Trying to understand the magic part here. I guess how it works is,
> > when
> > the high bit is set, it does the opposite of untagging the
> > addresses by
> > setting the tag bits instead of clearing them. So:
> >   - For proper canonical kernel addresses (with U57) it leaves
> > them 
> >     intact since the tag bits were already set.
> >   - For non-canonical kernel-half addresses, it fixes them up. 
> >     (0xeffffff000000840->0xfffffff000000840)
> >   - For U48 and 5 level paging, it corrupts some normal kernel 
> >     addresses. (0xff90ffffffffffff->0xffffffffffffffff)
> > 
> > I just ported this to userspace and threw some addresses at it to
> > see
> > what happened, so hopefully I got that right.
> 
> Ouch. Thanks for noticing this. I should have catched this myself.
> Yes,
> this implementation is broken for LAM_U48 on 5-level machine.
> 
> What about this:
> 
>         #define untagged_addr(mm,
> addr) ({                                      \
>                 u64 __addr = (__force
> u64)(addr);                               \
>                 s64 sign = (s64)__addr >>
> 63;                                   \
>                 __addr &= (mm)->context.untag_mask |
> sign;                      \
>                 (__force
> __typeof__(addr))__addr;                               \
>         })
> 
> It makes mask effectively. all-ones for supervisor addresses. And it
> is
> less magic to my eyes.

Yea, it seems to leave kernel half addresses alone now, including
leaving non-canonical addresses as non-canonical and 5 level addresses.

With the new bit math:
Reviewed-by: Rick Edgecombe <rick.p.edgecombe@intel.com>

> 
> The generated code also look sane to me:
> 
>     11d0:       48 89 f8                mov    %rdi,%rax
>     11d3:       48 c1 f8 3f             sar    $0x3f,%rax
>     11d7:       48 0b 05 52 2e 00 00    or    
> 0x2e52(%rip),%rax        # 4030 <untag_mask>
>     11de:       48 21 f8                and    %rdi,%rax
> 
> Any comments?
> 
> > Is this special kernel address handling only needed because
> > copy_to_kernel_nofault(), etc call the user helpers?
> 
> I did not have any particular use-case in mind. But just if some
> kernel
> address gets there and bits get cleared we will have very hard to
> debug
> bug.

I just was thinking if we could rearrange the code to avoid untagging
kernel addresses, we could skip this, or even VM_WARN_ON() if we see
one. Seems ok either way.
Peter Zijlstra June 16, 2022, 9:30 a.m. UTC | #4
On Mon, Jun 13, 2022 at 05:36:43PM +0000, Edgecombe, Rick P wrote:
> On Fri, 2022-06-10 at 17:35 +0300, Kirill A. Shutemov wrote:
> > +#ifdef CONFIG_X86_64
> > +/*
> > + * Mask out tag bits from the address.
> > + *
> > + * Magic with the 'sign' allows to untag userspace pointer without
> > any branches
> > + * while leaving kernel addresses intact.
> 
> Trying to understand the magic part here. I guess how it works is, when
> the high bit is set, it does the opposite of untagging the addresses by
> setting the tag bits instead of clearing them. So:

The magic is really rather simple to see; there's two observations:

  x ^ y ^ y == x

That is; xor is it's own inverse. And secondly, xor with 1 is a bit
toggle.

So if we mask a negative value, we destroy the sign. Therefore, if we
xor with the sign-bit, we have a nop for positive numbers and a toggle
for negatives (effectively making them positive, -1, 2s complement
yada-yada) then we can mask, without fear of destroying the sign, and
then we xor again to undo whatever we did before, effectively restoring
the sign.

Anyway, concequence of all this is that LAM_U48 won't work correct on
5-level kernels, because the mask will still destroy kernel pointers.

As such, this patch only does LAM_U57.
Peter Zijlstra June 16, 2022, 9:34 a.m. UTC | #5
On Mon, Jun 13, 2022 at 05:36:43PM +0000, Edgecombe, Rick P wrote:

> Is this special kernel address handling only needed because
> copy_to_kernel_nofault(), etc call the user helpers?

It is to make absolutely sure we don't need to go audit everything, if
code is correct without untag_pointer() it will still be correct with it
on.

Also avoids future bugs due to being robust in general.
Peter Zijlstra June 16, 2022, 10:02 a.m. UTC | #6
On Fri, Jun 10, 2022 at 05:35:24PM +0300, Kirill A. Shutemov wrote:
> +#ifdef CONFIG_X86_64
> +/*
> + * Mask out tag bits from the address.
> + *
> + * Magic with the 'sign' allows to untag userspace pointer without any branches
> + * while leaving kernel addresses intact.
> + */
> +#define untagged_addr(mm, addr)	({					\
> +	u64 __addr = (__force u64)(addr);				\
> +	s64 sign = (s64)__addr >> 63;					\
> +	__addr ^= sign;							\
> +	__addr &= (mm)->context.untag_mask;				\
> +	__addr ^= sign;							\
> +	(__force __typeof__(addr))__addr;				\
> +})

Can't we make that mask a constant and *always* unmask U57 irrespective
of LAM being on?
Kirill A. Shutemov June 16, 2022, 4:44 p.m. UTC | #7
On Thu, Jun 16, 2022 at 11:30:49AM +0200, Peter Zijlstra wrote:
> On Mon, Jun 13, 2022 at 05:36:43PM +0000, Edgecombe, Rick P wrote:
> > On Fri, 2022-06-10 at 17:35 +0300, Kirill A. Shutemov wrote:
> > > +#ifdef CONFIG_X86_64
> > > +/*
> > > + * Mask out tag bits from the address.
> > > + *
> > > + * Magic with the 'sign' allows to untag userspace pointer without
> > > any branches
> > > + * while leaving kernel addresses intact.
> > 
> > Trying to understand the magic part here. I guess how it works is, when
> > the high bit is set, it does the opposite of untagging the addresses by
> > setting the tag bits instead of clearing them. So:
> 
> The magic is really rather simple to see; there's two observations:
> 
>   x ^ y ^ y == x
> 
> That is; xor is it's own inverse. And secondly, xor with 1 is a bit
> toggle.
> 
> So if we mask a negative value, we destroy the sign. Therefore, if we
> xor with the sign-bit, we have a nop for positive numbers and a toggle
> for negatives (effectively making them positive, -1, 2s complement
> yada-yada) then we can mask, without fear of destroying the sign, and
> then we xor again to undo whatever we did before, effectively restoring
> the sign.
> 
> Anyway, concequence of all this is that LAM_U48 won't work correct on
> 5-level kernels, because the mask will still destroy kernel pointers.

Any objection against this variant (was posted in the thread):

       	#define untagged_addr(mm, addr) ({                                      \
               	u64 __addr = (__force u64)(addr);                               \
               	s64 sign = (s64)__addr >> 63;                                   \
               	__addr &= (mm)->context.untag_mask | sign;                      \
               	(__force __typeof__(addr))__addr;                               \
       	})

?

I find it easier to follow and it is LAM_U48-safe.
Kirill A. Shutemov June 16, 2022, 4:48 p.m. UTC | #8
On Thu, Jun 16, 2022 at 12:02:16PM +0200, Peter Zijlstra wrote:
> On Fri, Jun 10, 2022 at 05:35:24PM +0300, Kirill A. Shutemov wrote:
> > +#ifdef CONFIG_X86_64
> > +/*
> > + * Mask out tag bits from the address.
> > + *
> > + * Magic with the 'sign' allows to untag userspace pointer without any branches
> > + * while leaving kernel addresses intact.
> > + */
> > +#define untagged_addr(mm, addr)	({					\
> > +	u64 __addr = (__force u64)(addr);				\
> > +	s64 sign = (s64)__addr >> 63;					\
> > +	__addr ^= sign;							\
> > +	__addr &= (mm)->context.untag_mask;				\
> > +	__addr ^= sign;							\
> > +	(__force __typeof__(addr))__addr;				\
> > +})
> 
> Can't we make that mask a constant and *always* unmask U57 irrespective
> of LAM being on?

We can do this if we give up on LAM_U48.

It would also needlessly relax canonical check. I'm not sure it is a good
idea.
Peter Zijlstra June 17, 2022, 11:36 a.m. UTC | #9
On Thu, Jun 16, 2022 at 07:44:40PM +0300, Kirill A. Shutemov wrote:
> Any objection against this variant (was posted in the thread):
> 
>        	#define untagged_addr(mm, addr) ({                                      \
>                	u64 __addr = (__force u64)(addr);                               \
>                	s64 sign = (s64)__addr >> 63;                                   \
>                	__addr &= (mm)->context.untag_mask | sign;                      \
>                	(__force __typeof__(addr))__addr;                               \
>        	})
> 
> ?

Yeah, I suppose that should work fine.
H.J. Lu June 17, 2022, 2:22 p.m. UTC | #10
On Fri, Jun 17, 2022 at 4:36 AM Peter Zijlstra <peterz@infradead.org> wrote:
>
> On Thu, Jun 16, 2022 at 07:44:40PM +0300, Kirill A. Shutemov wrote:
> > Any objection against this variant (was posted in the thread):
> >
> >               #define untagged_addr(mm, addr) ({                                      \
> >                       u64 __addr = (__force u64)(addr);                               \
> >                       s64 sign = (s64)__addr >> 63;                                   \
> >                       __addr &= (mm)->context.untag_mask | sign;                      \
> >                       (__force __typeof__(addr))__addr;                               \
> >               })
> >
> > ?
>
> Yeah, I suppose that should work fine.

Won't the sign bit be put at the wrong place?
Peter Zijlstra June 17, 2022, 2:28 p.m. UTC | #11
On Fri, Jun 17, 2022 at 07:22:57AM -0700, H.J. Lu wrote:
> On Fri, Jun 17, 2022 at 4:36 AM Peter Zijlstra <peterz@infradead.org> wrote:
> >
> > On Thu, Jun 16, 2022 at 07:44:40PM +0300, Kirill A. Shutemov wrote:
> > > Any objection against this variant (was posted in the thread):
> > >
> > >               #define untagged_addr(mm, addr) ({                                      \
> > >                       u64 __addr = (__force u64)(addr);                               \
> > >                       s64 sign = (s64)__addr >> 63;                                   \
> > >                       __addr &= (mm)->context.untag_mask | sign;                      \
> > >                       (__force __typeof__(addr))__addr;                               \
> > >               })
> > >
> > > ?
> >
> > Yeah, I suppose that should work fine.
> 
> Won't the sign bit be put at the wrong place?

sign is either 0 or ~0, by or-ing ~0 into the mask, the masking becomes
no-op.
Andy Lutomirski June 28, 2022, 11:40 p.m. UTC | #12
On 6/10/22 07:35, Kirill A. Shutemov wrote:
> untagged_addr() is a helper used by the core-mm to strip tag bits and
> get the address to the canonical shape. In only handles userspace
> addresses. The untagging mask is stored in mmu_context and will be set
> on enabling LAM for the process.
> 
> The tags must not be included into check whether it's okay to access the
> userspace address.
> 
> Strip tags in access_ok().

What is the intended behavior for an access that spans a tag boundary?

Also, at the risk of a potentially silly question, why do we need to 
strip the tag before access_ok()?  With LAM, every valid tagged user 
address is also a valid untagged address, right?  (There is no 
particular need to enforce the actual value of TASK_SIZE_MAX on 
*access*, just on mmap.)

IOW, wouldn't it be sufficient, and probably better than what we have 
now, to just check that the entire range has the high bit clear?

--Andy
Kirill A. Shutemov June 29, 2022, 12:42 a.m. UTC | #13
On Tue, Jun 28, 2022 at 04:40:48PM -0700, Andy Lutomirski wrote:
> On 6/10/22 07:35, Kirill A. Shutemov wrote:
> > untagged_addr() is a helper used by the core-mm to strip tag bits and
> > get the address to the canonical shape. In only handles userspace
> > addresses. The untagging mask is stored in mmu_context and will be set
> > on enabling LAM for the process.
> > 
> > The tags must not be included into check whether it's okay to access the
> > userspace address.
> > 
> > Strip tags in access_ok().
> 
> What is the intended behavior for an access that spans a tag boundary?

You mean if 'addr' passed to access_ok() is below 47- or 56-bit but 'addr'
+ 'size' overflows into tags? If is not valid access and the current
implementation works correctly.

> Also, at the risk of a potentially silly question, why do we need to strip
> the tag before access_ok()?  With LAM, every valid tagged user address is
> also a valid untagged address, right?  (There is no particular need to
> enforce the actual value of TASK_SIZE_MAX on *access*, just on mmap.)
> 
> IOW, wouldn't it be sufficient, and probably better than what we have now,
> to just check that the entire range has the high bit clear?

No. We do things to addresses on kernel side beyond dereferencing them.
Like comparing addresses have to make sense. find_vma() has to find
relevant mapping and so on.
Andy Lutomirski June 30, 2022, 2:38 a.m. UTC | #14
On Tue, Jun 28, 2022, at 5:42 PM, Kirill A. Shutemov wrote:
> On Tue, Jun 28, 2022 at 04:40:48PM -0700, Andy Lutomirski wrote:
>> On 6/10/22 07:35, Kirill A. Shutemov wrote:
>> > untagged_addr() is a helper used by the core-mm to strip tag bits and
>> > get the address to the canonical shape. In only handles userspace
>> > addresses. The untagging mask is stored in mmu_context and will be set
>> > on enabling LAM for the process.
>> > 
>> > The tags must not be included into check whether it's okay to access the
>> > userspace address.
>> > 
>> > Strip tags in access_ok().
>> 
>> What is the intended behavior for an access that spans a tag boundary?
>
> You mean if 'addr' passed to access_ok() is below 47- or 56-bit but 'addr'
> + 'size' overflows into tags? If is not valid access and the current
> implementation works correctly.
>
>> Also, at the risk of a potentially silly question, why do we need to strip
>> the tag before access_ok()?  With LAM, every valid tagged user address is
>> also a valid untagged address, right?  (There is no particular need to
>> enforce the actual value of TASK_SIZE_MAX on *access*, just on mmap.)
>> 
>> IOW, wouldn't it be sufficient, and probably better than what we have now,
>> to just check that the entire range has the high bit clear?
>
> No. We do things to addresses on kernel side beyond dereferencing them.
> Like comparing addresses have to make sense. find_vma() has to find
> relevant mapping and so on. 

I think you’re misunderstanding me. Of course things like find_vma() need to untag the address. (But things like munmap, IMO, ought not accept tagged addresses.)

But I’m not talking about that at all. I’m asking why we need to untag an address for access_ok.  In the bad old days, access_ok checked that a range was below a *variable* cutoff. But set_fs() is gone, and I don’t think this is needed any more.

With some off-the-cuff bit hackery, I think it ought to be sufficient to calculate addr+len and fail if the overflow or sign bits get set. If LAM is off, we could calculate addr | len and fail if either of the top two bits is set, but LAM may not let us get away with that.  The general point being that, on x86 (as long as we ignore AMD’s LAM-like feature) an address is a user address if the top bit is clear. Whether that address is canonical or not or will translate or not is a separate issue. (And making this change would require allowing uaccess to #GP, which requires some care.)

What do you think?

>
> -- 
>  Kirill A. Shutemov
Kirill A. Shutemov July 5, 2022, 12:13 a.m. UTC | #15
On Wed, Jun 29, 2022 at 07:38:20PM -0700, Andy Lutomirski wrote:
> 
> 
> On Tue, Jun 28, 2022, at 5:42 PM, Kirill A. Shutemov wrote:
> > On Tue, Jun 28, 2022 at 04:40:48PM -0700, Andy Lutomirski wrote:
> >> On 6/10/22 07:35, Kirill A. Shutemov wrote:
> >> > untagged_addr() is a helper used by the core-mm to strip tag bits and
> >> > get the address to the canonical shape. In only handles userspace
> >> > addresses. The untagging mask is stored in mmu_context and will be set
> >> > on enabling LAM for the process.
> >> > 
> >> > The tags must not be included into check whether it's okay to access the
> >> > userspace address.
> >> > 
> >> > Strip tags in access_ok().
> >> 
> >> What is the intended behavior for an access that spans a tag boundary?
> >
> > You mean if 'addr' passed to access_ok() is below 47- or 56-bit but 'addr'
> > + 'size' overflows into tags? If is not valid access and the current
> > implementation works correctly.
> >
> >> Also, at the risk of a potentially silly question, why do we need to strip
> >> the tag before access_ok()?  With LAM, every valid tagged user address is
> >> also a valid untagged address, right?  (There is no particular need to
> >> enforce the actual value of TASK_SIZE_MAX on *access*, just on mmap.)
> >> 
> >> IOW, wouldn't it be sufficient, and probably better than what we have now,
> >> to just check that the entire range has the high bit clear?
> >
> > No. We do things to addresses on kernel side beyond dereferencing them.
> > Like comparing addresses have to make sense. find_vma() has to find
> > relevant mapping and so on. 
> 
> I think you’re misunderstanding me. Of course things like find_vma()
> need to untag the address. (But things like munmap, IMO, ought not
> accept tagged addresses.)
> 
> But I’m not talking about that at all. I’m asking why we need to untag
> an address for access_ok.  In the bad old days, access_ok checked that a
> range was below a *variable* cutoff. But set_fs() is gone, and I don’t
> think this is needed any more.
> 
> With some off-the-cuff bit hackery, I think it ought to be sufficient to
> calculate addr+len and fail if the overflow or sign bits get set. If LAM
> is off, we could calculate addr | len and fail if either of the top two
> bits is set, but LAM may not let us get away with that.  The general
> point being that, on x86 (as long as we ignore AMD’s LAM-like feature)
> an address is a user address if the top bit is clear. Whether that
> address is canonical or not or will translate or not is a separate
> issue. (And making this change would require allowing uaccess to #GP,
> which requires some care.)
> 
> What do you think?

I think it can work. Below is my first take on this. It is raw and
requires more work.

You talk about access_ok(), but I also checked what has to be done in
get_user()/put_user() path. Not sure it is a good idea (but it seems
work).

The basic idea is to OR start and end of the range and check that MSB is
clear.

I reworked array_index_mask_nospec() thing to make the address all-ones if
it is ouside of the user range. It should work to stop speculation as well
as making it zero as we do now. 

The patch as it is breaks 32-bit. It has to be handled separately.

Any comments?

diff --git a/arch/x86/include/asm/uaccess.h b/arch/x86/include/asm/uaccess.h
index 803241dfc473..53c6f86b036f 100644
--- a/arch/x86/include/asm/uaccess.h
+++ b/arch/x86/include/asm/uaccess.h
@@ -45,6 +45,12 @@ static inline bool pagefault_disabled(void);
 #define untagged_ptr(mm, ptr)	(ptr)
 #endif
 
+#define __access_ok(ptr, size)						\
+({									\
+	unsigned long addr = (unsigned long)(ptr);			\
+	!((addr | (size) | addr + (size)) >> (BITS_PER_LONG - 1));	\
+})
+
 /**
  * access_ok - Checks if a user space pointer is valid
  * @addr: User space pointer to start of block to check
@@ -62,10 +68,10 @@ static inline bool pagefault_disabled(void);
  * Return: true (nonzero) if the memory block may be valid, false (zero)
  * if it is definitely invalid.
  */
-#define access_ok(addr, size)					\
+#define access_ok(addr, size)						\
 ({									\
 	WARN_ON_IN_IRQ();						\
-	likely(__access_ok(untagged_addr(current->mm, addr), size));	\
+	likely(__access_ok(addr, size));				\
 })
 
 #include <asm-generic/access_ok.h>
@@ -150,13 +156,7 @@ extern int __get_user_bad(void);
  * Return: zero on success, or -EFAULT on error.
  * On error, the variable @x is set to zero.
  */
-#define get_user(x,ptr)							\
-({									\
-	__typeof__(*(ptr)) __user *__ptr_clean;				\
-	__ptr_clean = untagged_ptr(current->mm, ptr);			\
-	might_fault();							\
-	do_get_user_call(get_user,x,__ptr_clean);			\
-})
+#define get_user(x,ptr) ({ might_fault(); do_get_user_call(get_user,x,ptr); })
 
 /**
  * __get_user - Get a simple variable from user space, with less checking.
@@ -253,12 +253,7 @@ extern void __put_user_nocheck_8(void);
  *
  * Return: zero on success, or -EFAULT on error.
  */
-#define put_user(x, ptr) ({						\
-	__typeof__(*(ptr)) __user *__ptr_clean;				\
-	__ptr_clean = untagged_ptr(current->mm, ptr);			\
-	might_fault();							\
-	do_put_user_call(put_user,x,__ptr_clean);			\
-})
+#define put_user(x, ptr) ({ might_fault(); do_put_user_call(put_user,x,ptr); })
 
 /**
  * __put_user - Write a simple value into user space, with less checking.
diff --git a/arch/x86/lib/getuser.S b/arch/x86/lib/getuser.S
index b70d98d79a9d..e526ae6ff844 100644
--- a/arch/x86/lib/getuser.S
+++ b/arch/x86/lib/getuser.S
@@ -37,22 +37,13 @@
 
 #define ASM_BARRIER_NOSPEC ALTERNATIVE "", "lfence", X86_FEATURE_LFENCE_RDTSC
 
-#ifdef CONFIG_X86_5LEVEL
-#define LOAD_TASK_SIZE_MINUS_N(n) \
-	ALTERNATIVE __stringify(mov $((1 << 47) - 4096 - (n)),%rdx), \
-		    __stringify(mov $((1 << 56) - 4096 - (n)),%rdx), X86_FEATURE_LA57
-#else
-#define LOAD_TASK_SIZE_MINUS_N(n) \
-	mov $(TASK_SIZE_MAX - (n)),%_ASM_DX
-#endif
-
 	.text
 SYM_FUNC_START(__get_user_1)
-	LOAD_TASK_SIZE_MINUS_N(0)
-	cmp %_ASM_DX,%_ASM_AX
-	jae bad_get_user
-	sbb %_ASM_DX, %_ASM_DX		/* array_index_mask_nospec() */
-	and %_ASM_DX, %_ASM_AX
+	mov %_ASM_AX, %_ASM_DX
+	shl $1, %_ASM_DX
+	jb bad_get_user
+	sbb %_ASM_DX, %_ASM_DX
+	or %_ASM_DX, %_ASM_AX
 	ASM_STAC
 1:	movzbl (%_ASM_AX),%edx
 	xor %eax,%eax
@@ -62,11 +53,13 @@ SYM_FUNC_END(__get_user_1)
 EXPORT_SYMBOL(__get_user_1)
 
 SYM_FUNC_START(__get_user_2)
-	LOAD_TASK_SIZE_MINUS_N(1)
-	cmp %_ASM_DX,%_ASM_AX
-	jae bad_get_user
-	sbb %_ASM_DX, %_ASM_DX		/* array_index_mask_nospec() */
-	and %_ASM_DX, %_ASM_AX
+	mov %_ASM_AX, %_ASM_DX
+	add $1, %_ASM_DX
+	or %_ASM_AX, %_ASM_DX
+	shl $1, %_ASM_DX
+	jb bad_get_user
+	sbb %_ASM_DX, %_ASM_DX
+	or %_ASM_DX, %_ASM_AX
 	ASM_STAC
 2:	movzwl (%_ASM_AX),%edx
 	xor %eax,%eax
@@ -76,11 +69,13 @@ SYM_FUNC_END(__get_user_2)
 EXPORT_SYMBOL(__get_user_2)
 
 SYM_FUNC_START(__get_user_4)
-	LOAD_TASK_SIZE_MINUS_N(3)
-	cmp %_ASM_DX,%_ASM_AX
-	jae bad_get_user
-	sbb %_ASM_DX, %_ASM_DX		/* array_index_mask_nospec() */
-	and %_ASM_DX, %_ASM_AX
+	mov %_ASM_AX, %_ASM_DX
+	add $3, %_ASM_DX
+	or %_ASM_AX, %_ASM_DX
+	shl $1, %_ASM_DX
+	jb bad_get_user
+	sbb %_ASM_DX, %_ASM_DX
+	or %_ASM_DX, %_ASM_AX
 	ASM_STAC
 3:	movl (%_ASM_AX),%edx
 	xor %eax,%eax
@@ -91,22 +86,26 @@ EXPORT_SYMBOL(__get_user_4)
 
 SYM_FUNC_START(__get_user_8)
 #ifdef CONFIG_X86_64
-	LOAD_TASK_SIZE_MINUS_N(7)
-	cmp %_ASM_DX,%_ASM_AX
-	jae bad_get_user
-	sbb %_ASM_DX, %_ASM_DX		/* array_index_mask_nospec() */
-	and %_ASM_DX, %_ASM_AX
+	mov %_ASM_AX, %_ASM_DX
+	add $7, %_ASM_DX
+	or %_ASM_AX, %_ASM_DX
+	shl $1, %_ASM_DX
+	jb bad_get_user
+	sbb %_ASM_DX, %_ASM_DX
+	or %_ASM_DX, %_ASM_AX
 	ASM_STAC
 4:	movq (%_ASM_AX),%rdx
 	xor %eax,%eax
 	ASM_CLAC
 	RET
 #else
-	LOAD_TASK_SIZE_MINUS_N(7)
-	cmp %_ASM_DX,%_ASM_AX
-	jae bad_get_user_8
-	sbb %_ASM_DX, %_ASM_DX		/* array_index_mask_nospec() */
-	and %_ASM_DX, %_ASM_AX
+	mov %_ASM_AX, %_ASM_DX
+	add $7, %_ASM_DX
+	or %_ASM_AX, %_ASM_DX
+	shl $1, %_ASM_DX
+	jb bad_get_user
+	sbb %_ASM_DX, %_ASM_DX
+	or %_ASM_DX, %_ASM_AX
 	ASM_STAC
 4:	movl (%_ASM_AX),%edx
 5:	movl 4(%_ASM_AX),%ecx
diff --git a/arch/x86/lib/putuser.S b/arch/x86/lib/putuser.S
index b7dfd60243b7..eb7c4396cb1e 100644
--- a/arch/x86/lib/putuser.S
+++ b/arch/x86/lib/putuser.S
@@ -33,20 +33,11 @@
  * as they get called from within inline assembly.
  */
 
-#ifdef CONFIG_X86_5LEVEL
-#define LOAD_TASK_SIZE_MINUS_N(n) \
-	ALTERNATIVE __stringify(mov $((1 << 47) - 4096 - (n)),%rbx), \
-		    __stringify(mov $((1 << 56) - 4096 - (n)),%rbx), X86_FEATURE_LA57
-#else
-#define LOAD_TASK_SIZE_MINUS_N(n) \
-	mov $(TASK_SIZE_MAX - (n)),%_ASM_BX
-#endif
-
 .text
 SYM_FUNC_START(__put_user_1)
-	LOAD_TASK_SIZE_MINUS_N(0)
-	cmp %_ASM_BX,%_ASM_CX
-	jae .Lbad_put_user
+	mov %_ASM_CX, %_ASM_BX
+	shl $1, %_ASM_BX
+	jb .Lbad_put_user
 SYM_INNER_LABEL(__put_user_nocheck_1, SYM_L_GLOBAL)
 	ENDBR
 	ASM_STAC
@@ -59,9 +50,11 @@ EXPORT_SYMBOL(__put_user_1)
 EXPORT_SYMBOL(__put_user_nocheck_1)
 
 SYM_FUNC_START(__put_user_2)
-	LOAD_TASK_SIZE_MINUS_N(1)
-	cmp %_ASM_BX,%_ASM_CX
-	jae .Lbad_put_user
+	mov %_ASM_CX, %_ASM_BX
+	add $1, %_ASM_BX
+	or %_ASM_CX, %_ASM_BX
+	shl $1, %_ASM_BX
+	jb .Lbad_put_user
 SYM_INNER_LABEL(__put_user_nocheck_2, SYM_L_GLOBAL)
 	ENDBR
 	ASM_STAC
@@ -74,9 +67,11 @@ EXPORT_SYMBOL(__put_user_2)
 EXPORT_SYMBOL(__put_user_nocheck_2)
 
 SYM_FUNC_START(__put_user_4)
-	LOAD_TASK_SIZE_MINUS_N(3)
-	cmp %_ASM_BX,%_ASM_CX
-	jae .Lbad_put_user
+	mov %_ASM_CX, %_ASM_BX
+	add $3, %_ASM_BX
+	or %_ASM_CX, %_ASM_BX
+	shl $1, %_ASM_BX
+	jb .Lbad_put_user
 SYM_INNER_LABEL(__put_user_nocheck_4, SYM_L_GLOBAL)
 	ENDBR
 	ASM_STAC
@@ -89,9 +84,11 @@ EXPORT_SYMBOL(__put_user_4)
 EXPORT_SYMBOL(__put_user_nocheck_4)
 
 SYM_FUNC_START(__put_user_8)
-	LOAD_TASK_SIZE_MINUS_N(7)
-	cmp %_ASM_BX,%_ASM_CX
-	jae .Lbad_put_user
+	mov %_ASM_CX, %_ASM_BX
+	add $7, %_ASM_BX
+	or %_ASM_CX, %_ASM_BX
+	shl $1, %_ASM_BX
+	jb .Lbad_put_user
 SYM_INNER_LABEL(__put_user_nocheck_8, SYM_L_GLOBAL)
 	ENDBR
 	ASM_STAC
diff mbox series

Patch

diff --git a/arch/x86/include/asm/mmu.h b/arch/x86/include/asm/mmu.h
index d150e92163b6..59c617fc7c6f 100644
--- a/arch/x86/include/asm/mmu.h
+++ b/arch/x86/include/asm/mmu.h
@@ -41,6 +41,7 @@  typedef struct {
 #ifdef CONFIG_X86_64
 	unsigned short flags;
 	u64 lam_cr3_mask;
+	u64 untag_mask;
 #endif
 
 	struct mutex lock;
diff --git a/arch/x86/include/asm/mmu_context.h b/arch/x86/include/asm/mmu_context.h
index e6eac047c728..05821534aadc 100644
--- a/arch/x86/include/asm/mmu_context.h
+++ b/arch/x86/include/asm/mmu_context.h
@@ -100,6 +100,12 @@  static inline u64 mm_cr3_lam_mask(struct mm_struct *mm)
 static inline void dup_lam(struct mm_struct *oldmm, struct mm_struct *mm)
 {
 	mm->context.lam_cr3_mask = oldmm->context.lam_cr3_mask;
+	mm->context.untag_mask = oldmm->context.untag_mask;
+}
+
+static inline void mm_reset_untag_mask(struct mm_struct *mm)
+{
+	mm->context.untag_mask = -1UL;
 }
 
 #else
@@ -112,6 +118,10 @@  static inline u64 mm_cr3_lam_mask(struct mm_struct *mm)
 static inline void dup_lam(struct mm_struct *oldmm, struct mm_struct *mm)
 {
 }
+
+static inline void mm_reset_untag_mask(struct mm_struct *mm)
+{
+}
 #endif
 
 #define enter_lazy_tlb enter_lazy_tlb
@@ -138,6 +148,7 @@  static inline int init_new_context(struct task_struct *tsk,
 		mm->context.execute_only_pkey = -1;
 	}
 #endif
+	mm_reset_untag_mask(mm);
 	init_new_context_ldt(mm);
 	return 0;
 }
diff --git a/arch/x86/include/asm/uaccess.h b/arch/x86/include/asm/uaccess.h
index 35f222aa66bf..ca754521a4db 100644
--- a/arch/x86/include/asm/uaccess.h
+++ b/arch/x86/include/asm/uaccess.h
@@ -6,6 +6,7 @@ 
  */
 #include <linux/compiler.h>
 #include <linux/kasan-checks.h>
+#include <linux/mm_types.h>
 #include <linux/string.h>
 #include <asm/asm.h>
 #include <asm/page.h>
@@ -20,6 +21,32 @@  static inline bool pagefault_disabled(void);
 # define WARN_ON_IN_IRQ()
 #endif
 
+#ifdef CONFIG_X86_64
+/*
+ * Mask out tag bits from the address.
+ *
+ * Magic with the 'sign' allows to untag userspace pointer without any branches
+ * while leaving kernel addresses intact.
+ */
+#define untagged_addr(mm, addr)	({					\
+	u64 __addr = (__force u64)(addr);				\
+	s64 sign = (s64)__addr >> 63;					\
+	__addr ^= sign;							\
+	__addr &= (mm)->context.untag_mask;				\
+	__addr ^= sign;							\
+	(__force __typeof__(addr))__addr;				\
+})
+
+#define untagged_ptr(mm, ptr)	({					\
+	u64 __ptrval = (__force u64)(ptr);				\
+	__ptrval = untagged_addr(mm, __ptrval);				\
+	(__force __typeof__(*(ptr)) *)__ptrval;				\
+})
+#else
+#define untagged_addr(mm, addr)	(addr)
+#define untagged_ptr(mm, ptr)	(ptr)
+#endif
+
 /**
  * access_ok - Checks if a user space pointer is valid
  * @addr: User space pointer to start of block to check
@@ -40,7 +67,7 @@  static inline bool pagefault_disabled(void);
 #define access_ok(addr, size)					\
 ({									\
 	WARN_ON_IN_IRQ();						\
-	likely(__access_ok(addr, size));				\
+	likely(__access_ok(untagged_addr(current->mm, addr), size));	\
 })
 
 #include <asm-generic/access_ok.h>
@@ -125,7 +152,13 @@  extern int __get_user_bad(void);
  * Return: zero on success, or -EFAULT on error.
  * On error, the variable @x is set to zero.
  */
-#define get_user(x,ptr) ({ might_fault(); do_get_user_call(get_user,x,ptr); })
+#define get_user(x,ptr)							\
+({									\
+	__typeof__(*(ptr)) __user *__ptr_clean;				\
+	__ptr_clean = untagged_ptr(current->mm, ptr);			\
+	might_fault();							\
+	do_get_user_call(get_user,x,__ptr_clean);			\
+})
 
 /**
  * __get_user - Get a simple variable from user space, with less checking.
@@ -222,7 +255,12 @@  extern void __put_user_nocheck_8(void);
  *
  * Return: zero on success, or -EFAULT on error.
  */
-#define put_user(x, ptr) ({ might_fault(); do_put_user_call(put_user,x,ptr); })
+#define put_user(x, ptr) ({						\
+	__typeof__(*(ptr)) __user *__ptr_clean;				\
+	__ptr_clean = untagged_ptr(current->mm, ptr);			\
+	might_fault();							\
+	do_put_user_call(put_user,x,__ptr_clean);			\
+})
 
 /**
  * __put_user - Write a simple value into user space, with less checking.
diff --git a/arch/x86/kernel/process.c b/arch/x86/kernel/process.c
index 9b2772b7e1f3..18b2bfdf7b9b 100644
--- a/arch/x86/kernel/process.c
+++ b/arch/x86/kernel/process.c
@@ -47,6 +47,7 @@ 
 #include <asm/frame.h>
 #include <asm/unwind.h>
 #include <asm/tdx.h>
+#include <asm/mmu_context.h>
 
 #include "process.h"
 
@@ -367,6 +368,8 @@  void arch_setup_new_exec(void)
 		task_clear_spec_ssb_noexec(current);
 		speculation_ctrl_update(read_thread_flags());
 	}
+
+	mm_reset_untag_mask(current->mm);
 }
 
 #ifdef CONFIG_X86_IOPL_IOPERM