@@ -53,6 +53,18 @@
_min1 < _min2 ? _min1 : _min2; })
#endif
+/* Round number down to multiple */
+#define ALIGN_DOWN(n, m) ((n) / (m) * (m))
+
+/* Round number up to multiple */
+#define ALIGN_UP(n, m) ALIGN_DOWN((n) + (m) - 1, (m))
+
+/* Align each region to cache line size in shared memory */
+#define SHM_ALIGNMENT 64
+
+/* The version of shared memory */
+#define SHM_VERSION 1
+
#define VHOST_USER_HDR_SIZE offsetof(VhostUserMsg, payload.u64)
/* The version of the protocol we support */
@@ -100,6 +112,8 @@ vu_request_to_string(unsigned int req)
REQ(VHOST_USER_POSTCOPY_ADVISE),
REQ(VHOST_USER_POSTCOPY_LISTEN),
REQ(VHOST_USER_POSTCOPY_END),
+ REQ(VHOST_USER_GET_SHM_SIZE),
+ REQ(VHOST_USER_SET_SHM_FD),
REQ(VHOST_USER_MAX),
};
#undef REQ
@@ -890,6 +904,41 @@ vu_check_queue_msg_file(VuDev *dev, VhostUserMsg *vmsg)
return true;
}
+static int
+vu_check_queue_inflights(VuDev *dev, VuVirtq *vq)
+{
+ int i = 0;
+
+ if ((dev->protocol_features &
+ VHOST_USER_PROTOCOL_F_SLAVE_SHMFD) == 0) {
+ return 0;
+ }
+
+ if (unlikely(!vq->shm)) {
+ return -1;
+ }
+
+ vq->used_idx = vq->vring.used->idx;
+ vq->inflight_num = 0;
+ for (i = 0; i < vq->vring.num; i++) {
+ if (vq->shm->inflight[i] == 0) {
+ continue;
+ }
+
+ vq->inflight_desc[vq->inflight_num++] = i;
+ vq->inuse++;
+ }
+ vq->shadow_avail_idx = vq->last_avail_idx = vq->inuse + vq->used_idx;
+
+ /* in case of I/O hang after reconnecting */
+ if (eventfd_write(vq->kick_fd, 1) ||
+ eventfd_write(vq->call_fd, 1)) {
+ return -1;
+ }
+
+ return 0;
+}
+
static bool
vu_set_vring_kick_exec(VuDev *dev, VhostUserMsg *vmsg)
{
@@ -925,6 +974,10 @@ vu_set_vring_kick_exec(VuDev *dev, VhostUserMsg *vmsg)
dev->vq[index].kick_fd, index);
}
+ if (vu_check_queue_inflights(dev, &dev->vq[index])) {
+ vu_panic(dev, "Failed to check inflights for vq: %d\n", index);
+ }
+
return false;
}
@@ -1215,6 +1268,115 @@ vu_set_postcopy_end(VuDev *dev, VhostUserMsg *vmsg)
return true;
}
+static int
+vu_setup_shm(VuDev *dev)
+{
+ int i;
+ char *addr = (char *)dev->shm_info.addr;
+ uint64_t size = 0;
+ uint32_t vq_size = ALIGN_UP(dev->shm_info.vq_size, dev->shm_info.align);
+
+ if (dev->shm_info.version != SHM_VERSION) {
+ DPRINT("Invalid version for shm: %d", dev->shm_info.version);
+ return -1;
+ }
+
+ if (dev->shm_info.dev_size != 0) {
+ DPRINT("Invalid dev_size for shm: %d", dev->shm_info.dev_size);
+ return -1;
+ }
+
+ if (dev->shm_info.vq_size != sizeof(VuVirtqShm)) {
+ DPRINT("Invalid vq_size for shm: %d", dev->shm_info.vq_size);
+ return -1;
+ }
+
+ for (i = 0; i < VHOST_MAX_NR_VIRTQUEUE; i++) {
+ size += vq_size;
+ if (size > dev->shm_info.mmap_size) {
+ break;
+ }
+ dev->vq[i].shm = (VuVirtqShm *)addr;
+ addr += vq_size;
+ }
+
+ return 0;
+}
+
+static bool
+vu_get_shm_size(VuDev *dev, VhostUserMsg *vmsg)
+{
+ if (vmsg->size != sizeof(vmsg->payload.shm)) {
+ vu_panic(dev, "Invalid get_shm_size message:%d", vmsg->size);
+ vmsg->size = 0;
+ return true;
+ }
+
+ vmsg->payload.shm.dev_size = 0;
+ vmsg->payload.shm.vq_size = sizeof(VuVirtqShm);
+ vmsg->payload.shm.align = SHM_ALIGNMENT;
+ vmsg->payload.shm.version = SHM_VERSION;
+
+ DPRINT("send shm dev_size: %"PRId32"\n", vmsg->payload.shm.dev_size);
+ DPRINT("send shm vq_size: %"PRId32"\n", vmsg->payload.shm.vq_size);
+ DPRINT("send shm align: %"PRId32"\n", vmsg->payload.shm.align);
+ DPRINT("send shm version: %"PRId32"\n", vmsg->payload.shm.version);
+
+ return true;
+}
+
+static bool
+vu_set_shm_fd(VuDev *dev, VhostUserMsg *vmsg)
+{
+ int fd;
+ uint64_t mmap_size, mmap_offset;
+ void *rc;
+
+ if (vmsg->fd_num != 1 ||
+ vmsg->size != sizeof(vmsg->payload.shm)) {
+ vu_panic(dev, "Invalid set_shm_fd message size:%d fds:%d",
+ vmsg->size, vmsg->fd_num);
+ return false;
+ }
+
+ fd = vmsg->fds[0];
+ mmap_size = vmsg->payload.shm.mmap_size;
+ mmap_offset = vmsg->payload.shm.mmap_offset;
+ DPRINT("set_shm_fd mmap_size: %"PRId64"\n", mmap_size);
+ DPRINT("set_shm_fd mmap_offset: %"PRId64"\n", mmap_offset);
+ DPRINT("set_shm_fd dev_size: %"PRId32"\n", vmsg->payload.shm.dev_size);
+ DPRINT("set_shm_fd vq_size: %"PRId32"\n", vmsg->payload.shm.vq_size);
+ DPRINT("set_shm_fd align: %"PRId32"\n", vmsg->payload.shm.align);
+ DPRINT("set_shm_fd version: %"PRId32"\n", vmsg->payload.shm.version);
+
+ rc = mmap(0, mmap_size, PROT_READ | PROT_WRITE, MAP_SHARED,
+ fd, mmap_offset);
+
+ close(fd);
+
+ if (rc == MAP_FAILED) {
+ vu_panic(dev, "set_shm_fd mmap error: %s", strerror(errno));
+ return false;
+ }
+
+ if (dev->shm_info.addr) {
+ munmap(dev->shm_info.addr, dev->shm_info.mmap_size);
+ }
+ dev->shm_info.addr = rc;
+ dev->shm_info.mmap_size = mmap_size;
+ dev->shm_info.dev_size = vmsg->payload.shm.dev_size;
+ dev->shm_info.vq_size = vmsg->payload.shm.vq_size;
+ dev->shm_info.align = vmsg->payload.shm.align;
+ dev->shm_info.version = vmsg->payload.shm.version;
+
+ if (vu_setup_shm(dev)) {
+ vu_panic(dev, "setup shm failed");
+ return false;
+ }
+
+ return false;
+}
+
static bool
vu_process_message(VuDev *dev, VhostUserMsg *vmsg)
{
@@ -1292,6 +1454,10 @@ vu_process_message(VuDev *dev, VhostUserMsg *vmsg)
return vu_set_postcopy_listen(dev, vmsg);
case VHOST_USER_POSTCOPY_END:
return vu_set_postcopy_end(dev, vmsg);
+ case VHOST_USER_GET_SHM_SIZE:
+ return vu_get_shm_size(dev, vmsg);
+ case VHOST_USER_SET_SHM_FD:
+ return vu_set_shm_fd(dev, vmsg);
default:
vmsg_close_fds(vmsg);
vu_panic(dev, "Unhandled request: %d", vmsg->request);
@@ -1359,8 +1525,13 @@ vu_deinit(VuDev *dev)
close(vq->err_fd);
vq->err_fd = -1;
}
+ vq->shm = NULL;
}
+ if (dev->shm_info.addr) {
+ munmap(dev->shm_info.addr, dev->shm_info.mmap_size);
+ dev->shm_info.addr = NULL;
+ }
vu_close_log(dev);
if (dev->slave_fd != -1) {
@@ -1829,12 +2000,6 @@ virtqueue_map_desc(VuDev *dev,
*p_num_sg = num_sg;
}
-/* Round number down to multiple */
-#define ALIGN_DOWN(n, m) ((n) / (m) * (m))
-
-/* Round number up to multiple */
-#define ALIGN_UP(n, m) ALIGN_DOWN((n) + (m) - 1, (m))
-
static void *
virtqueue_alloc_element(size_t sz,
unsigned out_num, unsigned in_num)
@@ -1935,9 +2100,44 @@ vu_queue_map_desc(VuDev *dev, VuVirtq *vq, unsigned int idx, size_t sz)
return elem;
}
+static int
+vu_queue_inflight_get(VuDev *dev, VuVirtq *vq, int desc_idx)
+{
+ if ((dev->protocol_features &
+ VHOST_USER_PROTOCOL_F_SLAVE_SHMFD) == 0) {
+ return 0;
+ }
+
+ if (unlikely(!vq->shm)) {
+ return -1;
+ }
+
+ vq->shm->inflight[desc_idx] = 1;
+
+ return 0;
+}
+
+static int
+vu_queue_inflight_put(VuDev *dev, VuVirtq *vq, int desc_idx)
+{
+ if ((dev->protocol_features &
+ VHOST_USER_PROTOCOL_F_SLAVE_SHMFD) == 0) {
+ return 0;
+ }
+
+ if (unlikely(!vq->shm)) {
+ return -1;
+ }
+
+ vq->shm->inflight[desc_idx] = 0;
+
+ return 0;
+}
+
void *
vu_queue_pop(VuDev *dev, VuVirtq *vq, size_t sz)
{
+ int i;
unsigned int head;
VuVirtqElement *elem;
@@ -1946,6 +2146,12 @@ vu_queue_pop(VuDev *dev, VuVirtq *vq, size_t sz)
return NULL;
}
+ if (unlikely(vq->inflight_num > 0)) {
+ i = (--vq->inflight_num);
+ elem = vu_queue_map_desc(dev, vq, vq->inflight_desc[i], sz);
+ return elem;
+ }
+
if (vu_queue_empty(dev, vq)) {
return NULL;
}
@@ -1976,6 +2182,8 @@ vu_queue_pop(VuDev *dev, VuVirtq *vq, size_t sz)
vq->inuse++;
+ vu_queue_inflight_get(dev, vq, head);
+
return elem;
}
@@ -2121,4 +2329,5 @@ vu_queue_push(VuDev *dev, VuVirtq *vq,
{
vu_queue_fill(dev, vq, elem, len, 0);
vu_queue_flush(dev, vq, 1);
+ vu_queue_inflight_put(dev, vq, elem->index);
}
@@ -53,6 +53,7 @@ enum VhostUserProtocolFeature {
VHOST_USER_PROTOCOL_F_CONFIG = 9,
VHOST_USER_PROTOCOL_F_SLAVE_SEND_FD = 10,
VHOST_USER_PROTOCOL_F_HOST_NOTIFIER = 11,
+ VHOST_USER_PROTOCOL_F_SLAVE_SHMFD = 12,
VHOST_USER_PROTOCOL_F_MAX
};
@@ -91,6 +92,8 @@ typedef enum VhostUserRequest {
VHOST_USER_POSTCOPY_ADVISE = 28,
VHOST_USER_POSTCOPY_LISTEN = 29,
VHOST_USER_POSTCOPY_END = 30,
+ VHOST_USER_GET_SHM_SIZE = 31,
+ VHOST_USER_SET_SHM_FD = 32,
VHOST_USER_MAX
} VhostUserRequest;
@@ -138,6 +141,15 @@ typedef struct VhostUserVringArea {
uint64_t offset;
} VhostUserVringArea;
+typedef struct VhostUserShm {
+ uint64_t mmap_size;
+ uint64_t mmap_offset;
+ uint32_t dev_size;
+ uint32_t vq_size;
+ uint32_t align;
+ uint32_t version;
+} VhostUserShm;
+
#if defined(_WIN32)
# define VU_PACKED __attribute__((gcc_struct, packed))
#else
@@ -163,6 +175,7 @@ typedef struct VhostUserMsg {
VhostUserLog log;
VhostUserConfig config;
VhostUserVringArea area;
+ VhostUserShm shm;
} payload;
int fds[VHOST_MEMORY_MAX_NREGIONS];
@@ -234,9 +247,19 @@ typedef struct VuRing {
uint32_t flags;
} VuRing;
+typedef struct VuVirtqShm {
+ char inflight[VIRTQUEUE_MAX_SIZE];
+} VuVirtqShm;
+
typedef struct VuVirtq {
VuRing vring;
+ VuVirtqShm *shm;
+
+ uint16_t inflight_desc[VIRTQUEUE_MAX_SIZE];
+
+ uint16_t inflight_num;
+
/* Next head to pop */
uint16_t last_avail_idx;
@@ -279,11 +302,21 @@ typedef void (*vu_set_watch_cb) (VuDev *dev, int fd, int condition,
vu_watch_cb cb, void *data);
typedef void (*vu_remove_watch_cb) (VuDev *dev, int fd);
+typedef struct VuDevShmInfo {
+ void *addr;
+ uint64_t mmap_size;
+ uint32_t dev_size;
+ uint32_t vq_size;
+ uint32_t align;
+ uint32_t version;
+} VuDevShmInfo;
+
struct VuDev {
int sock;
uint32_t nregions;
VuDevRegion regions[VHOST_MEMORY_MAX_NREGIONS];
VuVirtq vq[VHOST_MAX_NR_VIRTQUEUE];
+ VuDevShmInfo shm_info;
int log_call_fd;
int slave_fd;
uint64_t log_size;