@@ -8,6 +8,7 @@
*/
#include <linux/miscdevice.h>
#include <linux/atomic.h>
+#include <linux/errqueue.h>
#include <linux/module.h>
#include <linux/mutex.h>
#include <linux/vmalloc.h>
@@ -32,7 +33,8 @@
enum {
VHOST_VSOCK_FEATURES = VHOST_FEATURES |
(1ULL << VIRTIO_F_ACCESS_PLATFORM) |
- (1ULL << VIRTIO_VSOCK_F_SEQPACKET)
+ (1ULL << VIRTIO_VSOCK_F_SEQPACKET) |
+ (1ULL << VIRTIO_VSOCK_F_DGRAM)
};
enum {
@@ -56,6 +58,7 @@ struct vhost_vsock {
atomic_t queued_replies;
u32 guest_cid;
+ bool dgram_allow;
bool seqpacket_allow;
};
@@ -86,6 +89,32 @@ static struct vhost_vsock *vhost_vsock_get(u32 guest_cid)
return NULL;
}
+/* Claims ownership of the skb, do not free the skb after calling! */
+static void
+vhost_transport_error(struct sk_buff *skb, int err)
+{
+ struct sock_exterr_skb *serr;
+ struct sock *sk = skb->sk;
+ struct sk_buff *clone;
+
+ serr = SKB_EXT_ERR(skb);
+ memset(serr, 0, sizeof(*serr));
+ serr->ee.ee_errno = err;
+ serr->ee.ee_origin = SO_EE_ORIGIN_NONE;
+
+ clone = skb_clone(skb, GFP_KERNEL);
+ if (!clone)
+ goto out;
+
+ if (sock_queue_err_skb(sk, clone))
+ kfree_skb(clone);
+
+ sk->sk_err = err;
+ sk_error_report(sk);
+out:
+ kfree_skb(skb);
+}
+
static void
vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
struct vhost_virtqueue *vq)
@@ -162,9 +191,15 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
hdr = virtio_vsock_hdr(skb);
/* If the packet is greater than the space available in the
- * buffer, we split it using multiple buffers.
+ * buffer, we split it using multiple buffers for connectible
+ * sockets and drop the packet for datagram sockets.
*/
if (payload_len > iov_len - sizeof(*hdr)) {
+ if (le16_to_cpu(hdr->type) == VIRTIO_VSOCK_TYPE_DGRAM) {
+ vhost_transport_error(skb, EHOSTUNREACH);
+ continue;
+ }
+
payload_len = iov_len - sizeof(*hdr);
/* As we are copying pieces of large packet's buffer to
@@ -403,6 +438,22 @@ static bool vhost_transport_msgzerocopy_allow(void)
return true;
}
+static bool vhost_transport_dgram_allow(u32 cid, u32 port)
+{
+ struct vhost_vsock *vsock;
+ bool dgram_allow = false;
+
+ rcu_read_lock();
+ vsock = vhost_vsock_get(cid);
+
+ if (vsock)
+ dgram_allow = vsock->dgram_allow;
+
+ rcu_read_unlock();
+
+ return dgram_allow;
+}
+
static bool vhost_transport_seqpacket_allow(u32 remote_cid);
static struct virtio_transport vhost_transport = {
@@ -419,7 +470,7 @@ static struct virtio_transport vhost_transport = {
.cancel_pkt = vhost_transport_cancel_pkt,
.dgram_enqueue = virtio_transport_dgram_enqueue,
- .dgram_allow = virtio_transport_dgram_allow,
+ .dgram_allow = vhost_transport_dgram_allow,
.stream_enqueue = virtio_transport_stream_enqueue,
.stream_dequeue = virtio_transport_stream_dequeue,
@@ -811,6 +862,9 @@ static int vhost_vsock_set_features(struct vhost_vsock *vsock, u64 features)
if (features & (1ULL << VIRTIO_VSOCK_F_SEQPACKET))
vsock->seqpacket_allow = true;
+ if (features & (1ULL << VIRTIO_VSOCK_F_DGRAM))
+ vsock->dgram_allow = true;
+
for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
vq = &vsock->vqs[i];
mutex_lock(&vq->mutex);
@@ -1463,7 +1463,7 @@ int vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg,
return prot->recvmsg(sk, msg, len, flags, NULL);
#endif
- if (flags & MSG_OOB || flags & MSG_ERRQUEUE)
+ if (unlikely(flags & MSG_OOB))
return -EOPNOTSUPP;
if (unlikely(flags & MSG_ERRQUEUE))