diff mbox series

[RFC,V0,02/10] mm: Maintain mm_struct list in the system

Message ID 20241201153818.2633616-3-raghavendra.kt@amd.com (mailing list archive)
State New
Headers show
Series None | expand

Commit Message

Raghavendra K T Dec. 1, 2024, 3:38 p.m. UTC
Add a hook in the fork and exec path to link mm_struct.
Reuse the mm_slot infrastructure to aid insert and lookup of mm_struct.

CC: linux-fsdevel@vger.kernel.org
Suggested-by: Bharata B Rao <bharata@amd.com>
Signed-off-by: Raghavendra K T <raghavendra.kt@amd.com>
---
 fs/exec.c                |  4 ++
 include/linux/kmmscand.h | 30 ++++++++++++++
 kernel/fork.c            |  4 ++
 mm/kmmscand.c            | 86 +++++++++++++++++++++++++++++++++++++++-
 4 files changed, 123 insertions(+), 1 deletion(-)
 create mode 100644 include/linux/kmmscand.h
diff mbox series

Patch

diff --git a/fs/exec.c b/fs/exec.c
index 98cb7ba9983c..bd72107b2ab1 100644
--- a/fs/exec.c
+++ b/fs/exec.c
@@ -68,6 +68,7 @@ 
 #include <linux/user_events.h>
 #include <linux/rseq.h>
 #include <linux/ksm.h>
+#include <linux/kmmscand.h>
 
 #include <linux/uaccess.h>
 #include <asm/mmu_context.h>
@@ -274,6 +275,8 @@  static int __bprm_mm_init(struct linux_binprm *bprm)
 	if (err)
 		goto err_ksm;
 
+	kmmscand_execve(mm);
+
 	/*
 	 * Place the stack at the largest stack address the architecture
 	 * supports. Later, we'll move this to an appropriate place. We don't
@@ -296,6 +299,7 @@  static int __bprm_mm_init(struct linux_binprm *bprm)
 	return 0;
 err:
 	ksm_exit(mm);
+	kmmscand_exit(mm);
 err_ksm:
 	mmap_write_unlock(mm);
 err_free:
diff --git a/include/linux/kmmscand.h b/include/linux/kmmscand.h
new file mode 100644
index 000000000000..b120c65ee8c6
--- /dev/null
+++ b/include/linux/kmmscand.h
@@ -0,0 +1,30 @@ 
+/* SPDX-License-Identifier: GPL-2.0 */
+#ifndef _LINUX_KMMSCAND_H_
+#define _LINUX_KMMSCAND_H_
+
+#ifdef CONFIG_KMMSCAND
+extern void __kmmscand_enter(struct mm_struct *mm);
+extern void __kmmscand_exit(struct mm_struct *mm);
+
+static inline void kmmscand_execve(struct mm_struct *mm)
+{
+	__kmmscand_enter(mm);
+}
+
+static inline void kmmscand_fork(struct mm_struct *mm, struct mm_struct *oldmm)
+{
+	__kmmscand_enter(mm);
+}
+
+static inline void kmmscand_exit(struct mm_struct *mm)
+{
+	__kmmscand_exit(mm);
+}
+#else /* !CONFIG_KMMSCAND */
+static inline void __kmmscand_enter(struct mm_struct *mm) {}
+static inline void __kmmscand_exit(struct mm_struct *mm) {}
+static inline void kmmscand_execve(struct mm_struct *mm) {}
+static inline void kmmscand_fork(struct mm_struct *mm, struct mm_struct *oldmm) {}
+static inline void kmmscand_exit(struct mm_struct *mm) {}
+#endif
+#endif /* _LINUX_KMMSCAND_H_ */
diff --git a/kernel/fork.c b/kernel/fork.c
index 1450b461d196..812b0032592e 100644
--- a/kernel/fork.c
+++ b/kernel/fork.c
@@ -85,6 +85,7 @@ 
 #include <linux/user-return-notifier.h>
 #include <linux/oom.h>
 #include <linux/khugepaged.h>
+#include <linux/kmmscand.h>
 #include <linux/signalfd.h>
 #include <linux/uprobes.h>
 #include <linux/aio.h>
