@@ -577,22 +577,42 @@ static int hmm_vma_walk_pmd(pmd_t *pmdp,
{
struct hmm_vma_walk *hmm_vma_walk = walk->private;
struct hmm_range *range = hmm_vma_walk->range;
+ struct vm_area_struct *vma = walk->vma;
uint64_t *pfns = range->pfns;
unsigned long addr = start, i;
pte_t *ptep;
+ pmd_t pmd;
- i = (addr - range->start) >> PAGE_SHIFT;
again:
- if (pmd_none(*pmdp))
+ pmd = READ_ONCE(*pmdp);
+ if (pmd_none(pmd))
return hmm_vma_walk_hole(start, end, walk);
- if (pmd_huge(*pmdp) && (range->vma->vm_flags & VM_HUGETLB))
+ if (pmd_huge(pmd) && (range->vma->vm_flags & VM_HUGETLB))
return hmm_pfns_bad(start, end, walk);
- if (pmd_devmap(*pmdp) || pmd_trans_huge(*pmdp)) {
- pmd_t pmd;
+ if (thp_migration_supported() && is_pmd_migration_entry(pmd)) {
+ bool fault, write_fault;
+ unsigned long npages;
+ uint64_t *pfns;
+
+ i = (addr - range->start) >> PAGE_SHIFT;
+ npages = (end - addr) >> PAGE_SHIFT;
+ pfns = &range->pfns[i];
+
+ hmm_range_need_fault(hmm_vma_walk, pfns, npages,
+ 0, &fault, &write_fault);
+ if (fault || write_fault) {
+ hmm_vma_walk->last = addr;
+ pmd_migration_entry_wait(vma->vm_mm, pmdp);
+ return -EAGAIN;
+ }
+ return 0;
+ } else if (!pmd_present(pmd))
+ return hmm_pfns_bad(start, end, walk);
+ if (pmd_devmap(pmd) || pmd_trans_huge(pmd)) {
/*
* No need to take pmd_lock here, even if some other threads
* is splitting the huge pmd we will get that event through
@@ -607,13 +627,21 @@ static int hmm_vma_walk_pmd(pmd_t *pmdp,
if (!pmd_devmap(pmd) && !pmd_trans_huge(pmd))
goto again;
+ i = (addr - range->start) >> PAGE_SHIFT;
return hmm_vma_handle_pmd(walk, addr, end, &pfns[i], pmd);
}
- if (pmd_bad(*pmdp))
+ /*
+ * We have handled all the valid case above ie either none, migration,
+ * huge or transparent huge. At this point either it is a valid pmd
+ * entry pointing to pte directory or it is a bad pmd that will not
+ * recover.
+ */
+ if (pmd_bad(pmd))
return hmm_pfns_bad(start, end, walk);
ptep = pte_offset_map(pmdp, addr);
+ i = (addr - range->start) >> PAGE_SHIFT;
for (; addr < end; addr += PAGE_SIZE, ptep++, i++) {
int r;