diff mbox series

[RFC,v2,06/26] KVM: arm64: Factor memory allocation out of pgtable.c

Message ID 20210108121524.656872-7-qperret@google.com (mailing list archive)
State New, archived
Headers show
Series KVM/arm64: A stage 2 for the host | expand

Commit Message

Quentin Perret Jan. 8, 2021, 12:15 p.m. UTC
In preparation for enabling the creation of page-tables at EL2, factor
all memory allocation out of the page-table code, hence making it
re-usable with any compatible memory allocator.

No functional changes intended.

Signed-off-by: Quentin Perret <qperret@google.com>
---
 arch/arm64/include/asm/kvm_pgtable.h | 32 +++++++++-
 arch/arm64/kvm/hyp/pgtable.c         | 90 +++++++++++++++++-----------
 arch/arm64/kvm/mmu.c                 | 70 +++++++++++++++++++++-
 3 files changed, 154 insertions(+), 38 deletions(-)

Comments

Will Deacon Feb. 1, 2021, 6:16 p.m. UTC | #1
On Fri, Jan 08, 2021 at 12:15:04PM +0000, Quentin Perret wrote:
> In preparation for enabling the creation of page-tables at EL2, factor
> all memory allocation out of the page-table code, hence making it
> re-usable with any compatible memory allocator.
> 
> No functional changes intended.
> 
> Signed-off-by: Quentin Perret <qperret@google.com>
> ---
>  arch/arm64/include/asm/kvm_pgtable.h | 32 +++++++++-
>  arch/arm64/kvm/hyp/pgtable.c         | 90 +++++++++++++++++-----------
>  arch/arm64/kvm/mmu.c                 | 70 +++++++++++++++++++++-
>  3 files changed, 154 insertions(+), 38 deletions(-)
> 
> diff --git a/arch/arm64/include/asm/kvm_pgtable.h b/arch/arm64/include/asm/kvm_pgtable.h
> index 52ab38db04c7..45acc9dc6c45 100644
> --- a/arch/arm64/include/asm/kvm_pgtable.h
> +++ b/arch/arm64/include/asm/kvm_pgtable.h
> @@ -13,17 +13,41 @@
>  
>  typedef u64 kvm_pte_t;
>  
> +/**
> + * struct kvm_pgtable_mm_ops - Memory management callbacks.
> + * @zalloc_page:	Allocate a zeroed memory page.

Please describe the 'arg' parameter.

> + * @zalloc_pages_exact:	Allocate an exact number of zeroed memory pages.

I think this comment coulld be expanded somewhat to make it clear that (a)
the 'size' parameter is in bytes rather than pages (b) the rounding
behaviour applied if 'size' is not page-aligned and (c) that the resulting
allocation is physically contiguous.

> + * @free_pages_exact:	Free an exact number of memory pages.
> + * @get_page:		Increment the refcount on a page.
> + * @put_page:		Decrement the refcount on a page.
> + * @page_count:		Returns the refcount of a page.
> + * @phys_to_virt:	Convert a physical address into a virtual address.
> + * @virt_to_phys:	Convert a virtual address into a physical address.

I think it would be good to be explicit about the nature of the virtual
address here. We've dealing with virtual addresses that are mapped in the
current context rather than e.g. guest virtual addresses.

> + */
> +struct kvm_pgtable_mm_ops {
> +	void*		(*zalloc_page)(void *arg);
> +	void*		(*zalloc_pages_exact)(size_t size);
> +	void		(*free_pages_exact)(void *addr, size_t size);
> +	void		(*get_page)(void *addr);
> +	void		(*put_page)(void *addr);
> +	int		(*page_count)(void *addr);
> +	void*		(*phys_to_virt)(phys_addr_t phys);
> +	phys_addr_t	(*virt_to_phys)(void *addr);
> +};

[...]

> diff --git a/arch/arm64/kvm/mmu.c b/arch/arm64/kvm/mmu.c
> index 1f41173e6149..278e163beda4 100644
> --- a/arch/arm64/kvm/mmu.c
> +++ b/arch/arm64/kvm/mmu.c
> @@ -88,6 +88,48 @@ static bool kvm_is_device_pfn(unsigned long pfn)
>  	return !pfn_valid(pfn);
>  }
>  
> +static void *stage2_memcache_alloc_page(void *arg)
> +{
> +	struct kvm_mmu_memory_cache *mc = arg;
> +	kvm_pte_t *ptep = NULL;
> +
> +	/* Allocated with GFP_KERNEL_ACCOUNT, so no need to zero */

I couldn't spot where GFP_KERNEL_ACCOUNT implies __GFP_ZERO. Please can you
elaborate?

> +	if (mc && mc->nobjs)
> +		ptep = mc->objects[--mc->nobjs];
> +
> +	return ptep;
> +}

Why can't we use kvm_mmu_memory_cache_alloc() directly instead of opening up
the memory_cache?

