diff mbox series

[v2,4/4] mm/memcg: Protect memcg_stock with a local_lock_t

Message ID 20220211223537.2175879-5-bigeasy@linutronix.de (mailing list archive)
State New
Headers show
Series mm/memcg: Address PREEMPT_RT problems instead of disabling it. | expand

Commit Message

Sebastian Andrzej Siewior Feb. 11, 2022, 10:35 p.m. UTC
The members of the per-CPU structure memcg_stock_pcp are protected by
disabling interrupts. This is not working on PREEMPT_RT because it
creates atomic context in which actions are performed which require
preemptible context. One example is obj_cgroup_release().

The IRQ-disable sections can be replaced with local_lock_t which
preserves the explicit disabling of interrupts while keeps the code
preemptible on PREEMPT_RT.

drain_all_stock() disables preemption via get_cpu() and then invokes
drain_local_stock() if it is the local CPU to avoid scheduling a worker (which
invokes the same function). Disabling preemption here is problematic due to the
sleeping locks in drain_local_stock().
This can be avoided by always scheduling a worker, even for the local
CPU. Using cpus_read_lock() to stabilize the cpu_online_mask is not
needed since the worker operates always on the CPU-local data structure.
Should a CPU go offline then a two worker would perform the work and no
harm is done. Using cpus_read_lock() leads to a possible deadlock.

drain_obj_stock() drops a reference on obj_cgroup which leads to an invocation
of obj_cgroup_release() if it is the last object. This in turn leads to
recursive locking of the local_lock_t. To avoid this, obj_cgroup_release() is
invoked outside of the locked section.

obj_cgroup_uncharge_pages() can be invoked with the local_lock_t acquired and
without it. This will lead later to a recursion in refill_stock(). To
avoid the locking recursion provide obj_cgroup_uncharge_pages_locked()
which uses the locked version of refill_stock().

- Replace disabling interrupts for memcg_stock with a local_lock_t.

- Schedule a worker even for the local CPU instead of invoking it
  directly (in drain_all_stock()).

- Let drain_obj_stock() return the old struct obj_cgroup which is passed
  to obj_cgroup_put() outside of the locked section.

- Provide obj_cgroup_uncharge_pages_locked() which uses the locked
  version of refill_stock() to avoid recursive locking in
  drain_obj_stock().

Link: https://lkml.kernel.org/r/20220209014709.GA26885@xsang-OptiPlex-9020
Reported-by: kernel test robot <oliver.sang@intel.com>
Signed-off-by: Sebastian Andrzej Siewior <bigeasy@linutronix.de>
---
 mm/memcontrol.c | 101 ++++++++++++++++++++++++++++++------------------
 1 file changed, 63 insertions(+), 38 deletions(-)

Comments

Johannes Weiner Feb. 14, 2022, 4:23 p.m. UTC | #1
Hi Sebastian,

