diff mbox series

[v2,20/25] selftests: tcp_authopt: Add tests for rollover

Message ID 070c753a89b9aa53c1fef03d10a1bf14503fcc11.1635784253.git.cdleonard@gmail.com (mailing list archive)
State New
Headers show
Series [v2,01/25] tcp: authopt: Initial support and key management | expand

Commit Message

Leonard Crestez Nov. 1, 2021, 4:34 p.m. UTC
RFC5925 requires that the use can examine or control the keys being
used. This is implemented in linux via fields on the TCP_AUTHOPT
sockopt.

Add socket-level tests for the adjusting keyids on live connections and
checking the they are reflected on the peer.

Also check smooth transitions via rnextkeyid.

Signed-off-by: Leonard Crestez <cdleonard@gmail.com>
---
 .../tcp_authopt_test/linux_tcp_authopt.py     |  16 +-
 .../tcp_authopt_test/test_rollover.py         | 181 ++++++++++++++++++
 2 files changed, 194 insertions(+), 3 deletions(-)
 create mode 100644 tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_rollover.py
diff mbox series

Patch

diff --git a/tools/testing/selftests/tcp_authopt/tcp_authopt_test/linux_tcp_authopt.py b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/linux_tcp_authopt.py
index b9dc9decda07..75cf5f993ccb 100644
--- a/tools/testing/selftests/tcp_authopt/tcp_authopt_test/linux_tcp_authopt.py
+++ b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/linux_tcp_authopt.py
@@ -29,10 +29,12 @@  TCP_AUTHOPT_KEY = 39
 
 TCP_AUTHOPT_MAXKEYLEN = 80
 
 
 class TCP_AUTHOPT_FLAG(IntFlag):
+    LOCK_KEYID = BIT(0)
+    LOCK_RNEXTKEYID = BIT(1)
     REJECT_UNEXPECTED = BIT(2)
 
 
 class TCP_AUTHOPT_KEY_FLAG(IntFlag):
     DEL = BIT(0)
@@ -48,24 +50,32 @@  class TCP_AUTHOPT_ALG(IntEnum):
 @dataclass
 class tcp_authopt:
     """Like linux struct tcp_authopt"""
 
     flags: int = 0
-    sizeof = 4
+    send_keyid: int = 0
+    send_rnextkeyid: int = 0
+    recv_keyid: int = 0
+    recv_rnextkeyid: int = 0
+    sizeof = 8
 
     def pack(self) -> bytes:
         return struct.pack(
-            "I",
+            "IBBBB",
             self.flags,
+            self.send_keyid,
+            self.send_rnextkeyid,
+            self.recv_keyid,
+            self.recv_rnextkeyid,
         )
 
     def __bytes__(self):
         return self.pack()
 
     @classmethod
     def unpack(cls, b: bytes):
-        tup = struct.unpack("I", b)
+        tup = struct.unpack("IBBBB", b)
         return cls(*tup)
 
 
 def set_tcp_authopt(sock, opt: tcp_authopt):
     return sock.setsockopt(socket.SOL_TCP, TCP_AUTHOPT, bytes(opt))
