diff mbox series

[v2,06/15] KVM: x86/mmu: Support GFN direct mask

Message ID 20240530210714.364118-7-rick.p.edgecombe@intel.com (mailing list archive)
State New, archived
Headers show
Series TDX MMU prep series part 1 | expand

Commit Message

Rick Edgecombe May 30, 2024, 9:07 p.m. UTC
From: Isaku Yamahata <isaku.yamahata@intel.com>

Teach the MMU to map guest GFNs at a massaged position on the TDP, to aid
in implementing TDX shared memory.

Like other Coco technologies, TDX has the concept of private and shared
memory. For TDX the private and shared mappings are managed on separate
EPT roots. The private half is managed indirectly though calls into a
protected runtime environment called the TDX module, where the shared half
is managed within KVM in normal page tables.

For TDX, the shared half will be mapped in the higher alias, with a "shared
bit" set in the GPA. However, KVM will still manage it with the same
memslots as the private half. This means memslot looks ups and zapping
operations will be provided with a GFN without the shared bit set.

So KVM will either need to apply or strip the shared bit before mapping or
zapping the shared EPT. Having GFN's sometimes have the shared bit and
sometimes not would make the code confusing.

So instead arrange the code such that GFNs never have shared bit set.
Create a concept of a "direct mask", that is stripped from the fault
address when setting fault->gfn, and applied within the TDP MMU iterator.
Calling code will behave as if is operating on the PTE mapping the GFN
(without shared bits) but within the iterator, the actual mappings will be
shifted using a mask specific for the root. Sp's will have the gfn set
without the shared bit. In the end the TDP MMU will behave like it is
mapping things at the GFN without the shared bit but with a strange page
table format where everything is offset by the shared bit.

Since TDX only needs to shift the mapping like this for the shared bit,
which is mapped as the normal TDP root, add a "gfn_direct_mask" field to
the kvm_arch structure for each VM with a default value of 0. It will be
set to the position of the GPA shared bit in GFN through TD specific
initialization code. Keep TDX specific concepts out of the MMU code by not
naming it "shared".

Ranged TLB flushes (i.e. flush_remote_tlbs_range()) target specific GFN
ranges. In convention established above, these would need to target the
shifted GFN range. It won't matter functionally, since the actual
implementation will always result in a full flush for the only planned
user (TDX). But for code clarity, explicitly do the full flush when a
gfn_direct_mask is present.

This leaves one drawback. Some operations use a concept of max gfn (i.e.
kvm_mmu_max_gfn()), to iterate over the whole TDP range. These would then
exceed the range actually covered by each root. It should only result in a
bit of extra iterating, and not cause functional problems. This will be
addressed in a future change.

Signed-off-by: Isaku Yamahata <isaku.yamahata@intel.com>
Co-developed-by: Rick Edgecombe <rick.p.edgecombe@intel.com>
Signed-off-by: Rick Edgecombe <rick.p.edgecombe@intel.com>
---
TDX MMU Prep v2:
 - Rename from "KVM: x86/mmu: Add address conversion functions for TDX shared bit of GPA"
 - Dropped Binbin's reviewed-by tag because of the extend of the changes
 - Rename gfn_shared_mask to gfn_direct_mask.
 - Don't include shared bits in GFNs, hide the existence in the TDP MMU
   iterator.
 - Don't do range flushes if a gfn_direct_mask is present.
---
 arch/x86/include/asm/kvm_host.h | 11 +++------
 arch/x86/kvm/mmu.h              |  5 ++++
 arch/x86/kvm/mmu/mmu_internal.h | 16 +++++++++++-
 arch/x86/kvm/mmu/tdp_iter.c     |  5 ++--
 arch/x86/kvm/mmu/tdp_iter.h     | 16 ++++++------
 arch/x86/kvm/mmu/tdp_mmu.c      | 43 ++++++++++++++++-----------------
 arch/x86/kvm/x86.c              | 10 ++++++++
 7 files changed, 66 insertions(+), 40 deletions(-)

Comments

Paolo Bonzini June 7, 2024, 7:59 a.m. UTC | #1
> Keep TDX specific concepts out of the MMU code by not
> naming it "shared".

I think that, more than keeping TDX specific concepts out of MMU code,
it is better to have a different name because it doesn't confuse
memory attributes with MMU concepts.

For example, SNP uses the same page tables for both shared and private
memory, as it handles them at the RMP level.

By the way, in patch 3 it still talks about "shared PT", please change
that to "direct SP" (so we have "direct", "external", "mirror").

Just one non-cosmetic request at the very end of the email.

