diff mbox series

[v3,4/4] ksmbd: smbd: handle multiple Buffer descriptors

Message ID 20220418051412.13193-4-hyc.lee@gmail.com (mailing list archive)
State New, archived
Headers show
Series [v3,1/4] ksmbd: smbd: change prototypes of RDMA read/write related functions | expand

Commit Message

Hyunchul Lee April 18, 2022, 5:14 a.m. UTC
Make ksmbd handle multiple buffer descriptors
when reading and writing files using SMB direct:
Post the work requests of rdma_rw_ctx for
RDMA read/write in smb_direct_rdma_xmit(), and
the work request for the READ/WRITE response
with a remote invalidation in smb_direct_writev().

Signed-off-by: Hyunchul Lee <hyc.lee@gmail.com>
---
changes from v2:
 - Split a v2 patch to 4 patches.

 fs/ksmbd/smb2pdu.c        |   5 +-
 fs/ksmbd/transport_rdma.c | 166 +++++++++++++++++++++++++-------------
 2 files changed, 109 insertions(+), 62 deletions(-)

Comments

Namjae Jeon April 23, 2022, 2:38 p.m. UTC | #1
2022-04-18 14:14 GMT+09:00, Hyunchul Lee <hyc.lee@gmail.com>:
> Make ksmbd handle multiple buffer descriptors
> when reading and writing files using SMB direct:
> Post the work requests of rdma_rw_ctx for
> RDMA read/write in smb_direct_rdma_xmit(), and
> the work request for the READ/WRITE response
> with a remote invalidation in smb_direct_writev().
>
> Signed-off-by: Hyunchul Lee <hyc.lee@gmail.com>
> ---
> changes from v2:
>  - Split a v2 patch to 4 patches.
>
>  fs/ksmbd/smb2pdu.c        |   5 +-
>  fs/ksmbd/transport_rdma.c | 166 +++++++++++++++++++++++++-------------
>  2 files changed, 109 insertions(+), 62 deletions(-)
>
> diff --git a/fs/ksmbd/smb2pdu.c b/fs/ksmbd/smb2pdu.c
> index fc9b8def50df..621fa3e55fab 100644
> --- a/fs/ksmbd/smb2pdu.c
> +++ b/fs/ksmbd/smb2pdu.c
> @@ -6133,11 +6133,8 @@ static int smb2_set_remote_key_for_rdma(struct
> ksmbd_work *work,
>  				le32_to_cpu(desc[i].length));
>  		}
>  	}
> -	if (ch_count != 1) {
> -		ksmbd_debug(RDMA, "RDMA multiple buffer descriptors %d are not supported
> yet\n",
> -			    ch_count);
> +	if (!ch_count)
>  		return -EINVAL;
> -	}
>
>  	work->need_invalidate_rkey =
>  		(Channel == SMB2_CHANNEL_RDMA_V1_INVALIDATE);
> diff --git a/fs/ksmbd/transport_rdma.c b/fs/ksmbd/transport_rdma.c
> index 1343ff8e00fd..410b79edc9f2 100644
> --- a/fs/ksmbd/transport_rdma.c
> +++ b/fs/ksmbd/transport_rdma.c
> @@ -208,7 +208,9 @@ struct smb_direct_recvmsg {
>  struct smb_direct_rdma_rw_msg {
>  	struct smb_direct_transport	*t;
>  	struct ib_cqe		cqe;
> +	int			status;
>  	struct completion	*completion;
> +	struct list_head	list;
>  	struct rdma_rw_ctx	rw_ctx;
>  	struct sg_table		sgt;
>  	struct scatterlist	sg_list[];
> @@ -1313,6 +1315,18 @@ static int smb_direct_writev(struct ksmbd_transport
> *t,
>  	return ret;
>  }
>
> +static void smb_direct_free_rdma_rw_msg(struct smb_direct_transport *t,
> +					struct smb_direct_rdma_rw_msg *msg,
> +					enum dma_data_direction dir)
> +{
> +	if (msg->sgt.orig_nents) {
Is there any case where orig_ent is 0?
> +		rdma_rw_ctx_destroy(&msg->rw_ctx, t->qp, t->qp->port,
> +				    msg->sgt.sgl, msg->sgt.nents, dir);
> +		sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE);
> +	}
> +	kfree(msg);
> +}
> +
>  static void read_write_done(struct ib_cq *cq, struct ib_wc *wc,
>  			    enum dma_data_direction dir)
>  {
> @@ -1321,19 +1335,14 @@ static void read_write_done(struct ib_cq *cq, struct
> ib_wc *wc,
>  	struct smb_direct_transport *t = msg->t;
>
>  	if (wc->status != IB_WC_SUCCESS) {
> +		msg->status = -EIO;
>  		pr_err("read/write error. opcode = %d, status = %s(%d)\n",
>  		       wc->opcode, ib_wc_status_msg(wc->status), wc->status);
> -		smb_direct_disconnect_rdma_connection(t);
> +		if (wc->status != IB_WC_WR_FLUSH_ERR)
Why is this condition needed ?
> +			smb_direct_disconnect_rdma_connection(t);
>  	}
>
> -	if (atomic_inc_return(&t->rw_credits) > 0)
> -		wake_up(&t->wait_rw_credits);
> -
> -	rdma_rw_ctx_destroy(&msg->rw_ctx, t->qp, t->qp->port,
> -			    msg->sg_list, msg->sgt.nents, dir);
> -	sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE);
>  	complete(msg->completion);
> -	kfree(msg);
>  }
>
>  static void read_done(struct ib_cq *cq, struct ib_wc *wc)
> @@ -1352,75 +1361,116 @@ static int smb_direct_rdma_xmit(struct
> smb_direct_transport *t,
>  				unsigned int desc_len,
>  				bool is_read)
>  {
> -	struct smb_direct_rdma_rw_msg *msg;
> -	int ret;
> +	struct smb_direct_rdma_rw_msg *msg, *next_msg;
> +	int i, ret;
>  	DECLARE_COMPLETION_ONSTACK(completion);
> -	struct ib_send_wr *first_wr = NULL;
> -	u32 remote_key = le32_to_cpu(desc[0].token);
> -	u64 remote_offset = le64_to_cpu(desc[0].offset);
> +	struct ib_send_wr *first_wr;
> +	LIST_HEAD(msg_list);
> +	char *desc_buf;
>  	int credits_needed;
> +	unsigned int desc_buf_len;
> +	size_t total_length = 0;
> +
> +	if (t->status != SMB_DIRECT_CS_CONNECTED)
> +		return -ENOTCONN;
> +
> +	/* calculate needed credits */
> +	credits_needed = 0;
> +	desc_buf = buf;
> +	for (i = 0; i < desc_len / sizeof(*desc); i++) {
> +		desc_buf_len = le32_to_cpu(desc[i].length);
> +
> +		credits_needed += calc_rw_credits(t, desc_buf, desc_buf_len);
> +		desc_buf += desc_buf_len;
> +		total_length += desc_buf_len;
> +		if (desc_buf_len == 0 || total_length > buf_len ||
> +		    total_length > t->max_rdma_rw_size)
> +			return -EINVAL;
> +	}
> +
> +	ksmbd_debug(RDMA, "RDMA %s, len %#x, needed credits %#x\n",
> +		    is_read ? "read" : "write", buf_len, credits_needed);
>
> -	credits_needed = calc_rw_credits(t, buf, buf_len);
>  	ret = wait_for_rw_credits(t, credits_needed);
>  	if (ret < 0)
>  		return ret;
>
> -	/* TODO: mempool */
> -	msg = kmalloc(offsetof(struct smb_direct_rdma_rw_msg, sg_list) +
> -		      sizeof(struct scatterlist) * SG_CHUNK_SIZE, GFP_KERNEL);
> -	if (!msg) {
> -		atomic_add(credits_needed, &t->rw_credits);
> -		return -ENOMEM;
> -	}
> +	/* build rdma_rw_ctx for each descriptor */
> +	desc_buf = buf;
> +	for (i = 0; i < desc_len / sizeof(*desc); i++) {
> +		msg = kzalloc(offsetof(struct smb_direct_rdma_rw_msg, sg_list) +
> +			      sizeof(struct scatterlist) * SG_CHUNK_SIZE, GFP_KERNEL);
> +		if (!msg) {
> +			ret = -ENOMEM;
> +			goto out;
> +		}
>
> -	msg->sgt.sgl = &msg->sg_list[0];
> -	ret = sg_alloc_table_chained(&msg->sgt,
> -				     get_buf_page_count(buf, buf_len),
> -				     msg->sg_list, SG_CHUNK_SIZE);
> -	if (ret) {
> -		atomic_add(credits_needed, &t->rw_credits);
> -		kfree(msg);
> -		return -ENOMEM;
> -	}
> +		desc_buf_len = le32_to_cpu(desc[i].length);
>
> -	ret = get_sg_list(buf, buf_len, msg->sgt.sgl, msg->sgt.orig_nents);
> -	if (ret <= 0) {
> -		pr_err("failed to get pages\n");
> -		goto err;
> -	}
> +		msg->t = t;
> +		msg->cqe.done = is_read ? read_done : write_done;
> +		msg->completion = &completion;
>
> -	ret = rdma_rw_ctx_init(&msg->rw_ctx, t->qp, t->qp->port,
> -			       msg->sg_list, get_buf_page_count(buf, buf_len),
> -			       0, remote_offset, remote_key,
> -			       is_read ? DMA_FROM_DEVICE : DMA_TO_DEVICE);
> -	if (ret < 0) {
> -		pr_err("failed to init rdma_rw_ctx: %d\n", ret);
> -		goto err;
> +		msg->sgt.sgl = &msg->sg_list[0];
> +		ret = sg_alloc_table_chained(&msg->sgt,
> +					     get_buf_page_count(desc_buf, desc_buf_len),
> +					     msg->sg_list, SG_CHUNK_SIZE);
> +		if (ret) {
> +			kfree(msg);
> +			ret = -ENOMEM;
> +			goto out;
> +		}
> +
> +		ret = get_sg_list(desc_buf, desc_buf_len,
> +				  msg->sgt.sgl, msg->sgt.orig_nents);
> +		if (ret <= 0) {
Is there any problem if this function returns 0?
> +			sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE);
> +			kfree(msg);
> +			goto out;
> +		}
> +
> +		ret = rdma_rw_ctx_init(&msg->rw_ctx, t->qp, t->qp->port,
> +				       msg->sgt.sgl,
> +				       get_buf_page_count(desc_buf, desc_buf_len),
> +				       0,
> +				       le64_to_cpu(desc[i].offset),
> +				       le32_to_cpu(desc[i].token),
> +				       is_read ? DMA_FROM_DEVICE : DMA_TO_DEVICE);
> +		if (ret < 0) {
> +			pr_err("failed to init rdma_rw_ctx: %d\n", ret);
> +			sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE);
> +			kfree(msg);
> +			goto out;
> +		}
> +
> +		list_add_tail(&msg->list, &msg_list);
> +		desc_buf += desc_buf_len;
>  	}
>
> -	msg->t = t;
> -	msg->cqe.done = is_read ? read_done : write_done;
> -	msg->completion = &completion;
> -	first_wr = rdma_rw_ctx_wrs(&msg->rw_ctx, t->qp, t->qp->port,
> -				   &msg->cqe, NULL);
> +	/* concatenate work requests of rdma_rw_ctxs */
> +	first_wr = NULL;
> +	list_for_each_entry_reverse(msg, &msg_list, list) {
> +		first_wr = rdma_rw_ctx_wrs(&msg->rw_ctx, t->qp, t->qp->port,
> +					   &msg->cqe, first_wr);
> +	}
>
>  	ret = ib_post_send(t->qp, first_wr, NULL);
>  	if (ret) {
> -		pr_err("failed to post send wr: %d\n", ret);
> -		goto err;
> +		pr_err("failed to post send wr for RDMA R/W: %d\n", ret);
> +		goto out;
>  	}
>
> +	msg = list_last_entry(&msg_list, struct smb_direct_rdma_rw_msg, list);
>  	wait_for_completion(&completion);
> -	return 0;
> -
> -err:
> +	ret = msg->status;
> +out:
> +	list_for_each_entry_safe(msg, next_msg, &msg_list, list) {
> +		list_del(&msg->list);
> +		smb_direct_free_rdma_rw_msg(t, msg,
> +					    is_read ? DMA_FROM_DEVICE : DMA_TO_DEVICE);
> +	}
>  	atomic_add(credits_needed, &t->rw_credits);
> -	if (first_wr)
> -		rdma_rw_ctx_destroy(&msg->rw_ctx, t->qp, t->qp->port,
> -				    msg->sg_list, msg->sgt.nents,
> -				    is_read ? DMA_FROM_DEVICE : DMA_TO_DEVICE);
> -	sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE);
> -	kfree(msg);
> +	wake_up(&t->wait_rw_credits);
>  	return ret;
>  }
>
> --
> 2.25.1
>
>
Hyunchul Lee April 25, 2022, 6:27 a.m. UTC | #2
2022년 4월 23일 (토) 오후 11:38, Namjae Jeon <linkinjeon@kernel.org>님이 작성:
>
> 2022-04-18 14:14 GMT+09:00, Hyunchul Lee <hyc.lee@gmail.com>:
> > Make ksmbd handle multiple buffer descriptors
> > when reading and writing files using SMB direct:
> > Post the work requests of rdma_rw_ctx for
> > RDMA read/write in smb_direct_rdma_xmit(), and
> > the work request for the READ/WRITE response
> > with a remote invalidation in smb_direct_writev().
> >
> > Signed-off-by: Hyunchul Lee <hyc.lee@gmail.com>
> > ---
> > changes from v2:
> >  - Split a v2 patch to 4 patches.
> >
> >  fs/ksmbd/smb2pdu.c        |   5 +-
> >  fs/ksmbd/transport_rdma.c | 166 +++++++++++++++++++++++++-------------
> >  2 files changed, 109 insertions(+), 62 deletions(-)
> >
> > diff --git a/fs/ksmbd/smb2pdu.c b/fs/ksmbd/smb2pdu.c
> > index fc9b8def50df..621fa3e55fab 100644
> > --- a/fs/ksmbd/smb2pdu.c
> > +++ b/fs/ksmbd/smb2pdu.c
> > @@ -6133,11 +6133,8 @@ static int smb2_set_remote_key_for_rdma(struct
> > ksmbd_work *work,
> >                               le32_to_cpu(desc[i].length));
> >               }
> >       }
> > -     if (ch_count != 1) {
> > -             ksmbd_debug(RDMA, "RDMA multiple buffer descriptors %d are not supported
> > yet\n",
> > -                         ch_count);
> > +     if (!ch_count)
> >               return -EINVAL;
> > -     }
> >
> >       work->need_invalidate_rkey =
> >               (Channel == SMB2_CHANNEL_RDMA_V1_INVALIDATE);
> > diff --git a/fs/ksmbd/transport_rdma.c b/fs/ksmbd/transport_rdma.c
> > index 1343ff8e00fd..410b79edc9f2 100644
> > --- a/fs/ksmbd/transport_rdma.c
> > +++ b/fs/ksmbd/transport_rdma.c
> > @@ -208,7 +208,9 @@ struct smb_direct_recvmsg {
> >  struct smb_direct_rdma_rw_msg {
> >       struct smb_direct_transport     *t;
> >       struct ib_cqe           cqe;
> > +     int                     status;
> >       struct completion       *completion;
> > +     struct list_head        list;
> >       struct rdma_rw_ctx      rw_ctx;
> >       struct sg_table         sgt;
> >       struct scatterlist      sg_list[];
> > @@ -1313,6 +1315,18 @@ static int smb_direct_writev(struct ksmbd_transport
> > *t,
> >       return ret;
> >  }
> >
> > +static void smb_direct_free_rdma_rw_msg(struct smb_direct_transport *t,
> > +                                     struct smb_direct_rdma_rw_msg *msg,
> > +                                     enum dma_data_direction dir)
> > +{
> > +     if (msg->sgt.orig_nents) {
> Is there any case where orig_ent is 0?

I will remove this condition.

> > +             rdma_rw_ctx_destroy(&msg->rw_ctx, t->qp, t->qp->port,
> > +                                 msg->sgt.sgl, msg->sgt.nents, dir);
> > +             sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE);
> > +     }
> > +     kfree(msg);
> > +}
> > +
> >  static void read_write_done(struct ib_cq *cq, struct ib_wc *wc,
> >                           enum dma_data_direction dir)
> >  {
> > @@ -1321,19 +1335,14 @@ static void read_write_done(struct ib_cq *cq, struct
> > ib_wc *wc,
> >       struct smb_direct_transport *t = msg->t;
> >
> >       if (wc->status != IB_WC_SUCCESS) {
> > +             msg->status = -EIO;
> >               pr_err("read/write error. opcode = %d, status = %s(%d)\n",
> >                      wc->opcode, ib_wc_status_msg(wc->status), wc->status);
> > -             smb_direct_disconnect_rdma_connection(t);
> > +             if (wc->status != IB_WC_WR_FLUSH_ERR)
> Why is this condition needed ?

IB_WC_FLUSH_ERR is set after the RDMA connection is
disconnected. So we don't need to try to disconnect it again.

> > +                     smb_direct_disconnect_rdma_connection(t);
> >       }
> >
> > -     if (atomic_inc_return(&t->rw_credits) > 0)
> > -             wake_up(&t->wait_rw_credits);
> > -
> > -     rdma_rw_ctx_destroy(&msg->rw_ctx, t->qp, t->qp->port,
> > -                         msg->sg_list, msg->sgt.nents, dir);
> > -     sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE);
> >       complete(msg->completion);
> > -     kfree(msg);
> >  }
> >
> >  static void read_done(struct ib_cq *cq, struct ib_wc *wc)
> > @@ -1352,75 +1361,116 @@ static int smb_direct_rdma_xmit(struct
> > smb_direct_transport *t,
> >                               unsigned int desc_len,
> >                               bool is_read)
> >  {
> > -     struct smb_direct_rdma_rw_msg *msg;
> > -     int ret;
> > +     struct smb_direct_rdma_rw_msg *msg, *next_msg;
> > +     int i, ret;
> >       DECLARE_COMPLETION_ONSTACK(completion);
> > -     struct ib_send_wr *first_wr = NULL;
> > -     u32 remote_key = le32_to_cpu(desc[0].token);
> > -     u64 remote_offset = le64_to_cpu(desc[0].offset);
> > +     struct ib_send_wr *first_wr;
> > +     LIST_HEAD(msg_list);
> > +     char *desc_buf;
> >       int credits_needed;
> > +     unsigned int desc_buf_len;
> > +     size_t total_length = 0;
> > +
> > +     if (t->status != SMB_DIRECT_CS_CONNECTED)
> > +             return -ENOTCONN;
> > +
> > +     /* calculate needed credits */
> > +     credits_needed = 0;
> > +     desc_buf = buf;
> > +     for (i = 0; i < desc_len / sizeof(*desc); i++) {
> > +             desc_buf_len = le32_to_cpu(desc[i].length);
> > +
> > +             credits_needed += calc_rw_credits(t, desc_buf, desc_buf_len);
> > +             desc_buf += desc_buf_len;
> > +             total_length += desc_buf_len;
> > +             if (desc_buf_len == 0 || total_length > buf_len ||
> > +                 total_length > t->max_rdma_rw_size)
> > +                     return -EINVAL;
> > +     }
> > +
> > +     ksmbd_debug(RDMA, "RDMA %s, len %#x, needed credits %#x\n",
> > +                 is_read ? "read" : "write", buf_len, credits_needed);
> >
> > -     credits_needed = calc_rw_credits(t, buf, buf_len);
> >       ret = wait_for_rw_credits(t, credits_needed);
> >       if (ret < 0)
> >               return ret;
> >
> > -     /* TODO: mempool */
> > -     msg = kmalloc(offsetof(struct smb_direct_rdma_rw_msg, sg_list) +
> > -                   sizeof(struct scatterlist) * SG_CHUNK_SIZE, GFP_KERNEL);
> > -     if (!msg) {
> > -             atomic_add(credits_needed, &t->rw_credits);
> > -             return -ENOMEM;
> > -     }
> > +     /* build rdma_rw_ctx for each descriptor */
> > +     desc_buf = buf;
> > +     for (i = 0; i < desc_len / sizeof(*desc); i++) {
> > +             msg = kzalloc(offsetof(struct smb_direct_rdma_rw_msg, sg_list) +
> > +                           sizeof(struct scatterlist) * SG_CHUNK_SIZE, GFP_KERNEL);
> > +             if (!msg) {
> > +                     ret = -ENOMEM;
> > +                     goto out;
> > +             }
> >
> > -     msg->sgt.sgl = &msg->sg_list[0];
> > -     ret = sg_alloc_table_chained(&msg->sgt,
> > -                                  get_buf_page_count(buf, buf_len),
> > -                                  msg->sg_list, SG_CHUNK_SIZE);
> > -     if (ret) {
> > -             atomic_add(credits_needed, &t->rw_credits);
> > -             kfree(msg);
> > -             return -ENOMEM;
> > -     }
> > +             desc_buf_len = le32_to_cpu(desc[i].length);
> >
> > -     ret = get_sg_list(buf, buf_len, msg->sgt.sgl, msg->sgt.orig_nents);
> > -     if (ret <= 0) {
> > -             pr_err("failed to get pages\n");
> > -             goto err;
> > -     }
> > +             msg->t = t;
> > +             msg->cqe.done = is_read ? read_done : write_done;
> > +             msg->completion = &completion;
> >
> > -     ret = rdma_rw_ctx_init(&msg->rw_ctx, t->qp, t->qp->port,
> > -                            msg->sg_list, get_buf_page_count(buf, buf_len),
> > -                            0, remote_offset, remote_key,
> > -                            is_read ? DMA_FROM_DEVICE : DMA_TO_DEVICE);
> > -     if (ret < 0) {
> > -             pr_err("failed to init rdma_rw_ctx: %d\n", ret);
> > -             goto err;
> > +             msg->sgt.sgl = &msg->sg_list[0];
> > +             ret = sg_alloc_table_chained(&msg->sgt,
> > +                                          get_buf_page_count(desc_buf, desc_buf_len),
> > +                                          msg->sg_list, SG_CHUNK_SIZE);
> > +             if (ret) {
> > +                     kfree(msg);
> > +                     ret = -ENOMEM;
> > +                     goto out;
> > +             }
> > +
> > +             ret = get_sg_list(desc_buf, desc_buf_len,
> > +                               msg->sgt.sgl, msg->sgt.orig_nents);
> > +             if (ret <= 0) {
> Is there any problem if this function returns 0?

Yes, 0 means mapping scatterlist is failed. I will change
this function to return an error code in this situation.

> > +                     sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE);
> > +                     kfree(msg);
> > +                     goto out;
> > +             }
> > +
> > +             ret = rdma_rw_ctx_init(&msg->rw_ctx, t->qp, t->qp->port,
> > +                                    msg->sgt.sgl,
> > +                                    get_buf_page_count(desc_buf, desc_buf_len),
> > +                                    0,
> > +                                    le64_to_cpu(desc[i].offset),
> > +                                    le32_to_cpu(desc[i].token),
> > +                                    is_read ? DMA_FROM_DEVICE : DMA_TO_DEVICE);
> > +             if (ret < 0) {
> > +                     pr_err("failed to init rdma_rw_ctx: %d\n", ret);
> > +                     sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE);
> > +                     kfree(msg);
> > +                     goto out;
> > +             }
> > +
> > +             list_add_tail(&msg->list, &msg_list);
> > +             desc_buf += desc_buf_len;
> >       }
> >
> > -     msg->t = t;
> > -     msg->cqe.done = is_read ? read_done : write_done;
> > -     msg->completion = &completion;
> > -     first_wr = rdma_rw_ctx_wrs(&msg->rw_ctx, t->qp, t->qp->port,
> > -                                &msg->cqe, NULL);
> > +     /* concatenate work requests of rdma_rw_ctxs */
> > +     first_wr = NULL;
> > +     list_for_each_entry_reverse(msg, &msg_list, list) {
> > +             first_wr = rdma_rw_ctx_wrs(&msg->rw_ctx, t->qp, t->qp->port,
> > +                                        &msg->cqe, first_wr);
> > +     }
> >
> >       ret = ib_post_send(t->qp, first_wr, NULL);
> >       if (ret) {
> > -             pr_err("failed to post send wr: %d\n", ret);
> > -             goto err;
> > +             pr_err("failed to post send wr for RDMA R/W: %d\n", ret);
> > +             goto out;
> >       }
> >
> > +     msg = list_last_entry(&msg_list, struct smb_direct_rdma_rw_msg, list);
> >       wait_for_completion(&completion);
> > -     return 0;
> > -
> > -err:
> > +     ret = msg->status;
> > +out:
> > +     list_for_each_entry_safe(msg, next_msg, &msg_list, list) {
> > +             list_del(&msg->list);
> > +             smb_direct_free_rdma_rw_msg(t, msg,
> > +                                         is_read ? DMA_FROM_DEVICE : DMA_TO_DEVICE);
> > +     }
> >       atomic_add(credits_needed, &t->rw_credits);
> > -     if (first_wr)
> > -             rdma_rw_ctx_destroy(&msg->rw_ctx, t->qp, t->qp->port,
> > -                                 msg->sg_list, msg->sgt.nents,
> > -                                 is_read ? DMA_FROM_DEVICE : DMA_TO_DEVICE);
> > -     sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE);
> > -     kfree(msg);
> > +     wake_up(&t->wait_rw_credits);
> >       return ret;
> >  }
> >
> > --
> > 2.25.1
> >
> >
diff mbox series

