diff mbox series

[v2,1/7] mm: khugepaged: retract_page_tables() use pte_offset_map_rw_nolock()

Message ID 4c3f4aa29f38c013c4529a43bce846a3edd31523.1730360798.git.zhengqi.arch@bytedance.com (mailing list archive)
State New
Headers show
Series synchronously scan and reclaim empty user PTE pages | expand

Commit Message

Qi Zheng Oct. 31, 2024, 8:13 a.m. UTC
In retract_page_tables(), we may modify the pmd entry after acquiring the
pml and ptl, so we should also check whether the pmd entry is stable.
Using pte_offset_map_rw_nolock() + pmd_same() to do it, and then we can
also remove the calling of the pte_lockptr().

Signed-off-by: Qi Zheng <zhengqi.arch@bytedance.com>
---
 mm/khugepaged.c | 17 ++++++++++++++++-
 1 file changed, 16 insertions(+), 1 deletion(-)

Comments

Jann Horn Nov. 6, 2024, 9:48 p.m. UTC | #1
On Thu, Oct 31, 2024 at 9:14 AM Qi Zheng <zhengqi.arch@bytedance.com> wrote:
> In retract_page_tables(), we may modify the pmd entry after acquiring the
> pml and ptl, so we should also check whether the pmd entry is stable.

Why does taking the PMD lock not guarantee that the PMD entry is stable?

