diff mbox series

[3/3] ksmbd: fix wrong smbd max read/write size check

Message ID 20220516074140.28522-3-linkinjeon@kernel.org (mailing list archive)
State New, archived
Headers show
Series [1/3] ksmbd: handle smb2 query dir request for OutputBufferLength that is too small | expand

Commit Message

Namjae Jeon May 16, 2022, 7:41 a.m. UTC
smb-direct max read/write size can be different with smb2 max read/write
size. So smb2_read() can return error by wrong max read/write size check.
This patch use smb_direct_max_read_write_size for this check in
smb-direct read/write().

Signed-off-by: Namjae Jeon <linkinjeon@kernel.org>
---
 fs/ksmbd/smb2pdu.c        | 39 +++++++++++++++++++++++++--------------
 fs/ksmbd/transport_rdma.c |  5 +++++
 fs/ksmbd/transport_rdma.h |  2 ++
 3 files changed, 32 insertions(+), 14 deletions(-)

Comments

Hyunchul Lee May 17, 2022, 8:07 a.m. UTC | #1
2022년 5월 16일 (월) 오후 4:42, Namjae Jeon <linkinjeon@kernel.org>님이 작성:
>
> smb-direct max read/write size can be different with smb2 max read/write
> size. So smb2_read() can return error by wrong max read/write size check.
> This patch use smb_direct_max_read_write_size for this check in
> smb-direct read/write().
>
> Signed-off-by: Namjae Jeon <linkinjeon@kernel.org>

Looks good to me.
Reviewed-by: Hyunchul Lee <hyc.lee@gmail.com>