diff --git a/tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_rollover.py b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_rollover.py
new file mode 100644
index 000000000000..2f48706a90e5
--- /dev/null
+++ b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_rollover.py
@@ -0,0 +1,181 @@ 
+# SPDX-License-Identifier: GPL-2.0
+import socket
+import typing
+from contextlib import ExitStack, contextmanager
+
+from .conftest import skipif_missing_tcp_authopt
+from .linux_tcp_authopt import (
+    TCP_AUTHOPT_FLAG,
+    get_tcp_authopt,
+    set_tcp_authopt,
+    set_tcp_authopt_key,
+    tcp_authopt,
+    tcp_authopt_key,
+)
+from .server import SimpleServerThread
+from .utils import DEFAULT_TCP_SERVER_PORT, check_socket_echo, create_listen_socket
+
+pytestmark = skipif_missing_tcp_authopt
+
+
+@contextmanager
+def make_tcp_authopt_socket_pair(
+    server_addr="127.0.0.1",
+    server_authopt: tcp_authopt = None,
+    server_key_list: typing.Iterable[tcp_authopt_key] = [],
+    client_authopt: tcp_authopt = None,
+    client_key_list: typing.Iterable[tcp_authopt_key] = [],
+) -> typing.Iterator[typing.Tuple[socket.socket, socket.socket]]:
+    """Make a pair for connected sockets for key switching tests
+
+    Server runs in a background thread implementing echo protocol"""
+    with ExitStack() as exit_stack:
+        listen_socket = exit_stack.enter_context(
+            create_listen_socket(bind_addr=server_addr)
+        )
+        server_thread = exit_stack.enter_context(
+            SimpleServerThread(listen_socket, mode="echo")
+        )
+        client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        client_socket.settimeout(1.0)
+
+        if server_authopt:
+            set_tcp_authopt(listen_socket, server_authopt)
+        for k in server_key_list:
+            set_tcp_authopt_key(listen_socket, k)
+        if client_authopt:
+            set_tcp_authopt(client_socket, client_authopt)
+        for k in client_key_list:
+            set_tcp_authopt_key(client_socket, k)
+
+        client_socket.connect((server_addr, DEFAULT_TCP_SERVER_PORT))
+        check_socket_echo(client_socket)
+        server_socket = server_thread.server_socket[0]
+
+        yield client_socket, server_socket
+
+
+def test_get_keyids(exit_stack: ExitStack):
+    """Check reading key ids"""
+    sk1 = tcp_authopt_key(send_id=11, recv_id=12, key="111")
+    sk2 = tcp_authopt_key(send_id=21, recv_id=22, key="222")
+    ck1 = tcp_authopt_key(send_id=12, recv_id=11, key="111")
+    client_socket, server_socket = exit_stack.enter_context(
+        make_tcp_authopt_socket_pair(
+            server_key_list=[sk1, sk2],
+            client_key_list=[ck1],
+        )
+    )
+
+    check_socket_echo(client_socket)
+    client_tcp_authopt = get_tcp_authopt(client_socket)
+    server_tcp_authopt = get_tcp_authopt(server_socket)
+    assert server_tcp_authopt.send_keyid == 11
+    assert server_tcp_authopt.send_rnextkeyid == 12
+    assert server_tcp_authopt.recv_keyid == 12
+    assert server_tcp_authopt.recv_rnextkeyid == 11
+    assert client_tcp_authopt.send_keyid == 12
+    assert client_tcp_authopt.send_rnextkeyid == 11
+    assert client_tcp_authopt.recv_keyid == 11
+    assert client_tcp_authopt.recv_rnextkeyid == 12
+
+
+def test_rollover_send_keyid(exit_stack: ExitStack):
+    """Check reading key ids"""
+    sk1 = tcp_authopt_key(send_id=11, recv_id=12, key="111")
+    sk2 = tcp_authopt_key(send_id=21, recv_id=22, key="222")
+    ck1 = tcp_authopt_key(send_id=12, recv_id=11, key="111")
+    ck2 = tcp_authopt_key(send_id=22, recv_id=21, key="222")
+    client_socket, server_socket = exit_stack.enter_context(
+        make_tcp_authopt_socket_pair(
+            server_key_list=[sk1, sk2],
+            client_key_list=[ck1, ck2],
+            client_authopt=tcp_authopt(
+                send_keyid=12, flags=TCP_AUTHOPT_FLAG.LOCK_KEYID
+            ),
+        )
+    )
+
+    check_socket_echo(client_socket)
+    assert get_tcp_authopt(client_socket).recv_keyid == 11
+    assert get_tcp_authopt(server_socket).recv_keyid == 12
+
+    # Explicit request for key2
+    set_tcp_authopt(
+        client_socket, tcp_authopt(send_keyid=22, flags=TCP_AUTHOPT_FLAG.LOCK_KEYID)
+    )
+    check_socket_echo(client_socket)
+    assert get_tcp_authopt(client_socket).recv_keyid == 21
+    assert get_tcp_authopt(server_socket).recv_keyid == 22
+
+
+def test_rollover_rnextkeyid(exit_stack: ExitStack):
+    """Check reading key ids"""
+    sk1 = tcp_authopt_key(send_id=11, recv_id=12, key="111")
+    sk2 = tcp_authopt_key(send_id=21, recv_id=22, key="222")
+    ck1 = tcp_authopt_key(send_id=12, recv_id=11, key="111")
+    ck2 = tcp_authopt_key(send_id=22, recv_id=21, key="222")
+    client_socket, server_socket = exit_stack.enter_context(
+        make_tcp_authopt_socket_pair(
+            server_key_list=[sk1],
+            client_key_list=[ck1, ck2],
+            client_authopt=tcp_authopt(
+                send_keyid=12, flags=TCP_AUTHOPT_FLAG.LOCK_KEYID
+            ),
+        )
+    )
+
+    check_socket_echo(client_socket)
+    assert get_tcp_authopt(server_socket).recv_rnextkeyid == 11
+
+    # request rnextkeyd=22 but server does not have it
+    set_tcp_authopt(
+        client_socket,
+        tcp_authopt(send_rnextkeyid=21, flags=TCP_AUTHOPT_FLAG.LOCK_RNEXTKEYID),
+    )
+    check_socket_echo(client_socket)
+    check_socket_echo(client_socket)
+    assert get_tcp_authopt(server_socket).recv_rnextkeyid == 21
+    assert get_tcp_authopt(server_socket).send_keyid == 11
+
+    # after adding k2 on server the key is switched
+    set_tcp_authopt_key(server_socket, sk2)
+    check_socket_echo(client_socket)
+    check_socket_echo(client_socket)
+    assert get_tcp_authopt(server_socket).send_keyid == 21
+
+
+def test_rollover_delkey(exit_stack: ExitStack):
+    sk1 = tcp_authopt_key(send_id=11, recv_id=12, key="111")
+    sk2 = tcp_authopt_key(send_id=21, recv_id=22, key="222")
+    ck1 = tcp_authopt_key(send_id=12, recv_id=11, key="111")
+    ck2 = tcp_authopt_key(send_id=22, recv_id=21, key="222")
+    client_socket, server_socket = exit_stack.enter_context(
+        make_tcp_authopt_socket_pair(
+            server_key_list=[sk1, sk2],
+            client_key_list=[ck1, ck2],
+            client_authopt=tcp_authopt(
+                send_keyid=12, flags=TCP_AUTHOPT_FLAG.LOCK_KEYID
+            ),
+        )
+    )
+
+    check_socket_echo(client_socket)
+    assert get_tcp_authopt(server_socket).recv_keyid == 12
+
+    # invalid send_keyid is just ignored
+    set_tcp_authopt(client_socket, tcp_authopt(send_keyid=7))
+    check_socket_echo(client_socket)
+    assert get_tcp_authopt(client_socket).send_keyid == 12
+    assert get_tcp_authopt(server_socket).recv_keyid == 12
+    assert get_tcp_authopt(client_socket).recv_keyid == 11
+
+    # If a key is removed it is replaced by anything that matches
+    ck1.delete_flag = True
+    set_tcp_authopt_key(client_socket, ck1)
+    check_socket_echo(client_socket)
+    check_socket_echo(client_socket)
+    assert get_tcp_authopt(client_socket).send_keyid == 22
+    assert get_tcp_authopt(server_socket).send_keyid == 21
+    assert get_tcp_authopt(server_socket).recv_keyid == 22
+    assert get_tcp_authopt(client_socket).recv_keyid == 21