diff mbox series

[RFC,07/12] khugepaged: Scan PTEs order-wise

Message ID 20241216165105.56185-8-dev.jain@arm.com (mailing list archive)
State New
Headers show
Series khugepaged: Asynchronous mTHP collapse | expand

Commit Message

Dev Jain Dec. 16, 2024, 4:51 p.m. UTC
Scan the PTEs order-wise, using the mask of suitable orders for this VMA
derived in conjunction with sysfs THP settings. Scale down the tunables; in
case of collapse failure, we drop down to the next order. Otherwise, we try to
jump to the highest possible order and then start a fresh scan. Note that
madvise(MADV_COLLAPSE) has not been generalized.

Signed-off-by: Dev Jain <dev.jain@arm.com>
---
 mm/khugepaged.c | 84 ++++++++++++++++++++++++++++++++++++++++---------
 1 file changed, 69 insertions(+), 15 deletions(-)

Comments

Ryan Roberts Dec. 17, 2024, 6:15 p.m. UTC | #1
On 16/12/2024 16:51, Dev Jain wrote:
> Scan the PTEs order-wise, using the mask of suitable orders for this VMA
> derived in conjunction with sysfs THP settings. Scale down the tunables; in
> case of collapse failure, we drop down to the next order. Otherwise, we try to
> jump to the highest possible order and then start a fresh scan. Note that
> madvise(MADV_COLLAPSE) has not been generalized.

