diff mbox series

[v3,net-next,2/2] udp_tunnel: use static call for GRO hooks when possible

Message ID b65c13770225f4a655657373f5ad90bcef3f57c9.1741632298.git.pabeni@redhat.com (mailing list archive)
State Superseded
Delegated to: Netdev Maintainers
Headers show
Series udp_tunnel: GRO optimizations | expand

Commit Message

Paolo Abeni March 10, 2025, 7:09 p.m. UTC
It's quite common to have a single UDP tunnel type active in the
whole system. In such a case we can replace the indirect call for
the UDP tunnel GRO callback with a static call.

Add the related accounting in the control path and switch to static
call when possible. To keep the code simple use a static array for
the registered tunnel types, and size such array based on the kernel
config.

Signed-off-by: Paolo Abeni <pabeni@redhat.com>
---
v2 -> v3:
 - avoid unneeded checks in udp_tunnel_update_gro_rcv()

v1 -> v2:
 - fix UDP_TUNNEL=n build
---
 include/net/udp_tunnel.h   |   4 ++
 net/ipv4/udp_offload.c     | 137 ++++++++++++++++++++++++++++++++++++-
 net/ipv4/udp_tunnel_core.c |   2 +
 3 files changed, 142 insertions(+), 1 deletion(-)

Comments

