diff mbox series

[07/12] huge_memory: Allow mappings of PMD sized pages

Message ID b63e8b07ceed8cf7b9cd07332132d6713853c777.1725941415.git-series.apopple@nvidia.com (mailing list archive)
State Handled Elsewhere, archived
Delegated to: Ira Weiny
Headers show
Series fs/dax: Fix FS DAX page reference counts | expand

Commit Message

Alistair Popple Sept. 10, 2024, 4:14 a.m. UTC
Currently DAX folio/page reference counts are managed differently to
normal pages. To allow these to be managed the same as normal pages
introduce dax_insert_pfn_pmd. This will map the entire PMD-sized folio
and take references as it would for a normally mapped page.

This is distinct from the current mechanism, vmf_insert_pfn_pmd, which
simply inserts a special devmap PMD entry into the page table without
holding a reference to the page for the mapping.

Signed-off-by: Alistair Popple <apopple@nvidia.com>
---
 include/linux/huge_mm.h |  1 +-
 mm/huge_memory.c        | 57 ++++++++++++++++++++++++++++++++++--------
 2 files changed, 48 insertions(+), 10 deletions(-)

Comments

Dan Williams Sept. 27, 2024, 2:48 a.m. UTC | #1
Alistair Popple wrote:
> Currently DAX folio/page reference counts are managed differently to
> normal pages. To allow these to be managed the same as normal pages
> introduce dax_insert_pfn_pmd. This will map the entire PMD-sized folio
> and take references as it would for a normally mapped page.
> 
> This is distinct from the current mechanism, vmf_insert_pfn_pmd, which
> simply inserts a special devmap PMD entry into the page table without
> holding a reference to the page for the mapping.

It would be useful to mention the rationale for the locking changes and
your understanding of the new "pgtable deposit" handling, because those
things make this not a trivial conversion.


