diff mbox series

[2/3] mm: fix the reference of mempolicy in some functions.

Message ID 20221204161432.2149375-3-hezhongkun.hzk@bytedance.com (mailing list archive)
State New
Headers show
Series [1/3] mm: replace mpol_put() with mpol_kill() initiated the destruction of mpol. | expand

Commit Message

Zhongkun He Dec. 4, 2022, 4:14 p.m. UTC
There are some functions that use mempolicy in process
context, but don't reference it. Let's fix it to have
a clear life time model.

Suggested-by: Michal Hocko <mhocko@suse.com>
Signed-off-by: Zhongkun He <hezhongkun.hzk@bytedance.com>
---
 mm/hugetlb.c   | 16 ++++++-----
 mm/mempolicy.c | 78 +++++++++++++++++++++++++-------------------------
 2 files changed, 48 insertions(+), 46 deletions(-)
diff mbox series

Patch

diff --git a/mm/hugetlb.c b/mm/hugetlb.c
index 277330f40818..0c2b5233e0c9 100644
--- a/mm/hugetlb.c
+++ b/mm/hugetlb.c
@@ -4353,19 +4353,19 @@  static int __init default_hugepagesz_setup(char *s)
 }
 __setup("default_hugepagesz=", default_hugepagesz_setup);
 
-static nodemask_t *policy_mbind_nodemask(gfp_t gfp)
+static nodemask_t *policy_mbind_nodemask(gfp_t gfp, struct mempolicy **mpol)
 {
 #ifdef CONFIG_NUMA
-	struct mempolicy *mpol = get_task_policy(current);
+	*mpol = get_task_policy(current);
 
 	/*
 	 * Only enforce MPOL_BIND policy which overlaps with cpuset policy
 	 * (from policy_nodemask) specifically for hugetlb case
 	 */
-	if (mpol->mode == MPOL_BIND &&
-		(apply_policy_zone(mpol, gfp_zone(gfp)) &&
-		 cpuset_nodemask_valid_mems_allowed(&mpol->nodes)))
-		return &mpol->nodes;
+	if ((*mpol)->mode == MPOL_BIND &&
+		(apply_policy_zone(*mpol, gfp_zone(gfp)) &&
+		 cpuset_nodemask_valid_mems_allowed(&(*mpol)->nodes)))
+		return &(*mpol)->nodes;
 #endif
 	return NULL;
 }
@@ -4375,14 +4375,16 @@  static unsigned int allowed_mems_nr(struct hstate *h)
 	int node;
 	unsigned int nr = 0;
 	nodemask_t *mbind_nodemask;
+	struct mempolicy *mpol = NULL;
 	unsigned int *array = h->free_huge_pages_node;
 	gfp_t gfp_mask = htlb_alloc_mask(h);
 
-	mbind_nodemask = policy_mbind_nodemask(gfp_mask);
+	mbind_nodemask = policy_mbind_nodemask(gfp_mask, &mpol);
 	for_each_node_mask(node, cpuset_current_mems_allowed) {
 		if (!mbind_nodemask || node_isset(node, *mbind_nodemask))
 			nr += array[node];
 	}
+	mpol_put(mpol);
 
 	return nr;
 }
diff --git a/mm/mempolicy.c b/mm/mempolicy.c
index f1857ebded46..0feffb7ff01e 100644
--- a/mm/mempolicy.c
+++ b/mm/mempolicy.c
@@ -159,7 +159,7 @@  int numa_map_to_online_node(int node)
 EXPORT_SYMBOL_GPL(numa_map_to_online_node);
 
 /* Obtain a reference on the specified task mempolicy */
-static mempolicy *get_task_mpol(struct task_struct *p)
+static struct mempolicy *get_task_mpol(struct task_struct *p)
 {
 	struct mempolicy *pol;
 
@@ -925,7 +925,8 @@  static void get_policy_nodemask(struct mempolicy *p, nodemask_t *nodes)
 		*nodes = p->nodes;
 		break;
 	case MPOL_LOCAL:
-		/* return empty node mask for local allocation */killbreak;
+		/* return empty node mask for local allocation */
+		break;
 	default:
 		BUG();
 	}
