@@ -90,7 +90,7 @@ static struct mm_struct tboot_mm = {
.pgd = swapper_pg_dir,
.mm_users = ATOMIC_INIT(2),
.mm_count = ATOMIC_INIT(1),
- .mmap_sem = __RWSEM_INITIALIZER(init_mm.mmap_sem),
+ .mmap_sem = MM_LOCK_INITIALIZER(init_mm.mmap_sem),
.page_table_lock = __SPIN_LOCK_UNLOCKED(init_mm.page_table_lock),
.mmlist = LIST_HEAD_INIT(init_mm.mmlist),
};
@@ -60,7 +60,7 @@ struct mm_struct efi_mm = {
.mm_rb = RB_ROOT,
.mm_users = ATOMIC_INIT(2),
.mm_count = ATOMIC_INIT(1),
- .mmap_sem = __RWSEM_INITIALIZER(efi_mm.mmap_sem),
+ .mmap_sem = MM_LOCK_INITIALIZER(efi_mm.mmap_sem),
.page_table_lock = __SPIN_LOCK_UNLOCKED(efi_mm.page_table_lock),
.mmlist = LIST_HEAD_INIT(efi_mm.mmlist),
.cpu_bitmap = { [BITS_TO_LONGS(NR_CPUS)] = 0},
@@ -2,17 +2,26 @@
#define _LINUX_MM_LOCK_H
#include <linux/sched.h>
-
-static inline void mm_init_lock(struct mm_struct *mm)
-{
- init_rwsem(&mm->mmap_sem);
-}
+#include <linux/lockdep.h>
#ifdef CONFIG_MM_LOCK_RWSEM_INLINE
+#define MM_LOCK_INITIALIZER __RWSEM_INITIALIZER
#define MM_COARSE_LOCK_RANGE_INITIALIZER {}
+static inline void mm_init_lock(struct mm_struct *mm)
+{
+ init_rwsem(&mm->mmap_sem);
+}
+
static inline void mm_init_coarse_lock_range(struct mm_lock_range *range) {}
+static inline void mm_init_lock_range(struct mm_lock_range *range,
+ unsigned long start, unsigned long end) {}
+
+static inline bool mm_range_is_coarse(struct mm_lock_range *range)
+{
+ return true;
+}
static inline void mm_write_range_lock(struct mm_struct *mm,
struct mm_lock_range *range)
@@ -86,15 +95,80 @@ static inline struct mm_lock_range *mm_coarse_lock_range(void)
return NULL;
}
-#else /* CONFIG_MM_LOCK_RWSEM_CHECKED */
+#else /* !CONFIG_MM_LOCK_RWSEM_INLINE */
+
+#ifdef CONFIG_MM_LOCK_RWSEM_CHECKED
+#define MM_LOCK_INITIALIZER __RWSEM_INITIALIZER
#define MM_COARSE_LOCK_RANGE_INITIALIZER { .mm = NULL }
+static inline void mm_init_lock(struct mm_struct *mm)
+{
+ init_rwsem(&mm->mmap_sem);
+}
+
static inline void mm_init_coarse_lock_range(struct mm_lock_range *range)
{
range->mm = NULL;
}
+static inline void mm_init_lock_range(struct mm_lock_range *range,
+ unsigned long start, unsigned long end) {
+ mm_init_coarse_lock_range(range);
+}
+
+static inline bool mm_range_is_coarse(struct mm_lock_range *range)
+{
+ return true;
+}
+
+#else /* CONFIG_MM_LOCK_RANGE */
+
+#ifdef CONFIG_DEBUG_LOCK_ALLOC
+#define __DEP_MAP_MM_LOCK_INITIALIZER(lockname) \
+ .dep_map = { .name = #lockname },
+#else
+#define __DEP_MAP_MM_LOCK_INITIALIZER(lockname)
+#endif
+
+#define MM_LOCK_INITIALIZER(name) { \
+ .mutex = __MUTEX_INITIALIZER(name.mutex), \
+ .rb_root = RB_ROOT, \
+ __DEP_MAP_MM_LOCK_INITIALIZER(name) \
+}
+
+#define MM_COARSE_LOCK_RANGE_INITIALIZER { \
+ .start = 0, \
+ .end = ~0UL, \
+}
+
+static inline void mm_init_lock(struct mm_struct *mm)
+{
+ static struct lock_class_key __key;
+
+ mutex_init(&mm->mmap_sem.mutex);
+ mm->mmap_sem.rb_root = RB_ROOT;
+ lockdep_init_map(&mm->mmap_sem.dep_map, "&mm->mmap_sem", &__key, 0);
+}
+
+static inline void mm_init_lock_range(struct mm_lock_range *range,
+ unsigned long start, unsigned long end) {
+ range->start = start;
+ range->end = end;
+}
+
+static inline void mm_init_coarse_lock_range(struct mm_lock_range *range)
+{
+ mm_init_lock_range(range, 0, ~0UL);
+}
+
+static inline bool mm_range_is_coarse(struct mm_lock_range *range)
+{
+ return range->start == 0 && range->end == ~0UL;
+}
+
+#endif /* CONFIG_MM_LOCK_RANGE */
+
extern void mm_write_range_lock(struct mm_struct *mm,
struct mm_lock_range *range);
#ifdef CONFIG_LOCKDEP
@@ -129,11 +203,11 @@ static inline struct mm_lock_range *mm_coarse_lock_range(void)
return ¤t->mm_coarse_lock_range;
}
-#endif
+#endif /* !CONFIG_MM_LOCK_RWSEM_INLINE */
static inline void mm_read_release(struct mm_struct *mm, unsigned long ip)
{
- rwsem_release(&mm->mmap_sem.dep_map, ip);
+ lock_release(&mm->mmap_sem.dep_map, ip);
}
static inline void mm_write_lock(struct mm_struct *mm)
@@ -183,7 +257,13 @@ static inline void mm_read_unlock(struct mm_struct *mm)
static inline bool mm_is_locked(struct mm_struct *mm)
{
+#ifndef CONFIG_MM_LOCK_RANGE
return rwsem_is_locked(&mm->mmap_sem) != 0;
+#elseif defined(CONFIG_LOCKDEP)
+ return lockdep_is_held(&mm->mmap_sem); /* Close enough for asserts */
+#else
+ return true;
+#endif
}
#endif /* _LINUX_MM_LOCK_H */
@@ -283,6 +283,21 @@ struct vm_userfaultfd_ctx {
struct vm_userfaultfd_ctx {};
#endif /* CONFIG_USERFAULTFD */
+/*
+ * struct mm_lock stores locked address ranges for a given mm,
+ * implementing a fine-grained replacement for the mmap_sem rwsem.
+ */
+#ifdef CONFIG_MM_LOCK_RANGE
+struct mm_lock {
+ struct mutex mutex;
+ struct rb_root rb_root;
+ unsigned long seq;
+#ifdef CONFIG_DEBUG_LOCK_ALLOC
+ struct lockdep_map dep_map;
+#endif
+};
+#endif
+
/*
* This struct defines a memory VMM memory area. There is one of these
* per VM-area/task. A VM area is any part of the process virtual memory
@@ -426,7 +441,12 @@ struct mm_struct {
spinlock_t page_table_lock; /* Protects page tables and some
* counters
*/
+
+#ifndef CONFIG_MM_LOCK_RANGE
struct rw_semaphore mmap_sem;
+#else
+ struct mm_lock mmap_sem;
+#endif
struct list_head mmlist; /* List of maybe swapped mm's. These
* are globally strung together off
@@ -12,6 +12,7 @@
#include <linux/threads.h>
#include <linux/atomic.h>
#include <linux/cpumask.h>
+#include <linux/rbtree.h>
#include <asm/page.h>
@@ -100,6 +101,20 @@ struct mm_lock_range {
#ifdef CONFIG_MM_LOCK_RWSEM_CHECKED
struct mm_struct *mm;
#endif
+#ifdef CONFIG_MM_LOCK_RANGE
+ /* First cache line - used in insert / remove / iter */
+ struct rb_node rb;
+ long flags_count;
+ unsigned long start; /* First address of the range. */
+ unsigned long end; /* First address after the range. */
+ struct {
+ unsigned long read_end; /* Largest end in reader nodes. */
+ unsigned long write_end; /* Largest end in writer nodes. */
+ } __subtree; /* Subtree augmented information. */
+ /* Second cache line - used in wait and wake. */
+ unsigned long seq; /* Killable wait sequence number. */
+ struct task_struct *task; /* Task trying to lock this range. */
+#endif
};
#endif /* _LINUX_MM_TYPES_TASK_H */
@@ -741,7 +741,7 @@ config MAPPING_DIRTY_HELPERS
choice
prompt "MM lock implementation (mmap_sem)"
- default MM_LOCK_RWSEM_CHECKED
+ default MM_LOCK_RANGE
config MM_LOCK_RWSEM_INLINE
bool "rwsem, inline"
@@ -755,6 +755,13 @@ config MM_LOCK_RWSEM_CHECKED
This option implements the MM lock using a read-write semaphore,
ignoring the passed address range but checking its validity.
+config MM_LOCK_RANGE
+ bool "range lock"
+ help
+ This option implements the MM lock as a read-write range lock,
+ thus avoiding false conflicts between operations that operate
+ on non-overlapping address ranges.
+
endchoice
endmenu
@@ -109,3 +109,4 @@ obj-$(CONFIG_HMM_MIRROR) += hmm.o
obj-$(CONFIG_MEMFD_CREATE) += memfd.o
obj-$(CONFIG_MAPPING_DIRTY_HELPERS) += mapping_dirty_helpers.o
obj-$(CONFIG_MM_LOCK_RWSEM_CHECKED) += mm_lock_rwsem_checked.o
+obj-$(CONFIG_MM_LOCK_RANGE) += mm_lock_range.o
@@ -1,5 +1,6 @@
// SPDX-License-Identifier: GPL-2.0
#include <linux/mm_types.h>
+#include <linux/mm_lock.h>
#include <linux/rbtree.h>
#include <linux/rwsem.h>
#include <linux/spinlock.h>
@@ -31,7 +32,7 @@ struct mm_struct init_mm = {
.pgd = swapper_pg_dir,
.mm_users = ATOMIC_INIT(2),
.mm_count = ATOMIC_INIT(1),
- .mmap_sem = __RWSEM_INITIALIZER(init_mm.mmap_sem),
+ .mmap_sem = MM_LOCK_INITIALIZER(init_mm.mmap_sem),
.page_table_lock = __SPIN_LOCK_UNLOCKED(init_mm.page_table_lock),
.arg_lock = __SPIN_LOCK_UNLOCKED(init_mm.arg_lock),
.mmlist = LIST_HEAD_INIT(init_mm.mmlist),
new file mode 100644
@@ -0,0 +1,691 @@
+#include <linux/mm_lock.h>
+#include <linux/rbtree_augmented.h>
+#include <linux/mutex.h>
+#include <linux/lockdep.h>
+#include <linux/sched.h>
+#include <linux/sched/signal.h>
+#include <linux/sched/wake_q.h>
+
+/* range->flags_count definitions */
+#define MM_LOCK_RANGE_WRITE 1
+#define MM_LOCK_RANGE_COUNT_ONE 2
+
+static inline bool rbcompute(struct mm_lock_range *range, bool exit)
+{
+ struct mm_lock_range *child;
+ unsigned long subtree_read_end = range->end, subtree_write_end = 0;
+ if (range->flags_count & MM_LOCK_RANGE_WRITE) {
+ subtree_read_end = 0;
+ subtree_write_end = range->end;
+ }
+ if (range->rb.rb_left) {
+ child = rb_entry(range->rb.rb_left, struct mm_lock_range, rb);
+ if (child->__subtree.read_end > subtree_read_end)
+ subtree_read_end = child->__subtree.read_end;
+ if (child->__subtree.write_end > subtree_write_end)
+ subtree_write_end = child->__subtree.write_end;
+ }
+ if (range->rb.rb_right) {
+ child = rb_entry(range->rb.rb_right, struct mm_lock_range, rb);
+ if (child->__subtree.read_end > subtree_read_end)
+ subtree_read_end = child->__subtree.read_end;
+ if (child->__subtree.write_end > subtree_write_end)
+ subtree_write_end = child->__subtree.write_end;
+ }
+ if (exit && range->__subtree.read_end == subtree_read_end &&
+ range->__subtree.write_end == subtree_write_end)
+ return true;
+ range->__subtree.read_end = subtree_read_end;
+ range->__subtree.write_end = subtree_write_end;
+ return false;
+}
+
+RB_DECLARE_CALLBACKS(static, augment, struct mm_lock_range, rb,
+ __subtree, rbcompute);
+
+static void insert_read(struct mm_lock_range *range, struct rb_root *root)
+{
+ struct rb_node **link = &root->rb_node, *rb_parent = NULL;
+ unsigned long start = range->start, end = range->end;
+ struct mm_lock_range *parent;
+
+ while (*link) {
+ rb_parent = *link;
+ parent = rb_entry(rb_parent, struct mm_lock_range, rb);
+ if (parent->__subtree.read_end < end)
+ parent->__subtree.read_end = end;
+ if (start < parent->start)
+ link = &parent->rb.rb_left;
+ else
+ link = &parent->rb.rb_right;
+ }
+
+ range->__subtree.read_end = end;
+ range->__subtree.write_end = 0;
+ rb_link_node(&range->rb, rb_parent, link);
+ rb_insert_augmented(&range->rb, root, &augment);
+}
+
+static void insert_write(struct mm_lock_range *range, struct rb_root *root)
+{
+ struct rb_node **link = &root->rb_node, *rb_parent = NULL;
+ unsigned long start = range->start, end = range->end;
+ struct mm_lock_range *parent;
+
+ while (*link) {
+ rb_parent = *link;
+ parent = rb_entry(rb_parent, struct mm_lock_range, rb);
+ if (parent->__subtree.write_end < end)
+ parent->__subtree.write_end = end;
+ if (start < parent->start)
+ link = &parent->rb.rb_left;
+ else
+ link = &parent->rb.rb_right;
+ }
+
+ range->__subtree.read_end = 0;
+ range->__subtree.write_end = end;
+ rb_link_node(&range->rb, rb_parent, link);
+ rb_insert_augmented(&range->rb, root, &augment);
+}
+
+static void remove(struct mm_lock_range *range, struct rb_root *root)
+{
+ rb_erase_augmented(&range->rb, root, &augment);
+}
+
+/*
+ * Iterate over ranges intersecting [start;end)
+ *
+ * Note that a range intersects [start;end) iff:
+ * Cond1: range->start < end
+ * and
+ * Cond2: start < range->end
+ */
+
+static struct mm_lock_range *
+subtree_search(struct mm_lock_range *range,
+ unsigned long start, unsigned long end)
+{
+ while (true) {
+ /*
+ * Loop invariant: start < range->__subtree.read_end
+ * or start < range->__subtree.write_end
+ * (Cond2 is satisfied by one of the subtree ranges)
+ */
+ if (range->rb.rb_left) {
+ struct mm_lock_range *left = rb_entry(
+ range->rb.rb_left, struct mm_lock_range, rb);
+ if (start < left->__subtree.read_end ||
+ start < left->__subtree.write_end) {
+ /*
+ * Some ranges in left subtree satisfy Cond2.
+ * Iterate to find the leftmost such range R.
+ * If it also satisfies Cond1, that's the
+ * match we are looking for. Otherwise, there
+ * is no matching interval as ranges to the
+ * right of R can't satisfy Cond1 either.
+ */
+ range = left;
+ continue;
+ }
+ }
+ if (range->start < end) { /* Cond1 */
+ if (start < range->end) /* Cond2 */
+ return range; /* range is leftmost match */
+ if (range->rb.rb_right) {
+ range = rb_entry(range->rb.rb_right,
+ struct mm_lock_range, rb);
+ if (start < range->__subtree.read_end ||
+ start < range->__subtree.write_end)
+ continue;
+ }
+ }
+ return NULL; /* No match */
+ }
+}
+
+static struct mm_lock_range *
+iter_first(struct rb_root *root, unsigned long start, unsigned long end)
+{
+ struct mm_lock_range *range;
+
+ if (!root->rb_node)
+ return NULL;
+ range = rb_entry(root->rb_node, struct mm_lock_range, rb);
+ if (range->__subtree.read_end <= start &&
+ range->__subtree.write_end <= start)
+ return NULL;
+ return subtree_search(range, start, end);
+}
+
+static struct mm_lock_range *
+iter_next(struct mm_lock_range *range, unsigned long start, unsigned long end)
+{
+ struct rb_node *rb = range->rb.rb_right, *prev;
+
+ while (true) {
+ /*
+ * Loop invariants:
+ * Cond1: range->start < end
+ * rb == range->rb.rb_right
+ *
+ * First, search right subtree if suitable
+ */
+ if (rb) {
+ struct mm_lock_range *right = rb_entry(
+ rb, struct mm_lock_range, rb);
+ if (start < right->__subtree.read_end ||
+ start < right->__subtree.write_end)
+ return subtree_search(right, start, end);
+ }
+
+ /* Move up the tree until we come from a range's left child */
+ do {
+ rb = rb_parent(&range->rb);
+ if (!rb)
+ return NULL;
+ prev = &range->rb;
+ range = rb_entry(rb, struct mm_lock_range, rb);
+ rb = range->rb.rb_right;
+ } while (prev == rb);
+
+ /* Check if the range intersects [start;end) */
+ if (end <= range->start) /* !Cond1 */
+ return NULL;
+ else if (start < range->end) /* Cond2 */
+ return range;
+ }
+}
+
+#define FOR_EACH_RANGE(mm, start, end, tmp) \
+for (tmp = iter_first(&mm->mmap_sem.rb_root, start, end); tmp; \
+ tmp = iter_next(tmp, start, end))
+
+static struct mm_lock_range *
+subtree_search_read(struct mm_lock_range *range,
+ unsigned long start, unsigned long end)
+{
+ while (true) {
+ /*
+ * Loop invariant: start < range->__subtree.read_end
+ * (Cond2 is satisfied by one of the subtree ranges)
+ */
+ if (range->rb.rb_left) {
+ struct mm_lock_range *left = rb_entry(
+ range->rb.rb_left, struct mm_lock_range, rb);
+ if (start < left->__subtree.read_end) {
+ /*
+ * Some ranges in left subtree satisfy Cond2.
+ * Iterate to find the leftmost such range R.
+ * If it also satisfies Cond1, that's the
+ * match we are looking for. Otherwise, there
+ * is no matching interval as ranges to the
+ * right of R can't satisfy Cond1 either.
+ */
+ range = left;
+ continue;
+ }
+ }
+ if (range->start < end) { /* Cond1 */
+ if (start < range->end && /* Cond2 */
+ !(range->flags_count & MM_LOCK_RANGE_WRITE))
+ return range; /* range is leftmost match */
+ if (range->rb.rb_right) {
+ range = rb_entry(range->rb.rb_right,
+ struct mm_lock_range, rb);
+ if (start < range->__subtree.read_end)
+ continue;
+ }
+ }
+ return NULL; /* No match */
+ }
+}
+
+static struct mm_lock_range *
+iter_first_read(struct rb_root *root, unsigned long start, unsigned long end)
+{
+ struct mm_lock_range *range;
+
+ if (!root->rb_node)
+ return NULL;
+ range = rb_entry(root->rb_node, struct mm_lock_range, rb);
+ if (range->__subtree.read_end <= start)
+ return NULL;
+ return subtree_search_read(range, start, end);
+}
+
+static struct mm_lock_range *
+iter_next_read(struct mm_lock_range *range,
+ unsigned long start, unsigned long end)
+{
+ struct rb_node *rb = range->rb.rb_right, *prev;
+
+ while (true) {
+ /*
+ * Loop invariants:
+ * Cond1: range->start < end
+ * rb == range->rb.rb_right
+ *
+ * First, search right subtree if suitable
+ */
+ if (rb) {
+ struct mm_lock_range *right = rb_entry(
+ rb, struct mm_lock_range, rb);
+ if (start < right->__subtree.read_end)
+ return subtree_search_read(right, start, end);
+ }
+
+ /* Move up the tree until we come from a range's left child */
+ do {
+ rb = rb_parent(&range->rb);
+ if (!rb)
+ return NULL;
+ prev = &range->rb;
+ range = rb_entry(rb, struct mm_lock_range, rb);
+ rb = range->rb.rb_right;
+ } while (prev == rb);
+
+ /* Check if the range intersects [start;end) */
+ if (end <= range->start) /* !Cond1 */
+ return NULL;
+ else if (start < range->end && /* Cond2 */
+ !(range->flags_count & MM_LOCK_RANGE_WRITE))
+ return range;
+ }
+}
+
+#define FOR_EACH_RANGE_READ(mm, start, end, tmp) \
+for (tmp = iter_first_read(&mm->mmap_sem.rb_root, start, end); tmp; \
+ tmp = iter_next_read(tmp, start, end))
+
+static struct mm_lock_range *
+subtree_search_write(struct mm_lock_range *range,
+ unsigned long start, unsigned long end)
+{
+ while (true) {
+ /*
+ * Loop invariant: start < range->__subtree.write_end
+ * (Cond2 is satisfied by one of the subtree ranges)
+ */
+ if (range->rb.rb_left) {
+ struct mm_lock_range *left = rb_entry(
+ range->rb.rb_left, struct mm_lock_range, rb);
+ if (start < left->__subtree.write_end) {
+ /*
+ * Some ranges in left subtree satisfy Cond2.
+ * Iterate to find the leftmost such range R.
+ * If it also satisfies Cond1, that's the
+ * match we are looking for. Otherwise, there
+ * is no matching interval as ranges to the
+ * right of R can't satisfy Cond1 either.
+ */
+ range = left;
+ continue;
+ }
+ }
+ if (range->start < end) { /* Cond1 */
+ if (start < range->end && /* Cond2 */
+ range->flags_count & MM_LOCK_RANGE_WRITE)
+ return range; /* range is leftmost match */
+ if (range->rb.rb_right) {
+ range = rb_entry(range->rb.rb_right,
+ struct mm_lock_range, rb);
+ if (start < range->__subtree.write_end)
+ continue;
+ }
+ }
+ return NULL; /* No match */
+ }
+}
+
+static struct mm_lock_range *
+iter_first_write(struct rb_root *root, unsigned long start, unsigned long end)
+{
+ struct mm_lock_range *range;
+
+ if (!root->rb_node)
+ return NULL;
+ range = rb_entry(root->rb_node, struct mm_lock_range, rb);
+ if (range->__subtree.write_end <= start)
+ return NULL;
+ return subtree_search_write(range, start, end);
+}
+
+static struct mm_lock_range *
+iter_next_write(struct mm_lock_range *range,
+ unsigned long start, unsigned long end)
+{
+ struct rb_node *rb = range->rb.rb_right, *prev;
+
+ while (true) {
+ /*
+ * Loop invariants:
+ * Cond1: range->start < end
+ * rb == range->rb.rb_right
+ *
+ * First, search right subtree if suitable
+ */
+ if (rb) {
+ struct mm_lock_range *right = rb_entry(
+ rb, struct mm_lock_range, rb);
+ if (start < right->__subtree.write_end)
+ return subtree_search_write(right, start, end);
+ }
+
+ /* Move up the tree until we come from a range's left child */
+ do {
+ rb = rb_parent(&range->rb);
+ if (!rb)
+ return NULL;
+ prev = &range->rb;
+ range = rb_entry(rb, struct mm_lock_range, rb);
+ rb = range->rb.rb_right;
+ } while (prev == rb);
+
+ /* Check if the range intersects [start;end) */
+ if (end <= range->start) /* !Cond1 */
+ return NULL;
+ else if (start < range->end && /* Cond2 */
+ range->flags_count & MM_LOCK_RANGE_WRITE)
+ return range;
+ }
+}
+
+#define FOR_EACH_RANGE_WRITE(mm, start, end, tmp) \
+for (tmp = iter_first_write(&mm->mmap_sem.rb_root, start, end); tmp; \
+ tmp = iter_next_write(tmp, start, end))
+
+static bool queue_read(struct mm_struct *mm, struct mm_lock_range *range)
+{
+ struct mm_lock_range *conflict;
+ long flags_count = 0;
+
+ FOR_EACH_RANGE_WRITE(mm, range->start, range->end, conflict)
+ flags_count -= MM_LOCK_RANGE_COUNT_ONE;
+ range->flags_count = flags_count;
+ insert_read(range, &mm->mmap_sem.rb_root);
+ return flags_count < 0;
+}
+
+static bool queue_write(struct mm_struct *mm, struct mm_lock_range *range)
+{
+ struct mm_lock_range *conflict;
+ long flags_count = MM_LOCK_RANGE_WRITE;
+
+ FOR_EACH_RANGE(mm, range->start, range->end, conflict)
+ flags_count -= MM_LOCK_RANGE_COUNT_ONE;
+ range->flags_count = flags_count;
+ insert_write(range, &mm->mmap_sem.rb_root);
+ return flags_count < 0;
+}
+
+static inline void prepare_wait(struct mm_lock_range *range, unsigned long seq)
+{
+ range->seq = seq;
+ range->task = current;
+}
+
+static void wait(struct mm_lock_range *range)
+{
+ while (true) {
+ set_current_state(TASK_UNINTERRUPTIBLE);
+ if (range->flags_count >= 0)
+ break;
+ schedule();
+ }
+ __set_current_state(TASK_RUNNING);
+}
+
+static bool wait_killable(struct mm_lock_range *range)
+{
+ while (true) {
+ set_current_state(TASK_INTERRUPTIBLE);
+ if (range->flags_count >= 0) {
+ __set_current_state(TASK_RUNNING);
+ return true;
+ }
+ if (signal_pending(current)) {
+ __set_current_state(TASK_RUNNING);
+ return false;
+ }
+ schedule();
+ }
+}
+
+static inline void unlock_conflict(struct mm_lock_range *range,
+ struct wake_q_head *wake_q)
+{
+ if ((range->flags_count += MM_LOCK_RANGE_COUNT_ONE) >= 0)
+ wake_q_add(wake_q, range->task);
+}
+
+void mm_write_range_lock(struct mm_struct *mm, struct mm_lock_range *range)
+{
+ bool contended;
+
+ lock_acquire_exclusive(&mm->mmap_sem.dep_map, 0, 0, NULL, _RET_IP_);
+
+ mutex_lock(&mm->mmap_sem.mutex);
+ if ((contended = queue_write(mm, range)))
+ prepare_wait(range, mm->mmap_sem.seq);
+ mutex_unlock(&mm->mmap_sem.mutex);
+
+ if (contended) {
+ lock_contended(&mm->mmap_sem.dep_map, _RET_IP_);
+ wait(range);
+ }
+ lock_acquired(&mm->mmap_sem.dep_map, _RET_IP_);
+}
+EXPORT_SYMBOL(mm_write_range_lock);
+
+#ifdef CONFIG_LOCKDEP
+void mm_write_range_lock_nested(struct mm_struct *mm,
+ struct mm_lock_range *range, int subclass)
+{
+ bool contended;
+
+ lock_acquire_exclusive(&mm->mmap_sem.dep_map, subclass, 0, NULL,
+ _RET_IP_);
+
+ mutex_lock(&mm->mmap_sem.mutex);
+ if ((contended = queue_write(mm, range)))
+ prepare_wait(range, mm->mmap_sem.seq);
+ mutex_unlock(&mm->mmap_sem.mutex);
+
+ if (contended) {
+ lock_contended(&mm->mmap_sem.dep_map, _RET_IP_);
+ wait(range);
+ }
+ lock_acquired(&mm->mmap_sem.dep_map, _RET_IP_);
+}
+EXPORT_SYMBOL(mm_write_range_lock_nested);
+#endif
+
+int mm_write_range_lock_killable(struct mm_struct *mm,
+ struct mm_lock_range *range)
+{
+ bool contended;
+
+ lock_acquire_exclusive(&mm->mmap_sem.dep_map, 0, 0, NULL, _RET_IP_);
+
+ mutex_lock(&mm->mmap_sem.mutex);
+ if ((contended = queue_write(mm, range)))
+ prepare_wait(range, ++(mm->mmap_sem.seq));
+ mutex_unlock(&mm->mmap_sem.mutex);
+
+ if (contended) {
+ lock_contended(&mm->mmap_sem.dep_map, _RET_IP_);
+ if (!wait_killable(range)) {
+ struct mm_lock_range *conflict;
+ DEFINE_WAKE_Q(wake_q);
+
+ mutex_lock(&mm->mmap_sem.mutex);
+ remove(range, &mm->mmap_sem.rb_root);
+ FOR_EACH_RANGE(mm, range->start, range->end, conflict)
+ if (conflict->flags_count < 0 &&
+ conflict->seq - range->seq <= (~0UL >> 1))
+ unlock_conflict(conflict, &wake_q);
+ mutex_unlock(&mm->mmap_sem.mutex);
+
+ wake_up_q(&wake_q);
+ lock_release(&mm->mmap_sem.dep_map, _RET_IP_);
+ return -EINTR;
+ }
+ }
+ lock_acquired(&mm->mmap_sem.dep_map, _RET_IP_);
+ return 0;
+}
+EXPORT_SYMBOL(mm_write_range_lock_killable);
+
+bool mm_write_range_trylock(struct mm_struct *mm, struct mm_lock_range *range)
+{
+ bool locked = false;
+
+ if (!mutex_trylock(&mm->mmap_sem.mutex))
+ goto exit;
+ if (iter_first(&mm->mmap_sem.rb_root, range->start, range->end))
+ goto unlock;
+ lock_acquire_exclusive(&mm->mmap_sem.dep_map, 0, 1, NULL,
+ _RET_IP_);
+ range->flags_count = MM_LOCK_RANGE_WRITE;
+ insert_write(range, &mm->mmap_sem.rb_root);
+ locked = true;
+unlock:
+ mutex_unlock(&mm->mmap_sem.mutex);
+exit:
+ return locked;
+}
+EXPORT_SYMBOL(mm_write_range_trylock);
+
+void mm_write_range_unlock(struct mm_struct *mm, struct mm_lock_range *range)
+{
+ struct mm_lock_range *conflict;
+ DEFINE_WAKE_Q(wake_q);
+
+ mutex_lock(&mm->mmap_sem.mutex);
+ remove(range, &mm->mmap_sem.rb_root);
+ FOR_EACH_RANGE(mm, range->start, range->end, conflict)
+ unlock_conflict(conflict, &wake_q);
+ mutex_unlock(&mm->mmap_sem.mutex);
+
+ wake_up_q(&wake_q);
+ lock_release(&mm->mmap_sem.dep_map, _RET_IP_);
+}
+EXPORT_SYMBOL(mm_write_range_unlock);
+
+void mm_downgrade_write_range_lock(struct mm_struct *mm,
+ struct mm_lock_range *range)
+{
+ struct mm_lock_range *conflict;
+ DEFINE_WAKE_Q(wake_q);
+
+ mutex_lock(&mm->mmap_sem.mutex);
+ FOR_EACH_RANGE_READ(mm, range->start, range->end, conflict)
+ unlock_conflict(conflict, &wake_q);
+ range->flags_count -= MM_LOCK_RANGE_WRITE;
+ augment_propagate(&range->rb, NULL);
+ mutex_unlock(&mm->mmap_sem.mutex);
+
+ wake_up_q(&wake_q);
+ lock_downgrade(&mm->mmap_sem.dep_map, _RET_IP_);
+}
+EXPORT_SYMBOL(mm_downgrade_write_range_lock);
+
+void mm_read_range_lock(struct mm_struct *mm, struct mm_lock_range *range)
+{
+ bool contended;
+
+ lock_acquire_shared(&mm->mmap_sem.dep_map, 0, 0, NULL, _RET_IP_);
+
+ mutex_lock(&mm->mmap_sem.mutex);
+ if ((contended = queue_read(mm, range)))
+ prepare_wait(range, mm->mmap_sem.seq);
+ mutex_unlock(&mm->mmap_sem.mutex);
+
+ if (contended) {
+ lock_contended(&mm->mmap_sem.dep_map, _RET_IP_);
+ wait(range);
+ }
+ lock_acquired(&mm->mmap_sem.dep_map, _RET_IP_);
+}
+EXPORT_SYMBOL(mm_read_range_lock);
+
+int mm_read_range_lock_killable(struct mm_struct *mm,
+ struct mm_lock_range *range)
+{
+ bool contended;
+
+ lock_acquire_shared(&mm->mmap_sem.dep_map, 0, 0, NULL, _RET_IP_);
+
+ mutex_lock(&mm->mmap_sem.mutex);
+ if ((contended = queue_read(mm, range)))
+ prepare_wait(range, ++(mm->mmap_sem.seq));
+ mutex_unlock(&mm->mmap_sem.mutex);
+
+ if (contended) {
+ lock_contended(&mm->mmap_sem.dep_map, _RET_IP_);
+ if (!wait_killable(range)) {
+ struct mm_lock_range *conflict;
+ DEFINE_WAKE_Q(wake_q);
+
+ mutex_lock(&mm->mmap_sem.mutex);
+ remove(range, &mm->mmap_sem.rb_root);
+ FOR_EACH_RANGE_WRITE(mm, range->start, range->end,
+ conflict)
+ if (conflict->flags_count < 0 &&
+ conflict->seq - range->seq <= (~0UL >> 1))
+ unlock_conflict(conflict, &wake_q);
+ mutex_unlock(&mm->mmap_sem.mutex);
+
+ wake_up_q(&wake_q);
+ lock_release(&mm->mmap_sem.dep_map, _RET_IP_);
+ return -EINTR;
+ }
+ }
+ lock_acquired(&mm->mmap_sem.dep_map, _RET_IP_);
+ return 0;
+}
+EXPORT_SYMBOL(mm_read_range_lock_killable);
+
+bool mm_read_range_trylock(struct mm_struct *mm, struct mm_lock_range *range)
+{
+ bool locked = false;
+
+ if (!mutex_trylock(&mm->mmap_sem.mutex))
+ goto exit;
+ if (iter_first_write(&mm->mmap_sem.rb_root, range->start, range->end))
+ goto unlock;
+ lock_acquire_shared(&mm->mmap_sem.dep_map, 0, 1, NULL, _RET_IP_);
+ range->flags_count = 0;
+ insert_read(range, &mm->mmap_sem.rb_root);
+ locked = true;
+unlock:
+ mutex_unlock(&mm->mmap_sem.mutex);
+exit:
+ return locked;
+}
+EXPORT_SYMBOL(mm_read_range_trylock);
+
+void mm_read_range_unlock_non_owner(struct mm_struct *mm,
+ struct mm_lock_range *range)
+{
+ struct mm_lock_range *conflict;
+ DEFINE_WAKE_Q(wake_q);
+
+ mutex_lock(&mm->mmap_sem.mutex);
+ remove(range, &mm->mmap_sem.rb_root);
+ FOR_EACH_RANGE_WRITE(mm, range->start, range->end, conflict)
+ unlock_conflict(conflict, &wake_q);
+ mutex_unlock(&mm->mmap_sem.mutex);
+
+ wake_up_q(&wake_q);
+}
+EXPORT_SYMBOL(mm_read_range_unlock_non_owner);
+
+void mm_read_range_unlock(struct mm_struct *mm, struct mm_lock_range *range)
+{
+ mm_read_range_unlock_non_owner(mm, range);
+ lock_release(&mm->mmap_sem.dep_map, _RET_IP_);
+}
+EXPORT_SYMBOL(mm_read_range_unlock);
This change implements fine grained reader-writer range locks. Existing locked ranges are represented as an augmented rbtree protected by a mutex. The locked ranges hold information about two overlapping interval trees, representing the reader and writer locks respectively. This data structure allows quickly searching for existing readers, writers, or both, intersecting a given address range. When locking a range, a count of all existing conflicting ranges (either already locked, or queued) is added to mm_lock_range struct. If the count is non-zero, the locking task is put to sleep until all conflicting lock ranges are released. When unlocking a range, the conflict count for all existing (queued) conflicting ranges is decremented. If the count reaches zero, the locker task is woken up - it now has a lock on its desired address range. The general approach for this range locking implementation was first proposed by Jan Kara back in 2013, and later worked on by at least Laurent Dufour and Davidlohr Bueso. I have extended on the approach by using separate indexes for the reader and writer range locks. Signed-off-by: Michel Lespinasse <walken@google.com> --- arch/x86/kernel/tboot.c | 2 +- drivers/firmware/efi/efi.c | 2 +- include/linux/mm_lock.h | 96 ++++- include/linux/mm_types.h | 20 + include/linux/mm_types_task.h | 15 + mm/Kconfig | 9 +- mm/Makefile | 1 + mm/init-mm.c | 3 +- mm/mm_lock_range.c | 691 ++++++++++++++++++++++++++++++++++ 9 files changed, 827 insertions(+), 12 deletions(-) create mode 100644 mm/mm_lock_range.c