diff mbox series

[V3,04/10] x86/pks: Preserve the PKRS MSR on context switch

Message ID 20201106232908.364581-5-ira.weiny@intel.com (mailing list archive)
State New
Headers show
Series PKS: Add Protection Keys Supervisor (PKS) support V3 | expand

Commit Message

Ira Weiny Nov. 6, 2020, 11:29 p.m. UTC
From: Ira Weiny <ira.weiny@intel.com>

The PKRS MSR is defined as a per-logical-processor register.  This
isolates memory access by logical CPU.  Unfortunately, the MSR is not
managed by XSAVE.  Therefore, tasks must save/restore the MSR value on
context switch.

Define a saved PKRS value in the task struct, as well as a cached
per-logical-processor MSR value which mirrors the MSR value of the
current CPU.  Initialize all tasks with the default MSR value.  Then, on
schedule in, check the saved task MSR vs the per-cpu value.  If
different proceed to write the MSR.  If not avoid the overhead of the
MSR write and continue.

Follow on patches will update the saved PKRS as well as the MSR if
needed.

Finally it should be noted that the underlying WRMSR(MSR_IA32_PKRS) is
not serializing but still maintains ordering properties similar to
WRPKRU.  The current SDM section on PKRS needs updating but should be
the same as that of WRPKRU.  So to quote from the WRPKRU text:

	WRPKRU will never execute transiently. Memory accesses affected
	by PKRU register will not execute (even transiently) until all
	prior executions of WRPKRU have completed execution and updated
	the PKRU register.

Co-developed-by: Fenghua Yu <fenghua.yu@intel.com>
Signed-off-by: Fenghua Yu <fenghua.yu@intel.com>
Co-developed-by: Peter Zijlstra <peterz@infradead.org>
Signed-off-by: Peter Zijlstra <peterz@infradead.org>
Signed-off-by: Ira Weiny <ira.weiny@intel.com>

---
Changes from V2
	Adjust for PKS enable being final patch.

Changes from V1
	Rebase to latest tip/master
		Resolve conflicts with INIT_THREAD changes

Changes since RFC V3
	Per Dave Hansen
		Update commit message
		move saved_pkrs to be in a nicer place
	Per Peter Zijlstra
		Add Comment from Peter
		Clean up white space
		Update authorship
---
 arch/x86/include/asm/msr-index.h    |  1 +
 arch/x86/include/asm/pkeys_common.h | 20 +++++++++++++++++++
 arch/x86/include/asm/processor.h    | 18 ++++++++++++++++-
 arch/x86/kernel/process.c           | 26 ++++++++++++++++++++++++
 arch/x86/mm/pkeys.c                 | 31 +++++++++++++++++++++++++++++
 5 files changed, 95 insertions(+), 1 deletion(-)

Comments

Thomas Gleixner Dec. 17, 2020, 2:50 p.m. UTC | #1
On Fri, Nov 06 2020 at 15:29, ira weiny wrote:
> --- a/arch/x86/kernel/process.c
> +++ b/arch/x86/kernel/process.c
> @@ -43,6 +43,7 @@
>  #include <asm/io_bitmap.h>
>  #include <asm/proto.h>
>  #include <asm/frame.h>
> +#include <asm/pkeys_common.h>
>  
>  #include "process.h"
>  
> @@ -187,6 +188,27 @@ int copy_thread(unsigned long clone_flags, unsigned long sp, unsigned long arg,
>  	return ret;
>  }
>  
> +#ifdef CONFIG_ARCH_HAS_SUPERVISOR_PKEYS
> +DECLARE_PER_CPU(u32, pkrs_cache);
> +static inline void pks_init_task(struct task_struct *tsk)

First of all. I asked several times now not to glue stuff onto a
function without a newline inbetween. It's unreadable.

But what's worse is that the declaration of pkrs_cache which is global
is in a C file and not in a header. And pkrs_cache is not even used in
this file. So what?

> +{
> +	/* New tasks get the most restrictive PKRS value */
> +	tsk->thread.saved_pkrs = INIT_PKRS_VALUE;
> +}
> +static inline void pks_sched_in(void)

Newline between functions. It's fine for stubs, but not for a real implementation.

> diff --git a/arch/x86/mm/pkeys.c b/arch/x86/mm/pkeys.c
> index d1dfe743e79f..76a62419c446 100644
> --- a/arch/x86/mm/pkeys.c
> +++ b/arch/x86/mm/pkeys.c
> @@ -231,3 +231,34 @@ u32 update_pkey_val(u32 pk_reg, int pkey, unsigned int flags)
>  
>  	return pk_reg;
>  }
> +
> +DEFINE_PER_CPU(u32, pkrs_cache);

Again, why is this global?

> +void write_pkrs(u32 new_pkrs)
> +{
> +	u32 *pkrs;
> +
> +	if (!static_cpu_has(X86_FEATURE_PKS))
> +		return;
> +
> +	pkrs = get_cpu_ptr(&pkrs_cache);

So this is called from various places including schedule and also from
the low level entry/exit code. Why do we need to have an extra
preempt_disable/enable() there via get/put_cpu_ptr()?

Just because performance in those code paths does not matter?

> +	if (*pkrs != new_pkrs) {
> +		*pkrs = new_pkrs;
> +		wrmsrl(MSR_IA32_PKRS, new_pkrs);
> +	}
> +	put_cpu_ptr(pkrs);

Now back to the context switch:

> @@ -644,6 +668,8 @@ void __switch_to_xtra(struct task_struct *prev_p, struct task_struct *next_p)
>
>	 if ((tifp ^ tifn) & _TIF_SLD)
>		 switch_to_sld(tifn);
> +
> +	pks_sched_in();
>  }

How is this supposed to work? 

