diff mbox series

[v2,1/4] net: tun: fix tun_xdp_one() for IFF_TUN mode

Message ID 20210622161533.1214662-1-dwmw2@infradead.org (mailing list archive)
State Changes Requested
Delegated to: Netdev Maintainers
Headers show
Series [v2,1/4] net: tun: fix tun_xdp_one() for IFF_TUN mode | expand

Commit Message

David Woodhouse June 22, 2021, 4:15 p.m. UTC
From: David Woodhouse <dwmw@amazon.co.uk>

In tun_get_user(), skb->protocol is either taken from the tun_pi header
or inferred from the first byte of the packet in IFF_TUN mode, while
eth_type_trans() is called only in the IFF_TAP mode where the payload
is expected to be an Ethernet frame.

The alternative path in tun_xdp_one() was unconditionally using
eth_type_trans(), which corrupts packets in IFF_TUN mode. Fix it to
do the correct thing for IFF_TUN mode, as tun_get_user() does.

Fixes: 043d222f93ab ("tuntap: accept an array of XDP buffs through sendmsg()")
Signed-off-by: David Woodhouse <dwmw@amazon.co.uk>
---
 drivers/net/tun.c | 44 +++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 43 insertions(+), 1 deletion(-)

Comments

Jason Wang June 23, 2021, 3:45 a.m. UTC | #1
在 2021/6/23 上午12:15, David Woodhouse 写道:
> From: David Woodhouse <dwmw@amazon.co.uk>
>
> In tun_get_user(), skb->protocol is either taken from the tun_pi header
> or inferred from the first byte of the packet in IFF_TUN mode, while
> eth_type_trans() is called only in the IFF_TAP mode where the payload
> is expected to be an Ethernet frame.
>
> The alternative path in tun_xdp_one() was unconditionally using
> eth_type_trans(), which corrupts packets in IFF_TUN mode. Fix it to
> do the correct thing for IFF_TUN mode, as tun_get_user() does.
>
> Fixes: 043d222f93ab ("tuntap: accept an array of XDP buffs through sendmsg()")
> Signed-off-by: David Woodhouse <dwmw@amazon.co.uk>
> ---
>   drivers/net/tun.c | 44 +++++++++++++++++++++++++++++++++++++++++++-
>   1 file changed, 43 insertions(+), 1 deletion(-)
>
> diff --git a/drivers/net/tun.c b/drivers/net/tun.c
> index 4cf38be26dc9..f812dcdc640e 100644
> --- a/drivers/net/tun.c
> +++ b/drivers/net/tun.c
> @@ -2394,8 +2394,50 @@ static int tun_xdp_one(struct tun_struct *tun,
>   		err = -EINVAL;
>   		goto out;
>   	}
> +	switch (tun->flags & TUN_TYPE_MASK) {
> +	case IFF_TUN:
> +		if (tun->flags & IFF_NO_PI) {
> +			u8 ip_version = skb->len ? (skb->data[0] >> 4) : 0;
> +
> +			switch (ip_version) {
> +			case 4:
> +				skb->protocol = htons(ETH_P_IP);
> +				break;
> +			case 6:
> +				skb->protocol = htons(ETH_P_IPV6);
> +				break;
> +			default:
> +				atomic_long_inc(&tun->dev->rx_dropped);
> +				kfree_skb(skb);
> +				err = -EINVAL;
> +				goto out;
> +			}
> +		} else {
> +			struct tun_pi *pi = (struct tun_pi *)skb->data;
> +			if (!pskb_may_pull(skb, sizeof(*pi))) {
> +				atomic_long_inc(&tun->dev->rx_dropped);
> +				kfree_skb(skb);
> +				err = -ENOMEM;
> +				goto out;
> +			}
> +			skb_pull_inline(skb, sizeof(*pi));
> +			skb->protocol = pi->proto;


As replied in previous version, it would be better if we can unify 
similar logic in tun_get_user().

Thanks


> +		}
> +
> +		skb_reset_mac_header(skb);
> +		skb->dev = tun->dev;
> +		break;
> +	case IFF_TAP:
> +		if (!pskb_may_pull(skb, ETH_HLEN)) {
> +			atomic_long_inc(&tun->dev->rx_dropped);
> +			kfree_skb(skb);
> +			err = -ENOMEM;
> +			goto out;
> +		}
> +		skb->protocol = eth_type_trans(skb, tun->dev);
> +		break;
> +	}
>   
> -	skb->protocol = eth_type_trans(skb, tun->dev);
>   	skb_reset_network_header(skb);
>   	skb_probe_transport_header(skb);
>   	skb_record_rx_queue(skb, tfile->queue_index);
David Woodhouse June 23, 2021, 8:30 a.m. UTC | #2
On Wed, 2021-06-23 at 11:45 +0800, Jason Wang wrote:
> As replied in previous version, it would be better if we can unify 
> similar logic in tun_get_user().

Ah sorry, I missed that the first time.

Yes, that was my initial inclination too. But in the tun_get_user()
case we already *have* "pi", having read it separately into a local
variable; it never made it to the skb. So the cases are subtly
different enough that abstracting it out didn't seem to make sense.

If I try harder to unify it, I suppose it looks something like this and
*might* just make the cut for "simple enough to be backported to stable
kernels in a bug fix".

I'll add the PI mode to my test cases and try it (as well as *actually*
unifying the offending code, of course).

