diff mbox series

[01/11] xfrm: Add support for per cpu xfrm state handling.

Message ID 20241115083343.2340827-2-steffen.klassert@secunet.com (mailing list archive)
State Accepted
Commit 1ddf9916ac09313128e40d6581cef889c0b4ce84
Delegated to: Netdev Maintainers
Headers show
Series [01/11] xfrm: Add support for per cpu xfrm state handling. | expand

Checks

Context Check Description
netdev/series_format warning Pull request is its own cover letter; Target tree name not specified in the subject
netdev/tree_selection success Guessed tree name to be net-next, async
netdev/ynl success Generated files up to date; no warnings/errors; no diff in generated;
netdev/fixes_present success Fixes tag not required for -next series
netdev/header_inline success No static functions without inline keyword in header files
netdev/build_32bit success Errors and warnings before: 40 this patch: 40
netdev/build_tools success Errors and warnings before: 0 (+0) this patch: 0 (+0)
netdev/cc_maintainers warning 3 maintainers not CCed: edumazet@google.com horms@kernel.org pabeni@redhat.com
netdev/build_clang success Errors and warnings before: 71 this patch: 71
netdev/verify_signedoff success Signed-off-by tag matches author and committer
netdev/deprecated_api success None detected
netdev/check_selftest success No net selftest shell script
netdev/verify_fixes success No Fixes tag
netdev/build_allmodconfig_warn success Errors and warnings before: 4258 this patch: 4258
netdev/checkpatch warning CHECK: Unnecessary parentheses around 'best->pcpu_num == pcpu_id' WARNING: Block comments use a trailing */ on a separate line WARNING: Missing a blank line after declarations WARNING: line length of 86 exceeds 80 columns WARNING: line length of 87 exceeds 80 columns WARNING: line length of 88 exceeds 80 columns WARNING: line length of 89 exceeds 80 columns WARNING: line length of 97 exceeds 80 columns WARNING: line length of 98 exceeds 80 columns
netdev/build_clang_rust success No Rust files in patch. Skipping build
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/source_inline success Was 0 now: 0
netdev/contest success net-next-2024-11-15--21-00 (tests: 789)

Commit Message

Steffen Klassert Nov. 15, 2024, 8:33 a.m. UTC
Currently all flows for a certain SA must be processed by the same
cpu to avoid packet reordering and lock contention of the xfrm
state lock.

To get rid of this limitation, the IETF standardized per cpu SAs
in RFC 9611. This patch implements the xfrm part of it.

We add the cpu as a lookup key for xfrm states and a config option
to generate acquire messages for each cpu.

With that, we can have on each cpu a SA with identical traffic selector
so that flows can be processed in parallel on all cpus.

Signed-off-by: Steffen Klassert <steffen.klassert@secunet.com>
Tested-by: Antony Antony <antony.antony@secunet.com>
Tested-by: Tobias Brunner <tobias@strongswan.org>
---
 include/net/xfrm.h        |  5 ++--
 include/uapi/linux/xfrm.h |  2 ++
 net/key/af_key.c          |  7 +++--
 net/xfrm/xfrm_compat.c    |  6 ++--
 net/xfrm/xfrm_state.c     | 58 +++++++++++++++++++++++++++++++--------
 net/xfrm/xfrm_user.c      | 56 ++++++++++++++++++++++++++++++++++---
 6 files changed, 112 insertions(+), 22 deletions(-)

Comments

Jakub Kicinski Nov. 16, 2024, 2:09 a.m. UTC | #1
On Fri, 15 Nov 2024 09:33:33 +0100 Steffen Klassert wrote:
> +	/* We need the cpu id just as a lookup key,
> +	 * we don't require it to be stable.
> +	 */
> +	pcpu_id = get_cpu();
> +	put_cpu();

Why not smp_processor_id() ?

> +	if (attrs[XFRMA_SA_PCPU]) {
> +		x->pcpu_num = nla_get_u32(attrs[XFRMA_SA_PCPU]);
> +		if (x->pcpu_num >= num_possible_cpus())
> +			goto error;
> +	}