switch_to() {
   ....
   switch_to_extra() {
      ....
      if (unlikely(next_tif & _TIF_WORK_CTXSW_NEXT ||
	           prev_tif & _TIF_WORK_CTXSW_PREV))
	   __switch_to_xtra(prev, next);

I.e. __switch_to_xtra() is only invoked when the above condition is
true, which is not guaranteed at all.

While I have to admit that I dropped the ball on the update for the
entry patch, I'm not too sorry about it anymore when looking at this.

Are you still sure that this is ready for merging?

Thanks,

        tglx
Dave Hansen Dec. 17, 2020, 8:41 p.m. UTC | #2
On 11/6/20 3:29 PM, ira.weiny@intel.com wrote:
>  void disable_TSC(void)
> @@ -644,6 +668,8 @@ void __switch_to_xtra(struct task_struct *prev_p, struct task_struct *next_p)
>  
>  	if ((tifp ^ tifn) & _TIF_SLD)
>  		switch_to_sld(tifn);
> +
> +	pks_sched_in();
>  }

Does the selftest for this ever actually schedule()?

I see it talking about context switching, but I don't immediately see
how it would.
Thomas Gleixner Dec. 17, 2020, 10:43 p.m. UTC | #3
On Thu, Dec 17 2020 at 15:50, Thomas Gleixner wrote:
> On Fri, Nov 06 2020 at 15:29, ira weiny wrote:
>
>> +void write_pkrs(u32 new_pkrs)
>> +{
>> +	u32 *pkrs;
>> +
>> +	if (!static_cpu_has(X86_FEATURE_PKS))
>> +		return;
>> +
>> +	pkrs = get_cpu_ptr(&pkrs_cache);
>
> So this is called from various places including schedule and also from
> the low level entry/exit code. Why do we need to have an extra
> preempt_disable/enable() there via get/put_cpu_ptr()?
>
> Just because performance in those code paths does not matter?
>
>> +	if (*pkrs != new_pkrs) {
>> +		*pkrs = new_pkrs;
>> +		wrmsrl(MSR_IA32_PKRS, new_pkrs);
>> +	}
>> +	put_cpu_ptr(pkrs);

Which made me look at the other branch of your git repo just because I
wanted to know about the 'other' storage requirements and I found this
gem:

> update_global_pkrs()
> ...
>	/*
>	 * If we are preventing access from the old value.  Force the
>	 * update on all running threads.
>	 */
>	if (((old_val == 0) && protection) ||
>	    ((old_val & PKR_WD_BIT) && (protection & PKEY_DISABLE_ACCESS))) {
>		int cpu;
>
>		for_each_online_cpu(cpu) {
>			u32 *ptr = per_cpu_ptr(&pkrs_cache, cpu);
>
>			*ptr = update_pkey_val(*ptr, pkey, protection);
>			wrmsrl_on_cpu(cpu, MSR_IA32_PKRS, *ptr);
>			put_cpu_ptr(ptr);

1) per_cpu_ptr() -> put_cpu_ptr() is broken as per_cpu_ptr() is not
   disabling preemption while put_cpu_ptr() enables it which wreckages
   the preemption count. 

   How was that ever tested at all with any debug option enabled?

   Answer: Not at all

2) How is that sequence:

	ptr = per_cpu_ptr(&pkrs_cache, cpu);
	*ptr = update_pkey_val(*ptr, pkey, protection);
	wrmsrl_on_cpu(cpu, MSR_IA32_PKRS, *ptr);

   supposed to be correct vs. a concurrent modification of the
   pkrs_cache of the remote CPU?

   Answer: Not at all

Also doing a wrmsrl_on_cpu() on _each_ online CPU is insane at best.

  A smp function call on a remote CPU takes ~3-5us when the remote CPU
  is not idle and can immediately respond. If the remote CPU is deep in
  idle it can take up to 100us depending on C-State it is in.

  Even if the remote CPU is not not idle and just has interrupts
  disabled for a few dozen of microseconds this adds up.

  So on a 256 CPU system depending on the state of the remote CPUs this
  stalls the CPU doing the update for anything between 1 and 25ms worst
  case.

  Of course that also violates _all_ CPU isolation mechanisms.

  What for?

  Just for the theoretical chance that _all_ remote CPUs have
  seen that global permission and have it still active?

  You're not serious about that, right?

The only use case for this in your tree is: kmap() and the possible
usage of that mapping outside of the thread context which sets it up.

The only hint for doing this at all is:

    Some users, such as kmap(), sometimes requires PKS to be global.

'sometime requires' is really _not_ a technical explanation.

Where is the explanation why kmap() usage 'sometimes' requires this
global trainwreck in the first place and where is the analysis why this
can't be solved differently?

Detailed use case analysis please.

Thanks,

        tglx
Ira Weiny Dec. 18, 2020, 4:05 a.m. UTC | #4
On Thu, Dec 17, 2020 at 03:50:55PM +0100, Thomas Gleixner wrote:
> On Fri, Nov 06 2020 at 15:29, ira weiny wrote:
> > --- a/arch/x86/kernel/process.c
> > +++ b/arch/x86/kernel/process.c
> > @@ -43,6 +43,7 @@
> >  #include <asm/io_bitmap.h>
> >  #include <asm/proto.h>
> >  #include <asm/frame.h>
> > +#include <asm/pkeys_common.h>
> >  
> >  #include "process.h"
> >  
> > @@ -187,6 +188,27 @@ int copy_thread(unsigned long clone_flags, unsigned long sp, unsigned long arg,
> >  	return ret;
> >  }
> >  
> > +#ifdef CONFIG_ARCH_HAS_SUPERVISOR_PKEYS
> > +DECLARE_PER_CPU(u32, pkrs_cache);
> > +static inline void pks_init_task(struct task_struct *tsk)
> 
> First of all. I asked several times now not to glue stuff onto a
> function without a newline inbetween. It's unreadable.

Fixed.

> 
> But what's worse is that the declaration of pkrs_cache which is global
> is in a C file and not in a header. And pkrs_cache is not even used in
> this file. So what?

OK, this was just a complete rebase/refactor mess up on my part.  The
global'ness is not required until we need a global update of the pkrs which was
not part of this series.

I've removed it from this patch.  And cleaned it up in patch 6/10 as well.  And
cleaned it up in the global pkrs patch which you found in my git tree.

> 
> > +{
> > +	/* New tasks get the most restrictive PKRS value */
> > +	tsk->thread.saved_pkrs = INIT_PKRS_VALUE;
> > +}
> > +static inline void pks_sched_in(void)
> 
> Newline between functions. It's fine for stubs, but not for a real implementation.

Again my apologies.

Fixed.

> 
> > diff --git a/arch/x86/mm/pkeys.c b/arch/x86/mm/pkeys.c
> > index d1dfe743e79f..76a62419c446 100644
> > --- a/arch/x86/mm/pkeys.c
> > +++ b/arch/x86/mm/pkeys.c
> > @@ -231,3 +231,34 @@ u32 update_pkey_val(u32 pk_reg, int pkey, unsigned int flags)
> >  
> >  	return pk_reg;
> >  }
> > +
> > +DEFINE_PER_CPU(u32, pkrs_cache);
> 
> Again, why is this global?

In this patch it does not need to be.  I've changed it to static.

> 
> > +void write_pkrs(u32 new_pkrs)
> > +{
> > +	u32 *pkrs;
> > +
> > +	if (!static_cpu_has(X86_FEATURE_PKS))
> > +		return;
> > +
> > +	pkrs = get_cpu_ptr(&pkrs_cache);
> 
> So this is called from various places including schedule and also from
> the low level entry/exit code. Why do we need to have an extra
> preempt_disable/enable() there via get/put_cpu_ptr()?
> 
> Just because performance in those code paths does not matter?

Honestly I don't recall the full history at this point.  The
preempt_disable/enable() is required when this is called from
pks_update_protection()  AKA when a user is trying to update the protections of
their key.  What I do remember is that this was originally not preempt safe and we
had a comment to that effect in the early patches.[1]

Somewhere along the line the preempt discussion lead us to make write_pkrs()
'self contained' with the preemption protection here.  I just did not think
about any performance issues.  It is safe to call preempt_disable() from a
preempt disabled region, correct?  I seem to recall asking that and the answer
was 'yes'.

I will audit the calls again and adjust the preemption disable as needed.

[1] https://lore.kernel.org/lkml/20200717072056.73134-5-ira.weiny@intel.com/#t

> 
> > +	if (*pkrs != new_pkrs) {
> > +		*pkrs = new_pkrs;
> > +		wrmsrl(MSR_IA32_PKRS, new_pkrs);
> > +	}
> > +	put_cpu_ptr(pkrs);
> 
> Now back to the context switch:
> 
> > @@ -644,6 +668,8 @@ void __switch_to_xtra(struct task_struct *prev_p, struct task_struct *next_p)
> >
> >	 if ((tifp ^ tifn) & _TIF_SLD)
> >		 switch_to_sld(tifn);
> > +
> > +	pks_sched_in();
> >  }
> 
> How is this supposed to work? 
> 
> switch_to() {
>    ....
>    switch_to_extra() {
>       ....
>       if (unlikely(next_tif & _TIF_WORK_CTXSW_NEXT ||
> 	           prev_tif & _TIF_WORK_CTXSW_PREV))
> 	   __switch_to_xtra(prev, next);
> 
> I.e. __switch_to_xtra() is only invoked when the above condition is
> true, which is not guaranteed at all.

I did not know that.  I completely missunderstood what __switch_to_xtra()
meant.  I thought it was arch specific 'extra' stuff so it seemed reasonable to
me.

Also, our test seemed to work.  I'm still investigating what may be wrong.

> 
> While I have to admit that I dropped the ball on the update for the
> entry patch, I'm not too sorry about it anymore when looking at this.
> 
> Are you still sure that this is ready for merging?

Nope...

Thanks for the review,
Ira

> 
> Thanks,
> 
>         tglx
Ira Weiny Dec. 18, 2020, 4:10 a.m. UTC | #5
On Thu, Dec 17, 2020 at 12:41:50PM -0800, Dave Hansen wrote:
> On 11/6/20 3:29 PM, ira.weiny@intel.com wrote:
> >  void disable_TSC(void)
> > @@ -644,6 +668,8 @@ void __switch_to_xtra(struct task_struct *prev_p, struct task_struct *next_p)
> >  
> >  	if ((tifp ^ tifn) & _TIF_SLD)
> >  		switch_to_sld(tifn);
> > +
> > +	pks_sched_in();
> >  }
> 
> Does the selftest for this ever actually schedule()?

At this point I'm not sure.  This code has been in since the beginning.  So its
seen a lot of soak time.

> 
> I see it talking about context switching, but I don't immediately see
> how it would.

We were trying to force parent and child to run on the same CPU.  I suspect
something is wrong in the timing of that test.

Ira
Thomas Gleixner Dec. 18, 2020, 1:57 p.m. UTC | #6
On Thu, Dec 17 2020 at 23:43, Thomas Gleixner wrote:
> The only use case for this in your tree is: kmap() and the possible
> usage of that mapping outside of the thread context which sets it up.
>
> The only hint for doing this at all is:
>
>     Some users, such as kmap(), sometimes requires PKS to be global.
>
> 'sometime requires' is really _not_ a technical explanation.
>
> Where is the explanation why kmap() usage 'sometimes' requires this
> global trainwreck in the first place and where is the analysis why this
> can't be solved differently?
>
> Detailed use case analysis please.

A lengthy conversation with Dan and Dave over IRC confirmed what I was
suspecting.

The approach of this whole PKS thing is to make _all_ existing code
magically "work". That means aside of the obvious thread local mappings,
the kmap() part is needed to solve the problem of async handling where
the mapping is handed to some other context which then uses it and
notifies the context which created the mapping when done. That's the
principle which was used to make highmem work long time ago.

IMO that was a mistake back then. The right thing would have been to
change the code so that it does not rely on a temporary mapping created
by the initiator. Instead let the initiator hand the page over to the
other context which then creates a temporary mapping for fiddling with
it. Water under the bridge...

Glueing PKS on to that kmap() thing is horrible and global PKS is pretty
much the opposite of what PKS wants to achieve. It's disabling
protection systemwide for an unspecified amount of time and for all
contexts.

So instead of trying to make global PKS "work" we really should go and
take a smarter approach.

  1) Many kmap() use cases are strictly thread local and the mapped
     address is never handed to some other context, which means this can
     be replaced with kmap_local() now, which preserves the mapping
     accross preemption. PKS just works nicely on top of that.

  2) Modify kmap() so that it marks the to be mapped page as 'globaly
     unprotected' instead of doing this global unprotect PKS dance.
     kunmap() undoes that. That obviously needs some thought
     vs. refcounting if there are concurrent users, but that's a
     solvable problem either as part of struct page itself or
     stored in some global hash.

  3) Have PKS modes:

     - STRICT:   No pardon
     
     - RELAXED:  Warn and unprotect temporary for the current context

     - SILENT:	 Like RELAXED, but w/o warning to make sysadmins happy.
                 Default should be RELAXED.

     - OFF:      Disable the whole PKS thing


  4) Have a smart #PF mechanism which does:

     if (error_code & X86_PF_PK) {
         page = virt_to_page(address);

         if (!page || !page_is_globaly_unprotected(page))
                 goto die;

         if (pks_mode == PKS_MODE_STRICT)
         	 goto die;

         WARN_ONCE(pks_mode == PKS_MODE_RELAXED, "Useful info ...");

         temporary_unprotect(page, regs);
         return;
     }

     temporary_unprotect(page, regs)
     {
        key = page_to_key(page);

	/* Return from #PF will establish this for the faulting context */
        extended_state(regs)->pks &= ~PKS_MASK(key);
     }

     This temporary unprotect is undone when the context is left, so
     depending on the context (thread, interrupt, softirq) the
     unprotected section might be way wider than actually needed, but
     that's still orders of magnitudes better than having this fully
     unrestricted global PKS mode which is completely scopeless.

     The above is at least restricted to the pages which are in use for
     a particular operation. Stray pointers during that time are
     obviously not caught, but that's not any different from that
     proposed global thingy.

     The warning allows to find the non-obvious places so they can be
     analyzed and worked on.

  5) The DAX case which you made "work" with dev_access_enable() and
     dev_access_disable(), i.e. with yet another lazy approach of
     avoiding to change a handful of usage sites.

     The use cases are strictly context local which means the global
     magic is not used at all. Why does it exist in the first place?

     Aside of that this global thing would never work at all because the
     refcounting is per thread and not global.

     So that DAX use case is just a matter of:

        grant/revoke_access(DEV_PKS_KEY, READ/WRITE)

     which is effective for the current execution context and really
     wants to be a distinct READ/WRITE protection and not the magic
     global thing which just has on/off. All usage sites know whether
     they want to read or write.
   
     That leaves the question about the refcount. AFAICT, nothing nests
     in that use case for a given execution context. I'm surely missing
     something subtle here.

     Hmm?

