diff mbox series

[v3,4/9] vhost: Add kthread support in function vhost_worker_create

Message ID 20241105072642.898710-5-lulu@redhat.com (mailing list archive)
State Not Applicable
Headers show
Series vhost: Add support of kthread API | expand

Checks

Context Check Description
netdev/tree_selection success Not a local patch

Commit Message

Cindy Lu Nov. 5, 2024, 7:25 a.m. UTC
Restored the previous functions kthread_wakeup and kthread_stop.
Also add a new structure, vhost_task_fn. The function vhost_worker_create
Will initializes this structure based on the value of inherit_owner.

Signed-off-by: Cindy Lu <lulu@redhat.com>
---
 drivers/vhost/vhost.c | 71 ++++++++++++++++++++++++++++++++++++-------
 drivers/vhost/vhost.h |  6 ++++
 2 files changed, 66 insertions(+), 11 deletions(-)

Comments

Jason Wang Nov. 5, 2024, 9:36 a.m. UTC | #1
On Tue, Nov 5, 2024 at 3:27 PM Cindy Lu <lulu@redhat.com> wrote:
>
> Restored the previous functions kthread_wakeup and kthread_stop.
> Also add a new structure, vhost_task_fn. The function vhost_worker_create
> Will initializes this structure based on the value of inherit_owner.
>
> Signed-off-by: Cindy Lu <lulu@redhat.com>
> ---
>  drivers/vhost/vhost.c | 71 ++++++++++++++++++++++++++++++++++++-------
>  drivers/vhost/vhost.h |  6 ++++
>  2 files changed, 66 insertions(+), 11 deletions(-)
>
> diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
> index e40cef3a1fa5..603b146fccc1 100644
> --- a/drivers/vhost/vhost.c
> +++ b/drivers/vhost/vhost.c
> @@ -741,43 +741,92 @@ static void vhost_workers_free(struct vhost_dev *dev)
>         xa_destroy(&dev->worker_xa);
>  }
>
> +static int vhost_task_wakeup_fn(void *vtsk)
> +{
> +       vhost_task_wake((struct vhost_task *)vtsk);
> +       return 0;
> +}

Let's have a newline between two functions.

> +static int vhost_kthread_wakeup_fn(void *p)
> +{
> +       return wake_up_process((struct task_struct *)p);
> +}
> +static int vhost_task_stop_fn(void *vtsk)
> +{
> +       vhost_task_stop((struct vhost_task *)vtsk);
> +       return 0;
> +}
> +static int vhost_kthread_stop_fn(void *k)
> +{
> +       return kthread_stop((struct task_struct *)k);
> +}
> +
>  static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev)
>  {
>         struct vhost_worker *worker;
> -       struct vhost_task *vtsk;
> +       struct vhost_task *vtsk = NULL;
> +       struct task_struct *task = NULL;
>         char name[TASK_COMM_LEN];
>         int ret;
>         u32 id;
>
> +       /* Allocate resources for the worker */
>         worker = kzalloc(sizeof(*worker), GFP_KERNEL_ACCOUNT);
>         if (!worker)
>                 return NULL;
>
> +       worker->fn = kzalloc(sizeof(struct vhost_task_fn), GFP_KERNEL_ACCOUNT);
> +       if (!worker->fn) {
> +               kfree(worker);
> +               return NULL;
> +       }
> +
>         worker->dev = dev;
>         snprintf(name, sizeof(name), "vhost-%d", current->pid);
>
> -       vtsk = vhost_task_create(vhost_run_work_list, vhost_worker_killed,
> -                                worker, name);
> -       if (!vtsk)
> -               goto free_worker;
> -
>         mutex_init(&worker->mutex);
>         init_llist_head(&worker->work_list);
>         worker->kcov_handle = kcov_common_handle();
> -       worker->vtsk = vtsk;
>
> -       vhost_task_start(vtsk);
> +       if (dev->inherit_owner) {
> +               /* Create and start a vhost task */
> +               vtsk = vhost_task_create(vhost_run_work_list,
> +                                        vhost_worker_killed, worker, name);
> +               if (!vtsk)
> +                       goto free_worker;
> +
> +               worker->vtsk = vtsk;
> +               worker->fn->wakeup = vhost_task_wakeup_fn;
> +               worker->fn->stop = vhost_task_stop_fn;
> +
> +               vhost_task_start(vtsk);
> +       } else {
> +               /* Create and start a kernel thread */
> +               task = kthread_create(vhost_run_work_kthread_list, worker,
> +                                     "vhost-%d", current->pid);
> +               if (IS_ERR(task)) {
> +                       ret = PTR_ERR(task);
> +                       goto free_worker;
> +               }
> +               worker->task = task;
> +               worker->fn->wakeup = vhost_kthread_wakeup_fn;
> +               worker->fn->stop = vhost_kthread_stop_fn;
> +
> +               wake_up_process(task);
> +               /* Attach to the vhost cgroup */
> +               ret = vhost_attach_cgroups(dev);
> +               if (ret)
> +                       goto stop_worker;
> +       }
>
>         ret = xa_alloc(&dev->worker_xa, &id, worker, xa_limit_32b, GFP_KERNEL);
>         if (ret < 0)
>                 goto stop_worker;
>         worker->id = id;
> -
>         return worker;
> -
>  stop_worker:
> -       vhost_task_stop(vtsk);
> +       worker->fn->stop(dev->inherit_owner ? (void *)vtsk : (void *)task);
>  free_worker:
> +       kfree(worker->fn);
>         kfree(worker);
>         return NULL;
>  }
> diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
> index c650c4506c70..ebababa4e340 100644
> --- a/drivers/vhost/vhost.h
> +++ b/drivers/vhost/vhost.h
> @@ -25,8 +25,13 @@ struct vhost_work {
>         vhost_work_fn_t         fn;
>         unsigned long           flags;
>  };
> +struct vhost_task_fn {
> +       int (*wakeup)(void *task);

Let's have comments to explain the semantics of each operation.

> +       int (*stop)(void *task);
> +};