Willem de Bruijn March 11, 2025, 2:51 a.m. UTC | #1
Paolo Abeni wrote:
> It's quite common to have a single UDP tunnel type active in the
> whole system. In such a case we can replace the indirect call for
> the UDP tunnel GRO callback with a static call.
> 
> Add the related accounting in the control path and switch to static
> call when possible. To keep the code simple use a static array for
> the registered tunnel types, and size such array based on the kernel
> config.
> 
> Signed-off-by: Paolo Abeni <pabeni@redhat.com>
> ---
> v2 -> v3:
>  - avoid unneeded checks in udp_tunnel_update_gro_rcv()
> 
> v1 -> v2:
>  - fix UDP_TUNNEL=n build
> ---
>  include/net/udp_tunnel.h   |   4 ++
>  net/ipv4/udp_offload.c     | 137 ++++++++++++++++++++++++++++++++++++-
>  net/ipv4/udp_tunnel_core.c |   2 +
>  3 files changed, 142 insertions(+), 1 deletion(-)
> 
> diff --git a/include/net/udp_tunnel.h b/include/net/udp_tunnel.h
> index eda0f3e2f65fa..a7b230867eb14 100644
> --- a/include/net/udp_tunnel.h
> +++ b/include/net/udp_tunnel.h
> @@ -205,9 +205,11 @@ static inline void udp_tunnel_encap_enable(struct sock *sk)
>  
>  #if IS_ENABLED(CONFIG_NET_UDP_TUNNEL)
>  void udp_tunnel_update_gro_lookup(struct net *net, struct sock *sk, bool add);
> +void udp_tunnel_update_gro_rcv(struct sock *sk, bool add);
>  #else
>  static inline void udp_tunnel_update_gro_lookup(struct net *net,
>  						struct sock *sk, bool add) {}
> +static inline void udp_tunnel_update_gro_rcv(struct sock *sk, bool add) {}
>  #endif
>  
>  static inline void udp_tunnel_cleanup_gro(struct sock *sk)
> @@ -215,6 +217,8 @@ static inline void udp_tunnel_cleanup_gro(struct sock *sk)
>  	struct udp_sock *up = udp_sk(sk);
>  	struct net *net = sock_net(sk);
>  
> +	udp_tunnel_update_gro_rcv(sk, false);
> +
>  	if (!up->tunnel_list.pprev)
>  		return;
>  
> diff --git a/net/ipv4/udp_offload.c b/net/ipv4/udp_offload.c
> index 054d4d4a8927f..500b2a20053cd 100644
> --- a/net/ipv4/udp_offload.c
> +++ b/net/ipv4/udp_offload.c
> @@ -15,6 +15,39 @@
>  #include <net/udp_tunnel.h>
>  
>  #if IS_ENABLED(CONFIG_NET_UDP_TUNNEL)
> +
> +/*
> + * Dummy GRO tunnel callback; should never be invoked, exists
> + * mainly to avoid dangling/NULL values for the udp tunnel
> + * static call.
> + */
> +static struct sk_buff *dummy_gro_rcv(struct sock *sk,
> +				     struct list_head *head,
> +				     struct sk_buff *skb)
> +{
> +	WARN_ON_ONCE(1);
> +	NAPI_GRO_CB(skb)->flush = 1;
> +	return NULL;
> +}
> +
> +typedef struct sk_buff *(*udp_tunnel_gro_rcv_t)(struct sock *sk,
> +						struct list_head *head,
> +						struct sk_buff *skb);
> +
> +struct udp_tunnel_type_entry {
> +	udp_tunnel_gro_rcv_t gro_receive;
> +	refcount_t count;
> +};
> +
> +#define UDP_MAX_TUNNEL_TYPES (IS_ENABLED(CONFIG_GENEVE) + \
> +			      IS_ENABLED(CONFIG_VXLAN) * 2 + \
> +			      IS_ENABLED(CONFIG_FOE) * 2)
> +
> +DEFINE_STATIC_CALL(udp_tunnel_gro_rcv, dummy_gro_rcv);
> +static DEFINE_STATIC_KEY_FALSE(udp_tunnel_static_call);
> +static struct mutex udp_tunnel_gro_type_lock;
> +static struct udp_tunnel_type_entry udp_tunnel_gro_types[UDP_MAX_TUNNEL_TYPES];
> +static unsigned int udp_tunnel_gro_type_nr;
>  static DEFINE_SPINLOCK(udp_tunnel_gro_lock);
>  
>  void udp_tunnel_update_gro_lookup(struct net *net, struct sock *sk, bool add)
> @@ -43,6 +76,106 @@ void udp_tunnel_update_gro_lookup(struct net *net, struct sock *sk, bool add)
>  	spin_unlock(&udp_tunnel_gro_lock);
>  }
>  EXPORT_SYMBOL_GPL(udp_tunnel_update_gro_lookup);
> +
> +void udp_tunnel_update_gro_rcv(struct sock *sk, bool add)
> +{
> +	struct udp_tunnel_type_entry *cur = NULL, *avail = NULL;
> +	struct udp_sock *up = udp_sk(sk);
> +	int i, old_gro_type_nr;
> +
> +	if (!up->gro_receive)
> +		return;
> +
> +	mutex_lock(&udp_tunnel_gro_type_lock);
> +	for (i = 0; i < UDP_MAX_TUNNEL_TYPES; i++) {
> +		if (!refcount_read(&udp_tunnel_gro_types[i].count))

Optionally: && !avail, to fill the list from the front. And on delete
avoid gaps. For instance, like __fanout_link/__fanout_unlink.

Can stop sooner then. And list length is then implicit as i once found
the first [i].count == zero.

Then again, this list is always short. I can imagine you prefer to
leave as is.

> +			avail = &udp_tunnel_gro_types[i];
> +		else if (udp_tunnel_gro_types[i].gro_receive == up->gro_receive)
> +			cur = &udp_tunnel_gro_types[i];
> +	}
> +	old_gro_type_nr = udp_tunnel_gro_type_nr;
> +	if (add) {
> +		/*
> +		 * Update the matching entry, if found, or add a new one
> +		 * if needed
> +		 */
> +		if (cur) {
> +			refcount_inc(&cur->count);
> +			goto out;
> +		}
> +
> +		if (unlikely(!avail)) {
> +			pr_err_once("Too many UDP tunnel types, please increase UDP_MAX_TUNNEL_TYPES\n");
> +			/* Ensure static call will never be enabled */
> +			udp_tunnel_gro_type_nr = UDP_MAX_TUNNEL_TYPES + 2;
> +			goto out;
> +		}
> +
> +		refcount_set(&avail->count, 1);
> +		avail->gro_receive = up->gro_receive;
> +		udp_tunnel_gro_type_nr++;
> +	} else {
> +		/*
> +		 * The stack cleanups only successfully added tunnel, the
> +		 * lookup on removal should never fail.
> +		 */
> +		if (WARN_ON_ONCE(!cur))
> +			goto out;
> +
> +		if (!refcount_dec_and_test(&cur->count))
> +			goto out;
> +		udp_tunnel_gro_type_nr--;
> +	}
> +
> +	if (udp_tunnel_gro_type_nr == 1) {
> +		for (i = 0; i < UDP_MAX_TUNNEL_TYPES; i++) {
> +			cur = &udp_tunnel_gro_types[i];
> +			if (refcount_read(&cur->count)) {
> +				static_call_update(udp_tunnel_gro_rcv,
> +						   cur->gro_receive);
> +				static_branch_enable(&udp_tunnel_static_call);
> +			}
> +		}
> +	} else if (old_gro_type_nr == 1) {
> +		static_branch_disable(&udp_tunnel_static_call);
> +		static_call_update(udp_tunnel_gro_rcv, dummy_gro_rcv);

These operations must not be reorderd, or dummy_gro_rcv might get hit.

If static calls are not configured, the last call is just a
WRITE_ONCE. Similar for static_branch_disable if !CONFIG_JUMP_LABEL.

> +	}
> +
> +out:
> +	mutex_unlock(&udp_tunnel_gro_type_lock);
> +}
> +EXPORT_SYMBOL_GPL(udp_tunnel_update_gro_rcv);
> +
> +static void udp_tunnel_gro_init(void)
> +{
> +	mutex_init(&udp_tunnel_gro_type_lock);
> +}
> +
> +static struct sk_buff *udp_tunnel_gro_rcv(struct sock *sk,
> +					  struct list_head *head,
> +					  struct sk_buff *skb)
> +{
> +	if (static_branch_likely(&udp_tunnel_static_call)) {
> +		if (unlikely(gro_recursion_inc_test(skb))) {
> +			NAPI_GRO_CB(skb)->flush |= 1;
> +			return NULL;
> +		}
> +		return static_call(udp_tunnel_gro_rcv)(sk, head, skb);
> +	}
> +	return call_gro_receive_sk(udp_sk(sk)->gro_receive, sk, head, skb);
> +}
> +
> +#else
> +
> +static void udp_tunnel_gro_init(void) {}
> +
> +static struct sk_buff *udp_tunnel_gro_rcv(struct sock *sk,
> +					  struct list_head *head,
> +					  struct sk_buff *skb)
> +{
> +	return call_gro_receive_sk(udp_sk(sk)->gro_receive, sk, head, skb);
> +}
> +
>  #endif
>  
>  static struct sk_buff *__skb_udp_tunnel_segment(struct sk_buff *skb,
> @@ -654,7 +787,7 @@ struct sk_buff *udp_gro_receive(struct list_head *head, struct sk_buff *skb,
>  
>  	skb_gro_pull(skb, sizeof(struct udphdr)); /* pull encapsulating udp header */
>  	skb_gro_postpull_rcsum(skb, uh, sizeof(struct udphdr));
> -	pp = call_gro_receive_sk(udp_sk(sk)->gro_receive, sk, head, skb);
> +	pp = udp_tunnel_gro_rcv(sk, head, skb);
>  
>  out:
>  	skb_gro_flush_final(skb, pp, flush);
> @@ -804,5 +937,7 @@ int __init udpv4_offload_init(void)
>  			.gro_complete =	udp4_gro_complete,
>  		},
>  	};
> +
> +	udp_tunnel_gro_init();
>  	return inet_add_offload(&net_hotdata.udpv4_offload, IPPROTO_UDP);
>  }
> diff --git a/net/ipv4/udp_tunnel_core.c b/net/ipv4/udp_tunnel_core.c
> index b5695826e57ad..c49fceea83139 100644
> --- a/net/ipv4/udp_tunnel_core.c
> +++ b/net/ipv4/udp_tunnel_core.c
> @@ -90,6 +90,8 @@ void setup_udp_tunnel_sock(struct net *net, struct socket *sock,
>  
>  	udp_tunnel_encap_enable(sk);
>  
> +	udp_tunnel_update_gro_rcv(sock->sk, true);
> +
>  	if (!sk->sk_dport && !sk->sk_bound_dev_if && sk_saddr_any(sock->sk))
>  		udp_tunnel_update_gro_lookup(net, sock->sk, true);
>  }
> -- 
> 2.48.1
>
Kuniyuki Iwashima March 11, 2025, 5:49 a.m. UTC | #2
From: Paolo Abeni <pabeni@redhat.com>
Date: Mon, 10 Mar 2025 20:09:49 +0100
> diff --git a/net/ipv4/udp_offload.c b/net/ipv4/udp_offload.c
> index 054d4d4a8927f..500b2a20053cd 100644
> --- a/net/ipv4/udp_offload.c
> +++ b/net/ipv4/udp_offload.c
> @@ -15,6 +15,39 @@
>  #include <net/udp_tunnel.h>
>  
>  #if IS_ENABLED(CONFIG_NET_UDP_TUNNEL)
> +
> +/*
> + * Dummy GRO tunnel callback; should never be invoked, exists
> + * mainly to avoid dangling/NULL values for the udp tunnel
> + * static call.
> + */
> +static struct sk_buff *dummy_gro_rcv(struct sock *sk,
> +				     struct list_head *head,
> +				     struct sk_buff *skb)
> +{
> +	WARN_ON_ONCE(1);
> +	NAPI_GRO_CB(skb)->flush = 1;
> +	return NULL;
> +}
> +
> +typedef struct sk_buff *(*udp_tunnel_gro_rcv_t)(struct sock *sk,
> +						struct list_head *head,
> +						struct sk_buff *skb);
> +
> +struct udp_tunnel_type_entry {
> +	udp_tunnel_gro_rcv_t gro_receive;
> +	refcount_t count;
> +};
> +
> +#define UDP_MAX_TUNNEL_TYPES (IS_ENABLED(CONFIG_GENEVE) + \
> +			      IS_ENABLED(CONFIG_VXLAN) * 2 + \
> +			      IS_ENABLED(CONFIG_FOE) * 2)