On Fri, Feb 11, 2022 at 11:35:37PM +0100, Sebastian Andrzej Siewior wrote:
> +static struct obj_cgroup *drain_obj_stock(struct memcg_stock_pcp *stock)
>  {
>  	struct obj_cgroup *old = stock->cached_objcg;
>  
>  	if (!old)
> -		return;
> +		return NULL;
>  
>  	if (stock->nr_bytes) {
>  		unsigned int nr_pages = stock->nr_bytes >> PAGE_SHIFT;
>  		unsigned int nr_bytes = stock->nr_bytes & (PAGE_SIZE - 1);
>  
>  		if (nr_pages)
> -			obj_cgroup_uncharge_pages(old, nr_pages);
> +			obj_cgroup_uncharge_pages_locked(old, nr_pages);

This is a bit dubious in terms of layering. It's an objcg operation,
but what's "locked" isn't the objcg, it's the underlying stock. That
function then looks it up again, even though we have it right there.

You can open-code it and factor out the stock operation instead, and
it makes things much simpler and clearer.

I.e. something like this (untested!):

diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index 1b3550f09335..eed6e0ff84d7 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -2215,6 +2215,20 @@ static void drain_local_stock(struct work_struct *dummy)
 	local_irq_restore(flags);
 }
 
+static void __refill_stock(struct memcg_stock_pcp *stock,
+			   struct mem_cgroup *memcg,
+			   unsigned int nr_pages)
+{
+	if (stock->cached != memcg) {
+		drain_stock(stock);
+		css_get(&memcg->css);
+		stock->cached = memcg;
+	}
+	stock->nr_pages += nr_pages;
+	if (stock->nr_pages > MEMCG_CHARGE_BATCH)
+		drain_stock(stock);
+}
+
 /*
  * Cache charges(val) to local per_cpu area.
  * This will be consumed by consume_stock() function, later.
@@ -2225,18 +2239,8 @@ static void refill_stock(struct mem_cgroup *memcg, unsigned int nr_pages)
 	unsigned long flags;
 
 	local_irq_save(flags);
-
 	stock = this_cpu_ptr(&memcg_stock);
-	if (stock->cached != memcg) { /* reset if necessary */
-		drain_stock(stock);
-		css_get(&memcg->css);
-		stock->cached = memcg;
-	}
-	stock->nr_pages += nr_pages;
-
-	if (stock->nr_pages > MEMCG_CHARGE_BATCH)
-		drain_stock(stock);
-
+	__refill_stock(stock, memcg, nr_pages);
 	local_irq_restore(flags);
 }
 
@@ -3213,8 +3217,15 @@ static void drain_obj_stock(struct obj_stock *stock)
 		unsigned int nr_pages = stock->nr_bytes >> PAGE_SHIFT;
 		unsigned int nr_bytes = stock->nr_bytes & (PAGE_SIZE - 1);
 
-		if (nr_pages)
-			obj_cgroup_uncharge_pages(old, nr_pages);
+		/* Flush complete pages back to the page stock */
+		if (nr_pages) {
+			struct mem_cgroup *memcg;
+
+			memcg = get_mem_cgroup_from_objcg(objcg);
+			mem_cgroup_kmem_record(memcg, -nr_pages);
+			__refill_stock(stock, memcg, nr_pages);
+			css_put(&memcg->css);
+		}
 
 		/*
 		 * The leftover is flushed to the centralized per-memcg value.
Sebastian Andrzej Siewior Feb. 16, 2022, 3:51 p.m. UTC | #2
On 2022-02-14 11:23:11 [-0500], Johannes Weiner wrote:
> Hi Sebastian,
Hi Johannes,

> This is a bit dubious in terms of layering. It's an objcg operation,
> but what's "locked" isn't the objcg, it's the underlying stock. That
> function then looks it up again, even though we have it right there.
> 
> You can open-code it and factor out the stock operation instead, and
> it makes things much simpler and clearer.
> 
> I.e. something like this (untested!):

This then:

------>8------

From: Johannes Weiner <hannes@cmpxchg.org>
Date: Wed, 16 Feb 2022 13:25:49 +0100
Subject: [PATCH] mm/memcg: Opencode the inner part of
 obj_cgroup_uncharge_pages() in drain_obj_stock()

Provide the inner part of refill_stock() as __refill_stock() without
disabling interrupts. This eases the integration of local_lock_t where
recursive locking must be avoided.
Open code obj_cgroup_uncharge_pages() in drain_obj_stock() and use
__refill_stock(). The caller of drain_obj_stock() already disables
interrupts.

[bigeasy: Patch body around Johannes' diff ]

Signed-off-by: Johannes Weiner <hannes@cmpxchg.org>
Signed-off-by: Sebastian Andrzej Siewior <bigeasy@linutronix.de>
---
 mm/memcontrol.c | 25 +++++++++++++++++++------
 1 file changed, 19 insertions(+), 6 deletions(-)

diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index 69130a5fe3d51..f574f2e1cc399 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -2227,12 +2227,9 @@ static void drain_local_stock(struct work_struct *dummy)
  * Cache charges(val) to local per_cpu area.
  * This will be consumed by consume_stock() function, later.
  */