Thanks,

        tglx
Dave Hansen Dec. 18, 2020, 3:33 p.m. UTC | #7
On 12/17/20 8:10 PM, Ira Weiny wrote:
> On Thu, Dec 17, 2020 at 12:41:50PM -0800, Dave Hansen wrote:
>> On 11/6/20 3:29 PM, ira.weiny@intel.com wrote:
>>>  void disable_TSC(void)
>>> @@ -644,6 +668,8 @@ void __switch_to_xtra(struct task_struct *prev_p, struct task_struct *next_p)
>>>  
>>>  	if ((tifp ^ tifn) & _TIF_SLD)
>>>  		switch_to_sld(tifn);
>>> +
>>> +	pks_sched_in();
>>>  }
>>
>> Does the selftest for this ever actually schedule()?
> 
> At this point I'm not sure.  This code has been in since the beginning.  So its
> seen a lot of soak time.

Think about it another way.  Let's say this didn't get called on the
first context switch away from the PKS-using task.  Would anyone notice?
 How likely is this to happen?

The function tracers or kprobes tend to be a great tool for this, at
least for testing whether the code path you expect to hit is getting hit.
Dan Williams Dec. 18, 2020, 7:20 p.m. UTC | #8
On Fri, Dec 18, 2020 at 5:58 AM Thomas Gleixner <tglx@linutronix.de> wrote:
[..]
>   5) The DAX case which you made "work" with dev_access_enable() and
>      dev_access_disable(), i.e. with yet another lazy approach of
>      avoiding to change a handful of usage sites.
>
>      The use cases are strictly context local which means the global
>      magic is not used at all. Why does it exist in the first place?
>
>      Aside of that this global thing would never work at all because the
>      refcounting is per thread and not global.
>
>      So that DAX use case is just a matter of:
>
>         grant/revoke_access(DEV_PKS_KEY, READ/WRITE)
>
>      which is effective for the current execution context and really
>      wants to be a distinct READ/WRITE protection and not the magic
>      global thing which just has on/off. All usage sites know whether
>      they want to read or write.

