diff mbox series

[v5,2/6] memcontrol: allows mem_cgroup_iter() to check for onlineness

Message ID 20231106183159.3562879-3-nphamcs@gmail.com (mailing list archive)
State New
Headers show
Series workload-specific and memory pressure-driven zswap writeback | expand

Commit Message

Nhat Pham Nov. 6, 2023, 6:31 p.m. UTC
The new zswap writeback scheme requires an online-only memcg hierarchy
traversal. Add a new parameter to mem_cgroup_iter() to check for
onlineness before returning.

Signed-off-by: Nhat Pham <nphamcs@gmail.com>
---
 include/linux/memcontrol.h |  4 ++--
 mm/memcontrol.c            | 17 ++++++++++-------
 mm/shrinker.c              |  4 ++--
 mm/vmscan.c                | 26 +++++++++++++-------------
 4 files changed, 27 insertions(+), 24 deletions(-)
diff mbox series

Patch

diff --git a/include/linux/memcontrol.h b/include/linux/memcontrol.h
index 6edd3ec4d8d5..55c85f952afd 100644
--- a/include/linux/memcontrol.h
+++ b/include/linux/memcontrol.h
@@ -832,7 +832,7 @@  static inline void mem_cgroup_put(struct mem_cgroup *memcg)
 
 struct mem_cgroup *mem_cgroup_iter(struct mem_cgroup *,
 				   struct mem_cgroup *,
-				   struct mem_cgroup_reclaim_cookie *);
+				   struct mem_cgroup_reclaim_cookie *, bool online);
 void mem_cgroup_iter_break(struct mem_cgroup *, struct mem_cgroup *);
 void mem_cgroup_scan_tasks(struct mem_cgroup *memcg,
 			   int (*)(struct task_struct *, void *), void *arg);