> +static void *kvm_host_zalloc_pages_exact(size_t size)
> +{
> +	return alloc_pages_exact(size, GFP_KERNEL_ACCOUNT | __GFP_ZERO);

Hmm, so now we're passing __GFP_ZERO? ;)

> +static void kvm_host_get_page(void *addr)
> +{
> +	get_page(virt_to_page(addr));
> +}
> +
> +static void kvm_host_put_page(void *addr)
> +{
> +	put_page(virt_to_page(addr));
> +}
> +
> +static int kvm_host_page_count(void *addr)
> +{
> +	return page_count(virt_to_page(addr));
> +}
> +
> +static phys_addr_t kvm_host_pa(void *addr)
> +{
> +	return __pa(addr);
> +}
> +
> +static void *kvm_host_va(phys_addr_t phys)
> +{
> +	return __va(phys);
> +}
> +
>  /*
>   * Unmapping vs dcache management:
>   *
> @@ -351,6 +393,17 @@ int create_hyp_exec_mappings(phys_addr_t phys_addr, size_t size,
>  	return 0;
>  }
>  
> +static struct kvm_pgtable_mm_ops kvm_s2_mm_ops = {
> +	.zalloc_page		= stage2_memcache_alloc_page,
> +	.zalloc_pages_exact	= kvm_host_zalloc_pages_exact,
> +	.free_pages_exact	= free_pages_exact,
> +	.get_page		= kvm_host_get_page,
> +	.put_page		= kvm_host_put_page,
> +	.page_count		= kvm_host_page_count,
> +	.phys_to_virt		= kvm_host_va,
> +	.virt_to_phys		= kvm_host_pa,
> +};

Idle thought, but I wonder whether it would be better to have these
implementations as the default and make the mm_ops structure parameter
to kvm_pgtable_stage2_init() optional? I guess you don't gain an awful
lot though, so feel free to ignore me.

Will
Quentin Perret Feb. 1, 2021, 6:32 p.m. UTC | #2
On Monday 01 Feb 2021 at 18:16:08 (+0000), Will Deacon wrote:
> On Fri, Jan 08, 2021 at 12:15:04PM +0000, Quentin Perret wrote:
> > In preparation for enabling the creation of page-tables at EL2, factor
> > all memory allocation out of the page-table code, hence making it
> > re-usable with any compatible memory allocator.
> > 
> > No functional changes intended.
> > 
> > Signed-off-by: Quentin Perret <qperret@google.com>
> > ---
> >  arch/arm64/include/asm/kvm_pgtable.h | 32 +++++++++-
> >  arch/arm64/kvm/hyp/pgtable.c         | 90 +++++++++++++++++-----------
> >  arch/arm64/kvm/mmu.c                 | 70 +++++++++++++++++++++-
> >  3 files changed, 154 insertions(+), 38 deletions(-)
> > 
> > diff --git a/arch/arm64/include/asm/kvm_pgtable.h b/arch/arm64/include/asm/kvm_pgtable.h
> > index 52ab38db04c7..45acc9dc6c45 100644
> > --- a/arch/arm64/include/asm/kvm_pgtable.h
> > +++ b/arch/arm64/include/asm/kvm_pgtable.h
> > @@ -13,17 +13,41 @@
> >  
> >  typedef u64 kvm_pte_t;
> >  
> > +/**
> > + * struct kvm_pgtable_mm_ops - Memory management callbacks.
> > + * @zalloc_page:	Allocate a zeroed memory page.
> 
> Please describe the 'arg' parameter.
> 
> > + * @zalloc_pages_exact:	Allocate an exact number of zeroed memory pages.
> 
> I think this comment coulld be expanded somewhat to make it clear that (a)
> the 'size' parameter is in bytes rather than pages (b) the rounding
> behaviour applied if 'size' is not page-aligned and (c) that the resulting
> allocation is physically contiguous.
> 
> > + * @free_pages_exact:	Free an exact number of memory pages.
> > + * @get_page:		Increment the refcount on a page.
> > + * @put_page:		Decrement the refcount on a page.
> > + * @page_count:		Returns the refcount of a page.
> > + * @phys_to_virt:	Convert a physical address into a virtual address.
> > + * @virt_to_phys:	Convert a virtual address into a physical address.
> 
> I think it would be good to be explicit about the nature of the virtual
> address here. We've dealing with virtual addresses that are mapped in the
> current context rather than e.g. guest virtual addresses.

Ack to all the above.

> > + */
> > +struct kvm_pgtable_mm_ops {
> > +	void*		(*zalloc_page)(void *arg);
> > +	void*		(*zalloc_pages_exact)(size_t size);
> > +	void		(*free_pages_exact)(void *addr, size_t size);
> > +	void		(*get_page)(void *addr);
> > +	void		(*put_page)(void *addr);
> > +	int		(*page_count)(void *addr);
> > +	void*		(*phys_to_virt)(phys_addr_t phys);
> > +	phys_addr_t	(*virt_to_phys)(void *addr);
> > +};
> 
> [...]
> 
> > diff --git a/arch/arm64/kvm/mmu.c b/arch/arm64/kvm/mmu.c
> > index 1f41173e6149..278e163beda4 100644
> > --- a/arch/arm64/kvm/mmu.c
> > +++ b/arch/arm64/kvm/mmu.c
> > @@ -88,6 +88,48 @@ static bool kvm_is_device_pfn(unsigned long pfn)
> >  	return !pfn_valid(pfn);
> >  }
> >  
> > +static void *stage2_memcache_alloc_page(void *arg)
> > +{
> > +	struct kvm_mmu_memory_cache *mc = arg;
> > +	kvm_pte_t *ptep = NULL;
> > +
> > +	/* Allocated with GFP_KERNEL_ACCOUNT, so no need to zero */
> 
> I couldn't spot where GFP_KERNEL_ACCOUNT implies __GFP_ZERO.

I'm not suprised, it doesn't. Broken comment clearly, I'll fix with
s/GFP_KERNEL_ACCOUNT/__GFP_ZERO

> Please can you elaborate?
> 
> > +	if (mc && mc->nobjs)
> > +		ptep = mc->objects[--mc->nobjs];
> > +
> > +	return ptep;
> > +}
> 
> Why can't we use kvm_mmu_memory_cache_alloc() directly instead of opening up
> the memory_cache?