-static void refill_stock(struct mem_cgroup *memcg, unsigned int nr_pages)
+static void __refill_stock(struct mem_cgroup *memcg, unsigned int nr_pages)
 {
 	struct memcg_stock_pcp *stock;
-	unsigned long flags;
-
-	local_irq_save(flags);
 
 	stock = this_cpu_ptr(&memcg_stock);
 	if (stock->cached != memcg) { /* reset if necessary */
@@ -2244,7 +2241,14 @@ static void refill_stock(struct mem_cgroup *memcg, unsigned int nr_pages)
 
 	if (stock->nr_pages > MEMCG_CHARGE_BATCH)
 		drain_stock(stock);
+}
 
+static void refill_stock(struct mem_cgroup *memcg, unsigned int nr_pages)
+{
+	unsigned long flags;
+
+	local_irq_save(flags);
+	__refill_stock(memcg, nr_pages);
 	local_irq_restore(flags);
 }
 
@@ -3151,8 +3155,17 @@ static void drain_obj_stock(struct memcg_stock_pcp *stock)
 		unsigned int nr_pages = stock->nr_bytes >> PAGE_SHIFT;
 		unsigned int nr_bytes = stock->nr_bytes & (PAGE_SIZE - 1);
 
-		if (nr_pages)
-			obj_cgroup_uncharge_pages(old, nr_pages);
+		if (nr_pages) {
+			struct mem_cgroup *memcg;
+
+			memcg = get_mem_cgroup_from_objcg(old);
+
+			if (!cgroup_subsys_on_dfl(memory_cgrp_subsys))
+				page_counter_uncharge(&memcg->kmem, nr_pages);
+			__refill_stock(memcg, nr_pages);
+
+			css_put(&memcg->css);
+		}
 
 		/*
 		 * The leftover is flushed to the centralized per-memcg value.
Johannes Weiner Feb. 16, 2022, 6:08 p.m. UTC | #3
On Wed, Feb 16, 2022 at 04:51:14PM +0100, Sebastian Andrzej Siewior wrote:
> On 2022-02-14 11:23:11 [-0500], Johannes Weiner wrote:
> > Hi Sebastian,
> Hi Johannes,
> 
> > This is a bit dubious in terms of layering. It's an objcg operation,
> > but what's "locked" isn't the objcg, it's the underlying stock. That
> > function then looks it up again, even though we have it right there.
> > 
> > You can open-code it and factor out the stock operation instead, and
> > it makes things much simpler and clearer.
> > 
> > I.e. something like this (untested!):
> 
> This then:
> 
> ------>8------
> 
> From: Johannes Weiner <hannes@cmpxchg.org>
> Date: Wed, 16 Feb 2022 13:25:49 +0100
> Subject: [PATCH] mm/memcg: Opencode the inner part of
>  obj_cgroup_uncharge_pages() in drain_obj_stock()
> 
> Provide the inner part of refill_stock() as __refill_stock() without
> disabling interrupts. This eases the integration of local_lock_t where
> recursive locking must be avoided.
> Open code obj_cgroup_uncharge_pages() in drain_obj_stock() and use
> __refill_stock(). The caller of drain_obj_stock() already disables
> interrupts.
> 
> [bigeasy: Patch body around Johannes' diff ]
> 
> Signed-off-by: Johannes Weiner <hannes@cmpxchg.org>
> Signed-off-by: Sebastian Andrzej Siewior <bigeasy@linutronix.de>

I thought you'd fold it into yours, but separate patch works too,
thanks!

Signed-off-by: Johannes Weiner <hannes@cmpxchg.org>

One important note, though:

> @@ -3151,8 +3155,17 @@ static void drain_obj_stock(struct memcg_stock_pcp *stock)
>  		unsigned int nr_pages = stock->nr_bytes >> PAGE_SHIFT;
>  		unsigned int nr_bytes = stock->nr_bytes & (PAGE_SIZE - 1);
>  
> -		if (nr_pages)
> -			obj_cgroup_uncharge_pages(old, nr_pages);
> +		if (nr_pages) {
> +			struct mem_cgroup *memcg;
> +
> +			memcg = get_mem_cgroup_from_objcg(old);
> +
> +			if (!cgroup_subsys_on_dfl(memory_cgrp_subsys))
> +				page_counter_uncharge(&memcg->kmem, nr_pages);
> +			__refill_stock(memcg, nr_pages);

This doesn't take "memcg: add per-memcg total kernel memory stat"
queued in -mm into account and so will break kmem accounting.

Make sure to rebase the patches to the -mm tree before sending it
out. You can find it here: https://github.com/hnaz/linux-mm
Sebastian Andrzej Siewior Feb. 17, 2022, 9:28 a.m. UTC | #4
On 2022-02-16 13:08:55 [-0500], Johannes Weiner wrote:
> I thought you'd fold it into yours, but separate patch works too,
> thanks!
> 
> Signed-off-by: Johannes Weiner <hannes@cmpxchg.org>

Thank you.

> One important note, though:
> 
> > @@ -3151,8 +3155,17 @@ static void drain_obj_stock(struct memcg_stock_pcp *stock)
> >  		unsigned int nr_pages = stock->nr_bytes >> PAGE_SHIFT;
> >  		unsigned int nr_bytes = stock->nr_bytes & (PAGE_SIZE - 1);
> >  
> > -		if (nr_pages)
> > -			obj_cgroup_uncharge_pages(old, nr_pages);
> > +		if (nr_pages) {
> > +			struct mem_cgroup *memcg;
> > +
> > +			memcg = get_mem_cgroup_from_objcg(old);
> > +
> > +			if (!cgroup_subsys_on_dfl(memory_cgrp_subsys))
> > +				page_counter_uncharge(&memcg->kmem, nr_pages);
> > +			__refill_stock(memcg, nr_pages);
> 
> This doesn't take "memcg: add per-memcg total kernel memory stat"
> queued in -mm into account and so will break kmem accounting.
>
> Make sure to rebase the patches to the -mm tree before sending it
> out. You can find it here: https://github.com/hnaz/linux-mm

Thank you, rebased on-top.

Sebastian
diff mbox series

Patch

diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index 466466f285cea..f7120a92cf46e 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -2097,6 +2097,7 @@  void unlock_page_memcg(struct page *page)
 }
 
 struct memcg_stock_pcp {
+	local_lock_t stock_lock;
 	struct mem_cgroup *cached; /* this never be root cgroup */
 	unsigned int nr_pages;
 
@@ -2112,17 +2113,20 @@  struct memcg_stock_pcp {
 	unsigned long flags;
 #define FLUSHING_CACHED_CHARGE	0
 };
-static DEFINE_PER_CPU(struct memcg_stock_pcp, memcg_stock);
+static DEFINE_PER_CPU(struct memcg_stock_pcp, memcg_stock) = {
+	.stock_lock = INIT_LOCAL_LOCK(stock_lock),
+};
 static DEFINE_MUTEX(percpu_charge_mutex);
 
 #ifdef CONFIG_MEMCG_KMEM
-static void drain_obj_stock(struct memcg_stock_pcp *stock);
+static struct obj_cgroup *drain_obj_stock(struct memcg_stock_pcp *stock);
 static bool obj_stock_flush_required(struct memcg_stock_pcp *stock,
 				     struct mem_cgroup *root_memcg);
 
 #else
-static inline void drain_obj_stock(struct memcg_stock_pcp *stock)
+static inline struct obj_cgroup *drain_obj_stock(struct memcg_stock_pcp *stock)
 {
+	return NULL;
 }
 static bool obj_stock_flush_required(struct memcg_stock_pcp *stock,
 				     struct mem_cgroup *root_memcg)
@@ -2151,7 +2155,7 @@  static bool consume_stock(struct mem_cgroup *memcg, unsigned int nr_pages)
 	if (nr_pages > MEMCG_CHARGE_BATCH)
 		return ret;
 
-	local_irq_save(flags);
+	local_lock_irqsave(&memcg_stock.stock_lock, flags);
 
 	stock = this_cpu_ptr(&memcg_stock);
 	if (memcg == stock->cached && stock->nr_pages >= nr_pages) {
@@ -2159,7 +2163,7 @@  static bool consume_stock(struct mem_cgroup *memcg, unsigned int nr_pages)
 		ret = true;
 	}
 
-	local_irq_restore(flags);
+	local_unlock_irqrestore(&memcg_stock.stock_lock, flags);
 
 	return ret;
 }