cpu ids can be sparse, shouldn't it be checking if the CPU is online ?
Steffen Klassert Nov. 18, 2024, 6:23 a.m. UTC | #2
On Fri, Nov 15, 2024 at 06:09:08PM -0800, Jakub Kicinski wrote:
> On Fri, 15 Nov 2024 09:33:33 +0100 Steffen Klassert wrote:
> > +	/* We need the cpu id just as a lookup key,
> > +	 * we don't require it to be stable.
> > +	 */
> > +	pcpu_id = get_cpu();
> > +	put_cpu();
> 
> Why not smp_processor_id() ?

This might be executed in preemptable code sections,
smp_processor_id() will throw a warning if that happens.
Maybe raw_smp_processor_id() can be used instead here,
but was not sure if that's the right thing.

> 
> > +	if (attrs[XFRMA_SA_PCPU]) {
> > +		x->pcpu_num = nla_get_u32(attrs[XFRMA_SA_PCPU]);
> > +		if (x->pcpu_num >= num_possible_cpus())
> > +			goto error;
> > +	}
> 
> cpu ids can be sparse, shouldn't it be checking if the CPU is online ?

I thought about that. But then we must wait for an IKE negotiation
before we can use a fresh booted cpu. If we pre-negotiate a SA,
the cpu can be used right away. I depends a bit on the usecase what's
preferred here.
patchwork-bot+netdevbpf@kernel.org Nov. 18, 2024, noon UTC | #3
Hello:

This series was applied to netdev/net-next.git (main)
by Steffen Klassert <steffen.klassert@secunet.com>:

On Fri, 15 Nov 2024 09:33:33 +0100 you wrote:
> Currently all flows for a certain SA must be processed by the same
> cpu to avoid packet reordering and lock contention of the xfrm
> state lock.
> 
> To get rid of this limitation, the IETF standardized per cpu SAs
> in RFC 9611. This patch implements the xfrm part of it.
> 
> [...]

Here is the summary with links:
  - [01/11] xfrm: Add support for per cpu xfrm state handling.
    https://git.kernel.org/netdev/net-next/c/1ddf9916ac09
  - [02/11] xfrm: Cache used outbound xfrm states at the policy.
    https://git.kernel.org/netdev/net-next/c/0045e3d80613
  - [03/11] xfrm: Add an inbound percpu state cache.
    https://git.kernel.org/netdev/net-next/c/81a331a0e72d
  - [04/11] xfrm: Restrict percpu SA attribute to specific netlink message types
    https://git.kernel.org/netdev/net-next/c/83dfce38c49f
  - [05/11] xfrm: Convert xfrm_get_tos() to dscp_t.
    https://git.kernel.org/netdev/net-next/c/766f532089af
  - [06/11] xfrm: Convert xfrm_bundle_create() to dscp_t.
    https://git.kernel.org/netdev/net-next/c/01f61cbfc8b2
  - [07/11] xfrm: Convert xfrm_dst_lookup() to dscp_t.
    https://git.kernel.org/netdev/net-next/c/3021a2a3403d
  - [08/11] xfrm: Convert struct xfrm_dst_lookup_params -> tos to dscp_t.
    https://git.kernel.org/netdev/net-next/c/e57dfaa4b0a7
  - [09/11] xfrm: Add error handling when nla_put_u32() returns an error
    https://git.kernel.org/netdev/net-next/c/9d287e70c51f
  - [10/11] xfrm: replace deprecated strncpy with strscpy_pad
    https://git.kernel.org/netdev/net-next/c/9e1a6db68e3c
  - [11/11] xfrm: Fix acquire state insertion.
    https://git.kernel.org/netdev/net-next/c/a35672819f8d

You are awesome, thank you!
diff mbox series

Patch

diff --git a/include/net/xfrm.h b/include/net/xfrm.h
index a0bdd58f401c..f5275618e744 100644
--- a/include/net/xfrm.h
+++ b/include/net/xfrm.h
@@ -188,6 +188,7 @@  struct xfrm_state {
 	refcount_t		refcnt;
 	spinlock_t		lock;
 
+	u32			pcpu_num;
 	struct xfrm_id		id;
 	struct xfrm_selector	sel;
 	struct xfrm_mark	mark;
@@ -1684,7 +1685,7 @@  struct xfrmk_spdinfo {
 	u32 spdhmcnt;
 };
 
-struct xfrm_state *xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq);
+struct xfrm_state *xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq, u32 pcpu_num);
 int xfrm_state_delete(struct xfrm_state *x);
 int xfrm_state_flush(struct net *net, u8 proto, bool task_valid, bool sync);
 int xfrm_dev_state_flush(struct net *net, struct net_device *dev, bool task_valid);
