diff mbox series

[v2,18/25] selftests: tcp_authopt: Initial sne test

Message ID cc1f1ce2d723613db48b15249406e9c9f14479a4.1635784253.git.cdleonard@gmail.com (mailing list archive)
State New
Headers show
Series [v2,01/25] tcp: authopt: Initial support and key management | expand

Commit Message

Leonard Crestez Nov. 1, 2021, 4:34 p.m. UTC
In order to trigger a seq or ack rollover create many connection in a
loop and check for a "high" value, then make a lot of traffic.

This relies on both TCP_REPAIR and TCP_REPAIR_AUTHOPT, making it unfit
for upstream.

Signed-off-by: Leonard Crestez <cdleonard@gmail.com>
---
 .../tcp_authopt_test/linux_tcp_repair.py      |  67 ++++++
 .../tcp_authopt/tcp_authopt_test/test_sne.py  | 202 ++++++++++++++++++
 2 files changed, 269 insertions(+)
 create mode 100644 tools/testing/selftests/tcp_authopt/tcp_authopt_test/linux_tcp_repair.py
 create mode 100644 tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_sne.py
diff mbox series

Patch

diff --git a/tools/testing/selftests/tcp_authopt/tcp_authopt_test/linux_tcp_repair.py b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/linux_tcp_repair.py
new file mode 100644
index 000000000000..68f111a207e6
--- /dev/null
+++ b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/linux_tcp_repair.py
@@ -0,0 +1,67 @@ 
+# SPDX-License-Identifier: GPL-2.0
+import socket
+import struct
+from contextlib import contextmanager
+from enum import IntEnum
+
+# Extra sockopts not present in python stdlib
+TCP_REPAIR = 19
+TCP_REPAIR_QUEUE = 20
+TCP_QUEUE_SEQ = 21
+TCP_REPAIR_OPTIONS = 22
+TCP_REPAIR_WINDOW = 29
+
+
+class TCP_REPAIR_VAL(IntEnum):
+    OFF = 0
+    ON = 1
+    OFF_NO_WP = -1
+
+
+def get_tcp_repair(sock) -> TCP_REPAIR_VAL:
+    return TCP_REPAIR_VAL(sock.getsockopt(socket.SOL_TCP, TCP_REPAIR))
+
+
+def set_tcp_repair(sock, val: TCP_REPAIR_VAL) -> None:
+    return sock.setsockopt(socket.SOL_TCP, TCP_REPAIR, int(val))
+
+
+class TCP_REPAIR_QUEUE_ID(IntEnum):
+    NO_QUEUE = 0
+    RECV_QUEUE = 1
+    SEND_QUEUE = 2
+
+
+def get_tcp_repair_queue(sock) -> TCP_REPAIR_QUEUE_ID:
+    return TCP_REPAIR_QUEUE_ID(sock.getsockopt(socket.SOL_TCP, TCP_REPAIR_QUEUE))
+
+
+def set_tcp_repair_queue(sock, val: TCP_REPAIR_QUEUE_ID) -> None:
+    return sock.setsockopt(socket.SOL_TCP, TCP_REPAIR_QUEUE, int(val))
+
+
+def get_tcp_queue_seq(sock) -> int:
+    return struct.unpack("I", sock.getsockopt(socket.SOL_TCP, TCP_QUEUE_SEQ, 4))[0]
+
+
+def set_tcp_queue_seq(sock, val: int) -> None:
+    return sock.setsockopt(socket.SOL_TCP, TCP_QUEUE_SEQ, val)[0]
+
+
+@contextmanager
+def tcp_repair_toggle(sock, off_val=TCP_REPAIR_VAL.OFF_NO_WP):
+    """Set TCP_REPAIR on/off as a context"""
+    try:
+        set_tcp_repair(sock, TCP_REPAIR_VAL.ON)
+        yield
+    finally:
+        set_tcp_repair(sock, off_val)
+
+
+def get_tcp_repair_recv_send_queue_seq(sock):
+    with tcp_repair_toggle(sock):
+        set_tcp_repair_queue(sock, TCP_REPAIR_QUEUE_ID.RECV_QUEUE)
+        recv_seq = get_tcp_queue_seq(sock)
+        set_tcp_repair_queue(sock, TCP_REPAIR_QUEUE_ID.SEND_QUEUE)
+        send_seq = get_tcp_queue_seq(sock)
+        return (recv_seq, send_seq)
diff --git a/tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_sne.py b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_sne.py
new file mode 100644
index 000000000000..180d7cdbd5f3
--- /dev/null
+++ b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_sne.py
@@ -0,0 +1,202 @@ 
+# SPDX-License-Identifier: GPL-2.0
+"""Validate SNE implementation for TCP-AO"""
+
+import logging
+import socket
+from contextlib import ExitStack
+from ipaddress import ip_address
+
+import pytest
+
+from .linux_tcp_authopt import set_tcp_authopt_key_kwargs
+from .linux_tcp_repair import get_tcp_repair_recv_send_queue_seq, tcp_repair_toggle
+from .netns_fixture import NamespaceFixture
+from .scapy_conntrack import TCPConnectionKey, TCPConnectionTracker
+from .scapy_utils import AsyncSnifferContext, create_capture_socket, tcp_seq_wrap
+from .server import SimpleServerThread
+from .utils import (
+    DEFAULT_TCP_SERVER_PORT,
+    check_socket_echo,
+    create_client_socket,
+    create_listen_socket,
+    socket_set_linger,
+)
+from .validator import TcpAuthValidator, TcpAuthValidatorKey
+
+logger = logging.getLogger(__name__)
+
+
+def add_connection_info(
+    tracker: TCPConnectionTracker,
+    saddr,
+    daddr,
+    sport,
+    dport,
+    sisn,
+    disn,
+):
+    client2server_key = TCPConnectionKey(
+        saddr=saddr,
+        daddr=daddr,
+        sport=sport,
+        dport=dport,
+    )
+    client2server_conn = tracker.get_or_create(client2server_key)
+    client2server_conn.sisn = sisn
+    client2server_conn.disn = disn
+    client2server_conn.snd_sne.reset(sisn)
+    client2server_conn.rcv_sne.reset(disn)
+    client2server_conn.found_syn = True
+    client2server_conn.found_synack = True
+    server2client_conn = tracker.get_or_create(client2server_key.rev())
+    server2client_conn.sisn = disn
+    server2client_conn.disn = sisn
+    server2client_conn.snd_sne.reset(disn)
+    server2client_conn.rcv_sne.reset(sisn)
+    server2client_conn.found_syn = True
+    server2client_conn.found_synack = True
+
+
+@pytest.mark.parametrize("signed", [False, True])
+def test_high_seq_rollover(exit_stack: ExitStack, signed: bool):
+    """Test SNE by rolling over from a high seq/ack value
+
+    Create many connections until a very high seq/ack is found and then transfer
+    enough for those values to roll over.
+
+    A side effect of this approach is that this stresses connection
+    establishment.
+    """
+    overflow = 0x200000
+    bufsize = 0x10000
+    secret_key = b"12345"
+    mode = "echo"
+    validator_enabled = True
+
+    nsfixture = exit_stack.enter_context(NamespaceFixture())
+    server_addr = nsfixture.get_addr(socket.AF_INET, 1)
+    client_addr = nsfixture.get_addr(socket.AF_INET, 2)
+    server_addr_port = (str(server_addr), DEFAULT_TCP_SERVER_PORT)
+    listen_socket = create_listen_socket(
+        ns=nsfixture.server_netns_name,
+        bind_addr=server_addr,
+        listen_depth=1024,
+    )
+    exit_stack.enter_context(listen_socket)
+    if signed:
+        set_tcp_authopt_key_kwargs(listen_socket, key=secret_key)
+    server_thread = SimpleServerThread(listen_socket, mode=mode, bufsize=bufsize)
+    exit_stack.enter_context(server_thread)
+
+    found = False
+    client_socket = None
+    for iternum in range(50000):
+        try:
+            # Manually assign increasing client ports
+            #
+            # Sometimes linux kills timewait sockets (TCPTimeWaitOverflow) and
+            # then attempts to reuse the port. The stricter validation
+            # requirements of TCP-AO mean the other side of the socket survives
+            # and rejects packets coming from the reused port.
+            #
+            # This issue is not related to SNE so a workaround is acceptable.
+            client_socket = create_client_socket(
+                ns=nsfixture.client_netns_name,
+                bind_addr=client_addr,
+                bind_port=10000 + iternum,
+            )
+            if signed:
+                set_tcp_authopt_key_kwargs(client_socket, key=secret_key)
+            try:
+                client_socket.connect(server_addr_port)
+            except:
+                logger.error("failed connect on iteration %d", iternum, exc_info=True)
+                raise
+
+            recv_seq, send_seq = get_tcp_repair_recv_send_queue_seq(client_socket)
+            if (recv_seq + overflow > 0x100000000 and mode == "echo") or (
+                send_seq + overflow > 0x100000000
+            ):
+                found = True
+                break
+            # Wait for graceful close to avoid swamping server listen queue.
+            # This makes the test work even with a server listen_depth=1 but set
+            # a very high value anyway.
+            socket_set_linger(client_socket, 1, 1)
+            client_socket.close()
+            client_socket = None
+        finally:
+            if not found and client_socket:
+                client_socket.close()
+    assert found
+    assert client_socket is not None
+
+    logger.debug("setup recv_seq %08x send_seq %08x", recv_seq, send_seq)
+
+    # Init validator
+    if signed and validator_enabled:
+        capture_filter = f"tcp port {DEFAULT_TCP_SERVER_PORT}"
+        capture_socket = create_capture_socket(
+            ns=nsfixture.client_netns_name,
+            iface="veth0",
+            filter=capture_filter,
+        )
+        sniffer = exit_stack.enter_context(
+            AsyncSnifferContext(opened_socket=capture_socket)
+        )
+        validator = TcpAuthValidator()
+        validator.keys.append(
+            TcpAuthValidatorKey(key=secret_key, alg_name="HMAC-SHA-1-96")
+        )
+
+        # SYN+SYNACK is not captured so initialize connection info manually
+        add_connection_info(
+            validator.tracker,
+            saddr=ip_address(client_addr),
+            daddr=ip_address(server_addr),
+            dport=client_socket.getpeername()[1],
+            sport=client_socket.getsockname()[1],
+            sisn=tcp_seq_wrap(send_seq - 1),
+            disn=tcp_seq_wrap(recv_seq - 1),
+        )
+
+    logger.info("transfer %d bytes", 2 * overflow)
+    fail_transfer = False
+    for iternum in range(2 * overflow // bufsize):
+        try:
+            if mode == "recv":
+                from .utils import randbytes
+
+                send_buf = randbytes(bufsize)
+                client_socket.sendall(send_buf)
+            else:
+                check_socket_echo(client_socket, bufsize)
+        except:
+            logger.error("failed traffic on iteration %d", iternum, exc_info=True)
+            fail_transfer = True
+            break
+
+    new_recv_seq, new_send_seq = get_tcp_repair_recv_send_queue_seq(client_socket)
+    logger.debug("final recv_seq %08x send_seq %08x", new_recv_seq, new_send_seq)
+    assert new_recv_seq < recv_seq or new_send_seq < send_seq
+
+    # Validate capture
+    if signed and validator_enabled:
+        sniffer.stop()
+        for p in sniffer.results:
+            validator.handle_packet(p)
+        # Allow incomplete connections from FIN/ACK of connections dropped
+        # because of low seq/ack
+        validator.raise_errors(allow_incomplete=True)
+        client_scappy_key = TCPConnectionKey(
+            saddr=ip_address(client_addr),
+            daddr=ip_address(server_addr),
+            dport=client_socket.getpeername()[1],
+            sport=client_socket.getsockname()[1],
+        )
+        client_scappy_conn = validator.tracker.get(client_scappy_key)
+        snd_sne_rollover = client_scappy_conn.snd_sne.sne != 0
+        rcv_sne_rollover = client_scappy_conn.rcv_sne.sne != 0
+        assert snd_sne_rollover or rcv_sne_rollover
+
+    assert not fail_transfer