@@ -1326,7 +1326,7 @@ static int vhost_net_open(struct inode *inode, struct file *f)
}
vhost_dev_init(dev, vqs, VHOST_NET_VQ_MAX,
UIO_MAXIOV + VHOST_NET_BATCH,
- VHOST_NET_PKT_WEIGHT, VHOST_NET_WEIGHT,
+ VHOST_NET_PKT_WEIGHT, VHOST_NET_WEIGHT, true,
NULL);
vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, EPOLLOUT, dev);
@@ -1628,7 +1628,7 @@ static int vhost_scsi_open(struct inode *inode, struct file *f)
vs->vqs[i].vq.handle_kick = vhost_scsi_handle_kick;
}
vhost_dev_init(&vs->dev, vqs, VHOST_SCSI_MAX_VQ, UIO_MAXIOV,
- VHOST_SCSI_WEIGHT, 0, NULL);
+ VHOST_SCSI_WEIGHT, 0, true, NULL);
vhost_scsi_init_inflight(vs, NULL);
@@ -696,7 +696,7 @@ static int vhost_vdpa_open(struct inode *inode, struct file *filep)
vqs[i] = &v->vqs[i];
vqs[i]->handle_kick = handle_vq_kick;
}
- vhost_dev_init(dev, vqs, nvqs, 0, 0, 0,
+ vhost_dev_init(dev, vqs, nvqs, 0, 0, 0, false,
vhost_vdpa_process_iotlb_msg);
dev->iotlb = vhost_iotlb_alloc(0, 0);
@@ -166,11 +166,16 @@ static int vhost_poll_wakeup(wait_queue_entry_t *wait, unsigned mode, int sync,
void *key)
{
struct vhost_poll *poll = container_of(wait, struct vhost_poll, wait);
+ struct vhost_work *work = &poll->work;
if (!(key_to_poll(key) & poll->mask))
return 0;
- vhost_poll_queue(poll);
+ if (!poll->dev->use_worker)
+ work->fn(work);
+ else
+ vhost_poll_queue(poll);
+
return 0;
}
@@ -454,6 +459,7 @@ static size_t vhost_get_desc_size(struct vhost_virtqueue *vq,
void vhost_dev_init(struct vhost_dev *dev,
struct vhost_virtqueue **vqs, int nvqs,
int iov_limit, int weight, int byte_weight,
+ bool use_worker,
int (*msg_handler)(struct vhost_dev *dev,
struct vhost_iotlb_msg *msg))
{
@@ -471,6 +477,7 @@ void vhost_dev_init(struct vhost_dev *dev,
dev->iov_limit = iov_limit;
dev->weight = weight;
dev->byte_weight = byte_weight;
+ dev->use_worker = use_worker;
dev->msg_handler = msg_handler;
init_llist_head(&dev->work_list);
init_waitqueue_head(&dev->wait);
@@ -549,18 +556,21 @@ long vhost_dev_set_owner(struct vhost_dev *dev)
/* No owner, become one */
dev->mm = get_task_mm(current);
dev->kcov_handle = kcov_common_handle();
- worker = kthread_create(vhost_worker, dev, "vhost-%d", current->pid);
- if (IS_ERR(worker)) {
- err = PTR_ERR(worker);
- goto err_worker;
- }
+ if (dev->use_worker) {
+ worker = kthread_create(vhost_worker, dev,
+ "vhost-%d", current->pid);
+ if (IS_ERR(worker)) {
+ err = PTR_ERR(worker);
+ goto err_worker;
+ }
- dev->worker = worker;
- wake_up_process(worker); /* avoid contributing to loadavg */
+ dev->worker = worker;
+ wake_up_process(worker); /* avoid contributing to loadavg */
- err = vhost_attach_cgroups(dev);
- if (err)
- goto err_cgroup;
+ err = vhost_attach_cgroups(dev);
+ if (err)
+ goto err_cgroup;
+ }
err = vhost_dev_alloc_iovecs(dev);
if (err)
@@ -568,8 +578,10 @@ long vhost_dev_set_owner(struct vhost_dev *dev)
return 0;
err_cgroup:
- kthread_stop(worker);
- dev->worker = NULL;
+ if (dev->worker) {
+ kthread_stop(dev->worker);
+ dev->worker = NULL;
+ }
err_worker:
if (dev->mm)
mmput(dev->mm);
@@ -154,6 +154,7 @@ struct vhost_dev {
int weight;
int byte_weight;
u64 kcov_handle;
+ bool use_worker;
int (*msg_handler)(struct vhost_dev *dev,
struct vhost_iotlb_msg *msg);
};
@@ -161,6 +162,7 @@ struct vhost_dev {
bool vhost_exceeds_weight(struct vhost_virtqueue *vq, int pkts, int total_len);
void vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs,
int nvqs, int iov_limit, int weight, int byte_weight,
+ bool use_worker,
int (*msg_handler)(struct vhost_dev *dev,
struct vhost_iotlb_msg *msg));
long vhost_dev_set_owner(struct vhost_dev *dev);
@@ -621,7 +621,7 @@ static int vhost_vsock_dev_open(struct inode *inode, struct file *file)
vhost_dev_init(&vsock->dev, vqs, ARRAY_SIZE(vsock->vqs),
UIO_MAXIOV, VHOST_VSOCK_PKT_WEIGHT,
- VHOST_VSOCK_WEIGHT, NULL);
+ VHOST_VSOCK_WEIGHT, true, NULL);
file->private_data = vsock;
spin_lock_init(&vsock->send_pkt_list_lock);