@@ -1796,7 +1797,7 @@  int verify_spi_info(u8 proto, u32 min, u32 max, struct netlink_ext_ack *extack);
 int xfrm_alloc_spi(struct xfrm_state *x, u32 minspi, u32 maxspi,
 		   struct netlink_ext_ack *extack);
 struct xfrm_state *xfrm_find_acq(struct net *net, const struct xfrm_mark *mark,
-				 u8 mode, u32 reqid, u32 if_id, u8 proto,
+				 u8 mode, u32 reqid, u32 if_id, u32 pcpu_num, u8 proto,
 				 const xfrm_address_t *daddr,
 				 const xfrm_address_t *saddr, int create,
 				 unsigned short family);
diff --git a/include/uapi/linux/xfrm.h b/include/uapi/linux/xfrm.h
index f28701500714..d73a97e3030a 100644
--- a/include/uapi/linux/xfrm.h
+++ b/include/uapi/linux/xfrm.h
@@ -322,6 +322,7 @@  enum xfrm_attr_type_t {
 	XFRMA_MTIMER_THRESH,	/* __u32 in seconds for input SA */
 	XFRMA_SA_DIR,		/* __u8 */
 	XFRMA_NAT_KEEPALIVE_INTERVAL,	/* __u32 in seconds for NAT keepalive */
+	XFRMA_SA_PCPU,		/* __u32 */
 	__XFRMA_MAX
 
 #define XFRMA_OUTPUT_MARK XFRMA_SET_MARK	/* Compatibility */
@@ -437,6 +438,7 @@  struct xfrm_userpolicy_info {
 #define XFRM_POLICY_LOCALOK	1	/* Allow user to override global policy */
 	/* Automatically expand selector to include matching ICMP payloads. */
 #define XFRM_POLICY_ICMP	2
+#define XFRM_POLICY_CPU_ACQUIRE	4
 	__u8				share;
 };
 
diff --git a/net/key/af_key.c b/net/key/af_key.c
index f79fb99271ed..c56bb4f451e6 100644
--- a/net/key/af_key.c
+++ b/net/key/af_key.c
@@ -1354,7 +1354,7 @@  static int pfkey_getspi(struct sock *sk, struct sk_buff *skb, const struct sadb_
 	}
 
 	if (hdr->sadb_msg_seq) {
-		x = xfrm_find_acq_byseq(net, DUMMY_MARK, hdr->sadb_msg_seq);
+		x = xfrm_find_acq_byseq(net, DUMMY_MARK, hdr->sadb_msg_seq, UINT_MAX);
 		if (x && !xfrm_addr_equal(&x->id.daddr, xdaddr, family)) {
 			xfrm_state_put(x);
 			x = NULL;
@@ -1362,7 +1362,8 @@  static int pfkey_getspi(struct sock *sk, struct sk_buff *skb, const struct sadb_
 	}
 
 	if (!x)
-		x = xfrm_find_acq(net, &dummy_mark, mode, reqid, 0, proto, xdaddr, xsaddr, 1, family);
+		x = xfrm_find_acq(net, &dummy_mark, mode, reqid, 0, UINT_MAX,
+				  proto, xdaddr, xsaddr, 1, family);
 
 	if (x == NULL)
 		return -ENOENT;
@@ -1417,7 +1418,7 @@  static int pfkey_acquire(struct sock *sk, struct sk_buff *skb, const struct sadb
 	if (hdr->sadb_msg_seq == 0 || hdr->sadb_msg_errno == 0)
 		return 0;
 
-	x = xfrm_find_acq_byseq(net, DUMMY_MARK, hdr->sadb_msg_seq);
+	x = xfrm_find_acq_byseq(net, DUMMY_MARK, hdr->sadb_msg_seq, UINT_MAX);
 	if (x == NULL)
 		return 0;
 
diff --git a/net/xfrm/xfrm_compat.c b/net/xfrm/xfrm_compat.c
index 91357ccaf4af..5b9ee63e30b6 100644
--- a/net/xfrm/xfrm_compat.c
+++ b/net/xfrm/xfrm_compat.c
@@ -132,6 +132,7 @@  static const struct nla_policy compat_policy[XFRMA_MAX+1] = {
 	[XFRMA_MTIMER_THRESH]	= { .type = NLA_U32 },
 	[XFRMA_SA_DIR]          = NLA_POLICY_RANGE(NLA_U8, XFRM_SA_DIR_IN, XFRM_SA_DIR_OUT),
 	[XFRMA_NAT_KEEPALIVE_INTERVAL]	= { .type = NLA_U32 },
+	[XFRMA_SA_PCPU]		= { .type = NLA_U32 },
 };
 
 static struct nlmsghdr *xfrm_nlmsg_put_compat(struct sk_buff *skb,
@@ -282,9 +283,10 @@  static int xfrm_xlate64_attr(struct sk_buff *dst, const struct nlattr *src)
 	case XFRMA_MTIMER_THRESH:
 	case XFRMA_SA_DIR:
 	case XFRMA_NAT_KEEPALIVE_INTERVAL:
+	case XFRMA_SA_PCPU:
 		return xfrm_nla_cpy(dst, src, nla_len(src));
 	default:
-		BUILD_BUG_ON(XFRMA_MAX != XFRMA_NAT_KEEPALIVE_INTERVAL);
+		BUILD_BUG_ON(XFRMA_MAX != XFRMA_SA_PCPU);
 		pr_warn_once("unsupported nla_type %d\n", src->nla_type);
 		return -EOPNOTSUPP;
 	}
@@ -439,7 +441,7 @@  static int xfrm_xlate32_attr(void *dst, const struct nlattr *nla,
 	int err;
 
 	if (type > XFRMA_MAX) {
-		BUILD_BUG_ON(XFRMA_MAX != XFRMA_NAT_KEEPALIVE_INTERVAL);
+		BUILD_BUG_ON(XFRMA_MAX != XFRMA_SA_PCPU);
 		NL_SET_ERR_MSG(extack, "Bad attribute");
 		return -EOPNOTSUPP;
 	}
diff --git a/net/xfrm/xfrm_state.c b/net/xfrm/xfrm_state.c
index 37478d36a8df..ebef07b80afa 100644
--- a/net/xfrm/xfrm_state.c
+++ b/net/xfrm/xfrm_state.c
@@ -679,6 +679,7 @@  struct xfrm_state *xfrm_state_alloc(struct net *net)
 		x->lft.hard_packet_limit = XFRM_INF;
 		x->replay_maxage = 0;
 		x->replay_maxdiff = 0;
+		x->pcpu_num = UINT_MAX;
 		spin_lock_init(&x->lock);
 	}
 	return x;
@@ -1155,6 +1156,12 @@  static void xfrm_state_look_at(struct xfrm_policy *pol, struct xfrm_state *x,
 			       struct xfrm_state **best, int *acq_in_progress,
 			       int *error)
 {
+	/* We need the cpu id just as a lookup key,
+	 * we don't require it to be stable.
+	 */
+	unsigned int pcpu_id = get_cpu();
+	put_cpu();
+
 	/* Resolution logic:
 	 * 1. There is a valid state with matching selector. Done.
 	 * 2. Valid state with inappropriate selector. Skip.
@@ -1174,13 +1181,18 @@  static void xfrm_state_look_at(struct xfrm_policy *pol, struct xfrm_state *x,
 							&fl->u.__fl_common))
 			return;
 
+		if (x->pcpu_num != UINT_MAX && x->pcpu_num != pcpu_id)
+			return;
+
 		if (!*best ||
+		    ((*best)->pcpu_num == UINT_MAX && x->pcpu_num == pcpu_id) ||
 		    (*best)->km.dying > x->km.dying ||
 		    ((*best)->km.dying == x->km.dying &&
 		     (*best)->curlft.add_time < x->curlft.add_time))
 			*best = x;
 	} else if (x->km.state == XFRM_STATE_ACQ) {
-		*acq_in_progress = 1;
+		if (!*best || x->pcpu_num == pcpu_id)
+			*acq_in_progress = 1;
 	} else if (x->km.state == XFRM_STATE_ERROR ||
 		   x->km.state == XFRM_STATE_EXPIRED) {
 		if ((!x->sel.family ||
@@ -1209,6 +1221,13 @@  xfrm_state_find(const xfrm_address_t *daddr, const xfrm_address_t *saddr,
 	unsigned short encap_family = tmpl->encap_family;
 	unsigned int sequence;
 	struct km_event c;
+	unsigned int pcpu_id;
+
+	/* We need the cpu id just as a lookup key,
+	 * we don't require it to be stable.
+	 */
+	pcpu_id = get_cpu();
+	put_cpu();
 
 	to_put = NULL;
 
@@ -1282,7 +1301,10 @@  xfrm_state_find(const xfrm_address_t *daddr, const xfrm_address_t *saddr,
 	}
 
 found:
-	x = best;
+	if (!(pol->flags & XFRM_POLICY_CPU_ACQUIRE) ||
+	    (best && (best->pcpu_num == pcpu_id)))
+		x = best;
+
 	if (!x && !error && !acquire_in_progress) {
 		if (tmpl->id.spi &&
 		    (x0 = __xfrm_state_lookup_all(net, mark, daddr,
@@ -1314,6 +1336,8 @@  xfrm_state_find(const xfrm_address_t *daddr, const xfrm_address_t *saddr,
 		xfrm_init_tempstate(x, fl, tmpl, daddr, saddr, family);
 		memcpy(&x->mark, &pol->mark, sizeof(x->mark));
 		x->if_id = if_id;
+		if ((pol->flags & XFRM_POLICY_CPU_ACQUIRE) && best)
+			x->pcpu_num = pcpu_id;
 
 		error = security_xfrm_state_alloc_acquire(x, pol->security, fl->flowi_secid);
 		if (error) {
@@ -1392,6 +1416,11 @@  xfrm_state_find(const xfrm_address_t *daddr, const xfrm_address_t *saddr,
 			x = NULL;
 			error = -ESRCH;
 		}
+
+		/* Use the already installed 'fallback' while the CPU-specific
+		 * SA acquire is handled*/
+		if (best)
+			x = best;
 	}
 out:
 	if (x) {
@@ -1524,12 +1553,14 @@  static void __xfrm_state_bump_genids(struct xfrm_state *xnew)
 	unsigned int h;
 	u32 mark = xnew->mark.v & xnew->mark.m;
 	u32 if_id = xnew->if_id;
+	u32 cpu_id = xnew->pcpu_num;
 
 	h = xfrm_dst_hash(net, &xnew->id.daddr, &xnew->props.saddr, reqid, family);
 	hlist_for_each_entry(x, net->xfrm.state_bydst+h, bydst) {
 		if (x->props.family	== family &&
 		    x->props.reqid	== reqid &&
 		    x->if_id		== if_id &&
+		    x->pcpu_num		== cpu_id &&
 		    (mark & x->mark.m) == x->mark.v &&
 		    xfrm_addr_equal(&x->id.daddr, &xnew->id.daddr, family) &&
 		    xfrm_addr_equal(&x->props.saddr, &xnew->props.saddr, family))
@@ -1552,7 +1583,7 @@  EXPORT_SYMBOL(xfrm_state_insert);
 static struct xfrm_state *__find_acq_core(struct net *net,
 					  const struct xfrm_mark *m,
 					  unsigned short family, u8 mode,
-					  u32 reqid, u32 if_id, u8 proto,
+					  u32 reqid, u32 if_id, u32 pcpu_num, u8 proto,
 					  const xfrm_address_t *daddr,
 					  const xfrm_address_t *saddr,
 					  int create)
@@ -1569,6 +1600,7 @@  static struct xfrm_state *__find_acq_core(struct net *net,
 		    x->id.spi       != 0 ||
 		    x->id.proto	    != proto ||
 		    (mark & x->mark.m) != x->mark.v ||
+		    x->pcpu_num != pcpu_num ||
 		    !xfrm_addr_equal(&x->id.daddr, daddr, family) ||
 		    !xfrm_addr_equal(&x->props.saddr, saddr, family))
 			continue;
@@ -1602,6 +1634,7 @@  static struct xfrm_state *__find_acq_core(struct net *net,
 			break;
 		}
 
+		x->pcpu_num = pcpu_num;
 		x->km.state = XFRM_STATE_ACQ;
 		x->id.proto = proto;
 		x->props.family = family;
@@ -1630,7 +1663,7 @@  static struct xfrm_state *__find_acq_core(struct net *net,
 	return x;
 }
 
-static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq);
+static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq, u32 pcpu_num);
 
 int xfrm_state_add(struct xfrm_state *x)
 {
@@ -1656,7 +1689,7 @@  int xfrm_state_add(struct xfrm_state *x)
 	}
 
 	if (use_spi && x->km.seq) {
-		x1 = __xfrm_find_acq_byseq(net, mark, x->km.seq);
+		x1 = __xfrm_find_acq_byseq(net, mark, x->km.seq, x->pcpu_num);
 		if (x1 && ((x1->id.proto != x->id.proto) ||
 		    !xfrm_addr_equal(&x1->id.daddr, &x->id.daddr, family))) {
 			to_put = x1;
@@ -1666,7 +1699,7 @@  int xfrm_state_add(struct xfrm_state *x)
 
 	if (use_spi && !x1)
 		x1 = __find_acq_core(net, &x->mark, family, x->props.mode,
-				     x->props.reqid, x->if_id, x->id.proto,
+				     x->props.reqid, x->if_id, x->pcpu_num, x->id.proto,
 				     &x->id.daddr, &x->props.saddr, 0);
 
 	__xfrm_state_bump_genids(x);
@@ -1791,6 +1824,7 @@  static struct xfrm_state *xfrm_state_clone(struct xfrm_state *orig,
 	x->props.flags = orig->props.flags;
 	x->props.extra_flags = orig->props.extra_flags;
 
+	x->pcpu_num = orig->pcpu_num;
 	x->if_id = orig->if_id;
 	x->tfcpad = orig->tfcpad;
 	x->replay_maxdiff = orig->replay_maxdiff;
@@ -2066,13 +2100,14 @@  EXPORT_SYMBOL(xfrm_state_lookup_byaddr);
 
 struct xfrm_state *
 xfrm_find_acq(struct net *net, const struct xfrm_mark *mark, u8 mode, u32 reqid,
-	      u32 if_id, u8 proto, const xfrm_address_t *daddr,
+	      u32 if_id, u32 pcpu_num, u8 proto, const xfrm_address_t *daddr,
 	      const xfrm_address_t *saddr, int create, unsigned short family)
 {
 	struct xfrm_state *x;
 
 	spin_lock_bh(&net->xfrm.xfrm_state_lock);
-	x = __find_acq_core(net, mark, family, mode, reqid, if_id, proto, daddr, saddr, create);
+	x = __find_acq_core(net, mark, family, mode, reqid, if_id, pcpu_num,
+			    proto, daddr, saddr, create);
 	spin_unlock_bh(&net->xfrm.xfrm_state_lock);
 
 	return x;
@@ -2207,7 +2242,7 @@  xfrm_state_sort(struct xfrm_state **dst, struct xfrm_state **src, int n,
 
 /* Silly enough, but I'm lazy to build resolution list */
 
-static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq)
+static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq, u32 pcpu_num)
 {
 	unsigned int h = xfrm_seq_hash(net, seq);
 	struct xfrm_state *x;
@@ -2215,6 +2250,7 @@  static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 s
 	hlist_for_each_entry_rcu(x, net->xfrm.state_byseq + h, byseq) {
 		if (x->km.seq == seq &&
 		    (mark & x->mark.m) == x->mark.v &&
+		    x->pcpu_num == pcpu_num &&
 		    x->km.state == XFRM_STATE_ACQ) {
 			xfrm_state_hold(x);
 			return x;
@@ -2224,12 +2260,12 @@  static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 s
 	return NULL;
 }
 
-struct xfrm_state *xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq)
+struct xfrm_state *xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq, u32 pcpu_num)
 {
 	struct xfrm_state *x;
 
 	spin_lock_bh(&net->xfrm.xfrm_state_lock);
-	x = __xfrm_find_acq_byseq(net, mark, seq);
+	x = __xfrm_find_acq_byseq(net, mark, seq, pcpu_num);
 	spin_unlock_bh(&net->xfrm.xfrm_state_lock);
 	return x;
 }
diff --git a/net/xfrm/xfrm_user.c b/net/xfrm/xfrm_user.c
index e3b8ce89831a..e4d448950d05 100644
--- a/net/xfrm/xfrm_user.c
+++ b/net/xfrm/xfrm_user.c
@@ -460,6 +460,12 @@  static int verify_newsa_info(struct xfrm_usersa_info *p,
 		}
 	}
 
+	if (!sa_dir && attrs[XFRMA_SA_PCPU]) {
+		NL_SET_ERR_MSG(extack, "SA_PCPU only supported with SA_DIR");
+		err = -EINVAL;
+		goto out;
+	}
+
 out:
 	return err;
 }
@@ -841,6 +847,12 @@  static struct xfrm_state *xfrm_state_construct(struct net *net,
 		x->nat_keepalive_interval =
 			nla_get_u32(attrs[XFRMA_NAT_KEEPALIVE_INTERVAL]);
 
+	if (attrs[XFRMA_SA_PCPU]) {
+		x->pcpu_num = nla_get_u32(attrs[XFRMA_SA_PCPU]);
+		if (x->pcpu_num >= num_possible_cpus())
+			goto error;
+	}
+
 	err = __xfrm_init_state(x, false, attrs[XFRMA_OFFLOAD_DEV], extack);
 	if (err)
 		goto error;
@@ -1296,6 +1308,11 @@  static int copy_to_user_state_extra(struct xfrm_state *x,
 		if (ret)
 			goto out;
 	}
+	if (x->pcpu_num != UINT_MAX) {
+		ret = nla_put_u32(skb, XFRMA_SA_PCPU, x->pcpu_num);
+		if (ret)
+			goto out;
+	}
 	if (x->dir)
 		ret = nla_put_u8(skb, XFRMA_SA_DIR, x->dir);
 
@@ -1700,6 +1717,7 @@  static int xfrm_alloc_userspi(struct sk_buff *skb, struct nlmsghdr *nlh,
 	u32 mark;
 	struct xfrm_mark m;
 	u32 if_id = 0;
+	u32 pcpu_num = UINT_MAX;
 
 	p = nlmsg_data(nlh);
 	err = verify_spi_info(p->info.id.proto, p->min, p->max, extack);
@@ -1716,8 +1734,16 @@  static int xfrm_alloc_userspi(struct sk_buff *skb, struct nlmsghdr *nlh,
 	if (attrs[XFRMA_IF_ID])
 		if_id = nla_get_u32(attrs[XFRMA_IF_ID]);
 
+	if (attrs[XFRMA_SA_PCPU]) {
+		pcpu_num = nla_get_u32(attrs[XFRMA_SA_PCPU]);
+		if (pcpu_num >= num_possible_cpus()) {
+			err = -EINVAL;
+			goto out_noput;
+		}
+	}
+
 	if (p->info.seq) {
-		x = xfrm_find_acq_byseq(net, mark, p->info.seq);
+		x = xfrm_find_acq_byseq(net, mark, p->info.seq, pcpu_num);
 		if (x && !xfrm_addr_equal(&x->id.daddr, daddr, family)) {
 			xfrm_state_put(x);
 			x = NULL;
@@ -1726,7 +1752,7 @@  static int xfrm_alloc_userspi(struct sk_buff *skb, struct nlmsghdr *nlh,
 
 	if (!x)
 		x = xfrm_find_acq(net, &m, p->info.mode, p->info.reqid,
-				  if_id, p->info.id.proto, daddr,
+				  if_id, pcpu_num, p->info.id.proto, daddr,
 				  &p->info.saddr, 1,
 				  family);
 	err = -ENOENT;
@@ -2526,7 +2552,8 @@  static inline unsigned int xfrm_aevent_msgsize(struct xfrm_state *x)
 	       + nla_total_size(sizeof(struct xfrm_mark))
 	       + nla_total_size(4) /* XFRM_AE_RTHR */
 	       + nla_total_size(4) /* XFRM_AE_ETHR */
-	       + nla_total_size(sizeof(x->dir)); /* XFRMA_SA_DIR */
+	       + nla_total_size(sizeof(x->dir)) /* XFRMA_SA_DIR */
+	       + nla_total_size(4); /* XFRMA_SA_PCPU */
 }
 
 static int build_aevent(struct sk_buff *skb, struct xfrm_state *x, const struct km_event *c)
@@ -2582,6 +2609,8 @@  static int build_aevent(struct sk_buff *skb, struct xfrm_state *x, const struct
 	err = xfrm_if_id_put(skb, x->if_id);
 	if (err)
 		goto out_cancel;
+	if (x->pcpu_num != UINT_MAX)
+		err = nla_put_u32(skb, XFRMA_SA_PCPU, x->pcpu_num);
 
 	if (x->dir) {
 		err = nla_put_u8(skb, XFRMA_SA_DIR, x->dir);
@@ -2852,6 +2881,13 @@  static int xfrm_add_acquire(struct sk_buff *skb, struct nlmsghdr *nlh,
 
 	xfrm_mark_get(attrs, &mark);
 
+	if (attrs[XFRMA_SA_PCPU]) {
+		x->pcpu_num = nla_get_u32(attrs[XFRMA_SA_PCPU]);
+		err = -EINVAL;
+		if (x->pcpu_num >= num_possible_cpus())
+			goto free_state;
+	}
+
 	err = verify_newpolicy_info(&ua->policy, extack);
 	if (err)
 		goto free_state;
@@ -3182,6 +3218,7 @@  const struct nla_policy xfrma_policy[XFRMA_MAX+1] = {
 	[XFRMA_MTIMER_THRESH]   = { .type = NLA_U32 },
 	[XFRMA_SA_DIR]          = NLA_POLICY_RANGE(NLA_U8, XFRM_SA_DIR_IN, XFRM_SA_DIR_OUT),
 	[XFRMA_NAT_KEEPALIVE_INTERVAL] = { .type = NLA_U32 },
+	[XFRMA_SA_PCPU]		= { .type = NLA_U32 },
 };
 EXPORT_SYMBOL_GPL(xfrma_policy);
 
@@ -3348,7 +3385,8 @@  static inline unsigned int xfrm_expire_msgsize(void)
 {
 	return NLMSG_ALIGN(sizeof(struct xfrm_user_expire)) +
 	       nla_total_size(sizeof(struct xfrm_mark)) +
-	       nla_total_size(sizeof_field(struct xfrm_state, dir));
+	       nla_total_size(sizeof_field(struct xfrm_state, dir)) +
+	       nla_total_size(4); /* XFRMA_SA_PCPU */
 }
 
 static int build_expire(struct sk_buff *skb, struct xfrm_state *x, const struct km_event *c)
@@ -3374,6 +3412,11 @@  static int build_expire(struct sk_buff *skb, struct xfrm_state *x, const struct
 	err = xfrm_if_id_put(skb, x->if_id);
 	if (err)
 		return err;
+	if (x->pcpu_num != UINT_MAX) {
+		err = nla_put_u32(skb, XFRMA_SA_PCPU, x->pcpu_num);
+		if (err)
+			return err;
+	}
 
 	if (x->dir) {
 		err = nla_put_u8(skb, XFRMA_SA_DIR, x->dir);
@@ -3481,6 +3524,8 @@  static inline unsigned int xfrm_sa_len(struct xfrm_state *x)
 	}
 	if (x->if_id)
 		l += nla_total_size(sizeof(x->if_id));
+	if (x->pcpu_num)
+		l += nla_total_size(sizeof(x->pcpu_num));
 
 	/* Must count x->lastused as it may become non-zero behind our back. */
 	l += nla_total_size_64bit(sizeof(u64));
@@ -3587,6 +3632,7 @@  static inline unsigned int xfrm_acquire_msgsize(struct xfrm_state *x,
 	       + nla_total_size(sizeof(struct xfrm_user_tmpl) * xp->xfrm_nr)
 	       + nla_total_size(sizeof(struct xfrm_mark))
 	       + nla_total_size(xfrm_user_sec_ctx_size(x->security))
+	       + nla_total_size(4) /* XFRMA_SA_PCPU */
 	       + userpolicy_type_attrsize();
 }
 
@@ -3623,6 +3669,8 @@  static int build_acquire(struct sk_buff *skb, struct xfrm_state *x,
 		err = xfrm_if_id_put(skb, xp->if_id);
 	if (!err && xp->xdo.dev)
 		err = copy_user_offload(&xp->xdo, skb);
+	if (!err && x->pcpu_num != UINT_MAX)
+		err = nla_put_u32(skb, XFRMA_SA_PCPU, x->pcpu_num);
 	if (err) {
 		nlmsg_cancel(skb, nlh);
 		return err;