I was tracking and nodding until this point. Yes, kill the global /
kmap() support, but if grant/revoke_access is not integrated behind
kmap_{local,atomic}() then it's not a "handful" of sites that need to
be instrumented it's 100s. Are you suggesting that "relaxed" mode
enforcement is a way to distribute the work of teaching driver writers
that they need to incorporate explicit grant/revoke-read/write in
addition to kmap? The entire reason PTE_DEVMAP exists was to allow
get_user_pages() for PMEM and not require every downstream-GUP code
path to specifically consider whether it was talking to PMEM or RAM
pages, and certainly not whether they were reading or writing to it.
Ira Weiny Dec. 18, 2020, 7:42 p.m. UTC | #9
On Fri, Dec 18, 2020 at 02:57:51PM +0100, Thomas Gleixner wrote:
> On Thu, Dec 17 2020 at 23:43, Thomas Gleixner wrote:
> > The only use case for this in your tree is: kmap() and the possible
> > usage of that mapping outside of the thread context which sets it up.
> >
> > The only hint for doing this at all is:
> >
> >     Some users, such as kmap(), sometimes requires PKS to be global.
> >
> > 'sometime requires' is really _not_ a technical explanation.
> >
> > Where is the explanation why kmap() usage 'sometimes' requires this
> > global trainwreck in the first place and where is the analysis why this
> > can't be solved differently?
> >
> > Detailed use case analysis please.
> 
> A lengthy conversation with Dan and Dave over IRC confirmed what I was
> suspecting.
> 
> The approach of this whole PKS thing is to make _all_ existing code
> magically "work". That means aside of the obvious thread local mappings,
> the kmap() part is needed to solve the problem of async handling where
> the mapping is handed to some other context which then uses it and
> notifies the context which created the mapping when done. That's the
> principle which was used to make highmem work long time ago.
> 
> IMO that was a mistake back then. The right thing would have been to
> change the code so that it does not rely on a temporary mapping created
> by the initiator. Instead let the initiator hand the page over to the
> other context which then creates a temporary mapping for fiddling with
> it. Water under the bridge...

But maybe not.  We are getting rid of a lot of the kmaps and once the bulk are
gone perhaps we can change this and remove kmap completely?

> 
> Glueing PKS on to that kmap() thing is horrible and global PKS is pretty
> much the opposite of what PKS wants to achieve. It's disabling
> protection systemwide for an unspecified amount of time and for all
> contexts.

I agree.  This is why I have been working on converting kmap() call sites to
kmap_local_page().[1]

> 
> So instead of trying to make global PKS "work" we really should go and
> take a smarter approach.
> 
>   1) Many kmap() use cases are strictly thread local and the mapped
>      address is never handed to some other context, which means this can
>      be replaced with kmap_local() now, which preserves the mapping
>      accross preemption. PKS just works nicely on top of that.

Yes hence the massive kmap->kmap_thread patch set which is now becoming
kmap_local_page().[2]

> 
>   2) Modify kmap() so that it marks the to be mapped page as 'globaly
>      unprotected' instead of doing this global unprotect PKS dance.
>      kunmap() undoes that. That obviously needs some thought
>      vs. refcounting if there are concurrent users, but that's a
>      solvable problem either as part of struct page itself or
>      stored in some global hash.

How would this globally unprotected flag work?  I suppose if kmap created a new
PTE we could make that PTE non-PKS protected then we don't have to fiddle with
the register...  I think I like that idea.

> 
>   3) Have PKS modes:
> 
>      - STRICT:   No pardon
>      
>      - RELAXED:  Warn and unprotect temporary for the current context
> 
>      - SILENT:	 Like RELAXED, but w/o warning to make sysadmins happy.
>                  Default should be RELAXED.
> 
>      - OFF:      Disable the whole PKS thing