@@ -2188,6 +2192,7 @@  static void drain_stock(struct memcg_stock_pcp *stock)
 static void drain_local_stock(struct work_struct *dummy)
 {
 	struct memcg_stock_pcp *stock;
+	struct obj_cgroup *old = NULL;
 	unsigned long flags;
 
 	/*
@@ -2195,26 +2200,25 @@  static void drain_local_stock(struct work_struct *dummy)
 	 * drain_stock races is that we always operate on local CPU stock
 	 * here with IRQ disabled
 	 */
-	local_irq_save(flags);
+	local_lock_irqsave(&memcg_stock.stock_lock, flags);
 
 	stock = this_cpu_ptr(&memcg_stock);
-	drain_obj_stock(stock);
+	old = drain_obj_stock(stock);
 	drain_stock(stock);
 	clear_bit(FLUSHING_CACHED_CHARGE, &stock->flags);
 
-	local_irq_restore(flags);
+	local_unlock_irqrestore(&memcg_stock.stock_lock, flags);
+	if (old)
+		obj_cgroup_put(old);
 }
 
 /*
  * Cache charges(val) to local per_cpu area.
  * This will be consumed by consume_stock() function, later.
  */
-static void refill_stock(struct mem_cgroup *memcg, unsigned int nr_pages)
+static void __refill_stock(struct mem_cgroup *memcg, unsigned int nr_pages)
 {
 	struct memcg_stock_pcp *stock;
-	unsigned long flags;
-
-	local_irq_save(flags);
 
 	stock = this_cpu_ptr(&memcg_stock);
 	if (stock->cached != memcg) { /* reset if necessary */
@@ -2226,8 +2230,15 @@  static void refill_stock(struct mem_cgroup *memcg, unsigned int nr_pages)
 
 	if (stock->nr_pages > MEMCG_CHARGE_BATCH)
 		drain_stock(stock);
+}
 
-	local_irq_restore(flags);
+static void refill_stock(struct mem_cgroup *memcg, unsigned int nr_pages)
+{
+	unsigned long flags;
+
+	local_lock_irqsave(&memcg_stock.stock_lock, flags);
+	__refill_stock(memcg, nr_pages);
+	local_unlock_irqrestore(&memcg_stock.stock_lock, flags);
 }
 
 /*
@@ -2236,7 +2247,7 @@  static void refill_stock(struct mem_cgroup *memcg, unsigned int nr_pages)
  */
 static void drain_all_stock(struct mem_cgroup *root_memcg)
 {
-	int cpu, curcpu;
+	int cpu;
 
 	/* If someone's already draining, avoid adding running more workers. */
 	if (!mutex_trylock(&percpu_charge_mutex))
@@ -2247,7 +2258,6 @@  static void drain_all_stock(struct mem_cgroup *root_memcg)
 	 * as well as workers from this path always operate on the local
 	 * per-cpu data. CPU up doesn't touch memcg_stock at all.
 	 */
-	curcpu = get_cpu();
 	for_each_online_cpu(cpu) {
 		struct memcg_stock_pcp *stock = &per_cpu(memcg_stock, cpu);
 		struct mem_cgroup *memcg;
@@ -2263,14 +2273,9 @@  static void drain_all_stock(struct mem_cgroup *root_memcg)
 		rcu_read_unlock();
 
 		if (flush &&
-		    !test_and_set_bit(FLUSHING_CACHED_CHARGE, &stock->flags)) {
-			if (cpu == curcpu)
-				drain_local_stock(&stock->work);
-			else
-				schedule_work_on(cpu, &stock->work);
-		}
+		    !test_and_set_bit(FLUSHING_CACHED_CHARGE, &stock->flags))
+			schedule_work_on(cpu, &stock->work);
 	}
-	put_cpu();
 	mutex_unlock(&percpu_charge_mutex);
 }
 