Is there are reason you are not modifying MADV_COLLAPSE? It's really just a
synchonous way to do what khugepaged does asynchonously (isn't it?), so it would
behave the same way in an ideal world.

> 
> Signed-off-by: Dev Jain <dev.jain@arm.com>
> ---
>  mm/khugepaged.c | 84 ++++++++++++++++++++++++++++++++++++++++---------
>  1 file changed, 69 insertions(+), 15 deletions(-)
> 
> diff --git a/mm/khugepaged.c b/mm/khugepaged.c
> index 886c76816963..078794aa3335 100644
> --- a/mm/khugepaged.c
> +++ b/mm/khugepaged.c
> @@ -20,6 +20,7 @@
>  #include <linux/swapops.h>
>  #include <linux/shmem_fs.h>
>  #include <linux/ksm.h>
> +#include <linux/count_zeros.h>
>  
>  #include <asm/tlb.h>
>  #include <asm/pgalloc.h>
> @@ -1111,7 +1112,7 @@ static int alloc_charge_folio(struct folio **foliop, struct mm_struct *mm,
>  }
>  
>  static int collapse_huge_page(struct mm_struct *mm, unsigned long address,
> -			      int referenced, int unmapped,
> +			      int referenced, int unmapped, int order,
>  			      struct collapse_control *cc)
>  {
>  	LIST_HEAD(compound_pagelist);
> @@ -1278,38 +1279,59 @@ static int hpage_collapse_scan_ptes(struct mm_struct *mm,
>  				   unsigned long address, bool *mmap_locked,
>  				   struct collapse_control *cc)
>  {
> -	pmd_t *pmd;
> -	pte_t *pte, *_pte;
> -	int result = SCAN_FAIL, referenced = 0;
> -	int none_or_zero = 0, shared = 0;
> -	struct page *page = NULL;
> +	unsigned int max_ptes_shared, max_ptes_none, max_ptes_swap;
> +	int referenced, shared, none_or_zero, unmapped;
> +	unsigned long _address, org_address = address;

nit: Perhaps it's clearer to keep the original address in address and use a
variable, start, for the starting point of each scan?

>  	struct folio *folio = NULL;
> -	unsigned long _address;
> -	spinlock_t *ptl;
> -	int node = NUMA_NO_NODE, unmapped = 0;
> +	struct page *page = NULL;
> +	int node = NUMA_NO_NODE;
> +	int result = SCAN_FAIL;
>  	bool writable = false;
> +	unsigned long orders;
> +	pte_t *pte, *_pte;
> +	spinlock_t *ptl;
> +	pmd_t *pmd;
> +	int order;
>  
>  	VM_BUG_ON(address & ~HPAGE_PMD_MASK);
>  
> +	orders = thp_vma_allowable_orders(vma, vma->vm_flags,
> +			TVA_IN_PF | TVA_ENFORCE_SYSFS, BIT(PMD_ORDER + 1) - 1);

Perhaps THP_ORDERS_ALL instead of "BIT(PMD_ORDER + 1) - 1"?

> +	orders = thp_vma_suitable_orders(vma, address, orders);
> +	order = highest_order(orders);
> +
> +	/* MADV_COLLAPSE needs to work irrespective of sysfs setting */
> +	if (!cc->is_khugepaged)
> +		order = HPAGE_PMD_ORDER;
> +
> +scan_pte_range:
> +
> +	max_ptes_shared = khugepaged_max_ptes_shared >> (HPAGE_PMD_ORDER - order);
> +	max_ptes_none = khugepaged_max_ptes_none >> (HPAGE_PMD_ORDER - order);
> +	max_ptes_swap = khugepaged_max_ptes_swap >> (HPAGE_PMD_ORDER - order);
> +	referenced = 0, shared = 0, none_or_zero = 0, unmapped = 0;
> +
> +	/* Check pmd after taking mmap lock */
>  	result = find_pmd_or_thp_or_none(mm, address, &pmd);
>  	if (result != SCAN_SUCCEED)
>  		goto out;
>  
>  	memset(cc->node_load, 0, sizeof(cc->node_load));
>  	nodes_clear(cc->alloc_nmask);
> +
>  	pte = pte_offset_map_lock(mm, pmd, address, &ptl);
>  	if (!pte) {
>  		result = SCAN_PMD_NULL;
>  		goto out;
>  	}
>  
> -	for (_address = address, _pte = pte; _pte < pte + HPAGE_PMD_NR;
> +	for (_address = address, _pte = pte; _pte < pte + (1UL << order);
>  	     _pte++, _address += PAGE_SIZE) {
>  		pte_t pteval = ptep_get(_pte);
>  		if (is_swap_pte(pteval)) {
>  			++unmapped;
>  			if (!cc->is_khugepaged ||
> -			    unmapped <= khugepaged_max_ptes_swap) {
> +			    unmapped <= max_ptes_swap) {
>  				/*
>  				 * Always be strict with uffd-wp
>  				 * enabled swap entries.  Please see
> @@ -1330,7 +1352,7 @@ static int hpage_collapse_scan_ptes(struct mm_struct *mm,
>  			++none_or_zero;
>  			if (!userfaultfd_armed(vma) &&
>  			    (!cc->is_khugepaged ||
> -			     none_or_zero <= khugepaged_max_ptes_none)) {
> +			     none_or_zero <= max_ptes_none)) {
>  				continue;
>  			} else {
>  				result = SCAN_EXCEED_NONE_PTE;
> @@ -1375,7 +1397,7 @@ static int hpage_collapse_scan_ptes(struct mm_struct *mm,
>  		if (folio_likely_mapped_shared(folio)) {
>  			++shared;
>  			if (cc->is_khugepaged &&
> -			    shared > khugepaged_max_ptes_shared) {
> +			    shared > max_ptes_shared) {
>  				result = SCAN_EXCEED_SHARED_PTE;
>  				count_vm_event(THP_SCAN_EXCEED_SHARED_PTE);
>  				goto out_unmap;
> @@ -1432,7 +1454,7 @@ static int hpage_collapse_scan_ptes(struct mm_struct *mm,
>  		result = SCAN_PAGE_RO;
>  	} else if (cc->is_khugepaged &&
>  		   (!referenced ||
> -		    (unmapped && referenced < HPAGE_PMD_NR / 2))) {
> +		    (unmapped && referenced < (1UL << order) / 2))) {
>  		result = SCAN_LACK_REFERENCED_PAGE;
>  	} else {
>  		result = SCAN_SUCCEED;
> @@ -1441,9 +1463,41 @@ static int hpage_collapse_scan_ptes(struct mm_struct *mm,
>  	pte_unmap_unlock(pte, ptl);
>  	if (result == SCAN_SUCCEED) {
>  		result = collapse_huge_page(mm, address, referenced,
> -					    unmapped, cc);
> +					    unmapped, order, cc);
>  		/* collapse_huge_page will return with the mmap_lock released */
>  		*mmap_locked = false;
> +
> +		/* Immediately exit on exhaustion of range */
> +		if (_address == org_address + (PAGE_SIZE << HPAGE_PMD_ORDER))
> +			goto out;

Looks like this assumes this function is always asked to scan a full PTE table?
Does that mean that you can't handle collapse for VMAs that don't span a whole
PMD entry? I think we will want to support that.

> +	}
> +	if (result != SCAN_SUCCEED) {
> +
> +		/* Go to the next order. */
> +		order = next_order(&orders, order);
> +		if (order < 2)

This should be:
		if (!orders)

I think the return order is undefined when order is the last order in orders.

> +			goto out;
> +		goto maybe_mmap_lock;
> +	} else {
> +		address = _address;
> +		pte = _pte;
> +
> +
> +		/* Get highest order possible starting from address */
> +		order = count_trailing_zeros(address >> PAGE_SHIFT);
> +
> +		/* This needs to be present in the mask too */
> +		if (!(orders & (1UL << order)))
> +			order = next_order(&orders, order);

Not quite; if the exact order isn't in the bitmap, this will pick out the
highest order in the bitmap, which may be higher than count_trailing_zeros()
returned. You could do:

		order = count_trailing_zeros(address >> PAGE_SHIFT);
		orders &= (1UL << order + 1) - 1;
		order = next_order(&orders, order);
		if (!orders)
			goto out;

That will mask out any orders that are bigger than the one returned by
count_trailing_zeros() then next_order() will return the highest order in the
remaining set.

But even that doesn't quite work because next_order() is destructive. Once you
arrive on a higher order address boundary, you want to be able to select a
higher order from the original orders bitmap. But you have lost them on a
previous go around the loop.

Perhaps stash orig_orders at the top of the function when you first calculate
it. Then I think this works (totally untested):

		order = count_trailing_zeros(address >> PAGE_SHIFT);
		orders = orig_orders & (1UL << order + 1) - 1;
		order = next_order(&orders, order);
		if (!orders)
			goto out;

You might want to do something like this for the first go around the loop, but I
think address is currently always at the start of the PMD on entry, so not
needed until that restriction is removed.

> +		if (order < 2)
> +			goto out;
> +
> +maybe_mmap_lock:
> +		if (!(*mmap_locked)) {
> +			mmap_read_lock(mm);

Given the lock was already held in read mode on entering this function, then
released by collapse_huge_page(), is it definitely safe to retake this lock and
rerun this function? Is it possible that state that was checked before entering
this function has changed since the lock was released that would now need
re-checking?

> +			*mmap_locked = true;
> +		}
> +		goto scan_pte_range;
>  	}
>  out:
>  	trace_mm_khugepaged_scan_pmd(mm, &folio->page, writable, referenced,
Dev Jain Dec. 18, 2024, 9:24 a.m. UTC | #2
On 17/12/24 11:45 pm, Ryan Roberts wrote:
> On 16/12/2024 16:51, Dev Jain wrote:
>> Scan the PTEs order-wise, using the mask of suitable orders for this VMA
>> derived in conjunction with sysfs THP settings. Scale down the tunables; in
>> case of collapse failure, we drop down to the next order. Otherwise, we try to
>> jump to the highest possible order and then start a fresh scan. Note that
>> madvise(MADV_COLLAPSE) has not been generalized.
> Is there are reason you are not modifying MADV_COLLAPSE? It's really just a
> synchonous way to do what khugepaged does asynchonously (isn't it?), so it would
> behave the same way in an ideal world.

Correct, but I started running into return value problems for madvise(). For example,
the return value of hpage_collapse_scan_ptes() will be the return value of the last
mTHP scan. In this case, what do we want madvise() to return? If I collapse the range
to multiple 64K mTHPs, then I should still return failure, because otherwise the caller would
logically assume that MADV_COLLAPSE succeeded so will assume a PMD-hugepage mapped there. But
then the caller ended up collapsing memory...if you return success, then the khugepaged selftest
starts failing. Basically, this will be (kind of?) an ABI change and I really didn't want to sway
the discussion away from khugepaged, so I just kept it simple :)

>
>> Signed-off-by: Dev Jain <dev.jain@arm.com>
>> ---
>>   mm/khugepaged.c | 84 ++++++++++++++++++++++++++++++++++++++++---------
>>   1 file changed, 69 insertions(+), 15 deletions(-)
>>
>> diff --git a/mm/khugepaged.c b/mm/khugepaged.c
>> index 886c76816963..078794aa3335 100644
>> --- a/mm/khugepaged.c
>> +++ b/mm/khugepaged.c
>> @@ -20,6 +20,7 @@
>>   #include <linux/swapops.h>
>>   #include <linux/shmem_fs.h>
>>   #include <linux/ksm.h>
>> +#include <linux/count_zeros.h>
>>   
>>   #include <asm/tlb.h>
>>   #include <asm/pgalloc.h>
>> @@ -1111,7 +1112,7 @@ static int alloc_charge_folio(struct folio **foliop, struct mm_struct *mm,
>>   }
>>   
>>   static int collapse_huge_page(struct mm_struct *mm, unsigned long address,
>> -			      int referenced, int unmapped,
>> +			      int referenced, int unmapped, int order,
>>   			      struct collapse_control *cc)
>>   {
>>   	LIST_HEAD(compound_pagelist);
>> @@ -1278,38 +1279,59 @@ static int hpage_collapse_scan_ptes(struct mm_struct *mm,
>>   				   unsigned long address, bool *mmap_locked,
>>   				   struct collapse_control *cc)
>>   {
>> -	pmd_t *pmd;
>> -	pte_t *pte, *_pte;
>> -	int result = SCAN_FAIL, referenced = 0;
>> -	int none_or_zero = 0, shared = 0;
>> -	struct page *page = NULL;
>> +	unsigned int max_ptes_shared, max_ptes_none, max_ptes_swap;
>> +	int referenced, shared, none_or_zero, unmapped;
>> +	unsigned long _address, org_address = address;
> nit: Perhaps it's clearer to keep the original address in address and use a
> variable, start, for the starting point of each scan?

Probably...will keep it in mind.

>
>>   	struct folio *folio = NULL;
>> -	unsigned long _address;
>> -	spinlock_t *ptl;
>> -	int node = NUMA_NO_NODE, unmapped = 0;
>> +	struct page *page = NULL;
>> +	int node = NUMA_NO_NODE;
>> +	int result = SCAN_FAIL;
>>   	bool writable = false;
>> +	unsigned long orders;
>> +	pte_t *pte, *_pte;
>> +	spinlock_t *ptl;
>> +	pmd_t *pmd;
>> +	int order;
>>   
>>   	VM_BUG_ON(address & ~HPAGE_PMD_MASK);
>>   
>> +	orders = thp_vma_allowable_orders(vma, vma->vm_flags,
>> +			TVA_IN_PF | TVA_ENFORCE_SYSFS, BIT(PMD_ORDER + 1) - 1);
> Perhaps THP_ORDERS_ALL instead of "BIT(PMD_ORDER + 1) - 1"?

Ah yes, THP_ORDERS_ALL_ANON.

>
>> +	orders = thp_vma_suitable_orders(vma, address, orders);
>> +	order = highest_order(orders);
>> +
>> +	/* MADV_COLLAPSE needs to work irrespective of sysfs setting */
>> +	if (!cc->is_khugepaged)
>> +		order = HPAGE_PMD_ORDER;
>> +
>> +scan_pte_range:
>> +
>> +	max_ptes_shared = khugepaged_max_ptes_shared >> (HPAGE_PMD_ORDER - order);
>> +	max_ptes_none = khugepaged_max_ptes_none >> (HPAGE_PMD_ORDER - order);
>> +	max_ptes_swap = khugepaged_max_ptes_swap >> (HPAGE_PMD_ORDER - order);
>> +	referenced = 0, shared = 0, none_or_zero = 0, unmapped = 0;
>> +
>> +	/* Check pmd after taking mmap lock */
>>   	result = find_pmd_or_thp_or_none(mm, address, &pmd);
>>   	if (result != SCAN_SUCCEED)
>>   		goto out;
>>   
>>   	memset(cc->node_load, 0, sizeof(cc->node_load));
>>   	nodes_clear(cc->alloc_nmask);
>> +
>>   	pte = pte_offset_map_lock(mm, pmd, address, &ptl);
>>   	if (!pte) {
>>   		result = SCAN_PMD_NULL;
>>   		goto out;
>>   	}
>>   
>> -	for (_address = address, _pte = pte; _pte < pte + HPAGE_PMD_NR;
>> +	for (_address = address, _pte = pte; _pte < pte + (1UL << order);
>>   	     _pte++, _address += PAGE_SIZE) {
>>   		pte_t pteval = ptep_get(_pte);
>>   		if (is_swap_pte(pteval)) {
>>   			++unmapped;
>>   			if (!cc->is_khugepaged ||
>> -			    unmapped <= khugepaged_max_ptes_swap) {
>> +			    unmapped <= max_ptes_swap) {
>>   				/*
>>   				 * Always be strict with uffd-wp
>>   				 * enabled swap entries.  Please see
>> @@ -1330,7 +1352,7 @@ static int hpage_collapse_scan_ptes(struct mm_struct *mm,
>>   			++none_or_zero;
>>   			if (!userfaultfd_armed(vma) &&
>>   			    (!cc->is_khugepaged ||
>> -			     none_or_zero <= khugepaged_max_ptes_none)) {
>> +			     none_or_zero <= max_ptes_none)) {
>>   				continue;
>>   			} else {
>>   				result = SCAN_EXCEED_NONE_PTE;
>> @@ -1375,7 +1397,7 @@ static int hpage_collapse_scan_ptes(struct mm_struct *mm,
>>   		if (folio_likely_mapped_shared(folio)) {
>>   			++shared;
>>   			if (cc->is_khugepaged &&
>> -			    shared > khugepaged_max_ptes_shared) {
>> +			    shared > max_ptes_shared) {
>>   				result = SCAN_EXCEED_SHARED_PTE;
>>   				count_vm_event(THP_SCAN_EXCEED_SHARED_PTE);
>>   				goto out_unmap;
>> @@ -1432,7 +1454,7 @@ static int hpage_collapse_scan_ptes(struct mm_struct *mm,
>>   		result = SCAN_PAGE_RO;
>>   	} else if (cc->is_khugepaged &&
>>   		   (!referenced ||
>> -		    (unmapped && referenced < HPAGE_PMD_NR / 2))) {
>> +		    (unmapped && referenced < (1UL << order) / 2))) {
>>   		result = SCAN_LACK_REFERENCED_PAGE;
>>   	} else {
>>   		result = SCAN_SUCCEED;
>> @@ -1441,9 +1463,41 @@ static int hpage_collapse_scan_ptes(struct mm_struct *mm,
>>   	pte_unmap_unlock(pte, ptl);
>>   	if (result == SCAN_SUCCEED) {
>>   		result = collapse_huge_page(mm, address, referenced,
>> -					    unmapped, cc);
>> +					    unmapped, order, cc);
>>   		/* collapse_huge_page will return with the mmap_lock released */
>>   		*mmap_locked = false;
>> +
>> +		/* Immediately exit on exhaustion of range */
>> +		if (_address == org_address + (PAGE_SIZE << HPAGE_PMD_ORDER))
>> +			goto out;
> Looks like this assumes this function is always asked to scan a full PTE table?
> Does that mean that you can't handle collapse for VMAs that don't span a whole
> PMD entry? I think we will want to support that.

Correct. Yes, we should support that, otherwise khugepaged will scan only large VMAs.
Will have to make a change in khugepaged_scan_mm_slot() for anon case.

>
>> +	}
>> +	if (result != SCAN_SUCCEED) {
>> +
>> +		/* Go to the next order. */
>> +		order = next_order(&orders, order);
>> +		if (order < 2)
> This should be:
> 		if (!orders)
>
> I think the return order is undefined when order is the last order in orders.

The return order is -1, from what I could gather from reading the code.

>
>> +			goto out;
>> +		goto maybe_mmap_lock;
>> +	} else {
>> +		address = _address;
>> +		pte = _pte;
>> +
>> +
>> +		/* Get highest order possible starting from address */
>> +		order = count_trailing_zeros(address >> PAGE_SHIFT);
>> +
>> +		/* This needs to be present in the mask too */
>> +		if (!(orders & (1UL << order)))
>> +			order = next_order(&orders, order);
> Not quite; if the exact order isn't in the bitmap, this will pick out the
> highest order in the bitmap, which may be higher than count_trailing_zeros()
> returned.

Oh okay, nice catch.

>   You could do:
>
> 		order = count_trailing_zeros(address >> PAGE_SHIFT);
> 		orders &= (1UL << order + 1) - 1;
> 		order = next_order(&orders, order);
> 		if (!orders)
> 			goto out;
>
> That will mask out any orders that are bigger than the one returned by
> count_trailing_zeros() then next_order() will return the highest order in the
> remaining set.
>
> But even that doesn't quite work because next_order() is destructive. Once you
> arrive on a higher order address boundary, you want to be able to select a
> higher order from the original orders bitmap. But you have lost them on a
> previous go around the loop.
>
> Perhaps stash orig_orders at the top of the function when you first calculate
> it. Then I think this works (totally untested):
>
> 		order = count_trailing_zeros(address >> PAGE_SHIFT);
> 		orders = orig_orders & (1UL << order + 1) - 1;
> 		order = next_order(&orders, order);
> 		if (!orders)
> 			goto out;
>
> You might want to do something like this for the first go around the loop, but I
> think address is currently always at the start of the PMD on entry, so not
> needed until that restriction is removed.

Will take a look, we just need order <= order derived from trailing zeroes,
and then we need the first enabled order below this in the bitmask, shouldn't
be too complicated.

>
>> +		if (order < 2)
>> +			goto out;
>> +
>> +maybe_mmap_lock:
>> +		if (!(*mmap_locked)) {
>> +			mmap_read_lock(mm);
> Given the lock was already held in read mode on entering this function, then
> released by collapse_huge_page(), is it definitely safe to retake this lock and
> rerun this function? Is it possible that state that was checked before entering
> this function has changed since the lock was released that would now need
> re-checking?

Thanks, what I am missing is a hugepage_vma_revalidate().

>> +			*mmap_locked = true;
>> +		}
>> +		goto scan_pte_range;
>>   	}
>>   out:
>>   	trace_mm_khugepaged_scan_pmd(mm, &folio->page, writable, referenced,
diff mbox series

Patch

diff --git a/mm/khugepaged.c b/mm/khugepaged.c
index 886c76816963..078794aa3335 100644
--- a/mm/khugepaged.c
+++ b/mm/khugepaged.c
@@ -20,6 +20,7 @@ 
 #include <linux/swapops.h>
 #include <linux/shmem_fs.h>
 #include <linux/ksm.h>
+#include <linux/count_zeros.h>
 
 #include <asm/tlb.h>
 #include <asm/pgalloc.h>
@@ -1111,7 +1112,7 @@  static int alloc_charge_folio(struct folio **foliop, struct mm_struct *mm,
 }
 
 static int collapse_huge_page(struct mm_struct *mm, unsigned long address,
-			      int referenced, int unmapped,
+			      int referenced, int unmapped, int order,
 			      struct collapse_control *cc)
 {
 	LIST_HEAD(compound_pagelist);
@@ -1278,38 +1279,59 @@  static int hpage_collapse_scan_ptes(struct mm_struct *mm,
 				   unsigned long address, bool *mmap_locked,
 				   struct collapse_control *cc)
 {
-	pmd_t *pmd;
-	pte_t *pte, *_pte;
-	int result = SCAN_FAIL, referenced = 0;
-	int none_or_zero = 0, shared = 0;
-	struct page *page = NULL;
+	unsigned int max_ptes_shared, max_ptes_none, max_ptes_swap;
+	int referenced, shared, none_or_zero, unmapped;
+	unsigned long _address, org_address = address;
 	struct folio *folio = NULL;
-	unsigned long _address;
-	spinlock_t *ptl;
-	int node = NUMA_NO_NODE, unmapped = 0;
+	struct page *page = NULL;
+	int node = NUMA_NO_NODE;
+	int result = SCAN_FAIL;
 	bool writable = false;
+	unsigned long orders;
+	pte_t *pte, *_pte;
+	spinlock_t *ptl;
+	pmd_t *pmd;
+	int order;
 
 	VM_BUG_ON(address & ~HPAGE_PMD_MASK);
 
+	orders = thp_vma_allowable_orders(vma, vma->vm_flags,
+			TVA_IN_PF | TVA_ENFORCE_SYSFS, BIT(PMD_ORDER + 1) - 1);
+	orders = thp_vma_suitable_orders(vma, address, orders);
+	order = highest_order(orders);
+
+	/* MADV_COLLAPSE needs to work irrespective of sysfs setting */
+	if (!cc->is_khugepaged)
+		order = HPAGE_PMD_ORDER;
+
+scan_pte_range:
+
+	max_ptes_shared = khugepaged_max_ptes_shared >> (HPAGE_PMD_ORDER - order);
+	max_ptes_none = khugepaged_max_ptes_none >> (HPAGE_PMD_ORDER - order);
+	max_ptes_swap = khugepaged_max_ptes_swap >> (HPAGE_PMD_ORDER - order);
+	referenced = 0, shared = 0, none_or_zero = 0, unmapped = 0;
+
+	/* Check pmd after taking mmap lock */
 	result = find_pmd_or_thp_or_none(mm, address, &pmd);
 	if (result != SCAN_SUCCEED)
 		goto out;
 
 	memset(cc->node_load, 0, sizeof(cc->node_load));
 	nodes_clear(cc->alloc_nmask);
+
 	pte = pte_offset_map_lock(mm, pmd, address, &ptl);
 	if (!pte) {
 		result = SCAN_PMD_NULL;
 		goto out;
 	}
 
-	for (_address = address, _pte = pte; _pte < pte + HPAGE_PMD_NR;
+	for (_address = address, _pte = pte; _pte < pte + (1UL << order);
 	     _pte++, _address += PAGE_SIZE) {
 		pte_t pteval = ptep_get(_pte);
 		if (is_swap_pte(pteval)) {
 			++unmapped;
 			if (!cc->is_khugepaged ||
-			    unmapped <= khugepaged_max_ptes_swap) {
+			    unmapped <= max_ptes_swap) {
 				/*
 				 * Always be strict with uffd-wp
 				 * enabled swap entries.  Please see
@@ -1330,7 +1352,7 @@  static int hpage_collapse_scan_ptes(struct mm_struct *mm,
 			++none_or_zero;
 			if (!userfaultfd_armed(vma) &&
 			    (!cc->is_khugepaged ||
-			     none_or_zero <= khugepaged_max_ptes_none)) {
+			     none_or_zero <= max_ptes_none)) {
 				continue;
 			} else {
 				result = SCAN_EXCEED_NONE_PTE;
@@ -1375,7 +1397,7 @@  static int hpage_collapse_scan_ptes(struct mm_struct *mm,
 		if (folio_likely_mapped_shared(folio)) {
 			++shared;
 			if (cc->is_khugepaged &&
-			    shared > khugepaged_max_ptes_shared) {
+			    shared > max_ptes_shared) {
 				result = SCAN_EXCEED_SHARED_PTE;
 				count_vm_event(THP_SCAN_EXCEED_SHARED_PTE);
 				goto out_unmap;
@@ -1432,7 +1454,7 @@  static int hpage_collapse_scan_ptes(struct mm_struct *mm,
 		result = SCAN_PAGE_RO;
 	} else if (cc->is_khugepaged &&
 		   (!referenced ||
-		    (unmapped && referenced < HPAGE_PMD_NR / 2))) {
+		    (unmapped && referenced < (1UL << order) / 2))) {
 		result = SCAN_LACK_REFERENCED_PAGE;
 	} else {
 		result = SCAN_SUCCEED;
@@ -1441,9 +1463,41 @@  static int hpage_collapse_scan_ptes(struct mm_struct *mm,
 	pte_unmap_unlock(pte, ptl);
 	if (result == SCAN_SUCCEED) {
 		result = collapse_huge_page(mm, address, referenced,
-					    unmapped, cc);
+					    unmapped, order, cc);
 		/* collapse_huge_page will return with the mmap_lock released */
 		*mmap_locked = false;
+
+		/* Immediately exit on exhaustion of range */
+		if (_address == org_address + (PAGE_SIZE << HPAGE_PMD_ORDER))
+			goto out;
+	}
+	if (result != SCAN_SUCCEED) {
+
+		/* Go to the next order. */
+		order = next_order(&orders, order);
+		if (order < 2)
+			goto out;
+		goto maybe_mmap_lock;
+	} else {
+		address = _address;
+		pte = _pte;
+
+
+		/* Get highest order possible starting from address */
+		order = count_trailing_zeros(address >> PAGE_SHIFT);
+
+		/* This needs to be present in the mask too */
+		if (!(orders & (1UL << order)))
+			order = next_order(&orders, order);
+		if (order < 2)
+			goto out;
+
+maybe_mmap_lock:
+		if (!(*mmap_locked)) {
+			mmap_read_lock(mm);
+			*mmap_locked = true;
+		}
+		goto scan_pte_range;
 	}
 out:
 	trace_mm_khugepaged_scan_pmd(mm, &folio->page, writable, referenced,