I'm not really sure how this solves the global problem but it is probably worth
having in general.

> 
> 
>   4) Have a smart #PF mechanism which does:
> 
>      if (error_code & X86_PF_PK) {
>          page = virt_to_page(address);
> 
>          if (!page || !page_is_globaly_unprotected(page))
>                  goto die;
> 
>          if (pks_mode == PKS_MODE_STRICT)
>          	 goto die;
> 
>          WARN_ONCE(pks_mode == PKS_MODE_RELAXED, "Useful info ...");
> 
>          temporary_unprotect(page, regs);
>          return;
>      }

I feel like this is very similar to what I had in the global patch you found in
my git tree with the exception of the RELAXED mode.  I simply had globally
unprotected or die.

global_pkey_is_enabled() handles the page_is_globaly_unprotected() and
temporary_unprotect().[3]

Anyway, I'm sorry (but not sorry) that you found it.  I've been trying to get
0-day and other testing on it and my public tree was the easiest way to do
that.  Anyway...

The patch as a whole needs work.  You are 100% correct that if a mapping is
handed to another context it is going to suck performance wise.  It has had
some internal review but not much.

Regardless I think unprotecting a global context is the easy part.  The code
you had a problem with (and I see is fully broken) was the restriction of
access.  A failure to update in that direction would only result in a wider
window of access.  I contemplated not doing a global update at all and just
leave the access open until the next context switch.  But the code as it stands
tries to force an update for a couple of reasons:

1) kmap_local_page() removes most of the need for global pks.  So I was
   thinking that global PKS could be a slow path.

2) kmap()'s that are handed to other contexts they are likely to be 'long term'
   and should not need to be updated 'too' often.  I will admit that I don't
   know how often 'too often' is.

But IMO these questions are best left to after the kmaps are converted.  Thus
this patch set was just basic support.  Other uses cases beyond pmem such as
trusted keys or secret mem don't need a global pks feature and could build on
the patch set submitted.  I was trying to break the problem down.

> 
>      temporary_unprotect(page, regs)
>      {
>         key = page_to_key(page);
> 
> 	/* Return from #PF will establish this for the faulting context */
>         extended_state(regs)->pks &= ~PKS_MASK(key);
>      }
> 
>      This temporary unprotect is undone when the context is left, so
>      depending on the context (thread, interrupt, softirq) the
>      unprotected section might be way wider than actually needed, but
>      that's still orders of magnitudes better than having this fully
>      unrestricted global PKS mode which is completely scopeless.

I'm not sure I follow you.  How would we know when the context is left?

> 
>      The above is at least restricted to the pages which are in use for
>      a particular operation. Stray pointers during that time are
>      obviously not caught, but that's not any different from that
>      proposed global thingy.
> 
>      The warning allows to find the non-obvious places so they can be
>      analyzed and worked on.

I could add the warning for sure.

> 
>   5) The DAX case which you made "work" with dev_access_enable() and
>      dev_access_disable(), i.e. with yet another lazy approach of
>      avoiding to change a handful of usage sites.
> 
>      The use cases are strictly context local which means the global
>      magic is not used at all. Why does it exist in the first place?

I'm not following.  What is 'it'?

> 
>      Aside of that this global thing would never work at all because the
>      refcounting is per thread and not global.
> 
>      So that DAX use case is just a matter of:
> 
>         grant/revoke_access(DEV_PKS_KEY, READ/WRITE)
> 
>      which is effective for the current execution context and really
>      wants to be a distinct READ/WRITE protection and not the magic
>      global thing which just has on/off. All usage sites know whether
>      they want to read or write.
>    
>      That leaves the question about the refcount. AFAICT, nothing nests
>      in that use case for a given execution context. I'm surely missing
>      something subtle here.

The refcount is needed for non-global pks as well as global.  I've not resolved
if anything needs to be done with the refcount on the global update since the
following is legal.

kmap()
kmap_local_page()
kunmap()
kunmap_local()

Which would be a problem.  But I don't think it is ever actually done.

Another problem would be if the kmap and kunmap happened in different
contexts...  :-/  I don't think that is done either but I don't know for
certain.

Frankly, my main focus before any of this global support has been to get rid of
as many kmaps as possible.[1]  Once that is done I think more of these
questions can be answered better.

Ira

[1] https://lore.kernel.org/lkml/20201210171834.2472353-1-ira.weiny@intel.com/
[2] https://lore.kernel.org/lkml/20201009195033.3208459-1-ira.weiny@intel.com/
[3] Latest untested patch pushed for reference here because I can't find
    exactly the branch you found.
    https://github.com/weiny2/linux-kernel/commit/37439e91e141be58c13ccc4462f7782311680636

> 
>      Hmm?
> 
> Thanks,
> 
>         tglx
>
Dave Hansen Dec. 18, 2020, 8:10 p.m. UTC | #10
On 12/18/20 11:42 AM, Ira Weiny wrote:
> Another problem would be if the kmap and kunmap happened in different
> contexts...  :-/  I don't think that is done either but I don't know for
> certain.

It would be really nice to put together some surveillance patches to
help become more certain about these things.  Even a per-task counter
would be better than nothing.

On kmap:
	current->kmaps++;
On kunmap:
	current->kmaps--;
	WARN_ON(current->kmaps < 0);
On exit:
	WARN_ON(current->kmaps);

That would at least find imbalances.  You could take it even further by
having a little array, say:

struct one_kmap {
	struct page *page;
	depot_stack_handle_t handle;
};

Then:

	 struct task_struct {
		...
	+	struct one_kmap kmaps[10];
	 };

On kmap() you make a new entry in current->kmaps[], and on kunmap() you
try to find the corresponding entry.  If you can't find one, in the
current task you can even go search all the other tasks and see who
might be responsible.  If something goes and does more than 10
simultaneous kmap()s in one thread, dump a warning and give up.  Or,
dynamically allocate the kmaps[] array.

Then you can dump out the stack of the kmap() culprit if it exits after
a kmap() but without a corresponding kfree().

