@@ -2334,13 +2334,12 @@ static void unfetch_descs(struct vhost_virtqueue *vq)
* number of output then some number of input descriptors, it's actually two
* iovecs, but we pack them into one and note how many of each there were.
*
- * This function returns the descriptor number found, or vq->num (which is
- * never a valid descriptor number) if none was found. A negative code is
- * returned on error. */
-int vhost_get_vq_desc(struct vhost_virtqueue *vq,
- struct iovec iov[], unsigned int iov_size,
- unsigned int *out_num, unsigned int *in_num,
- struct vhost_log *log, unsigned int *log_num)
+ * This function returns a value > 0 if a descriptor was found, or 0 if none were found.
+ * A negative code is returned on error. */
+int vhost_get_avail_buf(struct vhost_virtqueue *vq, struct vhost_buf *buf,
+ struct iovec iov[], unsigned int iov_size,
+ unsigned int *out_num, unsigned int *in_num,
+ struct vhost_log *log, unsigned int *log_num)
{
int ret = fetch_descs(vq);
int i;
@@ -2353,6 +2352,8 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
*out_num = *in_num = 0;
if (unlikely(log))
*log_num = 0;
+ buf->in_len = buf->out_len = 0;
+ buf->descs = 0;
for (i = vq->first_desc; i < vq->ndescs; ++i) {
unsigned iov_count = *in_num + *out_num;
@@ -2382,6 +2383,7 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
/* If this is an input descriptor,
* increment that count. */
*in_num += ret;
+ buf->in_len += desc->len;
if (unlikely(log && ret)) {
log[*log_num].addr = desc->addr;
log[*log_num].len = desc->len;
@@ -2397,9 +2399,11 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
goto err;
}
*out_num += ret;
+ buf->out_len += desc->len;
}
- ret = desc->id;
+ buf->id = desc->id;
+ ++buf->descs;
if (!(desc->flags & VRING_DESC_F_NEXT))
break;
@@ -2407,12 +2411,41 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
vq->first_desc = i + 1;
- return ret;
+ return 1;
err:
unfetch_descs(vq);
- return ret ? ret : vq->num;
+ return ret;
+}
+EXPORT_SYMBOL_GPL(vhost_get_avail_buf);
+
+/* Reverse the effect of vhost_get_avail_buf. Useful for error handling. */
+void vhost_discard_avail_bufs(struct vhost_virtqueue *vq,
+ struct vhost_buf *buf, unsigned count)
+{
+ vhost_discard_vq_desc(vq, count);
+}
+EXPORT_SYMBOL_GPL(vhost_discard_avail_bufs);
+
+/* This function returns the descriptor number found, or vq->num (which is
+ * never a valid descriptor number) if none was found. A negative code is
+ * returned on error. */
+int vhost_get_vq_desc(struct vhost_virtqueue *vq,
+ struct iovec iov[], unsigned int iov_size,
+ unsigned int *out_num, unsigned int *in_num,
+ struct vhost_log *log, unsigned int *log_num)
+{
+ struct vhost_buf buf;
+ int ret = vhost_get_avail_buf(vq, &buf,
+ iov, iov_size, out_num, in_num,
+ log, log_num);
+
+ if (likely(ret > 0))
+ return buf.id;
+ if (likely(!ret))
+ return vq->num;
+ return ret;
}
EXPORT_SYMBOL_GPL(vhost_get_vq_desc);
@@ -2506,6 +2539,26 @@ int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len)
}
EXPORT_SYMBOL_GPL(vhost_add_used);
+int vhost_put_used_buf(struct vhost_virtqueue *vq, struct vhost_buf *buf)
+{
+ return vhost_add_used(vq, buf->id, buf->in_len);
+}
+EXPORT_SYMBOL_GPL(vhost_put_used_buf);
+
+int vhost_put_used_n_bufs(struct vhost_virtqueue *vq,
+ struct vhost_buf *bufs, unsigned count)
+{
+ unsigned i;
+
+ for (i = 0; i < count; ++i) {
+ vq->heads[i].id = cpu_to_vhost32(vq, bufs[i].id);
+ vq->heads[i].len = cpu_to_vhost32(vq, bufs[i].in_len);
+ }
+
+ return vhost_add_used_n(vq, vq->heads, count);
+}
+EXPORT_SYMBOL_GPL(vhost_put_used_n_bufs);
+
static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
{
__u16 old, new;
@@ -67,6 +67,13 @@ struct vhost_desc {
u16 id;
};
+struct vhost_buf {
+ u32 out_len;
+ u32 in_len;
+ u16 descs;
+ u16 id;
+};
+
/* The virtqueue structure describes a queue attached to a device. */
struct vhost_virtqueue {
struct vhost_dev *dev;
@@ -195,7 +202,12 @@ int vhost_get_vq_desc(struct vhost_virtqueue *,
unsigned int *out_num, unsigned int *in_num,
struct vhost_log *log, unsigned int *log_num);
void vhost_discard_vq_desc(struct vhost_virtqueue *, int n);
-
+int vhost_get_avail_buf(struct vhost_virtqueue *, struct vhost_buf *buf,
+ struct iovec iov[], unsigned int iov_count,
+ unsigned int *out_num, unsigned int *in_num,
+ struct vhost_log *log, unsigned int *log_num);
+void vhost_discard_avail_bufs(struct vhost_virtqueue *,
+ struct vhost_buf *, unsigned count);
int vhost_vq_init_access(struct vhost_virtqueue *);
int vhost_add_used(struct vhost_virtqueue *, unsigned int head, int len);
int vhost_add_used_n(struct vhost_virtqueue *, struct vring_used_elem *heads,
@@ -204,6 +216,9 @@ void vhost_add_used_and_signal(struct vhost_dev *, struct vhost_virtqueue *,
unsigned int id, int len);
void vhost_add_used_and_signal_n(struct vhost_dev *, struct vhost_virtqueue *,
struct vring_used_elem *heads, unsigned count);
+int vhost_put_used_buf(struct vhost_virtqueue *, struct vhost_buf *buf);
+int vhost_put_used_n_bufs(struct vhost_virtqueue *,
+ struct vhost_buf *bufs, unsigned count);
void vhost_signal(struct vhost_dev *, struct vhost_virtqueue *);
void vhost_disable_notify(struct vhost_dev *, struct vhost_virtqueue *);
bool vhost_vq_avail_empty(struct vhost_dev *, struct vhost_virtqueue *);