diff mbox series

[1/9] verbs: Simplify query_device_ex

Message ID 1-v1-34e141ddf17e+89-query_device_ex_jgg@nvidia.com (mailing list archive)
State Not Applicable
Headers show
Series Simplify query_device() in libibverbs | expand

Commit Message

Jason Gunthorpe Nov. 16, 2020, 8:23 p.m. UTC
The obtuse logic here is hard to read, simplify it with a small macro and
add offsetofend()

Signed-off-by: Jason Gunthorpe <jgg@nvidia.com>
---
 libibverbs/cmd.c | 146 ++++++++++++++++-------------------------------
 util/util.h      |   3 +
 2 files changed, 52 insertions(+), 97 deletions(-)
diff mbox series

Patch

diff --git a/libibverbs/cmd.c b/libibverbs/cmd.c
index 25c8a971540c63..a439f8c06481dd 100644
--- a/libibverbs/cmd.c
+++ b/libibverbs/cmd.c
@@ -44,6 +44,7 @@ 
 #include <infiniband/cmd_write.h>
 #include "ibverbs.h"
 #include <ccan/minmax.h>
+#include <util/util.h>
 
 bool verbs_allow_disassociate_destroy;
 
@@ -144,117 +145,68 @@  int ibv_cmd_query_device_ex(struct ibv_context *context,
 	/* Report back supported comp_mask bits. For now no comp_mask bit is
 	 * defined */
 	attr->comp_mask = resp->comp_mask & 0;
-	if (attr_size >= offsetof(struct ibv_device_attr_ex, odp_caps) +
-			 sizeof(attr->odp_caps)) {
-		if (resp->response_length >=
-		    offsetof(struct ib_uverbs_ex_query_device_resp, odp_caps) +
-		    sizeof(resp->odp_caps)) {
-			attr->odp_caps.general_caps = resp->odp_caps.general_caps;
-			attr->odp_caps.per_transport_caps.rc_odp_caps =
-				resp->odp_caps.per_transport_caps.rc_odp_caps;
-			attr->odp_caps.per_transport_caps.uc_odp_caps =
-				resp->odp_caps.per_transport_caps.uc_odp_caps;
-			attr->odp_caps.per_transport_caps.ud_odp_caps =
-				resp->odp_caps.per_transport_caps.ud_odp_caps;
-		}
-	}
 
-	if (attr_size >= offsetof(struct ibv_device_attr_ex,
-				  completion_timestamp_mask) +
-			 sizeof(attr->completion_timestamp_mask)) {
-		if (resp->response_length >=
-		    offsetof(struct ib_uverbs_ex_query_device_resp, timestamp_mask) +
-		    sizeof(resp->timestamp_mask))
-			attr->completion_timestamp_mask = resp->timestamp_mask;
+#define CAN_COPY(_ibv_attr, _uverbs_attr)                                      \
+	(attr_size >= offsetofend(struct ibv_device_attr_ex, _ibv_attr) &&     \
+	 resp->response_length >=                                              \
+		 offsetofend(struct ib_uverbs_ex_query_device_resp,            \
+			     _uverbs_attr))
+
+	if (CAN_COPY(odp_caps, odp_caps)) {
+		attr->odp_caps.general_caps = resp->odp_caps.general_caps;
+		attr->odp_caps.per_transport_caps.rc_odp_caps =
+			resp->odp_caps.per_transport_caps.rc_odp_caps;
+		attr->odp_caps.per_transport_caps.uc_odp_caps =
+			resp->odp_caps.per_transport_caps.uc_odp_caps;
+		attr->odp_caps.per_transport_caps.ud_odp_caps =
+			resp->odp_caps.per_transport_caps.ud_odp_caps;
 	}
 
-	if (attr_size >= offsetof(struct ibv_device_attr_ex, hca_core_clock) +
-			 sizeof(attr->hca_core_clock)) {
-		if (resp->response_length >=
-		    offsetof(struct ib_uverbs_ex_query_device_resp, hca_core_clock) +
-		    sizeof(resp->hca_core_clock))
-			attr->hca_core_clock = resp->hca_core_clock;
-	}
+	if (CAN_COPY(completion_timestamp_mask, timestamp_mask))
+		attr->completion_timestamp_mask = resp->timestamp_mask;
 
-	if (attr_size >= offsetof(struct ibv_device_attr_ex, device_cap_flags_ex) +
-			 sizeof(attr->device_cap_flags_ex)) {
-		if (resp->response_length >=
-		    offsetof(struct ib_uverbs_ex_query_device_resp, device_cap_flags_ex) +
-		    sizeof(resp->device_cap_flags_ex))
-			attr->device_cap_flags_ex = resp->device_cap_flags_ex;
-	}
+	if (CAN_COPY(hca_core_clock, hca_core_clock))
+		attr->hca_core_clock = resp->hca_core_clock;
 
-	if (attr_size >= offsetof(struct ibv_device_attr_ex, rss_caps) +
-			 sizeof(attr->rss_caps)) {
-		if (resp->response_length >=
-		    offsetof(struct ib_uverbs_ex_query_device_resp, rss_caps) +
-		    sizeof(resp->rss_caps)) {
-			attr->rss_caps.supported_qpts = resp->rss_caps.supported_qpts;
-			attr->rss_caps.max_rwq_indirection_tables = resp->rss_caps.max_rwq_indirection_tables;
-			attr->rss_caps.max_rwq_indirection_table_size = resp->rss_caps.max_rwq_indirection_table_size;
-		}
-	}
+	if (CAN_COPY(device_cap_flags_ex, device_cap_flags_ex))
+		attr->device_cap_flags_ex = resp->device_cap_flags_ex;
 
-	if (attr_size >= offsetof(struct ibv_device_attr_ex, max_wq_type_rq) +
-			 sizeof(attr->max_wq_type_rq)) {
-		if (resp->response_length >=
-		    offsetof(struct ib_uverbs_ex_query_device_resp, max_wq_type_rq) +
-		    sizeof(resp->max_wq_type_rq))
-			attr->max_wq_type_rq = resp->max_wq_type_rq;
+	if (CAN_COPY(rss_caps, rss_caps)) {
+		attr->rss_caps.supported_qpts = resp->rss_caps.supported_qpts;
+		attr->rss_caps.max_rwq_indirection_tables =
+			resp->rss_caps.max_rwq_indirection_tables;
+		attr->rss_caps.max_rwq_indirection_table_size =
+			resp->rss_caps.max_rwq_indirection_table_size;
 	}
 
-	if (attr_size >= offsetof(struct ibv_device_attr_ex, raw_packet_caps) +
-			 sizeof(attr->raw_packet_caps)) {
-		if (resp->response_length >=
-		    offsetof(struct ib_uverbs_ex_query_device_resp, raw_packet_caps) +
-		    sizeof(resp->raw_packet_caps))
-			attr->raw_packet_caps = resp->raw_packet_caps;
-	}
+	if (CAN_COPY(max_wq_type_rq, max_wq_type_rq))
+		attr->max_wq_type_rq = resp->max_wq_type_rq;
 
-	if (attr_size >= offsetof(struct ibv_device_attr_ex, tm_caps) +
-			 sizeof(attr->tm_caps)) {
-		if (resp->response_length >=
-		    offsetof(struct ib_uverbs_ex_query_device_resp, tm_caps) +
-		    sizeof(resp->tm_caps)) {
-			attr->tm_caps.max_rndv_hdr_size =
-				resp->tm_caps.max_rndv_hdr_size;
-			attr->tm_caps.max_num_tags =
-				resp->tm_caps.max_num_tags;
-			attr->tm_caps.flags = resp->tm_caps.flags;
-			attr->tm_caps.max_ops =
-				resp->tm_caps.max_ops;
-			attr->tm_caps.max_sge =
-				resp->tm_caps.max_sge;
-		}
-	}
+	if (CAN_COPY(raw_packet_caps, raw_packet_caps))
+		attr->raw_packet_caps = resp->raw_packet_caps;
 
-	if (attr_size >= offsetof(struct ibv_device_attr_ex, cq_mod_caps) +
-			 sizeof(attr->cq_mod_caps)) {
-		if (resp->response_length >=
-		    offsetof(struct ib_uverbs_ex_query_device_resp, cq_moderation_caps) +
-		    sizeof(resp->cq_moderation_caps)) {
-			attr->cq_mod_caps.max_cq_count = resp->cq_moderation_caps.max_cq_moderation_count;
-			attr->cq_mod_caps.max_cq_period = resp->cq_moderation_caps.max_cq_moderation_period;
-		}
+	if (CAN_COPY(tm_caps, tm_caps)) {
+		attr->tm_caps.max_rndv_hdr_size =
+			resp->tm_caps.max_rndv_hdr_size;
+		attr->tm_caps.max_num_tags = resp->tm_caps.max_num_tags;
+		attr->tm_caps.flags = resp->tm_caps.flags;
+		attr->tm_caps.max_ops = resp->tm_caps.max_ops;
+		attr->tm_caps.max_sge = resp->tm_caps.max_sge;
 	}
 
-	if (attr_size >= offsetof(struct ibv_device_attr_ex, max_dm_size) +
-			sizeof(attr->max_dm_size)) {
-		if (resp->response_length >=
-		    offsetof(struct ib_uverbs_ex_query_device_resp, max_dm_size) +
-		    sizeof(resp->max_dm_size)) {
-			attr->max_dm_size = resp->max_dm_size;
-		}
+	if (CAN_COPY(cq_mod_caps, cq_moderation_caps)) {
+		attr->cq_mod_caps.max_cq_count =
+			resp->cq_moderation_caps.max_cq_moderation_count;
+		attr->cq_mod_caps.max_cq_period =
+			resp->cq_moderation_caps.max_cq_moderation_period;
 	}
 
-	if (attr_size >= offsetof(struct ibv_device_attr_ex, xrc_odp_caps) +
-			sizeof(attr->xrc_odp_caps)) {
-		if (resp->response_length >=
-		    offsetof(struct ib_uverbs_ex_query_device_resp, xrc_odp_caps) +
-		    sizeof(resp->xrc_odp_caps)) {
-			attr->xrc_odp_caps = resp->xrc_odp_caps;
-		}
-	}
+	if (CAN_COPY(max_dm_size, max_dm_size))
+		attr->max_dm_size = resp->max_dm_size;
+
+	if (CAN_COPY(xrc_odp_caps, xrc_odp_caps))
+		attr->xrc_odp_caps = resp->xrc_odp_caps;
+#undef CAN_COPY
 
 	return 0;
 }
diff --git a/util/util.h b/util/util.h
index 0f2c35cd0647ce..47346ca1bf5841 100644
--- a/util/util.h
+++ b/util/util.h
@@ -23,6 +23,9 @@  static inline bool __good_snprintf(size_t len, int rc)
 	 ((a)->tv_nsec CMP (b)->tv_nsec) :	\
 	 ((a)->tv_sec CMP (b)->tv_sec))
 
+#define offsetofend(_type, _member)                                            \
+	(offsetof(_type, _member) + sizeof(((_type *)0)->_member))
+
 static inline unsigned long align(unsigned long val, unsigned long align)
 {
 	return (val + align - 1) & ~(align - 1);