Something like that should be low overhead enough to get it into things
like the 0day debug kernel.  It should be way cheaper than something
like lockdep.
Thomas Gleixner Dec. 18, 2020, 9:06 p.m. UTC | #11
On Fri, Dec 18 2020 at 11:20, Dan Williams wrote:
> On Fri, Dec 18, 2020 at 5:58 AM Thomas Gleixner <tglx@linutronix.de> wrote:
> [..]
>>   5) The DAX case which you made "work" with dev_access_enable() and
>>      dev_access_disable(), i.e. with yet another lazy approach of
>>      avoiding to change a handful of usage sites.
>>
>>      The use cases are strictly context local which means the global
>>      magic is not used at all. Why does it exist in the first place?
>>
>>      Aside of that this global thing would never work at all because the
>>      refcounting is per thread and not global.
>>
>>      So that DAX use case is just a matter of:
>>
>>         grant/revoke_access(DEV_PKS_KEY, READ/WRITE)
>>
>>      which is effective for the current execution context and really
>>      wants to be a distinct READ/WRITE protection and not the magic
>>      global thing which just has on/off. All usage sites know whether
>>      they want to read or write.
>
> I was tracking and nodding until this point. Yes, kill the global /
> kmap() support, but if grant/revoke_access is not integrated behind
> kmap_{local,atomic}() then it's not a "handful" of sites that need to
> be instrumented it's 100s. Are you suggesting that "relaxed" mode
> enforcement is a way to distribute the work of teaching driver writers
> that they need to incorporate explicit grant/revoke-read/write in
> addition to kmap? The entire reason PTE_DEVMAP exists was to allow
> get_user_pages() for PMEM and not require every downstream-GUP code
> path to specifically consider whether it was talking to PMEM or RAM
> pages, and certainly not whether they were reading or writing to it.

kmap_local() is fine. That can work automatically because it's strict
local to the context which does the mapping.

kmap() is dubious because it's a 'global' mapping as dictated per
HIGHMEM. So doing the RELAXED mode for kmap() is sensible I think to
identify cases where the mapped address is really handed to a different
execution context. We want to see those cases and analyse whether this
can't be solved in a different way. That's why I suggested to do a
warning in that case.

Also vs. the DAX use case I really meant the code in fs/dax and
drivers/dax/ itself which is handling this via dax_read_[un]lock.

Does that make more sense?

Thanks,

        tglx
Thomas Gleixner Dec. 18, 2020, 9:30 p.m. UTC | #12
On Fri, Dec 18 2020 at 11:42, Ira Weiny wrote:
> On Fri, Dec 18, 2020 at 02:57:51PM +0100, Thomas Gleixner wrote:
>>   2) Modify kmap() so that it marks the to be mapped page as 'globaly
>>      unprotected' instead of doing this global unprotect PKS dance.
>>      kunmap() undoes that. That obviously needs some thought
>>      vs. refcounting if there are concurrent users, but that's a
>>      solvable problem either as part of struct page itself or
>>      stored in some global hash.
>
> How would this globally unprotected flag work?  I suppose if kmap created a new
> PTE we could make that PTE non-PKS protected then we don't have to fiddle with
> the register...  I think I like that idea.

No. Look at the highmem implementation of kmap(). It's a terrible idea,
really. Don't even think about that.

There is _no_ global flag. The point is that the kmap is strictly bound
to a particular struct page. So you can simply do:

  kmap(page)
    if (page_is_access_protected(page))
        atomic_inc(&page->unprotect);

  kunmap(page)
    if (page_is_access_protected(page))
        atomic_dec(&page->unprotect);

and in the #PF handler:

    if (!page->unprotect)
       goto die;

The reason why I said: either in struct page itself or in a global hash
is that struct page is already packed and people are not really happy
about increasing it's size. But the principle is roughly the same.

>> 
>>   4) Have a smart #PF mechanism which does:
>> 
>>      if (error_code & X86_PF_PK) {
>>          page = virt_to_page(address);
>> 
>>          if (!page || !page_is_globaly_unprotected(page))
>>                  goto die;
>> 
>>          if (pks_mode == PKS_MODE_STRICT)
>>          	 goto die;
>> 
>>          WARN_ONCE(pks_mode == PKS_MODE_RELAXED, "Useful info ...");
>> 
>>          temporary_unprotect(page, regs);
>>          return;
>>      }
>
> I feel like this is very similar to what I had in the global patch you found in
> my git tree with the exception of the RELAXED mode.  I simply had globally
> unprotected or die.

Your stuff depends on that global_pks_state which is not maintainable
especially not the teardown side. This depends on per page state which
is clearly way simpler and more focussed.

> Regardless I think unprotecting a global context is the easy part.  The code
> you had a problem with (and I see is fully broken) was the restriction of
> access.  A failure to update in that direction would only result in a wider
> window of access.  I contemplated not doing a global update at all and just
> leave the access open until the next context switch.  But the code as it stands
> tries to force an update for a couple of reasons:
>
> 1) kmap_local_page() removes most of the need for global pks.  So I was
>    thinking that global PKS could be a slow path.
>
> 2) kmap()'s that are handed to other contexts they are likely to be 'long term'
>    and should not need to be updated 'too' often.  I will admit that I don't
>    know how often 'too often' is.

Even once in while is not a justification for stopping the world for N
milliseconds.

>>      temporary_unprotect(page, regs)
>>      {
>>         key = page_to_key(page);
>> 
>> 	/* Return from #PF will establish this for the faulting context */
>>         extended_state(regs)->pks &= ~PKS_MASK(key);
>>      }
>> 
>>      This temporary unprotect is undone when the context is left, so
>>      depending on the context (thread, interrupt, softirq) the
>>      unprotected section might be way wider than actually needed, but
>>      that's still orders of magnitudes better than having this fully
>>      unrestricted global PKS mode which is completely scopeless.
>
> I'm not sure I follow you.  How would we know when the context is
> left?

The context goes away on it's own. Either context switch or return from
interrupt. As I said there is an extended window where the external
context still might have unprotected access even if the initiating
context has called kunmap() already. It's not pretty, but it's not the
end of the world either.

That's why I suggested to have that WARN_ONCE() so we can actually see
why and where that happens and think about solutions to make this go
into local context, e.g. by changing the vaddr pointer to a struct page
pointer for these particular use cases and then the other context can do
kmap/unmap_local().

>>   5) The DAX case which you made "work" with dev_access_enable() and
>>      dev_access_disable(), i.e. with yet another lazy approach of
>>      avoiding to change a handful of usage sites.
>> 
>>      The use cases are strictly context local which means the global
>>      magic is not used at all. Why does it exist in the first place?
>
> I'm not following.  What is 'it'?

That global argument to dev_access_enable()/disable(). 

>>      That leaves the question about the refcount. AFAICT, nothing nests
>>      in that use case for a given execution context. I'm surely missing
>>      something subtle here.
>
> The refcount is needed for non-global pks as well as global.  I've not resolved
> if anything needs to be done with the refcount on the global update since the
> following is legal.
>
> kmap()
> kmap_local_page()
> kunmap()
> kunmap_local()
>
> Which would be a problem.  But I don't think it is ever actually done.

If it does not exist why would we support it in the first place? We can
have some warning there to catch that case.

