diff mbox series

[RFC] mm/hugetlb_vmemmap: fix race with speculative PFN walkers

Message ID 20240621213717.1099079-1-yuzhao@google.com (mailing list archive)
State New
Headers show
Series [RFC] mm/hugetlb_vmemmap: fix race with speculative PFN walkers | expand

Commit Message

Yu Zhao June 21, 2024, 9:37 p.m. UTC
While investigating HVO for THPs [1], it turns out that speculative
PFN walkers like compaction can race with vmemmap modificatioins,
e.g.,

  CPU 1 (vmemmap modifier)         CPU 2 (speculative PFN walker)
  -----------------------------    ------------------------------
  Allocates an LRU folio page1
                                   Sees page1
  Frees page1

  Allocates a hugeTLB folio page2
  (page1 being a tail of page2)

  Updates vmemmap mapping page1
                                   get_page_unless_zero(page1)

Even though page1 has a zero refcnt after HVO, get_page_unless_zero()
can still try to modify its read-only struct page resulting in a
crash.

An independent report [2] confirmed this race.

There are two discussed approaches to fix this race:
1. Make RO vmemmap RW so that get_page_unless_zero() can fail without
   triggering a PF.
2. Use RCU to make sure get_page_unless_zero() either sees zero
   refcnts through the old vmemmap or non-zero refcnts through the new
   one.

The second approach is preferred here because:
1. It can prevent illegal modifications to struct page[] that is HVO;
2. It can be generalized, in a way similar to ZERO_PAGE(), to fix
   similar races in other places, e.g., arch_remove_memory() on x86
   [3], which frees vmemmap mapping offlined struct page[].

While adding synchronize_rcu(), the goal is to be surgical, rather
than optimized. Specifically, calls to synchronize_rcu() on the error
handling paths can be coalesced, but it is not done for the sake of
Simplicity: noticeably, this fix removes ~50% more lines than it adds.

[1] https://lore.kernel.org/20240229183436.4110845-4-yuzhao@google.com/
[2] https://lore.kernel.org/917FFC7F-0615-44DD-90EE-9F85F8EA9974@linux.dev/
[3] https://lore.kernel.org/be130a96-a27e-4240-ad78-776802f57cad@redhat.com/

Signed-off-by: Yu Zhao <yuzhao@google.com>
---
 include/linux/page_ref.h |  8 ++++++-
 mm/hugetlb.c             | 50 +++++-----------------------------------
 mm/hugetlb_vmemmap.c     | 16 +++++++++++++
 3 files changed, 29 insertions(+), 45 deletions(-)


base-commit: 264efe488fd82cf3145a3dc625f394c61db99934
prerequisite-patch-id: 5029fb66d9bf40b84903a5b4f066e85101169e84
prerequisite-patch-id: 7889e5ee16b8e91cccde12468f1d2c3f65500336
prerequisite-patch-id: 0d4c19afc7b92f16bee9e9cf2b6832406389742a
prerequisite-patch-id: c56f06d4bb3e738aea489ec30313ed0c1dbac325

Comments

Muchun Song June 26, 2024, 2:37 a.m. UTC | #1
On 2024/6/22 05:37, Yu Zhao wrote:
> While investigating HVO for THPs [1], it turns out that speculative
> PFN walkers like compaction can race with vmemmap modificatioins,
> e.g.,
>
>    CPU 1 (vmemmap modifier)         CPU 2 (speculative PFN walker)
>    -----------------------------    ------------------------------
>    Allocates an LRU folio page1
>                                     Sees page1
>    Frees page1
>
>    Allocates a hugeTLB folio page2
>    (page1 being a tail of page2)
>
>    Updates vmemmap mapping page1
>                                     get_page_unless_zero(page1)
>
> Even though page1 has a zero refcnt after HVO, get_page_unless_zero()
> can still try to modify its read-only struct page resulting in a
> crash.
>
> An independent report [2] confirmed this race.

Right. Thanks for your continuous focus on this race.