@@ -2948,12 +2953,13 @@  static void memcg_free_cache_id(int id)
 }
 
 /*
- * obj_cgroup_uncharge_pages: uncharge a number of kernel pages from a objcg
+ * __obj_cgroup_uncharge_pages: uncharge a number of kernel pages from a objcg
  * @objcg: object cgroup to uncharge
  * @nr_pages: number of pages to uncharge
  */
-static void obj_cgroup_uncharge_pages(struct obj_cgroup *objcg,
-				      unsigned int nr_pages)
+static void __obj_cgroup_uncharge_pages(struct obj_cgroup *objcg,
+					unsigned int nr_pages,
+					void (*refill)(struct mem_cgroup *memcg, unsigned int nr_pages))
 {
 	struct mem_cgroup *memcg;
 
@@ -2961,11 +2967,24 @@  static void obj_cgroup_uncharge_pages(struct obj_cgroup *objcg,
 
 	if (!cgroup_subsys_on_dfl(memory_cgrp_subsys))
 		page_counter_uncharge(&memcg->kmem, nr_pages);
-	refill_stock(memcg, nr_pages);
+	refill(memcg, nr_pages);
 
 	css_put(&memcg->css);
 }
 
+static void obj_cgroup_uncharge_pages(struct obj_cgroup *objcg,
+				      unsigned int nr_pages)
+{
+	__obj_cgroup_uncharge_pages(objcg, nr_pages, refill_stock);
+}
+
+static void obj_cgroup_uncharge_pages_locked(struct obj_cgroup *objcg,
+					     unsigned int nr_pages)
+{
+	__obj_cgroup_uncharge_pages(objcg, nr_pages, __refill_stock);
+}
+
+
 /*
  * obj_cgroup_charge_pages: charge a number of kernel pages to a objcg
  * @objcg: object cgroup to charge
@@ -3044,10 +3063,11 @@  void mod_objcg_state(struct obj_cgroup *objcg, struct pglist_data *pgdat,
 		     enum node_stat_item idx, int nr)
 {
 	struct memcg_stock_pcp *stock;
+	struct obj_cgroup *old = NULL;
 	unsigned long flags;
 	int *bytes;
 
-	local_irq_save(flags);
+	local_lock_irqsave(&memcg_stock.stock_lock, flags);
 	stock = this_cpu_ptr(&memcg_stock);
 
 	/*
@@ -3056,7 +3076,7 @@  void mod_objcg_state(struct obj_cgroup *objcg, struct pglist_data *pgdat,
 	 * changes.
 	 */
 	if (stock->cached_objcg != objcg) {
-		drain_obj_stock(stock);
+		old = drain_obj_stock(stock);
 		obj_cgroup_get(objcg);
 		stock->nr_bytes = atomic_read(&objcg->nr_charged_bytes)
 				? atomic_xchg(&objcg->nr_charged_bytes, 0) : 0;
@@ -3100,7 +3120,9 @@  void mod_objcg_state(struct obj_cgroup *objcg, struct pglist_data *pgdat,
 	if (nr)
 		mod_objcg_mlstate(objcg, pgdat, idx, nr);
 
-	local_irq_restore(flags);
+	local_unlock_irqrestore(&memcg_stock.stock_lock, flags);
+	if (old)
+		obj_cgroup_put(old);
 }
 
 static bool consume_obj_stock(struct obj_cgroup *objcg, unsigned int nr_bytes)
@@ -3109,7 +3131,7 @@  static bool consume_obj_stock(struct obj_cgroup *objcg, unsigned int nr_bytes)
 	unsigned long flags;
 	bool ret = false;
 
-	local_irq_save(flags);
+	local_lock_irqsave(&memcg_stock.stock_lock, flags);
 
 	stock = this_cpu_ptr(&memcg_stock);
 	if (objcg == stock->cached_objcg && stock->nr_bytes >= nr_bytes) {
@@ -3117,24 +3139,24 @@  static bool consume_obj_stock(struct obj_cgroup *objcg, unsigned int nr_bytes)
 		ret = true;
 	}
 
-	local_irq_restore(flags);
+	local_unlock_irqrestore(&memcg_stock.stock_lock, flags);
 
 	return ret;
 }
 
-static void drain_obj_stock(struct memcg_stock_pcp *stock)
+static struct obj_cgroup *drain_obj_stock(struct memcg_stock_pcp *stock)
 {
 	struct obj_cgroup *old = stock->cached_objcg;
 
 	if (!old)
-		return;
+		return NULL;
 
 	if (stock->nr_bytes) {
 		unsigned int nr_pages = stock->nr_bytes >> PAGE_SHIFT;
 		unsigned int nr_bytes = stock->nr_bytes & (PAGE_SIZE - 1);
 
 		if (nr_pages)
-			obj_cgroup_uncharge_pages(old, nr_pages);
+			obj_cgroup_uncharge_pages_locked(old, nr_pages);
 
 		/*
 		 * The leftover is flushed to the centralized per-memcg value.
@@ -3169,8 +3191,8 @@  static void drain_obj_stock(struct memcg_stock_pcp *stock)
 		stock->cached_pgdat = NULL;
 	}
 
-	obj_cgroup_put(old);
 	stock->cached_objcg = NULL;
+	return old;
 }
 
 static bool obj_stock_flush_required(struct memcg_stock_pcp *stock,
@@ -3191,14 +3213,15 @@  static void refill_obj_stock(struct obj_cgroup *objcg, unsigned int nr_bytes,
 			     bool allow_uncharge)
 {
 	struct memcg_stock_pcp *stock;
+	struct obj_cgroup *old = NULL;
 	unsigned long flags;
 	unsigned int nr_pages = 0;
 
-	local_irq_save(flags);
+	local_lock_irqsave(&memcg_stock.stock_lock, flags);
 
 	stock = this_cpu_ptr(&memcg_stock);
 	if (stock->cached_objcg != objcg) { /* reset if necessary */
-		drain_obj_stock(stock);
+		old = drain_obj_stock(stock);
 		obj_cgroup_get(objcg);
 		stock->cached_objcg = objcg;
 		stock->nr_bytes = atomic_read(&objcg->nr_charged_bytes)
@@ -3212,7 +3235,9 @@  static void refill_obj_stock(struct obj_cgroup *objcg, unsigned int nr_bytes,
 		stock->nr_bytes &= (PAGE_SIZE - 1);
 	}
 
-	local_irq_restore(flags);
+	local_unlock_irqrestore(&memcg_stock.stock_lock, flags);
+	if (old)
+		obj_cgroup_put(old);
 
 	if (nr_pages)
 		obj_cgroup_uncharge_pages(objcg, nr_pages);