diff mbox series

[07/14] fs: teach the mm about range locking

Message ID 20190521045242.24378-8-dave@stgolabs.net (mailing list archive)
State New, archived
Headers show
Series mmap_sem range locking | expand

Commit Message

Davidlohr Bueso May 21, 2019, 4:52 a.m. UTC
Conversion is straightforward, mmap_sem is used within the
the same function context most of the time. No change in
semantics.

Signed-off-by: Davidlohr Bueso <dbueso@suse.de>
---
 fs/aio.c                      |  5 +++--
 fs/coredump.c                 |  5 +++--
 fs/exec.c                     | 19 +++++++++-------
 fs/io_uring.c                 |  5 +++--
 fs/proc/base.c                | 23 ++++++++++++--------
 fs/proc/internal.h            |  2 ++
 fs/proc/task_mmu.c            | 32 +++++++++++++++------------
 fs/proc/task_nommu.c          | 22 +++++++++++--------
 fs/userfaultfd.c              | 50 ++++++++++++++++++++++++++-----------------
 include/linux/userfaultfd_k.h |  5 +++--
 10 files changed, 100 insertions(+), 68 deletions(-)
diff mbox series

Patch

diff --git a/fs/aio.c b/fs/aio.c
index 3490d1fa0e16..215d19dbbefa 100644
--- a/fs/aio.c
+++ b/fs/aio.c
@@ -461,6 +461,7 @@  static const struct address_space_operations aio_ctx_aops = {
 
 static int aio_setup_ring(struct kioctx *ctx, unsigned int nr_events)
 {
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 	struct aio_ring *ring;
 	struct mm_struct *mm = current->mm;
 	unsigned long size, unused;
@@ -521,7 +522,7 @@  static int aio_setup_ring(struct kioctx *ctx, unsigned int nr_events)
 	ctx->mmap_size = nr_pages * PAGE_SIZE;
 	pr_debug("attempting mmap of %lu bytes\n", ctx->mmap_size);
 
-	if (down_write_killable(&mm->mmap_sem)) {
+	if (mm_write_lock_killable(mm, &mmrange)) {
 		ctx->mmap_size = 0;
 		aio_free_ring(ctx);
 		return -EINTR;
@@ -530,7 +531,7 @@  static int aio_setup_ring(struct kioctx *ctx, unsigned int nr_events)
 	ctx->mmap_base = do_mmap_pgoff(ctx->aio_ring_file, 0, ctx->mmap_size,
 				       PROT_READ | PROT_WRITE,
 				       MAP_SHARED, 0, &unused, NULL);
-	up_write(&mm->mmap_sem);
+	mm_write_unlock(mm, &mmrange);
 	if (IS_ERR((void *)ctx->mmap_base)) {
 		ctx->mmap_size = 0;
 		aio_free_ring(ctx);
diff --git a/fs/coredump.c b/fs/coredump.c
index e42e17e55bfd..433713b63187 100644
--- a/fs/coredump.c
+++ b/fs/coredump.c
@@ -409,6 +409,7 @@  static int zap_threads(struct task_struct *tsk, struct mm_struct *mm,
 
 static int coredump_wait(int exit_code, struct core_state *core_state)
 {
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 	struct task_struct *tsk = current;
 	struct mm_struct *mm = tsk->mm;
 	int core_waiters = -EBUSY;
@@ -417,12 +418,12 @@  static int coredump_wait(int exit_code, struct core_state *core_state)
 	core_state->dumper.task = tsk;
 	core_state->dumper.next = NULL;
 
-	if (down_write_killable(&mm->mmap_sem))
+	if (mm_write_lock_killable(mm, &mmrange))
 		return -EINTR;
 
 	if (!mm->core_state)
 		core_waiters = zap_threads(tsk, mm, core_state, exit_code);
-	up_write(&mm->mmap_sem);
+	mm_write_unlock(mm, &mmrange);
 
 	if (core_waiters > 0) {
 		struct core_thread *ptr;
diff --git a/fs/exec.c b/fs/exec.c
index e96fd5328739..fbcb36bc4fd1 100644
--- a/fs/exec.c
+++ b/fs/exec.c
@@ -241,6 +241,7 @@  static void flush_arg_page(struct linux_binprm *bprm, unsigned long pos,
 
 static int __bprm_mm_init(struct linux_binprm *bprm)
 {
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 	int err;
 	struct vm_area_struct *vma = NULL;
 	struct mm_struct *mm = bprm->mm;
@@ -250,7 +251,7 @@  static int __bprm_mm_init(struct linux_binprm *bprm)
 		return -ENOMEM;
 	vma_set_anonymous(vma);
 
-	if (down_write_killable(&mm->mmap_sem)) {
+	if (mm_write_lock_killable(mm, &mmrange)) {
 		err = -EINTR;
 		goto err_free;
 	}
@@ -273,11 +274,11 @@  static int __bprm_mm_init(struct linux_binprm *bprm)
 
 	mm->stack_vm = mm->total_vm = 1;
 	arch_bprm_mm_init(mm, vma);
-	up_write(&mm->mmap_sem);
+	mm_write_unlock(mm, &mmrange);
 	bprm->p = vma->vm_end - sizeof(void *);
 	return 0;
 err:
-	up_write(&mm->mmap_sem);
+	mm_write_unlock(mm, &mmrange);
 err_free:
 	bprm->vma = NULL;
 	vm_area_free(vma);
@@ -691,6 +692,7 @@  int setup_arg_pages(struct linux_binprm *bprm,
 		    unsigned long stack_top,
 		    int executable_stack)
 {
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 	unsigned long ret;
 	unsigned long stack_shift;
 	struct mm_struct *mm = current->mm;
@@ -738,7 +740,7 @@  int setup_arg_pages(struct linux_binprm *bprm,
 		bprm->loader -= stack_shift;
 	bprm->exec -= stack_shift;
 
-	if (down_write_killable(&mm->mmap_sem))
+	if (mm_write_lock_killable(mm, &mmrange))
 		return -EINTR;
 
 	vm_flags = VM_STACK_FLAGS;
@@ -795,7 +797,7 @@  int setup_arg_pages(struct linux_binprm *bprm,
 		ret = -EFAULT;
 
 out_unlock:
-	up_write(&mm->mmap_sem);
+	mm_write_unlock(mm, &mmrange);
 	return ret;
 }
 EXPORT_SYMBOL(setup_arg_pages);
@@ -1010,6 +1012,7 @@  static int exec_mmap(struct mm_struct *mm)
 {
 	struct task_struct *tsk;
 	struct mm_struct *old_mm, *active_mm;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
 	/* Notify parent that we're no longer interested in the old VM */
 	tsk = current;
@@ -1024,9 +1027,9 @@  static int exec_mmap(struct mm_struct *mm)
 		 * through with the exec.  We must hold mmap_sem around
 		 * checking core_state and changing tsk->mm.
 		 */
-		down_read(&old_mm->mmap_sem);
+		mm_read_lock(old_mm, &mmrange);
 		if (unlikely(old_mm->core_state)) {
-			up_read(&old_mm->mmap_sem);
+			mm_read_unlock(old_mm, &mmrange);
 			return -EINTR;
 		}
 	}
@@ -1039,7 +1042,7 @@  static int exec_mmap(struct mm_struct *mm)
 	vmacache_flush(tsk);
 	task_unlock(tsk);
 	if (old_mm) {
-		up_read(&old_mm->mmap_sem);
+		mm_read_unlock(old_mm, &mmrange);
 		BUG_ON(active_mm != old_mm);
 		setmax_mm_hiwater_rss(&tsk->signal->maxrss, old_mm);
 		mm_update_next_owner(old_mm);
diff --git a/fs/io_uring.c b/fs/io_uring.c
index e11d77181398..16c06811193b 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -2597,6 +2597,7 @@  static int io_sqe_buffer_register(struct io_ring_ctx *ctx, void __user *arg,
 	struct page **pages = NULL;
 	int i, j, got_pages = 0;
 	int ret = -EINVAL;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
 	if (ctx->user_bufs)
 		return -EBUSY;
@@ -2671,7 +2672,7 @@  static int io_sqe_buffer_register(struct io_ring_ctx *ctx, void __user *arg,
 		}
 
 		ret = 0;
-		down_read(&current->mm->mmap_sem);
+		mm_read_lock(current->mm, &mmrange);
 		pret = get_user_pages(ubuf, nr_pages,
 				      FOLL_WRITE | FOLL_LONGTERM,
 				      pages, vmas);
@@ -2689,7 +2690,7 @@  static int io_sqe_buffer_register(struct io_ring_ctx *ctx, void __user *arg,
 		} else {
 			ret = pret < 0 ? pret : -EFAULT;
 		}
-		up_read(&current->mm->mmap_sem);
+		mm_read_unlock(current->mm, &mmrange);
 		if (ret) {
 			/*
 			 * if we did partial map, or found file backed vmas,
diff --git a/fs/proc/base.c b/fs/proc/base.c
index 9c8ca6cd3ce4..63d0fea104af 100644
--- a/fs/proc/base.c
+++ b/fs/proc/base.c
@@ -1962,9 +1962,11 @@  static int map_files_d_revalidate(struct dentry *dentry, unsigned int flags)
 		goto out;
 
 	if (!dname_to_vma_addr(dentry, &vm_start, &vm_end)) {
-		down_read(&mm->mmap_sem);
+		DEFINE_RANGE_LOCK_FULL(mmrange);
+
+		mm_read_lock(mm, &mmrange);
 		exact_vma_exists = !!find_exact_vma(mm, vm_start, vm_end);
-		up_read(&mm->mmap_sem);
+		mm_read_unlock(mm, &mmrange);
 	}
 
 	mmput(mm);
@@ -1995,6 +1997,7 @@  static int map_files_get_link(struct dentry *dentry, struct path *path)
 	struct task_struct *task;
 	struct mm_struct *mm;
 	int rc;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
 	rc = -ENOENT;
 	task = get_proc_task(d_inode(dentry));
@@ -2011,14 +2014,14 @@  static int map_files_get_link(struct dentry *dentry, struct path *path)
 		goto out_mmput;
 
 	rc = -ENOENT;
-	down_read(&mm->mmap_sem);
+	mm_read_lock(mm, &mmrange);
 	vma = find_exact_vma(mm, vm_start, vm_end);
 	if (vma && vma->vm_file) {
 		*path = vma->vm_file->f_path;
 		path_get(path);
 		rc = 0;
 	}
-	up_read(&mm->mmap_sem);
+	mm_read_unlock(mm, &mmrange);
 
 out_mmput:
 	mmput(mm);
@@ -2089,6 +2092,7 @@  static struct dentry *proc_map_files_lookup(struct inode *dir,
 	struct task_struct *task;
 	struct dentry *result;
 	struct mm_struct *mm;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
 	result = ERR_PTR(-ENOENT);
 	task = get_proc_task(dir);
@@ -2107,7 +2111,7 @@  static struct dentry *proc_map_files_lookup(struct inode *dir,
 	if (!mm)
 		goto out_put_task;
 
-	down_read(&mm->mmap_sem);
+	mm_read_lock(mm, &mmrange);
 	vma = find_exact_vma(mm, vm_start, vm_end);
 	if (!vma)
 		goto out_no_vma;
@@ -2117,7 +2121,7 @@  static struct dentry *proc_map_files_lookup(struct inode *dir,
 				(void *)(unsigned long)vma->vm_file->f_mode);
 
 out_no_vma:
-	up_read(&mm->mmap_sem);
+	mm_read_unlock(mm, &mmrange);
 	mmput(mm);
 out_put_task:
 	put_task_struct(task);
@@ -2141,6 +2145,7 @@  proc_map_files_readdir(struct file *file, struct dir_context *ctx)
 	GENRADIX(struct map_files_info) fa;
 	struct map_files_info *p;
 	int ret;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
 	genradix_init(&fa);
 
@@ -2160,7 +2165,7 @@  proc_map_files_readdir(struct file *file, struct dir_context *ctx)
 	mm = get_task_mm(task);
 	if (!mm)
 		goto out_put_task;
-	down_read(&mm->mmap_sem);
+	mm_read_lock(mm, &mmrange);
 
 	nr_files = 0;
 
@@ -2183,7 +2188,7 @@  proc_map_files_readdir(struct file *file, struct dir_context *ctx)
 		p = genradix_ptr_alloc(&fa, nr_files++, GFP_KERNEL);
 		if (!p) {
 			ret = -ENOMEM;
-			up_read(&mm->mmap_sem);
+			mm_read_unlock(mm, &mmrange);
 			mmput(mm);
 			goto out_put_task;
 		}
@@ -2192,7 +2197,7 @@  proc_map_files_readdir(struct file *file, struct dir_context *ctx)
 		p->end = vma->vm_end;
 		p->mode = vma->vm_file->f_mode;
 	}
-	up_read(&mm->mmap_sem);
+	mm_read_unlock(mm, &mmrange);
 	mmput(mm);
 
 	for (i = 0; i < nr_files; i++) {
diff --git a/fs/proc/internal.h b/fs/proc/internal.h
index d1671e97f7fe..df6f0ec84a8f 100644
--- a/fs/proc/internal.h
+++ b/fs/proc/internal.h
@@ -15,6 +15,7 @@ 
 #include <linux/spinlock.h>
 #include <linux/atomic.h>
 #include <linux/binfmts.h>
+#include <linux/range_lock.h>
 #include <linux/sched/coredump.h>
 #include <linux/sched/task.h>
 
@@ -287,6 +288,7 @@  struct proc_maps_private {
 #ifdef CONFIG_NUMA
 	struct mempolicy *task_mempolicy;
 #endif
+	struct range_lock mmrange;
 } __randomize_layout;
 
 struct mm_struct *proc_mem_open(struct inode *inode, unsigned int mode);
diff --git a/fs/proc/task_mmu.c b/fs/proc/task_mmu.c
index a1c2ad9f960a..7ab5c6f5b8aa 100644
--- a/fs/proc/task_mmu.c
+++ b/fs/proc/task_mmu.c
@@ -128,7 +128,7 @@  static void vma_stop(struct proc_maps_private *priv)
 	struct mm_struct *mm = priv->mm;
 
 	release_task_mempolicy(priv);
-	up_read(&mm->mmap_sem);
+	mm_read_unlock(mm, &priv->mmrange);
 	mmput(mm);
 }
 
@@ -166,7 +166,9 @@  static void *m_start(struct seq_file *m, loff_t *ppos)
 	if (!mm || !mmget_not_zero(mm))
 		return NULL;
 
-	down_read(&mm->mmap_sem);
+	range_lock_init_full(&priv->mmrange);
+
+	mm_read_lock(mm, &priv->mmrange);
 	hold_task_mempolicy(priv);
 	priv->tail_vma = get_gate_vma(mm);
 
@@ -828,7 +830,7 @@  static int show_smaps_rollup(struct seq_file *m, void *v)
 
 	memset(&mss, 0, sizeof(mss));
 
-	down_read(&mm->mmap_sem);
+	mm_read_lock(mm, &priv->mmrange);
 	hold_task_mempolicy(priv);
 
 	for (vma = priv->mm->mmap; vma; vma = vma->vm_next) {
@@ -844,7 +846,7 @@  static int show_smaps_rollup(struct seq_file *m, void *v)
 	__show_smap(m, &mss);
 
 	release_task_mempolicy(priv);
-	up_read(&mm->mmap_sem);
+	mm_read_unlock(mm, &priv->mmrange);
 	mmput(mm);
 
 out_put_task:
@@ -1080,6 +1082,7 @@  static int clear_refs_test_walk(unsigned long start, unsigned long end,
 static ssize_t clear_refs_write(struct file *file, const char __user *buf,
 				size_t count, loff_t *ppos)
 {
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 	struct task_struct *task;
 	char buffer[PROC_NUMBUF];
 	struct mm_struct *mm;
@@ -1118,7 +1121,7 @@  static ssize_t clear_refs_write(struct file *file, const char __user *buf,
 		};
 
 		if (type == CLEAR_REFS_MM_HIWATER_RSS) {
-			if (down_write_killable(&mm->mmap_sem)) {
+			if (mm_write_lock_killable(mm, &mmrange)) {
 				count = -EINTR;
 				goto out_mm;
 			}
@@ -1128,18 +1131,18 @@  static ssize_t clear_refs_write(struct file *file, const char __user *buf,
 			 * resident set size to this mm's current rss value.
 			 */
 			reset_mm_hiwater_rss(mm);
-			up_write(&mm->mmap_sem);
+			mm_write_unlock(mm, &mmrange);
 			goto out_mm;
 		}
 
-		down_read(&mm->mmap_sem);
+		mm_read_lock(mm, &mmrange);
 		tlb_gather_mmu(&tlb, mm, 0, -1);
 		if (type == CLEAR_REFS_SOFT_DIRTY) {
 			for (vma = mm->mmap; vma; vma = vma->vm_next) {
 				if (!(vma->vm_flags & VM_SOFTDIRTY))
 					continue;
-				up_read(&mm->mmap_sem);
-				if (down_write_killable(&mm->mmap_sem)) {
+				mm_read_unlock(mm, &mmrange);
+				if (mm_write_lock_killable(mm, &mmrange)) {
 					count = -EINTR;
 					goto out_mm;
 				}
@@ -1158,14 +1161,14 @@  static ssize_t clear_refs_write(struct file *file, const char __user *buf,
 					 * failed like if
 					 * get_proc_task() fails?
 					 */
-					up_write(&mm->mmap_sem);
+					mm_write_unlock(mm, &mmrange);
 					goto out_mm;
 				}
 				for (vma = mm->mmap; vma; vma = vma->vm_next) {
 					vma->vm_flags &= ~VM_SOFTDIRTY;
 					vma_set_page_prot(vma);
 				}
-				downgrade_write(&mm->mmap_sem);
+				mm_downgrade_write(mm, &mmrange);
 				break;
 			}
 
@@ -1177,7 +1180,7 @@  static ssize_t clear_refs_write(struct file *file, const char __user *buf,
 		if (type == CLEAR_REFS_SOFT_DIRTY)
 			mmu_notifier_invalidate_range_end(&range);
 		tlb_finish_mmu(&tlb, 0, -1);
-		up_read(&mm->mmap_sem);
+		mm_read_unlock(mm, &mmrange);
 out_mm:
 		mmput(mm);
 	}
@@ -1484,6 +1487,7 @@  static ssize_t pagemap_read(struct file *file, char __user *buf,
 	unsigned long start_vaddr;
 	unsigned long end_vaddr;
 	int ret = 0, copied = 0;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
 	if (!mm || !mmget_not_zero(mm))
 		goto out;
@@ -1539,9 +1543,9 @@  static ssize_t pagemap_read(struct file *file, char __user *buf,
 		/* overflow ? */
 		if (end < start_vaddr || end > end_vaddr)
 			end = end_vaddr;
-		down_read(&mm->mmap_sem);
+		mm_read_lock(mm, &mmrange);
 		ret = walk_page_range(start_vaddr, end, &pagemap_walk);
-		up_read(&mm->mmap_sem);
+		mm_read_unlock(mm, &mmrange);
 		start_vaddr = end;
 
 		len = min(count, PM_ENTRY_BYTES * pm.pos);
diff --git a/fs/proc/task_nommu.c b/fs/proc/task_nommu.c
index 36bf0f2e102e..32bf2860eff3 100644
--- a/fs/proc/task_nommu.c
+++ b/fs/proc/task_nommu.c
@@ -23,9 +23,10 @@  void task_mem(struct seq_file *m, struct mm_struct *mm)
 	struct vm_area_struct *vma;
 	struct vm_region *region;
 	struct rb_node *p;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 	unsigned long bytes = 0, sbytes = 0, slack = 0, size;
         
-	down_read(&mm->mmap_sem);
+	mm_read_lock(mm, &mmrange);
 	for (p = rb_first(&mm->mm_rb); p; p = rb_next(p)) {
 		vma = rb_entry(p, struct vm_area_struct, vm_rb);
 
@@ -77,7 +78,7 @@  void task_mem(struct seq_file *m, struct mm_struct *mm)
 		"Shared:\t%8lu bytes\n",
 		bytes, slack, sbytes);
 
-	up_read(&mm->mmap_sem);
+	mm_read_unlock(mm, &mmrange);
 }
 
 unsigned long task_vsize(struct mm_struct *mm)
@@ -85,13 +86,14 @@  unsigned long task_vsize(struct mm_struct *mm)
 	struct vm_area_struct *vma;
 	struct rb_node *p;
 	unsigned long vsize = 0;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
-	down_read(&mm->mmap_sem);
+	mm_read_lock(mm, &mmrange);
 	for (p = rb_first(&mm->mm_rb); p; p = rb_next(p)) {
 		vma = rb_entry(p, struct vm_area_struct, vm_rb);
 		vsize += vma->vm_end - vma->vm_start;
 	}
-	up_read(&mm->mmap_sem);
+	mm_read_unlock(mm, &mmrange);
 	return vsize;
 }
 
@@ -103,8 +105,9 @@  unsigned long task_statm(struct mm_struct *mm,
 	struct vm_region *region;
 	struct rb_node *p;
 	unsigned long size = kobjsize(mm);
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
-	down_read(&mm->mmap_sem);
+	mm_read_lock(mm, &mmrange);
 	for (p = rb_first(&mm->mm_rb); p; p = rb_next(p)) {
 		vma = rb_entry(p, struct vm_area_struct, vm_rb);
 		size += kobjsize(vma);
@@ -119,7 +122,7 @@  unsigned long task_statm(struct mm_struct *mm,
 		>> PAGE_SHIFT;
 	*data = (PAGE_ALIGN(mm->start_stack) - (mm->start_data & PAGE_MASK))
 		>> PAGE_SHIFT;
-	up_read(&mm->mmap_sem);
+	mm_read_unlock(mm, &mmrange);
 	size >>= PAGE_SHIFT;
 	size += *text + *data;
 	*resident = size;
@@ -201,6 +204,7 @@  static void *m_start(struct seq_file *m, loff_t *pos)
 	struct mm_struct *mm;
 	struct rb_node *p;
 	loff_t n = *pos;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
 	/* pin the task and mm whilst we play with them */
 	priv->task = get_proc_task(priv->inode);
@@ -211,13 +215,13 @@  static void *m_start(struct seq_file *m, loff_t *pos)
 	if (!mm || !mmget_not_zero(mm))
 		return NULL;
 
-	down_read(&mm->mmap_sem);
+	mm_read_lock(mm, &mmrange);
 	/* start from the Nth VMA */
 	for (p = rb_first(&mm->mm_rb); p; p = rb_next(p))
 		if (n-- == 0)
 			return p;
 
-	up_read(&mm->mmap_sem);
+	mm_read_unlock(mm, &mmrange);
 	mmput(mm);
 	return NULL;
 }
@@ -227,7 +231,7 @@  static void m_stop(struct seq_file *m, void *_vml)
 	struct proc_maps_private *priv = m->private;
 
 	if (!IS_ERR_OR_NULL(_vml)) {
-		up_read(&priv->mm->mmap_sem);
+		mm_read_unlock(priv->mm, &mmrange);
 		mmput(priv->mm);
 	}
 	if (priv->task) {
diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c
index 3b30301c90ec..3592f6d71778 100644
--- a/fs/userfaultfd.c
+++ b/fs/userfaultfd.c
@@ -220,13 +220,14 @@  static inline bool userfaultfd_huge_must_wait(struct userfaultfd_ctx *ctx,
 					 struct vm_area_struct *vma,
 					 unsigned long address,
 					 unsigned long flags,
-					 unsigned long reason)
+					 unsigned long reason,
+					 struct range_lock *mmrange)
 {
 	struct mm_struct *mm = ctx->mm;
 	pte_t *ptep, pte;
 	bool ret = true;
 
-	VM_BUG_ON(!rwsem_is_locked(&mm->mmap_sem));
+	VM_BUG_ON(!mm_is_locked(mm, mmrange));
 
 	ptep = huge_pte_offset(mm, address, vma_mmu_pagesize(vma));
 
@@ -252,7 +253,9 @@  static inline bool userfaultfd_huge_must_wait(struct userfaultfd_ctx *ctx,
 					 struct vm_area_struct *vma,
 					 unsigned long address,
 					 unsigned long flags,
-					 unsigned long reason)
+					 unsigned long reason,
+					 struct range_lock *mmrange)
+
 {
 	return false;	/* should never get here */
 }
@@ -268,7 +271,8 @@  static inline bool userfaultfd_huge_must_wait(struct userfaultfd_ctx *ctx,
 static inline bool userfaultfd_must_wait(struct userfaultfd_ctx *ctx,
 					 unsigned long address,
 					 unsigned long flags,
-					 unsigned long reason)
+					 unsigned long reason,
+					 struct range_lock *mmrange)
 {
 	struct mm_struct *mm = ctx->mm;
 	pgd_t *pgd;
@@ -278,7 +282,7 @@  static inline bool userfaultfd_must_wait(struct userfaultfd_ctx *ctx,
 	pte_t *pte;
 	bool ret = true;
 
-	VM_BUG_ON(!rwsem_is_locked(&mm->mmap_sem));
+	VM_BUG_ON(!mm_is_locked(mm, mmrange));
 
 	pgd = pgd_offset(mm, address);
 	if (!pgd_present(*pgd))
@@ -368,7 +372,7 @@  vm_fault_t handle_userfault(struct vm_fault *vmf, unsigned long reason)
 	 * Coredumping runs without mmap_sem so we can only check that
 	 * the mmap_sem is held, if PF_DUMPCORE was not set.
 	 */
-	WARN_ON_ONCE(!rwsem_is_locked(&mm->mmap_sem));
+	WARN_ON_ONCE(!mm_is_locked(mm, vmf->lockrange));
 
 	ctx = vmf->vma->vm_userfaultfd_ctx.ctx;
 	if (!ctx)
@@ -476,12 +480,13 @@  vm_fault_t handle_userfault(struct vm_fault *vmf, unsigned long reason)
 
 	if (!is_vm_hugetlb_page(vmf->vma))
 		must_wait = userfaultfd_must_wait(ctx, vmf->address, vmf->flags,
-						  reason);
+						  reason, vmf->lockrange);
 	else
 		must_wait = userfaultfd_huge_must_wait(ctx, vmf->vma,
 						       vmf->address,
-						       vmf->flags, reason);
-	up_read(&mm->mmap_sem);
+						       vmf->flags, reason,
+						       vmf->lockrange);
+	mm_read_unlock(mm, vmf->lockrange);
 
 	if (likely(must_wait && !READ_ONCE(ctx->released) &&
 		   (return_to_userland ? !signal_pending(current) :
@@ -535,7 +540,7 @@  vm_fault_t handle_userfault(struct vm_fault *vmf, unsigned long reason)
 			 * and there's no need to retake the mmap_sem
 			 * in such case.
 			 */
-			down_read(&mm->mmap_sem);
+			mm_read_lock(mm, vmf->lockrange);
 			ret = VM_FAULT_NOPAGE;
 		}
 	}
@@ -628,9 +633,10 @@  static void userfaultfd_event_wait_completion(struct userfaultfd_ctx *ctx,
 	if (release_new_ctx) {
 		struct vm_area_struct *vma;
 		struct mm_struct *mm = release_new_ctx->mm;
+		DEFINE_RANGE_LOCK_FULL(mmrange);
 
 		/* the various vma->vm_userfaultfd_ctx still points to it */
-		down_write(&mm->mmap_sem);
+		mm_write_lock(mm, &mmrange);
 		/* no task can run (and in turn coredump) yet */
 		VM_WARN_ON(!mmget_still_valid(mm));
 		for (vma = mm->mmap; vma; vma = vma->vm_next)
@@ -638,7 +644,7 @@  static void userfaultfd_event_wait_completion(struct userfaultfd_ctx *ctx,
 				vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX;
 				vma->vm_flags &= ~(VM_UFFD_WP | VM_UFFD_MISSING);
 			}
-		up_write(&mm->mmap_sem);
+		mm_write_unlock(mm, &mmrange);
 
 		userfaultfd_ctx_put(release_new_ctx);
 	}
@@ -780,7 +786,8 @@  void mremap_userfaultfd_complete(struct vm_userfaultfd_ctx *vm_ctx,
 }
 
 bool userfaultfd_remove(struct vm_area_struct *vma,
-			unsigned long start, unsigned long end)
+			unsigned long start, unsigned long end,
+			struct range_lock *mmrange)
 {
 	struct mm_struct *mm = vma->vm_mm;
 	struct userfaultfd_ctx *ctx;
@@ -792,7 +799,7 @@  bool userfaultfd_remove(struct vm_area_struct *vma,
 
 	userfaultfd_ctx_get(ctx);
 	WRITE_ONCE(ctx->mmap_changing, true);
-	up_read(&mm->mmap_sem);
+	mm_read_unlock(mm, mmrange);
 
 	msg_init(&ewq.msg);
 
@@ -872,6 +879,7 @@  static int userfaultfd_release(struct inode *inode, struct file *file)
 	/* len == 0 means wake all */
 	struct userfaultfd_wake_range range = { .len = 0, };
 	unsigned long new_flags;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
 	WRITE_ONCE(ctx->released, true);
 
@@ -886,7 +894,7 @@  static int userfaultfd_release(struct inode *inode, struct file *file)
 	 * it's critical that released is set to true (above), before
 	 * taking the mmap_sem for writing.
 	 */
-	down_write(&mm->mmap_sem);
+	mm_write_lock(mm, &mmrange);
 	if (!mmget_still_valid(mm))
 		goto skip_mm;
 	prev = NULL;
@@ -912,7 +920,7 @@  static int userfaultfd_release(struct inode *inode, struct file *file)
 		vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX;
 	}
 skip_mm:
-	up_write(&mm->mmap_sem);
+	mm_write_unlock(mm, &mmrange);
 	mmput(mm);
 wakeup:
 	/*
@@ -1299,6 +1307,7 @@  static int userfaultfd_register(struct userfaultfd_ctx *ctx,
 	unsigned long vm_flags, new_flags;
 	bool found;
 	bool basic_ioctls;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 	unsigned long start, end, vma_end;
 
 	user_uffdio_register = (struct uffdio_register __user *) arg;
@@ -1339,7 +1348,7 @@  static int userfaultfd_register(struct userfaultfd_ctx *ctx,
 	if (!mmget_not_zero(mm))
 		goto out;
 
-	down_write(&mm->mmap_sem);
+	mm_write_lock(mm, &mmrange);
 	if (!mmget_still_valid(mm))
 		goto out_unlock;
 	vma = find_vma_prev(mm, start, &prev);
@@ -1483,7 +1492,7 @@  static int userfaultfd_register(struct userfaultfd_ctx *ctx,
 		vma = vma->vm_next;
 	} while (vma && vma->vm_start < end);
 out_unlock:
-	up_write(&mm->mmap_sem);
+	mm_write_unlock(mm, &mmrange);
 	mmput(mm);
 	if (!ret) {
 		/*
@@ -1511,6 +1520,7 @@  static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
 	bool found;
 	unsigned long start, end, vma_end;
 	const void __user *buf = (void __user *)arg;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
 	ret = -EFAULT;
 	if (copy_from_user(&uffdio_unregister, buf, sizeof(uffdio_unregister)))
@@ -1528,7 +1538,7 @@  static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
 	if (!mmget_not_zero(mm))
 		goto out;
 
-	down_write(&mm->mmap_sem);
+	mm_write_lock(mm, &mmrange);
 	if (!mmget_still_valid(mm))
 		goto out_unlock;
 	vma = find_vma_prev(mm, start, &prev);
@@ -1645,7 +1655,7 @@  static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
 		vma = vma->vm_next;
 	} while (vma && vma->vm_start < end);
 out_unlock:
-	up_write(&mm->mmap_sem);
+	mm_write_unlock(mm, &mmrange);
 	mmput(mm);
 out:
 	return ret;
diff --git a/include/linux/userfaultfd_k.h b/include/linux/userfaultfd_k.h
index ac9d71e24b81..c8d3c102ce5e 100644
--- a/include/linux/userfaultfd_k.h
+++ b/include/linux/userfaultfd_k.h
@@ -68,7 +68,7 @@  extern void mremap_userfaultfd_complete(struct vm_userfaultfd_ctx *,
 
 extern bool userfaultfd_remove(struct vm_area_struct *vma,
 			       unsigned long start,
-			       unsigned long end);
+			       unsigned long end, struct range_lock *mmrange);
 
 extern int userfaultfd_unmap_prep(struct vm_area_struct *vma,
 				  unsigned long start, unsigned long end,
@@ -125,7 +125,8 @@  static inline void mremap_userfaultfd_complete(struct vm_userfaultfd_ctx *ctx,
 
 static inline bool userfaultfd_remove(struct vm_area_struct *vma,
 				      unsigned long start,
-				      unsigned long end)
+				      unsigned long end,
+				      struct range_lock *mmrange)
 {
 	return true;
 }