Patch

diff --git a/fs/ksmbd/smb2pdu.c b/fs/ksmbd/smb2pdu.c
index fc9b8def50df..621fa3e55fab 100644
--- a/fs/ksmbd/smb2pdu.c
+++ b/fs/ksmbd/smb2pdu.c
@@ -6133,11 +6133,8 @@  static int smb2_set_remote_key_for_rdma(struct ksmbd_work *work,
 				le32_to_cpu(desc[i].length));
 		}
 	}
-	if (ch_count != 1) {
-		ksmbd_debug(RDMA, "RDMA multiple buffer descriptors %d are not supported yet\n",
-			    ch_count);
+	if (!ch_count)
 		return -EINVAL;
-	}
 
 	work->need_invalidate_rkey =
 		(Channel == SMB2_CHANNEL_RDMA_V1_INVALIDATE);
diff --git a/fs/ksmbd/transport_rdma.c b/fs/ksmbd/transport_rdma.c
index 1343ff8e00fd..410b79edc9f2 100644
--- a/fs/ksmbd/transport_rdma.c
+++ b/fs/ksmbd/transport_rdma.c
@@ -208,7 +208,9 @@  struct smb_direct_recvmsg {
 struct smb_direct_rdma_rw_msg {
 	struct smb_direct_transport	*t;
 	struct ib_cqe		cqe;
+	int			status;
 	struct completion	*completion;
+	struct list_head	list;
 	struct rdma_rw_ctx	rw_ctx;
 	struct sg_table		sgt;
 	struct scatterlist	sg_list[];
@@ -1313,6 +1315,18 @@  static int smb_direct_writev(struct ksmbd_transport *t,
 	return ret;
 }
 
+static void smb_direct_free_rdma_rw_msg(struct smb_direct_transport *t,
+					struct smb_direct_rdma_rw_msg *msg,
+					enum dma_data_direction dir)
+{
+	if (msg->sgt.orig_nents) {
+		rdma_rw_ctx_destroy(&msg->rw_ctx, t->qp, t->qp->port,
+				    msg->sgt.sgl, msg->sgt.nents, dir);
+		sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE);
+	}
+	kfree(msg);
+}
+
 static void read_write_done(struct ib_cq *cq, struct ib_wc *wc,
 			    enum dma_data_direction dir)
 {
@@ -1321,19 +1335,14 @@  static void read_write_done(struct ib_cq *cq, struct ib_wc *wc,
 	struct smb_direct_transport *t = msg->t;
 
 	if (wc->status != IB_WC_SUCCESS) {
+		msg->status = -EIO;
 		pr_err("read/write error. opcode = %d, status = %s(%d)\n",
 		       wc->opcode, ib_wc_status_msg(wc->status), wc->status);
-		smb_direct_disconnect_rdma_connection(t);
+		if (wc->status != IB_WC_WR_FLUSH_ERR)
+			smb_direct_disconnect_rdma_connection(t);
 	}
 
-	if (atomic_inc_return(&t->rw_credits) > 0)
-		wake_up(&t->wait_rw_credits);
-
-	rdma_rw_ctx_destroy(&msg->rw_ctx, t->qp, t->qp->port,
-			    msg->sg_list, msg->sgt.nents, dir);
-	sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE);
 	complete(msg->completion);
-	kfree(msg);
 }
 
 static void read_done(struct ib_cq *cq, struct ib_wc *wc)
@@ -1352,75 +1361,116 @@  static int smb_direct_rdma_xmit(struct smb_direct_transport *t,
 				unsigned int desc_len,
 				bool is_read)
 {
-	struct smb_direct_rdma_rw_msg *msg;
-	int ret;
+	struct smb_direct_rdma_rw_msg *msg, *next_msg;
+	int i, ret;
 	DECLARE_COMPLETION_ONSTACK(completion);
-	struct ib_send_wr *first_wr = NULL;
-	u32 remote_key = le32_to_cpu(desc[0].token);
-	u64 remote_offset = le64_to_cpu(desc[0].offset);
+	struct ib_send_wr *first_wr;
+	LIST_HEAD(msg_list);
+	char *desc_buf;
 	int credits_needed;
+	unsigned int desc_buf_len;
+	size_t total_length = 0;
+
+	if (t->status != SMB_DIRECT_CS_CONNECTED)
+		return -ENOTCONN;
+
+	/* calculate needed credits */
+	credits_needed = 0;
+	desc_buf = buf;
+	for (i = 0; i < desc_len / sizeof(*desc); i++) {
+		desc_buf_len = le32_to_cpu(desc[i].length);
+
+		credits_needed += calc_rw_credits(t, desc_buf, desc_buf_len);
+		desc_buf += desc_buf_len;
+		total_length += desc_buf_len;
+		if (desc_buf_len == 0 || total_length > buf_len ||
+		    total_length > t->max_rdma_rw_size)
+			return -EINVAL;
+	}
+
+	ksmbd_debug(RDMA, "RDMA %s, len %#x, needed credits %#x\n",
+		    is_read ? "read" : "write", buf_len, credits_needed);
 
-	credits_needed = calc_rw_credits(t, buf, buf_len);
 	ret = wait_for_rw_credits(t, credits_needed);
 	if (ret < 0)
 		return ret;
 
-	/* TODO: mempool */
-	msg = kmalloc(offsetof(struct smb_direct_rdma_rw_msg, sg_list) +
-		      sizeof(struct scatterlist) * SG_CHUNK_SIZE, GFP_KERNEL);
-	if (!msg) {
-		atomic_add(credits_needed, &t->rw_credits);
-		return -ENOMEM;
-	}
+	/* build rdma_rw_ctx for each descriptor */
+	desc_buf = buf;
+	for (i = 0; i < desc_len / sizeof(*desc); i++) {
+		msg = kzalloc(offsetof(struct smb_direct_rdma_rw_msg, sg_list) +
+			      sizeof(struct scatterlist) * SG_CHUNK_SIZE, GFP_KERNEL);
+		if (!msg) {
+			ret = -ENOMEM;
+			goto out;
+		}
 
-	msg->sgt.sgl = &msg->sg_list[0];
-	ret = sg_alloc_table_chained(&msg->sgt,
-				     get_buf_page_count(buf, buf_len),
-				     msg->sg_list, SG_CHUNK_SIZE);
-	if (ret) {
-		atomic_add(credits_needed, &t->rw_credits);
-		kfree(msg);
-		return -ENOMEM;
-	}
+		desc_buf_len = le32_to_cpu(desc[i].length);
 
-	ret = get_sg_list(buf, buf_len, msg->sgt.sgl, msg->sgt.orig_nents);
-	if (ret <= 0) {
-		pr_err("failed to get pages\n");
-		goto err;
-	}
+		msg->t = t;
+		msg->cqe.done = is_read ? read_done : write_done;
+		msg->completion = &completion;
 
-	ret = rdma_rw_ctx_init(&msg->rw_ctx, t->qp, t->qp->port,
-			       msg->sg_list, get_buf_page_count(buf, buf_len),
-			       0, remote_offset, remote_key,
-			       is_read ? DMA_FROM_DEVICE : DMA_TO_DEVICE);
-	if (ret < 0) {
-		pr_err("failed to init rdma_rw_ctx: %d\n", ret);
-		goto err;
+		msg->sgt.sgl = &msg->sg_list[0];
+		ret = sg_alloc_table_chained(&msg->sgt,
+					     get_buf_page_count(desc_buf, desc_buf_len),
+					     msg->sg_list, SG_CHUNK_SIZE);
+		if (ret) {
+			kfree(msg);
+			ret = -ENOMEM;
+			goto out;
+		}
+
+		ret = get_sg_list(desc_buf, desc_buf_len,
+				  msg->sgt.sgl, msg->sgt.orig_nents);
+		if (ret <= 0) {
+			sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE);
+			kfree(msg);
+			goto out;
+		}
+
+		ret = rdma_rw_ctx_init(&msg->rw_ctx, t->qp, t->qp->port,
+				       msg->sgt.sgl,
+				       get_buf_page_count(desc_buf, desc_buf_len),
+				       0,
+				       le64_to_cpu(desc[i].offset),
+				       le32_to_cpu(desc[i].token),
+				       is_read ? DMA_FROM_DEVICE : DMA_TO_DEVICE);
+		if (ret < 0) {
+			pr_err("failed to init rdma_rw_ctx: %d\n", ret);
+			sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE);
+			kfree(msg);
+			goto out;
+		}
+
+		list_add_tail(&msg->list, &msg_list);
+		desc_buf += desc_buf_len;
 	}
 
-	msg->t = t;
-	msg->cqe.done = is_read ? read_done : write_done;
-	msg->completion = &completion;
-	first_wr = rdma_rw_ctx_wrs(&msg->rw_ctx, t->qp, t->qp->port,
-				   &msg->cqe, NULL);
+	/* concatenate work requests of rdma_rw_ctxs */
+	first_wr = NULL;
+	list_for_each_entry_reverse(msg, &msg_list, list) {
+		first_wr = rdma_rw_ctx_wrs(&msg->rw_ctx, t->qp, t->qp->port,
+					   &msg->cqe, first_wr);
+	}
 
 	ret = ib_post_send(t->qp, first_wr, NULL);
 	if (ret) {
-		pr_err("failed to post send wr: %d\n", ret);
-		goto err;
+		pr_err("failed to post send wr for RDMA R/W: %d\n", ret);
+		goto out;
 	}
 
+	msg = list_last_entry(&msg_list, struct smb_direct_rdma_rw_msg, list);
 	wait_for_completion(&completion);
-	return 0;
-
-err:
+	ret = msg->status;
+out:
+	list_for_each_entry_safe(msg, next_msg, &msg_list, list) {
+		list_del(&msg->list);
+		smb_direct_free_rdma_rw_msg(t, msg,
+					    is_read ? DMA_FROM_DEVICE : DMA_TO_DEVICE);
+	}
 	atomic_add(credits_needed, &t->rw_credits);
-	if (first_wr)
-		rdma_rw_ctx_destroy(&msg->rw_ctx, t->qp, t->qp->port,
-				    msg->sg_list, msg->sgt.nents,
-				    is_read ? DMA_FROM_DEVICE : DMA_TO_DEVICE);
-	sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE);
-	kfree(msg);
+	wake_up(&t->wait_rw_credits);
 	return ret;
 }