@@ -659,6 +660,8 @@  static __latent_entropy int dup_mmap(struct mm_struct *mm,
 	mm->exec_vm = oldmm->exec_vm;
 	mm->stack_vm = oldmm->stack_vm;
 
+	kmmscand_fork(mm, oldmm);
+
 	/* Use __mt_dup() to efficiently build an identical maple tree. */
 	retval = __mt_dup(&oldmm->mm_mt, &mm->mm_mt, GFP_KERNEL);
 	if (unlikely(retval))
@@ -1350,6 +1353,7 @@  static inline void __mmput(struct mm_struct *mm)
 	exit_aio(mm);
 	ksm_exit(mm);
 	khugepaged_exit(mm); /* must run before exit_mmap */
+	kmmscand_exit(mm);
 	exit_mmap(mm);
 	mm_put_huge_zero_folio(mm);
 	set_mm_exe_file(mm, NULL);
diff --git a/mm/kmmscand.c b/mm/kmmscand.c
index 23cf5638fe10..957128d4e425 100644
--- a/mm/kmmscand.c
+++ b/mm/kmmscand.c
@@ -7,13 +7,14 @@ 
 #include <linux/swap.h>
 #include <linux/mm_inline.h>
 #include <linux/kthread.h>
+#include <linux/kmmscand.h>
 #include <linux/string.h>
 #include <linux/delay.h>
 #include <linux/cleanup.h>
 
 #include <asm/pgalloc.h>
 #include "internal.h"
-
+#include "mm_slot.h"
 
 static struct task_struct *kmmscand_thread __read_mostly;
 static DEFINE_MUTEX(kmmscand_mutex);
@@ -30,10 +31,21 @@  static bool need_wakeup;
 
 static unsigned long kmmscand_sleep_expire;
 
+static DEFINE_SPINLOCK(kmmscand_mm_lock);
 static DECLARE_WAIT_QUEUE_HEAD(kmmscand_wait);
 
+#define KMMSCAND_SLOT_HASH_BITS 10
+static DEFINE_READ_MOSTLY_HASHTABLE(kmmscand_slots_hash, KMMSCAND_SLOT_HASH_BITS);
+
+static struct kmem_cache *kmmscand_slot_cache __read_mostly;
+
+struct kmmscand_mm_slot {
+	struct mm_slot slot;
+};
+
 struct kmmscand_scan {
 	struct list_head mm_head;
+	struct kmmscand_mm_slot *mm_slot;
 };
 
 struct kmmscand_scan kmmscand_scan = {
@@ -76,6 +88,11 @@  static void kmmscand_migrate_folio(void)
 {
 }
 
+static inline int kmmscand_test_exit(struct mm_struct *mm)
+{
+	return atomic_read(&mm->mm_users) == 0;
+}
+
 static unsigned long kmmscand_scan_mm_slot(void)
 {
 	/* placeholder for scanning */
@@ -123,6 +140,65 @@  static int kmmscand(void *none)
 	return 0;
 }
 
+static inline void kmmscand_destroy(void)
+{
+	kmem_cache_destroy(kmmscand_slot_cache);
+}
+
+void __kmmscand_enter(struct mm_struct *mm)
+{
+	struct kmmscand_mm_slot *kmmscand_slot;
+	struct mm_slot *slot;
+	int wakeup;
+
+	/* __kmmscand_exit() must not run from under us */
+	VM_BUG_ON_MM(kmmscand_test_exit(mm), mm);
+
+	kmmscand_slot = mm_slot_alloc(kmmscand_slot_cache);
+
+	if (!kmmscand_slot)
+		return;
+
+	slot = &kmmscand_slot->slot;
+
+	spin_lock(&kmmscand_mm_lock);
+	mm_slot_insert(kmmscand_slots_hash, mm, slot);
+
+	wakeup = list_empty(&kmmscand_scan.mm_head);
+	list_add_tail(&slot->mm_node, &kmmscand_scan.mm_head);
+	spin_unlock(&kmmscand_mm_lock);
+
+	mmgrab(mm);
+	if (wakeup)
+		wake_up_interruptible(&kmmscand_wait);
+}
+
+void __kmmscand_exit(struct mm_struct *mm)
+{
+	struct kmmscand_mm_slot *mm_slot;
+	struct mm_slot *slot;
+	int free = 0;
+
+	spin_lock(&kmmscand_mm_lock);
+	slot = mm_slot_lookup(kmmscand_slots_hash, mm);
+	mm_slot = mm_slot_entry(slot, struct kmmscand_mm_slot, slot);
+	if (mm_slot && kmmscand_scan.mm_slot != mm_slot) {
+		hash_del(&slot->hash);
+		list_del(&slot->mm_node);
+		free = 1;
+	}
+
+	spin_unlock(&kmmscand_mm_lock);
+
+	if (free) {
+		mm_slot_free(kmmscand_slot_cache, mm_slot);
+		mmdrop(mm);
+	} else if (mm_slot) {
+		mmap_write_lock(mm);
+		mmap_write_unlock(mm);
+	}
+}
+
 static int start_kmmscand(void)
 {
 	int err = 0;
@@ -168,6 +244,13 @@  static int __init kmmscand_init(void)
 {
 	int err;
 
+	kmmscand_slot_cache = KMEM_CACHE(kmmscand_mm_slot, 0);
+
+	if (!kmmscand_slot_cache) {
+		pr_err("kmmscand: kmem_cache error");
+		return -ENOMEM;
+	}
+
 	err = start_kmmscand();
 	if (err)
 		goto err_kmmscand;
@@ -176,6 +259,7 @@  static int __init kmmscand_init(void)
 
 err_kmmscand:
 	stop_kmmscand();
+	kmmscand_destroy();
 
 	return err;
 }