I think we can -- that function didn't exist when I first wrote this,
but no good reason not to use it now.

> > +static void *kvm_host_zalloc_pages_exact(size_t size)
> > +{
> > +	return alloc_pages_exact(size, GFP_KERNEL_ACCOUNT | __GFP_ZERO);
> 
> Hmm, so now we're passing __GFP_ZERO? ;)

:-)

> > +static void kvm_host_get_page(void *addr)
> > +{
> > +	get_page(virt_to_page(addr));
> > +}
> > +
> > +static void kvm_host_put_page(void *addr)
> > +{
> > +	put_page(virt_to_page(addr));
> > +}
> > +
> > +static int kvm_host_page_count(void *addr)
> > +{
> > +	return page_count(virt_to_page(addr));
> > +}
> > +
> > +static phys_addr_t kvm_host_pa(void *addr)
> > +{
> > +	return __pa(addr);
> > +}
> > +
> > +static void *kvm_host_va(phys_addr_t phys)
> > +{
> > +	return __va(phys);
> > +}
> > +
> >  /*
> >   * Unmapping vs dcache management:
> >   *
> > @@ -351,6 +393,17 @@ int create_hyp_exec_mappings(phys_addr_t phys_addr, size_t size,
> >  	return 0;
> >  }
> >  
> > +static struct kvm_pgtable_mm_ops kvm_s2_mm_ops = {
> > +	.zalloc_page		= stage2_memcache_alloc_page,
> > +	.zalloc_pages_exact	= kvm_host_zalloc_pages_exact,
> > +	.free_pages_exact	= free_pages_exact,
> > +	.get_page		= kvm_host_get_page,
> > +	.put_page		= kvm_host_put_page,
> > +	.page_count		= kvm_host_page_count,
> > +	.phys_to_virt		= kvm_host_va,
> > +	.virt_to_phys		= kvm_host_pa,
> > +};
> 
> Idle thought, but I wonder whether it would be better to have these
> implementations as the default and make the mm_ops structure parameter
> to kvm_pgtable_stage2_init() optional? I guess you don't gain an awful
> lot though, so feel free to ignore me.

No strong opinion really, but I suppose I could do something as simple
as having static inline wrappers which provide kvm_s2_mm_ops to the
pgtable API for me. I'll probably want to make sure these are not
defined when compiling EL2 code, though, to avoid confusion.

Or maybe you had something else in mind?

Cheers,
Quentin
Will Deacon Feb. 1, 2021, 6:39 p.m. UTC | #3
On Mon, Feb 01, 2021 at 06:32:52PM +0000, Quentin Perret wrote:
> On Monday 01 Feb 2021 at 18:16:08 (+0000), Will Deacon wrote:
> > On Fri, Jan 08, 2021 at 12:15:04PM +0000, Quentin Perret wrote:
> > > +static struct kvm_pgtable_mm_ops kvm_s2_mm_ops = {
> > > +	.zalloc_page		= stage2_memcache_alloc_page,
> > > +	.zalloc_pages_exact	= kvm_host_zalloc_pages_exact,
> > > +	.free_pages_exact	= free_pages_exact,
> > > +	.get_page		= kvm_host_get_page,
> > > +	.put_page		= kvm_host_put_page,
> > > +	.page_count		= kvm_host_page_count,
> > > +	.phys_to_virt		= kvm_host_va,
> > > +	.virt_to_phys		= kvm_host_pa,
> > > +};
> > 
> > Idle thought, but I wonder whether it would be better to have these
> > implementations as the default and make the mm_ops structure parameter
> > to kvm_pgtable_stage2_init() optional? I guess you don't gain an awful
> > lot though, so feel free to ignore me.
> 
> No strong opinion really, but I suppose I could do something as simple
> as having static inline wrappers which provide kvm_s2_mm_ops to the
> pgtable API for me. I'll probably want to make sure these are not
> defined when compiling EL2 code, though, to avoid confusion.
> 
> Or maybe you had something else in mind?

No, just food for thought. If we can reduce the changes for normal KVM then
it's probably worth considering if it doesn't add divergent code paths. But
I'm also fine with the proposal you have here, so if it doesn't work then
don't get hung up on it.

Will
diff mbox series

Patch

diff --git a/arch/arm64/include/asm/kvm_pgtable.h b/arch/arm64/include/asm/kvm_pgtable.h
index 52ab38db04c7..45acc9dc6c45 100644
--- a/arch/arm64/include/asm/kvm_pgtable.h
+++ b/arch/arm64/include/asm/kvm_pgtable.h
@@ -13,17 +13,41 @@ 
 
 typedef u64 kvm_pte_t;
 
+/**
+ * struct kvm_pgtable_mm_ops - Memory management callbacks.
+ * @zalloc_page:	Allocate a zeroed memory page.
+ * @zalloc_pages_exact:	Allocate an exact number of zeroed memory pages.
+ * @free_pages_exact:	Free an exact number of memory pages.
+ * @get_page:		Increment the refcount on a page.
+ * @put_page:		Decrement the refcount on a page.
+ * @page_count:		Returns the refcount of a page.
+ * @phys_to_virt:	Convert a physical address into a virtual address.
+ * @virt_to_phys:	Convert a virtual address into a physical address.
+ */
+struct kvm_pgtable_mm_ops {
+	void*		(*zalloc_page)(void *arg);
+	void*		(*zalloc_pages_exact)(size_t size);
+	void		(*free_pages_exact)(void *addr, size_t size);
+	void		(*get_page)(void *addr);
+	void		(*put_page)(void *addr);
+	int		(*page_count)(void *addr);
+	void*		(*phys_to_virt)(phys_addr_t phys);
+	phys_addr_t	(*virt_to_phys)(void *addr);
+};
+
 /**
  * struct kvm_pgtable - KVM page-table.
  * @ia_bits:		Maximum input address size, in bits.
  * @start_level:	Level at which the page-table walk starts.
  * @pgd:		Pointer to the first top-level entry of the page-table.
+ * @mm_ops:		Memory management callbacks.
  * @mmu:		Stage-2 KVM MMU struct. Unused for stage-1 page-tables.
  */
 struct kvm_pgtable {
 	u32					ia_bits;
 	u32					start_level;
 	kvm_pte_t				*pgd;
+	struct kvm_pgtable_mm_ops		*mm_ops;
 
 	/* Stage-2 only */
 	struct kvm_s2_mmu			*mmu;
@@ -86,10 +110,12 @@  struct kvm_pgtable_walker {
  * kvm_pgtable_hyp_init() - Initialise a hypervisor stage-1 page-table.
  * @pgt:	Uninitialised page-table structure to initialise.
  * @va_bits:	Maximum virtual address bits.
+ * @mm_ops:	Memory management callbacks.
  *
  * Return: 0 on success, negative error code on failure.
  */
-int kvm_pgtable_hyp_init(struct kvm_pgtable *pgt, u32 va_bits);
+int kvm_pgtable_hyp_init(struct kvm_pgtable *pgt, u32 va_bits,
+			 struct kvm_pgtable_mm_ops *mm_ops);
 
 /**
  * kvm_pgtable_hyp_destroy() - Destroy an unused hypervisor stage-1 page-table.
@@ -126,10 +152,12 @@  int kvm_pgtable_hyp_map(struct kvm_pgtable *pgt, u64 addr, u64 size, u64 phys,
  * kvm_pgtable_stage2_init() - Initialise a guest stage-2 page-table.
  * @pgt:	Uninitialised page-table structure to initialise.
  * @kvm:	KVM structure representing the guest virtual machine.
+ * @mm_ops:	Memory management callbacks.
  *
  * Return: 0 on success, negative error code on failure.
  */
-int kvm_pgtable_stage2_init(struct kvm_pgtable *pgt, struct kvm *kvm);
+int kvm_pgtable_stage2_init(struct kvm_pgtable *pgt, struct kvm *kvm,
+			    struct kvm_pgtable_mm_ops *mm_ops);
 
 /**
  * kvm_pgtable_stage2_destroy() - Destroy an unused guest stage-2 page-table.
diff --git a/arch/arm64/kvm/hyp/pgtable.c b/arch/arm64/kvm/hyp/pgtable.c
index d7122c5eac24..61a8a34ddfdb 100644
--- a/arch/arm64/kvm/hyp/pgtable.c
+++ b/arch/arm64/kvm/hyp/pgtable.c
@@ -148,9 +148,9 @@  static kvm_pte_t kvm_phys_to_pte(u64 pa)
 	return pte;
 }
 
-static kvm_pte_t *kvm_pte_follow(kvm_pte_t pte)
+static kvm_pte_t *kvm_pte_follow(kvm_pte_t pte, struct kvm_pgtable_mm_ops *mm_ops)
 {
-	return __va(kvm_pte_to_phys(pte));
+	return mm_ops->phys_to_virt(kvm_pte_to_phys(pte));
 }
 
 static void kvm_set_invalid_pte(kvm_pte_t *ptep)
@@ -159,9 +159,10 @@  static void kvm_set_invalid_pte(kvm_pte_t *ptep)
 	WRITE_ONCE(*ptep, pte & ~KVM_PTE_VALID);
 }
 
-static void kvm_set_table_pte(kvm_pte_t *ptep, kvm_pte_t *childp)
+static void kvm_set_table_pte(kvm_pte_t *ptep, kvm_pte_t *childp,
+			      struct kvm_pgtable_mm_ops *mm_ops)
 {
-	kvm_pte_t old = *ptep, pte = kvm_phys_to_pte(__pa(childp));
+	kvm_pte_t old = *ptep, pte = kvm_phys_to_pte(mm_ops->virt_to_phys(childp));
 
 	pte |= FIELD_PREP(KVM_PTE_TYPE, KVM_PTE_TYPE_TABLE);
 	pte |= KVM_PTE_VALID;
@@ -229,7 +230,7 @@  static inline int __kvm_pgtable_visit(struct kvm_pgtable_walk_data *data,
 		goto out;
 	}
 
-	childp = kvm_pte_follow(pte);
+	childp = kvm_pte_follow(pte, data->pgt->mm_ops);
 	ret = __kvm_pgtable_walk(data, childp, level + 1);
 	if (ret)
 		goto out;
@@ -304,8 +305,9 @@  int kvm_pgtable_walk(struct kvm_pgtable *pgt, u64 addr, u64 size,
 }
 
 struct hyp_map_data {
-	u64		phys;
-	kvm_pte_t	attr;
+	u64				phys;
+	kvm_pte_t			attr;
+	struct kvm_pgtable_mm_ops	*mm_ops;
 };
 
 static int hyp_map_set_prot_attr(enum kvm_pgtable_prot prot,
@@ -355,6 +357,8 @@  static int hyp_map_walker(u64 addr, u64 end, u32 level, kvm_pte_t *ptep,
 			  enum kvm_pgtable_walk_flags flag, void * const arg)
 {
 	kvm_pte_t *childp;
+	struct hyp_map_data *data = arg;
+	struct kvm_pgtable_mm_ops *mm_ops = data->mm_ops;
 
 	if (hyp_map_walker_try_leaf(addr, end, level, ptep, arg))
 		return 0;
@@ -362,11 +366,11 @@  static int hyp_map_walker(u64 addr, u64 end, u32 level, kvm_pte_t *ptep,
 	if (WARN_ON(level == KVM_PGTABLE_MAX_LEVELS - 1))
 		return -EINVAL;
 
-	childp = (kvm_pte_t *)get_zeroed_page(GFP_KERNEL);
+	childp = (kvm_pte_t *)mm_ops->zalloc_page(NULL);
 	if (!childp)
 		return -ENOMEM;
 
-	kvm_set_table_pte(ptep, childp);
+	kvm_set_table_pte(ptep, childp, mm_ops);
 	return 0;
 }
 
@@ -376,6 +380,7 @@  int kvm_pgtable_hyp_map(struct kvm_pgtable *pgt, u64 addr, u64 size, u64 phys,
 	int ret;
 	struct hyp_map_data map_data = {
 		.phys	= ALIGN_DOWN(phys, PAGE_SIZE),
+		.mm_ops	= pgt->mm_ops,
 	};
 	struct kvm_pgtable_walker walker = {
 		.cb	= hyp_map_walker,
@@ -393,16 +398,18 @@  int kvm_pgtable_hyp_map(struct kvm_pgtable *pgt, u64 addr, u64 size, u64 phys,
 	return ret;
 }
 
-int kvm_pgtable_hyp_init(struct kvm_pgtable *pgt, u32 va_bits)
+int kvm_pgtable_hyp_init(struct kvm_pgtable *pgt, u32 va_bits,
+			 struct kvm_pgtable_mm_ops *mm_ops)
 {
 	u64 levels = ARM64_HW_PGTABLE_LEVELS(va_bits);
 
-	pgt->pgd = (kvm_pte_t *)get_zeroed_page(GFP_KERNEL);
+	pgt->pgd = (kvm_pte_t *)mm_ops->zalloc_page(NULL);
 	if (!pgt->pgd)
 		return -ENOMEM;
 
 	pgt->ia_bits		= va_bits;
 	pgt->start_level	= KVM_PGTABLE_MAX_LEVELS - levels;
+	pgt->mm_ops		= mm_ops;
 	pgt->mmu		= NULL;
 	return 0;
 }
@@ -410,7 +417,9 @@  int kvm_pgtable_hyp_init(struct kvm_pgtable *pgt, u32 va_bits)
 static int hyp_free_walker(u64 addr, u64 end, u32 level, kvm_pte_t *ptep,
 			   enum kvm_pgtable_walk_flags flag, void * const arg)
 {
-	put_page(virt_to_page(kvm_pte_follow(*ptep)));
+	struct kvm_pgtable_mm_ops *mm_ops = arg;
+
+	mm_ops->put_page((void *)kvm_pte_follow(*ptep, mm_ops));
 	return 0;
 }
 
@@ -419,10 +428,11 @@  void kvm_pgtable_hyp_destroy(struct kvm_pgtable *pgt)
 	struct kvm_pgtable_walker walker = {
 		.cb	= hyp_free_walker,
 		.flags	= KVM_PGTABLE_WALK_TABLE_POST,
+		.arg	= pgt->mm_ops,
 	};
 
 	WARN_ON(kvm_pgtable_walk(pgt, 0, BIT(pgt->ia_bits), &walker));
-	put_page(virt_to_page(pgt->pgd));
+	pgt->mm_ops->put_page(pgt->pgd);
 	pgt->pgd = NULL;
 }
 
@@ -434,6 +444,8 @@  struct stage2_map_data {
 
 	struct kvm_s2_mmu		*mmu;
 	struct kvm_mmu_memory_cache	*memcache;
+
+	struct kvm_pgtable_mm_ops	*mm_ops;
 };
 
 static int stage2_map_set_prot_attr(enum kvm_pgtable_prot prot,
@@ -501,12 +513,12 @@  static int stage2_map_walk_table_pre(u64 addr, u64 end, u32 level,
 static int stage2_map_walk_leaf(u64 addr, u64 end, u32 level, kvm_pte_t *ptep,
 				struct stage2_map_data *data)
 {
+	struct kvm_pgtable_mm_ops *mm_ops = data->mm_ops;
 	kvm_pte_t *childp, pte = *ptep;
-	struct page *page = virt_to_page(ptep);
 
 	if (data->anchor) {
 		if (kvm_pte_valid(pte))
-			put_page(page);
+			mm_ops->put_page(ptep);
 
 		return 0;
 	}
@@ -520,7 +532,7 @@  static int stage2_map_walk_leaf(u64 addr, u64 end, u32 level, kvm_pte_t *ptep,
 	if (!data->memcache)
 		return -ENOMEM;
 
-	childp = kvm_mmu_memory_cache_alloc(data->memcache);
+	childp = mm_ops->zalloc_page(data->memcache);
 	if (!childp)
 		return -ENOMEM;
 
@@ -532,13 +544,13 @@  static int stage2_map_walk_leaf(u64 addr, u64 end, u32 level, kvm_pte_t *ptep,
 	if (kvm_pte_valid(pte)) {
 		kvm_set_invalid_pte(ptep);
 		kvm_call_hyp(__kvm_tlb_flush_vmid_ipa, data->mmu, addr, level);
-		put_page(page);
+		mm_ops->put_page(ptep);
 	}
 
-	kvm_set_table_pte(ptep, childp);
+	kvm_set_table_pte(ptep, childp, mm_ops);
 
 out_get_page:
-	get_page(page);
+	mm_ops->get_page(ptep);
 	return 0;
 }
 
@@ -546,13 +558,14 @@  static int stage2_map_walk_table_post(u64 addr, u64 end, u32 level,
 				      kvm_pte_t *ptep,
 				      struct stage2_map_data *data)
 {
+	struct kvm_pgtable_mm_ops *mm_ops = data->mm_ops;
 	int ret = 0;
 
 	if (!data->anchor)
 		return 0;
 
-	put_page(virt_to_page(kvm_pte_follow(*ptep)));
-	put_page(virt_to_page(ptep));
+	mm_ops->put_page(kvm_pte_follow(*ptep, mm_ops));
+	mm_ops->put_page(ptep);
 
 	if (data->anchor == ptep) {
 		data->anchor = NULL;
@@ -607,6 +620,7 @@  int kvm_pgtable_stage2_map(struct kvm_pgtable *pgt, u64 addr, u64 size,
 		.phys		= ALIGN_DOWN(phys, PAGE_SIZE),
 		.mmu		= pgt->mmu,
 		.memcache	= mc,
+		.mm_ops		= pgt->mm_ops,
 	};
 	struct kvm_pgtable_walker walker = {
 		.cb		= stage2_map_walker,
@@ -643,7 +657,9 @@  static int stage2_unmap_walker(u64 addr, u64 end, u32 level, kvm_pte_t *ptep,
 			       enum kvm_pgtable_walk_flags flag,
 			       void * const arg)
 {
-	struct kvm_s2_mmu *mmu = arg;
+	struct kvm_pgtable *pgt = arg;
+	struct kvm_s2_mmu *mmu = pgt->mmu;
+	struct kvm_pgtable_mm_ops *mm_ops = pgt->mm_ops;
 	kvm_pte_t pte = *ptep, *childp = NULL;
 	bool need_flush = false;
 
@@ -651,9 +667,9 @@  static int stage2_unmap_walker(u64 addr, u64 end, u32 level, kvm_pte_t *ptep,
 		return 0;
 
 	if (kvm_pte_table(pte, level)) {
-		childp = kvm_pte_follow(pte);
+		childp = kvm_pte_follow(pte, mm_ops);
 
-		if (page_count(virt_to_page(childp)) != 1)
+		if (mm_ops->page_count(childp) != 1)
 			return 0;
 	} else if (stage2_pte_cacheable(pte)) {
 		need_flush = true;
@@ -666,15 +682,15 @@  static int stage2_unmap_walker(u64 addr, u64 end, u32 level, kvm_pte_t *ptep,
 	 */
 	kvm_set_invalid_pte(ptep);
 	kvm_call_hyp(__kvm_tlb_flush_vmid_ipa, mmu, addr, level);
-	put_page(virt_to_page(ptep));
+	mm_ops->put_page(ptep);
 
 	if (need_flush) {
-		stage2_flush_dcache(kvm_pte_follow(pte),
+		stage2_flush_dcache(kvm_pte_follow(pte, mm_ops),
 				    kvm_granule_size(level));
 	}
 
 	if (childp)
-		put_page(virt_to_page(childp));
+		mm_ops->put_page(childp);
 
 	return 0;
 }
@@ -683,7 +699,7 @@  int kvm_pgtable_stage2_unmap(struct kvm_pgtable *pgt, u64 addr, u64 size)
 {
 	struct kvm_pgtable_walker walker = {
 		.cb	= stage2_unmap_walker,
-		.arg	= pgt->mmu,
+		.arg	= pgt,
 		.flags	= KVM_PGTABLE_WALK_LEAF | KVM_PGTABLE_WALK_TABLE_POST,
 	};
 
@@ -815,12 +831,13 @@  static int stage2_flush_walker(u64 addr, u64 end, u32 level, kvm_pte_t *ptep,
 			       enum kvm_pgtable_walk_flags flag,
 			       void * const arg)
 {
+	struct kvm_pgtable_mm_ops *mm_ops = arg;
 	kvm_pte_t pte = *ptep;
 
 	if (!kvm_pte_valid(pte) || !stage2_pte_cacheable(pte))
 		return 0;
 
-	stage2_flush_dcache(kvm_pte_follow(pte), kvm_granule_size(level));
+	stage2_flush_dcache(kvm_pte_follow(pte, mm_ops), kvm_granule_size(level));
 	return 0;
 }
 
@@ -829,6 +846,7 @@  int kvm_pgtable_stage2_flush(struct kvm_pgtable *pgt, u64 addr, u64 size)
 	struct kvm_pgtable_walker walker = {
 		.cb	= stage2_flush_walker,
 		.flags	= KVM_PGTABLE_WALK_LEAF,
+		.arg	= pgt->mm_ops,
 	};
 
 	if (cpus_have_const_cap(ARM64_HAS_STAGE2_FWB))
@@ -837,7 +855,8 @@  int kvm_pgtable_stage2_flush(struct kvm_pgtable *pgt, u64 addr, u64 size)
 	return kvm_pgtable_walk(pgt, addr, size, &walker);
 }
 
-int kvm_pgtable_stage2_init(struct kvm_pgtable *pgt, struct kvm *kvm)
+int kvm_pgtable_stage2_init(struct kvm_pgtable *pgt, struct kvm *kvm,
+			    struct kvm_pgtable_mm_ops *mm_ops)
 {
 	size_t pgd_sz;
 	u64 vtcr = kvm->arch.vtcr;
@@ -846,12 +865,13 @@  int kvm_pgtable_stage2_init(struct kvm_pgtable *pgt, struct kvm *kvm)
 	u32 start_level = VTCR_EL2_TGRAN_SL0_BASE - sl0;
 
 	pgd_sz = kvm_pgd_pages(ia_bits, start_level) * PAGE_SIZE;
-	pgt->pgd = alloc_pages_exact(pgd_sz, GFP_KERNEL_ACCOUNT | __GFP_ZERO);
+	pgt->pgd = mm_ops->zalloc_pages_exact(pgd_sz);
 	if (!pgt->pgd)
 		return -ENOMEM;
 
 	pgt->ia_bits		= ia_bits;
 	pgt->start_level	= start_level;
+	pgt->mm_ops		= mm_ops;
 	pgt->mmu		= &kvm->arch.mmu;
 
 	/* Ensure zeroed PGD pages are visible to the hardware walker */
@@ -863,15 +883,16 @@  static int stage2_free_walker(u64 addr, u64 end, u32 level, kvm_pte_t *ptep,
 			      enum kvm_pgtable_walk_flags flag,
 			      void * const arg)
 {
+	struct kvm_pgtable_mm_ops *mm_ops = arg;
 	kvm_pte_t pte = *ptep;
 
 	if (!kvm_pte_valid(pte))
 		return 0;
 
-	put_page(virt_to_page(ptep));
+	mm_ops->put_page(ptep);
 
 	if (kvm_pte_table(pte, level))
-		put_page(virt_to_page(kvm_pte_follow(pte)));
+		mm_ops->put_page(kvm_pte_follow(pte, mm_ops));
 
 	return 0;
 }
@@ -883,10 +904,11 @@  void kvm_pgtable_stage2_destroy(struct kvm_pgtable *pgt)
 		.cb	= stage2_free_walker,
 		.flags	= KVM_PGTABLE_WALK_LEAF |
 			  KVM_PGTABLE_WALK_TABLE_POST,
+		.arg	= pgt->mm_ops,
 	};
 
 	WARN_ON(kvm_pgtable_walk(pgt, 0, BIT(pgt->ia_bits), &walker));
 	pgd_sz = kvm_pgd_pages(pgt->ia_bits, pgt->start_level) * PAGE_SIZE;
-	free_pages_exact(pgt->pgd, pgd_sz);
+	pgt->mm_ops->free_pages_exact(pgt->pgd, pgd_sz);
 	pgt->pgd = NULL;
 }
diff --git a/arch/arm64/kvm/mmu.c b/arch/arm64/kvm/mmu.c
index 1f41173e6149..278e163beda4 100644
--- a/arch/arm64/kvm/mmu.c
+++ b/arch/arm64/kvm/mmu.c
@@ -88,6 +88,48 @@  static bool kvm_is_device_pfn(unsigned long pfn)
 	return !pfn_valid(pfn);
 }
 