> Another problem would be if the kmap and kunmap happened in different
> contexts...  :-/  I don't think that is done either but I don't know for
> certain.
>
> Frankly, my main focus before any of this global support has been to
> get rid of as many kmaps as possible.[1] Once that is done I think
> more of these questions can be answered better.

I was expecting that you could answer these questions :)

Thanks,

        tglx
Dan Williams Dec. 18, 2020, 9:58 p.m. UTC | #13
On Fri, Dec 18, 2020 at 1:06 PM Thomas Gleixner <tglx@linutronix.de> wrote:
>
> On Fri, Dec 18 2020 at 11:20, Dan Williams wrote:
> > On Fri, Dec 18, 2020 at 5:58 AM Thomas Gleixner <tglx@linutronix.de> wrote:
> > [..]
> >>   5) The DAX case which you made "work" with dev_access_enable() and
> >>      dev_access_disable(), i.e. with yet another lazy approach of
> >>      avoiding to change a handful of usage sites.
> >>
> >>      The use cases are strictly context local which means the global
> >>      magic is not used at all. Why does it exist in the first place?
> >>
> >>      Aside of that this global thing would never work at all because the
> >>      refcounting is per thread and not global.
> >>
> >>      So that DAX use case is just a matter of:
> >>
> >>         grant/revoke_access(DEV_PKS_KEY, READ/WRITE)
> >>
> >>      which is effective for the current execution context and really
> >>      wants to be a distinct READ/WRITE protection and not the magic
> >>      global thing which just has on/off. All usage sites know whether
> >>      they want to read or write.
> >
> > I was tracking and nodding until this point. Yes, kill the global /
> > kmap() support, but if grant/revoke_access is not integrated behind
> > kmap_{local,atomic}() then it's not a "handful" of sites that need to
> > be instrumented it's 100s. Are you suggesting that "relaxed" mode
> > enforcement is a way to distribute the work of teaching driver writers
> > that they need to incorporate explicit grant/revoke-read/write in
> > addition to kmap? The entire reason PTE_DEVMAP exists was to allow
> > get_user_pages() for PMEM and not require every downstream-GUP code
> > path to specifically consider whether it was talking to PMEM or RAM
> > pages, and certainly not whether they were reading or writing to it.
>
> kmap_local() is fine. That can work automatically because it's strict
> local to the context which does the mapping.
>
> kmap() is dubious because it's a 'global' mapping as dictated per
> HIGHMEM. So doing the RELAXED mode for kmap() is sensible I think to
> identify cases where the mapped address is really handed to a different
> execution context. We want to see those cases and analyse whether this
> can't be solved in a different way. That's why I suggested to do a
> warning in that case.
>
> Also vs. the DAX use case I really meant the code in fs/dax and
> drivers/dax/ itself which is handling this via dax_read_[un]lock.
>
> Does that make more sense?

Yup, got it. The dax code can be precise wrt to PKS in a way that
kmap_local() cannot.
Thomas Gleixner Dec. 18, 2020, 10:44 p.m. UTC | #14
On Fri, Dec 18 2020 at 13:58, Dan Williams wrote:
> On Fri, Dec 18, 2020 at 1:06 PM Thomas Gleixner <tglx@linutronix.de> wrote:
>> kmap_local() is fine. That can work automatically because it's strict
>> local to the context which does the mapping.
>>
>> kmap() is dubious because it's a 'global' mapping as dictated per
>> HIGHMEM. So doing the RELAXED mode for kmap() is sensible I think to
>> identify cases where the mapped address is really handed to a different
>> execution context. We want to see those cases and analyse whether this
>> can't be solved in a different way. That's why I suggested to do a
>> warning in that case.
>>
>> Also vs. the DAX use case I really meant the code in fs/dax and
>> drivers/dax/ itself which is handling this via dax_read_[un]lock.
>>
>> Does that make more sense?
>
> Yup, got it. The dax code can be precise wrt to PKS in a way that
> kmap_local() cannot.

Which makes me wonder whether we should have kmap_local_for_read()
or something like that, which could be obviously only be RO enforced for
the real HIGHMEM case or the (for now x86 only) enforced kmap_local()
debug mechanics on 64bit.

So for the !highmem case it would not magically make the existing kernel
mapping RO, but this could be forwarded to the PKS protection. Aside of
that it's a nice annotation in the code.

That could be used right away for all the kmap[_atomic] -> kmap_local
conversions.

Thanks,

        tglx
---
 include/linux/highmem-internal.h |   14 ++++++++++++++
 1 file changed, 14 insertions(+)