On Thu, May 30, 2024 at 11:07 PM Rick Edgecombe
<rick.p.edgecombe@intel.com> wrote:
> +static inline gfn_t kvm_gfn_root_mask(const struct kvm *kvm, const struct kvm_mmu_page *root)
> +{
> +       if (is_mirror_sp(root))
> +               return 0;

Maybe add a comment:

/*
 * Since mirror SPs are used only for TDX, which maps private memory
 * at its "natural" GFN, no mask needs to be applied to them - and, dually,
 * we expect that the mask is only used for the shared PT.
 */

> +       return kvm_gfn_direct_mask(kvm);

Ok, please excuse me again for being fussy on the naming. Typically I
think of a "mask" as something that you check against, or something
that you do x &~ mask, not as something that you add. Maybe
kvm_gfn_root_offset and gfn_direct_offset?

I also thought of gfn_direct_fixed_bits, but I'm not sure it
translates as well to kvm_gfn_root_fixed_bits. Anyway, I'll leave it
to you to make a decision, speak up if you think it's not an
improvement or if (especially for fixed_bits) it results in too long
lines.

Fortunately this kind of change is decently easy to do with a
search/replace on the patch files themselves.

> +}
> +
>  static inline bool kvm_mmu_page_ad_need_write_protect(struct kvm_mmu_page *sp)
>  {
>         /*
> @@ -359,7 +368,12 @@ static inline int __kvm_mmu_do_page_fault(struct kvm_vcpu *vcpu, gpa_t cr2_or_gp
>         int r;
>
>         if (vcpu->arch.mmu->root_role.direct) {
> -               fault.gfn = fault.addr >> PAGE_SHIFT;
> +               /*
> +                * Things like memslots don't understand the concept of a shared
> +                * bit. Strip it so that the GFN can be used like normal, and the
> +                * fault.addr can be used when the shared bit is needed.
> +                */
> +               fault.gfn = gpa_to_gfn(fault.addr) & ~kvm_gfn_direct_mask(vcpu->kvm);
>                 fault.slot = kvm_vcpu_gfn_to_memslot(vcpu, fault.gfn);

Please add a comment to struct kvm_page_fault's gfn field, about how
it differs from addr.

> +       /* Mask applied to convert the GFN to the mapping GPA */
> +       gfn_t gfn_mask;

s/mask/offset/ or s/mask/fixed_bits/ here, if you go for it; won't
repeat myself below.

>         /* The level of the root page given to the iterator */
>         int root_level;
>         /* The lowest level the iterator should traverse to */
> @@ -120,18 +122,18 @@ struct tdp_iter {
>   * Iterates over every SPTE mapping the GFN range [start, end) in a
>   * preorder traversal.
>   */
> -#define for_each_tdp_pte_min_level(iter, root, min_level, start, end) \
> -       for (tdp_iter_start(&iter, root, min_level, start); \
> -            iter.valid && iter.gfn < end;                   \
> +#define for_each_tdp_pte_min_level(iter, kvm, root, min_level, start, end)               \
> +       for (tdp_iter_start(&iter, root, min_level, start, kvm_gfn_root_mask(kvm, root)); \
> +            iter.valid && iter.gfn < end;                                                \
>              tdp_iter_next(&iter))
>
> -#define for_each_tdp_pte(iter, root, start, end) \
> -       for_each_tdp_pte_min_level(iter, root, PG_LEVEL_4K, start, end)
> +#define for_each_tdp_pte(iter, kvm, root, start, end)                          \
> +       for_each_tdp_pte_min_level(iter, kvm, root, PG_LEVEL_4K, start, end)

Maybe add the kvm pointer / remove the mmu pointer in a separate patch
to make the mask-related changes easier to identify?

> diff --git a/arch/x86/kvm/x86.c b/arch/x86/kvm/x86.c
> index 7c593a081eba..0e6325b5f5e7 100644
> --- a/arch/x86/kvm/x86.c
> +++ b/arch/x86/kvm/x86.c
> @@ -13987,6 +13987,16 @@ int kvm_sev_es_string_io(struct kvm_vcpu *vcpu, unsigned int size,
>  }
>  EXPORT_SYMBOL_GPL(kvm_sev_es_string_io);
>
> +#ifdef __KVM_HAVE_ARCH_FLUSH_REMOTE_TLBS_RANGE
> +int kvm_arch_flush_remote_tlbs_range(struct kvm *kvm, gfn_t gfn, u64 nr_pages)
> +{
> +       if (!kvm_x86_ops.flush_remote_tlbs_range || kvm_gfn_direct_mask(kvm))

I think the code need not check kvm_gfn_direct_mask() here? In the old
patches that I have it check kvm_gfn_direct_mask() in the vmx/main.c
callback.

Paolo
Paolo Bonzini June 7, 2024, 8 a.m. UTC | #2
On Thu, May 30, 2024 at 11:07 PM Rick Edgecombe
<rick.p.edgecombe@intel.com> wrote:
> -                                                  u64 nr_pages)
> -{
> -       if (!kvm_x86_ops.flush_remote_tlbs_range)
> -               return -EOPNOTSUPP;
> -
> -       return static_call(kvm_x86_flush_remote_tlbs_range)(kvm, gfn, nr_pages);
> -}
> +int kvm_arch_flush_remote_tlbs_range(struct kvm *kvm, gfn_t gfn, u64 nr_pages);
>  #endif /* CONFIG_HYPERV */

Ah, since you are at it please move the prototype out of the #ifdef
CONFIG_HYPERV.

Paolo
Rick Edgecombe June 7, 2024, 6:39 p.m. UTC | #3
On Fri, 2024-06-07 at 09:59 +0200, Paolo Bonzini wrote:
> Just one non-cosmetic request at the very end of the email.
> 
> On Thu, May 30, 2024 at 11:07 PM Rick Edgecombe
> <rick.p.edgecombe@intel.com> wrote:
> > +static inline gfn_t kvm_gfn_root_mask(const struct kvm *kvm, const struct
> > kvm_mmu_page *root)
> > +{
> > +       if (is_mirror_sp(root))
> > +               return 0;
> 
> Maybe add a comment:
> 
> /*
>  * Since mirror SPs are used only for TDX, which maps private memory
>  * at its "natural" GFN, no mask needs to be applied to them - and, dually,
>  * we expect that the mask is only used for the shared PT.
>  */

Sure, seems like a good idea.

> 
> > +       return kvm_gfn_direct_mask(kvm);
> 
> Ok, please excuse me again for being fussy on the naming. Typically I
> think of a "mask" as something that you check against, or something
> that you do x &~ mask, not as something that you add. Maybe
> kvm_gfn_root_offset and gfn_direct_offset?
> 
> I also thought of gfn_direct_fixed_bits, but I'm not sure it
> translates as well to kvm_gfn_root_fixed_bits. Anyway, I'll leave it
> to you to make a decision, speak up if you think it's not an
> improvement or if (especially for fixed_bits) it results in too long
> lines.
> 
> Fortunately this kind of change is decently easy to do with a
> search/replace on the patch files themselves.

Yea, it's no problem to update the code. I'll be happy if this code is more
understandable for non-tdx developers.

As for the name, I guess I'd be less keen on "offset" because it's not clear
that it is a power-of-two value that can be used with bitwise operations. 

I'm not sure what the "fixed" adds and it makes it longer. Also, many PTE bits
cannot be moved and they are not referred to as fixed, where the shared bit
actually *can* be moved via GPAW (not that the MMU code cares about that
though).

Just "bits" sounds better to me, so maybe I'll try?
kvm_gfn_direct_bits()
kvm_gfn_root_bits()

> 
> > +}
> > +
> >   static inline bool kvm_mmu_page_ad_need_write_protect(struct kvm_mmu_page
> > *sp)
> >   {
> >          /*
> > @@ -359,7 +368,12 @@ static inline int __kvm_mmu_do_page_fault(struct
> > kvm_vcpu *vcpu, gpa_t cr2_or_gp
> >          int r;
> > 
> >          if (vcpu->arch.mmu->root_role.direct) {
> > -               fault.gfn = fault.addr >> PAGE_SHIFT;
> > +               /*
> > +                * Things like memslots don't understand the concept of a
> > shared
> > +                * bit. Strip it so that the GFN can be used like normal,
> > and the
> > +                * fault.addr can be used when the shared bit is needed.
> > +                */
> > +               fault.gfn = gpa_to_gfn(fault.addr) &
> > ~kvm_gfn_direct_mask(vcpu->kvm);
> >                  fault.slot = kvm_vcpu_gfn_to_memslot(vcpu, fault.gfn);
> 
> Please add a comment to struct kvm_page_fault's gfn field, about how
> it differs from addr.

Doh, yes totally.

> 
> > +       /* Mask applied to convert the GFN to the mapping GPA */
> > +       gfn_t gfn_mask;
> 
> s/mask/offset/ or s/mask/fixed_bits/ here, if you go for it; won't
> repeat myself below.
> 
> >          /* The level of the root page given to the iterator */
> >          int root_level;
> >          /* The lowest level the iterator should traverse to */
> > @@ -120,18 +122,18 @@ struct tdp_iter {
> >    * Iterates over every SPTE mapping the GFN range [start, end) in a
> >    * preorder traversal.
> >    */
> > -#define for_each_tdp_pte_min_level(iter, root, min_level, start, end) \
> > -       for (tdp_iter_start(&iter, root, min_level, start); \
> > -            iter.valid && iter.gfn < end;                   \
> > +#define for_each_tdp_pte_min_level(iter, kvm, root, min_level, start,
> > end)               \
> > +       for (tdp_iter_start(&iter, root, min_level, start,
> > kvm_gfn_root_mask(kvm, root)); \
> > +            iter.valid && iter.gfn <
> > end;                                                \
> >               tdp_iter_next(&iter))
> > 
> > -#define for_each_tdp_pte(iter, root, start, end) \
> > -       for_each_tdp_pte_min_level(iter, root, PG_LEVEL_4K, start, end)
> > +#define for_each_tdp_pte(iter, kvm, root, start,
> > end)                          \
> > +       for_each_tdp_pte_min_level(iter, kvm, root, PG_LEVEL_4K, start, end)
> 
> Maybe add the kvm pointer / remove the mmu pointer in a separate patch
> to make the mask-related changes easier to identify?

Hmm, yea. I can split it.

> 
> > diff --git a/arch/x86/kvm/x86.c b/arch/x86/kvm/x86.c
> > index 7c593a081eba..0e6325b5f5e7 100644
> > --- a/arch/x86/kvm/x86.c
> > +++ b/arch/x86/kvm/x86.c
> > @@ -13987,6 +13987,16 @@ int kvm_sev_es_string_io(struct kvm_vcpu *vcpu,
> > unsigned int size,
> >   }
> >   EXPORT_SYMBOL_GPL(kvm_sev_es_string_io);
> > 
> > +#ifdef __KVM_HAVE_ARCH_FLUSH_REMOTE_TLBS_RANGE
> > +int kvm_arch_flush_remote_tlbs_range(struct kvm *kvm, gfn_t gfn, u64
> > nr_pages)
> > +{
> > +       if (!kvm_x86_ops.flush_remote_tlbs_range ||
> > kvm_gfn_direct_mask(kvm))
> 
> I think the code need not check kvm_gfn_direct_mask() here? In the old
> patches that I have it check kvm_gfn_direct_mask() in the vmx/main.c
> callback.

You mean a VMX/TDX implementation of flush_remote_tlbs_range that just returns
-EOPNOTSUPP? Which version of the patches is this? I couldn't find anything like
that.

But I guess we could add one in the later patches. In which case we could drop
the hunk in this one. I see benefit being less churn.


The downside would be wider distribution of the concerns for dealing with
multiple aliases for a GFN. Currently, the behavior to have multiple aliases is
implemented in core MMU code. While it's fine to pollute tdx.c with TDX specific
knowledge of course, removing the handling of this corner from mmu.c might make
it less understandable for non-tdx readers who are working in MMU code.
Basically, if a concept fits into some non-TDX abstraction like this, having it
in core code seems the better default to me.

For this reason, my preference would be to leave the logic in core code. But I'm
fine changing it. I'll move it into the tdx.c for now, unless you are convinced
by the above.
Paolo Bonzini June 8, 2024, 8:52 a.m. UTC | #4
On Fri, Jun 7, 2024 at 8:39 PM Edgecombe, Rick P
<rick.p.edgecombe@intel.com> wrote:
> > > +       return kvm_gfn_direct_mask(kvm);
> >
> > Ok, please excuse me again for being fussy on the naming. Typically I
> > think of a "mask" as something that you check against, or something
> > that you do x &~ mask, not as something that you add. Maybe
> > kvm_gfn_root_offset and gfn_direct_offset?
>
> As for the name, I guess I'd be less keen on "offset" because it's not clear
> that it is a power-of-two value that can be used with bitwise operations.
>
> I'm not sure what the "fixed" adds and it makes it longer. Also, many PTE bits
> cannot be moved and they are not referred to as fixed, where the shared bit
> actually *can* be moved via GPAW (not that the MMU code cares about that
> though).
>
> Just "bits" sounds better to me, so maybe I'll try?
> kvm_gfn_direct_bits()
> kvm_gfn_root_bits()

Yep, kvm_gfn_direct_bits and kvm_gfn_root_bits are good.

Paolo
Paolo Bonzini June 8, 2024, 9:08 a.m. UTC | #5
On Fri, Jun 7, 2024 at 8:39 PM Edgecombe, Rick P
<rick.p.edgecombe@intel.com> wrote:
> > I think the code need not check kvm_gfn_direct_mask() here? In the old
> > patches that I have it check kvm_gfn_direct_mask() in the vmx/main.c
> > callback.
>
> You mean a VMX/TDX implementation of flush_remote_tlbs_range that just returns
> -EOPNOTSUPP? Which version of the patches is this? I couldn't find anything like
> that.

Something from Intel's GitHub, roughly June 2023... Looking at the
whole history, it starts with

     if (!kvm_x86_ops.flush_remote_tlbs_range)
         return -EOPNOTSUPP;

     return static_call(kvm_x86_flush_remote_tlbs_range)(kvm, gfn, nr_pages);

and it only assigns the callback in vmx.c (not main.c); then it adds
an implementation of the callback for TDX that has:

static int vt_flush_remote_tlbs_range(struct kvm *kvm, gfn_t gfn,
gfn_t nr_pages)
{
        if (is_td(kvm))
                return tdx_sept_flush_remote_tlbs_range(kvm, gfn, nr_pages);

        /* fallback to flush_remote_tlbs method */
        return -EOPNOTSUPP;
}

where the callback knows that it should flush both private GFN and
shared GFN. So I didn't remember it correctly, but still there is no
check for the presence of direct-mapping bits.

> The downside would be wider distribution of the concerns for dealing with
> multiple aliases for a GFN. Currently, the behavior to have multiple aliases is
> implemented in core MMU code. While it's fine to pollute tdx.c with TDX specific
> knowledge of course, removing the handling of this corner from mmu.c might make
> it less understandable for non-tdx readers who are working in MMU code.
> Basically, if a concept fits into some non-TDX abstraction like this, having it
> in core code seems the better default to me.

I am not sure why it's an MMU concept that "if you offset the shared
mappings you cannot implement flush_remote_tlbs_range". It seems more
like, you need to know what you're doing?

Right now it makes no difference because you don't set the callback;
but if you ever wanted to implement flush_remote_tlbs_range as an
optimization you'd have to remove the condition from the "if". So it's
better not to have it in the first place.

Perhaps add a comment instead, like:

     if (!kvm_x86_ops.flush_remote_tlbs_range)
         return -EOPNOTSUPP;

+    /*
+     * If applicable, the callback should flush GFNs both with and without
+     * the direct-mapping bits.
+     */
     return static_call(kvm_x86_flush_remote_tlbs_range)(kvm, gfn, nr_pages);

Paolo
Rick Edgecombe June 9, 2024, 11:25 p.m. UTC | #6
On Sat, 2024-06-08 at 11:08 +0200, Paolo Bonzini wrote:
> > The downside would be wider distribution of the concerns for dealing with
> > multiple aliases for a GFN. Currently, the behavior to have multiple aliases
> > is
> > implemented in core MMU code. While it's fine to pollute tdx.c with TDX
> > specific
> > knowledge of course, removing the handling of this corner from mmu.c might
> > make
> > it less understandable for non-tdx readers who are working in MMU code.
> > Basically, if a concept fits into some non-TDX abstraction like this, having
> > it
> > in core code seems the better default to me.
> 
> I am not sure why it's an MMU concept that "if you offset the shared
> mappings you cannot implement flush_remote_tlbs_range". It seems more
> like, you need to know what you're doing?
> 
> Right now it makes no difference because you don't set the callback;
> but if you ever wanted to implement flush_remote_tlbs_range as an
> optimization you'd have to remove the condition from the "if". So it's
> better not to have it in the first place.

Yea that's true.

> 
> Perhaps add a comment instead, like:
> 
>      if (!kvm_x86_ops.flush_remote_tlbs_range)
>          return -EOPNOTSUPP;
> 
> +    /*
> +     * If applicable, the callback should flush GFNs both with and without
> +     * the direct-mapping bits.
> +     */
>      return static_call(kvm_x86_flush_remote_tlbs_range)(kvm, gfn, nr_pages);

Ok, works for me.
diff mbox series

Patch

diff --git a/arch/x86/include/asm/kvm_host.h b/arch/x86/include/asm/kvm_host.h
index 084f4708aff1..c9af963ab897 100644
--- a/arch/x86/include/asm/kvm_host.h
+++ b/arch/x86/include/asm/kvm_host.h
@@ -1535,6 +1535,8 @@  struct kvm_arch {
 	 */
 #define SPLIT_DESC_CACHE_MIN_NR_OBJECTS (SPTE_ENT_PER_PAGE + 1)
 	struct kvm_mmu_memory_cache split_desc_cache;
+
+	gfn_t gfn_direct_mask;
 };
 
 struct kvm_vm_stat {
@@ -1908,14 +1910,7 @@  static inline int kvm_arch_flush_remote_tlbs(struct kvm *kvm)
 }
 
 #define __KVM_HAVE_ARCH_FLUSH_REMOTE_TLBS_RANGE
-static inline int kvm_arch_flush_remote_tlbs_range(struct kvm *kvm, gfn_t gfn,
-						   u64 nr_pages)
-{
-	if (!kvm_x86_ops.flush_remote_tlbs_range)
-		return -EOPNOTSUPP;
-
-	return static_call(kvm_x86_flush_remote_tlbs_range)(kvm, gfn, nr_pages);
-}
+int kvm_arch_flush_remote_tlbs_range(struct kvm *kvm, gfn_t gfn, u64 nr_pages);
 #endif /* CONFIG_HYPERV */
 
 enum kvm_intr_type {
diff --git a/arch/x86/kvm/mmu.h b/arch/x86/kvm/mmu.h
index 0c3bf89cf7db..f0713b6e4ee5 100644
--- a/arch/x86/kvm/mmu.h
+++ b/arch/x86/kvm/mmu.h
@@ -323,4 +323,9 @@  static inline bool kvm_has_mirrored_tdp(const struct kvm *kvm)
 {
 	return kvm->arch.vm_type == KVM_X86_TDX_VM;
 }
+
+static inline gfn_t kvm_gfn_direct_mask(const struct kvm *kvm)
+{
+	return kvm->arch.gfn_direct_mask;
+}
 #endif
diff --git a/arch/x86/kvm/mmu/mmu_internal.h b/arch/x86/kvm/mmu/mmu_internal.h
index 6d82e389cd65..076871c9e694 100644
--- a/arch/x86/kvm/mmu/mmu_internal.h
+++ b/arch/x86/kvm/mmu/mmu_internal.h
@@ -6,6 +6,8 @@ 
 #include <linux/kvm_host.h>
 #include <asm/kvm_host.h>
 
+#include "mmu.h"
+
 #ifdef CONFIG_KVM_PROVE_MMU
 #define KVM_MMU_WARN_ON(x) WARN_ON_ONCE(x)
 #else
@@ -189,6 +191,13 @@  static inline void kvm_mmu_alloc_private_spt(struct kvm_vcpu *vcpu, struct kvm_m
 	sp->mirrored_spt = kvm_mmu_memory_cache_alloc(&vcpu->arch.mmu_mirrored_spt_cache);
 }
 
+static inline gfn_t kvm_gfn_root_mask(const struct kvm *kvm, const struct kvm_mmu_page *root)
+{
+	if (is_mirror_sp(root))
+		return 0;
+	return kvm_gfn_direct_mask(kvm);
+}
+
 static inline bool kvm_mmu_page_ad_need_write_protect(struct kvm_mmu_page *sp)
 {
 	/*
@@ -359,7 +368,12 @@  static inline int __kvm_mmu_do_page_fault(struct kvm_vcpu *vcpu, gpa_t cr2_or_gp
 	int r;
 
 	if (vcpu->arch.mmu->root_role.direct) {
-		fault.gfn = fault.addr >> PAGE_SHIFT;
+		/*
+		 * Things like memslots don't understand the concept of a shared
+		 * bit. Strip it so that the GFN can be used like normal, and the
+		 * fault.addr can be used when the shared bit is needed.
+		 */
+		fault.gfn = gpa_to_gfn(fault.addr) & ~kvm_gfn_direct_mask(vcpu->kvm);
 		fault.slot = kvm_vcpu_gfn_to_memslot(vcpu, fault.gfn);
 	}
 
diff --git a/arch/x86/kvm/mmu/tdp_iter.c b/arch/x86/kvm/mmu/tdp_iter.c
index 04c247bfe318..a3bfe7fe473a 100644
--- a/arch/x86/kvm/mmu/tdp_iter.c
+++ b/arch/x86/kvm/mmu/tdp_iter.c
@@ -12,7 +12,7 @@ 
 static void tdp_iter_refresh_sptep(struct tdp_iter *iter)
 {
 	iter->sptep = iter->pt_path[iter->level - 1] +
-		SPTE_INDEX(iter->gfn << PAGE_SHIFT, iter->level);
+		SPTE_INDEX((iter->gfn | iter->gfn_mask) << PAGE_SHIFT, iter->level);
 	iter->old_spte = kvm_tdp_mmu_read_spte(iter->sptep);
 }
 
@@ -37,7 +37,7 @@  void tdp_iter_restart(struct tdp_iter *iter)
  * rooted at root_pt, starting with the walk to translate next_last_level_gfn.
  */
 void tdp_iter_start(struct tdp_iter *iter, struct kvm_mmu_page *root,
-		    int min_level, gfn_t next_last_level_gfn)
+		    int min_level, gfn_t next_last_level_gfn, gfn_t gfn_mask)
 {
 	if (WARN_ON_ONCE(!root || (root->role.level < 1) ||
 			 (root->role.level > PT64_ROOT_MAX_LEVEL))) {
@@ -46,6 +46,7 @@  void tdp_iter_start(struct tdp_iter *iter, struct kvm_mmu_page *root,
 	}
 
 	iter->next_last_level_gfn = next_last_level_gfn;
+	iter->gfn_mask = gfn_mask;
 	iter->root_level = root->role.level;
 	iter->min_level = min_level;
 	iter->pt_path[iter->root_level - 1] = (tdp_ptep_t)root->spt;
diff --git a/arch/x86/kvm/mmu/tdp_iter.h b/arch/x86/kvm/mmu/tdp_iter.h
index fae559559a80..6864d21edb4e 100644
--- a/arch/x86/kvm/mmu/tdp_iter.h
+++ b/arch/x86/kvm/mmu/tdp_iter.h
@@ -91,8 +91,10 @@  struct tdp_iter {
 	tdp_ptep_t pt_path[PT64_ROOT_MAX_LEVEL];
 	/* A pointer to the current SPTE */
 	tdp_ptep_t sptep;
-	/* The lowest GFN mapped by the current SPTE */
+	/* The lowest GFN (mask bits excluded) mapped by the current SPTE */
 	gfn_t gfn;
+	/* Mask applied to convert the GFN to the mapping GPA */
+	gfn_t gfn_mask;
 	/* The level of the root page given to the iterator */
 	int root_level;
 	/* The lowest level the iterator should traverse to */
@@ -120,18 +122,18 @@  struct tdp_iter {
  * Iterates over every SPTE mapping the GFN range [start, end) in a
  * preorder traversal.
  */
-#define for_each_tdp_pte_min_level(iter, root, min_level, start, end) \
-	for (tdp_iter_start(&iter, root, min_level, start); \
-	     iter.valid && iter.gfn < end;		     \
+#define for_each_tdp_pte_min_level(iter, kvm, root, min_level, start, end)		  \
+	for (tdp_iter_start(&iter, root, min_level, start, kvm_gfn_root_mask(kvm, root)); \
+	     iter.valid && iter.gfn < end;						  \
 	     tdp_iter_next(&iter))
 
-#define for_each_tdp_pte(iter, root, start, end) \
-	for_each_tdp_pte_min_level(iter, root, PG_LEVEL_4K, start, end)
+#define for_each_tdp_pte(iter, kvm, root, start, end)				\
+	for_each_tdp_pte_min_level(iter, kvm, root, PG_LEVEL_4K, start, end)
 
 tdp_ptep_t spte_to_child_pt(u64 pte, int level);
 
 void tdp_iter_start(struct tdp_iter *iter, struct kvm_mmu_page *root,
-		    int min_level, gfn_t next_last_level_gfn);
+		    int min_level, gfn_t next_last_level_gfn, gfn_t gfn_mask);
 void tdp_iter_next(struct tdp_iter *iter);
 void tdp_iter_restart(struct tdp_iter *iter);
 
diff --git a/arch/x86/kvm/mmu/tdp_mmu.c b/arch/x86/kvm/mmu/tdp_mmu.c
index 2770230a5636..ed93bba76483 100644
--- a/arch/x86/kvm/mmu/tdp_mmu.c
+++ b/arch/x86/kvm/mmu/tdp_mmu.c
@@ -674,18 +674,18 @@  static inline void tdp_mmu_iter_set_spte(struct kvm *kvm, struct tdp_iter *iter,
 					  iter->gfn, iter->level);
 }
 
-#define tdp_root_for_each_pte(_iter, _root, _start, _end) \
-	for_each_tdp_pte(_iter, _root, _start, _end)
+#define tdp_root_for_each_pte(_iter, _kvm, _root, _start, _end)	\
+	for_each_tdp_pte(_iter, _kvm, _root, _start, _end)
 
-#define tdp_root_for_each_leaf_pte(_iter, _root, _start, _end)	\
-	tdp_root_for_each_pte(_iter, _root, _start, _end)		\
+#define tdp_root_for_each_leaf_pte(_iter, _kvm, _root, _start, _end)	\
+	tdp_root_for_each_pte(_iter, _kvm, _root, _start, _end)		\
 		if (!is_shadow_present_pte(_iter.old_spte) ||		\
 		    !is_last_spte(_iter.old_spte, _iter.level))		\
 			continue;					\
 		else
 
-#define tdp_mmu_for_each_pte(_iter, _mmu, _start, _end)		\
-	for_each_tdp_pte(_iter, root_to_sp(_mmu->root.hpa), _start, _end)
+#define tdp_mmu_for_each_pte(_iter, _kvm, _root, _start, _end)	\
+	for_each_tdp_pte(_iter, _kvm, _root, _start, _end)
 
 /*
  * Yield if the MMU lock is contended or this thread needs to return control
@@ -751,7 +751,7 @@  static void __tdp_mmu_zap_root(struct kvm *kvm, struct kvm_mmu_page *root,
 	gfn_t end = tdp_mmu_max_gfn_exclusive();
 	gfn_t start = 0;
 
-	for_each_tdp_pte_min_level(iter, root, zap_level, start, end) {
+	for_each_tdp_pte_min_level(iter, kvm, root, zap_level, start, end) {
 retry:
 		if (tdp_mmu_iter_cond_resched(kvm, &iter, false, shared))
 			continue;
@@ -855,7 +855,7 @@  static bool tdp_mmu_zap_leafs(struct kvm *kvm, struct kvm_mmu_page *root,
 
 	rcu_read_lock();
 
-	for_each_tdp_pte_min_level(iter, root, PG_LEVEL_4K, start, end) {
+	for_each_tdp_pte_min_level(iter, kvm, root, PG_LEVEL_4K, start, end) {
 		if (can_yield &&
 		    tdp_mmu_iter_cond_resched(kvm, &iter, flush, false)) {
 			flush = false;
@@ -1104,8 +1104,8 @@  static int tdp_mmu_split_huge_page(struct kvm *kvm, struct tdp_iter *iter,
  */
 int kvm_tdp_mmu_map(struct kvm_vcpu *vcpu, struct kvm_page_fault *fault)
 {
-	struct kvm_mmu *mmu = vcpu->arch.mmu;
 	struct kvm *kvm = vcpu->kvm;
+	struct kvm_mmu_page *root = root_to_sp(vcpu->arch.mmu->root.hpa);
 	struct tdp_iter iter;
 	struct kvm_mmu_page *sp;
 	int ret = RET_PF_RETRY;
@@ -1115,8 +1115,7 @@  int kvm_tdp_mmu_map(struct kvm_vcpu *vcpu, struct kvm_page_fault *fault)
 	trace_kvm_mmu_spte_requested(fault);
 
 	rcu_read_lock();
-
-	tdp_mmu_for_each_pte(iter, mmu, fault->gfn, fault->gfn + 1) {
+	tdp_mmu_for_each_pte(iter, vcpu->kvm, root, fault->gfn, fault->gfn + 1) {
 		int r;
 
 		if (fault->nx_huge_page_workaround_enabled)
@@ -1214,7 +1213,7 @@  static __always_inline bool kvm_tdp_mmu_handle_gfn(struct kvm *kvm,
 	for_each_tdp_mmu_root(kvm, root, range->slot->as_id) {
 		rcu_read_lock();
 
-		tdp_root_for_each_leaf_pte(iter, root, range->start, range->end)
+		tdp_root_for_each_leaf_pte(iter, kvm, root, range->start, range->end)
 			ret |= handler(kvm, &iter, range);
 
 		rcu_read_unlock();
@@ -1297,7 +1296,7 @@  static bool wrprot_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
 
 	BUG_ON(min_level > KVM_MAX_HUGEPAGE_LEVEL);
 
-	for_each_tdp_pte_min_level(iter, root, min_level, start, end) {
+	for_each_tdp_pte_min_level(iter, kvm, root, min_level, start, end) {
 retry:
 		if (tdp_mmu_iter_cond_resched(kvm, &iter, false, true))
 			continue;
@@ -1460,7 +1459,7 @@  static int tdp_mmu_split_huge_pages_root(struct kvm *kvm,
 	 * level above the target level (e.g. splitting a 1GB to 512 2MB pages,
 	 * and then splitting each of those to 512 4KB pages).
 	 */
-	for_each_tdp_pte_min_level(iter, root, target_level + 1, start, end) {
+	for_each_tdp_pte_min_level(iter, kvm, root, target_level + 1, start, end) {
 retry:
 		if (tdp_mmu_iter_cond_resched(kvm, &iter, false, shared))
 			continue;
@@ -1545,7 +1544,7 @@  static bool clear_dirty_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
 
 	rcu_read_lock();
 
-	tdp_root_for_each_pte(iter, root, start, end) {
+	tdp_root_for_each_pte(iter, kvm, root, start, end) {
 retry:
 		if (!is_shadow_present_pte(iter.old_spte) ||
 		    !is_last_spte(iter.old_spte, iter.level))
@@ -1600,7 +1599,7 @@  static void clear_dirty_pt_masked(struct kvm *kvm, struct kvm_mmu_page *root,
 
 	rcu_read_lock();
 
-	tdp_root_for_each_leaf_pte(iter, root, gfn + __ffs(mask),
+	tdp_root_for_each_leaf_pte(iter, kvm, root, gfn + __ffs(mask),
 				    gfn + BITS_PER_LONG) {
 		if (!mask)
 			break;
@@ -1657,7 +1656,7 @@  static void zap_collapsible_spte_range(struct kvm *kvm,
 
 	rcu_read_lock();
 
-	for_each_tdp_pte_min_level(iter, root, PG_LEVEL_2M, start, end) {
+	for_each_tdp_pte_min_level(iter, kvm, root, PG_LEVEL_2M, start, end) {
 retry:
 		if (tdp_mmu_iter_cond_resched(kvm, &iter, false, true))
 			continue;
@@ -1727,7 +1726,7 @@  static bool write_protect_gfn(struct kvm *kvm, struct kvm_mmu_page *root,
 
 	rcu_read_lock();
 
-	for_each_tdp_pte_min_level(iter, root, min_level, gfn, gfn + 1) {
+	for_each_tdp_pte_min_level(iter, kvm, root, min_level, gfn, gfn + 1) {
 		if (!is_shadow_present_pte(iter.old_spte) ||
 		    !is_last_spte(iter.old_spte, iter.level))
 			continue;
@@ -1775,14 +1774,14 @@  bool kvm_tdp_mmu_write_protect_gfn(struct kvm *kvm,
 int kvm_tdp_mmu_get_walk(struct kvm_vcpu *vcpu, u64 addr, u64 *sptes,
 			 int *root_level)
 {
+	struct kvm_mmu_page *root = root_to_sp(vcpu->arch.mmu->root.hpa);
 	struct tdp_iter iter;
-	struct kvm_mmu *mmu = vcpu->arch.mmu;
 	gfn_t gfn = addr >> PAGE_SHIFT;
 	int leaf = -1;
 
 	*root_level = vcpu->arch.mmu->root_role.level;
 
-	tdp_mmu_for_each_pte(iter, mmu, gfn, gfn + 1) {
+	tdp_mmu_for_each_pte(iter, vcpu->kvm, root, gfn, gfn + 1) {
 		leaf = iter.level;
 		sptes[leaf] = iter.old_spte;
 	}
@@ -1804,12 +1803,12 @@  int kvm_tdp_mmu_get_walk(struct kvm_vcpu *vcpu, u64 addr, u64 *sptes,
 u64 *kvm_tdp_mmu_fast_pf_get_last_sptep(struct kvm_vcpu *vcpu, u64 addr,
 					u64 *spte)
 {
+	struct kvm_mmu_page *root = root_to_sp(vcpu->arch.mmu->root.hpa);
 	struct tdp_iter iter;
-	struct kvm_mmu *mmu = vcpu->arch.mmu;
 	gfn_t gfn = addr >> PAGE_SHIFT;
 	tdp_ptep_t sptep = NULL;
 
-	tdp_mmu_for_each_pte(iter, mmu, gfn, gfn + 1) {
+	tdp_mmu_for_each_pte(iter, vcpu->kvm, root, gfn, gfn + 1) {
 		*spte = iter.old_spte;
 		sptep = iter.sptep;
 	}
diff --git a/arch/x86/kvm/x86.c b/arch/x86/kvm/x86.c
index 7c593a081eba..0e6325b5f5e7 100644
--- a/arch/x86/kvm/x86.c
+++ b/arch/x86/kvm/x86.c
@@ -13987,6 +13987,16 @@  int kvm_sev_es_string_io(struct kvm_vcpu *vcpu, unsigned int size,
 }
 EXPORT_SYMBOL_GPL(kvm_sev_es_string_io);
 
+#ifdef __KVM_HAVE_ARCH_FLUSH_REMOTE_TLBS_RANGE
+int kvm_arch_flush_remote_tlbs_range(struct kvm *kvm, gfn_t gfn, u64 nr_pages)
+{
+	if (!kvm_x86_ops.flush_remote_tlbs_range || kvm_gfn_direct_mask(kvm))
+		return -EOPNOTSUPP;
+
+	return static_call(kvm_x86_flush_remote_tlbs_range)(kvm, gfn, nr_pages);
+}
+#endif
+
 EXPORT_TRACEPOINT_SYMBOL_GPL(kvm_entry);
 EXPORT_TRACEPOINT_SYMBOL_GPL(kvm_exit);
 EXPORT_TRACEPOINT_SYMBOL_GPL(kvm_fast_mmio);