> Using pte_offset_map_rw_nolock() + pmd_same() to do it, and then we can
> also remove the calling of the pte_lockptr().
>
> Signed-off-by: Qi Zheng <zhengqi.arch@bytedance.com>
> ---
>  mm/khugepaged.c | 17 ++++++++++++++++-
>  1 file changed, 16 insertions(+), 1 deletion(-)
>
> diff --git a/mm/khugepaged.c b/mm/khugepaged.c
> index 6f8d46d107b4b..6d76dde64f5fb 100644
> --- a/mm/khugepaged.c
> +++ b/mm/khugepaged.c
> @@ -1721,6 +1721,7 @@ static void retract_page_tables(struct address_space *mapping, pgoff_t pgoff)
>                 spinlock_t *pml;
>                 spinlock_t *ptl;
>                 bool skipped_uffd = false;
> +               pte_t *pte;
>
>                 /*
>                  * Check vma->anon_vma to exclude MAP_PRIVATE mappings that
> @@ -1756,11 +1757,25 @@ static void retract_page_tables(struct address_space *mapping, pgoff_t pgoff)
>                                         addr, addr + HPAGE_PMD_SIZE);
>                 mmu_notifier_invalidate_range_start(&range);
>
> +               pte = pte_offset_map_rw_nolock(mm, pmd, addr, &pgt_pmd, &ptl);
> +               if (!pte) {
> +                       mmu_notifier_invalidate_range_end(&range);
> +                       continue;
> +               }
> +
>                 pml = pmd_lock(mm, pmd);

I don't understand why you're mapping the page table before locking
the PMD. Doesn't that just mean we need more error checking
afterwards?


> -               ptl = pte_lockptr(mm, pmd);
>                 if (ptl != pml)
>                         spin_lock_nested(ptl, SINGLE_DEPTH_NESTING);
>
> +               if (unlikely(!pmd_same(pgt_pmd, pmdp_get_lockless(pmd)))) {
> +                       pte_unmap_unlock(pte, ptl);
> +                       if (ptl != pml)
> +                               spin_unlock(pml);
> +                       mmu_notifier_invalidate_range_end(&range);
> +                       continue;
> +               }
> +               pte_unmap(pte);
> +
>                 /*
>                  * Huge page lock is still held, so normally the page table
>                  * must remain empty; and we have already skipped anon_vma
> --
> 2.20.1
>
Qi Zheng Nov. 7, 2024, 7:54 a.m. UTC | #2
Hi Jann,

On 2024/11/7 05:48, Jann Horn wrote:
> On Thu, Oct 31, 2024 at 9:14 AM Qi Zheng <zhengqi.arch@bytedance.com> wrote:
>> In retract_page_tables(), we may modify the pmd entry after acquiring the
>> pml and ptl, so we should also check whether the pmd entry is stable.
> 
> Why does taking the PMD lock not guarantee that the PMD entry is stable?

Because the pmd entry may have changed before taking the pmd lock, so we
need to recheck it after taking the pmd or pte lock.

> 
>> Using pte_offset_map_rw_nolock() + pmd_same() to do it, and then we can
>> also remove the calling of the pte_lockptr().
>>
>> Signed-off-by: Qi Zheng <zhengqi.arch@bytedance.com>
>> ---
>>   mm/khugepaged.c | 17 ++++++++++++++++-
>>   1 file changed, 16 insertions(+), 1 deletion(-)
>>
>> diff --git a/mm/khugepaged.c b/mm/khugepaged.c
>> index 6f8d46d107b4b..6d76dde64f5fb 100644
>> --- a/mm/khugepaged.c
>> +++ b/mm/khugepaged.c
>> @@ -1721,6 +1721,7 @@ static void retract_page_tables(struct address_space *mapping, pgoff_t pgoff)
>>                  spinlock_t *pml;
>>                  spinlock_t *ptl;
>>                  bool skipped_uffd = false;
>> +               pte_t *pte;
>>
>>                  /*
>>                   * Check vma->anon_vma to exclude MAP_PRIVATE mappings that
>> @@ -1756,11 +1757,25 @@ static void retract_page_tables(struct address_space *mapping, pgoff_t pgoff)
>>                                          addr, addr + HPAGE_PMD_SIZE);
>>                  mmu_notifier_invalidate_range_start(&range);
>>
>> +               pte = pte_offset_map_rw_nolock(mm, pmd, addr, &pgt_pmd, &ptl);
>> +               if (!pte) {
>> +                       mmu_notifier_invalidate_range_end(&range);
>> +                       continue;
>> +               }
>> +
>>                  pml = pmd_lock(mm, pmd);
> 
> I don't understand why you're mapping the page table before locking
> the PMD. Doesn't that just mean we need more error checking
> afterwards?

The main purpose is to obtain the pmdval. If we don't use
pte_offset_map_rw_nolock, we should pay attention to recheck pmd entry
before pte_lockptr(), like this:

pmdval = pmdp_get_lockless(pmd);
pmd_lock
recheck pmdval
pte_lockptr(mm, pmd)

Otherwise, it may cause the system to crash. Consider the following
situation:

     CPU 0              CPU 1

zap_pte_range
--> clear pmd entry
     free pte page (by RCU)

                       retract_page_tables
                       --> pmd_lock
                           pte_lockptr(mm, pmd)  <-- BOOM!!

So maybe calling pte_offset_map_rw_nolock() is more convenient.

Thanks,
Qi


> 
> 
>> -               ptl = pte_lockptr(mm, pmd);
>>                  if (ptl != pml)
>>                          spin_lock_nested(ptl, SINGLE_DEPTH_NESTING);
>>
>> +               if (unlikely(!pmd_same(pgt_pmd, pmdp_get_lockless(pmd)))) {
>> +                       pte_unmap_unlock(pte, ptl);
>> +                       if (ptl != pml)
>> +                               spin_unlock(pml);
>> +                       mmu_notifier_invalidate_range_end(&range);
>> +                       continue;
>> +               }
>> +               pte_unmap(pte);
>> +
>>                  /*
>>                   * Huge page lock is still held, so normally the page table
>>                   * must remain empty; and we have already skipped anon_vma
>> --
>> 2.20.1
>>
Jann Horn Nov. 7, 2024, 5:57 p.m. UTC | #3
On Thu, Nov 7, 2024 at 8:54 AM Qi Zheng <zhengqi.arch@bytedance.com> wrote:
> On 2024/11/7 05:48, Jann Horn wrote:
> > On Thu, Oct 31, 2024 at 9:14 AM Qi Zheng <zhengqi.arch@bytedance.com> wrote:
> >> In retract_page_tables(), we may modify the pmd entry after acquiring the
> >> pml and ptl, so we should also check whether the pmd entry is stable.
> >
> > Why does taking the PMD lock not guarantee that the PMD entry is stable?
>
> Because the pmd entry may have changed before taking the pmd lock, so we
> need to recheck it after taking the pmd or pte lock.

You mean it could have changed from the value we obtained from
find_pmd_or_thp_or_none(mm, addr, &pmd)? I don't think that matters
though.

> >> Using pte_offset_map_rw_nolock() + pmd_same() to do it, and then we can
> >> also remove the calling of the pte_lockptr().
> >>
> >> Signed-off-by: Qi Zheng <zhengqi.arch@bytedance.com>
> >> ---
> >>   mm/khugepaged.c | 17 ++++++++++++++++-
> >>   1 file changed, 16 insertions(+), 1 deletion(-)
> >>
> >> diff --git a/mm/khugepaged.c b/mm/khugepaged.c
> >> index 6f8d46d107b4b..6d76dde64f5fb 100644
> >> --- a/mm/khugepaged.c
> >> +++ b/mm/khugepaged.c
> >> @@ -1721,6 +1721,7 @@ static void retract_page_tables(struct address_space *mapping, pgoff_t pgoff)
> >>                  spinlock_t *pml;
> >>                  spinlock_t *ptl;
> >>                  bool skipped_uffd = false;
> >> +               pte_t *pte;
> >>
> >>                  /*
> >>                   * Check vma->anon_vma to exclude MAP_PRIVATE mappings that
> >> @@ -1756,11 +1757,25 @@ static void retract_page_tables(struct address_space *mapping, pgoff_t pgoff)
> >>                                          addr, addr + HPAGE_PMD_SIZE);
> >>                  mmu_notifier_invalidate_range_start(&range);
> >>
> >> +               pte = pte_offset_map_rw_nolock(mm, pmd, addr, &pgt_pmd, &ptl);
> >> +               if (!pte) {
> >> +                       mmu_notifier_invalidate_range_end(&range);
> >> +                       continue;
> >> +               }
> >> +
> >>                  pml = pmd_lock(mm, pmd);
> >
> > I don't understand why you're mapping the page table before locking
> > the PMD. Doesn't that just mean we need more error checking
> > afterwards?
>
> The main purpose is to obtain the pmdval. If we don't use
> pte_offset_map_rw_nolock, we should pay attention to recheck pmd entry
> before pte_lockptr(), like this:
>
> pmdval = pmdp_get_lockless(pmd);
> pmd_lock
> recheck pmdval
> pte_lockptr(mm, pmd)
>
> Otherwise, it may cause the system to crash. Consider the following
> situation:
>
>      CPU 0              CPU 1
>
> zap_pte_range
> --> clear pmd entry
>      free pte page (by RCU)
>
>                        retract_page_tables
>                        --> pmd_lock
>                            pte_lockptr(mm, pmd)  <-- BOOM!!
>
> So maybe calling pte_offset_map_rw_nolock() is more convenient.

How about refactoring find_pmd_or_thp_or_none() like this, by moving
the checks of the PMD entry value into a separate helper:



-static int find_pmd_or_thp_or_none(struct mm_struct *mm,
-                                  unsigned long address,
-                                  pmd_t **pmd)
+static int check_pmd_state(pmd_t *pmd)
 {
-       pmd_t pmde;
+       pmd_t pmde = pmdp_get_lockless(*pmd);

-       *pmd = mm_find_pmd(mm, address);
-       if (!*pmd)
-               return SCAN_PMD_NULL;
-
-       pmde = pmdp_get_lockless(*pmd);
        if (pmd_none(pmde))
                return SCAN_PMD_NONE;
        if (!pmd_present(pmde))
                return SCAN_PMD_NULL;
        if (pmd_trans_huge(pmde))
                return SCAN_PMD_MAPPED;
        if (pmd_devmap(pmde))
                return SCAN_PMD_NULL;
        if (pmd_bad(pmde))
                return SCAN_PMD_NULL;
        return SCAN_SUCCEED;
 }

+static int find_pmd_or_thp_or_none(struct mm_struct *mm,
+                                  unsigned long address,
+                                  pmd_t **pmd)
+{
+
+       *pmd = mm_find_pmd(mm, address);
+       if (!*pmd)
+               return SCAN_PMD_NULL;
+       return check_pmd_state(*pmd);
+}
+


And simplifying retract_page_tables() a little bit like this:


        i_mmap_lock_read(mapping);
        vma_interval_tree_foreach(vma, &mapping->i_mmap, pgoff, pgoff) {
                struct mmu_notifier_range range;
                struct mm_struct *mm;
                unsigned long addr;
                pmd_t *pmd, pgt_pmd;
                spinlock_t *pml;
                spinlock_t *ptl;
-               bool skipped_uffd = false;
+               bool success = false;

                /*
                 * Check vma->anon_vma to exclude MAP_PRIVATE mappings that
                 * got written to. These VMAs are likely not worth removing
                 * page tables from, as PMD-mapping is likely to be split later.
                 */
                if (READ_ONCE(vma->anon_vma))
                        continue;

                addr = vma->vm_start + ((pgoff - vma->vm_pgoff) << PAGE_SHIFT);