--- a/include/linux/highmem-internal.h
+++ b/include/linux/highmem-internal.h
@@ -32,6 +32,10 @@ static inline void kmap_flush_tlb(unsign
 #define kmap_prot PAGE_KERNEL
 #endif
 
+#ifndef kmap_prot_to
+#define kmap_prot PAGE_KERNEL_RO
+#endif
+
 void *kmap_high(struct page *page);
 void kunmap_high(struct page *page);
 void __kmap_flush_unused(void);
@@ -73,6 +77,11 @@ static inline void *kmap_local_page(stru
 	return __kmap_local_page_prot(page, kmap_prot);
 }
 
+static inline void *kmap_local_page_for_read(struct page *page)
+{
+	return __kmap_local_page_prot(page, kmap_prot_ro);
+}
+
 static inline void *kmap_local_page_prot(struct page *page, pgprot_t prot)
 {
 	return __kmap_local_page_prot(page, prot);
@@ -169,6 +178,11 @@ static inline void *kmap_local_page_prot
 {
 	return kmap_local_page(page);
 }
+
+static inline void *kmap_local_page_for_read(struct page *page)
+{
+	return kmap_local_page(page);
+}
 
 static inline void *kmap_local_pfn(unsigned long pfn)
 {
diff mbox series

Patch

diff --git a/arch/x86/include/asm/msr-index.h b/arch/x86/include/asm/msr-index.h
index 972a34d93505..ddb125e44408 100644
--- a/arch/x86/include/asm/msr-index.h
+++ b/arch/x86/include/asm/msr-index.h
@@ -754,6 +754,7 @@ 
 
 #define MSR_IA32_TSC_DEADLINE		0x000006E0
 
+#define MSR_IA32_PKRS			0x000006E1
 
 #define MSR_TSX_FORCE_ABORT		0x0000010F
 
diff --git a/arch/x86/include/asm/pkeys_common.h b/arch/x86/include/asm/pkeys_common.h
index 737d916f476c..801a75615209 100644
--- a/arch/x86/include/asm/pkeys_common.h
+++ b/arch/x86/include/asm/pkeys_common.h
@@ -12,4 +12,24 @@ 
  */
 #define PKR_AD_KEY(pkey)	(PKR_AD_BIT << ((pkey) * PKR_BITS_PER_PKEY))
 
+/*
+ * Define a default PKRS value for each task.
+ *
+ * Key 0 has no restriction.  All other keys are set to the most restrictive
+ * value which is access disabled (AD=1).
+ *
+ * NOTE: This needs to be a macro to be used as part of the INIT_THREAD macro.
+ */
+#define INIT_PKRS_VALUE (PKR_AD_KEY(1) | PKR_AD_KEY(2) | PKR_AD_KEY(3) | \
+			 PKR_AD_KEY(4) | PKR_AD_KEY(5) | PKR_AD_KEY(6) | \
+			 PKR_AD_KEY(7) | PKR_AD_KEY(8) | PKR_AD_KEY(9) | \
+			 PKR_AD_KEY(10) | PKR_AD_KEY(11) | PKR_AD_KEY(12) | \
+			 PKR_AD_KEY(13) | PKR_AD_KEY(14) | PKR_AD_KEY(15))
+
+#ifdef CONFIG_ARCH_HAS_SUPERVISOR_PKEYS
+void write_pkrs(u32 new_pkrs);
+#else
+static inline void write_pkrs(u32 new_pkrs) { }
+#endif
+
 #endif /*_ASM_X86_PKEYS_INTERNAL_H */
diff --git a/arch/x86/include/asm/processor.h b/arch/x86/include/asm/processor.h
index 82a08b585818..e9c65368b0b2 100644
--- a/arch/x86/include/asm/processor.h
+++ b/arch/x86/include/asm/processor.h
@@ -18,6 +18,7 @@  struct vm86;
 #include <asm/cpufeatures.h>
 #include <asm/page.h>
 #include <asm/pgtable_types.h>
+#include <asm/pkeys_common.h>
 #include <asm/percpu.h>
 #include <asm/msr.h>
 #include <asm/desc_defs.h>
@@ -520,6 +521,12 @@  struct thread_struct {
 	unsigned long		cr2;
 	unsigned long		trap_nr;
 	unsigned long		error_code;
+
+#ifdef	CONFIG_ARCH_HAS_SUPERVISOR_PKEYS
+	/* Saved Protection key register for supervisor mappings */
+	u32			saved_pkrs;
+#endif
+
 #ifdef CONFIG_VM86
 	/* Virtual 86 mode info */
 	struct vm86		*vm86;
@@ -785,7 +792,16 @@  static inline void spin_lock_prefetch(const void *x)
 #define KSTK_ESP(task)		(task_pt_regs(task)->sp)
 
 #else
-#define INIT_THREAD { }
+
+#ifdef CONFIG_ARCH_HAS_SUPERVISOR_PKEYS
+#define INIT_THREAD_PKRS	.saved_pkrs = INIT_PKRS_VALUE
+#else
+#define INIT_THREAD_PKRS	0
+#endif
+
+#define INIT_THREAD  {						\
+	INIT_THREAD_PKRS,					\
+}
 
 extern unsigned long KSTK_ESP(struct task_struct *task);
 
diff --git a/arch/x86/kernel/process.c b/arch/x86/kernel/process.c
index ba4593a913fa..aa2ae5292ff1 100644
--- a/arch/x86/kernel/process.c
+++ b/arch/x86/kernel/process.c
@@ -43,6 +43,7 @@ 
 #include <asm/io_bitmap.h>
 #include <asm/proto.h>
 #include <asm/frame.h>
+#include <asm/pkeys_common.h>
 
 #include "process.h"
 
@@ -187,6 +188,27 @@  int copy_thread(unsigned long clone_flags, unsigned long sp, unsigned long arg,
 	return ret;
 }
 
+#ifdef CONFIG_ARCH_HAS_SUPERVISOR_PKEYS
+DECLARE_PER_CPU(u32, pkrs_cache);
+static inline void pks_init_task(struct task_struct *tsk)
+{
+	/* New tasks get the most restrictive PKRS value */
+	tsk->thread.saved_pkrs = INIT_PKRS_VALUE;
+}
+static inline void pks_sched_in(void)
+{
+	/*
+	 * PKRS is only temporarily changed during specific code paths.  Only a
+	 * preemption during these windows away from the default value would
+	 * require updating the MSR.  write_pkrs() handles this optimization.
+	 */
+	write_pkrs(current->thread.saved_pkrs);
+}
+#else
+static inline void pks_init_task(struct task_struct *tsk) { }
+static inline void pks_sched_in(void) { }
+#endif
+
 void flush_thread(void)
 {
 	struct task_struct *tsk = current;
@@ -195,6 +217,8 @@  void flush_thread(void)
 	memset(tsk->thread.tls_array, 0, sizeof(tsk->thread.tls_array));
 
 	fpu__clear_all(&tsk->thread.fpu);
+
+	pks_init_task(tsk);
 }
 
 void disable_TSC(void)
@@ -644,6 +668,8 @@  void __switch_to_xtra(struct task_struct *prev_p, struct task_struct *next_p)
 
 	if ((tifp ^ tifn) & _TIF_SLD)
 		switch_to_sld(tifn);
+
+	pks_sched_in();
 }
 
 /*
diff --git a/arch/x86/mm/pkeys.c b/arch/x86/mm/pkeys.c
index d1dfe743e79f..76a62419c446 100644
--- a/arch/x86/mm/pkeys.c
+++ b/arch/x86/mm/pkeys.c
@@ -231,3 +231,34 @@  u32 update_pkey_val(u32 pk_reg, int pkey, unsigned int flags)
 
 	return pk_reg;
 }
+
+DEFINE_PER_CPU(u32, pkrs_cache);
+
+/**
+ * write_pkrs() optimizes MSR writes by maintaining a per cpu cache which can
+ * be checked quickly.
+ *
+ * It should also be noted that the underlying WRMSR(MSR_IA32_PKRS) is not
+ * serializing but still maintains ordering properties similar to WRPKRU.
+ * The current SDM section on PKRS needs updating but should be the same as
+ * that of WRPKRU.  So to quote from the WRPKRU text:
+ *
+ *     WRPKRU will never execute transiently. Memory accesses
+ *     affected by PKRU register will not execute (even transiently)
+ *     until all prior executions of WRPKRU have completed execution
+ *     and updated the PKRU register.
+ */
+void write_pkrs(u32 new_pkrs)
+{
+	u32 *pkrs;
+
+	if (!static_cpu_has(X86_FEATURE_PKS))
+		return;
+
+	pkrs = get_cpu_ptr(&pkrs_cache);
+	if (*pkrs != new_pkrs) {
+		*pkrs = new_pkrs;
+		wrmsrl(MSR_IA32_PKRS, new_pkrs);
+	}
+	put_cpu_ptr(pkrs);
+}