I guess this is CONFIG_NET_FOU ?
Paolo Abeni March 11, 2025, 5:24 p.m. UTC | #3
On 3/11/25 3:51 AM, Willem de Bruijn wrote:
> Paolo Abeni wrote:
[...]
>> +void udp_tunnel_update_gro_rcv(struct sock *sk, bool add)
>> +{
>> +	struct udp_tunnel_type_entry *cur = NULL, *avail = NULL;
>> +	struct udp_sock *up = udp_sk(sk);
>> +	int i, old_gro_type_nr;
>> +
>> +	if (!up->gro_receive)
>> +		return;
>> +
>> +	mutex_lock(&udp_tunnel_gro_type_lock);
>> +	for (i = 0; i < UDP_MAX_TUNNEL_TYPES; i++) {
>> +		if (!refcount_read(&udp_tunnel_gro_types[i].count))
> 
> Optionally: && !avail, to fill the list from the front. And on delete
> avoid gaps. For instance, like __fanout_link/__fanout_unlink.
> 
> Can stop sooner then. And list length is then implicit as i once found
> the first [i].count == zero.
> 
> Then again, this list is always short. I can imagine you prefer to
> leave as is.

I avoided optimizations for this slow path, to keep the code simpler.
Thinking again about it, avoiding gaps will simplify/cleanup the code a
bit (no need to lookup the enabled tunnel on deletion and to use `avail`
on addition), so I'll do it.

Note that I'll still need to explicitly track the number of enabled
tunnel types, as an easy way to disable the static call in the unlikely
udp_tunnel_gro_type_nr == UDP_MAX_TUNNEL_TYPES event.

[...]
>> +	if (udp_tunnel_gro_type_nr == 1) {
>> +		for (i = 0; i < UDP_MAX_TUNNEL_TYPES; i++) {
>> +			cur = &udp_tunnel_gro_types[i];
>> +			if (refcount_read(&cur->count)) {
>> +				static_call_update(udp_tunnel_gro_rcv,
>> +						   cur->gro_receive);
>> +				static_branch_enable(&udp_tunnel_static_call);
>> +			}
>> +		}
>> +	} else if (old_gro_type_nr == 1) {
>> +		static_branch_disable(&udp_tunnel_static_call);
>> +		static_call_update(udp_tunnel_gro_rcv, dummy_gro_rcv);
> 
> These operations must not be reorderd, or dummy_gro_rcv might get hit.
> 
> If static calls are not configured, the last call is just a
> WRITE_ONCE. Similar for static_branch_disable if !CONFIG_JUMP_LABEL.

When both construct are disabled, I think a wmb/rmb pair would be needed
to ensure no reordering, and that in turn looks overkill. I think it
would be better just drop the WARN_ONCE in dummy_gro_rcv().

/P
Willem de Bruijn March 11, 2025, 5:30 p.m. UTC | #4
Paolo Abeni wrote:
> On 3/11/25 3:51 AM, Willem de Bruijn wrote:
> > Paolo Abeni wrote:
> [...]
> >> +void udp_tunnel_update_gro_rcv(struct sock *sk, bool add)
> >> +{
> >> +	struct udp_tunnel_type_entry *cur = NULL, *avail = NULL;
> >> +	struct udp_sock *up = udp_sk(sk);
> >> +	int i, old_gro_type_nr;
> >> +
> >> +	if (!up->gro_receive)
> >> +		return;
> >> +
> >> +	mutex_lock(&udp_tunnel_gro_type_lock);
> >> +	for (i = 0; i < UDP_MAX_TUNNEL_TYPES; i++) {
> >> +		if (!refcount_read(&udp_tunnel_gro_types[i].count))
> > 
> > Optionally: && !avail, to fill the list from the front. And on delete
> > avoid gaps. For instance, like __fanout_link/__fanout_unlink.
> > 
> > Can stop sooner then. And list length is then implicit as i once found
> > the first [i].count == zero.
> > 
> > Then again, this list is always short. I can imagine you prefer to
> > leave as is.
> 
> I avoided optimizations for this slow path, to keep the code simpler.
> Thinking again about it, avoiding gaps will simplify/cleanup the code a
> bit (no need to lookup the enabled tunnel on deletion and to use `avail`
> on addition), so I'll do it.
> 
> Note that I'll still need to explicitly track the number of enabled
> tunnel types, as an easy way to disable the static call in the unlikely
> udp_tunnel_gro_type_nr == UDP_MAX_TUNNEL_TYPES event.
> 
> [...]
> >> +	if (udp_tunnel_gro_type_nr == 1) {
> >> +		for (i = 0; i < UDP_MAX_TUNNEL_TYPES; i++) {
> >> +			cur = &udp_tunnel_gro_types[i];
> >> +			if (refcount_read(&cur->count)) {
> >> +				static_call_update(udp_tunnel_gro_rcv,
> >> +						   cur->gro_receive);
> >> +				static_branch_enable(&udp_tunnel_static_call);
> >> +			}
> >> +		}
> >> +	} else if (old_gro_type_nr == 1) {
> >> +		static_branch_disable(&udp_tunnel_static_call);
> >> +		static_call_update(udp_tunnel_gro_rcv, dummy_gro_rcv);
> > 
> > These operations must not be reorderd, or dummy_gro_rcv might get hit.
> > 
> > If static calls are not configured, the last call is just a
> > WRITE_ONCE. Similar for static_branch_disable if !CONFIG_JUMP_LABEL.
> 
> When both construct are disabled, I think a wmb/rmb pair would be needed
> to ensure no reordering, and that in turn looks overkill. I think it
> would be better just drop the WARN_ONCE in dummy_gro_rcv().

SGTM, thanks.
diff mbox series

Patch

diff --git a/include/net/udp_tunnel.h b/include/net/udp_tunnel.h
index eda0f3e2f65fa..a7b230867eb14 100644
--- a/include/net/udp_tunnel.h
+++ b/include/net/udp_tunnel.h
@@ -205,9 +205,11 @@  static inline void udp_tunnel_encap_enable(struct sock *sk)
 
 #if IS_ENABLED(CONFIG_NET_UDP_TUNNEL)
 void udp_tunnel_update_gro_lookup(struct net *net, struct sock *sk, bool add);
+void udp_tunnel_update_gro_rcv(struct sock *sk, bool add);
 #else
 static inline void udp_tunnel_update_gro_lookup(struct net *net,
 						struct sock *sk, bool add) {}
+static inline void udp_tunnel_update_gro_rcv(struct sock *sk, bool add) {}
 #endif
 
 static inline void udp_tunnel_cleanup_gro(struct sock *sk)
@@ -215,6 +217,8 @@  static inline void udp_tunnel_cleanup_gro(struct sock *sk)
 	struct udp_sock *up = udp_sk(sk);
 	struct net *net = sock_net(sk);
 
+	udp_tunnel_update_gro_rcv(sk, false);
+
 	if (!up->tunnel_list.pprev)
 		return;
 
diff --git a/net/ipv4/udp_offload.c b/net/ipv4/udp_offload.c
index 054d4d4a8927f..500b2a20053cd 100644
--- a/net/ipv4/udp_offload.c
+++ b/net/ipv4/udp_offload.c
@@ -15,6 +15,39 @@ 
 #include <net/udp_tunnel.h>
 
 #if IS_ENABLED(CONFIG_NET_UDP_TUNNEL)
+
+/*
+ * Dummy GRO tunnel callback; should never be invoked, exists
+ * mainly to avoid dangling/NULL values for the udp tunnel
+ * static call.
+ */
+static struct sk_buff *dummy_gro_rcv(struct sock *sk,
+				     struct list_head *head,
+				     struct sk_buff *skb)
+{
+	WARN_ON_ONCE(1);
+	NAPI_GRO_CB(skb)->flush = 1;
+	return NULL;
+}
+
+typedef struct sk_buff *(*udp_tunnel_gro_rcv_t)(struct sock *sk,
+						struct list_head *head,
+						struct sk_buff *skb);
+
+struct udp_tunnel_type_entry {
+	udp_tunnel_gro_rcv_t gro_receive;
+	refcount_t count;
+};
+
+#define UDP_MAX_TUNNEL_TYPES (IS_ENABLED(CONFIG_GENEVE) + \
+			      IS_ENABLED(CONFIG_VXLAN) * 2 + \
+			      IS_ENABLED(CONFIG_FOE) * 2)
+
+DEFINE_STATIC_CALL(udp_tunnel_gro_rcv, dummy_gro_rcv);
+static DEFINE_STATIC_KEY_FALSE(udp_tunnel_static_call);
+static struct mutex udp_tunnel_gro_type_lock;
+static struct udp_tunnel_type_entry udp_tunnel_gro_types[UDP_MAX_TUNNEL_TYPES];
+static unsigned int udp_tunnel_gro_type_nr;
 static DEFINE_SPINLOCK(udp_tunnel_gro_lock);
 
 void udp_tunnel_update_gro_lookup(struct net *net, struct sock *sk, bool add)
@@ -43,6 +76,106 @@  void udp_tunnel_update_gro_lookup(struct net *net, struct sock *sk, bool add)
 	spin_unlock(&udp_tunnel_gro_lock);
 }
 EXPORT_SYMBOL_GPL(udp_tunnel_update_gro_lookup);
+
+void udp_tunnel_update_gro_rcv(struct sock *sk, bool add)
+{
+	struct udp_tunnel_type_entry *cur = NULL, *avail = NULL;
+	struct udp_sock *up = udp_sk(sk);
+	int i, old_gro_type_nr;
+
+	if (!up->gro_receive)
+		return;
+
+	mutex_lock(&udp_tunnel_gro_type_lock);
+	for (i = 0; i < UDP_MAX_TUNNEL_TYPES; i++) {
+		if (!refcount_read(&udp_tunnel_gro_types[i].count))
+			avail = &udp_tunnel_gro_types[i];
+		else if (udp_tunnel_gro_types[i].gro_receive == up->gro_receive)
+			cur = &udp_tunnel_gro_types[i];
+	}
+	old_gro_type_nr = udp_tunnel_gro_type_nr;
+	if (add) {
+		/*
+		 * Update the matching entry, if found, or add a new one
+		 * if needed
+		 */
+		if (cur) {
+			refcount_inc(&cur->count);
+			goto out;
+		}
+
+		if (unlikely(!avail)) {
+			pr_err_once("Too many UDP tunnel types, please increase UDP_MAX_TUNNEL_TYPES\n");
+			/* Ensure static call will never be enabled */
+			udp_tunnel_gro_type_nr = UDP_MAX_TUNNEL_TYPES + 2;
+			goto out;
+		}
+
+		refcount_set(&avail->count, 1);
+		avail->gro_receive = up->gro_receive;
+		udp_tunnel_gro_type_nr++;
+	} else {
+		/*
+		 * The stack cleanups only successfully added tunnel, the
+		 * lookup on removal should never fail.
+		 */
+		if (WARN_ON_ONCE(!cur))
+			goto out;
+
+		if (!refcount_dec_and_test(&cur->count))
+			goto out;
+		udp_tunnel_gro_type_nr--;
+	}
+
+	if (udp_tunnel_gro_type_nr == 1) {
+		for (i = 0; i < UDP_MAX_TUNNEL_TYPES; i++) {
+			cur = &udp_tunnel_gro_types[i];
+			if (refcount_read(&cur->count)) {
+				static_call_update(udp_tunnel_gro_rcv,
+						   cur->gro_receive);
+				static_branch_enable(&udp_tunnel_static_call);
+			}
+		}
+	} else if (old_gro_type_nr == 1) {
+		static_branch_disable(&udp_tunnel_static_call);
+		static_call_update(udp_tunnel_gro_rcv, dummy_gro_rcv);
+	}
+
+out:
+	mutex_unlock(&udp_tunnel_gro_type_lock);
+}
+EXPORT_SYMBOL_GPL(udp_tunnel_update_gro_rcv);
+
+static void udp_tunnel_gro_init(void)
+{
+	mutex_init(&udp_tunnel_gro_type_lock);
+}
+
+static struct sk_buff *udp_tunnel_gro_rcv(struct sock *sk,
+					  struct list_head *head,
+					  struct sk_buff *skb)
+{
+	if (static_branch_likely(&udp_tunnel_static_call)) {
+		if (unlikely(gro_recursion_inc_test(skb))) {
+			NAPI_GRO_CB(skb)->flush |= 1;
+			return NULL;
+		}
+		return static_call(udp_tunnel_gro_rcv)(sk, head, skb);
+	}
+	return call_gro_receive_sk(udp_sk(sk)->gro_receive, sk, head, skb);
+}
+
+#else
+
+static void udp_tunnel_gro_init(void) {}
+
+static struct sk_buff *udp_tunnel_gro_rcv(struct sock *sk,
+					  struct list_head *head,
+					  struct sk_buff *skb)
+{
+	return call_gro_receive_sk(udp_sk(sk)->gro_receive, sk, head, skb);
+}
+
 #endif
 
 static struct sk_buff *__skb_udp_tunnel_segment(struct sk_buff *skb,
@@ -654,7 +787,7 @@  struct sk_buff *udp_gro_receive(struct list_head *head, struct sk_buff *skb,
 
 	skb_gro_pull(skb, sizeof(struct udphdr)); /* pull encapsulating udp header */
 	skb_gro_postpull_rcsum(skb, uh, sizeof(struct udphdr));
-	pp = call_gro_receive_sk(udp_sk(sk)->gro_receive, sk, head, skb);
+	pp = udp_tunnel_gro_rcv(sk, head, skb);
 
 out:
 	skb_gro_flush_final(skb, pp, flush);
@@ -804,5 +937,7 @@  int __init udpv4_offload_init(void)
 			.gro_complete =	udp4_gro_complete,
 		},
 	};
+
+	udp_tunnel_gro_init();
 	return inet_add_offload(&net_hotdata.udpv4_offload, IPPROTO_UDP);
 }
diff --git a/net/ipv4/udp_tunnel_core.c b/net/ipv4/udp_tunnel_core.c
index b5695826e57ad..c49fceea83139 100644
--- a/net/ipv4/udp_tunnel_core.c
+++ b/net/ipv4/udp_tunnel_core.c
@@ -90,6 +90,8 @@  void setup_udp_tunnel_sock(struct net *net, struct socket *sock,
 
 	udp_tunnel_encap_enable(sk);
 
+	udp_tunnel_update_gro_rcv(sock->sk, true);
+
 	if (!sk->sk_dport && !sk->sk_bound_dev_if && sk_saddr_any(sock->sk))
 		udp_tunnel_update_gro_lookup(net, sock->sk, true);
 }