diff mbox series

[v6,10/16] mm: replace vm_lock and detached flag with a reference count

Message ID 20241216192419.2970941-11-surenb@google.com (mailing list archive)
State New
Headers show
Series move per-vma lock into vm_area_struct | expand

Commit Message

Suren Baghdasaryan Dec. 16, 2024, 7:24 p.m. UTC
rw_semaphore is a sizable structure of 40 bytes and consumes
considerable space for each vm_area_struct. However vma_lock has
two important specifics which can be used to replace rw_semaphore
with a simpler structure:
1. Readers never wait. They try to take the vma_lock and fall back to
mmap_lock if that fails.
2. Only one writer at a time will ever try to write-lock a vma_lock
because writers first take mmap_lock in write mode.
Because of these requirements, full rw_semaphore functionality is not
needed and we can replace rw_semaphore and the vma->detached flag with
a refcount (vm_refcnt).
When vma is in detached state, vm_refcnt is 0 and only a call to
vma_mark_attached() can take it out of this state. Note that unlike
before, now we enforce both vma_mark_attached() and vma_mark_detached()
to be done only after vma has been write-locked. vma_mark_attached()
changes vm_refcnt to 1 to indicate that it has been attached to the vma
tree. When a reader takes read lock, it increments vm_refcnt, unless the
top usable bit of vm_refcnt (0x40000000) is set, indicating presence of
a writer. When writer takes write lock, it both increments vm_refcnt and
sets the top usable bit to indicate its presence. If there are readers,
writer will wait using newly introduced mm->vma_writer_wait. Since all
writers take mmap_lock in write mode first, there can be only one writer
at a time. The last reader to release the lock will signal the writer
to wake up.
refcount might overflow if there are many competing readers, in which case
read-locking will fail. Readers are expected to handle such failures.

Suggested-by: Peter Zijlstra <peterz@infradead.org>
Suggested-by: Matthew Wilcox <willy@infradead.org>
Signed-off-by: Suren Baghdasaryan <surenb@google.com>
---
 include/linux/mm.h               | 95 ++++++++++++++++++++++++--------
 include/linux/mm_types.h         | 23 ++++----
 kernel/fork.c                    |  9 +--
 mm/init-mm.c                     |  1 +
 mm/memory.c                      | 33 +++++++----
 tools/testing/vma/linux/atomic.h |  5 ++
 tools/testing/vma/vma_internal.h | 57 ++++++++++---------
 7 files changed, 147 insertions(+), 76 deletions(-)

Comments

Peter Zijlstra Dec. 16, 2024, 8:42 p.m. UTC | #1
On Mon, Dec 16, 2024 at 11:24:13AM -0800, Suren Baghdasaryan wrote:
> @@ -734,10 +761,12 @@ static inline bool vma_start_read(struct vm_area_struct *vma)
>  	 * after it has been unlocked.
>  	 * This pairs with RELEASE semantics in vma_end_write_all().
>  	 */
> +	if (oldcnt & VMA_STATE_LOCKED ||
> +	    unlikely(vma->vm_lock_seq == raw_read_seqcount(&vma->vm_mm->mm_lock_seq))) {

You likely want that unlikely to cover both conditions :-)

> +		vma_refcount_put(vma);
>  		return false;
>  	}
> +
>  	return true;
>  }
Suren Baghdasaryan Dec. 16, 2024, 8:53 p.m. UTC | #2
On Mon, Dec 16, 2024 at 12:42 PM Peter Zijlstra <peterz@infradead.org> wrote:
>
> On Mon, Dec 16, 2024 at 11:24:13AM -0800, Suren Baghdasaryan wrote:
> > @@ -734,10 +761,12 @@ static inline bool vma_start_read(struct vm_area_struct *vma)
> >        * after it has been unlocked.
> >        * This pairs with RELEASE semantics in vma_end_write_all().
> >        */
> > +     if (oldcnt & VMA_STATE_LOCKED ||
> > +         unlikely(vma->vm_lock_seq == raw_read_seqcount(&vma->vm_mm->mm_lock_seq))) {
>
> You likely want that unlikely to cover both conditions :-)

True. VMA_STATE_LOCKED is set only while the writer is updating the
vm_lock_seq and that's a narrow window. I'll make that change in the
next revision. Thanks!

>
> > +             vma_refcount_put(vma);
> >               return false;
> >       }
> > +
> >       return true;
> >  }
Peter Zijlstra Dec. 16, 2024, 9:15 p.m. UTC | #3
On Mon, Dec 16, 2024 at 11:24:13AM -0800, Suren Baghdasaryan wrote:

FWIW, I find the whole VMA_STATE_{A,DE}TATCHED thing awkward. And
perhaps s/VMA_STATE_LOCKED/VMA_LOCK_OFFSET/ ?

Also, perhaps:

#define VMA_REF_LIMIT	(VMA_LOCK_OFFSET - 2)