@@ -1381,7 +1381,7 @@  static inline struct lruvec *folio_lruvec_lock_irqsave(struct folio *folio,
 static inline struct mem_cgroup *
 mem_cgroup_iter(struct mem_cgroup *root,
 		struct mem_cgroup *prev,
-		struct mem_cgroup_reclaim_cookie *reclaim)
+		struct mem_cgroup_reclaim_cookie *reclaim, bool online)
 {
 	return NULL;
 }
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index 61c0c46c2d62..6f7fc0101252 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -221,14 +221,14 @@  enum res_type {
  * be used for reference counting.
  */
 #define for_each_mem_cgroup_tree(iter, root)		\
-	for (iter = mem_cgroup_iter(root, NULL, NULL);	\
+	for (iter = mem_cgroup_iter(root, NULL, NULL, false);	\
 	     iter != NULL;				\
-	     iter = mem_cgroup_iter(root, iter, NULL))
+	     iter = mem_cgroup_iter(root, iter, NULL, false))
 
 #define for_each_mem_cgroup(iter)			\
-	for (iter = mem_cgroup_iter(NULL, NULL, NULL);	\
+	for (iter = mem_cgroup_iter(NULL, NULL, NULL, false);	\
 	     iter != NULL;				\
-	     iter = mem_cgroup_iter(NULL, iter, NULL))
+	     iter = mem_cgroup_iter(NULL, iter, NULL, false))
 
 static inline bool task_is_dying(void)
 {
@@ -1139,6 +1139,7 @@  struct mem_cgroup *get_mem_cgroup_from_current(void)
  * @root: hierarchy root
  * @prev: previously returned memcg, NULL on first invocation
  * @reclaim: cookie for shared reclaim walks, NULL for full walks
+ * @online: skip offline memcgs
  *
  * Returns references to children of the hierarchy below @root, or
  * @root itself, or %NULL after a full round-trip.
@@ -1153,7 +1154,8 @@  struct mem_cgroup *get_mem_cgroup_from_current(void)
  */
 struct mem_cgroup *mem_cgroup_iter(struct mem_cgroup *root,
 				   struct mem_cgroup *prev,
-				   struct mem_cgroup_reclaim_cookie *reclaim)
+				   struct mem_cgroup_reclaim_cookie *reclaim,
+				   bool online)
 {
 	struct mem_cgroup_reclaim_iter *iter;
 	struct cgroup_subsys_state *css = NULL;
@@ -1223,7 +1225,8 @@  struct mem_cgroup *mem_cgroup_iter(struct mem_cgroup *root,
 		 * is provided by the caller, so we know it's alive
 		 * and kicking, and don't take an extra reference.
 		 */
-		if (css == &root->css || css_tryget(css)) {
+		if (css == &root->css || (!online && css_tryget(css)) ||
+				css_tryget_online(css)) {
 			memcg = mem_cgroup_from_css(css);
 			break;
 		}
@@ -1836,7 +1839,7 @@  static int mem_cgroup_soft_reclaim(struct mem_cgroup *root_memcg,
 	excess = soft_limit_excess(root_memcg);
 
 	while (1) {
-		victim = mem_cgroup_iter(root_memcg, victim, &reclaim);
+		victim = mem_cgroup_iter(root_memcg, victim, &reclaim, false);
 		if (!victim) {
 			loop++;
 			if (loop >= 2) {
diff --git a/mm/shrinker.c b/mm/shrinker.c
index dd91eab43ed3..54f5d3aa4f27 100644
--- a/mm/shrinker.c
+++ b/mm/shrinker.c
@@ -160,7 +160,7 @@  static int expand_shrinker_info(int new_id)
 	new_size = shrinker_unit_size(new_nr_max);
 	old_size = shrinker_unit_size(shrinker_nr_max);
 
-	memcg = mem_cgroup_iter(NULL, NULL, NULL);
+	memcg = mem_cgroup_iter(NULL, NULL, NULL, false);
 	do {
 		ret = expand_one_shrinker_info(memcg, new_size, old_size,
 					       new_nr_max);
@@ -168,7 +168,7 @@  static int expand_shrinker_info(int new_id)
 			mem_cgroup_iter_break(NULL, memcg);
 			goto out;
 		}
-	} while ((memcg = mem_cgroup_iter(NULL, memcg, NULL)) != NULL);
+	} while ((memcg = mem_cgroup_iter(NULL, memcg, NULL, false)) != NULL);
 out:
 	if (!ret)
 		shrinker_nr_max = new_nr_max;
diff --git a/mm/vmscan.c b/mm/vmscan.c
index 2cc0cb41fb32..065d29502580 100644
--- a/mm/vmscan.c
+++ b/mm/vmscan.c
@@ -397,10 +397,10 @@  static unsigned long drop_slab_node(int nid)
 	unsigned long freed = 0;
 	struct mem_cgroup *memcg = NULL;
 
-	memcg = mem_cgroup_iter(NULL, NULL, NULL);
+	memcg = mem_cgroup_iter(NULL, NULL, NULL, false);
 	do {
 		freed += shrink_slab(GFP_KERNEL, nid, memcg, 0);
-	} while ((memcg = mem_cgroup_iter(NULL, memcg, NULL)) != NULL);
+	} while ((memcg = mem_cgroup_iter(NULL, memcg, NULL, false)) != NULL);
 
 	return freed;
 }
@@ -3931,7 +3931,7 @@  static void lru_gen_age_node(struct pglist_data *pgdat, struct scan_control *sc)
 	if (!min_ttl || sc->order || sc->priority == DEF_PRIORITY)
 		return;
 
-	memcg = mem_cgroup_iter(NULL, NULL, NULL);
+	memcg = mem_cgroup_iter(NULL, NULL, NULL, false);
 	do {
 		struct lruvec *lruvec = mem_cgroup_lruvec(memcg, pgdat);
 
@@ -3941,7 +3941,7 @@  static void lru_gen_age_node(struct pglist_data *pgdat, struct scan_control *sc)
 		}
 
 		cond_resched();
-	} while ((memcg = mem_cgroup_iter(NULL, memcg, NULL)));
+	} while ((memcg = mem_cgroup_iter(NULL, memcg, NULL, false)));
 
 	/*
 	 * The main goal is to OOM kill if every generation from all memcgs is
@@ -5033,7 +5033,7 @@  static void lru_gen_change_state(bool enabled)
 	else
 		static_branch_disable_cpuslocked(&lru_gen_caps[LRU_GEN_CORE]);
 
-	memcg = mem_cgroup_iter(NULL, NULL, NULL);
+	memcg = mem_cgroup_iter(NULL, NULL, NULL, false);
 	do {
 		int nid;
 
@@ -5057,7 +5057,7 @@  static void lru_gen_change_state(bool enabled)
 		}
 
 		cond_resched();
-	} while ((memcg = mem_cgroup_iter(NULL, memcg, NULL)));
+	} while ((memcg = mem_cgroup_iter(NULL, memcg, NULL, false)));
 unlock:
 	mutex_unlock(&state_mutex);
 	put_online_mems();
@@ -5160,7 +5160,7 @@  static void *lru_gen_seq_start(struct seq_file *m, loff_t *pos)
 	if (!m->private)
 		return ERR_PTR(-ENOMEM);
 
-	memcg = mem_cgroup_iter(NULL, NULL, NULL);
+	memcg = mem_cgroup_iter(NULL, NULL, NULL, false);
 	do {
 		int nid;
 
@@ -5168,7 +5168,7 @@  static void *lru_gen_seq_start(struct seq_file *m, loff_t *pos)
 			if (!nr_to_skip--)
 				return get_lruvec(memcg, nid);
 		}
-	} while ((memcg = mem_cgroup_iter(NULL, memcg, NULL)));
+	} while ((memcg = mem_cgroup_iter(NULL, memcg, NULL, false)));
 
 	return NULL;
 }
@@ -5191,7 +5191,7 @@  static void *lru_gen_seq_next(struct seq_file *m, void *v, loff_t *pos)
 
 	nid = next_memory_node(nid);
 	if (nid == MAX_NUMNODES) {
-		memcg = mem_cgroup_iter(NULL, memcg, NULL);
+		memcg = mem_cgroup_iter(NULL, memcg, NULL, false);
 		if (!memcg)
 			return NULL;
 
@@ -5794,7 +5794,7 @@  static void shrink_node_memcgs(pg_data_t *pgdat, struct scan_control *sc)
 	struct mem_cgroup *target_memcg = sc->target_mem_cgroup;
 	struct mem_cgroup *memcg;
 
-	memcg = mem_cgroup_iter(target_memcg, NULL, NULL);
+	memcg = mem_cgroup_iter(target_memcg, NULL, NULL, false);
 	do {
 		struct lruvec *lruvec = mem_cgroup_lruvec(memcg, pgdat);
 		unsigned long reclaimed;
@@ -5844,7 +5844,7 @@  static void shrink_node_memcgs(pg_data_t *pgdat, struct scan_control *sc)
 				   sc->nr_scanned - scanned,
 				   sc->nr_reclaimed - reclaimed);
 
-	} while ((memcg = mem_cgroup_iter(target_memcg, memcg, NULL)));
+	} while ((memcg = mem_cgroup_iter(target_memcg, memcg, NULL, false)));
 }
 
 static void shrink_node(pg_data_t *pgdat, struct scan_control *sc)
@@ -6511,12 +6511,12 @@  static void kswapd_age_node(struct pglist_data *pgdat, struct scan_control *sc)
 	if (!inactive_is_low(lruvec, LRU_INACTIVE_ANON))
 		return;
 
-	memcg = mem_cgroup_iter(NULL, NULL, NULL);
+	memcg = mem_cgroup_iter(NULL, NULL, NULL, false);
 	do {
 		lruvec = mem_cgroup_lruvec(memcg, pgdat);
 		shrink_active_list(SWAP_CLUSTER_MAX, lruvec,
 				   sc, LRU_ACTIVE_ANON);
-		memcg = mem_cgroup_iter(NULL, memcg, NULL);
+		memcg = mem_cgroup_iter(NULL, memcg, NULL, false);
 	} while (memcg);
 }