--- a/drivers/net/tun.c
+++ b/drivers/net/tun.c
@@ -2332,10 +2332,12 @@ static int tun_xdp_one(struct tun_struct *tun,
        unsigned int datasize = xdp->data_end - xdp->data;
        struct tun_xdp_hdr *hdr = xdp->data_hard_start;
        struct virtio_net_hdr *gso = NULL;
+       struct tun_pi *pi = NULL;
        struct bpf_prog *xdp_prog;
        struct sk_buff *skb = NULL;
        u32 rxhash = 0, act;
        int buflen = hdr->buflen;
+       int reservelen = xdp->data - xdp->data_hard_start;
        int err = 0;
        bool skb_xdp = false;
        struct page *page;
@@ -2343,6 +2345,11 @@ static int tun_xdp_one(struct tun_struct *tun,
        if (tun->flags & IFF_VNET_HDR)
                gso = &hdr->gso;
 
+       if (!(tun->flags & IFF_NO_PI)) {
+               pi = xdp->data;
+               reservelen += sizeof(*pi);
+       }
+
        xdp_prog = rcu_dereference(tun->xdp_prog);
        if (xdp_prog) {
                if (gso && gso->gso_type) {
@@ -2388,7 +2395,7 @@ static int tun_xdp_one(struct tun_struct *tun,
                goto out;
        }
 
-       skb_reserve(skb, xdp->data - xdp->data_hard_start);
+       skb_reserve(skb, reservelen);
        skb_put(skb, xdp->data_end - xdp->data);
 
        if (gso && virtio_net_hdr_to_skb(skb, gso, tun_is_little_endian(tun))) {
David Woodhouse June 23, 2021, 1:52 p.m. UTC | #3
On Wed, 2021-06-23 at 11:45 +0800, Jason Wang wrote:
> 
> As replied in previous version, it would be better if we can unify 
> similar logic in tun_get_user().

So that ends up looking something like this (incremental).

Note the '/* XXX: frags && */' part in tun_skb_set_protocol(), because
the 'frags &&' was there in tun_get_user() and it wasn't obvious to me
whether I should be lifting that out as a separate argument to
tun_skb_set_protocol() or if there's a better way.

Either way, in my judgement this is less suitable for a stable fix and
more appropriate for a follow-on cleanup. But I don't feel that
strongly; I'm more than happy for you to overrule me on that.
Especially if you fix the above XXX part while you're at it :)

I tested this with vhost-net and !IFF_NO_PI, and TX works. RX is still
hosed on the vhost-net side, for the same reason that a bunch of test
cases were already listed in #if 0, but I'll address that in a separate
email. It's not part of *this* patch.

--- a/drivers/net/tun.c
+++ b/drivers/net/tun.c
@@ -1641,6 +1641,40 @@ static struct sk_buff *tun_build_skb(struct tun_struct *tun,
 	return NULL;
 }
 
+static int tun_skb_set_protocol(struct tun_struct *tun, struct sk_buff *skb,
+				__be16 pi_proto)
+{
+	switch (tun->flags & TUN_TYPE_MASK) {
+	case IFF_TUN:
+		if (tun->flags & IFF_NO_PI) {
+			u8 ip_version = skb->len ? (skb->data[0] >> 4) : 0;
+
+			switch (ip_version) {
+			case 4:
+				pi_proto = htons(ETH_P_IP);
+				break;
+			case 6:
+				pi_proto = htons(ETH_P_IPV6);
+				break;
+			default:
+				return -EINVAL;
+			}
+		}
+
+		skb_reset_mac_header(skb);
+		skb->protocol = pi_proto;
+		skb->dev = tun->dev;
+		break;
+	case IFF_TAP:
+		if (/* XXX frags && */!pskb_may_pull(skb, ETH_HLEN))
+			return -ENOMEM;
+
+		skb->protocol = eth_type_trans(skb, tun->dev);
+		break;
+	}
+	return 0;
+}
+
 /* Get packet from user space buffer */
 static ssize_t tun_get_user(struct tun_struct *tun, struct tun_file *tfile,
 			    void *msg_control, struct iov_iter *from,
@@ -1784,37 +1818,9 @@ static ssize_t tun_get_user(struct tun_struct *tun, struct tun_file *tfile,
 		return -EINVAL;
 	}
 
-	switch (tun->flags & TUN_TYPE_MASK) {
-	case IFF_TUN:
-		if (tun->flags & IFF_NO_PI) {
-			u8 ip_version = skb->len ? (skb->data[0] >> 4) : 0;
-
-			switch (ip_version) {
-			case 4:
-				pi.proto = htons(ETH_P_IP);
-				break;
-			case 6:
-				pi.proto = htons(ETH_P_IPV6);
-				break;
-			default:
-				atomic_long_inc(&tun->dev->rx_dropped);
-				kfree_skb(skb);
-				return -EINVAL;
-			}
-		}
-
-		skb_reset_mac_header(skb);
-		skb->protocol = pi.proto;
-		skb->dev = tun->dev;
-		break;
-	case IFF_TAP:
-		if (frags && !pskb_may_pull(skb, ETH_HLEN)) {
-			err = -ENOMEM;
-			goto drop;
-		}
-		skb->protocol = eth_type_trans(skb, tun->dev);
-		break;
-	}
+	err = tun_skb_set_protocol(tun, skb, pi.proto);
+	if (err)
+		goto drop;
 
 	/* copy skb_ubuf_info for callback when skb has no error */
 	if (zerocopy) {
@@ -2334,8 +2340,10 @@ static int tun_xdp_one(struct tun_struct *tun,
 	struct virtio_net_hdr *gso = NULL;
 	struct bpf_prog *xdp_prog;
 	struct sk_buff *skb = NULL;
+	__be16 proto = 0;
 	u32 rxhash = 0, act;
 	int buflen = hdr->buflen;
+	int reservelen = xdp->data - xdp->data_hard_start;
 	int err = 0;
 	bool skb_xdp = false;
 	struct page *page;
@@ -2343,6 +2351,17 @@ static int tun_xdp_one(struct tun_struct *tun,
 	if (tun->flags & IFF_VNET_HDR)
 		gso = &hdr->gso;
 
+	if (!(tun->flags & IFF_NO_PI)) {
+		struct tun_pi *pi = xdp->data;
+		if (datasize < sizeof(*pi)) {
+			atomic_long_inc(&tun->rx_frame_errors);
+			return  -EINVAL;
+		}
+		proto = pi->proto;
+		reservelen += sizeof(*pi);
+		datasize -= sizeof(*pi);
+	}
+
 	xdp_prog = rcu_dereference(tun->xdp_prog);
 	if (xdp_prog) {
 		if (gso && gso->gso_type) {
@@ -2388,8 +2407,8 @@ static int tun_xdp_one(struct tun_struct *tun,
 		goto out;
 	}
 
-	skb_reserve(skb, xdp->data - xdp->data_hard_start);
-	skb_put(skb, xdp->data_end - xdp->data);
+	skb_reserve(skb, reservelen);
+	skb_put(skb, datasize);
 
 	if (gso && virtio_net_hdr_to_skb(skb, gso, tun_is_little_endian(tun))) {
 		atomic_long_inc(&tun->rx_frame_errors);
@@ -2397,48 +2416,12 @@ static int tun_xdp_one(struct tun_struct *tun,
 		err = -EINVAL;
 		goto out;
 	}
-	switch (tun->flags & TUN_TYPE_MASK) {
-	case IFF_TUN:
-		if (tun->flags & IFF_NO_PI) {
-			u8 ip_version = skb->len ? (skb->data[0] >> 4) : 0;
 
-			switch (ip_version) {
-			case 4:
-				skb->protocol = htons(ETH_P_IP);
-				break;
-			case 6:
-				skb->protocol = htons(ETH_P_IPV6);
-				break;
-			default:
-				atomic_long_inc(&tun->dev->rx_dropped);
-				kfree_skb(skb);
-				err = -EINVAL;
-				goto out;
-			}
-		} else {
-			struct tun_pi *pi = (struct tun_pi *)skb->data;
-			if (!pskb_may_pull(skb, sizeof(*pi))) {
-				atomic_long_inc(&tun->dev->rx_dropped);
-				kfree_skb(skb);
-				err = -ENOMEM;
-				goto out;
-			}
-			skb_pull_inline(skb, sizeof(*pi));
-			skb->protocol = pi->proto;
-		}
-
-		skb_reset_mac_header(skb);
-		skb->dev = tun->dev;
-		break;
-	case IFF_TAP:
-		if (!pskb_may_pull(skb, ETH_HLEN)) {
-			atomic_long_inc(&tun->dev->rx_dropped);
-			kfree_skb(skb);
-			err = -ENOMEM;
-			goto out;
-		}
-		skb->protocol = eth_type_trans(skb, tun->dev);
-		break;
+	err = tun_skb_set_protocol(tun, skb, proto);
+	if (err) {
+		atomic_long_inc(&tun->dev->rx_dropped);
+		kfree_skb(skb);
+		goto out;
 	}
 
 	skb_reset_network_header(skb);
David Woodhouse June 23, 2021, 5:31 p.m. UTC | #4
On Wed, 2021-06-23 at 14:52 +0100, David Woodhouse wrote:
> @@ -2343,6 +2351,17 @@ static int tun_xdp_one(struct tun_struct *tun,
>         if (tun->flags & IFF_VNET_HDR)
>                 gso = &hdr->gso;
>  
> +       if (!(tun->flags & IFF_NO_PI)) {
> +               struct tun_pi *pi = xdp->data;
> +               if (datasize < sizeof(*pi)) {
> +                       atomic_long_inc(&tun->rx_frame_errors);
> +                       return  -EINVAL;
> +               }
> +               proto = pi->proto;
> +               reservelen += sizeof(*pi);
> +               datasize -= sizeof(*pi);
> +       }
> +
>         xdp_prog = rcu_dereference(tun->xdp_prog);
>         if (xdp_prog) {
>                 if (gso && gso->gso_type) {

Joy... that's wrong because when tun does both the PI and the vnet
headers, the PI header comes *first*. When tun does only PI and vhost
does the vnet headers, they come in the other order.

Will fix (and adjust the test cases to cope).
David Woodhouse June 23, 2021, 10:52 p.m. UTC | #5
On Wed, 2021-06-23 at 18:31 +0100, David Woodhouse wrote:
> 
> Joy... that's wrong because when tun does both the PI and the vnet
> headers, the PI header comes *first*. When tun does only PI and vhost
> does the vnet headers, they come in the other order.
> 
> Will fix (and adjust the test cases to cope).


I got this far, pushed to
https://git.infradead.org/users/dwmw2/linux.git/shortlog/refs/heads/vhost-net

All the test cases are now passing. I don't guarantee I haven't
actually broken qemu and IFF_TAP mode though, mind you :)

I'll need to refactor the intermediate commits a little so I won't
repost the series quite yet, but figured I should at least show what I
have for comments, as my day ends and yours begins.


As discussed, I expanded tun_get_socket()/tap_get_socket() to return
the actual header length instead of letting vhost make wild guesses.
Note that in doing so, I have made tun_get_socket() return -ENOTCONN if
the tun fd *isn't* actually attached (TUNSETIFF) to a real device yet.

I moved the sanity check back to tun/tap instead of doing it in
vhost_net_build_xdp(), because the latter has no clue about the tun PI
header and doesn't know *where* the virtio header is.


diff --git a/drivers/net/tap.c b/drivers/net/tap.c
index 8e3a28ba6b28..d1b1f1de374e 100644
--- a/drivers/net/tap.c
+++ b/drivers/net/tap.c
@@ -1132,16 +1132,35 @@ static const struct file_operations tap_fops = {
 static int tap_get_user_xdp(struct tap_queue *q, struct xdp_buff *xdp)
 {
 	struct tun_xdp_hdr *hdr = xdp->data_hard_start;
-	struct virtio_net_hdr *gso = &hdr->gso;
+	struct virtio_net_hdr *gso = NULL;
 	int buflen = hdr->buflen;
 	int vnet_hdr_len = 0;
 	struct tap_dev *tap;
 	struct sk_buff *skb;
 	int err, depth;
 
-	if (q->flags & IFF_VNET_HDR)
+	if (q->flags & IFF_VNET_HDR) {
 		vnet_hdr_len = READ_ONCE(q->vnet_hdr_sz);
+		if (xdp->data != xdp->data_hard_start + sizeof(*hdr) + vnet_hdr_len) {
+			err = -EINVAL;
+			goto err;
+		}
+
+		gso = (void *)&hdr[1];
+
+		if ((gso->flags & VIRTIO_NET_HDR_F_NEEDS_CSUM) &&
+		     tap16_to_cpu(q, gso->csum_start) +
+		     tap16_to_cpu(q, gso->csum_offset) + 2 >
+			     tap16_to_cpu(q, gso->hdr_len))
+			gso->hdr_len = cpu_to_tap16(q,
+				 tap16_to_cpu(q, gso->csum_start) +
+				 tap16_to_cpu(q, gso->csum_offset) + 2);
 
+		if (tap16_to_cpu(q, gso->hdr_len) > xdp->data_end - xdp->data) {
+			err = -EINVAL;
+			goto err;
+		}
+	}
 	skb = build_skb(xdp->data_hard_start, buflen);
 	if (!skb) {
 		err = -ENOMEM;
@@ -1155,7 +1174,7 @@ static int tap_get_user_xdp(struct tap_queue *q, struct xdp_buff *xdp)
 	skb_reset_mac_header(skb);
 	skb->protocol = eth_hdr(skb)->h_proto;
 
-	if (vnet_hdr_len) {
+	if (gso) {
 		err = virtio_net_hdr_to_skb(skb, gso, tap_is_little_endian(q));
 		if (err)
 			goto err_kfree;
@@ -1246,7 +1265,7 @@ static const struct proto_ops tap_socket_ops = {
  * attached to a device.  The returned object works like a packet socket, it
  * can be used for sock_sendmsg/sock_recvmsg.  The caller is responsible for
  * holding a reference to the file for as long as the socket is in use. */
-struct socket *tap_get_socket(struct file *file)
+struct socket *tap_get_socket(struct file *file, size_t *hlen)
 {
 	struct tap_queue *q;
 	if (file->f_op != &tap_fops)
@@ -1254,6 +1273,9 @@ struct socket *tap_get_socket(struct file *file)
 	q = file->private_data;
 	if (!q)
 		return ERR_PTR(-EBADFD);
+	if (hlen)
+		*hlen = (q->flags & IFF_VNET_HDR) ? q->vnet_hdr_sz : 0;
+
 	return &q->sock;
 }
 EXPORT_SYMBOL_GPL(tap_get_socket);
diff --git a/drivers/net/tun.c b/drivers/net/tun.c
index 4cf38be26dc9..72f8a04f493b 100644
--- a/drivers/net/tun.c
+++ b/drivers/net/tun.c
@@ -1641,6 +1641,40 @@ static struct sk_buff *tun_build_skb(struct tun_struct *tun,
 	return NULL;
 }
 
+static int tun_skb_set_protocol(struct tun_struct *tun, struct sk_buff *skb,
+				__be16 pi_proto)
+{
+	switch (tun->flags & TUN_TYPE_MASK) {
+	case IFF_TUN:
+		if (tun->flags & IFF_NO_PI) {
+			u8 ip_version = skb->len ? (skb->data[0] >> 4) : 0;
+
+			switch (ip_version) {
+			case 4:
+				pi_proto = htons(ETH_P_IP);
+				break;
+			case 6:
+				pi_proto = htons(ETH_P_IPV6);
+				break;
+			default:
+				return -EINVAL;
+			}
+		}
+
+		skb_reset_mac_header(skb);
+		skb->protocol = pi_proto;
+		skb->dev = tun->dev;
+		break;
+	case IFF_TAP:
+		if (/* frags && */!pskb_may_pull(skb, ETH_HLEN))
+			return -ENOMEM;
+
+		skb->protocol = eth_type_trans(skb, tun->dev);
+		break;
+	}
+	return 0;
+}
+
 /* Get packet from user space buffer */
 static ssize_t tun_get_user(struct tun_struct *tun, struct tun_file *tfile,
 			    void *msg_control, struct iov_iter *from,
@@ -1784,37 +1818,9 @@ static ssize_t tun_get_user(struct tun_struct *tun, struct tun_file *tfile,
 		return -EINVAL;
 	}
 
-	switch (tun->flags & TUN_TYPE_MASK) {
-	case IFF_TUN:
-		if (tun->flags & IFF_NO_PI) {
-			u8 ip_version = skb->len ? (skb->data[0] >> 4) : 0;
-
-			switch (ip_version) {
-			case 4:
-				pi.proto = htons(ETH_P_IP);
-				break;
-			case 6:
-				pi.proto = htons(ETH_P_IPV6);
-				break;
-			default:
-				atomic_long_inc(&tun->dev->rx_dropped);
-				kfree_skb(skb);
-				return -EINVAL;
-			}
-		}
-
-		skb_reset_mac_header(skb);
-		skb->protocol = pi.proto;
-		skb->dev = tun->dev;
-		break;
-	case IFF_TAP:
-		if (frags && !pskb_may_pull(skb, ETH_HLEN)) {
-			err = -ENOMEM;
-			goto drop;
-		}
-		skb->protocol = eth_type_trans(skb, tun->dev);
-		break;
-	}
+	err = tun_skb_set_protocol(tun, skb, pi.proto);
+	if (err)
+		goto drop;
 
 	/* copy skb_ubuf_info for callback when skb has no error */
 	if (zerocopy) {
@@ -2331,18 +2337,48 @@ static int tun_xdp_one(struct tun_struct *tun,
 {
 	unsigned int datasize = xdp->data_end - xdp->data;
 	struct tun_xdp_hdr *hdr = xdp->data_hard_start;
-	struct virtio_net_hdr *gso = &hdr->gso;
+	void *tun_hdr = &hdr[1];
+	struct virtio_net_hdr *gso = NULL;
 	struct bpf_prog *xdp_prog;
 	struct sk_buff *skb = NULL;
+	__be16 proto = 0;
 	u32 rxhash = 0, act;
 	int buflen = hdr->buflen;
 	int err = 0;
 	bool skb_xdp = false;
 	struct page *page;
 
+	if (!(tun->flags & IFF_NO_PI)) {
+		struct tun_pi *pi = tun_hdr;
+		tun_hdr += sizeof(*pi);
+
+		if (tun_hdr > xdp->data) {
+			atomic_long_inc(&tun->rx_frame_errors);
+			return -EINVAL;
+		}
+		proto = pi->proto;
+	}
+
+	if (tun->flags & IFF_VNET_HDR) {
+		gso = tun_hdr;
+		tun_hdr += sizeof(*gso);
+
+		if (tun_hdr > xdp->data) {
+			atomic_long_inc(&tun->rx_frame_errors);
+			return -EINVAL;
+		}
+
+		if ((gso->flags & VIRTIO_NET_HDR_F_NEEDS_CSUM) &&
+		    tun16_to_cpu(tun, gso->csum_start) + tun16_to_cpu(tun, gso->csum_offset) + 2 > tun16_to_cpu(tun, gso->hdr_len))
+			gso->hdr_len = cpu_to_tun16(tun, tun16_to_cpu(tun, gso->csum_start) + tun16_to_cpu(tun, gso->csum_offset) + 2);
+
+		if (tun16_to_cpu(tun, gso->hdr_len) > datasize)
+			return -EINVAL;
+	}
+
 	xdp_prog = rcu_dereference(tun->xdp_prog);
 	if (xdp_prog) {
-		if (gso->gso_type) {
+		if (gso && gso->gso_type) {
 			skb_xdp = true;
 			goto build;
 		}
@@ -2386,16 +2422,22 @@ static int tun_xdp_one(struct tun_struct *tun,
 	}
 
 	skb_reserve(skb, xdp->data - xdp->data_hard_start);
-	skb_put(skb, xdp->data_end - xdp->data);
+	skb_put(skb, datasize);
 
-	if (virtio_net_hdr_to_skb(skb, gso, tun_is_little_endian(tun))) {
+	if (gso && virtio_net_hdr_to_skb(skb, gso, tun_is_little_endian(tun))) {
 		atomic_long_inc(&tun->rx_frame_errors);
 		kfree_skb(skb);
 		err = -EINVAL;
 		goto out;
 	}
 
-	skb->protocol = eth_type_trans(skb, tun->dev);
+	err = tun_skb_set_protocol(tun, skb, proto);
+	if (err) {
+		atomic_long_inc(&tun->dev->rx_dropped);
+		kfree_skb(skb);
+		goto out;
+	}
+
 	skb_reset_network_header(skb);
 	skb_probe_transport_header(skb);
 	skb_record_rx_queue(skb, tfile->queue_index);
@@ -3649,7 +3691,7 @@ static void tun_cleanup(void)
  * attached to a device.  The returned object works like a packet socket, it
  * can be used for sock_sendmsg/sock_recvmsg.  The caller is responsible for
  * holding a reference to the file for as long as the socket is in use. */
-struct socket *tun_get_socket(struct file *file)
+struct socket *tun_get_socket(struct file *file, size_t *hlen)
 {
 	struct tun_file *tfile;
 	if (file->f_op != &tun_fops)
@@ -3657,6 +3699,20 @@ struct socket *tun_get_socket(struct file *file)
 	tfile = file->private_data;
 	if (!tfile)
 		return ERR_PTR(-EBADFD);
+
+	if (hlen) {
+		struct tun_struct *tun = tun_get(tfile);
+		size_t len = 0;
+
+		if (!tun)
+			return ERR_PTR(-ENOTCONN);
+		if (tun->flags & IFF_VNET_HDR)
+			len += READ_ONCE(tun->vnet_hdr_sz);
+		if (!(tun->flags & IFF_NO_PI))
+			len += sizeof(struct tun_pi);
+		tun_put(tun);
+		*hlen = len;
+	}
 	return &tfile->socket;
 }
 EXPORT_SYMBOL_GPL(tun_get_socket);
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index df82b124170e..d9491c620a9c 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -690,7 +690,6 @@ static int vhost_net_build_xdp(struct vhost_net_virtqueue *nvq,
 					     dev);
 	struct socket *sock = vhost_vq_get_backend(vq);
 	struct page_frag *alloc_frag = &net->page_frag;
-	struct virtio_net_hdr *gso;
 	struct xdp_buff *xdp = &nvq->xdp[nvq->batched_xdp];
 	struct tun_xdp_hdr *hdr;
 	size_t len = iov_iter_count(from);
@@ -715,29 +714,18 @@ static int vhost_net_build_xdp(struct vhost_net_virtqueue *nvq,
 		return -ENOMEM;
 
 	buf = (char *)page_address(alloc_frag->page) + alloc_frag->offset;
-	copied = copy_page_from_iter(alloc_frag->page,
-				     alloc_frag->offset +
-				     offsetof(struct tun_xdp_hdr, gso),
-				     sock_hlen, from);
-	if (copied != sock_hlen)
-		return -EFAULT;
-
 	hdr = buf;
-	gso = &hdr->gso;
-
-	if ((gso->flags & VIRTIO_NET_HDR_F_NEEDS_CSUM) &&
-	    vhost16_to_cpu(vq, gso->csum_start) +
-	    vhost16_to_cpu(vq, gso->csum_offset) + 2 >
-	    vhost16_to_cpu(vq, gso->hdr_len)) {
-		gso->hdr_len = cpu_to_vhost16(vq,
-			       vhost16_to_cpu(vq, gso->csum_start) +
-			       vhost16_to_cpu(vq, gso->csum_offset) + 2);
-
-		if (vhost16_to_cpu(vq, gso->hdr_len) > len)
-			return -EINVAL;
+	if (sock_hlen) {
+		copied = copy_page_from_iter(alloc_frag->page,
+					     alloc_frag->offset +
+					     sizeof(struct tun_xdp_hdr),
+					     sock_hlen, from);
+		if (copied != sock_hlen)
+			return -EFAULT;
+
+		len -= sock_hlen;
 	}
 
-	len -= sock_hlen;
 	copied = copy_page_from_iter(alloc_frag->page,
 				     alloc_frag->offset + pad,
 				     len, from);
@@ -1420,7 +1408,7 @@ static int vhost_net_release(struct inode *inode, struct file *f)
 	return 0;
 }
 
-static struct socket *get_raw_socket(int fd)
+static struct socket *get_raw_socket(int fd, size_t *hlen)
 {
 	int r;
 	struct socket *sock = sockfd_lookup(fd, &r);
@@ -1438,6 +1426,7 @@ static struct socket *get_raw_socket(int fd)
 		r = -EPFNOSUPPORT;
 		goto err;
 	}
+	*hlen = 0;
 	return sock;
 err:
 	sockfd_put(sock);
@@ -1463,33 +1452,33 @@ static struct ptr_ring *get_tap_ptr_ring(int fd)
 	return ring;
 }
 
-static struct socket *get_tap_socket(int fd)
+static struct socket *get_tap_socket(int fd, size_t *hlen)
 {
 	struct file *file = fget(fd);
 	struct socket *sock;
 
 	if (!file)
 		return ERR_PTR(-EBADF);
-	sock = tun_get_socket(file);
+	sock = tun_get_socket(file, hlen);
 	if (!IS_ERR(sock))
 		return sock;
-	sock = tap_get_socket(file);
+	sock = tap_get_socket(file, hlen);
 	if (IS_ERR(sock))
 		fput(file);
 	return sock;
 }
 
-static struct socket *get_socket(int fd)
+static struct socket *get_socket(int fd, size_t *hlen)
 {
 	struct socket *sock;
 
 	/* special case to disable backend */
 	if (fd == -1)
 		return NULL;
-	sock = get_raw_socket(fd);
+	sock = get_raw_socket(fd, hlen);
 	if (!IS_ERR(sock))
 		return sock;
-	sock = get_tap_socket(fd);
+	sock = get_tap_socket(fd, hlen);
 	if (!IS_ERR(sock))
 		return sock;
 	return ERR_PTR(-ENOTSOCK);
@@ -1521,7 +1510,7 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
 		r = -EFAULT;
 		goto err_vq;
 	}
-	sock = get_socket(fd);
+	sock = get_socket(fd, &nvq->sock_hlen);
 	if (IS_ERR(sock)) {
 		r = PTR_ERR(sock);
 		goto err_vq;
@@ -1621,7 +1610,7 @@ static long vhost_net_reset_owner(struct vhost_net *n)
 
 static int vhost_net_set_features(struct vhost_net *n, u64 features)
 {
-	size_t vhost_hlen, sock_hlen, hdr_len;
+	size_t vhost_hlen, hdr_len;
 	int i;
 
 	hdr_len = (features & ((1ULL << VIRTIO_NET_F_MRG_RXBUF) |
@@ -1631,11 +1620,8 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features)
 	if (features & (1 << VHOST_NET_F_VIRTIO_NET_HDR)) {
 		/* vhost provides vnet_hdr */
 		vhost_hlen = hdr_len;
-		sock_hlen = 0;
 	} else {
-		/* socket provides vnet_hdr */
 		vhost_hlen = 0;
-		sock_hlen = hdr_len;
 	}
 	mutex_lock(&n->dev.mutex);
 	if ((features & (1 << VHOST_F_LOG_ALL)) &&
@@ -1651,7 +1637,6 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features)
 		mutex_lock(&n->vqs[i].vq.mutex);
 		n->vqs[i].vq.acked_features = features;
 		n->vqs[i].vhost_hlen = vhost_hlen;
-		n->vqs[i].sock_hlen = sock_hlen;
 		mutex_unlock(&n->vqs[i].vq.mutex);
 	}
 	mutex_unlock(&n->dev.mutex);
diff --git a/include/linux/if_tap.h b/include/linux/if_tap.h
index 915a187cfabd..b460ba98f34e 100644
--- a/include/linux/if_tap.h
+++ b/include/linux/if_tap.h
@@ -3,14 +3,14 @@
 #define _LINUX_IF_TAP_H_
 
 #if IS_ENABLED(CONFIG_TAP)
-struct socket *tap_get_socket(struct file *);
+struct socket *tap_get_socket(struct file *, size_t *);
 struct ptr_ring *tap_get_ptr_ring(struct file *file);
 #else
 #include <linux/err.h>
 #include <linux/errno.h>
 struct file;
 struct socket;
-static inline struct socket *tap_get_socket(struct file *f)
+static inline struct socket *tap_get_socket(struct file *f, size_t *)
 {
 	return ERR_PTR(-EINVAL);
 }
diff --git a/include/linux/if_tun.h b/include/linux/if_tun.h
index 2a7660843444..8d78b6bbc228 100644
--- a/include/linux/if_tun.h
+++ b/include/linux/if_tun.h
@@ -21,11 +21,10 @@ struct tun_msg_ctl {
 
 struct tun_xdp_hdr {
 	int buflen;
-	struct virtio_net_hdr gso;
 };
 
 #if defined(CONFIG_TUN) || defined(CONFIG_TUN_MODULE)
-struct socket *tun_get_socket(struct file *);
+struct socket *tun_get_socket(struct file *, size_t *);
 struct ptr_ring *tun_get_tx_ring(struct file *file);
 static inline bool tun_is_xdp_frame(void *ptr)
 {
@@ -45,7 +44,7 @@ void tun_ptr_free(void *ptr);
 #include <linux/errno.h>
 struct file;
 struct socket;
-static inline struct socket *tun_get_socket(struct file *f)
+static inline struct socket *tun_get_socket(struct file *f, size_t *)
 {
 	return ERR_PTR(-EINVAL);
 }
diff --git a/tools/testing/selftests/Makefile b/tools/testing/selftests/Makefile
index 6c575cf34a71..300c03cfd0c7 100644
--- a/tools/testing/selftests/Makefile
+++ b/tools/testing/selftests/Makefile
@@ -71,6 +71,7 @@ TARGETS += user
 TARGETS += vDSO
 TARGETS += vm
 TARGETS += x86
+TARGETS += vhost
 TARGETS += zram
 #Please keep the TARGETS list alphabetically sorted
 # Run "make quicktest=1 run_tests" or
diff --git a/tools/testing/selftests/vhost/Makefile b/tools/testing/selftests/vhost/Makefile
new file mode 100644
index 000000000000..f5e565d80733
--- /dev/null
+++ b/tools/testing/selftests/vhost/Makefile
@@ -0,0 +1,16 @@
+# SPDX-License-Identifier: GPL-2.0
+all:
+
+include ../lib.mk
+
+.PHONY: all clean
+
+BINARIES := test_vhost_net
+
+test_vhost_net: test_vhost_net.c ../kselftest.h ../kselftest_harness.h
+	$(CC) $(CFLAGS) -g $< -o $@
+
+TEST_PROGS += $(BINARIES)
+EXTRA_CLEAN := $(BINARIES)
+
+all: $(BINARIES)
diff --git a/tools/testing/selftests/vhost/config b/tools/testing/selftests/vhost/config
new file mode 100644
index 000000000000..6391c1f32c34
--- /dev/null
+++ b/tools/testing/selftests/vhost/config
@@ -0,0 +1,2 @@
+CONFIG_VHOST_NET=y
+CONFIG_TUN=y
diff --git a/tools/testing/selftests/vhost/test_vhost_net.c b/tools/testing/selftests/vhost/test_vhost_net.c
new file mode 100644
index 000000000000..747f0e5e4f57
--- /dev/null
+++ b/tools/testing/selftests/vhost/test_vhost_net.c
@@ -0,0 +1,530 @@
+// SPDX-License-Identifier: LGPL-2.1
+
+#include "../kselftest_harness.h"
+#include "../../../virtio/asm/barrier.h"
+
+#include <sys/eventfd.h>
+
+#include <sys/types.h>
+#include <sys/stat.h>
+
+#include <fcntl.h>
+#include <unistd.h>
+#include <sys/wait.h>
+#include <sys/ioctl.h>
+#include <errno.h>
+#include <stdio.h>
+#include <stdlib.h>
+
+#include <net/if.h>
+#include <sys/socket.h>
+
+#include <netinet/tcp.h>
+#include <netinet/ip.h>
+#include <netinet/ip_icmp.h>
+#include <netinet/ip6.h>
+#include <netinet/icmp6.h>
+
+#include <linux/if_tun.h>
+#include <linux/virtio_net.h>
+#include <linux/vhost.h>
+
+static unsigned char hexnybble(char hex)
+{
+	switch (hex) {
+	case '0'...'9':
+		return hex - '0';
+	case 'a'...'f':
+		return 10 + hex - 'a';
+	case 'A'...'F':
+		return 10 + hex - 'A';
+	default:
+		exit (KSFT_SKIP);
+	}
+}
+
+static unsigned char hexchar(char *hex)
+{
+	return (hexnybble(hex[0]) << 4) | hexnybble(hex[1]);
+}
+
+int open_tun(int vnet_hdr_sz, int pi, struct in6_addr *addr)
+{
+	int tun_fd = open("/dev/net/tun", O_RDWR);
+	if (tun_fd == -1)
+		return -1;
+
+	struct ifreq ifr = { 0 };
+
+	ifr.ifr_flags = IFF_TUN;
+	if (!pi)
+		ifr.ifr_flags |= IFF_NO_PI;
+	if (vnet_hdr_sz)
+		ifr.ifr_flags |= IFF_VNET_HDR;
+
+	if (ioctl(tun_fd, TUNSETIFF, (void *)&ifr) < 0)
+		goto out_tun;
+
+	if (vnet_hdr_sz &&
+	    ioctl(tun_fd, TUNSETVNETHDRSZ, &vnet_hdr_sz) < 0)
+		goto out_tun;
+
+	int sockfd = socket(AF_INET6, SOCK_DGRAM, IPPROTO_IP);
+	if (sockfd == -1)
+		goto out_tun;
+
+	if (ioctl(sockfd, SIOCGIFFLAGS, (void *)&ifr) < 0)
+		goto out_sock;
+
+	ifr.ifr_flags |= IFF_UP;
+	if (ioctl(sockfd, SIOCSIFFLAGS, (void *)&ifr) < 0)
+		goto out_sock;
+
+	close(sockfd);
+
+	FILE *inet6 = fopen("/proc/net/if_inet6", "r");
+	if (!inet6)
+		goto out_tun;
+
+	char buf[80];
+	while (fgets(buf, sizeof(buf), inet6)) {
+		size_t len = strlen(buf), namelen = strlen(ifr.ifr_name);
+		if (!strncmp(buf, "fe80", 4) &&
+		    buf[len - namelen - 2] == ' ' &&
+		    !strncmp(buf + len - namelen - 1, ifr.ifr_name, namelen)) {
+			for (int i = 0; i < 16; i++) {
+				addr->s6_addr[i] = hexchar(buf + i*2);
+			}
+			fclose(inet6);
+			return tun_fd;
+		}
+	}
+	/* Not found */
+	fclose(inet6);
+ out_sock:
+	close(sockfd);
+ out_tun:
+	close(tun_fd);
+	return -1;
+}
+
+#define RING_SIZE 32
+#define RING_MASK(x) ((x) & (RING_SIZE-1))
+
+struct pkt_buf {
+	unsigned char data[2048];
+};
+
+struct test_vring {
+	struct vring_desc desc[RING_SIZE];
+	struct vring_avail avail;
+	__virtio16 avail_ring[RING_SIZE];
+	struct vring_used used;
+	struct vring_used_elem used_ring[RING_SIZE];
+	struct pkt_buf pkts[RING_SIZE];
+} rings[2];
+
+static int setup_vring(int vhost_fd, int tun_fd, int call_fd, int kick_fd, int idx)
+{
+	struct test_vring *vring = &rings[idx];
+	int ret;
+
+	memset(vring, 0, sizeof(vring));
+
+	struct vhost_vring_state vs = { };
+	vs.index = idx;
+	vs.num = RING_SIZE;
+	if (ioctl(vhost_fd, VHOST_SET_VRING_NUM, &vs) < 0) {
+		perror("VHOST_SET_VRING_NUM");
+		return -1;
+	}
+
+	vs.num = 0;
+	if (ioctl(vhost_fd, VHOST_SET_VRING_BASE, &vs) < 0) {
+		perror("VHOST_SET_VRING_BASE");
+		return -1;
+	}
+
+	struct vhost_vring_addr va = { };
+	va.index = idx;
+	va.desc_user_addr = (uint64_t)vring->desc;
+	va.avail_user_addr = (uint64_t)&vring->avail;
+	va.used_user_addr  = (uint64_t)&vring->used;
+	if (ioctl(vhost_fd, VHOST_SET_VRING_ADDR, &va) < 0) {
+		perror("VHOST_SET_VRING_ADDR");
+		return -1;
+	}
+
+	struct vhost_vring_file vf = { };
+	vf.index = idx;
+	vf.fd = tun_fd;
+	if (ioctl(vhost_fd, VHOST_NET_SET_BACKEND, &vf) < 0) {
+		perror("VHOST_NET_SET_BACKEND");
+		return -1;
+	}
+
+	vf.fd = call_fd;
+	if (ioctl(vhost_fd, VHOST_SET_VRING_CALL, &vf) < 0) {
+		perror("VHOST_SET_VRING_CALL");
+		return -1;
+	}
+
+	vf.fd = kick_fd;
+	if (ioctl(vhost_fd, VHOST_SET_VRING_KICK, &vf) < 0) {
+		perror("VHOST_SET_VRING_KICK");
+		return -1;
+	}
+
+	return 0;
+}
+
+int setup_vhost(int vhost_fd, int tun_fd, int call_fd, int kick_fd, uint64_t want_features)
+{
+	int ret;
+
+	if (ioctl(vhost_fd, VHOST_SET_OWNER, NULL) < 0) {
+		perror("VHOST_SET_OWNER");
+		return -1;
+	}
+
+	uint64_t features;
+	if (ioctl(vhost_fd, VHOST_GET_FEATURES, &features) < 0) {
+		perror("VHOST_GET_FEATURES");
+		return -1;
+	}
+
+	if ((features & want_features) != want_features)
+		return KSFT_SKIP;
+
+	if (ioctl(vhost_fd, VHOST_SET_FEATURES, &want_features) < 0) {
+		perror("VHOST_SET_FEATURES");
+		return -1;
+	}
+
+	struct vhost_memory *vmem = alloca(sizeof(*vmem) + sizeof(vmem->regions[0]));
+
+	memset(vmem, 0, sizeof(*vmem) + sizeof(vmem->regions[0]));
+	vmem->nregions = 1;
+	/*
+	 * I just want to map the *whole* of userspace address space. But
+	 * from userspace I don't know what that is. On x86_64 it would be:
+	 *
+	 * vmem->regions[0].guest_phys_addr = 4096;
+	 * vmem->regions[0].memory_size = 0x7fffffffe000;
+	 * vmem->regions[0].userspace_addr = 4096;
+	 *
+	 * For now, just ensure we put everything inside a single BSS region.
+	 */
+	vmem->regions[0].guest_phys_addr = (uint64_t)&rings;
+	vmem->regions[0].userspace_addr = (uint64_t)&rings;
+	vmem->regions[0].memory_size = sizeof(rings);
+
+	if (ioctl(vhost_fd, VHOST_SET_MEM_TABLE, vmem) < 0) {
+		perror("VHOST_SET_MEM_TABLE");
+		return -1;
+	}
+
+	if (setup_vring(vhost_fd, tun_fd, call_fd, kick_fd, 0))
+		return -1;
+
+	if (setup_vring(vhost_fd, tun_fd, call_fd, kick_fd, 1))
+		return -1;
+
+	return 0;
+}
+
+
+static char ping_payload[16] = "VHOST TEST PACKT";
+
+static inline uint32_t csum_partial(uint16_t *buf, int nwords)
+{
+	uint32_t sum = 0;
+	for(sum=0; nwords>0; nwords--)
+		sum += ntohs(*buf++);
+	return sum;
+}
+
+static inline uint16_t csum_finish(uint32_t sum)
+{
+	sum = (sum >> 16) + (sum &0xffff);
+	sum += (sum >> 16);
+	return htons((uint16_t)(~sum));
+}
+
+static int create_icmp_echo(unsigned char *data, struct in6_addr *dst,
+			    struct in6_addr *src, uint16_t id, uint16_t seq)
+{
+	const int icmplen = ICMP_MINLEN + sizeof(ping_payload);
+	const int plen = sizeof(struct ip6_hdr) + icmplen;
+	struct ip6_hdr *iph = (void *)data;
+	struct icmp6_hdr *icmph = (void *)(data + sizeof(*iph));
+
+	/* IPv6 Header */
+	iph->ip6_flow = htonl((6 << 28) + /* version 6 */
+			      (0 << 20) + /* traffic class */
+			      (0 << 0));  /* flow ID  */
+	iph->ip6_nxt = IPPROTO_ICMPV6;
+	iph->ip6_plen = htons(icmplen);
+	iph->ip6_hlim = 128;
+	iph->ip6_src = *src;
+	iph->ip6_dst = *dst;
+
+	/* ICMPv6 echo request */
+	icmph->icmp6_type = ICMP6_ECHO_REQUEST;
+	icmph->icmp6_code = 0;
+	icmph->icmp6_data16[0] = htons(id);	/* ID */
+	icmph->icmp6_data16[1] = htons(seq);	/* sequence */
+
+	/* Some arbitrary payload */
+	memcpy(&icmph[1], ping_payload, sizeof(ping_payload));
+
+	/*
+	 * IPv6 upper-layer checksums include a pseudo-header
+	 * for IPv6 which contains the source address, the
+	 * destination address, the upper-layer packet length
+	 * and next-header field. See RFC8200 §8.1. The
+	 * checksum is as follows:
+	 *
+	 *   checksum 32 bytes of real IPv6 header:
+	 *     src addr (16 bytes)
+	 *     dst addr (16 bytes)
+	 *   8 bytes more:
+	 *     length of ICMPv6 in bytes (be32)
+	 *     3 bytes of 0
+	 *     next header byte (IPPROTO_ICMPV6)
+	 *   Then the actual ICMPv6 bytes
+	 */
+	uint32_t sum = csum_partial((uint16_t *)&iph->ip6_src, 8);      /* 8 uint16_t */
+	sum += csum_partial((uint16_t *)&iph->ip6_dst, 8);              /* 8 uint16_t */
+
+	/* The easiest way to checksum the following 8-byte
+	 * part of the pseudo-header without horridly violating
+	 * C type aliasing rules is *not* to build it in memory
+	 * at all. We know the length fits in 16 bits so the
+	 * partial checksum of 00 00 LL LL 00 00 00 NH ends up
+	 * being just LLLL + NH.
+	 */
+	sum += IPPROTO_ICMPV6;
+	sum += ICMP_MINLEN + sizeof(ping_payload);
+
+	sum += csum_partial((uint16_t *)icmph, icmplen / 2);
+	icmph->icmp6_cksum = csum_finish(sum);
+	return plen;
+}
+
+
+static int check_icmp_response(unsigned char *data, uint32_t len,
+			       struct in6_addr *dst, struct in6_addr *src)
+{
+	struct ip6_hdr *iph = (void *)data;
+	return ( len >= 41 && (ntohl(iph->ip6_flow) >> 28)==6 /* IPv6 header */
+		 && iph->ip6_nxt == IPPROTO_ICMPV6 /* IPv6 next header field = ICMPv6 */
+		 && !memcmp(&iph->ip6_src, src, 16) /* source == magic address */
+		 && !memcmp(&iph->ip6_dst, dst, 16) /* source == magic address */
+		 && len >= 40 + ICMP_MINLEN + sizeof(ping_payload) /* No short-packet segfaults */
+		 && data[40] == ICMP6_ECHO_REPLY /* ICMPv6 reply */
+		 && !memcmp(&data[40 + ICMP_MINLEN], ping_payload, sizeof(ping_payload)) /* Same payload in response */
+		 );
+
+}
+
+#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
+#define vio16(x) (x)
+#define vio32(x) (x)
+#define vio64(x) (x)
+#else
+#define vio16(x) __builtin_bswap16(x)
+#define vio32(x) __builtin_bswap32(x)
+#define vio64(x) __builtin_bswap64(x)
+#endif
+
+
+int test_vhost(int vnet_hdr_sz, int pi, int xdp, uint64_t features)
+{
+	int call_fd = eventfd(0, EFD_CLOEXEC|EFD_NONBLOCK);
+	int kick_fd = eventfd(0, EFD_CLOEXEC|EFD_NONBLOCK);
+	int vhost_fd = open("/dev/vhost-net", O_RDWR);
+	int tun_fd = -1;
+	int ret = KSFT_SKIP;
+
+	if (call_fd < 0 || kick_fd < 0 || vhost_fd < 0)
+		goto err;
+
+	memset(rings, 0, sizeof(rings));
+
+	/* Pick up the link-local address that the kernel
+	 * assigns to the tun device. */
+	struct in6_addr tun_addr;
+	tun_fd = open_tun(vnet_hdr_sz, pi, &tun_addr);
+	if (tun_fd < 0)
+		goto err;
+
+	int pi_offset = -1;
+	int data_offset = vnet_hdr_sz;
+
+	/* The tun device puts PI *first*, before the vnet hdr */
+	if (pi) {
+		pi_offset = 0;
+		data_offset += sizeof(struct tun_pi);
+	};
+
+	/* If vhost is going a vnet hdr it comes before all else */
+	if (features & (1ULL << VHOST_NET_F_VIRTIO_NET_HDR)) {
+		int vhost_hdr_sz = (features & ((1ULL << VIRTIO_NET_F_MRG_RXBUF) |
+						(1ULL << VIRTIO_F_VERSION_1))) ?
+			sizeof(struct virtio_net_hdr_mrg_rxbuf) :
+			sizeof(struct virtio_net_hdr);
+
+		data_offset += vhost_hdr_sz;
+		if (pi_offset != -1)
+			pi_offset += vhost_hdr_sz;
+	}
+
+	if (!xdp) {
+		int sndbuf = RING_SIZE * 2048;
+		if (ioctl(tun_fd, TUNSETSNDBUF, &sndbuf) < 0) {
+			perror("TUNSETSNDBUF");
+			ret = -1;
+			goto err;
+		}
+	}
+
+	ret = setup_vhost(vhost_fd, tun_fd, call_fd, kick_fd, features);
+	if (ret)
+		goto err;
+
+	/* A fake link-local address for the userspace end */
+	struct in6_addr local_addr = { 0 };
+	local_addr.s6_addr16[0] = htons(0xfe80);
+	local_addr.s6_addr16[7] = htons(1);
+
+
+	/* Set up RX and TX descriptors; the latter with ping packets ready to
+	 * send to the kernel, but don't actually send them yet. */
+	for (int i = 0; i < RING_SIZE; i++) {
+		struct pkt_buf *pkt = &rings[1].pkts[i];
+		if (pi_offset != -1) {
+			struct tun_pi *pi = (void *)&pkt->data[pi_offset];
+			pi->proto = htons(ETH_P_IPV6);
+		}
+		int plen = create_icmp_echo(&pkt->data[data_offset], &tun_addr,
+					    &local_addr, 0x4747, i);
+
+		rings[1].desc[i].addr = vio64((uint64_t)pkt);
+		rings[1].desc[i].len = vio32(plen + data_offset);
+		rings[1].avail_ring[i] = vio16(i);
+
+		pkt = &rings[0].pkts[i];
+		rings[0].desc[i].addr = vio64((uint64_t)pkt);
+		rings[0].desc[i].len = vio32(sizeof(*pkt));
+		rings[0].desc[i].flags = vio16(VRING_DESC_F_WRITE);
+		rings[0].avail_ring[i] = vio16(i);
+	}
+	barrier();
+	rings[1].avail.idx = vio16(1);
+
+	uint16_t rx_seen_used = 0;
+	struct timeval tv = { 1, 0 };
+	while (1) {
+		fd_set rfds = { 0 };
+		FD_SET(call_fd, &rfds);
+
+		rings[0].avail.idx = vio16(rx_seen_used + RING_SIZE);
+		barrier();
+		eventfd_write(kick_fd, 1);
+
+		if (select(call_fd + 1, &rfds, NULL, NULL, &tv) <= 0) {
+			ret = -1;
+			goto err;
+		}
+
+		uint16_t rx_used_idx = vio16(rings[0].used.idx);
+		barrier();
+
+		while(rx_used_idx != rx_seen_used) {
+			uint32_t desc = vio32(rings[0].used_ring[RING_MASK(rx_seen_used)].id);
+			uint32_t len  = vio32(rings[0].used_ring[RING_MASK(rx_seen_used)].len);
+
+			if (desc >= RING_SIZE || len < data_offset)
+				return -1;
+
+			uint64_t addr = vio64(rings[0].desc[desc].addr);
+			if (!addr)
+				return -1;
+
+			if (len > data_offset &&
+			    (pi_offset == -1 ||
+			     ((struct tun_pi *)(addr + pi_offset))->proto == htons(ETH_P_IPV6)) &&
+			    check_icmp_response((void *)(addr + data_offset), len - data_offset,
+						&local_addr, &tun_addr)) {
+				ret = 0;
+				goto err;
+			}
+
+			/* Give the same buffer back */
+			rings[0].avail_ring[RING_MASK(rx_seen_used++)] = vio32(desc);
+		}
+		barrier();
+
+		uint64_t ev_val;
+		eventfd_read(call_fd, &ev_val);
+	}
+
+ err:
+	if (call_fd != -1)
+		close(call_fd);
+	if (kick_fd != -1)
+		close(kick_fd);
+	if (vhost_fd != -1)
+		close(vhost_fd);
+	if (tun_fd != -1)
+		close(tun_fd);
+
+	printf("TEST: (hdr %d, xdp %d, pi %d, features %llx) RESULT: %d\n",
+	       vnet_hdr_sz, xdp, pi, (unsigned long long)features, ret);
+	return ret;
+}
+
+/* For iterating over all permutations. */
+#define VHDR_LEN_BITS	3	/* Tun vhdr length selection */
+#define XDP_BIT		4	/* Don't TUNSETSNDBUF, so we use XDP */
+#define PI_BIT		8	/* Don't set IFF_NO_PI */
+#define VIRTIO_V1_BIT	16	/* Use VIRTIO_F_VERSION_1 feature */
+#define VHOST_HDR_BIT	32	/* Use VHOST_NET_F_VIRTIO_NET_HDR */
+
+unsigned int tun_vhdr_lens[] = { 0, 10, 12, 20 };
+
+int main(void)
+{
+	int result = KSFT_SKIP;
+	int i, ret;
+
+	for (i = 0; i < 64; i++) {
+		uint64_t features = 0;
+
+		if (i & VIRTIO_V1_BIT)
+			features |= (1ULL << VIRTIO_F_VERSION_1);
+#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
+		else
+			continue; /* We'd need vio16 et al not to byteswap */
+#endif
+
+		if (i & VHOST_HDR_BIT) {
+			features |= (1ULL << VHOST_NET_F_VIRTIO_NET_HDR);
+
+			/* Even though the test actually passes at the time of
+			 * writing, don't bother to try asking tun *and* vhost
+			 * both to handle a virtio_net_hdr at the same time.
+			 * That's just silly.  */
+			if (i & VHDR_LEN_BITS)
+				continue;
+		}
+
+		ret = test_vhost(tun_vhdr_lens[i & VHDR_LEN_BITS],
+				 !!(i & PI_BIT), !!(i & XDP_BIT), features);
+		if (ret < result)
+			result = ret;
+	}
+
+	return result;
+}
Jason Wang June 24, 2021, 6:18 a.m. UTC | #6
在 2021/6/23 下午9:52, David Woodhouse 写道:
> On Wed, 2021-06-23 at 11:45 +0800, Jason Wang wrote:
>> As replied in previous version, it would be better if we can unify
>> similar logic in tun_get_user().
> So that ends up looking something like this (incremental).
>
> Note the '/* XXX: frags && */' part in tun_skb_set_protocol(), because
> the 'frags &&' was there in tun_get_user() and it wasn't obvious to me
> whether I should be lifting that out as a separate argument to
> tun_skb_set_protocol() or if there's a better way.
>
> Either way, in my judgement this is less suitable for a stable fix and
> more appropriate for a follow-on cleanup. But I don't feel that
> strongly; I'm more than happy for you to overrule me on that.
> Especially if you fix the above XXX part while you're at it :)


By simply adding a boolean "pull" in tun_skb_set_protocol()?

Thanks


>
> I tested this with vhost-net and !IFF_NO_PI, and TX works. RX is still
> hosed on the vhost-net side, for the same reason that a bunch of test
> cases were already listed in #if 0, but I'll address that in a separate
> email. It's not part of *this* patch.
>
> --- a/drivers/net/tun.c
> +++ b/drivers/net/tun.c
> @@ -1641,6 +1641,40 @@ static struct sk_buff *tun_build_skb(struct tun_struct *tun,
>   	return NULL;
>   }
>   
> +static int tun_skb_set_protocol(struct tun_struct *tun, struct sk_buff *skb,
> +				__be16 pi_proto)
> +{
> +	switch (tun->flags & TUN_TYPE_MASK) {
> +	case IFF_TUN:
> +		if (tun->flags & IFF_NO_PI) {
> +			u8 ip_version = skb->len ? (skb->data[0] >> 4) : 0;
> +
> +			switch (ip_version) {
> +			case 4:
> +				pi_proto = htons(ETH_P_IP);
> +				break;
> +			case 6:
> +				pi_proto = htons(ETH_P_IPV6);
> +				break;
> +			default:
> +				return -EINVAL;
> +			}
> +		}
> +
> +		skb_reset_mac_header(skb);
> +		skb->protocol = pi_proto;
> +		skb->dev = tun->dev;
> +		break;
> +	case IFF_TAP:
> +		if (/* XXX frags && */!pskb_may_pull(skb, ETH_HLEN))
> +			return -ENOMEM;
> +
> +		skb->protocol = eth_type_trans(skb, tun->dev);
> +		break;
> +	}
> +	return 0;
> +}
> +
>   /* Get packet from user space buffer */
>   static ssize_t tun_get_user(struct tun_struct *tun, struct tun_file *tfile,
>   			    void *msg_control, struct iov_iter *from,
> @@ -1784,37 +1818,9 @@ static ssize_t tun_get_user(struct tun_struct *tun, struct tun_file *tfile,
>   		return -EINVAL;
>   	}
>   
> -	switch (tun->flags & TUN_TYPE_MASK) {
> -	case IFF_TUN:
> -		if (tun->flags & IFF_NO_PI) {
> -			u8 ip_version = skb->len ? (skb->data[0] >> 4) : 0;
> -
> -			switch (ip_version) {
> -			case 4:
> -				pi.proto = htons(ETH_P_IP);
> -				break;
> -			case 6:
> -				pi.proto = htons(ETH_P_IPV6);
> -				break;
> -			default:
> -				atomic_long_inc(&tun->dev->rx_dropped);
> -				kfree_skb(skb);
> -				return -EINVAL;
> -			}
> -		}
> -
> -		skb_reset_mac_header(skb);
> -		skb->protocol = pi.proto;
> -		skb->dev = tun->dev;
> -		break;
> -	case IFF_TAP:
> -		if (frags && !pskb_may_pull(skb, ETH_HLEN)) {
> -			err = -ENOMEM;
> -			goto drop;
> -		}
> -		skb->protocol = eth_type_trans(skb, tun->dev);
> -		break;
> -	}
> +	err = tun_skb_set_protocol(tun, skb, pi.proto);
> +	if (err)
> +		goto drop;
>   
>   	/* copy skb_ubuf_info for callback when skb has no error */
>   	if (zerocopy) {
> @@ -2334,8 +2340,10 @@ static int tun_xdp_one(struct tun_struct *tun,
>   	struct virtio_net_hdr *gso = NULL;
>   	struct bpf_prog *xdp_prog;
>   	struct sk_buff *skb = NULL;
> +	__be16 proto = 0;
>   	u32 rxhash = 0, act;
>   	int buflen = hdr->buflen;
> +	int reservelen = xdp->data - xdp->data_hard_start;
>   	int err = 0;
>   	bool skb_xdp = false;
>   	struct page *page;
> @@ -2343,6 +2351,17 @@ static int tun_xdp_one(struct tun_struct *tun,
>   	if (tun->flags & IFF_VNET_HDR)
>   		gso = &hdr->gso;
>   
> +	if (!(tun->flags & IFF_NO_PI)) {
> +		struct tun_pi *pi = xdp->data;
> +		if (datasize < sizeof(*pi)) {
> +			atomic_long_inc(&tun->rx_frame_errors);
> +			return  -EINVAL;
> +		}
> +		proto = pi->proto;
> +		reservelen += sizeof(*pi);
> +		datasize -= sizeof(*pi);
> +	}
> +
>   	xdp_prog = rcu_dereference(tun->xdp_prog);
>   	if (xdp_prog) {
>   		if (gso && gso->gso_type) {
> @@ -2388,8 +2407,8 @@ static int tun_xdp_one(struct tun_struct *tun,
>   		goto out;
>   	}
>   
> -	skb_reserve(skb, xdp->data - xdp->data_hard_start);
> -	skb_put(skb, xdp->data_end - xdp->data);
> +	skb_reserve(skb, reservelen);
> +	skb_put(skb, datasize);
>   
>   	if (gso && virtio_net_hdr_to_skb(skb, gso, tun_is_little_endian(tun))) {
>   		atomic_long_inc(&tun->rx_frame_errors);
> @@ -2397,48 +2416,12 @@ static int tun_xdp_one(struct tun_struct *tun,
>   		err = -EINVAL;
>   		goto out;
>   	}
> -	switch (tun->flags & TUN_TYPE_MASK) {
> -	case IFF_TUN:
> -		if (tun->flags & IFF_NO_PI) {
> -			u8 ip_version = skb->len ? (skb->data[0] >> 4) : 0;
>   
> -			switch (ip_version) {
> -			case 4:
> -				skb->protocol = htons(ETH_P_IP);
> -				break;
> -			case 6:
> -				skb->protocol = htons(ETH_P_IPV6);
> -				break;
> -			default:
> -				atomic_long_inc(&tun->dev->rx_dropped);
> -				kfree_skb(skb);
> -				err = -EINVAL;
> -				goto out;
> -			}
> -		} else {
> -			struct tun_pi *pi = (struct tun_pi *)skb->data;
> -			if (!pskb_may_pull(skb, sizeof(*pi))) {
> -				atomic_long_inc(&tun->dev->rx_dropped);
> -				kfree_skb(skb);
> -				err = -ENOMEM;
> -				goto out;
> -			}
> -			skb_pull_inline(skb, sizeof(*pi));
> -			skb->protocol = pi->proto;
> -		}
> -
> -		skb_reset_mac_header(skb);
> -		skb->dev = tun->dev;
> -		break;
> -	case IFF_TAP:
> -		if (!pskb_may_pull(skb, ETH_HLEN)) {
> -			atomic_long_inc(&tun->dev->rx_dropped);
> -			kfree_skb(skb);
> -			err = -ENOMEM;
> -			goto out;
> -		}
> -		skb->protocol = eth_type_trans(skb, tun->dev);
> -		break;
> +	err = tun_skb_set_protocol(tun, skb, proto);
> +	if (err) {
> +		atomic_long_inc(&tun->dev->rx_dropped);
> +		kfree_skb(skb);
> +		goto out;
>   	}
>   
>   	skb_reset_network_header(skb);
>
Jason Wang June 24, 2021, 6:37 a.m. UTC | #7
在 2021/6/24 上午6:52, David Woodhouse 写道:
> On Wed, 2021-06-23 at 18:31 +0100, David Woodhouse wrote:
>> Joy... that's wrong because when tun does both the PI and the vnet
>> headers, the PI header comes *first*. When tun does only PI and vhost
>> does the vnet headers, they come in the other order.
>>
>> Will fix (and adjust the test cases to cope).
>
> I got this far, pushed to
> https://git.infradead.org/users/dwmw2/linux.git/shortlog/refs/heads/vhost-net
>
> All the test cases are now passing. I don't guarantee I haven't
> actually broken qemu and IFF_TAP mode though, mind you :)


No problem, but it would be easier for me if you can post another 
version of the series.


>
> I'll need to refactor the intermediate commits a little so I won't
> repost the series quite yet, but figured I should at least show what I
> have for comments, as my day ends and yours begins.
>
>
> As discussed, I expanded tun_get_socket()/tap_get_socket() to return
> the actual header length instead of letting vhost make wild guesses.


This probably won't work since we had TUNSETVNETHDRSZ.

I agree the vhost codes is tricky since it assumes only two kinds of the 
hdr length.

But it was basically how it works for the past 10 years. It depends on 
the userspace (Qemu) to coordinate it with the TUN/TAP through 
TUNSETVNETHDRSZ during the feature negotiation.


> Note that in doing so, I have made tun_get_socket() return -ENOTCONN if
> the tun fd *isn't* actually attached (TUNSETIFF) to a real device yet.


Any reason for doing this? Note that the socket is loosely coupled with 
the networking device.


>
> I moved the sanity check back to tun/tap instead of doing it in
> vhost_net_build_xdp(), because the latter has no clue about the tun PI
> header and doesn't know *where* the virtio header is.


Right, the deserves a separate patch.


>
>

[...]


>   	mutex_unlock(&n->dev.mutex);
> diff --git a/include/linux/if_tap.h b/include/linux/if_tap.h
> index 915a187cfabd..b460ba98f34e 100644
> --- a/include/linux/if_tap.h
> +++ b/include/linux/if_tap.h
> @@ -3,14 +3,14 @@
>   #define _LINUX_IF_TAP_H_
>   
>   #if IS_ENABLED(CONFIG_TAP)
> -struct socket *tap_get_socket(struct file *);
> +struct socket *tap_get_socket(struct file *, size_t *);
>   struct ptr_ring *tap_get_ptr_ring(struct file *file);
>   #else
>   #include <linux/err.h>
>   #include <linux/errno.h>
>   struct file;
>   struct socket;
> -static inline struct socket *tap_get_socket(struct file *f)
> +static inline struct socket *tap_get_socket(struct file *f, size_t *)
>   {
>   	return ERR_PTR(-EINVAL);
>   }
> diff --git a/include/linux/if_tun.h b/include/linux/if_tun.h
> index 2a7660843444..8d78b6bbc228 100644
> --- a/include/linux/if_tun.h
> +++ b/include/linux/if_tun.h
> @@ -21,11 +21,10 @@ struct tun_msg_ctl {
>   
>   struct tun_xdp_hdr {
>   	int buflen;
> -	struct virtio_net_hdr gso;


Any reason for doing this? I meant it can work but we need limit the 
changes that is unrelated to the fixes.

Thanks
David Woodhouse June 24, 2021, 7:05 a.m. UTC | #8
On Thu, 2021-06-24 at 14:18 +0800, Jason Wang wrote:
> 在 2021/6/23 下午9:52, David Woodhouse 写道:
> > On Wed, 2021-06-23 at 11:45 +0800, Jason Wang wrote:
> > > As replied in previous version, it would be better if we can unify
> > > similar logic in tun_get_user().
> > 
> > So that ends up looking something like this (incremental).
> > 
> > Note the '/* XXX: frags && */' part in tun_skb_set_protocol(), because
> > the 'frags &&' was there in tun_get_user() and it wasn't obvious to me
> > whether I should be lifting that out as a separate argument to
> > tun_skb_set_protocol() or if there's a better way.
> > 
> > Either way, in my judgement this is less suitable for a stable fix and
> > more appropriate for a follow-on cleanup. But I don't feel that
> > strongly; I'm more than happy for you to overrule me on that.
> > Especially if you fix the above XXX part while you're at it :)
> 
> 
> By simply adding a boolean "pull" in tun_skb_set_protocol()?

Sure; thanks. It's been a few years since I really played with skb
handling; I was half hoping for a simpler "you don't need to
because..." answer, but that works :)
David Woodhouse June 24, 2021, 7:23 a.m. UTC | #9
On Thu, 2021-06-24 at 14:37 +0800, Jason Wang wrote:
> 在 2021/6/24 上午6:52, David Woodhouse 写道:
> > On Wed, 2021-06-23 at 18:31 +0100, David Woodhouse wrote:
> > > Joy... that's wrong because when tun does both the PI and the vnet
> > > headers, the PI header comes *first*. When tun does only PI and vhost
> > > does the vnet headers, they come in the other order.
> > > 
> > > Will fix (and adjust the test cases to cope).
> > 
> > I got this far, pushed to
> > https://git.infradead.org/users/dwmw2/linux.git/shortlog/refs/heads/vhost-net
> > 
> > All the test cases are now passing. I don't guarantee I haven't
> > actually broken qemu and IFF_TAP mode though, mind you :)
> 
> 
> No problem, but it would be easier for me if you can post another 
> version of the series.

Ack; I'm reworking it now into a saner series. All three of my initial
simple fixes ended up with more changes once I expanded the test cases
to cover more permutations of PI/XDP/headers :)

> > As discussed, I expanded tun_get_socket()/tap_get_socket() to return
> > the actual header length instead of letting vhost make wild guesses.
> 
> 
> This probably won't work since we had TUNSETVNETHDRSZ.

Or indeed IFF_NO_PI.

> I agree the vhost codes is tricky since it assumes only two kinds of the 
> hdr length.
> 
> But it was basically how it works for the past 10 years. It depends on 
> the userspace (Qemu) to coordinate it with the TUN/TAP through 
> TUNSETVNETHDRSZ during the feature negotiation.

I think that in any given situation, the kernel should either work
correctly, or gracefully refuse to set it up.

My patch set will make it work correctly for all the permutations I've
looked at. I would accept and answer of "screw that, just make
tun_get_socket() return failure if IFF_NO_PI isn't set", for example.

> > Note that in doing so, I have made tun_get_socket() return -ENOTCONN if
> > the tun fd *isn't* actually attached (TUNSETIFF) to a real device yet.
> 
> Any reason for doing this? Note that the socket is loosely coupled with 
> the networking device.

Because to determine the sock_hlen to return, it needs to look at the
tun>flags and tun->vndr_hdr_sz field. And if there isn't an actual tun
device attached, it can't.

> 
> > 
> > I moved the sanity check back to tun/tap instead of doing it in
> > vhost_net_build_xdp(), because the latter has no clue about the tun PI
> > header and doesn't know *where* the virtio header is.
> 
> 
> Right, the deserves a separate patch.

Yep, in my tree it has one, but it's a bit mixed in with other fixes
until I do that refactoring. 

> > diff --git a/include/linux/if_tun.h b/include/linux/if_tun.h
> > index 2a7660843444..8d78b6bbc228 100644
> > --- a/include/linux/if_tun.h
> > +++ b/include/linux/if_tun.h
> > @@ -21,11 +21,10 @@ struct tun_msg_ctl {
> >   
> >   struct tun_xdp_hdr {
> >   	int buflen;
> > -	struct virtio_net_hdr gso;
> 
> 
> Any reason for doing this? I meant it can work but we need limit the 
> changes that is unrelated to the fixes.

That's part of the patch that moves the sanity check back to tun/tap.
As I said it needs a little reworking, so it currently contains a
little bit of cleanup to previous code in tun_xdp_one(), but it looks
like this. The bit in drivers/vhost/net.c is obviously removing code
that I'd made conditional in a previous patch, so that will change
somewhat as I rework the series and drop the original patch.

From 2a0080f37244ec6dac8fb3e8330f9153a4373cfd Mon Sep 17 00:00:00 2001
From: David Woodhouse <dwmw@amazon.co.uk>
Date: Wed, 23 Jun 2021 23:32:00 +0100
Subject: [PATCH 10/10] net: remove virtio_net_hdr from struct tun_xdp_hdr

The tun device puts its struct tun_pi *before* the virtio_net_hdr, which
significantly complicates letting vhost validate it. Just let tap and
tun validate it for themselves, as they do in the non-XDP case anyway.

Signed-off-by: David Woodhouse <dwmw@amazon.co.uk>
---
 drivers/net/tap.c      | 25 ++++++++++++++++++++++---
 drivers/net/tun.c      | 34 ++++++++++++++++++++++++----------
 drivers/vhost/net.c    | 15 +--------------
 include/linux/if_tun.h |  1 -
 4 files changed, 47 insertions(+), 28 deletions(-)

diff --git a/drivers/net/tap.c b/drivers/net/tap.c
index 2170a0d3d34c..d1b1f1de374e 100644
--- a/drivers/net/tap.c
+++ b/drivers/net/tap.c
@@ -1132,16 +1132,35 @@ static const struct file_operations tap_fops = {
 static int tap_get_user_xdp(struct tap_queue *q, struct xdp_buff *xdp)
 {
 	struct tun_xdp_hdr *hdr = xdp->data_hard_start;
-	struct virtio_net_hdr *gso = &hdr->gso;
+	struct virtio_net_hdr *gso = NULL;
 	int buflen = hdr->buflen;
 	int vnet_hdr_len = 0;
 	struct tap_dev *tap;
 	struct sk_buff *skb;
 	int err, depth;
 
-	if (q->flags & IFF_VNET_HDR)
+	if (q->flags & IFF_VNET_HDR) {
 		vnet_hdr_len = READ_ONCE(q->vnet_hdr_sz);
+		if (xdp->data != xdp->data_hard_start + sizeof(*hdr) + vnet_hdr_len) {
+			err = -EINVAL;
+			goto err;
+		}
+
+		gso = (void *)&hdr[1];
 
+		if ((gso->flags & VIRTIO_NET_HDR_F_NEEDS_CSUM) &&
+		     tap16_to_cpu(q, gso->csum_start) +
+		     tap16_to_cpu(q, gso->csum_offset) + 2 >
+			     tap16_to_cpu(q, gso->hdr_len))
+			gso->hdr_len = cpu_to_tap16(q,
+				 tap16_to_cpu(q, gso->csum_start) +
+				 tap16_to_cpu(q, gso->csum_offset) + 2);
+
+		if (tap16_to_cpu(q, gso->hdr_len) > xdp->data_end - xdp->data) {
+			err = -EINVAL;
+			goto err;
+		}
+	}
 	skb = build_skb(xdp->data_hard_start, buflen);
 	if (!skb) {
 		err = -ENOMEM;
@@ -1155,7 +1174,7 @@ static int tap_get_user_xdp(struct tap_queue *q, struct xdp_buff *xdp)
 	skb_reset_mac_header(skb);
 	skb->protocol = eth_hdr(skb)->h_proto;
 
-	if (vnet_hdr_len) {
+	if (gso) {
 		err = virtio_net_hdr_to_skb(skb, gso, tap_is_little_endian(q));
 		if (err)
 			goto err_kfree;
diff --git a/drivers/net/tun.c b/drivers/net/tun.c
index 69f6ce87b109..72f8a04f493b 100644
--- a/drivers/net/tun.c
+++ b/drivers/net/tun.c
@@ -2337,29 +2337,43 @@ static int tun_xdp_one(struct tun_struct *tun,
 {
 	unsigned int datasize = xdp->data_end - xdp->data;
 	struct tun_xdp_hdr *hdr = xdp->data_hard_start;
+	void *tun_hdr = &hdr[1];
 	struct virtio_net_hdr *gso = NULL;
 	struct bpf_prog *xdp_prog;
 	struct sk_buff *skb = NULL;
 	__be16 proto = 0;
 	u32 rxhash = 0, act;
 	int buflen = hdr->buflen;
-	int reservelen = xdp->data - xdp->data_hard_start;
 	int err = 0;
 	bool skb_xdp = false;
 	struct page *page;
 
-	if (tun->flags & IFF_VNET_HDR)
-		gso = &hdr->gso;
-
 	if (!(tun->flags & IFF_NO_PI)) {
-		struct tun_pi *pi = xdp->data;
-		if (datasize < sizeof(*pi)) {
+		struct tun_pi *pi = tun_hdr;
+		tun_hdr += sizeof(*pi);
+
+		if (tun_hdr > xdp->data) {
 			atomic_long_inc(&tun->rx_frame_errors);
-			return  -EINVAL;
+			return -EINVAL;
 		}
 		proto = pi->proto;
-		reservelen += sizeof(*pi);
-		datasize -= sizeof(*pi);
+	}
+
+	if (tun->flags & IFF_VNET_HDR) {
+		gso = tun_hdr;
+		tun_hdr += sizeof(*gso);
+
+		if (tun_hdr > xdp->data) {
+			atomic_long_inc(&tun->rx_frame_errors);
+			return -EINVAL;
+		}
+
+		if ((gso->flags & VIRTIO_NET_HDR_F_NEEDS_CSUM) &&
+		    tun16_to_cpu(tun, gso->csum_start) + tun16_to_cpu(tun, gso->csum_offset) + 2 > tun16_to_cpu(tun, gso->hdr_len))
+			gso->hdr_len = cpu_to_tun16(tun, tun16_to_cpu(tun, gso->csum_start) + tun16_to_cpu(tun, gso->csum_offset) + 2);
+
+		if (tun16_to_cpu(tun, gso->hdr_len) > datasize)
+			return -EINVAL;
 	}
 
 	xdp_prog = rcu_dereference(tun->xdp_prog);
@@ -2407,7 +2421,7 @@ static int tun_xdp_one(struct tun_struct *tun,
 		goto out;
 	}
 
-	skb_reserve(skb, reservelen);
+	skb_reserve(skb, xdp->data - xdp->data_hard_start);
 	skb_put(skb, datasize);
 
 	if (gso && virtio_net_hdr_to_skb(skb, gso, tun_is_little_endian(tun))) {
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index e88cc18d079f..d9491c620a9c 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -716,26 +716,13 @@ static int vhost_net_build_xdp(struct vhost_net_virtqueue *nvq,
 	buf = (char *)page_address(alloc_frag->page) + alloc_frag->offset;
 	hdr = buf;
 	if (sock_hlen) {
-		struct virtio_net_hdr *gso = &hdr->gso;
-
 		copied = copy_page_from_iter(alloc_frag->page,
 					     alloc_frag->offset +
-					     offsetof(struct tun_xdp_hdr, gso),
+					     sizeof(struct tun_xdp_hdr),
 					     sock_hlen, from);
 		if (copied != sock_hlen)
 			return -EFAULT;
 
-		if ((gso->flags & VIRTIO_NET_HDR_F_NEEDS_CSUM) &&
-		    vhost16_to_cpu(vq, gso->csum_start) +
-		    vhost16_to_cpu(vq, gso->csum_offset) + 2 >
-		    vhost16_to_cpu(vq, gso->hdr_len)) {
-			gso->hdr_len = cpu_to_vhost16(vq,
-						      vhost16_to_cpu(vq, gso->csum_start) +
-						      vhost16_to_cpu(vq, gso->csum_offset) + 2);
-
-			if (vhost16_to_cpu(vq, gso->hdr_len) > len)
-				return -EINVAL;
-		}
 		len -= sock_hlen;
 	}
 
diff --git a/include/linux/if_tun.h b/include/linux/if_tun.h
index 8a7debd3f663..8d78b6bbc228 100644
--- a/include/linux/if_tun.h
+++ b/include/linux/if_tun.h
@@ -21,7 +21,6 @@ struct tun_msg_ctl {
 
 struct tun_xdp_hdr {
 	int buflen;
-	struct virtio_net_hdr gso;
 };
 
 #if defined(CONFIG_TUN) || defined(CONFIG_TUN_MODULE)
diff mbox series

Patch

diff --git a/drivers/net/tun.c b/drivers/net/tun.c
index 4cf38be26dc9..f812dcdc640e 100644
--- a/drivers/net/tun.c
+++ b/drivers/net/tun.c
@@ -2394,8 +2394,50 @@  static int tun_xdp_one(struct tun_struct *tun,
 		err = -EINVAL;
 		goto out;
 	}
+	switch (tun->flags & TUN_TYPE_MASK) {
+	case IFF_TUN:
+		if (tun->flags & IFF_NO_PI) {
+			u8 ip_version = skb->len ? (skb->data[0] >> 4) : 0;
+
+			switch (ip_version) {
+			case 4:
+				skb->protocol = htons(ETH_P_IP);
+				break;
+			case 6:
+				skb->protocol = htons(ETH_P_IPV6);
+				break;
+			default:
+				atomic_long_inc(&tun->dev->rx_dropped);
+				kfree_skb(skb);
+				err = -EINVAL;
+				goto out;
+			}
+		} else {
+			struct tun_pi *pi = (struct tun_pi *)skb->data;
+			if (!pskb_may_pull(skb, sizeof(*pi))) {
+				atomic_long_inc(&tun->dev->rx_dropped);
+				kfree_skb(skb);
+				err = -ENOMEM;
+				goto out;
+			}
+			skb_pull_inline(skb, sizeof(*pi));
+			skb->protocol = pi->proto;
+		}
+
+		skb_reset_mac_header(skb);
+		skb->dev = tun->dev;
+		break;
+	case IFF_TAP:
+		if (!pskb_may_pull(skb, ETH_HLEN)) {
+			atomic_long_inc(&tun->dev->rx_dropped);
+			kfree_skb(skb);
+			err = -ENOMEM;
+			goto out;
+		}
+		skb->protocol = eth_type_trans(skb, tun->dev);
+		break;
+	}
 
-	skb->protocol = eth_type_trans(skb, tun->dev);
 	skb_reset_network_header(skb);
 	skb_probe_transport_header(skb);
 	skb_record_rx_queue(skb, tfile->queue_index);