> @@ -699,10 +700,27 @@ static inline void vma_numab_state_free(struct vm_area_struct *vma) {}
>  #ifdef CONFIG_PER_VMA_LOCK
>  static inline void vma_lock_init(struct vm_area_struct *vma)
>  {
> -	init_rwsem(&vma->vm_lock.lock);
> +#ifdef CONFIG_DEBUG_LOCK_ALLOC
> +	static struct lock_class_key lockdep_key;
> +
> +	lockdep_init_map(&vma->vmlock_dep_map, "vm_lock", &lockdep_key, 0);
> +#endif
> +	refcount_set(&vma->vm_refcnt, VMA_STATE_DETACHED);
>  	vma->vm_lock_seq = UINT_MAX;

Depending on how you do the actual allocation (GFP_ZERO) you might want
to avoid that vm_refcount store entirely.

Perhaps instead write: VM_WARN_ON(refcount_read(&vma->vm_refcnt));

> @@ -813,25 +849,42 @@ static inline void vma_assert_write_locked(struct vm_area_struct *vma)
>  
>  static inline void vma_assert_locked(struct vm_area_struct *vma)
>  {
> -	if (!rwsem_is_locked(&vma->vm_lock.lock))
> +	if (refcount_read(&vma->vm_refcnt) <= VMA_STATE_ATTACHED)
	if (is_vma_detached(vma))

>  		vma_assert_write_locked(vma);
>  }
>  
> -static inline void vma_mark_attached(struct vm_area_struct *vma)
> +/*
> + * WARNING: to avoid racing with vma_mark_attached(), should be called either
> + * under mmap_write_lock or when the object has been isolated under
> + * mmap_write_lock, ensuring no competing writers.
> + */
> +static inline bool is_vma_detached(struct vm_area_struct *vma)
>  {
> -	vma->detached = false;
> +	return refcount_read(&vma->vm_refcnt) == VMA_STATE_DETACHED;
	return !refcount_read(&vma->vm_refcnt);
>  }
>  
> -static inline void vma_mark_detached(struct vm_area_struct *vma)
> +static inline void vma_mark_attached(struct vm_area_struct *vma)
>  {
> -	/* When detaching vma should be write-locked */
>  	vma_assert_write_locked(vma);
> -	vma->detached = true;
> +
> +	if (is_vma_detached(vma))
> +		refcount_set(&vma->vm_refcnt, VMA_STATE_ATTACHED);

Urgh, so it would be really good to not call this at all them not 0.
I've not tried to untangle the mess, but this is really awkward. Surely
you don't add it to the mas multiple times either.

Also:

	refcount_set(&vma->vm_refcnt, 1);

is so much clearer.

That is, should this not live in vma_iter_store*(), right before
mas_store_gfp() ?

Also, ISTR having to set vm_lock_seq right before it?

>  }
>  
> -static inline bool is_vma_detached(struct vm_area_struct *vma)
> +static inline void vma_mark_detached(struct vm_area_struct *vma)
>  {
> -	return vma->detached;
> +	vma_assert_write_locked(vma);
> +
> +	if (is_vma_detached(vma))
> +		return;

Again, this just reads like confusion :/ Surely you don't have the same
with mas_detach?

> +
> +	/* We are the only writer, so no need to use vma_refcount_put(). */
> +	if (!refcount_dec_and_test(&vma->vm_refcnt)) {
> +		/*
> +		 * Reader must have temporarily raised vm_refcnt but it will
> +		 * drop it without using the vma since vma is write-locked.
> +		 */
> +	}
>  }
Peter Zijlstra Dec. 16, 2024, 9:37 p.m. UTC | #4
On Mon, Dec 16, 2024 at 11:24:13AM -0800, Suren Baghdasaryan wrote:
> +static inline void vma_refcount_put(struct vm_area_struct *vma)
> +{
> +	int refcnt;
> +
> +	if (!__refcount_dec_and_test(&vma->vm_refcnt, &refcnt)) {
> +		rwsem_release(&vma->vmlock_dep_map, _RET_IP_);
> +
> +		if (refcnt & VMA_STATE_LOCKED)
> +			rcuwait_wake_up(&vma->vm_mm->vma_writer_wait);
> +	}
> +}
> +
>  /*
>   * Try to read-lock a vma. The function is allowed to occasionally yield false
>   * locked result to avoid performance overhead, in which case we fall back to
> @@ -710,6 +728,8 @@ static inline void vma_lock_init(struct vm_area_struct *vma)
>   */
>  static inline bool vma_start_read(struct vm_area_struct *vma)
>  {
> +	int oldcnt;
> +
>  	/*
>  	 * Check before locking. A race might cause false locked result.
>  	 * We can use READ_ONCE() for the mm_lock_seq here, and don't need
> @@ -720,13 +740,20 @@ static inline bool vma_start_read(struct vm_area_struct *vma)
>  	if (READ_ONCE(vma->vm_lock_seq) == READ_ONCE(vma->vm_mm->mm_lock_seq.sequence))
>  		return false;
>  
> +
> +	rwsem_acquire_read(&vma->vmlock_dep_map, 0, 0, _RET_IP_);
> +	/* Limit at VMA_STATE_LOCKED - 2 to leave one count for a writer */
> +	if (unlikely(!__refcount_inc_not_zero_limited(&vma->vm_refcnt, &oldcnt,
> +						      VMA_STATE_LOCKED - 2))) {
> +		rwsem_release(&vma->vmlock_dep_map, _RET_IP_);
>  		return false;
> +	}
> +	lock_acquired(&vma->vmlock_dep_map, _RET_IP_);
>  
>  	/*
> +	 * Overflow of vm_lock_seq/mm_lock_seq might produce false locked result.
>  	 * False unlocked result is impossible because we modify and check
> +	 * vma->vm_lock_seq under vma->vm_refcnt protection and mm->mm_lock_seq
>  	 * modification invalidates all existing locks.
>  	 *
>  	 * We must use ACQUIRE semantics for the mm_lock_seq so that if we are
> @@ -734,10 +761,12 @@ static inline bool vma_start_read(struct vm_area_struct *vma)
>  	 * after it has been unlocked.
>  	 * This pairs with RELEASE semantics in vma_end_write_all().
>  	 */
> +	if (oldcnt & VMA_STATE_LOCKED ||
> +	    unlikely(vma->vm_lock_seq == raw_read_seqcount(&vma->vm_mm->mm_lock_seq))) {
> +		vma_refcount_put(vma);

Suppose we have detach race with a concurrent RCU lookup like:

					vma = mas_lookup();

	vma_start_write();
	mas_detach();
					vma_start_read()
					rwsem_acquire_read()
					inc // success
	vma_mark_detach();
	dec_and_test // assumes 1->0
		     // is actually 2->1

					if (vm_lock_seq == vma->vm_mm_mm_lock_seq) // true
					  vma_refcount_put
					    dec_and_test() // 1->0
					      *NO* rwsem_release()



>  		return false;
>  	}
> +
>  	return true;
>  }
Suren Baghdasaryan Dec. 16, 2024, 9:44 p.m. UTC | #5
On Mon, Dec 16, 2024 at 1:38 PM Peter Zijlstra <peterz@infradead.org> wrote:
>
> On Mon, Dec 16, 2024 at 11:24:13AM -0800, Suren Baghdasaryan wrote:
> > +static inline void vma_refcount_put(struct vm_area_struct *vma)
> > +{
> > +     int refcnt;
> > +
> > +     if (!__refcount_dec_and_test(&vma->vm_refcnt, &refcnt)) {
> > +             rwsem_release(&vma->vmlock_dep_map, _RET_IP_);
> > +
> > +             if (refcnt & VMA_STATE_LOCKED)
> > +                     rcuwait_wake_up(&vma->vm_mm->vma_writer_wait);
> > +     }
> > +}
> > +
> >  /*
> >   * Try to read-lock a vma. The function is allowed to occasionally yield false
> >   * locked result to avoid performance overhead, in which case we fall back to
> > @@ -710,6 +728,8 @@ static inline void vma_lock_init(struct vm_area_struct *vma)
> >   */
> >  static inline bool vma_start_read(struct vm_area_struct *vma)
> >  {
> > +     int oldcnt;
> > +
> >       /*
> >        * Check before locking. A race might cause false locked result.
> >        * We can use READ_ONCE() for the mm_lock_seq here, and don't need
> > @@ -720,13 +740,20 @@ static inline bool vma_start_read(struct vm_area_struct *vma)
> >       if (READ_ONCE(vma->vm_lock_seq) == READ_ONCE(vma->vm_mm->mm_lock_seq.sequence))
> >               return false;
> >
> > +
> > +     rwsem_acquire_read(&vma->vmlock_dep_map, 0, 0, _RET_IP_);
> > +     /* Limit at VMA_STATE_LOCKED - 2 to leave one count for a writer */
> > +     if (unlikely(!__refcount_inc_not_zero_limited(&vma->vm_refcnt, &oldcnt,
> > +                                                   VMA_STATE_LOCKED - 2))) {
> > +             rwsem_release(&vma->vmlock_dep_map, _RET_IP_);
> >               return false;
> > +     }
> > +     lock_acquired(&vma->vmlock_dep_map, _RET_IP_);
> >
> >       /*
> > +      * Overflow of vm_lock_seq/mm_lock_seq might produce false locked result.
> >        * False unlocked result is impossible because we modify and check
> > +      * vma->vm_lock_seq under vma->vm_refcnt protection and mm->mm_lock_seq
> >        * modification invalidates all existing locks.
> >        *
> >        * We must use ACQUIRE semantics for the mm_lock_seq so that if we are
> > @@ -734,10 +761,12 @@ static inline bool vma_start_read(struct vm_area_struct *vma)
> >        * after it has been unlocked.
> >        * This pairs with RELEASE semantics in vma_end_write_all().
> >        */
> > +     if (oldcnt & VMA_STATE_LOCKED ||
> > +         unlikely(vma->vm_lock_seq == raw_read_seqcount(&vma->vm_mm->mm_lock_seq))) {
> > +             vma_refcount_put(vma);
>
> Suppose we have detach race with a concurrent RCU lookup like:
>
>                                         vma = mas_lookup();
>
>         vma_start_write();
>         mas_detach();
>                                         vma_start_read()
>                                         rwsem_acquire_read()
>                                         inc // success
>         vma_mark_detach();
>         dec_and_test // assumes 1->0
>                      // is actually 2->1
>
>                                         if (vm_lock_seq == vma->vm_mm_mm_lock_seq) // true
>                                           vma_refcount_put
>                                             dec_and_test() // 1->0
>                                               *NO* rwsem_release()
>

Yes, this is possible. I think that's not a problem until we start
reusing the vmas and I deal with this race later in this patchset.
I think what you described here is the same race I mention in the
description of this patch:
https://lore.kernel.org/all/20241216192419.2970941-14-surenb@google.com/
I introduce vma_ensure_detached() in that patch to handle this case
and ensure that vmas are detached before they are returned into the
slab cache for reuse. Does that make sense?


>
>
> >               return false;
> >       }
> > +
> >       return true;
> >  }
Suren Baghdasaryan Dec. 16, 2024, 9:53 p.m. UTC | #6
On Mon, Dec 16, 2024 at 1:15 PM Peter Zijlstra <peterz@infradead.org> wrote:
>
> On Mon, Dec 16, 2024 at 11:24:13AM -0800, Suren Baghdasaryan wrote:
>
> FWIW, I find the whole VMA_STATE_{A,DE}TATCHED thing awkward. And

I'm bad with naming things, so any better suggestions are welcome.
Are you suggesting to drop VMA_STATE_{A,DE}TATCHED nomenclature and
use 0/1 values directly?

> perhaps s/VMA_STATE_LOCKED/VMA_LOCK_OFFSET/ ?

Sounds good. I'll change it to VMA_LOCK_OFFSET.

>
> Also, perhaps:
>
> #define VMA_REF_LIMIT   (VMA_LOCK_OFFSET - 2)

Ack.

>
> > @@ -699,10 +700,27 @@ static inline void vma_numab_state_free(struct vm_area_struct *vma) {}
> >  #ifdef CONFIG_PER_VMA_LOCK
> >  static inline void vma_lock_init(struct vm_area_struct *vma)
> >  {
> > -     init_rwsem(&vma->vm_lock.lock);
> > +#ifdef CONFIG_DEBUG_LOCK_ALLOC
> > +     static struct lock_class_key lockdep_key;
> > +
> > +     lockdep_init_map(&vma->vmlock_dep_map, "vm_lock", &lockdep_key, 0);
> > +#endif
> > +     refcount_set(&vma->vm_refcnt, VMA_STATE_DETACHED);
> >       vma->vm_lock_seq = UINT_MAX;
>
> Depending on how you do the actual allocation (GFP_ZERO) you might want
> to avoid that vm_refcount store entirely.

I think we could initialize it to 0 in the slab cache constructor and
when vma is freed we already ensure it's 0. So, even when reused it
will be in the correct 0 state.

>
> Perhaps instead write: VM_WARN_ON(refcount_read(&vma->vm_refcnt));

Yes, with the above approach that should work.

>
> > @@ -813,25 +849,42 @@ static inline void vma_assert_write_locked(struct vm_area_struct *vma)
> >
> >  static inline void vma_assert_locked(struct vm_area_struct *vma)
> >  {
> > -     if (!rwsem_is_locked(&vma->vm_lock.lock))
> > +     if (refcount_read(&vma->vm_refcnt) <= VMA_STATE_ATTACHED)
>         if (is_vma_detached(vma))
>
> >               vma_assert_write_locked(vma);
> >  }
> >
> > -static inline void vma_mark_attached(struct vm_area_struct *vma)
> > +/*
> > + * WARNING: to avoid racing with vma_mark_attached(), should be called either
> > + * under mmap_write_lock or when the object has been isolated under
> > + * mmap_write_lock, ensuring no competing writers.
> > + */
> > +static inline bool is_vma_detached(struct vm_area_struct *vma)
> >  {
> > -     vma->detached = false;
> > +     return refcount_read(&vma->vm_refcnt) == VMA_STATE_DETACHED;
>         return !refcount_read(&vma->vm_refcnt);
> >  }
> >
> > -static inline void vma_mark_detached(struct vm_area_struct *vma)
> > +static inline void vma_mark_attached(struct vm_area_struct *vma)
> >  {
> > -     /* When detaching vma should be write-locked */
> >       vma_assert_write_locked(vma);
> > -     vma->detached = true;
> > +
> > +     if (is_vma_detached(vma))
> > +             refcount_set(&vma->vm_refcnt, VMA_STATE_ATTACHED);
>
> Urgh, so it would be really good to not call this at all them not 0.
> I've not tried to untangle the mess, but this is really awkward. Surely
> you don't add it to the mas multiple times either.

The issue is that when we merge/split/shrink/grow vmas, we skip on
marking them detached while modifying them. Therefore from
vma_mark_attached() POV it will look like we are attaching an already
attached vma. I can try to clean that up if this is really a concern.

>
> Also:
>
>         refcount_set(&vma->vm_refcnt, 1);
>
> is so much clearer.

Ok, IIUC you are in favour of dropping VMA_STATE_ATTACHED/VMA_STATE_DETACHED.

>
> That is, should this not live in vma_iter_store*(), right before
> mas_store_gfp() ?

Currently it's done right *after* mas_store_gfp() but I was debating
with myself if it indeed should be *before* insertion into the tree...

>
> Also, ISTR having to set vm_lock_seq right before it?

Yes, vma_mark_attached() requires vma to be write-locked beforehand,
hence the above vma_assert_write_locked(). But oftentimes it's locked
not right before vma_mark_attached() because some other modification
functions also require vma to be write-locked.

>
> >  }
> >
> > -static inline bool is_vma_detached(struct vm_area_struct *vma)
> > +static inline void vma_mark_detached(struct vm_area_struct *vma)
> >  {
> > -     return vma->detached;
> > +     vma_assert_write_locked(vma);
> > +
> > +     if (is_vma_detached(vma))
> > +             return;
>
> Again, this just reads like confusion :/ Surely you don't have the same
> with mas_detach?

I'll double-check if we ever double-mark vma as detached.

Thanks for the review!

>
> > +
> > +     /* We are the only writer, so no need to use vma_refcount_put(). */
> > +     if (!refcount_dec_and_test(&vma->vm_refcnt)) {
> > +             /*
> > +              * Reader must have temporarily raised vm_refcnt but it will
> > +              * drop it without using the vma since vma is write-locked.
> > +              */
> > +     }
> >  }
Peter Zijlstra Dec. 16, 2024, 10 p.m. UTC | #7
On Mon, Dec 16, 2024 at 01:53:06PM -0800, Suren Baghdasaryan wrote:

> > That is, should this not live in vma_iter_store*(), right before
> > mas_store_gfp() ?
> 
> Currently it's done right *after* mas_store_gfp() but I was debating
> with myself if it indeed should be *before* insertion into the tree...

The moment it goes into the tree it becomes visible to RCU lookups, it's
a bit weird to have them with !refcnt at that point, but I don't suppose
it harms.
Peter Zijlstra Dec. 17, 2024, 10:30 a.m. UTC | #8
On Mon, Dec 16, 2024 at 01:44:45PM -0800, Suren Baghdasaryan wrote:
> On Mon, Dec 16, 2024 at 1:38 PM Peter Zijlstra <peterz@infradead.org> wrote:
> >
> > On Mon, Dec 16, 2024 at 11:24:13AM -0800, Suren Baghdasaryan wrote:
> > > +static inline void vma_refcount_put(struct vm_area_struct *vma)
> > > +{
> > > +     int refcnt;
> > > +
> > > +     if (!__refcount_dec_and_test(&vma->vm_refcnt, &refcnt)) {
> > > +             rwsem_release(&vma->vmlock_dep_map, _RET_IP_);
> > > +
> > > +             if (refcnt & VMA_STATE_LOCKED)
> > > +                     rcuwait_wake_up(&vma->vm_mm->vma_writer_wait);
> > > +     }
> > > +}
> > > +
> > >  /*
> > >   * Try to read-lock a vma. The function is allowed to occasionally yield false
> > >   * locked result to avoid performance overhead, in which case we fall back to
> > > @@ -710,6 +728,8 @@ static inline void vma_lock_init(struct vm_area_struct *vma)
> > >   */
> > >  static inline bool vma_start_read(struct vm_area_struct *vma)
> > >  {
> > > +     int oldcnt;
> > > +
> > >       /*
> > >        * Check before locking. A race might cause false locked result.
> > >        * We can use READ_ONCE() for the mm_lock_seq here, and don't need
> > > @@ -720,13 +740,20 @@ static inline bool vma_start_read(struct vm_area_struct *vma)
> > >       if (READ_ONCE(vma->vm_lock_seq) == READ_ONCE(vma->vm_mm->mm_lock_seq.sequence))
> > >               return false;
> > >
> > > +
> > > +     rwsem_acquire_read(&vma->vmlock_dep_map, 0, 0, _RET_IP_);
> > > +     /* Limit at VMA_STATE_LOCKED - 2 to leave one count for a writer */
> > > +     if (unlikely(!__refcount_inc_not_zero_limited(&vma->vm_refcnt, &oldcnt,
> > > +                                                   VMA_STATE_LOCKED - 2))) {
> > > +             rwsem_release(&vma->vmlock_dep_map, _RET_IP_);
> > >               return false;
> > > +     }
> > > +     lock_acquired(&vma->vmlock_dep_map, _RET_IP_);
> > >
> > >       /*
> > > +      * Overflow of vm_lock_seq/mm_lock_seq might produce false locked result.
> > >        * False unlocked result is impossible because we modify and check
> > > +      * vma->vm_lock_seq under vma->vm_refcnt protection and mm->mm_lock_seq
> > >        * modification invalidates all existing locks.
> > >        *
> > >        * We must use ACQUIRE semantics for the mm_lock_seq so that if we are
> > > @@ -734,10 +761,12 @@ static inline bool vma_start_read(struct vm_area_struct *vma)
> > >        * after it has been unlocked.
> > >        * This pairs with RELEASE semantics in vma_end_write_all().
> > >        */
> > > +     if (oldcnt & VMA_STATE_LOCKED ||
> > > +         unlikely(vma->vm_lock_seq == raw_read_seqcount(&vma->vm_mm->mm_lock_seq))) {
> > > +             vma_refcount_put(vma);
> >
> > Suppose we have detach race with a concurrent RCU lookup like:
> >
> >                                         vma = mas_lookup();
> >
> >         vma_start_write();
> >         mas_detach();
> >                                         vma_start_read()
> >                                         rwsem_acquire_read()
> >                                         inc // success
> >         vma_mark_detach();
> >         dec_and_test // assumes 1->0
> >                      // is actually 2->1
> >
> >                                         if (vm_lock_seq == vma->vm_mm_mm_lock_seq) // true
> >                                           vma_refcount_put
> >                                             dec_and_test() // 1->0
> >                                               *NO* rwsem_release()
> >
> 
> Yes, this is possible. I think that's not a problem until we start
> reusing the vmas and I deal with this race later in this patchset.
> I think what you described here is the same race I mention in the
> description of this patch:
> https://lore.kernel.org/all/20241216192419.2970941-14-surenb@google.com/
> I introduce vma_ensure_detached() in that patch to handle this case
> and ensure that vmas are detached before they are returned into the
> slab cache for reuse. Does that make sense?

So I just replied there, and no, I don't think it makes sense. Just put
the kmem_cache_free() in vma_refcount_put(), to be done on 0.

Anyway, my point was more about the weird entanglement of lockdep and
the refcount. Just pull the lockdep annotation out of _put() and put it
explicitly in the vma_start_read() error paths and vma_end_read().

Additionally, having vma_end_write() would allow you to put a lockdep
annotation in vma_{start,end}_write() -- which was I think the original
reason I proposed it a while back, that and having improved clarity when
reading the code, since explicitly marking the end of a section is
helpful.
Suren Baghdasaryan Dec. 17, 2024, 4:27 p.m. UTC | #9
On Tue, Dec 17, 2024 at 2:30 AM Peter Zijlstra <peterz@infradead.org> wrote:
>
> On Mon, Dec 16, 2024 at 01:44:45PM -0800, Suren Baghdasaryan wrote:
> > On Mon, Dec 16, 2024 at 1:38 PM Peter Zijlstra <peterz@infradead.org> wrote:
> > >
> > > On Mon, Dec 16, 2024 at 11:24:13AM -0800, Suren Baghdasaryan wrote:
> > > > +static inline void vma_refcount_put(struct vm_area_struct *vma)
> > > > +{
> > > > +     int refcnt;
> > > > +
> > > > +     if (!__refcount_dec_and_test(&vma->vm_refcnt, &refcnt)) {
> > > > +             rwsem_release(&vma->vmlock_dep_map, _RET_IP_);
> > > > +
> > > > +             if (refcnt & VMA_STATE_LOCKED)
> > > > +                     rcuwait_wake_up(&vma->vm_mm->vma_writer_wait);
> > > > +     }
> > > > +}
> > > > +
> > > >  /*
> > > >   * Try to read-lock a vma. The function is allowed to occasionally yield false
> > > >   * locked result to avoid performance overhead, in which case we fall back to
> > > > @@ -710,6 +728,8 @@ static inline void vma_lock_init(struct vm_area_struct *vma)
> > > >   */
> > > >  static inline bool vma_start_read(struct vm_area_struct *vma)
> > > >  {
> > > > +     int oldcnt;
> > > > +
> > > >       /*
> > > >        * Check before locking. A race might cause false locked result.
> > > >        * We can use READ_ONCE() for the mm_lock_seq here, and don't need
> > > > @@ -720,13 +740,20 @@ static inline bool vma_start_read(struct vm_area_struct *vma)
> > > >       if (READ_ONCE(vma->vm_lock_seq) == READ_ONCE(vma->vm_mm->mm_lock_seq.sequence))
> > > >               return false;
> > > >
> > > > +
> > > > +     rwsem_acquire_read(&vma->vmlock_dep_map, 0, 0, _RET_IP_);
> > > > +     /* Limit at VMA_STATE_LOCKED - 2 to leave one count for a writer */
> > > > +     if (unlikely(!__refcount_inc_not_zero_limited(&vma->vm_refcnt, &oldcnt,
> > > > +                                                   VMA_STATE_LOCKED - 2))) {
> > > > +             rwsem_release(&vma->vmlock_dep_map, _RET_IP_);
> > > >               return false;
> > > > +     }
> > > > +     lock_acquired(&vma->vmlock_dep_map, _RET_IP_);
> > > >
> > > >       /*
> > > > +      * Overflow of vm_lock_seq/mm_lock_seq might produce false locked result.
> > > >        * False unlocked result is impossible because we modify and check
> > > > +      * vma->vm_lock_seq under vma->vm_refcnt protection and mm->mm_lock_seq
> > > >        * modification invalidates all existing locks.
> > > >        *
> > > >        * We must use ACQUIRE semantics for the mm_lock_seq so that if we are
> > > > @@ -734,10 +761,12 @@ static inline bool vma_start_read(struct vm_area_struct *vma)
> > > >        * after it has been unlocked.
> > > >        * This pairs with RELEASE semantics in vma_end_write_all().
> > > >        */
> > > > +     if (oldcnt & VMA_STATE_LOCKED ||
> > > > +         unlikely(vma->vm_lock_seq == raw_read_seqcount(&vma->vm_mm->mm_lock_seq))) {
> > > > +             vma_refcount_put(vma);
> > >
> > > Suppose we have detach race with a concurrent RCU lookup like:
> > >
> > >                                         vma = mas_lookup();
> > >
> > >         vma_start_write();
> > >         mas_detach();
> > >                                         vma_start_read()
> > >                                         rwsem_acquire_read()
> > >                                         inc // success
> > >         vma_mark_detach();
> > >         dec_and_test // assumes 1->0
> > >                      // is actually 2->1
> > >
> > >                                         if (vm_lock_seq == vma->vm_mm_mm_lock_seq) // true
> > >                                           vma_refcount_put
> > >                                             dec_and_test() // 1->0
> > >                                               *NO* rwsem_release()
> > >
> >
> > Yes, this is possible. I think that's not a problem until we start
> > reusing the vmas and I deal with this race later in this patchset.
> > I think what you described here is the same race I mention in the
> > description of this patch:
> > https://lore.kernel.org/all/20241216192419.2970941-14-surenb@google.com/
> > I introduce vma_ensure_detached() in that patch to handle this case
> > and ensure that vmas are detached before they are returned into the
> > slab cache for reuse. Does that make sense?
>
> So I just replied there, and no, I don't think it makes sense. Just put
> the kmem_cache_free() in vma_refcount_put(), to be done on 0.

That's very appealing indeed and makes things much simpler. The
problem I see with that is the case when we detach a vma from the tree
to isolate it, then do some cleanup and only then free it. That's done
in vms_gather_munmap_vmas() here:
https://elixir.bootlin.com/linux/v6.12.5/source/mm/vma.c#L1240 and we
even might reattach detached vmas back:
https://elixir.bootlin.com/linux/v6.12.5/source/mm/vma.c#L1312. IOW,
detached state is not final and we can't destroy the object that
reached this state. We could change states to: 0=unused (we can free
the object), 1=detached, 2=attached, etc. but then vma_start_read()
should do something like refcount_inc_more_than_one() instead of
refcount_inc_not_zero(). Would you be ok with such an approach?

>
> Anyway, my point was more about the weird entanglement of lockdep and
> the refcount. Just pull the lockdep annotation out of _put() and put it
> explicitly in the vma_start_read() error paths and vma_end_read().

Ok, I think that's easy.

>
> Additionally, having vma_end_write() would allow you to put a lockdep
> annotation in vma_{start,end}_write() -- which was I think the original
> reason I proposed it a while back, that and having improved clarity when
> reading the code, since explicitly marking the end of a section is
> helpful.

The vma->vmlock_dep_map is tracking vma->vm_refcnt, not the
vma->vm_lock_seq (similar to how today vma->vm_lock has its lockdep
tracking that rw_semaphore). If I implement vma_end_write() then it
will simply be something like:

void vma_end_write(vma)
{
         vma_assert_write_locked(vma);
         vma->vm_lock_seq = UINT_MAX;
}

so, vmlock_dep_map would not be involved.

If you want to track vma->vm_lock_seq with a separate lockdep, that
would be more complicated. Specifically for vma_end_write_all() that
would require us to call rwsem_release() on all locked vmas, however
we currently do not track individual locked vmas. vma_end_write_all()
allows us not to worry about tracking them, knowing that once we do
mmap_write_unlock() they all will get unlocked with one increment of
mm->mm_lock_seq. If your suggestion is to replace vma_end_write_all()
with vma_end_write() and unlock vmas individually across the mm code,
that would be a sizable effort. If that is indeed your ultimate goal,
I can do that as a separate project: introduce vma_end_write(),
gradually add them in required places (not yet sure how complex that
would be), then retire vma_end_write_all() and add a lockdep for
vma->vm_lock_seq.
Peter Zijlstra Dec. 18, 2024, 9:41 a.m. UTC | #10
On Tue, Dec 17, 2024 at 08:27:46AM -0800, Suren Baghdasaryan wrote:

> > So I just replied there, and no, I don't think it makes sense. Just put
> > the kmem_cache_free() in vma_refcount_put(), to be done on 0.
> 
> That's very appealing indeed and makes things much simpler. The
> problem I see with that is the case when we detach a vma from the tree
> to isolate it, then do some cleanup and only then free it. That's done
> in vms_gather_munmap_vmas() here:
> https://elixir.bootlin.com/linux/v6.12.5/source/mm/vma.c#L1240 and we
> even might reattach detached vmas back:
> https://elixir.bootlin.com/linux/v6.12.5/source/mm/vma.c#L1312. IOW,
> detached state is not final and we can't destroy the object that
> reached this state. 

Urgh, so that's the munmap() path, but arguably when that fails, the
map stays in place.

I think this means you're marking detached too soon; you should only
mark detached once you reach the point of no return.

That said, once you've reached the point of no return; and are about to
go remove the page-tables, you very much want to ensure a lack of
concurrency.

So perhaps waiting for out-standing readers at this point isn't crazy.

Also, I'm having a very hard time reading this maple tree stuff :/
Afaict vms_gather_munmap_vmas() only adds the VMAs to be removed to a
second tree, it does not in fact unlink them from the mm yet.

AFAICT it's vma_iter_clear_gfp() that actually wipes the vmas from the
mm -- and that being able to fail is mind boggling and I suppose is what
gives rise to much of this insanity :/

Anyway, I would expect remove_vma() to be the one that marks it detached
(it's already unreachable through vma_lookup() at this point) and there
you should wait for concurrent readers to bugger off.

> We could change states to: 0=unused (we can free
> the object), 1=detached, 2=attached, etc. but then vma_start_read()
> should do something like refcount_inc_more_than_one() instead of
> refcount_inc_not_zero(). Would you be ok with such an approach?

Urgh, I would strongly suggest ditching refcount_t if we go this route.
The thing is; refcount_t should remain a 'simple' straight forward
interface and not allow people to do the wrong thing. Its not meant to
be the kitchen sink -- we have atomic_t for that.

Anyway, the more common scheme at that point is using -1 for 'free', I
think folio->_mapcount uses that even. For that see:
atomic_add_negative*().

> > Additionally, having vma_end_write() would allow you to put a lockdep
> > annotation in vma_{start,end}_write() -- which was I think the original
> > reason I proposed it a while back, that and having improved clarity when
> > reading the code, since explicitly marking the end of a section is
> > helpful.
> 
> The vma->vmlock_dep_map is tracking vma->vm_refcnt, not the
> vma->vm_lock_seq (similar to how today vma->vm_lock has its lockdep
> tracking that rw_semaphore). If I implement vma_end_write() then it
> will simply be something like:
> 
> void vma_end_write(vma)
> {
>          vma_assert_write_locked(vma);
>          vma->vm_lock_seq = UINT_MAX;
> }
> 
> so, vmlock_dep_map would not be involved.

That's just weird; why would you not track vma_{start,end}_write() with
the exclusive side of the 'rwsem' dep_map ?

> If you want to track vma->vm_lock_seq with a separate lockdep, that
> would be more complicated. Specifically for vma_end_write_all() that
> would require us to call rwsem_release() on all locked vmas, however
> we currently do not track individual locked vmas. vma_end_write_all()
> allows us not to worry about tracking them, knowing that once we do
> mmap_write_unlock() they all will get unlocked with one increment of
> mm->mm_lock_seq. If your suggestion is to replace vma_end_write_all()
> with vma_end_write() and unlock vmas individually across the mm code,
> that would be a sizable effort. If that is indeed your ultimate goal,
> I can do that as a separate project: introduce vma_end_write(),
> gradually add them in required places (not yet sure how complex that
> would be), then retire vma_end_write_all() and add a lockdep for
> vma->vm_lock_seq.

Yeah, so ultimately I think it would be clearer if you explicitly mark
the point where the vma modification is 'done'. But I don't suppose we
have to do that here.
diff mbox series

Patch

diff --git a/include/linux/mm.h b/include/linux/mm.h
index ccb8f2afeca8..d9edabc385b3 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -32,6 +32,7 @@ 
 #include <linux/memremap.h>
 #include <linux/slab.h>
 #include <linux/cacheinfo.h>
+#include <linux/rcuwait.h>
 
 struct mempolicy;
 struct anon_vma;
@@ -699,10 +700,27 @@  static inline void vma_numab_state_free(struct vm_area_struct *vma) {}
 #ifdef CONFIG_PER_VMA_LOCK
 static inline void vma_lock_init(struct vm_area_struct *vma)
 {
-	init_rwsem(&vma->vm_lock.lock);
+#ifdef CONFIG_DEBUG_LOCK_ALLOC
+	static struct lock_class_key lockdep_key;
+
+	lockdep_init_map(&vma->vmlock_dep_map, "vm_lock", &lockdep_key, 0);
+#endif
+	refcount_set(&vma->vm_refcnt, VMA_STATE_DETACHED);
 	vma->vm_lock_seq = UINT_MAX;
 }
 
+static inline void vma_refcount_put(struct vm_area_struct *vma)
+{
+	int refcnt;
+
+	if (!__refcount_dec_and_test(&vma->vm_refcnt, &refcnt)) {
+		rwsem_release(&vma->vmlock_dep_map, _RET_IP_);
+
+		if (refcnt & VMA_STATE_LOCKED)
+			rcuwait_wake_up(&vma->vm_mm->vma_writer_wait);
+	}
+}
+
 /*
  * Try to read-lock a vma. The function is allowed to occasionally yield false
  * locked result to avoid performance overhead, in which case we fall back to
@@ -710,6 +728,8 @@  static inline void vma_lock_init(struct vm_area_struct *vma)
  */
 static inline bool vma_start_read(struct vm_area_struct *vma)
 {
+	int oldcnt;
+
 	/*
 	 * Check before locking. A race might cause false locked result.
 	 * We can use READ_ONCE() for the mm_lock_seq here, and don't need
@@ -720,13 +740,20 @@  static inline bool vma_start_read(struct vm_area_struct *vma)
 	if (READ_ONCE(vma->vm_lock_seq) == READ_ONCE(vma->vm_mm->mm_lock_seq.sequence))
 		return false;
 
-	if (unlikely(down_read_trylock(&vma->vm_lock.lock) == 0))
+
+	rwsem_acquire_read(&vma->vmlock_dep_map, 0, 0, _RET_IP_);
+	/* Limit at VMA_STATE_LOCKED - 2 to leave one count for a writer */
+	if (unlikely(!__refcount_inc_not_zero_limited(&vma->vm_refcnt, &oldcnt,
+						      VMA_STATE_LOCKED - 2))) {
+		rwsem_release(&vma->vmlock_dep_map, _RET_IP_);
 		return false;
+	}
+	lock_acquired(&vma->vmlock_dep_map, _RET_IP_);
 
 	/*
-	 * Overflow might produce false locked result.
+	 * Overflow of vm_lock_seq/mm_lock_seq might produce false locked result.
 	 * False unlocked result is impossible because we modify and check
-	 * vma->vm_lock_seq under vma->vm_lock protection and mm->mm_lock_seq
+	 * vma->vm_lock_seq under vma->vm_refcnt protection and mm->mm_lock_seq
 	 * modification invalidates all existing locks.
 	 *
 	 * We must use ACQUIRE semantics for the mm_lock_seq so that if we are
@@ -734,10 +761,12 @@  static inline bool vma_start_read(struct vm_area_struct *vma)
 	 * after it has been unlocked.
 	 * This pairs with RELEASE semantics in vma_end_write_all().
 	 */
-	if (unlikely(vma->vm_lock_seq == raw_read_seqcount(&vma->vm_mm->mm_lock_seq))) {
-		up_read(&vma->vm_lock.lock);
+	if (oldcnt & VMA_STATE_LOCKED ||
+	    unlikely(vma->vm_lock_seq == raw_read_seqcount(&vma->vm_mm->mm_lock_seq))) {
+		vma_refcount_put(vma);
 		return false;
 	}
+
 	return true;
 }
 
@@ -749,8 +778,17 @@  static inline bool vma_start_read(struct vm_area_struct *vma)
  */
 static inline bool vma_start_read_locked_nested(struct vm_area_struct *vma, int subclass)
 {
+	int oldcnt;
+
 	mmap_assert_locked(vma->vm_mm);
-	down_read_nested(&vma->vm_lock.lock, subclass);
+	rwsem_acquire_read(&vma->vmlock_dep_map, subclass, 0, _RET_IP_);
+	/* Limit at VMA_STATE_LOCKED - 2 to leave one count for a writer */
+	if (unlikely(!__refcount_inc_not_zero_limited(&vma->vm_refcnt, &oldcnt,
+						      VMA_STATE_LOCKED - 2))) {
+		rwsem_release(&vma->vmlock_dep_map, _RET_IP_);
+		return false;
+	}
+	lock_acquired(&vma->vmlock_dep_map, _RET_IP_);
 	return true;
 }
 
@@ -762,15 +800,13 @@  static inline bool vma_start_read_locked_nested(struct vm_area_struct *vma, int
  */
 static inline bool vma_start_read_locked(struct vm_area_struct *vma)
 {
-	mmap_assert_locked(vma->vm_mm);
-	down_read(&vma->vm_lock.lock);
-	return true;
+	return vma_start_read_locked_nested(vma, 0);
 }
 
 static inline void vma_end_read(struct vm_area_struct *vma)
 {
 	rcu_read_lock(); /* keeps vma alive till the end of up_read */
-	up_read(&vma->vm_lock.lock);
+	vma_refcount_put(vma);
 	rcu_read_unlock();
 }
 
@@ -813,25 +849,42 @@  static inline void vma_assert_write_locked(struct vm_area_struct *vma)
 
 static inline void vma_assert_locked(struct vm_area_struct *vma)
 {
-	if (!rwsem_is_locked(&vma->vm_lock.lock))
+	if (refcount_read(&vma->vm_refcnt) <= VMA_STATE_ATTACHED)
 		vma_assert_write_locked(vma);
 }
 
-static inline void vma_mark_attached(struct vm_area_struct *vma)
+/*
+ * WARNING: to avoid racing with vma_mark_attached(), should be called either
+ * under mmap_write_lock or when the object has been isolated under
+ * mmap_write_lock, ensuring no competing writers.
+ */
+static inline bool is_vma_detached(struct vm_area_struct *vma)
 {
-	vma->detached = false;
+	return refcount_read(&vma->vm_refcnt) == VMA_STATE_DETACHED;
 }
 
-static inline void vma_mark_detached(struct vm_area_struct *vma)
+static inline void vma_mark_attached(struct vm_area_struct *vma)
 {
-	/* When detaching vma should be write-locked */
 	vma_assert_write_locked(vma);
-	vma->detached = true;
+
+	if (is_vma_detached(vma))
+		refcount_set(&vma->vm_refcnt, VMA_STATE_ATTACHED);
 }
 
-static inline bool is_vma_detached(struct vm_area_struct *vma)
+static inline void vma_mark_detached(struct vm_area_struct *vma)
 {
-	return vma->detached;
+	vma_assert_write_locked(vma);
+
+	if (is_vma_detached(vma))
+		return;
+
+	/* We are the only writer, so no need to use vma_refcount_put(). */
+	if (!refcount_dec_and_test(&vma->vm_refcnt)) {
+		/*
+		 * Reader must have temporarily raised vm_refcnt but it will
+		 * drop it without using the vma since vma is write-locked.
+		 */
+	}
 }
 
 static inline void release_fault_lock(struct vm_fault *vmf)
@@ -896,10 +949,6 @@  static inline void vma_init(struct vm_area_struct *vma, struct mm_struct *mm)
 	vma->vm_mm = mm;
 	vma->vm_ops = &vma_dummy_vm_ops;
 	INIT_LIST_HEAD(&vma->anon_vma_chain);
-#ifdef CONFIG_PER_VMA_LOCK
-	/* vma is not locked, can't use vma_mark_detached() */
-	vma->detached = true;
-#endif
 	vma_numab_state_init(vma);
 	vma_lock_init(vma);
 }
diff --git a/include/linux/mm_types.h b/include/linux/mm_types.h
index 825f6328f9e5..803f718c007c 100644
--- a/include/linux/mm_types.h
+++ b/include/linux/mm_types.h
@@ -19,6 +19,7 @@ 
 #include <linux/workqueue.h>
 #include <linux/seqlock.h>
 #include <linux/percpu_counter.h>
+#include <linux/types.h>
 
 #include <asm/mmu.h>
 
@@ -599,9 +600,9 @@  static inline struct anon_vma_name *anon_vma_name_alloc(const char *name)
 }
 #endif
 
-struct vma_lock {
-	struct rw_semaphore lock;
-};
+#define VMA_STATE_DETACHED	0x0
+#define VMA_STATE_ATTACHED	0x1
+#define VMA_STATE_LOCKED	0x40000000
 
 struct vma_numab_state {
 	/*
@@ -679,19 +680,13 @@  struct vm_area_struct {
 	};
 
 #ifdef CONFIG_PER_VMA_LOCK
-	/*
-	 * Flag to indicate areas detached from the mm->mm_mt tree.
-	 * Unstable RCU readers are allowed to read this.
-	 */
-	bool detached;
-
 	/*
 	 * Can only be written (using WRITE_ONCE()) while holding both:
 	 *  - mmap_lock (in write mode)
-	 *  - vm_lock->lock (in write mode)
+	 *  - vm_refcnt VMA_STATE_LOCKED is set
 	 * Can be read reliably while holding one of:
 	 *  - mmap_lock (in read or write mode)
-	 *  - vm_lock->lock (in read or write mode)
+	 *  - vm_refcnt VMA_STATE_LOCKED is set or vm_refcnt > VMA_STATE_ATTACHED
 	 * Can be read unreliably (using READ_ONCE()) for pessimistic bailout
 	 * while holding nothing (except RCU to keep the VMA struct allocated).
 	 *
@@ -754,7 +749,10 @@  struct vm_area_struct {
 	struct vm_userfaultfd_ctx vm_userfaultfd_ctx;
 #ifdef CONFIG_PER_VMA_LOCK
 	/* Unstable RCU readers are allowed to read this. */
-	struct vma_lock vm_lock ____cacheline_aligned_in_smp;
+	refcount_t vm_refcnt ____cacheline_aligned_in_smp;
+#ifdef CONFIG_DEBUG_LOCK_ALLOC
+	struct lockdep_map vmlock_dep_map;
+#endif
 #endif
 } __randomize_layout;
 
@@ -889,6 +887,7 @@  struct mm_struct {
 					  * by mmlist_lock
 					  */
 #ifdef CONFIG_PER_VMA_LOCK
+		struct rcuwait vma_writer_wait;
 		/*
 		 * This field has lock-like semantics, meaning it is sometimes
 		 * accessed with ACQUIRE/RELEASE semantics.
diff --git a/kernel/fork.c b/kernel/fork.c
index 8cb19c23e892..283909d082cb 100644
--- a/kernel/fork.c
+++ b/kernel/fork.c
@@ -465,10 +465,6 @@  struct vm_area_struct *vm_area_dup(struct vm_area_struct *orig)
 	data_race(memcpy(new, orig, sizeof(*new)));
 	vma_lock_init(new);
 	INIT_LIST_HEAD(&new->anon_vma_chain);
-#ifdef CONFIG_PER_VMA_LOCK
-	/* vma is not locked, can't use vma_mark_detached() */
-	new->detached = true;
-#endif
 	vma_numab_state_init(new);
 	dup_anon_vma_name(orig, new);
 
@@ -488,8 +484,6 @@  static void vm_area_free_rcu_cb(struct rcu_head *head)
 	struct vm_area_struct *vma = container_of(head, struct vm_area_struct,
 						  vm_rcu);
 
-	/* The vma should not be locked while being destroyed. */
-	VM_BUG_ON_VMA(rwsem_is_locked(&vma->vm_lock.lock), vma);
 	__vm_area_free(vma);
 }
 #endif
@@ -1228,6 +1222,9 @@  static inline void mmap_init_lock(struct mm_struct *mm)
 {
 	init_rwsem(&mm->mmap_lock);
 	mm_lock_seqcount_init(mm);
+#ifdef CONFIG_PER_VMA_LOCK
+	rcuwait_init(&mm->vma_writer_wait);
+#endif
 }
 
 static struct mm_struct *mm_init(struct mm_struct *mm, struct task_struct *p,
diff --git a/mm/init-mm.c b/mm/init-mm.c
index 6af3ad675930..4600e7605cab 100644
--- a/mm/init-mm.c
+++ b/mm/init-mm.c
@@ -40,6 +40,7 @@  struct mm_struct init_mm = {
 	.arg_lock	=  __SPIN_LOCK_UNLOCKED(init_mm.arg_lock),
 	.mmlist		= LIST_HEAD_INIT(init_mm.mmlist),
 #ifdef CONFIG_PER_VMA_LOCK
+	.vma_writer_wait = __RCUWAIT_INITIALIZER(init_mm.vma_writer_wait),
 	.mm_lock_seq	= SEQCNT_ZERO(init_mm.mm_lock_seq),
 #endif
 	.user_ns	= &init_user_ns,
diff --git a/mm/memory.c b/mm/memory.c
index c6356ea703d8..cff132003e24 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -6331,7 +6331,25 @@  struct vm_area_struct *lock_mm_and_find_vma(struct mm_struct *mm,
 #ifdef CONFIG_PER_VMA_LOCK
 void __vma_start_write(struct vm_area_struct *vma, unsigned int mm_lock_seq)
 {
-	down_write(&vma->vm_lock.lock);
+	bool detached;
+
+	/*
+	 * If vma is detached then only vma_mark_attached() can raise the
+	 * vm_refcnt. mmap_write_lock prevents racing with vma_mark_attached().
+	 */
+	if (!refcount_inc_not_zero(&vma->vm_refcnt)) {
+		WRITE_ONCE(vma->vm_lock_seq, mm_lock_seq);
+		return;
+	}
+
+	rwsem_acquire(&vma->vmlock_dep_map, 0, 0, _RET_IP_);
+	/* vma is attached, set the writer present bit */
+	refcount_add(VMA_STATE_LOCKED, &vma->vm_refcnt);
+	/* wait until state is VMA_STATE_ATTACHED + (VMA_STATE_LOCKED + 1) */
+	rcuwait_wait_event(&vma->vm_mm->vma_writer_wait,
+		   refcount_read(&vma->vm_refcnt) == VMA_STATE_ATTACHED + (VMA_STATE_LOCKED + 1),
+		   TASK_UNINTERRUPTIBLE);
+	lock_acquired(&vma->vmlock_dep_map, _RET_IP_);
 	/*
 	 * We should use WRITE_ONCE() here because we can have concurrent reads
 	 * from the early lockless pessimistic check in vma_start_read().
@@ -6339,7 +6357,10 @@  void __vma_start_write(struct vm_area_struct *vma, unsigned int mm_lock_seq)
 	 * we should use WRITE_ONCE() for cleanliness and to keep KCSAN happy.
 	 */
 	WRITE_ONCE(vma->vm_lock_seq, mm_lock_seq);
-	up_write(&vma->vm_lock.lock);
+	detached = refcount_sub_and_test(VMA_STATE_LOCKED + 1,
+					 &vma->vm_refcnt);
+	rwsem_release(&vma->vmlock_dep_map, _RET_IP_);
+	VM_BUG_ON_VMA(detached, vma); /* vma should remain attached */
 }
 EXPORT_SYMBOL_GPL(__vma_start_write);
 
@@ -6355,7 +6376,6 @@  struct vm_area_struct *lock_vma_under_rcu(struct mm_struct *mm,
 	struct vm_area_struct *vma;
 
 	rcu_read_lock();
-retry:
 	vma = mas_walk(&mas);
 	if (!vma)
 		goto inval;
@@ -6363,13 +6383,6 @@  struct vm_area_struct *lock_vma_under_rcu(struct mm_struct *mm,
 	if (!vma_start_read(vma))
 		goto inval;
 
-	/* Check if the VMA got isolated after we found it */
-	if (is_vma_detached(vma)) {
-		vma_end_read(vma);
-		count_vm_vma_lock_event(VMA_LOCK_MISS);
-		/* The area was replaced with another one */
-		goto retry;
-	}
 	/*
 	 * At this point, we have a stable reference to a VMA: The VMA is
 	 * locked and we know it hasn't already been isolated.
diff --git a/tools/testing/vma/linux/atomic.h b/tools/testing/vma/linux/atomic.h
index e01f66f98982..2e2021553196 100644
--- a/tools/testing/vma/linux/atomic.h
+++ b/tools/testing/vma/linux/atomic.h
@@ -9,4 +9,9 @@ 
 #define atomic_set(x, y) do {} while (0)
 #define U8_MAX UCHAR_MAX
 
+#ifndef atomic_cmpxchg_relaxed
+#define  atomic_cmpxchg_relaxed		uatomic_cmpxchg
+#define  atomic_cmpxchg_release         uatomic_cmpxchg
+#endif /* atomic_cmpxchg_relaxed */
+
 #endif	/* _LINUX_ATOMIC_H */
diff --git a/tools/testing/vma/vma_internal.h b/tools/testing/vma/vma_internal.h
index 0cdc5f8c3d60..b55556b16060 100644
--- a/tools/testing/vma/vma_internal.h
+++ b/tools/testing/vma/vma_internal.h
@@ -25,7 +25,7 @@ 
 #include <linux/maple_tree.h>
 #include <linux/mm.h>
 #include <linux/rbtree.h>
-#include <linux/rwsem.h>
+#include <linux/refcount.h>
 
 extern unsigned long stack_guard_gap;
 #ifdef CONFIG_MMU
@@ -132,10 +132,6 @@  typedef __bitwise unsigned int vm_fault_t;
  */
 #define pr_warn_once pr_err
 
-typedef struct refcount_struct {
-	atomic_t refs;
-} refcount_t;
-
 struct kref {
 	refcount_t refcount;
 };
@@ -228,15 +224,14 @@  struct mm_struct {
 	unsigned long def_flags;
 };
 
-struct vma_lock {
-	struct rw_semaphore lock;
-};
-
-
 struct file {
 	struct address_space	*f_mapping;
 };
 
+#define VMA_STATE_DETACHED	0x0
+#define VMA_STATE_ATTACHED	0x1
+#define VMA_STATE_LOCKED	0x40000000
+
 struct vm_area_struct {
 	/* The first cache line has the info for VMA tree walking. */
 
@@ -264,16 +259,13 @@  struct vm_area_struct {
 	};
 
 #ifdef CONFIG_PER_VMA_LOCK
-	/* Flag to indicate areas detached from the mm->mm_mt tree */
-	bool detached;
-
 	/*
 	 * Can only be written (using WRITE_ONCE()) while holding both:
 	 *  - mmap_lock (in write mode)
-	 *  - vm_lock.lock (in write mode)
+	 *  - vm_refcnt VMA_STATE_LOCKED is set
 	 * Can be read reliably while holding one of:
 	 *  - mmap_lock (in read or write mode)
-	 *  - vm_lock.lock (in read or write mode)
+	 *  - vm_refcnt VMA_STATE_LOCKED is set or vm_refcnt > VMA_STATE_ATTACHED
 	 * Can be read unreliably (using READ_ONCE()) for pessimistic bailout
 	 * while holding nothing (except RCU to keep the VMA struct allocated).
 	 *
@@ -282,7 +274,6 @@  struct vm_area_struct {
 	 * slowpath.
 	 */
 	unsigned int vm_lock_seq;
-	struct vma_lock vm_lock;
 #endif
 
 	/*
@@ -335,6 +326,10 @@  struct vm_area_struct {
 	struct vma_numab_state *numab_state;	/* NUMA Balancing state */
 #endif
 	struct vm_userfaultfd_ctx vm_userfaultfd_ctx;
+#ifdef CONFIG_PER_VMA_LOCK
+	/* Unstable RCU readers are allowed to read this. */
+	refcount_t vm_refcnt;
+#endif
 } __randomize_layout;
 
 struct vm_fault {};
@@ -461,21 +456,37 @@  static inline struct vm_area_struct *vma_next(struct vma_iterator *vmi)
 
 static inline void vma_lock_init(struct vm_area_struct *vma)
 {
-	init_rwsem(&vma->vm_lock.lock);
+	refcount_set(&vma->vm_refcnt, VMA_STATE_DETACHED);
 	vma->vm_lock_seq = UINT_MAX;
 }
 
-static inline void vma_mark_attached(struct vm_area_struct *vma)
+static inline bool is_vma_detached(struct vm_area_struct *vma)
 {
-	vma->detached = false;
+	return refcount_read(&vma->vm_refcnt) == VMA_STATE_DETACHED;
 }
 
 static inline void vma_assert_write_locked(struct vm_area_struct *);
+static inline void vma_mark_attached(struct vm_area_struct *vma)
+{
+	vma_assert_write_locked(vma);
+
+	if (is_vma_detached(vma))
+		refcount_set(&vma->vm_refcnt, VMA_STATE_ATTACHED);
+}
+
 static inline void vma_mark_detached(struct vm_area_struct *vma)
 {
-	/* When detaching vma should be write-locked */
 	vma_assert_write_locked(vma);
-	vma->detached = true;
+
+	if (is_vma_detached(vma))
+		return;
+
+	if (!refcount_dec_and_test(&vma->vm_refcnt)) {
+		/*
+		 * Reader must have temporarily raised vm_refcnt but it will
+		 * drop it without using the vma since vma is write-locked.
+		 */
+	}
 }
 
 extern const struct vm_operations_struct vma_dummy_vm_ops;
@@ -488,8 +499,6 @@  static inline void vma_init(struct vm_area_struct *vma, struct mm_struct *mm)
 	vma->vm_mm = mm;
 	vma->vm_ops = &vma_dummy_vm_ops;
 	INIT_LIST_HEAD(&vma->anon_vma_chain);
-	/* vma is not locked, can't use vma_mark_detached() */
-	vma->detached = true;
 	vma_lock_init(vma);
 }
 
@@ -515,8 +524,6 @@  static inline struct vm_area_struct *vm_area_dup(struct vm_area_struct *orig)
 	memcpy(new, orig, sizeof(*new));
 	vma_lock_init(new);
 	INIT_LIST_HEAD(&new->anon_vma_chain);
-	/* vma is not locked, can't use vma_mark_detached() */
-	new->detached = true;
 
 	return new;
 }