+static void *stage2_memcache_alloc_page(void *arg)
+{
+	struct kvm_mmu_memory_cache *mc = arg;
+	kvm_pte_t *ptep = NULL;
+
+	/* Allocated with GFP_KERNEL_ACCOUNT, so no need to zero */
+	if (mc && mc->nobjs)
+		ptep = mc->objects[--mc->nobjs];
+
+	return ptep;
+}
+
+static void *kvm_host_zalloc_pages_exact(size_t size)
+{
+	return alloc_pages_exact(size, GFP_KERNEL_ACCOUNT | __GFP_ZERO);
+}
+
+static void kvm_host_get_page(void *addr)
+{
+	get_page(virt_to_page(addr));
+}
+
+static void kvm_host_put_page(void *addr)
+{
+	put_page(virt_to_page(addr));
+}
+
+static int kvm_host_page_count(void *addr)
+{
+	return page_count(virt_to_page(addr));
+}
+
+static phys_addr_t kvm_host_pa(void *addr)
+{
+	return __pa(addr);
+}
+
+static void *kvm_host_va(phys_addr_t phys)
+{
+	return __va(phys);
+}
+
 /*
  * Unmapping vs dcache management:
  *
@@ -351,6 +393,17 @@  int create_hyp_exec_mappings(phys_addr_t phys_addr, size_t size,
 	return 0;
 }
 
+static struct kvm_pgtable_mm_ops kvm_s2_mm_ops = {
+	.zalloc_page		= stage2_memcache_alloc_page,
+	.zalloc_pages_exact	= kvm_host_zalloc_pages_exact,
+	.free_pages_exact	= free_pages_exact,
+	.get_page		= kvm_host_get_page,
+	.put_page		= kvm_host_put_page,
+	.page_count		= kvm_host_page_count,
+	.phys_to_virt		= kvm_host_va,
+	.virt_to_phys		= kvm_host_pa,
+};
+
 /**
  * kvm_init_stage2_mmu - Initialise a S2 MMU strucrure
  * @kvm:	The pointer to the KVM structure
@@ -374,7 +427,7 @@  int kvm_init_stage2_mmu(struct kvm *kvm, struct kvm_s2_mmu *mmu)
 	if (!pgt)
 		return -ENOMEM;
 
-	err = kvm_pgtable_stage2_init(pgt, kvm);
+	err = kvm_pgtable_stage2_init(pgt, kvm, &kvm_s2_mm_ops);
 	if (err)
 		goto out_free_pgtable;
 
@@ -1198,6 +1251,19 @@  static int kvm_map_idmap_text(void)
 	return err;
 }
 
+static void *kvm_hyp_zalloc_page(void *arg)
+{
+	return (void *)get_zeroed_page(GFP_KERNEL);
+}
+
+static struct kvm_pgtable_mm_ops kvm_hyp_mm_ops = {
+	.zalloc_page		= kvm_hyp_zalloc_page,
+	.get_page		= kvm_host_get_page,
+	.put_page		= kvm_host_put_page,
+	.phys_to_virt		= kvm_host_va,
+	.virt_to_phys		= kvm_host_pa,
+};
+
 int kvm_mmu_init(void)
 {
 	int err;
@@ -1241,7 +1307,7 @@  int kvm_mmu_init(void)
 		goto out;
 	}
 
-	err = kvm_pgtable_hyp_init(hyp_pgtable, hyp_va_bits);
+	err = kvm_pgtable_hyp_init(hyp_pgtable, hyp_va_bits, &kvm_hyp_mm_ops);
 	if (err)
 		goto out_free_pgtable;