@@ -951,7 +952,7 @@  static long do_get_mempolicy(int *policy, nodemask_t *nmask,
 	int err;
 	struct mm_struct *mm = current->mm;
 	struct vm_area_struct *vma = NULL;
-	struct mempolicy *pol = current->mempolicy, *pol_refcount = NULL;
+	struct mempolicy *pol;
 
 	if (flags &
 		~(unsigned long)(MPOL_F_NODE|MPOL_F_ADDR|MPOL_F_MEMS_ALLOWED))
@@ -966,8 +967,10 @@  static long do_get_mempolicy(int *policy, nodemask_t *nmask,
 		task_unlock(current);
 		return 0;
 	}
+	pol = get_task_mpol(current);
 
 	if (flags & MPOL_F_ADDR) {
+		mpol_put(pol); /* put the refcount of task mpol */
 		/*
 		 * Do NOT fall back to task policy if the
 		 * vma/shared policy at addr is NULL.  We
@@ -979,27 +982,19 @@  static long do_get_mempolicy(int *policy, nodemask_t *nmask,
 			mmap_read_unlock(mm);
 			return -EFAULT;
 		}
-		if (vma->vm_ops && vma->vm_ops->get_policy)
-			pol = vma->vm_ops->get_policy(vma, addr);
-		else
-			pol = vma->vm_policy;
-	} else if (addr)
-		return -EINVAL;
+		/* obtain a reference to vma mpol. */
+		pol = __get_vma_policy(vma, addr);
+		mmap_read_unlock(mm);
+	} else if (addr) {
+		err = -EINVAL;
+		goto out;
+	}
 
 	if (!pol)
 		pol = &default_policy;	/* indicates default behavior */
 
 	if (flags & MPOL_F_NODE) {
 		if (flags & MPOL_F_ADDR) {
-			/*
-			 * Take a refcount on the mpol, because we are about to
-			 * drop the mmap_lock, after which only "pol" remains
-			 * valid, "vma" is stale.
-			 */
-			pol_refcount = pol;
-			vma = NULL;
-			mpol_get(pol);
-			mmap_read_unlock(mm);
 			err = lookup_node(mm, addr);
 			if (err < 0)
 				goto out;
@@ -1023,21 +1018,19 @@  static long do_get_mempolicy(int *policy, nodemask_t *nmask,
 
 	err = 0;
 	if (nmask) {
-		if (mpol_store_user_nodemask(pol)) {
+		/*
+		 * There is no need for a lock, since we get
+		 * a reference to mpol.
+		 */
+		if (mpol_store_user_nodemask(pol))
 			*nmask = pol->w.user_nodemask;
-		} else {
-			task_lock(current);
+		else
 			get_policy_nodemask(pol, nmask);
-			task_unlock(current);
-		}
 	}
 
  out:
-	mpol_cond_put(pol);
-	if (vma)
-		mmap_read_unlock(mm);
-	if (pol_refcount)
-		mpol_put(pol_refcount);
+	if (pol != &default_policy)
+		mpol_put(pol);
 	return err;
 }
 
@@ -1923,16 +1916,18 @@  unsigned int mempolicy_slab_node(void)
 	if (!in_task())
 		return node;
 
-	policy = current->mempolicy;
+	policy = get_task_mpol(current);
 	if (!policy)
 		return node;
 
 	switch (policy->mode) {
 	case MPOL_PREFERRED:
-		return first_node(policy->nodes);
+		node = first_node(policy->nodes);
+		break;
 
 	case MPOL_INTERLEAVE:
-		return interleave_nodes(policy);
+		node = interleave_nodes(policy);
+		break;
 
 	case MPOL_BIND:
 	case MPOL_PREFERRED_MANY:
@@ -1948,14 +1943,17 @@  unsigned int mempolicy_slab_node(void)
 		zonelist = &NODE_DATA(node)->node_zonelists[ZONELIST_FALLBACK];
 		z = first_zones_zonelist(zonelist, highest_zoneidx,
 							&policy->nodes);
-		return z->zone ? zone_to_nid(z->zone) : node;
+		node = z->zone ? zone_to_nid(z->zone) : node;
+		break;
 	}
 	case MPOL_LOCAL:
-		return node;
+		break;
 
 	default:
 		BUG();
 	}
+	mpol_put(policy);
+	return node;
 }
 
 /*
@@ -2379,21 +2377,23 @@  unsigned long alloc_pages_bulk_array_mempolicy(gfp_t gfp,
 		unsigned long nr_pages, struct page **page_array)
 {
 	struct mempolicy *pol = &default_policy;
+	unsigned long pages;
 
 	if (!in_interrupt() && !(gfp & __GFP_THISNODE))
 		pol = get_task_policy(current);
 
 	if (pol->mode == MPOL_INTERLEAVE)
-		return alloc_pages_bulk_array_interleave(gfp, pol,
+		pages = alloc_pages_bulk_array_interleave(gfp, pol,
 							 nr_pages, page_array);
-
-	if (pol->mode == MPOL_PREFERRED_MANY)
-		return alloc_pages_bulk_array_preferred_many(gfp,
+	else if (pol->mode == MPOL_PREFERRED_MANY)
+		pages = alloc_pages_bulk_array_preferred_many(gfp,
 				numa_node_id(), pol, nr_pages, page_array);
-
-	return __alloc_pages_bulk(gfp, policy_node(gfp, pol, numa_node_id()),
+	else
+		pages = __alloc_pages_bulk(gfp, policy_node(gfp, pol, numa_node_id()),
 				  policy_nodemask(gfp, pol), nr_pages, NULL,
 				  page_array);
+	mpol_put(pol);
+	return pages;
 }
 
 int vma_dup_policy(struct vm_area_struct *src, struct vm_area_struct *dst)