@@ -8,6 +8,9 @@ from string import ascii_lowercase as al
import unittest
import random
+from pyverbs.pyverbs_error import PyverbsError, PyverbsRDMAError
+from pyverbs.addr import AHAttr, AH, GlobalRoute
+from pyverbs.wr import SGE, SendWR, RecvWR
from pyverbs.qp import QPCap, QPInitAttrEx
import pyverbs.device as d
import pyverbs.enums as e
@@ -23,6 +26,7 @@ MIN_DM_LOG_ALIGN = 0
MAX_DM_LOG_ALIGN = 6
# Raw Packet QP supports TSO header, which creates a larger send WQE.
MAX_RAW_PACKET_SEND_WR = 2500
+GRH_SIZE = 40
def get_mr_length():
@@ -242,9 +246,154 @@ def wc_status_to_str(status):
except KeyError:
return 'Unknown WC status ({s})'.format(s=status)
+# Traffic helpers
+
+def get_send_wr(agr_obj, is_server):
+ """
+ Creates a single SGE Send WR for agr_obj's QP type. The content of the
+ message is either 's' for server side or 'c' for client side.
+ :param agr_obj: Aggregation object which contains all resources necessary
+ :param is_server: Indicates whether this is server or client side
+ :return: send wr
+ """
+ qp_type = agr_obj.qp.qp_type
+ mr = agr_obj.mr
+ if qp_type == e.IBV_QPT_UD:
+ send_sge = SGE(mr.buf + GRH_SIZE, agr_obj.msg_size, mr.lkey)
+ else:
+ send_sge = SGE(mr.buf, agr_obj.msg_size, mr.lkey)
+ msg = agr_obj.msg_size * ('s' if is_server else 'c')
+ mr.write(msg, agr_obj.msg_size)
+ return SendWR(num_sge=1, sg=[send_sge])
+
+
+def get_recv_wr(agr_obj):
+ """
+ Creates a single SGE Recv WR for agr_obj's QP type.
+ :param agr_obj: Aggregation object which contains all resources necessary
+ :return: recv wr
+ """
+ qp_type = agr_obj.qp.qp_type
+ mr = agr_obj.mr
+ if qp_type == e.IBV_QPT_UD:
+ recv_sge = SGE(mr.buf, agr_obj.msg_size + GRH_SIZE, mr.lkey)
+ else:
+ recv_sge = SGE(mr.buf, agr_obj.msg_size, mr.lkey)
+ return RecvWR(sg=[recv_sge], num_sge=1)
+
+
+def post_send(agr_obj, send_wr, gid_index, port):
+ """
+ Post a single send WR to the QP. Post_send's second parameter (send bad wr)
+ is ignored for simplicity. For UD traffic an address vector is added as
+ well.
+ :param agr_obj: aggregation object which contains all resources necessary
+ :param send_wr: Send work request to post send
+ :param gid_index: Local gid index
+ :param port: IB port number
+ :return: None
+ """
+ qp_type = agr_obj.qp.qp_type
+ if qp_type == e.IBV_QPT_UD:
+ gr = GlobalRoute(dgid=agr_obj.ctx.query_gid(port, gid_index),
+ sgid_index=gid_index)
+ ah_attr = AHAttr(port_num=port, is_global=1, gr=gr,
+ dlid=agr_obj.port_attr.lid)
+ ah = AH(agr_obj.pd, attr=ah_attr)
+ send_wr.set_wr_ud(ah, agr_obj.rqpn, agr_obj.UD_QKEY)
+ agr_obj.qp.post_send(send_wr, None)
+
+
+def post_recv(qp, recv_wr, num_wqes=1):
+ """
+ Call the QP's post_recv() method <num_wqes> times. Post_recv's second
+ parameter (recv bad wr) is ignored for simplicity.
+ :param qp: QP which posts receive work request
+ :param recv_wr: Receive work request to post
+ :param num_wqes: Number of WQEs to post
+ :return: None
+ """
+ for _ in range(num_wqes):
+ qp.post_recv(recv_wr, None)
-# Decorators
+def poll_cq(cq, count=1):
+ """
+ Poll <count> completions from the CQ.
+ Note: This function calls the blocking poll() method of the CQ
+ until <count> completions were received. Alternatively, gets a
+ single CQ event when events are used.
+ :param cq: CQ to poll from
+ :param count: How many completions to poll
+ :return: An array of work completions of length <count>, None
+ when events are used
+ """
+ wcs = None
+ while count > 0:
+ nc, wcs = cq.poll(count)
+ for wc in wcs:
+ if wc.status != e.IBV_WC_SUCCESS:
+ raise PyverbsRDMAError('Completion status is {s}'.
+ format(s=wc_status_to_str(wc.status)))
+ count -= nc
+ return wcs
+
+
+def validate(received_str, is_server, msg_size):
+ """
+ Validates the received buffer against the expected result.
+ The application should set client's send buffer to 'c's and the
+ server's send buffer to 's's.
+ If the expected buffer is different than the actual, an exception will
+ be raised.
+ :param received_str: The received buffer to check
+ :param is_server: Indicates whether this is the server (receiver) or
+ client side
+ :param msg_size: the message size of the received packet
+ :return: None
+ """
+ expected_str = msg_size * ('c' if is_server else 's')
+ received_str = received_str.decode()
+ if received_str[0:msg_size] == \
+ expected_str[0:msg_size]:
+ return
+ else:
+ raise PyverbsError(
+ 'Data validation failure: expected {exp}, received {rcv}'.
+ format(exp=expected_str, rcv=received_str))
+
+
+def traffic(client, server, iters, gid_idx, port):
+ """
+ Runs basic traffic between two sides
+ :param client: client side, clients base class is BaseTraffic
+ :param server: server side, servers base class is BaseTraffic
+ :param iters: number of traffic iterations
+ :param gid_idx: local gid index
+ :param port: IB port
+ :return:
+ """
+ s_recv_wr = get_recv_wr(server)
+ c_recv_wr = get_recv_wr(client)
+ post_recv(client.qp, c_recv_wr, client.num_msgs)
+ post_recv(server.qp, s_recv_wr, server.num_msgs)
+ for _ in range(iters):
+ c_send_wr = get_send_wr(client, False)
+ post_send(client, c_send_wr, gid_idx, port)
+ poll_cq(client.cq)
+ poll_cq(server.cq)
+ post_recv(client.qp, c_recv_wr)
+ msg_received = server.mr.read(server.msg_size, 0)
+ validate(msg_received, True, server.msg_size)
+ s_send_wr = get_send_wr(server, True)
+ post_send(server, s_send_wr, gid_idx, port)
+ poll_cq(server.cq)
+ poll_cq(client.cq)
+ post_recv(server.qp, s_recv_wr)
+ msg_received = client.mr.read(client.msg_size, 0)
+ validate(msg_received, False, client.msg_size)
+
+# Decorators
def requires_odp(qp_type):
def outer(func):
def inner(instance):