@@ -1763,34 +1767,34 @@ static void retract_page_tables(struct
address_space *mapping, pgoff_t pgoff)

                /*
                 * Huge page lock is still held, so normally the page table
                 * must remain empty; and we have already skipped anon_vma
                 * and userfaultfd_wp() vmas.  But since the mmap_lock is not
                 * held, it is still possible for a racing userfaultfd_ioctl()
                 * to have inserted ptes or markers.  Now that we hold ptlock,
                 * repeating the anon_vma check protects from one category,
                 * and repeating the userfaultfd_wp() check from another.
                 */
-               if (unlikely(vma->anon_vma || userfaultfd_wp(vma))) {
-                       skipped_uffd = true;
-               } else {
+               if (likely(!vma->anon_vma && !userfaultfd_wp(vma))) {
                        pgt_pmd = pmdp_collapse_flush(vma, addr, pmd);
                        pmdp_get_lockless_sync();
+                       success = true;
                }

                if (ptl != pml)
                        spin_unlock(ptl);
+drop_pml:
                spin_unlock(pml);

                mmu_notifier_invalidate_range_end(&range);

-               if (!skipped_uffd) {
+               if (success) {
                        mm_dec_nr_ptes(mm);
                        page_table_check_pte_clear_range(mm, addr, pgt_pmd);
                        pte_free_defer(mm, pmd_pgtable(pgt_pmd));
                }
        }
        i_mmap_unlock_read(mapping);


And then instead of your patch, I think you can just do this?


@@ -1754,20 +1754,22 @@ static void retract_page_tables(struct
address_space *mapping, pgoff_t pgoff)
                 */
                if (userfaultfd_wp(vma))
                        continue;

                /* PTEs were notified when unmapped; but now for the PMD? */
                mmu_notifier_range_init(&range, MMU_NOTIFY_CLEAR, 0, mm,
                                        addr, addr + HPAGE_PMD_SIZE);
                mmu_notifier_invalidate_range_start(&range);

                pml = pmd_lock(mm, pmd);
+               if (check_pmd_state(mm, addr, pmd) != SCAN_SUCCEED)
+                       goto drop_pml;
                ptl = pte_lockptr(mm, pmd);
                if (ptl != pml)
                        spin_lock_nested(ptl, SINGLE_DEPTH_NESTING);

                /*
                 * Huge page lock is still held, so normally the page table
                 * must remain empty; and we have already skipped anon_vma
                 * and userfaultfd_wp() vmas.  But since the mmap_lock is not
                 * held, it is still possible for a racing userfaultfd_ioctl()
                 * to have inserted ptes or markers.  Now that we hold ptlock,
Qi Zheng Nov. 8, 2024, 6:31 a.m. UTC | #4
On 2024/11/8 01:57, Jann Horn wrote:
> On Thu, Nov 7, 2024 at 8:54 AM Qi Zheng <zhengqi.arch@bytedance.com> wrote:
>> On 2024/11/7 05:48, Jann Horn wrote:
>>> On Thu, Oct 31, 2024 at 9:14 AM Qi Zheng <zhengqi.arch@bytedance.com> wrote:
>>>> In retract_page_tables(), we may modify the pmd entry after acquiring the
>>>> pml and ptl, so we should also check whether the pmd entry is stable.
>>>
>>> Why does taking the PMD lock not guarantee that the PMD entry is stable?
>>
>> Because the pmd entry may have changed before taking the pmd lock, so we
>> need to recheck it after taking the pmd or pte lock.
> 
> You mean it could have changed from the value we obtained from
> find_pmd_or_thp_or_none(mm, addr, &pmd)? I don't think that matters
> though.
> 
>>>> Using pte_offset_map_rw_nolock() + pmd_same() to do it, and then we can
>>>> also remove the calling of the pte_lockptr().
>>>>
>>>> Signed-off-by: Qi Zheng <zhengqi.arch@bytedance.com>
>>>> ---
>>>>    mm/khugepaged.c | 17 ++++++++++++++++-
>>>>    1 file changed, 16 insertions(+), 1 deletion(-)
>>>>
>>>> diff --git a/mm/khugepaged.c b/mm/khugepaged.c
>>>> index 6f8d46d107b4b..6d76dde64f5fb 100644
>>>> --- a/mm/khugepaged.c
>>>> +++ b/mm/khugepaged.c
>>>> @@ -1721,6 +1721,7 @@ static void retract_page_tables(struct address_space *mapping, pgoff_t pgoff)
>>>>                   spinlock_t *pml;
>>>>                   spinlock_t *ptl;
>>>>                   bool skipped_uffd = false;
>>>> +               pte_t *pte;
>>>>
>>>>                   /*
>>>>                    * Check vma->anon_vma to exclude MAP_PRIVATE mappings that
>>>> @@ -1756,11 +1757,25 @@ static void retract_page_tables(struct address_space *mapping, pgoff_t pgoff)
>>>>                                           addr, addr + HPAGE_PMD_SIZE);
>>>>                   mmu_notifier_invalidate_range_start(&range);
>>>>
>>>> +               pte = pte_offset_map_rw_nolock(mm, pmd, addr, &pgt_pmd, &ptl);
>>>> +               if (!pte) {
>>>> +                       mmu_notifier_invalidate_range_end(&range);
>>>> +                       continue;
>>>> +               }
>>>> +
>>>>                   pml = pmd_lock(mm, pmd);
>>>
>>> I don't understand why you're mapping the page table before locking
>>> the PMD. Doesn't that just mean we need more error checking
>>> afterwards?
>>
>> The main purpose is to obtain the pmdval. If we don't use
>> pte_offset_map_rw_nolock, we should pay attention to recheck pmd entry
>> before pte_lockptr(), like this:
>>
>> pmdval = pmdp_get_lockless(pmd);
>> pmd_lock
>> recheck pmdval
>> pte_lockptr(mm, pmd)
>>
>> Otherwise, it may cause the system to crash. Consider the following
>> situation:
>>
>>       CPU 0              CPU 1
>>
>> zap_pte_range
>> --> clear pmd entry
>>       free pte page (by RCU)
>>
>>                         retract_page_tables
>>                         --> pmd_lock
>>                             pte_lockptr(mm, pmd)  <-- BOOM!!
>>
>> So maybe calling pte_offset_map_rw_nolock() is more convenient.
> 
> How about refactoring find_pmd_or_thp_or_none() like this, by moving
> the checks of the PMD entry value into a separate helper:
> 
> 
> 
> -static int find_pmd_or_thp_or_none(struct mm_struct *mm,
> -                                  unsigned long address,
> -                                  pmd_t **pmd)
> +static int check_pmd_state(pmd_t *pmd)
>   {
> -       pmd_t pmde;
> +       pmd_t pmde = pmdp_get_lockless(*pmd);
> 
> -       *pmd = mm_find_pmd(mm, address);
> -       if (!*pmd)
> -               return SCAN_PMD_NULL;
> -
> -       pmde = pmdp_get_lockless(*pmd);
>          if (pmd_none(pmde))
>                  return SCAN_PMD_NONE;
>          if (!pmd_present(pmde))
>                  return SCAN_PMD_NULL;
>          if (pmd_trans_huge(pmde))
>                  return SCAN_PMD_MAPPED;
>          if (pmd_devmap(pmde))
>                  return SCAN_PMD_NULL;
>          if (pmd_bad(pmde))
>                  return SCAN_PMD_NULL;
>          return SCAN_SUCCEED;
>   }
> 
> +static int find_pmd_or_thp_or_none(struct mm_struct *mm,
> +                                  unsigned long address,
> +                                  pmd_t **pmd)
> +{
> +
> +       *pmd = mm_find_pmd(mm, address);
> +       if (!*pmd)
> +               return SCAN_PMD_NULL;
> +       return check_pmd_state(*pmd);
> +}
> +
> 
> 
> And simplifying retract_page_tables() a little bit like this:
> 
> 
>          i_mmap_lock_read(mapping);
>          vma_interval_tree_foreach(vma, &mapping->i_mmap, pgoff, pgoff) {
>                  struct mmu_notifier_range range;
>                  struct mm_struct *mm;
>                  unsigned long addr;
>                  pmd_t *pmd, pgt_pmd;
>                  spinlock_t *pml;
>                  spinlock_t *ptl;
> -               bool skipped_uffd = false;
> +               bool success = false;
> 
>                  /*
>                   * Check vma->anon_vma to exclude MAP_PRIVATE mappings that
>                   * got written to. These VMAs are likely not worth removing
>                   * page tables from, as PMD-mapping is likely to be split later.
>                   */
>                  if (READ_ONCE(vma->anon_vma))
>                          continue;
> 
>                  addr = vma->vm_start + ((pgoff - vma->vm_pgoff) << PAGE_SHIFT);
> @@ -1763,34 +1767,34 @@ static void retract_page_tables(struct
> address_space *mapping, pgoff_t pgoff)
> 
>                  /*
>                   * Huge page lock is still held, so normally the page table
>                   * must remain empty; and we have already skipped anon_vma
>                   * and userfaultfd_wp() vmas.  But since the mmap_lock is not
>                   * held, it is still possible for a racing userfaultfd_ioctl()
>                   * to have inserted ptes or markers.  Now that we hold ptlock,
>                   * repeating the anon_vma check protects from one category,
>                   * and repeating the userfaultfd_wp() check from another.
>                   */
> -               if (unlikely(vma->anon_vma || userfaultfd_wp(vma))) {
> -                       skipped_uffd = true;
> -               } else {
> +               if (likely(!vma->anon_vma && !userfaultfd_wp(vma))) {
>                          pgt_pmd = pmdp_collapse_flush(vma, addr, pmd);
>                          pmdp_get_lockless_sync();
> +                       success = true;
>                  }
> 
>                  if (ptl != pml)
>                          spin_unlock(ptl);
> +drop_pml:
>                  spin_unlock(pml);
> 
>                  mmu_notifier_invalidate_range_end(&range);
> 
> -               if (!skipped_uffd) {
> +               if (success) {
>                          mm_dec_nr_ptes(mm);
>                          page_table_check_pte_clear_range(mm, addr, pgt_pmd);
>                          pte_free_defer(mm, pmd_pgtable(pgt_pmd));
>                  }
>          }
>          i_mmap_unlock_read(mapping);
> 
> 
> And then instead of your patch, I think you can just do this?

Ah, this does look much better! Will change to this in the next version.

Thanks!

> 
> 
> @@ -1754,20 +1754,22 @@ static void retract_page_tables(struct
> address_space *mapping, pgoff_t pgoff)
>                   */
>                  if (userfaultfd_wp(vma))
>                          continue;
> 
>                  /* PTEs were notified when unmapped; but now for the PMD? */
>                  mmu_notifier_range_init(&range, MMU_NOTIFY_CLEAR, 0, mm,
>                                          addr, addr + HPAGE_PMD_SIZE);
>                  mmu_notifier_invalidate_range_start(&range);
> 
>                  pml = pmd_lock(mm, pmd);
> +               if (check_pmd_state(mm, addr, pmd) != SCAN_SUCCEED)
> +                       goto drop_pml;
>                  ptl = pte_lockptr(mm, pmd);
>                  if (ptl != pml)
>                          spin_lock_nested(ptl, SINGLE_DEPTH_NESTING);
> 
>                  /*
>                   * Huge page lock is still held, so normally the page table
>                   * must remain empty; and we have already skipped anon_vma
>                   * and userfaultfd_wp() vmas.  But since the mmap_lock is not
>                   * held, it is still possible for a racing userfaultfd_ioctl()
>                   * to have inserted ptes or markers.  Now that we hold ptlock,
diff mbox series

Patch

diff --git a/mm/khugepaged.c b/mm/khugepaged.c
index 6f8d46d107b4b..6d76dde64f5fb 100644
--- a/mm/khugepaged.c
+++ b/mm/khugepaged.c
@@ -1721,6 +1721,7 @@  static void retract_page_tables(struct address_space *mapping, pgoff_t pgoff)
 		spinlock_t *pml;
 		spinlock_t *ptl;
 		bool skipped_uffd = false;
+		pte_t *pte;
 
 		/*
 		 * Check vma->anon_vma to exclude MAP_PRIVATE mappings that
@@ -1756,11 +1757,25 @@  static void retract_page_tables(struct address_space *mapping, pgoff_t pgoff)
 					addr, addr + HPAGE_PMD_SIZE);
 		mmu_notifier_invalidate_range_start(&range);
 
+		pte = pte_offset_map_rw_nolock(mm, pmd, addr, &pgt_pmd, &ptl);
+		if (!pte) {
+			mmu_notifier_invalidate_range_end(&range);
+			continue;
+		}
+
 		pml = pmd_lock(mm, pmd);
-		ptl = pte_lockptr(mm, pmd);
 		if (ptl != pml)
 			spin_lock_nested(ptl, SINGLE_DEPTH_NESTING);
 
+		if (unlikely(!pmd_same(pgt_pmd, pmdp_get_lockless(pmd)))) {
+			pte_unmap_unlock(pte, ptl);
+			if (ptl != pml)
+				spin_unlock(pml);
+			mmu_notifier_invalidate_range_end(&range);
+			continue;
+		}
+		pte_unmap(pte);
+
 		/*
 		 * Huge page lock is still held, so normally the page table
 		 * must remain empty; and we have already skipped anon_vma