new file mode 100644
@@ -0,0 +1,91 @@
+# SPDX-License-Identifier: GPL-2.0
+import logging
+import threading
+import typing
+
+import scapy.sessions
+from scapy.packet import Packet
+
+from .scapy_conntrack import TCPConnectionInfo, TCPConnectionTracker
+
+logger = logging.getLogger(__name__)
+
+
+class FullTCPSniffSession(scapy.sessions.DefaultSession):
+ """Implementation of a scapy sniff session that can wait for a full TCP capture
+
+ Allows another thread to wait for a complete FIN handshake without polling or sleep.
+ """
+
+ #: Server port used to identify client and server
+ server_port: int
+ #: Connection tracker
+ tracker: TCPConnectionTracker
+
+ def __init__(self, server_port, **kw):
+ super().__init__(**kw)
+ self.server_port = server_port
+ self.tracker = TCPConnectionTracker()
+ self._close_event = threading.Event()
+ self._init_isn_event = threading.Event()
+ self._client_info = None
+ self._server_info = None
+
+ @property
+ def client_info(self) -> TCPConnectionInfo:
+ if not self._client_info:
+ self._client_info = self.tracker.match_one(dport=self.server_port)
+ return self._client_info
+
+ @property
+ def server_info(self) -> TCPConnectionInfo:
+ if not self._server_info:
+ self._server_info = self.tracker.match_one(sport=self.server_port)
+ return self._server_info
+
+ @property
+ def client_isn(self):
+ return self.client_info.sisn
+
+ @property
+ def server_isn(self):
+ return self.server_info.sisn
+
+ def on_packet_received(self, p: Packet):
+ super().on_packet_received(p)
+ self.tracker.handle_packet(p)
+
+ # check events:
+ if self.client_info.sisn is not None and self.client_info.disn is not None:
+ assert (
+ self.client_info.sisn == self.server_info.disn
+ and self.server_info.sisn == self.client_info.disn
+ )
+ self._init_isn_event.set()
+ if self.client_info.found_recv_finack and self.server_info.found_recv_finack:
+ self._close_event.set()
+
+ def reset(self):
+ """Reset known information about client/server"""
+ self.tracker.reset()
+ self._server_info = None
+ self._client_info = None
+ self._close_event.clear()
+ self._init_isn_event.clear()
+
+ def wait_close(self, timeout=10):
+ """Wait for a graceful close with FINs acked by both side"""
+ self._close_event.wait(timeout=timeout)
+ if not self._close_event.is_set():
+ raise TimeoutError("Timed out waiting for graceful close")
+
+ def wait_init_isn(self, timeout=10):
+ """Wait for both client_isn and server_isn to be determined"""
+ self._init_isn_event.wait(timeout=timeout)
+ if not self._init_isn_event.is_set():
+ raise TimeoutError("Timed out waiting for Initial Sequence Numbers")
+
+ def get_client_server_isn(self, timeout=10) -> typing.Tuple[int, int]:
+ """Return client/server ISN, blocking until they are captured"""
+ self.wait_init_isn(timeout=timeout)
+ return self.client_isn, self.server_isn
new file mode 100644
@@ -0,0 +1,110 @@
+# SPDX-License-Identifier: GPL-2.0
+"""Python wrapper around linux TCP_MD5SIG ABI"""
+
+import socket
+import struct
+import typing
+from dataclasses import dataclass
+from enum import IntFlag
+
+from .sockaddr import sockaddr_convert, sockaddr_unpack
+
+TCP_MD5SIG = 14
+TCP_MD5SIG_EXT = 32
+TCP_MD5SIG_MAXKEYLEN = 80
+
+
+class TCP_MD5SIG_FLAG(IntFlag):
+ PREFIX = 0x1
+ IFINDEX = 0x2
+
+
+@dataclass
+class tcp_md5sig:
+ """Like linux struct tcp_md5sig"""
+
+ addr: typing.Any
+ flags: typing.Optional[int]
+ prefixlen: typing.Optional[int]
+ keylen: typing.Optional[int]
+ ifindex: typing.Optional[int]
+ key: bytes
+
+ sizeof = 128 + 88
+
+ def __init__(
+ self, addr=None, flags=None, prefixlen=None, keylen=None, ifindex=0, key=bytes()
+ ):
+ self.addr = addr
+ self.flags = flags
+ self.prefixlen = prefixlen
+ self.ifindex = ifindex
+ self.key = key
+ self.keylen = keylen
+
+ def get_auto_flags(self):
+ return (TCP_MD5SIG_FLAG.PREFIX if self.prefixlen is not None else 0) | (
+ TCP_MD5SIG_FLAG.IFINDEX if self.ifindex else 0
+ )
+
+ def get_real_flags(self):
+ if self.flags is None:
+ return self.get_auto_flags()
+ else:
+ return self.flags
+
+ def get_addr_bytes(self) -> bytes:
+ if self.addr is None:
+ return b"\0" * 128
+ if self.addr is bytes:
+ assert len(self.addr) == 128
+ return self.addr
+ return sockaddr_convert(self.addr).pack()
+
+ def pack(self) -> bytes:
+ return struct.pack(
+ "128sBBHi80s",
+ self.get_addr_bytes(),
+ self.get_real_flags(),
+ self.prefixlen if self.prefixlen is not None else 0,
+ self.keylen if self.keylen is not None else len(self.key),
+ self.ifindex if self.ifindex is not None else 0,
+ self.key,
+ )
+
+ def __bytes__(self):
+ return self.pack()
+
+ @classmethod
+ def unpack(cls, buffer: bytes) -> "tcp_md5sig":
+ tup = struct.unpack("128sBBHi80s", buffer)
+ addr = sockaddr_unpack(tup[0])
+ return cls(addr, *tup[1:])
+
+ def set_ipv4_addr_all(self):
+ from .sockaddr import sockaddr_in
+
+ self.addr = sockaddr_in()
+ self.prefixlen = 0
+
+ def set_ipv6_addr_all(self):
+ from .sockaddr import sockaddr_in6
+
+ self.addr = sockaddr_in6()
+ self.prefixlen = 0
+
+
+def setsockopt_md5sig(sock, opt: tcp_md5sig):
+ if opt.flags != 0:
+ optname = TCP_MD5SIG_EXT
+ else:
+ optname = TCP_MD5SIG
+ return sock.setsockopt(socket.SOL_TCP, optname, bytes(opt))
+
+
+def setsockopt_md5sig_kwargs(sock, opt: tcp_md5sig = None, **kw):
+ if opt is None:
+ opt = tcp_md5sig()
+ for k, v in kw.items():
+ setattr(opt, k, v)
+ return setsockopt_md5sig(sock, opt)
new file mode 100644
@@ -0,0 +1,173 @@
+# SPDX-License-Identifier: GPL-2.0
+"""Identify TCP connections inside a capture and collect per-connection information"""
+import typing
+from dataclasses import dataclass
+
+from scapy.layers.inet import TCP
+from scapy.packet import Packet
+
+from .scapy_utils import IPvXAddress, get_packet_ipvx_dst, get_packet_ipvx_src
+from .sne_alg import SequenceNumberExtenderLinux
+
+
+@dataclass(frozen=True)
+class TCPConnectionKey:
+ """TCP connection identification key: standard 4-tuple"""
+
+ saddr: typing.Optional[IPvXAddress] = None
+ daddr: typing.Optional[IPvXAddress] = None
+ sport: int = 0
+ dport: int = 0
+
+ def rev(self) -> "TCPConnectionKey":
+ return TCPConnectionKey(self.daddr, self.saddr, self.dport, self.sport)
+
+
+def get_packet_tcp_connection_key(p: Packet) -> TCPConnectionKey:
+ th = p[TCP]
+ return TCPConnectionKey(
+ get_packet_ipvx_src(p), get_packet_ipvx_dst(p), th.sport, th.dport
+ )
+
+
+class TCPConnectionInfo:
+ saddr: typing.Optional[IPvXAddress] = None
+ daddr: typing.Optional[IPvXAddress] = None
+ sport: int = 0
+ dport: int = 0
+ sisn: typing.Optional[int] = None
+ disn: typing.Optional[int] = None
+ rcv_sne: SequenceNumberExtenderLinux
+ snd_sne: SequenceNumberExtenderLinux
+
+ found_syn = False
+ found_synack = False
+
+ found_send_fin = False
+ found_send_finack = False
+ found_recv_fin = False
+ found_recv_finack = False
+
+ def __init__(self):
+ self.rcv_sne = SequenceNumberExtenderLinux()
+ self.snd_sne = SequenceNumberExtenderLinux()
+
+ def get_key(self):
+ return TCPConnectionKey(self.saddr, self.daddr, self.sport, self.dport)
+
+ @classmethod
+ def from_key(cls, key: TCPConnectionKey) -> "TCPConnectionInfo":
+ obj = cls()
+ obj.saddr = key.saddr
+ obj.daddr = key.daddr
+ obj.sport = key.sport
+ obj.dport = key.dport
+ return obj
+
+ def handle_send(self, p: Packet):
+ th = p[TCP]
+ if self.get_key() != get_packet_tcp_connection_key(p):
+ raise ValueError("Packet not for this connection")
+
+ if th.flags.S and not th.flags.A:
+ assert th.ack == 0
+ self.found_syn = True
+ self.sisn = th.seq
+ self.snd_sne.reset(th.seq)
+ elif th.flags.S and th.flags.A:
+ self.found_synack = True
+ self.sisn = th.seq
+ self.snd_sne.reset(th.seq)
+ assert self.disn == th.ack - 1
+
+ # Should track seq numbers instead
+ if th.flags.F:
+ self.found_send_fin = True
+ if th.flags.A and self.found_recv_fin:
+ self.found_send_finack = True
+
+ # Should only take valid packets into account
+ self.snd_sne.calc(th.seq)
+
+ def handle_recv(self, p: Packet):
+ th = p[TCP]
+ if self.get_key().rev() != get_packet_tcp_connection_key(p):
+ raise ValueError("Packet not for this connection")
+
+ if th.flags.S and not th.flags.A:
+ assert th.ack == 0
+ self.found_syn = True
+ self.disn = th.seq
+ self.rcv_sne.reset(th.seq)
+ elif th.flags.S and th.flags.A:
+ self.found_synack = True
+ self.disn = th.seq
+ self.rcv_sne.reset(th.seq)
+ assert self.sisn == th.ack - 1
+
+ # Should track seq numbers instead
+ if th.flags.F:
+ self.found_recv_fin = True
+ if th.flags.A and self.found_send_fin:
+ self.found_recv_finack = True
+
+ # Should only take valid packets into account
+ self.rcv_sne.calc(th.seq)
+
+
+class TCPConnectionTracker:
+ table: typing.Dict[TCPConnectionKey, TCPConnectionInfo]
+
+ def __init__(self):
+ self.table = {}
+
+ def reset(self):
+ """Forget known connections"""
+ self.table = {}
+
+ def get_or_create(self, key: TCPConnectionKey) -> TCPConnectionInfo:
+ info = self.table.get(key, None)
+ if info is None:
+ info = TCPConnectionInfo.from_key(key)
+ self.table[key] = info
+ return info
+
+ def get(self, key: TCPConnectionKey) -> typing.Optional[TCPConnectionInfo]:
+ return self.table.get(key, None)
+
+ def handle_packet(self, p: Packet):
+ if not p or not TCP in p:
+ return
+ key = get_packet_tcp_connection_key(p)
+ info = self.get_or_create(key)
+ info.handle_send(p)
+ rkey = key.rev()
+ rinfo = self.get_or_create(rkey)
+ rinfo.handle_recv(p)
+
+ def iter_match(self, saddr=None, daddr=None, sport=None, dport=None):
+ def attr_optional_match(obj, name, val) -> bool:
+ if val is None:
+ return True
+ else:
+ return getattr(obj, name) == val
+
+ for key, info in self.table.items():
+ if (
+ attr_optional_match(key, "saddr", saddr)
+ and attr_optional_match(key, "daddr", daddr)
+ and attr_optional_match(key, "sport", sport)
+ and attr_optional_match(key, "dport", dport)
+ ):
+ yield info
+
+ def match_one(
+ self, saddr=None, daddr=None, sport=None, dport=None
+ ) -> typing.Optional[TCPConnectionInfo]:
+ res = list(self.iter_match(saddr, daddr, sport, dport))
+ if len(res) == 1:
+ return res[0]
+ elif len(res) == 0:
+ return None
+ else:
+ raise ValueError("Multiple connection matches")
new file mode 100644
@@ -0,0 +1,276 @@
+# SPDX-License-Identifier: GPL-2.0
+import logging
+import socket
+import subprocess
+from contextlib import ExitStack
+
+import pytest
+from scapy.data import ETH_P_IP, ETH_P_IPV6
+from scapy.layers.inet import IP, TCP
+from scapy.layers.inet6 import IPv6
+from scapy.layers.l2 import Ether
+from scapy.packet import Packet
+
+from . import linux_tcp_authopt
+from .full_tcp_sniff_session import FullTCPSniffSession
+from .linux_tcp_authopt import set_tcp_authopt_key, tcp_authopt_key
+from .netns_fixture import NamespaceFixture
+from .scapy_utils import (
+ AsyncSnifferContext,
+ create_capture_socket,
+ create_l2socket,
+ scapy_tcp_get_authopt_val,
+ scapy_tcp_get_md5_sig,
+)
+from .server import SimpleServerThread
+from .utils import (
+ DEFAULT_TCP_SERVER_PORT,
+ create_client_socket,
+ create_listen_socket,
+ netns_context,
+ nstat_json,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class TCPConnectionFixture:
+ """Test fixture with an instrumented TCP connection
+
+ Includes:
+ * pair of network namespaces
+ * one listen socket
+ * server thread with echo protocol
+ * one client socket
+ * one async sniffer on the server interface
+ * A `FullTCPSniffSession` examining TCP packets
+ * l2socket allowing packet injection from client
+
+ :ivar tcp_md5_key: Secret key for md5 (addr is implicit)
+ """
+
+ sniffer_session: FullTCPSniffSession
+
+ def __init__(
+ self,
+ address_family=socket.AF_INET,
+ sniffer_kwargs=None,
+ tcp_authopt_key: tcp_authopt_key = None,
+ server_thread_kwargs=None,
+ tcp_md5_key=None,
+ ):
+ self.address_family = address_family
+ self.server_port = DEFAULT_TCP_SERVER_PORT
+ self.client_port = 27972
+ self.sniffer_session = FullTCPSniffSession(DEFAULT_TCP_SERVER_PORT)
+ if sniffer_kwargs is None:
+ sniffer_kwargs = {}
+ self.sniffer_kwargs = sniffer_kwargs
+ self.tcp_authopt_key = tcp_authopt_key
+ self.server_thread = SimpleServerThread(
+ mode="echo", **(server_thread_kwargs or {})
+ )
+ self.tcp_md5_key = tcp_md5_key
+
+ def _set_tcp_md5(self):
+ from . import linux_tcp_md5sig
+ from .sockaddr import sockaddr_convert
+
+ linux_tcp_md5sig.setsockopt_md5sig(
+ self.listen_socket,
+ linux_tcp_md5sig.tcp_md5sig(
+ key=self.tcp_md5_key, addr=sockaddr_convert(self.client_addr)
+ ),
+ )
+ linux_tcp_md5sig.setsockopt_md5sig(
+ self.client_socket,
+ linux_tcp_md5sig.tcp_md5sig(
+ key=self.tcp_md5_key, addr=sockaddr_convert(self.server_addr)
+ ),
+ )
+
+ def create_client_socket(self, bind_port=0):
+ return create_client_socket(
+ ns=self.nsfixture.client_netns_name,
+ family=self.address_family,
+ bind_addr=self.client_addr,
+ bind_port=bind_port,
+ )
+
+ def __enter__(self):
+ if self.tcp_authopt_key and not linux_tcp_authopt.has_tcp_authopt():
+ pytest.skip("Need TCP_AUTHOPT")
+
+ self.exit_stack = ExitStack()
+ self.exit_stack.__enter__()
+
+ self.nsfixture = self.exit_stack.enter_context(NamespaceFixture())
+ self.server_addr = self.nsfixture.get_addr(self.address_family, 1)
+ self.client_addr = self.nsfixture.get_addr(self.address_family, 2)
+
+ self.listen_socket = create_listen_socket(
+ ns=self.nsfixture.server_netns_name,
+ family=self.address_family,
+ bind_addr=self.server_addr,
+ bind_port=self.server_port,
+ )
+ self.exit_stack.enter_context(self.listen_socket)
+ self.client_socket = self.create_client_socket(bind_port=self.client_port)
+ self.exit_stack.enter_context(self.client_socket)
+ self.server_thread.add_listen_socket(self.listen_socket)
+ self.exit_stack.enter_context(self.server_thread)
+
+ if self.tcp_authopt_key:
+ set_tcp_authopt_key(self.listen_socket, self.tcp_authopt_key)
+ set_tcp_authopt_key(self.client_socket, self.tcp_authopt_key)
+
+ if self.tcp_md5_key:
+ self._set_tcp_md5()
+
+ capture_filter = f"tcp port {self.server_port}"
+ self.capture_socket = create_capture_socket(
+ ns=self.nsfixture.server_netns_name, iface="veth0", filter=capture_filter
+ )
+ self.exit_stack.enter_context(self.capture_socket)
+
+ self.sniffer = AsyncSnifferContext(
+ opened_socket=self.capture_socket,
+ session=self.sniffer_session,
+ prn=log_tcp_authopt_packet,
+ **self.sniffer_kwargs,
+ )
+ self.exit_stack.enter_context(self.sniffer)
+
+ self.client_l2socket = create_l2socket(
+ ns=self.nsfixture.client_netns_name, iface="veth0"
+ )
+ self.exit_stack.enter_context(self.client_l2socket)
+ self.server_l2socket = create_l2socket(
+ ns=self.nsfixture.server_netns_name, iface="veth0"
+ )
+ self.exit_stack.enter_context(self.server_l2socket)
+
+ def __exit__(self, *args):
+ self.exit_stack.__exit__(*args)
+
+ @property
+ def ethertype(self):
+ if self.address_family == socket.AF_INET:
+ return ETH_P_IP
+ elif self.address_family == socket.AF_INET6:
+ return ETH_P_IPV6
+ else:
+ raise ValueError("bad address_family={self.address_family}")
+
+ def scapy_iplayer(self):
+ if self.address_family == socket.AF_INET:
+ return IP
+ elif self.address_family == socket.AF_INET6:
+ return IPv6
+ else:
+ raise ValueError("bad address_family={self.address_family}")
+
+ def create_client2server_packet(self) -> Packet:
+ return (
+ Ether(
+ type=self.ethertype,
+ src=self.nsfixture.client_mac_addr,
+ dst=self.nsfixture.server_mac_addr,
+ )
+ / self.scapy_iplayer()(src=str(self.client_addr), dst=str(self.server_addr))
+ / TCP(sport=self.client_port, dport=self.server_port)
+ )
+
+ def create_server2client_packet(self) -> Packet:
+ return (
+ Ether(
+ type=self.ethertype,
+ src=self.nsfixture.server_mac_addr,
+ dst=self.nsfixture.client_mac_addr,
+ )
+ / self.scapy_iplayer()(src=str(self.server_addr), dst=str(self.client_addr))
+ / TCP(sport=self.server_port, dport=self.client_port)
+ )
+
+ @property
+ def server_addr_port(self):
+ return (str(self.server_addr), self.server_port)
+
+ @property
+ def server_netns_name(self):
+ return self.nsfixture.server_netns_name
+
+ @property
+ def client_netns_name(self):
+ return self.nsfixture.client_netns_name
+
+ def client_nstat_json(self):
+ with netns_context(self.client_netns_name):
+ return nstat_json()
+
+ def server_nstat_json(self):
+ with netns_context(self.server_netns_name):
+ return nstat_json()
+
+ def assert_no_snmp_output_failures(self):
+ client_nstat_dict = self.client_nstat_json()
+ assert client_nstat_dict["TcpExtTCPAuthOptFailure"] == 0
+ server_nstat_dict = self.server_nstat_json()
+ assert server_nstat_dict["TcpExtTCPAuthOptFailure"] == 0
+
+ def _get_state_via_ss(self, command_prefix: str):
+ # Every namespace should have at most one socket
+ # the "state connected" filter includes TIME-WAIT but not LISTEN
+ cmd = command_prefix + "ss --numeric --no-header --tcp state connected"
+ out = subprocess.check_output(cmd, text=True, shell=True)
+ lines = out.splitlines()
+ # No socket found usually means "CLOSED". It is distinct from "TIME-WAIT"
+ if len(lines) == 0:
+ return None
+ if len(lines) > 1:
+ raise ValueError("At most one line expected")
+ return lines[0].split()[0]
+
+ def get_client_tcp_state(self):
+ return self._get_state_via_ss(f"ip netns exec {self.client_netns_name} ")
+
+ def get_server_tcp_state(self):
+ return self._get_state_via_ss(f"ip netns exec {self.server_netns_name} ")
+
+
+def format_tcp_authopt_packet(
+ p: Packet, include_ethernet=False, include_seq=False, include_md5=True
+) -> str:
+ """Format a TCP packet in a way that is useful for TCP-AO testing"""
+ if not TCP in p:
+ return p.summary()
+ th = p[TCP]
+ if isinstance(th.underlayer, IP):
+ result = p.sprintf(r"%IP.src%:%TCP.sport% > %IP.dst%:%TCP.dport%")
+ elif isinstance(th.underlayer, IPv6):
+ result = p.sprintf(r"%IPv6.src%:%TCP.sport% > %IPv6.dst%:%TCP.dport%")
+ else:
+ raise ValueError(f"Unknown TCP underlayer {th.underlayer}")
+ result += p.sprintf(r" Flags %-2s,TCP.flags%")
+ if include_ethernet:
+ result = p.sprintf(r"ethertype %Ether.type% ") + result
+ result = p.sprintf(r"%Ether.src% > %Ether.dst% ") + result
+ if include_seq:
+ result += p.sprintf(r" seq %TCP.seq% ack %TCP.ack%")
+ result += f" len {len(p[TCP].payload)}"
+ authopt = scapy_tcp_get_authopt_val(p[TCP])
+ if authopt:
+ result += f" AO keyid={authopt.keyid} rnextkeyid={authopt.rnextkeyid} mac={authopt.mac.hex()}"
+ else:
+ result += " no AO"
+ if include_md5:
+ md5sig = scapy_tcp_get_md5_sig(p[TCP])
+ if md5sig:
+ result += f" MD5 {md5sig.hex()}"
+ else:
+ result += " no MD5"
+ return result
+
+
+def log_tcp_authopt_packet(p):
+ logger.info("sniff %s", format_tcp_authopt_packet(p, include_seq=True))
new file mode 100644
@@ -0,0 +1,559 @@
+# SPDX-License-Identifier: GPL-2.0
+"""Capture packets with TCP-AO and verify signatures"""
+
+import logging
+import os
+import socket
+import subprocess
+from contextlib import ExitStack, nullcontext
+
+import pytest
+import waiting
+from scapy.layers.inet import TCP
+
+from .conftest import (
+ raises_optional_exception,
+ skipif_cant_capture,
+ skipif_missing_tcp_authopt,
+)
+from .full_tcp_sniff_session import FullTCPSniffSession
+from .linux_tcp_authopt import (
+ TCP_AUTHOPT_ALG,
+ TCP_AUTHOPT_KEY_FLAG,
+ set_tcp_authopt_key,
+ tcp_authopt_key,
+)
+from .netns_fixture import NamespaceFixture
+from .scapy_tcp_authopt import (
+ TcpAuthOptAlg_HMAC_SHA1,
+ add_tcp_authopt_signature,
+ break_tcp_authopt_signature,
+)
+from .scapy_utils import (
+ AsyncSnifferContext,
+ scapy_sniffer_stop,
+ scapy_tcp_get_authopt_val,
+ scapy_tcp_get_md5_sig,
+ tcp_seq_wrap,
+)
+from .server import SimpleServerThread
+from .tcp_connection_fixture import TCPConnectionFixture
+from .utils import (
+ DEFAULT_TCP_SERVER_PORT,
+ check_socket_echo,
+ create_client_socket,
+ create_listen_socket,
+ nstat_json,
+ socket_set_linger,
+)
+from .validator import TcpAuthValidator, TcpAuthValidatorKey
+
+logger = logging.getLogger(__name__)
+pytestmark = [skipif_missing_tcp_authopt, skipif_cant_capture]
+DEFAULT_TCP_AUTHOPT_KEY = tcp_authopt_key(
+ alg=TCP_AUTHOPT_ALG.HMAC_SHA_1_96,
+ key=b"hello",
+)
+
+
+def get_alg_id(alg_name) -> int:
+ if alg_name == "HMAC-SHA-1-96":
+ return TCP_AUTHOPT_ALG.HMAC_SHA_1_96
+ elif alg_name == "AES-128-CMAC-96":
+ return TCP_AUTHOPT_ALG.AES_128_CMAC_96
+ else:
+ raise ValueError()
+
+
+@pytest.mark.parametrize(
+ "address_family,alg_name,include_options,transfer_data",
+ [
+ (socket.AF_INET, "HMAC-SHA-1-96", True, True),
+ (socket.AF_INET, "AES-128-CMAC-96", True, True),
+ (socket.AF_INET, "AES-128-CMAC-96", False, True),
+ (socket.AF_INET6, "HMAC-SHA-1-96", True, True),
+ (socket.AF_INET6, "HMAC-SHA-1-96", False, True),
+ (socket.AF_INET6, "AES-128-CMAC-96", True, True),
+ (socket.AF_INET, "HMAC-SHA-1-96", True, False),
+ (socket.AF_INET6, "AES-128-CMAC-96", False, False),
+ ],
+)
+def test_verify_capture(
+ exit_stack, address_family, alg_name, include_options, transfer_data
+):
+ master_key = b"testvector"
+ alg_id = get_alg_id(alg_name)
+
+ session = FullTCPSniffSession(server_port=DEFAULT_TCP_SERVER_PORT)
+ sniffer = exit_stack.enter_context(
+ AsyncSnifferContext(
+ filter=f"inbound and tcp port {DEFAULT_TCP_SERVER_PORT}",
+ iface="lo",
+ session=session,
+ )
+ )
+
+ listen_socket = create_listen_socket(family=address_family)
+ listen_socket = exit_stack.enter_context(listen_socket)
+ exit_stack.enter_context(SimpleServerThread(listen_socket, mode="echo"))
+
+ client_socket = socket.socket(address_family, socket.SOCK_STREAM)
+ client_socket = exit_stack.push(client_socket)
+
+ key = tcp_authopt_key(alg=alg_id, key=master_key, include_options=include_options)
+ set_tcp_authopt_key(listen_socket, key)
+ set_tcp_authopt_key(client_socket, key)
+
+ # even if one signature is incorrect keep processing the capture
+ old_nstat = nstat_json()
+ valkey = TcpAuthValidatorKey(
+ key=master_key, alg_name=alg_name, include_options=include_options
+ )
+ validator = TcpAuthValidator(keys=[valkey])
+
+ try:
+ client_socket.settimeout(1.0)
+ client_socket.connect(("localhost", DEFAULT_TCP_SERVER_PORT))
+ if transfer_data:
+ for _ in range(5):
+ check_socket_echo(client_socket)
+ client_socket.close()
+ session.wait_close()
+ except socket.timeout:
+ # If invalid packets are sent let the validator run
+ logger.warning("socket timeout", exc_info=True)
+ pass
+
+ sniffer.stop()
+
+ logger.info("capture: %r", sniffer.results)
+ for p in sniffer.results:
+ validator.handle_packet(p)
+ validator.raise_errors()
+
+ new_nstat = nstat_json()
+ assert old_nstat["TcpExtTCPAuthOptFailure"] == new_nstat["TcpExtTCPAuthOptFailure"]
+
+
+@pytest.mark.parametrize(
+ "address_family,use_tcp_authopt,use_tcp_md5sig",
+ [
+ (socket.AF_INET, 0, 0),
+ (socket.AF_INET, 1, 0),
+ (socket.AF_INET, 0, 1),
+ (socket.AF_INET6, 0, 0),
+ (socket.AF_INET6, 1, 0),
+ (socket.AF_INET6, 0, 1),
+ (socket.AF_INET, 1, 1),
+ (socket.AF_INET6, 1, 1),
+ ],
+)
+def test_both_authopt_md5(exit_stack, address_family, use_tcp_authopt, use_tcp_md5sig):
+ """Basic test for interaction between TCP_AUTHOPT and TCP_MD5SIG
+
+ Configuring both on same socket is allowed but RFC5925 doesn't allow both on the
+ same packet or same connection.
+
+ The naive handling of inserting or validation both options is incorrect.
+ """
+ con = TCPConnectionFixture(address_family=address_family)
+ if use_tcp_authopt:
+ con.tcp_authopt_key = DEFAULT_TCP_AUTHOPT_KEY
+ if use_tcp_md5sig:
+ con.tcp_md5_key = b"hello"
+ exit_stack.enter_context(con)
+
+ con.client_socket.connect(con.server_addr_port)
+ check_socket_echo(con.client_socket)
+ check_socket_echo(con.client_socket)
+ check_socket_echo(con.client_socket)
+ con.client_socket.close()
+
+ scapy_sniffer_stop(con.sniffer)
+ fail = False
+ for p in con.sniffer.results:
+ has_tcp_authopt = scapy_tcp_get_authopt_val(p[TCP]) is not None
+ has_tcp_md5sig = scapy_tcp_get_md5_sig(p[TCP]) is not None
+
+ if has_tcp_authopt and has_tcp_md5sig:
+ logger.error("Packet has both AO and MD5: %r", p)
+ fail = False
+
+ if use_tcp_authopt:
+ if not has_tcp_authopt:
+ logger.error("missing AO: %r", p)
+ fail = True
+ elif use_tcp_md5sig:
+ if not has_tcp_md5sig:
+ logger.error("missing MD5: %r", p)
+ fail = True
+ else:
+ if has_tcp_md5sig or has_tcp_authopt:
+ logger.error("unexpected MD5 or AO: %r", p)
+ fail = True
+
+ assert not fail
+
+
+@pytest.mark.parametrize("mode", ["none", "ao", "ao-addrbind", "md5"])
+def test_v4mapv6(exit_stack, mode: str):
+ """Test ipv4 client and ipv6 server with and without TCP-AO
+
+ By default any IPv6 server will also receive packets from IPv4 clients. This
+ is not currently supported by TCP_AUTHOPT but it should fail in an orderly
+ manner.
+ """
+ nsfixture = NamespaceFixture()
+ exit_stack.enter_context(nsfixture)
+ server_ipv4_addr = nsfixture.get_addr(socket.AF_INET, 1)
+
+ listen_socket = create_listen_socket(
+ ns=nsfixture.server_netns_name, family=socket.AF_INET6
+ )
+ listen_socket = exit_stack.enter_context(listen_socket)
+
+ server_thread = SimpleServerThread(listen_socket, mode="echo")
+ exit_stack.enter_context(server_thread)
+
+ client_socket = create_client_socket(
+ ns=nsfixture.client_netns_name,
+ family=socket.AF_INET,
+ )
+ client_socket = exit_stack.push(client_socket)
+
+ if mode == "ao":
+ alg = TCP_AUTHOPT_ALG.HMAC_SHA_1_96
+ key = tcp_authopt_key(alg=alg, key="hello")
+ set_tcp_authopt_key(listen_socket, key)
+ set_tcp_authopt_key(client_socket, key)
+
+ if mode == "ao-addrbind":
+ alg = TCP_AUTHOPT_ALG.HMAC_SHA_1_96
+ client_ipv6_addr = nsfixture.get_addr(socket.AF_INET6, 2)
+ server_key = tcp_authopt_key(alg=alg, key="hello", addr=client_ipv6_addr)
+ server_key.flags = TCP_AUTHOPT_KEY_FLAG.BIND_ADDR
+ set_tcp_authopt_key(listen_socket, server_key)
+
+ client_key = tcp_authopt_key(alg=alg, key="hello")
+ set_tcp_authopt_key(client_socket, client_key)
+
+ if mode == "md5":
+ from . import linux_tcp_md5sig
+
+ server_md5key = linux_tcp_md5sig.tcp_md5sig(key=b"hello")
+ server_md5key.set_ipv6_addr_all()
+ linux_tcp_md5sig.setsockopt_md5sig(listen_socket, server_md5key)
+ client_md5key = linux_tcp_md5sig.tcp_md5sig(key=b"hellx")
+ client_md5key.set_ipv4_addr_all()
+ linux_tcp_md5sig.setsockopt_md5sig(client_socket, client_md5key)
+
+ with raises_optional_exception(socket.timeout if mode != "none" else None):
+ client_socket.connect((str(server_ipv4_addr), DEFAULT_TCP_SERVER_PORT))
+ check_socket_echo(client_socket)
+ client_socket.close()
+
+
+@pytest.mark.parametrize(
+ "address_family,signed",
+ [
+ (socket.AF_INET, True),
+ (socket.AF_INET, False),
+ (socket.AF_INET6, True),
+ (socket.AF_INET6, False),
+ ],
+)
+def test_rst(exit_stack: ExitStack, address_family, signed: bool):
+ """Check that an unsigned RST breaks a normal connection but not one protected by TCP-AO"""
+
+ con = TCPConnectionFixture(address_family=address_family)
+ if signed:
+ con.tcp_authopt_key = DEFAULT_TCP_AUTHOPT_KEY
+ exit_stack.enter_context(con)
+
+ # connect
+ con.client_socket.connect(con.server_addr_port)
+ check_socket_echo(con.client_socket)
+
+ client_isn, server_isn = con.sniffer_session.get_client_server_isn()
+ p = con.create_client2server_packet()
+ p[TCP].flags = "R"
+ p[TCP].seq = tcp_seq_wrap(client_isn + 1001)
+ p[TCP].ack = tcp_seq_wrap(server_isn + 1001)
+ con.client_l2socket.send(p)
+
+ if signed:
+ # When protected by TCP-AO unsigned RSTs are ignored.
+ check_socket_echo(con.client_socket)
+ else:
+ # By default an RST that guesses seq can kill the connection.
+ with pytest.raises(ConnectionResetError):
+ check_socket_echo(con.client_socket)
+
+
+@pytest.mark.parametrize("address_family", [socket.AF_INET, socket.AF_INET6])
+def test_rst_signed_manually(exit_stack: ExitStack, address_family):
+ """Check that an manually signed RST breaks a connection protected by TCP-AO"""
+
+ con = TCPConnectionFixture(address_family=address_family)
+ con.tcp_authopt_key = key = DEFAULT_TCP_AUTHOPT_KEY
+ exit_stack.enter_context(con)
+
+ # connect
+ con.client_socket.connect(con.server_addr_port)
+ check_socket_echo(con.client_socket)
+
+ client_isn, server_isn = con.sniffer_session.get_client_server_isn()
+ p = con.create_client2server_packet()
+ p[TCP].flags = "R"
+ p[TCP].seq = tcp_seq_wrap(client_isn + 1001)
+ p[TCP].ack = tcp_seq_wrap(server_isn + 1001)
+
+ add_tcp_authopt_signature(
+ p, TcpAuthOptAlg_HMAC_SHA1(), key.key, client_isn, server_isn
+ )
+ con.client_l2socket.send(p)
+
+ # The server socket will close in response to RST without a TIME-WAIT
+ # Attempting to send additional packets will result in a timeout because
+ # the signature can't be validated.
+ with pytest.raises(socket.timeout):
+ check_socket_echo(con.client_socket)
+
+
+@pytest.mark.parametrize("address_family", [socket.AF_INET, socket.AF_INET6])
+def test_tw_ack(exit_stack: ExitStack, address_family):
+ """Manually sent a duplicate ACK after FIN and check TWSK signs replies correctly
+
+ Kernel has a custom code path for this
+ """
+
+ con = TCPConnectionFixture(address_family=address_family)
+ con.tcp_authopt_key = key = DEFAULT_TCP_AUTHOPT_KEY
+ exit_stack.enter_context(con)
+
+ # connect and close nicely
+ con.client_socket.connect(con.server_addr_port)
+ check_socket_echo(con.client_socket)
+ assert con.get_client_tcp_state() == "ESTAB"
+ assert con.get_server_tcp_state() == "ESTAB"
+ con.client_socket.close()
+ con.sniffer_session.wait_close()
+
+ assert con.get_client_tcp_state() == "TIME-WAIT"
+ assert con.get_server_tcp_state() is None
+
+ # Sent a duplicate FIN/ACK
+ client_isn, server_isn = con.sniffer_session.get_client_server_isn()
+ p = con.create_server2client_packet()
+ p[TCP].flags = "FA"
+ p[TCP].seq = tcp_seq_wrap(server_isn + 1001)
+ p[TCP].ack = tcp_seq_wrap(client_isn + 1002)
+ add_tcp_authopt_signature(
+ p, TcpAuthOptAlg_HMAC_SHA1(), key.key, server_isn, client_isn
+ )
+ pr = con.server_l2socket.sr1(p)
+ assert pr[TCP].ack == tcp_seq_wrap(server_isn + 1001)
+ assert pr[TCP].seq == tcp_seq_wrap(client_isn + 1001)
+ assert pr[TCP].flags == "A"
+
+ scapy_sniffer_stop(con.sniffer)
+
+ val = TcpAuthValidator()
+ val.keys.append(TcpAuthValidatorKey(key=b"hello", alg_name="HMAC-SHA-1-96"))
+ for p in con.sniffer.results:
+ val.handle_packet(p)
+ val.raise_errors()
+
+ # The server does not have enough state to validate the ACK from TIME-WAIT
+ # so it reports a failure.
+ assert con.server_nstat_json()["TcpExtTCPAuthOptFailure"] == 1
+ assert con.client_nstat_json()["TcpExtTCPAuthOptFailure"] == 0
+
+
+@pytest.mark.parametrize("address_family", [socket.AF_INET, socket.AF_INET6])
+def test_tw_rst(exit_stack: ExitStack, address_family):
+ """Manually sent a signed invalid packet after FIN and check TWSK signs RST correctly
+
+ Kernel has a custom code path for this
+ """
+ key = DEFAULT_TCP_AUTHOPT_KEY
+ con = TCPConnectionFixture(
+ address_family=address_family,
+ tcp_authopt_key=key,
+ )
+ con.server_thread.keep_half_open = True
+ exit_stack.enter_context(con)
+
+ # connect, transfer data and close client nicely
+ con.client_socket.connect(con.server_addr_port)
+ check_socket_echo(con.client_socket)
+ con.client_socket.close()
+
+ # since server keeps connection open client goes to FIN-WAIT-2
+ def check_socket_states():
+ client_tcp_state_name = con.get_client_tcp_state()
+ server_tcp_state_name = con.get_server_tcp_state()
+ logger.info("%s %s", client_tcp_state_name, server_tcp_state_name)
+ return (
+ client_tcp_state_name == "FIN-WAIT-2"
+ and server_tcp_state_name == "CLOSE-WAIT"
+ )
+
+ waiting.wait(check_socket_states)
+
+ # sending a FIN-ACK with incorrect seq makes
+ # tcp_timewait_state_process return a TCP_TW_RST
+ client_isn, server_isn = con.sniffer_session.get_client_server_isn()
+ p = con.create_server2client_packet()
+ p[TCP].flags = "FA"
+ p[TCP].seq = tcp_seq_wrap(server_isn + 1001 + 1)
+ p[TCP].ack = tcp_seq_wrap(client_isn + 1002)
+ add_tcp_authopt_signature(
+ p, TcpAuthOptAlg_HMAC_SHA1(), key.key, server_isn, client_isn
+ )
+ con.server_l2socket.send(p)
+
+ # remove delay by scapy trick?
+ import time
+
+ time.sleep(1)
+ scapy_sniffer_stop(con.sniffer)
+
+ # Check client socket moved from FIN-WAIT-2 to CLOSED
+ assert con.get_client_tcp_state() is None
+
+ # Check some RST was seen
+ def is_tcp_rst(p):
+ return TCP in p and p[TCP].flags.R
+
+ assert any(is_tcp_rst(p) for p in con.sniffer.results)
+
+ # Check everything was valid
+ val = TcpAuthValidator()
+ val.keys.append(TcpAuthValidatorKey(key=b"hello", alg_name="HMAC-SHA-1-96"))
+ for p in con.sniffer.results:
+ val.handle_packet(p)
+ val.raise_errors()
+
+ # Check no snmp failures
+ con.assert_no_snmp_output_failures()
+
+
+def test_rst_linger(exit_stack: ExitStack):
+ """Test RST sent deliberately via SO_LINGER is valid"""
+ con = TCPConnectionFixture(
+ sniffer_kwargs=dict(count=8), tcp_authopt_key=DEFAULT_TCP_AUTHOPT_KEY
+ )
+ exit_stack.enter_context(con)
+
+ con.client_socket.connect(con.server_addr_port)
+ check_socket_echo(con.client_socket)
+ socket_set_linger(con.client_socket, 1, 0)
+ con.client_socket.close()
+
+ con.sniffer.join(timeout=3)
+
+ val = TcpAuthValidator()
+ val.keys.append(TcpAuthValidatorKey(key=b"hello", alg_name="HMAC-SHA-1-96"))
+ for p in con.sniffer.results:
+ val.handle_packet(p)
+ val.raise_errors()
+
+ def is_tcp_rst(p):
+ return TCP in p and p[TCP].flags.R
+
+ assert any(is_tcp_rst(p) for p in con.sniffer.results)
+
+
+@pytest.mark.parametrize(
+ "address_family,mode",
+ [
+ (socket.AF_INET, "goodsign"),
+ (socket.AF_INET, "fakesign"),
+ (socket.AF_INET, "unsigned"),
+ (socket.AF_INET6, "goodsign"),
+ (socket.AF_INET6, "fakesign"),
+ (socket.AF_INET6, "unsigned"),
+ ],
+)
+def test_badack_to_synack(exit_stack, address_family, mode: str):
+ """Test bad ack in response to server to syn/ack.
+
+ This is handled by a minisocket in the TCP_SYN_RECV state on a separate code path
+ """
+ con = TCPConnectionFixture(address_family=address_family)
+ if mode != "unsigned":
+ con.tcp_authopt_key = tcp_authopt_key(
+ alg=TCP_AUTHOPT_ALG.HMAC_SHA_1_96,
+ key=b"hello",
+ )
+ exit_stack.enter_context(con)
+
+ client_l2socket = con.client_l2socket
+ client_isn = 1000
+ server_isn = 0
+
+ def sign(packet):
+ if mode == "unsigned":
+ return
+ add_tcp_authopt_signature(
+ packet,
+ TcpAuthOptAlg_HMAC_SHA1(),
+ con.tcp_authopt_key.key,
+ client_isn,
+ server_isn,
+ )
+
+ # Prevent TCP in client namespace from sending RST
+ # Do this by removing the client address and insert a static ARP on server side
+ client_prefix_length = con.nsfixture.get_prefix_length(address_family)
+ subprocess.run(
+ f"""\
+set -e
+ip netns exec {con.nsfixture.client_netns_name} ip addr del {con.client_addr}/{client_prefix_length} dev veth0
+ip netns exec {con.nsfixture.server_netns_name} ip neigh add {con.client_addr} lladdr {con.nsfixture.client_mac_addr} dev veth0
+""",
+ shell=True,
+ check=True,
+ )
+
+ p1 = con.create_client2server_packet()
+ p1[TCP].flags = "S"
+ p1[TCP].seq = client_isn
+ p1[TCP].ack = 0
+ sign(p1)
+
+ p2 = client_l2socket.sr1(p1, timeout=1)
+ server_isn = p2[TCP].seq
+ assert p2[TCP].ack == client_isn + 1
+ assert p2[TCP].flags == "SA"
+
+ p3 = con.create_client2server_packet()
+ p3[TCP].flags = "A"
+ p3[TCP].seq = client_isn + 1
+ p3[TCP].ack = server_isn + 1
+ sign(p3)
+ if mode == "fakesign":
+ break_tcp_authopt_signature(p3)
+
+ assert con.server_nstat_json()["TcpExtTCPAuthOptFailure"] == 0
+ client_l2socket.send(p3)
+
+ def confirm_good():
+ return len(con.server_thread.server_socket) > 0
+
+ def confirm_fail():
+ return con.server_nstat_json()["TcpExtTCPAuthOptFailure"] == 1
+
+ def wait_good():
+ assert not confirm_fail()
+ return confirm_good()
+
+ def wait_fail():
+ assert not confirm_good()
+ return confirm_fail()
+
+ if mode == "fakesign":
+ waiting.wait(wait_fail, timeout_seconds=5, sleep_seconds=0.1)
+ else:
+ waiting.wait(wait_good, timeout_seconds=5, sleep_seconds=0.1)
This patch validates that the TCP-AO signatures inserted by linux are correct in all algorithm permutations, using scapy. It also tests that TCP-AO behaves correctly in a number of corner cases such as: * reset handling * timewait * syn-recv * ipv4-mapped ipv6 * interaction with tcp-md5 This reverts commit 297a301a4f1c3abe41d554a9f6df192257a017b8. Signed-off-by: Leonard Crestez <cdleonard@gmail.com> --- .../full_tcp_sniff_session.py | 91 +++ .../tcp_authopt_test/linux_tcp_md5sig.py | 110 ++++ .../tcp_authopt_test/scapy_conntrack.py | 173 ++++++ .../tcp_connection_fixture.py | 276 +++++++++ .../tcp_authopt_test/test_verify_capture.py | 559 ++++++++++++++++++ 5 files changed, 1209 insertions(+) create mode 100644 tools/testing/selftests/tcp_authopt/tcp_authopt_test/full_tcp_sniff_session.py create mode 100644 tools/testing/selftests/tcp_authopt/tcp_authopt_test/linux_tcp_md5sig.py create mode 100644 tools/testing/selftests/tcp_authopt/tcp_authopt_test/scapy_conntrack.py create mode 100644 tools/testing/selftests/tcp_authopt/tcp_authopt_test/tcp_connection_fixture.py create mode 100644 tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_verify_capture.py