> ---
>  fs/ksmbd/smb2pdu.c        | 39 +++++++++++++++++++++++++--------------
>  fs/ksmbd/transport_rdma.c |  5 +++++
>  fs/ksmbd/transport_rdma.h |  2 ++
>  3 files changed, 32 insertions(+), 14 deletions(-)
>
> diff --git a/fs/ksmbd/smb2pdu.c b/fs/ksmbd/smb2pdu.c
> index eb7ca5f24a3b..937f9760f181 100644
> --- a/fs/ksmbd/smb2pdu.c
> +++ b/fs/ksmbd/smb2pdu.c
> @@ -6098,6 +6098,8 @@ int smb2_read(struct ksmbd_work *work)
>         size_t length, mincount;
>         ssize_t nbytes = 0, remain_bytes = 0;
>         int err = 0;
> +       bool is_rdma_channel = false;
> +       unsigned int max_read_size = conn->vals->max_read_size;
>
>         WORK_BUFFERS(work, req, rsp);
>
> @@ -6109,6 +6111,11 @@ int smb2_read(struct ksmbd_work *work)
>
>         if (req->Channel == SMB2_CHANNEL_RDMA_V1_INVALIDATE ||
>             req->Channel == SMB2_CHANNEL_RDMA_V1) {
> +               is_rdma_channel = true;
> +               max_read_size = get_smbd_max_read_write_size();
> +       }
> +
> +       if (is_rdma_channel == true) {
>                 unsigned int ch_offset = le16_to_cpu(req->ReadChannelInfoOffset);
>
>                 if (ch_offset < offsetof(struct smb2_read_req, Buffer)) {
> @@ -6140,9 +6147,9 @@ int smb2_read(struct ksmbd_work *work)
>         length = le32_to_cpu(req->Length);
>         mincount = le32_to_cpu(req->MinimumCount);
>
> -       if (length > conn->vals->max_read_size) {
> +       if (length > max_read_size) {
>                 ksmbd_debug(SMB, "limiting read size to max size(%u)\n",
> -                           conn->vals->max_read_size);
> +                           max_read_size);
>                 err = -EINVAL;
>                 goto out;
>         }
> @@ -6174,8 +6181,7 @@ int smb2_read(struct ksmbd_work *work)
>         ksmbd_debug(SMB, "nbytes %zu, offset %lld mincount %zu\n",
>                     nbytes, offset, mincount);
>
> -       if (req->Channel == SMB2_CHANNEL_RDMA_V1_INVALIDATE ||
> -           req->Channel == SMB2_CHANNEL_RDMA_V1) {
> +       if (is_rdma_channel == true) {
>                 /* write data to the client using rdma channel */
>                 remain_bytes = smb2_read_rdma_channel(work, req,
>                                                       work->aux_payload_buf,
> @@ -6336,8 +6342,9 @@ int smb2_write(struct ksmbd_work *work)
>         size_t length;
>         ssize_t nbytes;
>         char *data_buf;
> -       bool writethrough = false;
> +       bool writethrough = false, is_rdma_channel = false;
>         int err = 0;
> +       unsigned int max_write_size = work->conn->vals->max_write_size;
>
>         WORK_BUFFERS(work, req, rsp);
>
> @@ -6346,8 +6353,17 @@ int smb2_write(struct ksmbd_work *work)
>                 return smb2_write_pipe(work);
>         }
>
> +       offset = le64_to_cpu(req->Offset);
> +       length = le32_to_cpu(req->Length);
> +
>         if (req->Channel == SMB2_CHANNEL_RDMA_V1 ||
>             req->Channel == SMB2_CHANNEL_RDMA_V1_INVALIDATE) {
> +               is_rdma_channel = true;
> +               max_write_size = get_smbd_max_read_write_size();
> +               length = le32_to_cpu(req->RemainingBytes);
> +       }
> +
> +       if (is_rdma_channel == true) {
>                 unsigned int ch_offset = le16_to_cpu(req->WriteChannelInfoOffset);
>
>                 if (req->Length != 0 || req->DataOffset != 0 ||
> @@ -6382,12 +6398,9 @@ int smb2_write(struct ksmbd_work *work)
>                 goto out;
>         }
>
> -       offset = le64_to_cpu(req->Offset);
> -       length = le32_to_cpu(req->Length);
> -
> -       if (length > work->conn->vals->max_write_size) {
> +       if (length > max_write_size) {
>                 ksmbd_debug(SMB, "limiting write size to max size(%u)\n",
> -                           work->conn->vals->max_write_size);
> +                           max_write_size);
>                 err = -EINVAL;
>                 goto out;
>         }
> @@ -6395,8 +6408,7 @@ int smb2_write(struct ksmbd_work *work)
>         if (le32_to_cpu(req->Flags) & SMB2_WRITEFLAG_WRITE_THROUGH)
>                 writethrough = true;
>
> -       if (req->Channel != SMB2_CHANNEL_RDMA_V1 &&
> -           req->Channel != SMB2_CHANNEL_RDMA_V1_INVALIDATE) {
> +       if (is_rdma_channel == false) {
>                 if ((u64)le16_to_cpu(req->DataOffset) + length >
>                     get_rfc1002_len(work->request_buf)) {
>                         pr_err("invalid write data offset %u, smb_len %u\n",
> @@ -6422,8 +6434,7 @@ int smb2_write(struct ksmbd_work *work)
>                 /* read data from the client using rdma channel, and
>                  * write the data.
>                  */
> -               nbytes = smb2_write_rdma_channel(work, req, fp, offset,
> -                                                le32_to_cpu(req->RemainingBytes),
> +               nbytes = smb2_write_rdma_channel(work, req, fp, offset, length,
>                                                  writethrough);
>                 if (nbytes < 0) {
>                         err = (int)nbytes;
> diff --git a/fs/ksmbd/transport_rdma.c b/fs/ksmbd/transport_rdma.c
> index 6d652ff38b82..0741fd129d16 100644
> --- a/fs/ksmbd/transport_rdma.c
> +++ b/fs/ksmbd/transport_rdma.c
> @@ -220,6 +220,11 @@ void init_smbd_max_io_size(unsigned int sz)
>         smb_direct_max_read_write_size = sz;
>  }
>
> +unsigned int get_smbd_max_read_write_size(void)
> +{
> +       return smb_direct_max_read_write_size;
> +}
> +
>  static inline int get_buf_page_count(void *buf, int size)
>  {
>         return DIV_ROUND_UP((uintptr_t)buf + size, PAGE_SIZE) -
> diff --git a/fs/ksmbd/transport_rdma.h b/fs/ksmbd/transport_rdma.h
> index e7b4e6790fab..77aee4e5c9dc 100644
> --- a/fs/ksmbd/transport_rdma.h
> +++ b/fs/ksmbd/transport_rdma.h
> @@ -57,11 +57,13 @@ int ksmbd_rdma_init(void);
>  void ksmbd_rdma_destroy(void);
>  bool ksmbd_rdma_capable_netdev(struct net_device *netdev);
>  void init_smbd_max_io_size(unsigned int sz);
> +unsigned int get_smbd_max_read_write_size(void);
>  #else
>  static inline int ksmbd_rdma_init(void) { return 0; }
>  static inline int ksmbd_rdma_destroy(void) { return 0; }
>  static inline bool ksmbd_rdma_capable_netdev(struct net_device *netdev) { return false; }
>  static inline void init_smbd_max_io_size(unsigned int sz) { }
> +static inline unsigned int get_smbd_max_read_write_size(void) { return 0; }
>  #endif
>
>  #endif /* __KSMBD_TRANSPORT_RDMA_H__ */
> --
> 2.25.1
>
diff mbox series

Patch

diff --git a/fs/ksmbd/smb2pdu.c b/fs/ksmbd/smb2pdu.c
index eb7ca5f24a3b..937f9760f181 100644
--- a/fs/ksmbd/smb2pdu.c
+++ b/fs/ksmbd/smb2pdu.c
@@ -6098,6 +6098,8 @@  int smb2_read(struct ksmbd_work *work)
 	size_t length, mincount;
 	ssize_t nbytes = 0, remain_bytes = 0;
 	int err = 0;
+	bool is_rdma_channel = false;
+	unsigned int max_read_size = conn->vals->max_read_size;
 
 	WORK_BUFFERS(work, req, rsp);
 
@@ -6109,6 +6111,11 @@  int smb2_read(struct ksmbd_work *work)
 
 	if (req->Channel == SMB2_CHANNEL_RDMA_V1_INVALIDATE ||
 	    req->Channel == SMB2_CHANNEL_RDMA_V1) {
+		is_rdma_channel = true;
+		max_read_size = get_smbd_max_read_write_size();
+	}
+
+	if (is_rdma_channel == true) {
 		unsigned int ch_offset = le16_to_cpu(req->ReadChannelInfoOffset);
 
 		if (ch_offset < offsetof(struct smb2_read_req, Buffer)) {
@@ -6140,9 +6147,9 @@  int smb2_read(struct ksmbd_work *work)
 	length = le32_to_cpu(req->Length);
 	mincount = le32_to_cpu(req->MinimumCount);
 
-	if (length > conn->vals->max_read_size) {
+	if (length > max_read_size) {
 		ksmbd_debug(SMB, "limiting read size to max size(%u)\n",
-			    conn->vals->max_read_size);
+			    max_read_size);
 		err = -EINVAL;
 		goto out;
 	}
@@ -6174,8 +6181,7 @@  int smb2_read(struct ksmbd_work *work)
 	ksmbd_debug(SMB, "nbytes %zu, offset %lld mincount %zu\n",
 		    nbytes, offset, mincount);
 
-	if (req->Channel == SMB2_CHANNEL_RDMA_V1_INVALIDATE ||
-	    req->Channel == SMB2_CHANNEL_RDMA_V1) {
+	if (is_rdma_channel == true) {
 		/* write data to the client using rdma channel */
 		remain_bytes = smb2_read_rdma_channel(work, req,
 						      work->aux_payload_buf,
@@ -6336,8 +6342,9 @@  int smb2_write(struct ksmbd_work *work)
 	size_t length;
 	ssize_t nbytes;
 	char *data_buf;
-	bool writethrough = false;
+	bool writethrough = false, is_rdma_channel = false;
 	int err = 0;
+	unsigned int max_write_size = work->conn->vals->max_write_size;
 
 	WORK_BUFFERS(work, req, rsp);
 
@@ -6346,8 +6353,17 @@  int smb2_write(struct ksmbd_work *work)
 		return smb2_write_pipe(work);
 	}
 
+	offset = le64_to_cpu(req->Offset);
+	length = le32_to_cpu(req->Length);
+
 	if (req->Channel == SMB2_CHANNEL_RDMA_V1 ||
 	    req->Channel == SMB2_CHANNEL_RDMA_V1_INVALIDATE) {
+		is_rdma_channel = true;
+		max_write_size = get_smbd_max_read_write_size();
+		length = le32_to_cpu(req->RemainingBytes);
+	}
+
+	if (is_rdma_channel == true) {
 		unsigned int ch_offset = le16_to_cpu(req->WriteChannelInfoOffset);
 
 		if (req->Length != 0 || req->DataOffset != 0 ||
@@ -6382,12 +6398,9 @@  int smb2_write(struct ksmbd_work *work)
 		goto out;
 	}
 
-	offset = le64_to_cpu(req->Offset);
-	length = le32_to_cpu(req->Length);
-
-	if (length > work->conn->vals->max_write_size) {
+	if (length > max_write_size) {
 		ksmbd_debug(SMB, "limiting write size to max size(%u)\n",
-			    work->conn->vals->max_write_size);
+			    max_write_size);
 		err = -EINVAL;
 		goto out;
 	}
@@ -6395,8 +6408,7 @@  int smb2_write(struct ksmbd_work *work)
 	if (le32_to_cpu(req->Flags) & SMB2_WRITEFLAG_WRITE_THROUGH)
 		writethrough = true;
 
-	if (req->Channel != SMB2_CHANNEL_RDMA_V1 &&
-	    req->Channel != SMB2_CHANNEL_RDMA_V1_INVALIDATE) {
+	if (is_rdma_channel == false) {
 		if ((u64)le16_to_cpu(req->DataOffset) + length >
 		    get_rfc1002_len(work->request_buf)) {
 			pr_err("invalid write data offset %u, smb_len %u\n",
@@ -6422,8 +6434,7 @@  int smb2_write(struct ksmbd_work *work)
 		/* read data from the client using rdma channel, and
 		 * write the data.
 		 */
-		nbytes = smb2_write_rdma_channel(work, req, fp, offset,
-						 le32_to_cpu(req->RemainingBytes),
+		nbytes = smb2_write_rdma_channel(work, req, fp, offset, length,
 						 writethrough);
 		if (nbytes < 0) {
 			err = (int)nbytes;
diff --git a/fs/ksmbd/transport_rdma.c b/fs/ksmbd/transport_rdma.c
index 6d652ff38b82..0741fd129d16 100644
--- a/fs/ksmbd/transport_rdma.c
+++ b/fs/ksmbd/transport_rdma.c
@@ -220,6 +220,11 @@  void init_smbd_max_io_size(unsigned int sz)
 	smb_direct_max_read_write_size = sz;
 }
 
+unsigned int get_smbd_max_read_write_size(void)
+{
+	return smb_direct_max_read_write_size;
+}
+
 static inline int get_buf_page_count(void *buf, int size)
 {
 	return DIV_ROUND_UP((uintptr_t)buf + size, PAGE_SIZE) -
diff --git a/fs/ksmbd/transport_rdma.h b/fs/ksmbd/transport_rdma.h
index e7b4e6790fab..77aee4e5c9dc 100644
--- a/fs/ksmbd/transport_rdma.h
+++ b/fs/ksmbd/transport_rdma.h
@@ -57,11 +57,13 @@  int ksmbd_rdma_init(void);
 void ksmbd_rdma_destroy(void);
 bool ksmbd_rdma_capable_netdev(struct net_device *netdev);
 void init_smbd_max_io_size(unsigned int sz);
+unsigned int get_smbd_max_read_write_size(void);
 #else
 static inline int ksmbd_rdma_init(void) { return 0; }
 static inline int ksmbd_rdma_destroy(void) { return 0; }
 static inline bool ksmbd_rdma_capable_netdev(struct net_device *netdev) { return false; }
 static inline void init_smbd_max_io_size(unsigned int sz) { }
+static inline unsigned int get_smbd_max_read_write_size(void) { return 0; }
 #endif
 
 #endif /* __KSMBD_TRANSPORT_RDMA_H__ */