diff mbox series

[rdma-core,2/8] pyverbs: Add TSO support

Message ID 20191231091915.23874-3-noaos@mellanox.com (mailing list archive)
State Not Applicable
Headers show
Series pyverbs: Add support for the new post send API | expand

Commit Message

Noa Osherovich Dec. 31, 2019, 9:19 a.m. UTC
Add TSO class to allow users to easily set the TSO field in a send
work request.

Signed-off-by: Ahmad Ghazawi <ahmadg@mellanox.com>
Signed-off-by: Noa Osherovich <noaos@mellanox.com>
Reviewed-by: Edward Srouji <edwards@mellanox.com>
---
 pyverbs/libibverbs.pxd |  7 ++---
 pyverbs/wr.pxd         |  6 +++++
 pyverbs/wr.pyx         | 60 ++++++++++++++++++++++++++++++++++++++++--
 3 files changed, 66 insertions(+), 7 deletions(-)
diff mbox series

Patch

diff --git a/pyverbs/libibverbs.pxd b/pyverbs/libibverbs.pxd
index fea8a1b408da..8949759511a5 100755
--- a/pyverbs/libibverbs.pxd
+++ b/pyverbs/libibverbs.pxd
@@ -311,10 +311,6 @@  cdef extern from 'infiniband/verbs.h':
         unsigned short  hdr_sz
         unsigned short  mss
 
-    cdef union unnamed:
-        bind_mw         bind_mw
-        tso             tso
-
     cdef struct xrc:
         unsigned int    remote_srqn
 
@@ -330,7 +326,8 @@  cdef extern from 'infiniband/verbs.h':
         unsigned int    send_flags
         wr              wr
         qp_type         qp_type
-        unnamed         unnamed
+        bind_mw         bind_mw
+        tso             tso
 
     cdef struct ibv_qp_cap:
         unsigned int    max_send_wr
diff --git a/pyverbs/wr.pxd b/pyverbs/wr.pxd
index e259249ef7f8..5cb282dc65c3 100644
--- a/pyverbs/wr.pxd
+++ b/pyverbs/wr.pxd
@@ -16,3 +16,9 @@  cdef class RecvWR(PyverbsCM):
 
 cdef class SendWR(PyverbsCM):
     cdef v.ibv_send_wr send_wr
+
+cdef class TSO(PyverbsCM):
+    cdef void* buf
+    cdef int length
+    cdef int mss
+    cpdef alloc_buf(self, content, length)
diff --git a/pyverbs/wr.pyx b/pyverbs/wr.pyx
index 06186e9b0fd2..b1df3677fce3 100644
--- a/pyverbs/wr.pyx
+++ b/pyverbs/wr.pyx
@@ -1,12 +1,16 @@ 
 # SPDX-License-Identifier: (GPL-2.0 OR Linux-OpenIB)
 # Copyright (c) 2019 Mellanox Technologies Inc. All rights reserved. See COPYING file
 
+import resource
+
+from cpython.pycapsule cimport PyCapsule_New
+from libc.stdlib cimport free, malloc
+from libc.string cimport memcpy
+
 from pyverbs.pyverbs_error import PyverbsUserError, PyverbsError
 from pyverbs.base import PyverbsRDMAErrno
 cimport pyverbs.libibverbs_enums as e
 from pyverbs.addr cimport AH
-from libc.stdlib cimport free, malloc
-from libc.string cimport memcpy
 
 
 cdef class SGE(PyverbsCM):
@@ -281,6 +285,58 @@  cdef class SendWR(PyverbsCM):
         """
         self.send_wr.qp_type.xrc.remote_srqn = remote_srqn
 
+    def set_tso(self, TSO tso not None):
+        """
+        Set the members of the tso struct in the send_wr's anonymous union.
+        :param tso: A TSO object to copy to the work request
+        :return: None
+        """
+        self.send_wr.tso.hdr = tso.buf
+        self.send_wr.tso.hdr_sz = tso.length
+        self.send_wr.tso.mss = tso.mss
+
+
+cdef class TSO(PyverbsCM):
+    """ Represents the 'tso' anonymous struct inside a send WR """
+    def __init__(self, data, length, mss):
+        super().__init__()
+        self.alloc_buf(data, length)
+        self.length = length
+        self.mss = mss
+
+    cpdef alloc_buf(self, data, length):
+        if self.buf != NULL:
+            free(self.buf)
+        self.buf = malloc(length)
+        if self.buf == NULL:
+            raise PyverbsError('Failed to allocate TSO buffer of length {l}'.\
+                                format(l=length))
+        if isinstance(data, str):
+            data = data.encode()
+        memcpy(self.buf, <char*>data, length)
+        self.length = length
+
+    def __dealloc__(self):
+        self.close()
+
+    cpdef close(self):
+        if self.buf != NULL:
+            free(self.buf)
+            self.buf = NULL
+
+    @property
+    def buf(self):
+        return PyCapsule_New(self.buf, NULL, NULL)
+
+    def __str__(self):
+        print_format = '{:22}: {:<20}\n'
+        data = <char*>(self.buf)
+        data = data[:self.length].decode()
+        return print_format.format('MSS', self.mss) +\
+               print_format.format('Length', self.length) +\
+               print_format.format('Data', data)
+
+
 def send_flags_to_str(flags):
     send_flags = {e.IBV_SEND_FENCE: 'IBV_SEND_FENCE',
                   e.IBV_SEND_SIGNALED: 'IBV_SEND_SIGNALED',