> 
> Signed-off-by: Alistair Popple <apopple@nvidia.com>
> ---
>  include/linux/huge_mm.h |  1 +-
>  mm/huge_memory.c        | 57 ++++++++++++++++++++++++++++++++++--------
>  2 files changed, 48 insertions(+), 10 deletions(-)
> 
> diff --git a/include/linux/huge_mm.h b/include/linux/huge_mm.h
> index d3a1872..eaf3f78 100644
> --- a/include/linux/huge_mm.h
> +++ b/include/linux/huge_mm.h
> @@ -40,6 +40,7 @@ int change_huge_pmd(struct mmu_gather *tlb, struct vm_area_struct *vma,
>  
>  vm_fault_t vmf_insert_pfn_pmd(struct vm_fault *vmf, pfn_t pfn, bool write);
>  vm_fault_t vmf_insert_pfn_pud(struct vm_fault *vmf, pfn_t pfn, bool write);
> +vm_fault_t dax_insert_pfn_pmd(struct vm_fault *vmf, pfn_t pfn, bool write);
>  vm_fault_t dax_insert_pfn_pud(struct vm_fault *vmf, pfn_t pfn, bool write);
>  
>  enum transparent_hugepage_flag {
> diff --git a/mm/huge_memory.c b/mm/huge_memory.c
> index e8985a4..790041e 100644
> --- a/mm/huge_memory.c
> +++ b/mm/huge_memory.c
> @@ -1237,14 +1237,12 @@ static void insert_pfn_pmd(struct vm_area_struct *vma, unsigned long addr,
>  {
>  	struct mm_struct *mm = vma->vm_mm;
>  	pmd_t entry;
> -	spinlock_t *ptl;
>  
> -	ptl = pmd_lock(mm, pmd);
>  	if (!pmd_none(*pmd)) {
>  		if (write) {
>  			if (pmd_pfn(*pmd) != pfn_t_to_pfn(pfn)) {
>  				WARN_ON_ONCE(!is_huge_zero_pmd(*pmd));
> -				goto out_unlock;
> +				return;
>  			}
>  			entry = pmd_mkyoung(*pmd);
>  			entry = maybe_pmd_mkwrite(pmd_mkdirty(entry), vma);
> @@ -1252,7 +1250,7 @@ static void insert_pfn_pmd(struct vm_area_struct *vma, unsigned long addr,
>  				update_mmu_cache_pmd(vma, addr, pmd);
>  		}
>  
> -		goto out_unlock;
> +		return;
>  	}
>  
>  	entry = pmd_mkhuge(pfn_t_pmd(pfn, prot));
> @@ -1271,11 +1269,6 @@ static void insert_pfn_pmd(struct vm_area_struct *vma, unsigned long addr,
>  
>  	set_pmd_at(mm, addr, pmd, entry);
>  	update_mmu_cache_pmd(vma, addr, pmd);
> -
> -out_unlock:
> -	spin_unlock(ptl);
> -	if (pgtable)
> -		pte_free(mm, pgtable);
>  }
>  
>  /**
> @@ -1294,6 +1287,7 @@ vm_fault_t vmf_insert_pfn_pmd(struct vm_fault *vmf, pfn_t pfn, bool write)
>  	struct vm_area_struct *vma = vmf->vma;
>  	pgprot_t pgprot = vma->vm_page_prot;
>  	pgtable_t pgtable = NULL;
> +	spinlock_t *ptl;
>  
>  	/*
>  	 * If we had pmd_special, we could avoid all these restrictions,
> @@ -1316,12 +1310,55 @@ vm_fault_t vmf_insert_pfn_pmd(struct vm_fault *vmf, pfn_t pfn, bool write)
>  	}
>  
>  	track_pfn_insert(vma, &pgprot, pfn);
> -
> +	ptl = pmd_lock(vma->vm_mm, vmf->pmd);
>  	insert_pfn_pmd(vma, addr, vmf->pmd, pfn, pgprot, write, pgtable);
> +	spin_unlock(ptl);
> +	if (pgtable)
> +		pte_free(vma->vm_mm, pgtable);
> +
>  	return VM_FAULT_NOPAGE;
>  }
>  EXPORT_SYMBOL_GPL(vmf_insert_pfn_pmd);
>  
> +vm_fault_t dax_insert_pfn_pmd(struct vm_fault *vmf, pfn_t pfn, bool write)
> +{
> +	struct vm_area_struct *vma = vmf->vma;
> +	unsigned long addr = vmf->address & PMD_MASK;
> +	struct mm_struct *mm = vma->vm_mm;
> +	spinlock_t *ptl;
> +	pgtable_t pgtable = NULL;
> +	struct folio *folio;
> +	struct page *page;
> +
> +	if (addr < vma->vm_start || addr >= vma->vm_end)
> +		return VM_FAULT_SIGBUS;
> +
> +	if (arch_needs_pgtable_deposit()) {
> +		pgtable = pte_alloc_one(vma->vm_mm);
> +		if (!pgtable)
> +			return VM_FAULT_OOM;
> +	}
> +
> +	track_pfn_insert(vma, &vma->vm_page_prot, pfn);
> +
> +	ptl = pmd_lock(mm, vmf->pmd);
> +	if (pmd_none(*vmf->pmd)) {
> +		page = pfn_t_to_page(pfn);
> +		folio = page_folio(page);
> +		folio_get(folio);
> +		folio_add_file_rmap_pmd(folio, page, vma);
> +		add_mm_counter(mm, mm_counter_file(folio), HPAGE_PMD_NR);
> +	}
> +	insert_pfn_pmd(vma, addr, vmf->pmd, pfn, vma->vm_page_prot,
> +		write, pgtable);
> +	spin_unlock(ptl);
> +	if (pgtable)
> +		pte_free(mm, pgtable);

Are not the deposit rules that the extra page table stick around for the
lifetime of the inserted pte? So would that not require this incremental
change?

---
diff --git a/mm/huge_memory.c b/mm/huge_memory.c
index ea65c2db2bb1..5ef1e5d21a96 100644
--- a/mm/huge_memory.c
+++ b/mm/huge_memory.c
@@ -1232,7 +1232,7 @@ vm_fault_t do_huge_pmd_anonymous_page(struct vm_fault *vmf)
 
 static void insert_pfn_pmd(struct vm_area_struct *vma, unsigned long addr,
 			   pmd_t *pmd, unsigned long pfn, pgprot_t prot,
-			   bool write, pgtable_t pgtable)
+			   bool write, pgtable_t *pgtable)
 {
 	struct mm_struct *mm = vma->vm_mm;
 	pmd_t entry;
@@ -1258,10 +1258,10 @@ static void insert_pfn_pmd(struct vm_area_struct *vma, unsigned long addr,
 		entry = maybe_pmd_mkwrite(entry, vma);
 	}
 
-	if (pgtable) {
-		pgtable_trans_huge_deposit(mm, pmd, pgtable);
+	if (*pgtable) {
+		pgtable_trans_huge_deposit(mm, pmd, *pgtable);
 		mm_inc_nr_ptes(mm);
-		pgtable = NULL;
+		*pgtable = NULL;
 	}
 
 	set_pmd_at(mm, addr, pmd, entry);
@@ -1306,7 +1306,7 @@ vm_fault_t vmf_insert_pfn_pmd(struct vm_fault *vmf, unsigned long pfn, bool writ
 
 	track_pfn_insert(vma, &pgprot, pfn);
 	ptl = pmd_lock(vma->vm_mm, vmf->pmd);
-	insert_pfn_pmd(vma, addr, vmf->pmd, pfn, pgprot, write, pgtable);
+	insert_pfn_pmd(vma, addr, vmf->pmd, pfn, pgprot, write, &pgtable);
 	spin_unlock(ptl);
 	if (pgtable)
 		pte_free(vma->vm_mm, pgtable);
@@ -1344,8 +1344,8 @@ vm_fault_t dax_insert_pfn_pmd(struct vm_fault *vmf, unsigned long pfn, bool writ
 		folio_add_file_rmap_pmd(folio, page, vma);
 		add_mm_counter(mm, mm_counter_file(folio), HPAGE_PMD_NR);
 	}
-	insert_pfn_pmd(vma, addr, vmf->pmd, pfn, vma->vm_page_prot,
-		write, pgtable);
+	insert_pfn_pmd(vma, addr, vmf->pmd, pfn, vma->vm_page_prot, write,
+		       &pgtable);
 	spin_unlock(ptl);
 	if (pgtable)
 		pte_free(mm, pgtable);
---

Along these lines it would be lovely if someone from the PowerPC side
could test these changes, or if someone has a canned qemu command line
to test radix vs hash with pmem+dax that they can share?

> +
> +	return VM_FAULT_NOPAGE;
> +}
> +EXPORT_SYMBOL_GPL(dax_insert_pfn_pmd);

Like I mentioned before, lets make the exported function
vmf_insert_folio() and move the pte, pmd, pud internal private / static
details of the implementation. The "dax_" specific aspect of this was
removed at the conversion of a dax_pfn to a folio.
Alistair Popple Oct. 14, 2024, 6:53 a.m. UTC | #2
Dan Williams <dan.j.williams@intel.com> writes:

> Alistair Popple wrote:
>> Currently DAX folio/page reference counts are managed differently to
>> normal pages. To allow these to be managed the same as normal pages
>> introduce dax_insert_pfn_pmd. This will map the entire PMD-sized folio
>> and take references as it would for a normally mapped page.
>> 
>> This is distinct from the current mechanism, vmf_insert_pfn_pmd, which
>> simply inserts a special devmap PMD entry into the page table without
>> holding a reference to the page for the mapping.
>
> It would be useful to mention the rationale for the locking changes and
> your understanding of the new "pgtable deposit" handling, because those
> things make this not a trivial conversion.

My intent was not to change the locking for the existing
vmf_insert_pfn_pmd() but just to move it up a level in the stack so
dax_insert_pfn_pmd() could do the metadata manipulation while holding
the lock. Looks like I didn't get that quite right though, so I will
review it for the next version.

>> 
>> Signed-off-by: Alistair Popple <apopple@nvidia.com>
>> ---
>>  include/linux/huge_mm.h |  1 +-
>>  mm/huge_memory.c        | 57 ++++++++++++++++++++++++++++++++++--------
>>  2 files changed, 48 insertions(+), 10 deletions(-)
>> 
>> diff --git a/include/linux/huge_mm.h b/include/linux/huge_mm.h
>> index d3a1872..eaf3f78 100644
>> --- a/include/linux/huge_mm.h
>> +++ b/include/linux/huge_mm.h
>> @@ -40,6 +40,7 @@ int change_huge_pmd(struct mmu_gather *tlb, struct vm_area_struct *vma,
>>  
>>  vm_fault_t vmf_insert_pfn_pmd(struct vm_fault *vmf, pfn_t pfn, bool write);
>>  vm_fault_t vmf_insert_pfn_pud(struct vm_fault *vmf, pfn_t pfn, bool write);
>> +vm_fault_t dax_insert_pfn_pmd(struct vm_fault *vmf, pfn_t pfn, bool write);
>>  vm_fault_t dax_insert_pfn_pud(struct vm_fault *vmf, pfn_t pfn, bool write);
>>  
>>  enum transparent_hugepage_flag {
>> diff --git a/mm/huge_memory.c b/mm/huge_memory.c
>> index e8985a4..790041e 100644
>> --- a/mm/huge_memory.c
>> +++ b/mm/huge_memory.c
>> @@ -1237,14 +1237,12 @@ static void insert_pfn_pmd(struct vm_area_struct *vma, unsigned long addr,
>>  {
>>  	struct mm_struct *mm = vma->vm_mm;
>>  	pmd_t entry;
>> -	spinlock_t *ptl;
>>  
>> -	ptl = pmd_lock(mm, pmd);
>>  	if (!pmd_none(*pmd)) {
>>  		if (write) {
>>  			if (pmd_pfn(*pmd) != pfn_t_to_pfn(pfn)) {
>>  				WARN_ON_ONCE(!is_huge_zero_pmd(*pmd));
>> -				goto out_unlock;
>> +				return;
>>  			}
>>  			entry = pmd_mkyoung(*pmd);
>>  			entry = maybe_pmd_mkwrite(pmd_mkdirty(entry), vma);
>> @@ -1252,7 +1250,7 @@ static void insert_pfn_pmd(struct vm_area_struct *vma, unsigned long addr,
>>  				update_mmu_cache_pmd(vma, addr, pmd);
>>  		}
>>  
>> -		goto out_unlock;
>> +		return;
>>  	}
>>  
>>  	entry = pmd_mkhuge(pfn_t_pmd(pfn, prot));
>> @@ -1271,11 +1269,6 @@ static void insert_pfn_pmd(struct vm_area_struct *vma, unsigned long addr,
>>  
>>  	set_pmd_at(mm, addr, pmd, entry);
>>  	update_mmu_cache_pmd(vma, addr, pmd);
>> -
>> -out_unlock:
>> -	spin_unlock(ptl);
>> -	if (pgtable)
>> -		pte_free(mm, pgtable);
>>  }
>>  
>>  /**
>> @@ -1294,6 +1287,7 @@ vm_fault_t vmf_insert_pfn_pmd(struct vm_fault *vmf, pfn_t pfn, bool write)
>>  	struct vm_area_struct *vma = vmf->vma;
>>  	pgprot_t pgprot = vma->vm_page_prot;
>>  	pgtable_t pgtable = NULL;
>> +	spinlock_t *ptl;
>>  
>>  	/*
>>  	 * If we had pmd_special, we could avoid all these restrictions,
>> @@ -1316,12 +1310,55 @@ vm_fault_t vmf_insert_pfn_pmd(struct vm_fault *vmf, pfn_t pfn, bool write)
>>  	}
>>  
>>  	track_pfn_insert(vma, &pgprot, pfn);
>> -
>> +	ptl = pmd_lock(vma->vm_mm, vmf->pmd);
>>  	insert_pfn_pmd(vma, addr, vmf->pmd, pfn, pgprot, write, pgtable);
>> +	spin_unlock(ptl);
>> +	if (pgtable)
>> +		pte_free(vma->vm_mm, pgtable);
>> +
>>  	return VM_FAULT_NOPAGE;
>>  }
>>  EXPORT_SYMBOL_GPL(vmf_insert_pfn_pmd);
>>  
>> +vm_fault_t dax_insert_pfn_pmd(struct vm_fault *vmf, pfn_t pfn, bool write)
>> +{
>> +	struct vm_area_struct *vma = vmf->vma;
>> +	unsigned long addr = vmf->address & PMD_MASK;
>> +	struct mm_struct *mm = vma->vm_mm;
>> +	spinlock_t *ptl;
>> +	pgtable_t pgtable = NULL;
>> +	struct folio *folio;
>> +	struct page *page;
>> +
>> +	if (addr < vma->vm_start || addr >= vma->vm_end)
>> +		return VM_FAULT_SIGBUS;
>> +
>> +	if (arch_needs_pgtable_deposit()) {
>> +		pgtable = pte_alloc_one(vma->vm_mm);
>> +		if (!pgtable)
>> +			return VM_FAULT_OOM;
>> +	}
>> +
>> +	track_pfn_insert(vma, &vma->vm_page_prot, pfn);
>> +
>> +	ptl = pmd_lock(mm, vmf->pmd);
>> +	if (pmd_none(*vmf->pmd)) {
>> +		page = pfn_t_to_page(pfn);
>> +		folio = page_folio(page);
>> +		folio_get(folio);
>> +		folio_add_file_rmap_pmd(folio, page, vma);
>> +		add_mm_counter(mm, mm_counter_file(folio), HPAGE_PMD_NR);
>> +	}
>> +	insert_pfn_pmd(vma, addr, vmf->pmd, pfn, vma->vm_page_prot,
>> +		write, pgtable);
>> +	spin_unlock(ptl);
>> +	if (pgtable)
>> +		pte_free(mm, pgtable);
>
> Are not the deposit rules that the extra page table stick around for the
> lifetime of the inserted pte? So would that not require this incremental
> change?

Yeah, thanks for catching this.

> ---
> diff --git a/mm/huge_memory.c b/mm/huge_memory.c
> index ea65c2db2bb1..5ef1e5d21a96 100644
> --- a/mm/huge_memory.c
> +++ b/mm/huge_memory.c
> @@ -1232,7 +1232,7 @@ vm_fault_t do_huge_pmd_anonymous_page(struct vm_fault *vmf)
>  
>  static void insert_pfn_pmd(struct vm_area_struct *vma, unsigned long addr,
>  			   pmd_t *pmd, unsigned long pfn, pgprot_t prot,
> -			   bool write, pgtable_t pgtable)
> +			   bool write, pgtable_t *pgtable)
>  {
>  	struct mm_struct *mm = vma->vm_mm;
>  	pmd_t entry;
> @@ -1258,10 +1258,10 @@ static void insert_pfn_pmd(struct vm_area_struct *vma, unsigned long addr,
>  		entry = maybe_pmd_mkwrite(entry, vma);
>  	}
>  
> -	if (pgtable) {
> -		pgtable_trans_huge_deposit(mm, pmd, pgtable);
> +	if (*pgtable) {
> +		pgtable_trans_huge_deposit(mm, pmd, *pgtable);
>  		mm_inc_nr_ptes(mm);
> -		pgtable = NULL;
> +		*pgtable = NULL;
>  	}
>  
>  	set_pmd_at(mm, addr, pmd, entry);
> @@ -1306,7 +1306,7 @@ vm_fault_t vmf_insert_pfn_pmd(struct vm_fault *vmf, unsigned long pfn, bool writ
>  
>  	track_pfn_insert(vma, &pgprot, pfn);
>  	ptl = pmd_lock(vma->vm_mm, vmf->pmd);
> -	insert_pfn_pmd(vma, addr, vmf->pmd, pfn, pgprot, write, pgtable);
> +	insert_pfn_pmd(vma, addr, vmf->pmd, pfn, pgprot, write, &pgtable);
>  	spin_unlock(ptl);
>  	if (pgtable)
>  		pte_free(vma->vm_mm, pgtable);
> @@ -1344,8 +1344,8 @@ vm_fault_t dax_insert_pfn_pmd(struct vm_fault *vmf, unsigned long pfn, bool writ
>  		folio_add_file_rmap_pmd(folio, page, vma);
>  		add_mm_counter(mm, mm_counter_file(folio), HPAGE_PMD_NR);
>  	}
> -	insert_pfn_pmd(vma, addr, vmf->pmd, pfn, vma->vm_page_prot,
> -		write, pgtable);
> +	insert_pfn_pmd(vma, addr, vmf->pmd, pfn, vma->vm_page_prot, write,
> +		       &pgtable);
>  	spin_unlock(ptl);
>  	if (pgtable)
>  		pte_free(mm, pgtable);
> ---
>
> Along these lines it would be lovely if someone from the PowerPC side
> could test these changes, or if someone has a canned qemu command line
> to test radix vs hash with pmem+dax that they can share?

Michael, Nick, do you know of a qemu command or anyone who might?

>> +
>> +	return VM_FAULT_NOPAGE;
>> +}
>> +EXPORT_SYMBOL_GPL(dax_insert_pfn_pmd);
>
> Like I mentioned before, lets make the exported function
> vmf_insert_folio() and move the pte, pmd, pud internal private / static
> details of the implementation. The "dax_" specific aspect of this was
> removed at the conversion of a dax_pfn to a folio.

Ok, let me try that. Note that vmf_insert_pfn{_pmd|_pud} will have to
stick around though.
Alistair Popple Oct. 23, 2024, 11:14 p.m. UTC | #3
Alistair Popple <apopple@nvidia.com> writes:

> Alistair Popple wrote:
>> Dan Williams <dan.j.williams@intel.com> writes:

[...]

>>> +
>>> +	return VM_FAULT_NOPAGE;
>>> +}
>>> +EXPORT_SYMBOL_GPL(dax_insert_pfn_pmd);
>>
>> Like I mentioned before, lets make the exported function
>> vmf_insert_folio() and move the pte, pmd, pud internal private / static
>> details of the implementation. The "dax_" specific aspect of this was
>> removed at the conversion of a dax_pfn to a folio.
>
> Ok, let me try that. Note that vmf_insert_pfn{_pmd|_pud} will have to
> stick around though.

Creating a single vmf_insert_folio() seems somewhat difficult because it
needs to be called from multiple fault paths (either PTE, PMD or PUD
fault) and do something different for each.

Specifically the issue I ran into is that DAX does not downgrade PMD
entries to PTE entries if they are backed by storage. So the PTE fault
handler will get a PMD-sized DAX entry and therefore a PMD size folio.

The way I tried implementing vmf_insert_folio() was to look at
folio_order() to determine which internal implementation to call. But
that doesn't work for a PTE fault, because there's no way to determine
if we should PTE map a subpage or PMD map the entire folio.

We could pass down some context as to what type of fault we're handling,
or add it to the vmf struct, but that seems excessive given callers
already know this and could just call a specific
vmf_insert_page_{pte|pmd|pud}.
Dan Williams Oct. 23, 2024, 11:38 p.m. UTC | #4
Alistair Popple wrote:
> 
> Alistair Popple <apopple@nvidia.com> writes:
> 
> > Alistair Popple wrote:
> >> Dan Williams <dan.j.williams@intel.com> writes:
> 
> [...]
> 
> >>> +
> >>> +	return VM_FAULT_NOPAGE;
> >>> +}
> >>> +EXPORT_SYMBOL_GPL(dax_insert_pfn_pmd);
> >>
> >> Like I mentioned before, lets make the exported function
> >> vmf_insert_folio() and move the pte, pmd, pud internal private / static
> >> details of the implementation. The "dax_" specific aspect of this was
> >> removed at the conversion of a dax_pfn to a folio.
> >
> > Ok, let me try that. Note that vmf_insert_pfn{_pmd|_pud} will have to
> > stick around though.
> 
> Creating a single vmf_insert_folio() seems somewhat difficult because it
> needs to be called from multiple fault paths (either PTE, PMD or PUD
> fault) and do something different for each.
> 
> Specifically the issue I ran into is that DAX does not downgrade PMD
> entries to PTE entries if they are backed by storage. So the PTE fault
> handler will get a PMD-sized DAX entry and therefore a PMD size folio.
> 
> The way I tried implementing vmf_insert_folio() was to look at
> folio_order() to determine which internal implementation to call. But
> that doesn't work for a PTE fault, because there's no way to determine
> if we should PTE map a subpage or PMD map the entire folio.

Ah, that conflict makes sense.

> We could pass down some context as to what type of fault we're handling,
> or add it to the vmf struct, but that seems excessive given callers
> already know this and could just call a specific
> vmf_insert_page_{pte|pmd|pud}.

Ok, I think it is good to capture that "because dax does not downgrade
entries it may satisfy PTE faults with PMD inserts", or something like
that in comment or changelog.
diff mbox series

Patch

diff --git a/include/linux/huge_mm.h b/include/linux/huge_mm.h
index d3a1872..eaf3f78 100644
--- a/include/linux/huge_mm.h
+++ b/include/linux/huge_mm.h
@@ -40,6 +40,7 @@  int change_huge_pmd(struct mmu_gather *tlb, struct vm_area_struct *vma,
 
 vm_fault_t vmf_insert_pfn_pmd(struct vm_fault *vmf, pfn_t pfn, bool write);
 vm_fault_t vmf_insert_pfn_pud(struct vm_fault *vmf, pfn_t pfn, bool write);
+vm_fault_t dax_insert_pfn_pmd(struct vm_fault *vmf, pfn_t pfn, bool write);
 vm_fault_t dax_insert_pfn_pud(struct vm_fault *vmf, pfn_t pfn, bool write);
 
 enum transparent_hugepage_flag {
diff --git a/mm/huge_memory.c b/mm/huge_memory.c
index e8985a4..790041e 100644
--- a/mm/huge_memory.c
+++ b/mm/huge_memory.c
@@ -1237,14 +1237,12 @@  static void insert_pfn_pmd(struct vm_area_struct *vma, unsigned long addr,
 {
 	struct mm_struct *mm = vma->vm_mm;
 	pmd_t entry;
-	spinlock_t *ptl;
 
-	ptl = pmd_lock(mm, pmd);
 	if (!pmd_none(*pmd)) {
 		if (write) {
 			if (pmd_pfn(*pmd) != pfn_t_to_pfn(pfn)) {
 				WARN_ON_ONCE(!is_huge_zero_pmd(*pmd));
-				goto out_unlock;
+				return;
 			}
 			entry = pmd_mkyoung(*pmd);
 			entry = maybe_pmd_mkwrite(pmd_mkdirty(entry), vma);
@@ -1252,7 +1250,7 @@  static void insert_pfn_pmd(struct vm_area_struct *vma, unsigned long addr,
 				update_mmu_cache_pmd(vma, addr, pmd);
 		}
 
-		goto out_unlock;
+		return;
 	}
 
 	entry = pmd_mkhuge(pfn_t_pmd(pfn, prot));
@@ -1271,11 +1269,6 @@  static void insert_pfn_pmd(struct vm_area_struct *vma, unsigned long addr,
 
 	set_pmd_at(mm, addr, pmd, entry);
 	update_mmu_cache_pmd(vma, addr, pmd);
-
-out_unlock:
-	spin_unlock(ptl);
-	if (pgtable)
-		pte_free(mm, pgtable);
 }
 
 /**
@@ -1294,6 +1287,7 @@  vm_fault_t vmf_insert_pfn_pmd(struct vm_fault *vmf, pfn_t pfn, bool write)
 	struct vm_area_struct *vma = vmf->vma;
 	pgprot_t pgprot = vma->vm_page_prot;
 	pgtable_t pgtable = NULL;
+	spinlock_t *ptl;
 
 	/*
 	 * If we had pmd_special, we could avoid all these restrictions,
@@ -1316,12 +1310,55 @@  vm_fault_t vmf_insert_pfn_pmd(struct vm_fault *vmf, pfn_t pfn, bool write)
 	}
 
 	track_pfn_insert(vma, &pgprot, pfn);
-
+	ptl = pmd_lock(vma->vm_mm, vmf->pmd);
 	insert_pfn_pmd(vma, addr, vmf->pmd, pfn, pgprot, write, pgtable);
+	spin_unlock(ptl);
+	if (pgtable)
+		pte_free(vma->vm_mm, pgtable);
+
 	return VM_FAULT_NOPAGE;
 }
 EXPORT_SYMBOL_GPL(vmf_insert_pfn_pmd);
 
+vm_fault_t dax_insert_pfn_pmd(struct vm_fault *vmf, pfn_t pfn, bool write)
+{
+	struct vm_area_struct *vma = vmf->vma;
+	unsigned long addr = vmf->address & PMD_MASK;
+	struct mm_struct *mm = vma->vm_mm;
+	spinlock_t *ptl;
+	pgtable_t pgtable = NULL;
+	struct folio *folio;
+	struct page *page;
+
+	if (addr < vma->vm_start || addr >= vma->vm_end)
+		return VM_FAULT_SIGBUS;
+
+	if (arch_needs_pgtable_deposit()) {
+		pgtable = pte_alloc_one(vma->vm_mm);
+		if (!pgtable)
+			return VM_FAULT_OOM;
+	}
+
+	track_pfn_insert(vma, &vma->vm_page_prot, pfn);
+
+	ptl = pmd_lock(mm, vmf->pmd);
+	if (pmd_none(*vmf->pmd)) {
+		page = pfn_t_to_page(pfn);
+		folio = page_folio(page);
+		folio_get(folio);
+		folio_add_file_rmap_pmd(folio, page, vma);
+		add_mm_counter(mm, mm_counter_file(folio), HPAGE_PMD_NR);
+	}
+	insert_pfn_pmd(vma, addr, vmf->pmd, pfn, vma->vm_page_prot,
+		write, pgtable);
+	spin_unlock(ptl);
+	if (pgtable)
+		pte_free(mm, pgtable);
+
+	return VM_FAULT_NOPAGE;
+}
+EXPORT_SYMBOL_GPL(dax_insert_pfn_pmd);
+
 #ifdef CONFIG_HAVE_ARCH_TRANSPARENT_HUGEPAGE_PUD
 static pud_t maybe_pud_mkwrite(pud_t pud, struct vm_area_struct *vma)
 {