I think the goal is to reduce if/else, so while at this, let's
introduce more ops. For example the create_worker one?

>
>  struct vhost_worker {
> +       struct task_struct      *task;
>         struct vhost_task       *vtsk;
>         struct vhost_dev        *dev;
>         /* Used to serialize device wide flushing with worker swapping. */
> @@ -36,6 +41,7 @@ struct vhost_worker {
>         u32                     id;
>         int                     attachment_cnt;
>         bool                    killed;
> +       struct vhost_task_fn *fn;
>  };
>
>  /* Poll a file (eventfd or socket) */
> --
> 2.45.0
>

Thanks
Cindy Lu Nov. 6, 2024, 9:21 a.m. UTC | #2
On Tue, Nov 5, 2024 at 5:36 PM Jason Wang <jasowang@redhat.com> wrote:
>
> On Tue, Nov 5, 2024 at 3:27 PM Cindy Lu <lulu@redhat.com> wrote:
> >
> > Restored the previous functions kthread_wakeup and kthread_stop.
> > Also add a new structure, vhost_task_fn. The function vhost_worker_create
> > Will initializes this structure based on the value of inherit_owner.
> >
> > Signed-off-by: Cindy Lu <lulu@redhat.com>
> > ---
> >  drivers/vhost/vhost.c | 71 ++++++++++++++++++++++++++++++++++++-------
> >  drivers/vhost/vhost.h |  6 ++++
> >  2 files changed, 66 insertions(+), 11 deletions(-)
> >
> > diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
> > index e40cef3a1fa5..603b146fccc1 100644
> > --- a/drivers/vhost/vhost.c
> > +++ b/drivers/vhost/vhost.c
> > @@ -741,43 +741,92 @@ static void vhost_workers_free(struct vhost_dev *dev)
> >         xa_destroy(&dev->worker_xa);
> >  }
> >
> > +static int vhost_task_wakeup_fn(void *vtsk)
> > +{
> > +       vhost_task_wake((struct vhost_task *)vtsk);
> > +       return 0;
> > +}
>
> Let's have a newline between two functions.
>
will fix this
> > +static int vhost_kthread_wakeup_fn(void *p)
> > +{
> > +       return wake_up_process((struct task_struct *)p);
> > +}
> > +static int vhost_task_stop_fn(void *vtsk)
> > +{
> > +       vhost_task_stop((struct vhost_task *)vtsk);
> > +       return 0;
> > +}
> > +static int vhost_kthread_stop_fn(void *k)
> > +{
> > +       return kthread_stop((struct task_struct *)k);
> > +}
> > +
> >  static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev)
> >  {
> >         struct vhost_worker *worker;
> > -       struct vhost_task *vtsk;
> > +       struct vhost_task *vtsk = NULL;
> > +       struct task_struct *task = NULL;
> >         char name[TASK_COMM_LEN];
> >         int ret;
> >         u32 id;
> >
> > +       /* Allocate resources for the worker */
> >         worker = kzalloc(sizeof(*worker), GFP_KERNEL_ACCOUNT);
> >         if (!worker)
> >                 return NULL;
> >
> > +       worker->fn = kzalloc(sizeof(struct vhost_task_fn), GFP_KERNEL_ACCOUNT);
> > +       if (!worker->fn) {
> > +               kfree(worker);
> > +               return NULL;
> > +       }
> > +
> >         worker->dev = dev;
> >         snprintf(name, sizeof(name), "vhost-%d", current->pid);
> >
> > -       vtsk = vhost_task_create(vhost_run_work_list, vhost_worker_killed,
> > -                                worker, name);
> > -       if (!vtsk)
> > -               goto free_worker;
> > -
> >         mutex_init(&worker->mutex);
> >         init_llist_head(&worker->work_list);
> >         worker->kcov_handle = kcov_common_handle();
> > -       worker->vtsk = vtsk;
> >
> > -       vhost_task_start(vtsk);
> > +       if (dev->inherit_owner) {
> > +               /* Create and start a vhost task */
> > +               vtsk = vhost_task_create(vhost_run_work_list,
> > +                                        vhost_worker_killed, worker, name);
> > +               if (!vtsk)
> > +                       goto free_worker;
> > +
> > +               worker->vtsk = vtsk;
> > +               worker->fn->wakeup = vhost_task_wakeup_fn;
> > +               worker->fn->stop = vhost_task_stop_fn;
> > +
> > +               vhost_task_start(vtsk);
> > +       } else {
> > +               /* Create and start a kernel thread */
> > +               task = kthread_create(vhost_run_work_kthread_list, worker,
> > +                                     "vhost-%d", current->pid);
> > +               if (IS_ERR(task)) {
> > +                       ret = PTR_ERR(task);
> > +                       goto free_worker;
> > +               }
> > +               worker->task = task;
> > +               worker->fn->wakeup = vhost_kthread_wakeup_fn;
> > +               worker->fn->stop = vhost_kthread_stop_fn;
> > +
> > +               wake_up_process(task);
> > +               /* Attach to the vhost cgroup */
> > +               ret = vhost_attach_cgroups(dev);
> > +               if (ret)
> > +                       goto stop_worker;
> > +       }
> >
> >         ret = xa_alloc(&dev->worker_xa, &id, worker, xa_limit_32b, GFP_KERNEL);
> >         if (ret < 0)
> >                 goto stop_worker;
> >         worker->id = id;
> > -
> >         return worker;
> > -
> >  stop_worker:
> > -       vhost_task_stop(vtsk);
> > +       worker->fn->stop(dev->inherit_owner ? (void *)vtsk : (void *)task);
> >  free_worker:
> > +       kfree(worker->fn);
> >         kfree(worker);
> >         return NULL;
> >  }
> > diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
> > index c650c4506c70..ebababa4e340 100644
> > --- a/drivers/vhost/vhost.h
> > +++ b/drivers/vhost/vhost.h
> > @@ -25,8 +25,13 @@ struct vhost_work {
> >         vhost_work_fn_t         fn;
> >         unsigned long           flags;
> >  };
> > +struct vhost_task_fn {
> > +       int (*wakeup)(void *task);
>
> Let's have comments to explain the semantics of each operation.
>
sure, will fix this
> > +       int (*stop)(void *task);
> > +};
>
> I think the goal is to reduce if/else, so while at this, let's
> introduce more ops. For example the create_worker one?
>
sure, will change this part
thanks
cindy
> >
> >  struct vhost_worker {
> > +       struct task_struct      *task;
> >         struct vhost_task       *vtsk;
> >         struct vhost_dev        *dev;
> >         /* Used to serialize device wide flushing with worker swapping. */
> > @@ -36,6 +41,7 @@ struct vhost_worker {
> >         u32                     id;
> >         int                     attachment_cnt;
> >         bool                    killed;
> > +       struct vhost_task_fn *fn;
> >  };
> >
> >  /* Poll a file (eventfd or socket) */
> > --
> > 2.45.0
> >
>
> Thanks
>
Mike Christie Nov. 26, 2024, 9:19 p.m. UTC | #3
On 11/5/24 1:25 AM, Cindy Lu wrote:
>  static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev)
>  {
>  	struct vhost_worker *worker;
> -	struct vhost_task *vtsk;
> +	struct vhost_task *vtsk = NULL;
> +	struct task_struct *task = NULL;
>  	char name[TASK_COMM_LEN];
>  	int ret;
>  	u32 id;
>  
> +	/* Allocate resources for the worker */
>  	worker = kzalloc(sizeof(*worker), GFP_KERNEL_ACCOUNT);
>  	if (!worker)
>  		return NULL;
>  
> +	worker->fn = kzalloc(sizeof(struct vhost_task_fn), GFP_KERNEL_ACCOUNT);
> +	if (!worker->fn) {
> +		kfree(worker);
> +		return NULL;
> +	}

Why dynamically allocate this?

You could probably even just kill the vhost_task_fn struct since we just
have to the 2 callouts.

> +
>  	worker->dev = dev;
>  	snprintf(name, sizeof(name), "vhost-%d", current->pid);
>  
> -	vtsk = vhost_task_create(vhost_run_work_list, vhost_worker_killed,
> -				 worker, name);
> -	if (!vtsk)
> -		goto free_worker;
> -
>  	mutex_init(&worker->mutex);
>  	init_llist_head(&worker->work_list);
>  	worker->kcov_handle = kcov_common_handle();
> -	worker->vtsk = vtsk;
>  
> -	vhost_task_start(vtsk);
> +	if (dev->inherit_owner) {
> +		/* Create and start a vhost task */

Maybe instead of this comment and the one below write something about
what inherit_owner means. We can see from the code we are creating a
vhost/kthread, but it's not really obvious why. Something like:

/*
 * If inherit_owner is true we use vhost_tasks to create
 * the worker so all settings/limits like cgroups, NPROC,
 * scheduler, etc are inherited from the owner. If false,
 * we use kthreads and only attach to the same cgroups
 * as the owner for compat with older kernels.
 */



> +		vtsk = vhost_task_create(vhost_run_work_list,
> +					 vhost_worker_killed, worker, name);
> +		if (!vtsk)
> +			goto free_worker;
> +
> +		worker->vtsk = vtsk;
> +		worker->fn->wakeup = vhost_task_wakeup_fn;
> +		worker->fn->stop = vhost_task_stop_fn;
> +
> +		vhost_task_start(vtsk);
> +	} else {
> +		/* Create and start a kernel thread */
> +		task = kthread_create(vhost_run_work_kthread_list, worker,
> +				      "vhost-%d", current->pid);
> +		if (IS_ERR(task)) {
> +			ret = PTR_ERR(task);
> +			goto free_worker;
> +		}
> +		worker->task = task;
> +		worker->fn->wakeup = vhost_kthread_wakeup_fn;
> +		worker->fn->stop = vhost_kthread_stop_fn;
> +
> +		wake_up_process(task);
> +		/* Attach to the vhost cgroup */

You don't need this comment do you? The function name tells us the same
info.

> +		ret = vhost_attach_cgroups(dev);

I don't think this works. Patch 3/9 did:

+	xa_for_each(&dev->worker_xa, i, worker) {
+		ret = vhost_worker_cgroups_kthread(worker);

but we don't add the worker to the xa until below.

You also want to just call vhost_worker_cgroups_kthread above, because
you only want to add the one task and not loop over all of them.

I would then also maybe rename vhost_worker_cgroups_kthread to something
like vhost_attach_task_to_cgroups.



> +		if (ret)
> +			goto stop_worker;
> +	}
>  
>  	ret = xa_alloc(&dev->worker_xa, &id, worker, xa_limit_32b, GFP_KERNEL);
>  	if (ret < 0)
>  		goto stop_worker;
>  	worker->id = id;
> -
>  	return worker;
> -
>  stop_worker:
> -	vhost_task_stop(vtsk);
> +	worker->fn->stop(dev->inherit_owner ? (void *)vtsk : (void *)task);

I don't think you need to cast since the function takes a void pointer.
Same comment for the other patches like 6/9 where you are calling the
callout and casting.
Cindy Lu Nov. 27, 2024, 6:43 a.m. UTC | #4
On Wed, Nov 27, 2024 at 5:20 AM <michael.christie@oracle.com> wrote:
>
> On 11/5/24 1:25 AM, Cindy Lu wrote:
> >  static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev)
> >  {
> >       struct vhost_worker *worker;
> > -     struct vhost_task *vtsk;
> > +     struct vhost_task *vtsk = NULL;
> > +     struct task_struct *task = NULL;
> >       char name[TASK_COMM_LEN];
> >       int ret;
> >       u32 id;
> >
> > +     /* Allocate resources for the worker */
> >       worker = kzalloc(sizeof(*worker), GFP_KERNEL_ACCOUNT);
> >       if (!worker)
> >               return NULL;
> >
> > +     worker->fn = kzalloc(sizeof(struct vhost_task_fn), GFP_KERNEL_ACCOUNT);
> > +     if (!worker->fn) {
> > +             kfree(worker);
> > +             return NULL;
> > +     }
>
> Why dynamically allocate this?
>
> You could probably even just kill the vhost_task_fn struct since we just
> have to the 2 callouts.

sure, will change this
>
>
> > +
> >       worker->dev = dev;
> >       snprintf(name, sizeof(name), "vhost-%d", current->pid);
> >
> > -     vtsk = vhost_task_create(vhost_run_work_list, vhost_worker_killed,
> > -                              worker, name);
> > -     if (!vtsk)
> > -             goto free_worker;
> > -
> >       mutex_init(&worker->mutex);
> >       init_llist_head(&worker->work_list);
> >       worker->kcov_handle = kcov_common_handle();
> > -     worker->vtsk = vtsk;
> >
> > -     vhost_task_start(vtsk);
> > +     if (dev->inherit_owner) {
> > +             /* Create and start a vhost task */
>
> Maybe instead of this comment and the one below write something about
> what inherit_owner means. We can see from the code we are creating a
> vhost/kthread, but it's not really obvious why. Something like:
>
> /*
>  * If inherit_owner is true we use vhost_tasks to create
>  * the worker so all settings/limits like cgroups, NPROC,
>  * scheduler, etc are inherited from the owner. If false,
>  * we use kthreads and only attach to the same cgroups
>  * as the owner for compat with older kernels.
>  */
>
Thanks, Mike, I will change this

>
>
> > +             vtsk = vhost_task_create(vhost_run_work_list,
> > +                                      vhost_worker_killed, worker, name);
> > +             if (!vtsk)
> > +                     goto free_worker;
> > +
> > +             worker->vtsk = vtsk;
> > +             worker->fn->wakeup = vhost_task_wakeup_fn;
> > +             worker->fn->stop = vhost_task_stop_fn;
> > +
> > +             vhost_task_start(vtsk);
> > +     } else {
> > +             /* Create and start a kernel thread */
> > +             task = kthread_create(vhost_run_work_kthread_list, worker,
> > +                                   "vhost-%d", current->pid);
> > +             if (IS_ERR(task)) {
> > +                     ret = PTR_ERR(task);
> > +                     goto free_worker;
> > +             }
> > +             worker->task = task;
> > +             worker->fn->wakeup = vhost_kthread_wakeup_fn;
> > +             worker->fn->stop = vhost_kthread_stop_fn;
> > +
> > +             wake_up_process(task);
> > +             /* Attach to the vhost cgroup */
>
> You don't need this comment do you? The function name tells us the same
> info.
>
sure, Will remove  this
> > +             ret = vhost_attach_cgroups(dev);
>
> I don't think this works. Patch 3/9 did:
>
> +       xa_for_each(&dev->worker_xa, i, worker) {
> +               ret = vhost_worker_cgroups_kthread(worker);
>
> but we don't add the worker to the xa until below.
>
> You also want to just call vhost_worker_cgroups_kthread above, because
> you only want to add the one task and not loop over all of them.
>
> I would then also maybe rename vhost_worker_cgroups_kthread to something
> like vhost_attach_task_to_cgroups.
>
>
Will fix this. Thanks
>
> > +             if (ret)
> > +                     goto stop_worker;
> > +     }
> >
> >       ret = xa_alloc(&dev->worker_xa, &id, worker, xa_limit_32b, GFP_KERNEL);
> >       if (ret < 0)
> >               goto stop_worker;
> >       worker->id = id;
> > -
> >       return worker;
> > -
> >  stop_worker:
> > -     vhost_task_stop(vtsk);
> > +     worker->fn->stop(dev->inherit_owner ? (void *)vtsk : (void *)task);
>
> I don't think you need to cast since the function takes a void pointer.
> Same comment for the other patches like 6/9 where you are calling the
> callout and casting.
>
Sure, Thanks I will rewrite this part
Thanks
cindy
diff mbox series

Patch

diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index e40cef3a1fa5..603b146fccc1 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -741,43 +741,92 @@  static void vhost_workers_free(struct vhost_dev *dev)
 	xa_destroy(&dev->worker_xa);
 }
 
+static int vhost_task_wakeup_fn(void *vtsk)
+{
+	vhost_task_wake((struct vhost_task *)vtsk);
+	return 0;
+}
+static int vhost_kthread_wakeup_fn(void *p)
+{
+	return wake_up_process((struct task_struct *)p);
+}
+static int vhost_task_stop_fn(void *vtsk)
+{
+	vhost_task_stop((struct vhost_task *)vtsk);
+	return 0;
+}
+static int vhost_kthread_stop_fn(void *k)
+{
+	return kthread_stop((struct task_struct *)k);
+}
+
 static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev)
 {
 	struct vhost_worker *worker;
-	struct vhost_task *vtsk;
+	struct vhost_task *vtsk = NULL;
+	struct task_struct *task = NULL;
 	char name[TASK_COMM_LEN];
 	int ret;
 	u32 id;
 
+	/* Allocate resources for the worker */
 	worker = kzalloc(sizeof(*worker), GFP_KERNEL_ACCOUNT);
 	if (!worker)
 		return NULL;
 
+	worker->fn = kzalloc(sizeof(struct vhost_task_fn), GFP_KERNEL_ACCOUNT);
+	if (!worker->fn) {
+		kfree(worker);
+		return NULL;
+	}
+
 	worker->dev = dev;
 	snprintf(name, sizeof(name), "vhost-%d", current->pid);
 
-	vtsk = vhost_task_create(vhost_run_work_list, vhost_worker_killed,
-				 worker, name);
-	if (!vtsk)
-		goto free_worker;
-
 	mutex_init(&worker->mutex);
 	init_llist_head(&worker->work_list);
 	worker->kcov_handle = kcov_common_handle();
-	worker->vtsk = vtsk;
 
-	vhost_task_start(vtsk);
+	if (dev->inherit_owner) {
+		/* Create and start a vhost task */
+		vtsk = vhost_task_create(vhost_run_work_list,
+					 vhost_worker_killed, worker, name);
+		if (!vtsk)
+			goto free_worker;
+
+		worker->vtsk = vtsk;
+		worker->fn->wakeup = vhost_task_wakeup_fn;
+		worker->fn->stop = vhost_task_stop_fn;
+
+		vhost_task_start(vtsk);
+	} else {
+		/* Create and start a kernel thread */
+		task = kthread_create(vhost_run_work_kthread_list, worker,
+				      "vhost-%d", current->pid);
+		if (IS_ERR(task)) {
+			ret = PTR_ERR(task);
+			goto free_worker;
+		}
+		worker->task = task;
+		worker->fn->wakeup = vhost_kthread_wakeup_fn;
+		worker->fn->stop = vhost_kthread_stop_fn;
+
+		wake_up_process(task);
+		/* Attach to the vhost cgroup */
+		ret = vhost_attach_cgroups(dev);
+		if (ret)
+			goto stop_worker;
+	}
 
 	ret = xa_alloc(&dev->worker_xa, &id, worker, xa_limit_32b, GFP_KERNEL);
 	if (ret < 0)
 		goto stop_worker;
 	worker->id = id;
-
 	return worker;
-
 stop_worker:
-	vhost_task_stop(vtsk);
+	worker->fn->stop(dev->inherit_owner ? (void *)vtsk : (void *)task);
 free_worker:
+	kfree(worker->fn);
 	kfree(worker);
 	return NULL;
 }
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index c650c4506c70..ebababa4e340 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -25,8 +25,13 @@  struct vhost_work {
 	vhost_work_fn_t		fn;
 	unsigned long		flags;
 };
+struct vhost_task_fn {
+	int (*wakeup)(void *task);
+	int (*stop)(void *task);
+};
 
 struct vhost_worker {
+	struct task_struct	*task;
 	struct vhost_task	*vtsk;
 	struct vhost_dev	*dev;
 	/* Used to serialize device wide flushing with worker swapping. */
@@ -36,6 +41,7 @@  struct vhost_worker {
 	u32			id;
 	int			attachment_cnt;
 	bool			killed;
+	struct vhost_task_fn *fn;
 };
 
 /* Poll a file (eventfd or socket) */