>
> There are two discussed approaches to fix this race:
> 1. Make RO vmemmap RW so that get_page_unless_zero() can fail without
>     triggering a PF.
> 2. Use RCU to make sure get_page_unless_zero() either sees zero
>     refcnts through the old vmemmap or non-zero refcnts through the new
>     one.
>
> The second approach is preferred here because:
> 1. It can prevent illegal modifications to struct page[] that is HVO;
> 2. It can be generalized, in a way similar to ZERO_PAGE(), to fix
>     similar races in other places, e.g., arch_remove_memory() on x86
>     [3], which frees vmemmap mapping offlined struct page[].
>
> While adding synchronize_rcu(), the goal is to be surgical, rather
> than optimized. Specifically, calls to synchronize_rcu() on the error
> handling paths can be coalesced, but it is not done for the sake of
> Simplicity: noticeably, this fix removes ~50% more lines than it adds.
>
> [1] https://lore.kernel.org/20240229183436.4110845-4-yuzhao@google.com/
> [2] https://lore.kernel.org/917FFC7F-0615-44DD-90EE-9F85F8EA9974@linux.dev/
> [3] https://lore.kernel.org/be130a96-a27e-4240-ad78-776802f57cad@redhat.com/
>
> Signed-off-by: Yu Zhao <yuzhao@google.com>
> ---
>   include/linux/page_ref.h |  8 ++++++-
>   mm/hugetlb.c             | 50 +++++-----------------------------------
>   mm/hugetlb_vmemmap.c     | 16 +++++++++++++
>   3 files changed, 29 insertions(+), 45 deletions(-)
>
> diff --git a/include/linux/page_ref.h b/include/linux/page_ref.h
> index 1acf5bac7f50..add92e8f31b2 100644
> --- a/include/linux/page_ref.h
> +++ b/include/linux/page_ref.h
> @@ -230,7 +230,13 @@ static inline int folio_ref_dec_return(struct folio *folio)
>   
>   static inline bool page_ref_add_unless(struct page *page, int nr, int u)
>   {
> -	bool ret = atomic_add_unless(&page->_refcount, nr, u);
> +	bool ret = false;
> +
> +	rcu_read_lock();
> +	/* avoid writing to the vmemmap area being remapped */
> +	if (!page_is_fake_head(page) && page_ref_count(page) != u)
> +		ret = atomic_add_unless(&page->_refcount, nr, u);
> +	rcu_read_unlock();
>   
>   	if (page_ref_tracepoint_active(page_ref_mod_unless))
>   		__page_ref_mod_unless(page, nr, ret);
> diff --git a/mm/hugetlb.c b/mm/hugetlb.c
> index f35abff8be60..271d83a7cde0 100644
> --- a/mm/hugetlb.c
> +++ b/mm/hugetlb.c
> @@ -1629,9 +1629,8 @@ static inline void destroy_compound_gigantic_folio(struct folio *folio,
>    *
>    * Must be called with hugetlb lock held.
>    */
> -static void __remove_hugetlb_folio(struct hstate *h, struct folio *folio,
> -							bool adjust_surplus,
> -							bool demote)
> +static void remove_hugetlb_folio(struct hstate *h, struct folio *folio,
> +							bool adjust_surplus)
>   {
>   	int nid = folio_nid(folio);
>   
> @@ -1661,33 +1660,13 @@ static void __remove_hugetlb_folio(struct hstate *h, struct folio *folio,
>   	if (!folio_test_hugetlb_vmemmap_optimized(folio))
>   		__folio_clear_hugetlb(folio);
>   
> -	 /*
> -	  * In the case of demote we do not ref count the page as it will soon
> -	  * be turned into a page of smaller size.
> -	 */
> -	if (!demote)
> -		folio_ref_unfreeze(folio, 1);
> -
>   	h->nr_huge_pages--;
>   	h->nr_huge_pages_node[nid]--;
>   }
>   
> -static void remove_hugetlb_folio(struct hstate *h, struct folio *folio,
> -							bool adjust_surplus)
> -{
> -	__remove_hugetlb_folio(h, folio, adjust_surplus, false);
> -}
> -
> -static void remove_hugetlb_folio_for_demote(struct hstate *h, struct folio *folio,
> -							bool adjust_surplus)
> -{
> -	__remove_hugetlb_folio(h, folio, adjust_surplus, true);
> -}
> -
>   static void add_hugetlb_folio(struct hstate *h, struct folio *folio,
>   			     bool adjust_surplus)
>   {
> -	int zeroed;
>   	int nid = folio_nid(folio);
>   
>   	VM_BUG_ON_FOLIO(!folio_test_hugetlb_vmemmap_optimized(folio), folio);
> @@ -1711,21 +1690,6 @@ static void add_hugetlb_folio(struct hstate *h, struct folio *folio,
>   	 */
>   	folio_set_hugetlb_vmemmap_optimized(folio);
>   
> -	/*
> -	 * This folio is about to be managed by the hugetlb allocator and
> -	 * should have no users.  Drop our reference, and check for others
> -	 * just in case.
> -	 */
> -	zeroed = folio_put_testzero(folio);
> -	if (unlikely(!zeroed))
> -		/*
> -		 * It is VERY unlikely soneone else has taken a ref
> -		 * on the folio.  In this case, we simply return as
> -		 * free_huge_folio() will be called when this other ref
> -		 * is dropped.
> -		 */
> -		return;
> -
>   	arch_clear_hugetlb_flags(folio);
>   	enqueue_hugetlb_folio(h, folio);
>   }
> @@ -1779,6 +1743,8 @@ static void __update_and_free_hugetlb_folio(struct hstate *h,
>   		spin_unlock_irq(&hugetlb_lock);
>   	}
>   
> +	folio_ref_unfreeze(folio, 1);
> +
>   	/*
>   	 * Non-gigantic pages demoted from CMA allocated gigantic pages
>   	 * need to be given back to CMA in free_gigantic_folio.
> @@ -3079,11 +3045,8 @@ static int alloc_and_dissolve_hugetlb_folio(struct hstate *h,
>   
>   free_new:
>   	spin_unlock_irq(&hugetlb_lock);
> -	if (new_folio) {
> -		/* Folio has a zero ref count, but needs a ref to be freed */
> -		folio_ref_unfreeze(new_folio, 1);
> +	if (new_folio)
>   		update_and_free_hugetlb_folio(h, new_folio, false);
> -	}

Look into this function, we have:

dissolve_free_huge_page
retry:
     if (!folio_test_hugetlb(folio))
         return;
     if (!folio_ref_count(folio))
         if (unlikely(!folio_test_hugetlb_freed(folio)))
             goto retry;
     remove_hugetlb_folio(h, folio, false);

Since you have not raised the refcount in remove_hugetlb_folio(), we will
disslove this page again if there is a concurrent dissolve_free_huge_page()
processing routine. Then, the statistics will be wrong (like 
->nr_huge_pages).
A solution seems easy, we should clear folio_clear_hugetlb_freed in
remove_hugetlb_folio.

Muchun,
Thanks.
>   
>   	return ret;
>   }
> @@ -3938,7 +3901,7 @@ static int demote_free_hugetlb_folio(struct hstate *h, struct folio *folio)
>   
>   	target_hstate = size_to_hstate(PAGE_SIZE << h->demote_order);
>   
> -	remove_hugetlb_folio_for_demote(h, folio, false);
> +	remove_hugetlb_folio(h, folio, false);
>   	spin_unlock_irq(&hugetlb_lock);
>   
>   	/*
> @@ -3952,7 +3915,6 @@ static int demote_free_hugetlb_folio(struct hstate *h, struct folio *folio)
>   		if (rc) {
>   			/* Allocation of vmemmmap failed, we can not demote folio */
>   			spin_lock_irq(&hugetlb_lock);
> -			folio_ref_unfreeze(folio, 1);
>   			add_hugetlb_folio(h, folio, false);
>   			return rc;
>   		}
> diff --git a/mm/hugetlb_vmemmap.c b/mm/hugetlb_vmemmap.c
> index b9a55322e52c..8193906515c6 100644
> --- a/mm/hugetlb_vmemmap.c
> +++ b/mm/hugetlb_vmemmap.c
> @@ -446,6 +446,8 @@ static int __hugetlb_vmemmap_restore_folio(const struct hstate *h,
>   	unsigned long vmemmap_reuse;
>   
>   	VM_WARN_ON_ONCE_FOLIO(!folio_test_hugetlb(folio), folio);
> +	VM_WARN_ON_ONCE_FOLIO(folio_ref_count(folio), folio);
> +
>   	if (!folio_test_hugetlb_vmemmap_optimized(folio))
>   		return 0;
>   
> @@ -481,6 +483,9 @@ static int __hugetlb_vmemmap_restore_folio(const struct hstate *h,
>    */
>   int hugetlb_vmemmap_restore_folio(const struct hstate *h, struct folio *folio)
>   {
> +	/* avoid writes from page_ref_add_unless() while unfolding vmemmap */
> +	synchronize_rcu();
> +
>   	return __hugetlb_vmemmap_restore_folio(h, folio, 0);
>   }
>   
> @@ -505,6 +510,9 @@ long hugetlb_vmemmap_restore_folios(const struct hstate *h,
>   	long restored = 0;
>   	long ret = 0;
>   
> +	/* avoid writes from page_ref_add_unless() while unfolding vmemmap */
> +	synchronize_rcu();
> +
>   	list_for_each_entry_safe(folio, t_folio, folio_list, lru) {
>   		if (folio_test_hugetlb_vmemmap_optimized(folio)) {
>   			ret = __hugetlb_vmemmap_restore_folio(h, folio,
> @@ -550,6 +558,8 @@ static int __hugetlb_vmemmap_optimize_folio(const struct hstate *h,
>   	unsigned long vmemmap_reuse;
>   
>   	VM_WARN_ON_ONCE_FOLIO(!folio_test_hugetlb(folio), folio);
> +	VM_WARN_ON_ONCE_FOLIO(folio_ref_count(folio), folio);
> +
>   	if (!vmemmap_should_optimize_folio(h, folio))
>   		return ret;
>   
> @@ -601,6 +611,9 @@ void hugetlb_vmemmap_optimize_folio(const struct hstate *h, struct folio *folio)
>   {
>   	LIST_HEAD(vmemmap_pages);
>   
> +	/* avoid writes from page_ref_add_unless() while folding vmemmap */
> +	synchronize_rcu();
> +
>   	__hugetlb_vmemmap_optimize_folio(h, folio, &vmemmap_pages, 0);
>   	free_vmemmap_page_list(&vmemmap_pages);
>   }
> @@ -644,6 +657,9 @@ void hugetlb_vmemmap_optimize_folios(struct hstate *h, struct list_head *folio_l
>   
>   	flush_tlb_all();
>   
> +	/* avoid writes from page_ref_add_unless() while folding vmemmap */
> +	synchronize_rcu();
> +
>   	list_for_each_entry(folio, folio_list, lru) {
>   		int ret;
>   
>
> base-commit: 264efe488fd82cf3145a3dc625f394c61db99934
> prerequisite-patch-id: 5029fb66d9bf40b84903a5b4f066e85101169e84
> prerequisite-patch-id: 7889e5ee16b8e91cccde12468f1d2c3f65500336
> prerequisite-patch-id: 0d4c19afc7b92f16bee9e9cf2b6832406389742a
> prerequisite-patch-id: c56f06d4bb3e738aea489ec30313ed0c1dbac325
Yu Zhao June 27, 2024, 4:47 a.m. UTC | #2
On Tue, Jun 25, 2024 at 8:38 PM Muchun Song <muchun.song@linux.dev> wrote:
>
>
>
> On 2024/6/22 05:37, Yu Zhao wrote:
> > While investigating HVO for THPs [1], it turns out that speculative
> > PFN walkers like compaction can race with vmemmap modificatioins,
> > e.g.,
> >
> >    CPU 1 (vmemmap modifier)         CPU 2 (speculative PFN walker)
> >    -----------------------------    ------------------------------
> >    Allocates an LRU folio page1
> >                                     Sees page1
> >    Frees page1
> >
> >    Allocates a hugeTLB folio page2
> >    (page1 being a tail of page2)
> >
> >    Updates vmemmap mapping page1
> >                                     get_page_unless_zero(page1)
> >
> > Even though page1 has a zero refcnt after HVO, get_page_unless_zero()
> > can still try to modify its read-only struct page resulting in a
> > crash.
> >
> > An independent report [2] confirmed this race.
>
> Right. Thanks for your continuous focus on this race.
>
> >
> > There are two discussed approaches to fix this race:
> > 1. Make RO vmemmap RW so that get_page_unless_zero() can fail without
> >     triggering a PF.
> > 2. Use RCU to make sure get_page_unless_zero() either sees zero
> >     refcnts through the old vmemmap or non-zero refcnts through the new
> >     one.
> >
> > The second approach is preferred here because:
> > 1. It can prevent illegal modifications to struct page[] that is HVO;
> > 2. It can be generalized, in a way similar to ZERO_PAGE(), to fix
> >     similar races in other places, e.g., arch_remove_memory() on x86
> >     [3], which frees vmemmap mapping offlined struct page[].
> >
> > While adding synchronize_rcu(), the goal is to be surgical, rather
> > than optimized. Specifically, calls to synchronize_rcu() on the error
> > handling paths can be coalesced, but it is not done for the sake of
> > Simplicity: noticeably, this fix removes ~50% more lines than it adds.
> >
> > [1] https://lore.kernel.org/20240229183436.4110845-4-yuzhao@google.com/
> > [2] https://lore.kernel.org/917FFC7F-0615-44DD-90EE-9F85F8EA9974@linux.dev/
> > [3] https://lore.kernel.org/be130a96-a27e-4240-ad78-776802f57cad@redhat.com/
> >
> > Signed-off-by: Yu Zhao <yuzhao@google.com>
> > ---
> >   include/linux/page_ref.h |  8 ++++++-
> >   mm/hugetlb.c             | 50 +++++-----------------------------------
> >   mm/hugetlb_vmemmap.c     | 16 +++++++++++++
> >   3 files changed, 29 insertions(+), 45 deletions(-)
> >
> > diff --git a/include/linux/page_ref.h b/include/linux/page_ref.h
> > index 1acf5bac7f50..add92e8f31b2 100644
> > --- a/include/linux/page_ref.h
> > +++ b/include/linux/page_ref.h
> > @@ -230,7 +230,13 @@ static inline int folio_ref_dec_return(struct folio *folio)
> >
> >   static inline bool page_ref_add_unless(struct page *page, int nr, int u)
> >   {
> > -     bool ret = atomic_add_unless(&page->_refcount, nr, u);
> > +     bool ret = false;
> > +
> > +     rcu_read_lock();
> > +     /* avoid writing to the vmemmap area being remapped */
> > +     if (!page_is_fake_head(page) && page_ref_count(page) != u)
> > +             ret = atomic_add_unless(&page->_refcount, nr, u);
> > +     rcu_read_unlock();
> >
> >       if (page_ref_tracepoint_active(page_ref_mod_unless))
> >               __page_ref_mod_unless(page, nr, ret);
> > diff --git a/mm/hugetlb.c b/mm/hugetlb.c
> > index f35abff8be60..271d83a7cde0 100644
> > --- a/mm/hugetlb.c
> > +++ b/mm/hugetlb.c
> > @@ -1629,9 +1629,8 @@ static inline void destroy_compound_gigantic_folio(struct folio *folio,
> >    *
> >    * Must be called with hugetlb lock held.
> >    */
> > -static void __remove_hugetlb_folio(struct hstate *h, struct folio *folio,
> > -                                                     bool adjust_surplus,
> > -                                                     bool demote)
> > +static void remove_hugetlb_folio(struct hstate *h, struct folio *folio,
> > +                                                     bool adjust_surplus)
> >   {
> >       int nid = folio_nid(folio);
> >
> > @@ -1661,33 +1660,13 @@ static void __remove_hugetlb_folio(struct hstate *h, struct folio *folio,
> >       if (!folio_test_hugetlb_vmemmap_optimized(folio))
> >               __folio_clear_hugetlb(folio);
> >
> > -      /*
> > -       * In the case of demote we do not ref count the page as it will soon
> > -       * be turned into a page of smaller size.
> > -      */
> > -     if (!demote)
> > -             folio_ref_unfreeze(folio, 1);
> > -
> >       h->nr_huge_pages--;
> >       h->nr_huge_pages_node[nid]--;
> >   }
> >
> > -static void remove_hugetlb_folio(struct hstate *h, struct folio *folio,
> > -                                                     bool adjust_surplus)
> > -{
> > -     __remove_hugetlb_folio(h, folio, adjust_surplus, false);
> > -}
> > -
> > -static void remove_hugetlb_folio_for_demote(struct hstate *h, struct folio *folio,
> > -                                                     bool adjust_surplus)
> > -{
> > -     __remove_hugetlb_folio(h, folio, adjust_surplus, true);
> > -}
> > -
> >   static void add_hugetlb_folio(struct hstate *h, struct folio *folio,
> >                            bool adjust_surplus)
> >   {
> > -     int zeroed;
> >       int nid = folio_nid(folio);
> >
> >       VM_BUG_ON_FOLIO(!folio_test_hugetlb_vmemmap_optimized(folio), folio);
> > @@ -1711,21 +1690,6 @@ static void add_hugetlb_folio(struct hstate *h, struct folio *folio,
> >        */
> >       folio_set_hugetlb_vmemmap_optimized(folio);
> >
> > -     /*
> > -      * This folio is about to be managed by the hugetlb allocator and
> > -      * should have no users.  Drop our reference, and check for others
> > -      * just in case.
> > -      */
> > -     zeroed = folio_put_testzero(folio);
> > -     if (unlikely(!zeroed))
> > -             /*
> > -              * It is VERY unlikely soneone else has taken a ref
> > -              * on the folio.  In this case, we simply return as
> > -              * free_huge_folio() will be called when this other ref
> > -              * is dropped.
> > -              */
> > -             return;
> > -
> >       arch_clear_hugetlb_flags(folio);
> >       enqueue_hugetlb_folio(h, folio);
> >   }
> > @@ -1779,6 +1743,8 @@ static void __update_and_free_hugetlb_folio(struct hstate *h,
> >               spin_unlock_irq(&hugetlb_lock);
> >       }
> >
> > +     folio_ref_unfreeze(folio, 1);
> > +
> >       /*
> >        * Non-gigantic pages demoted from CMA allocated gigantic pages
> >        * need to be given back to CMA in free_gigantic_folio.
> > @@ -3079,11 +3045,8 @@ static int alloc_and_dissolve_hugetlb_folio(struct hstate *h,
> >
> >   free_new:
> >       spin_unlock_irq(&hugetlb_lock);
> > -     if (new_folio) {
> > -             /* Folio has a zero ref count, but needs a ref to be freed */
> > -             folio_ref_unfreeze(new_folio, 1);
> > +     if (new_folio)
> >               update_and_free_hugetlb_folio(h, new_folio, false);
> > -     }
>
> Look into this function, we have:
>
> dissolve_free_huge_page
> retry:
>      if (!folio_test_hugetlb(folio))
>          return;
>      if (!folio_ref_count(folio))
>          if (unlikely(!folio_test_hugetlb_freed(folio)))
>              goto retry;
>      remove_hugetlb_folio(h, folio, false);
>
> Since you have not raised the refcount in remove_hugetlb_folio(), we will
> disslove this page again if there is a concurrent dissolve_free_huge_page()
> processing routine. Then, the statistics will be wrong (like
> ->nr_huge_pages).

Thanks for pointing this out!

> A solution seems easy, we should clear folio_clear_hugetlb_freed in
> remove_hugetlb_folio.

Agreed.
diff mbox series

Patch

diff --git a/include/linux/page_ref.h b/include/linux/page_ref.h
index 1acf5bac7f50..add92e8f31b2 100644
--- a/include/linux/page_ref.h
+++ b/include/linux/page_ref.h
@@ -230,7 +230,13 @@  static inline int folio_ref_dec_return(struct folio *folio)
 
 static inline bool page_ref_add_unless(struct page *page, int nr, int u)
 {
-	bool ret = atomic_add_unless(&page->_refcount, nr, u);
+	bool ret = false;
+
+	rcu_read_lock();
+	/* avoid writing to the vmemmap area being remapped */
+	if (!page_is_fake_head(page) && page_ref_count(page) != u)
+		ret = atomic_add_unless(&page->_refcount, nr, u);
+	rcu_read_unlock();
 
 	if (page_ref_tracepoint_active(page_ref_mod_unless))
 		__page_ref_mod_unless(page, nr, ret);
diff --git a/mm/hugetlb.c b/mm/hugetlb.c
index f35abff8be60..271d83a7cde0 100644
--- a/mm/hugetlb.c
+++ b/mm/hugetlb.c
@@ -1629,9 +1629,8 @@  static inline void destroy_compound_gigantic_folio(struct folio *folio,
  *
  * Must be called with hugetlb lock held.
  */
-static void __remove_hugetlb_folio(struct hstate *h, struct folio *folio,
-							bool adjust_surplus,
-							bool demote)
+static void remove_hugetlb_folio(struct hstate *h, struct folio *folio,
+							bool adjust_surplus)
 {
 	int nid = folio_nid(folio);
 
@@ -1661,33 +1660,13 @@  static void __remove_hugetlb_folio(struct hstate *h, struct folio *folio,
 	if (!folio_test_hugetlb_vmemmap_optimized(folio))
 		__folio_clear_hugetlb(folio);
 
-	 /*
-	  * In the case of demote we do not ref count the page as it will soon
-	  * be turned into a page of smaller size.
-	 */
-	if (!demote)
-		folio_ref_unfreeze(folio, 1);
-
 	h->nr_huge_pages--;
 	h->nr_huge_pages_node[nid]--;
 }
 
-static void remove_hugetlb_folio(struct hstate *h, struct folio *folio,
-							bool adjust_surplus)
-{
-	__remove_hugetlb_folio(h, folio, adjust_surplus, false);
-}
-
-static void remove_hugetlb_folio_for_demote(struct hstate *h, struct folio *folio,
-							bool adjust_surplus)
-{
-	__remove_hugetlb_folio(h, folio, adjust_surplus, true);
-}
-
 static void add_hugetlb_folio(struct hstate *h, struct folio *folio,
 			     bool adjust_surplus)
 {
-	int zeroed;
 	int nid = folio_nid(folio);
 
 	VM_BUG_ON_FOLIO(!folio_test_hugetlb_vmemmap_optimized(folio), folio);
@@ -1711,21 +1690,6 @@  static void add_hugetlb_folio(struct hstate *h, struct folio *folio,
 	 */
 	folio_set_hugetlb_vmemmap_optimized(folio);
 
-	/*
-	 * This folio is about to be managed by the hugetlb allocator and
-	 * should have no users.  Drop our reference, and check for others
-	 * just in case.
-	 */
-	zeroed = folio_put_testzero(folio);
-	if (unlikely(!zeroed))
-		/*
-		 * It is VERY unlikely soneone else has taken a ref
-		 * on the folio.  In this case, we simply return as
-		 * free_huge_folio() will be called when this other ref
-		 * is dropped.
-		 */
-		return;
-
 	arch_clear_hugetlb_flags(folio);
 	enqueue_hugetlb_folio(h, folio);
 }
@@ -1779,6 +1743,8 @@  static void __update_and_free_hugetlb_folio(struct hstate *h,
 		spin_unlock_irq(&hugetlb_lock);
 	}
 
+	folio_ref_unfreeze(folio, 1);
+
 	/*
 	 * Non-gigantic pages demoted from CMA allocated gigantic pages
 	 * need to be given back to CMA in free_gigantic_folio.
@@ -3079,11 +3045,8 @@  static int alloc_and_dissolve_hugetlb_folio(struct hstate *h,
 
 free_new:
 	spin_unlock_irq(&hugetlb_lock);
-	if (new_folio) {
-		/* Folio has a zero ref count, but needs a ref to be freed */
-		folio_ref_unfreeze(new_folio, 1);
+	if (new_folio)
 		update_and_free_hugetlb_folio(h, new_folio, false);
-	}
 
 	return ret;
 }
@@ -3938,7 +3901,7 @@  static int demote_free_hugetlb_folio(struct hstate *h, struct folio *folio)
 
 	target_hstate = size_to_hstate(PAGE_SIZE << h->demote_order);
 
-	remove_hugetlb_folio_for_demote(h, folio, false);
+	remove_hugetlb_folio(h, folio, false);
 	spin_unlock_irq(&hugetlb_lock);
 
 	/*
@@ -3952,7 +3915,6 @@  static int demote_free_hugetlb_folio(struct hstate *h, struct folio *folio)
 		if (rc) {
 			/* Allocation of vmemmmap failed, we can not demote folio */
 			spin_lock_irq(&hugetlb_lock);
-			folio_ref_unfreeze(folio, 1);
 			add_hugetlb_folio(h, folio, false);
 			return rc;
 		}
diff --git a/mm/hugetlb_vmemmap.c b/mm/hugetlb_vmemmap.c
index b9a55322e52c..8193906515c6 100644
--- a/mm/hugetlb_vmemmap.c
+++ b/mm/hugetlb_vmemmap.c
@@ -446,6 +446,8 @@  static int __hugetlb_vmemmap_restore_folio(const struct hstate *h,
 	unsigned long vmemmap_reuse;
 
 	VM_WARN_ON_ONCE_FOLIO(!folio_test_hugetlb(folio), folio);
+	VM_WARN_ON_ONCE_FOLIO(folio_ref_count(folio), folio);
+
 	if (!folio_test_hugetlb_vmemmap_optimized(folio))
 		return 0;
 
@@ -481,6 +483,9 @@  static int __hugetlb_vmemmap_restore_folio(const struct hstate *h,
  */
 int hugetlb_vmemmap_restore_folio(const struct hstate *h, struct folio *folio)
 {
+	/* avoid writes from page_ref_add_unless() while unfolding vmemmap */
+	synchronize_rcu();
+
 	return __hugetlb_vmemmap_restore_folio(h, folio, 0);
 }
 
@@ -505,6 +510,9 @@  long hugetlb_vmemmap_restore_folios(const struct hstate *h,
 	long restored = 0;
 	long ret = 0;
 
+	/* avoid writes from page_ref_add_unless() while unfolding vmemmap */
+	synchronize_rcu();
+
 	list_for_each_entry_safe(folio, t_folio, folio_list, lru) {
 		if (folio_test_hugetlb_vmemmap_optimized(folio)) {
 			ret = __hugetlb_vmemmap_restore_folio(h, folio,
@@ -550,6 +558,8 @@  static int __hugetlb_vmemmap_optimize_folio(const struct hstate *h,
 	unsigned long vmemmap_reuse;
 
 	VM_WARN_ON_ONCE_FOLIO(!folio_test_hugetlb(folio), folio);
+	VM_WARN_ON_ONCE_FOLIO(folio_ref_count(folio), folio);
+
 	if (!vmemmap_should_optimize_folio(h, folio))
 		return ret;
 
@@ -601,6 +611,9 @@  void hugetlb_vmemmap_optimize_folio(const struct hstate *h, struct folio *folio)
 {
 	LIST_HEAD(vmemmap_pages);
 
+	/* avoid writes from page_ref_add_unless() while folding vmemmap */
+	synchronize_rcu();
+
 	__hugetlb_vmemmap_optimize_folio(h, folio, &vmemmap_pages, 0);
 	free_vmemmap_page_list(&vmemmap_pages);
 }
@@ -644,6 +657,9 @@  void hugetlb_vmemmap_optimize_folios(struct hstate *h, struct list_head *folio_l
 
 	flush_tlb_all();
 
+	/* avoid writes from page_ref_add_unless() while folding vmemmap */
+	synchronize_rcu();
+
 	list_for_each_entry(folio, folio_list, lru) {
 		int ret;