diff mbox series

[04/19] selftests: tcp_authopt: Initial sockopt manipulation

Message ID 1d55493c5879971355f31d6a8ac33612f07df500.1632240523.git.cdleonard@gmail.com (mailing list archive)
State New
Headers show
Series tcp: Initial support for RFC5925 auth option | expand

Commit Message

Leonard Crestez Sept. 21, 2021, 4:14 p.m. UTC
Signed-off-by: Leonard Crestez <cdleonard@gmail.com>
---
 .../tcp_authopt/tcp_authopt_test/conftest.py  |  41 +++
 .../tcp_authopt_test/linux_tcp_authopt.py     | 238 ++++++++++++++++++
 .../tcp_authopt/tcp_authopt_test/sockaddr.py  | 112 +++++++++
 .../tcp_authopt_test/test_sockopt.py          | 185 ++++++++++++++
 4 files changed, 576 insertions(+)
 create mode 100644 tools/testing/selftests/tcp_authopt/tcp_authopt_test/conftest.py
 create mode 100644 tools/testing/selftests/tcp_authopt/tcp_authopt_test/linux_tcp_authopt.py
 create mode 100644 tools/testing/selftests/tcp_authopt/tcp_authopt_test/sockaddr.py
 create mode 100644 tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_sockopt.py
diff mbox series

Patch

diff --git a/tools/testing/selftests/tcp_authopt/tcp_authopt_test/conftest.py b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/conftest.py
new file mode 100644
index 000000000000..dfab5a631e34
--- /dev/null
+++ b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/conftest.py
@@ -0,0 +1,41 @@ 
+# SPDX-License-Identifier: GPL-2.0
+import logging
+import os
+from contextlib import ExitStack
+
+import pytest
+
+from .linux_tcp_authopt import has_tcp_authopt, enable_sysctl_tcp_authopt
+
+logger = logging.getLogger(__name__)
+
+skipif_missing_tcp_authopt = pytest.mark.skipif(
+    not has_tcp_authopt(), reason="Need CONFIG_TCP_AUTHOPT"
+)
+
+
+def can_capture():
+    # This is too restrictive:
+    return os.geteuid() == 0
+
+
+skipif_cant_capture = pytest.mark.skipif(
+    not can_capture(), reason="run as root to capture packets"
+)
+
+
+@pytest.fixture
+def exit_stack():
+    """Return a contextlib.ExitStack as a pytest fixture
+
+    This reduces indentation making code more readable
+    """
+    with ExitStack() as exit_stack:
+        yield exit_stack
+
+
+def pytest_configure():
+    # Silence messages regarding netns enter/exit:
+    logging.getLogger("nsenter").setLevel(logging.INFO)
+    if has_tcp_authopt():
+        enable_sysctl_tcp_authopt()
diff --git a/tools/testing/selftests/tcp_authopt/tcp_authopt_test/linux_tcp_authopt.py b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/linux_tcp_authopt.py
new file mode 100644
index 000000000000..339298998ff9
--- /dev/null
+++ b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/linux_tcp_authopt.py
@@ -0,0 +1,238 @@ 
+# SPDX-License-Identifier: GPL-2.0
+"""Python wrapper around linux TCP_AUTHOPT ABI"""
+
+from dataclasses import dataclass
+from ipaddress import IPv4Address, IPv6Address, ip_address
+import socket
+from enum import IntEnum, IntFlag
+import errno
+import logging
+from .sockaddr import sockaddr_in, sockaddr_in6, sockaddr_storage, sockaddr_unpack
+import typing
+import struct
+
+logger = logging.getLogger(__name__)
+
+
+def BIT(x):
+    return 1 << x
+
+
+TCP_AUTHOPT = 38
+TCP_AUTHOPT_KEY = 39
+
+TCP_AUTHOPT_MAXKEYLEN = 80
+
+
+class TCP_AUTHOPT_FLAG(IntFlag):
+    REJECT_UNEXPECTED = BIT(2)
+
+
+class TCP_AUTHOPT_KEY_FLAG(IntFlag):
+    DEL = BIT(0)
+    EXCLUDE_OPTS = BIT(1)
+    BIND_ADDR = BIT(2)
+
+
+class TCP_AUTHOPT_ALG(IntEnum):
+    HMAC_SHA_1_96 = 1
+    AES_128_CMAC_96 = 2
+
+
+@dataclass
+class tcp_authopt:
+    """Like linux struct tcp_authopt"""
+
+    flags: int = 0
+    sizeof = 4
+
+    def pack(self) -> bytes:
+        return struct.pack(
+            "I",
+            self.flags,
+        )
+
+    def __bytes__(self):
+        return self.pack()
+
+    @classmethod
+    def unpack(cls, b: bytes):
+        tup = struct.unpack("I", b)
+        return cls(*tup)
+
+
+def set_tcp_authopt(sock, opt: tcp_authopt):
+    return sock.setsockopt(socket.SOL_TCP, TCP_AUTHOPT, bytes(opt))
+
+
+def get_tcp_authopt(sock: socket.socket) -> tcp_authopt:
+    b = sock.getsockopt(socket.SOL_TCP, TCP_AUTHOPT, tcp_authopt.sizeof)
+    return tcp_authopt.unpack(b)
+
+
+class tcp_authopt_key:
+    """Like linux struct tcp_authopt_key"""
+
+    def __init__(
+        self,
+        flags: int = 0,
+        send_id: int = 0,
+        recv_id: int = 0,
+        alg=TCP_AUTHOPT_ALG.HMAC_SHA_1_96,
+        key: bytes = b"",
+        addr: bytes = b"",
+        include_options=None,
+    ):
+        self.flags = flags
+        self.send_id = send_id
+        self.recv_id = recv_id
+        self.alg = alg
+        self.key = key
+        self.addr = addr
+        if include_options is not None:
+            self.include_options = include_options
+
+    def pack(self):
+        if len(self.key) > TCP_AUTHOPT_MAXKEYLEN:
+            raise ValueError(f"Max key length is {TCP_AUTHOPT_MAXKEYLEN}")
+        data = struct.pack(
+            "IBBBB80s",
+            self.flags,
+            self.send_id,
+            self.recv_id,
+            self.alg,
+            len(self.key),
+            self.key,
+        )
+        data += bytes(self.addrbuf.ljust(sockaddr_storage.sizeof, b"\x00"))
+        return data
+
+    def __bytes__(self):
+        return self.pack()
+
+    @property
+    def key(self) -> bytes:
+        return self._key
+
+    @key.setter
+    def key(self, val: typing.Union[bytes, str]) -> bytes:
+        if isinstance(val, str):
+            val = val.encode("utf-8")
+        if len(val) > TCP_AUTHOPT_MAXKEYLEN:
+            raise ValueError(f"Max key length is {TCP_AUTHOPT_MAXKEYLEN}")
+        self._key = val
+        return val
+
+    @property
+    def addr(self):
+        if not self.addrbuf:
+            return None
+        else:
+            return sockaddr_unpack(bytes(self.addrbuf))
+
+    @addr.setter
+    def addr(self, val):
+        if isinstance(val, bytes):
+            if len(val) > sockaddr_storage.sizeof:
+                raise ValueError(f"Must be up to {sockaddr_storage.sizeof}")
+            self.addrbuf = val
+        elif val is None:
+            self.addrbuf = b""
+        elif isinstance(val, str):
+            self.addr = ip_address(val)
+        elif isinstance(val, IPv4Address):
+            self.addr = sockaddr_in(addr=val)
+        elif isinstance(val, IPv6Address):
+            self.addr = sockaddr_in6(addr=val)
+        elif (
+            isinstance(val, sockaddr_in)
+            or isinstance(val, sockaddr_in6)
+            or isinstance(val, sockaddr_storage)
+        ):
+            self.addr = bytes(val)
+        else:
+            raise TypeError(f"Can't handle addr {val}")
+        return self.addr
+
+    @property
+    def include_options(self) -> bool:
+        return (self.flags & TCP_AUTHOPT_KEY.EXCLUDE_OPTS) == 0
+
+    @include_options.setter
+    def include_options(self, value) -> bool:
+        if value:
+            self.flags &= ~TCP_AUTHOPT_KEY_FLAG.EXCLUDE_OPTS
+        else:
+            self.flags |= TCP_AUTHOPT_KEY_FLAG.EXCLUDE_OPTS
+
+    @property
+    def delete_flag(self) -> bool:
+        return bool(self.flags & TCP_AUTHOPT_KEY_FLAG.DEL)
+
+    @delete_flag.setter
+    def delete_flag(self, value) -> bool:
+        if value:
+            self.flags |= TCP_AUTHOPT_KEY_FLAG.DEL
+        else:
+            self.flags &= ~TCP_AUTHOPT_KEY_FLAG.DEL
+
+
+def set_tcp_authopt_key(sock, key: tcp_authopt_key):
+    return sock.setsockopt(socket.SOL_TCP, TCP_AUTHOPT_KEY, bytes(key))
+
+
+def del_tcp_authopt_key(sock, key: tcp_authopt_key) -> bool:
+    """Try to delete an authopt key
+
+    :return: True if a key was deleted, False if it was not present
+    """
+    import copy
+
+    key = copy.copy(key)
+    key.delete_flag = True
+    try:
+        sock.setsockopt(socket.SOL_TCP, TCP_AUTHOPT_KEY, bytes(key))
+        return True
+    except OSError as e:
+        if e.errno == errno.ENOENT:
+            return False
+        raise
+
+
+def get_sysctl_tcp_authopt() -> bool:
+    from pathlib import Path
+
+    path = Path("/proc/sys/net/ipv4/tcp_authopt")
+    if path.exists():
+        return path.read_text().strip() != "0"
+
+
+def enable_sysctl_tcp_authopt() -> bool:
+    from pathlib import Path
+
+    path = Path("/proc/sys/net/ipv4/tcp_authopt")
+    try:
+        if path.read_text().strip() == "0":
+            path.write_text("1")
+    except:
+        raise Exception("Failed to enable /proc/sys/net/ipv4/tcp_authopt")
+
+
+def has_tcp_authopt() -> bool:
+    """Check is TCP_AUTHOPT is implemented by the OS
+
+    Returns True if implemented but disabled by sysctl
+    Returns False if disabled at compile time
+    """
+    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+        try:
+            optbuf = bytes(4)
+            sock.setsockopt(socket.SOL_TCP, TCP_AUTHOPT, optbuf)
+            return True
+        except OSError as e:
+            if e.errno == errno.ENOPROTOOPT:
+                return False
+            elif e.errno == errno.EPERM and get_sysctl_tcp_authopt() is False:
+                return True
+            else:
+                raise
diff --git a/tools/testing/selftests/tcp_authopt/tcp_authopt_test/sockaddr.py b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/sockaddr.py
new file mode 100644
index 000000000000..be1745ac10ab
--- /dev/null
+++ b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/sockaddr.py
@@ -0,0 +1,112 @@ 
+# SPDX-License-Identifier: GPL-2.0
+"""pack/unpack wrappers for sockaddr"""
+import socket
+import struct
+from dataclasses import dataclass
+from ipaddress import IPv4Address, IPv6Address, ip_address
+
+
+@dataclass
+class sockaddr_in:
+    port: int
+    addr: IPv4Address
+    sizeof = 8
+
+    def __init__(self, port=0, addr=None):
+        self.port = port
+        if addr is None:
+            addr = IPv4Address(0)
+        self.addr = IPv4Address(addr)
+
+    def pack(self):
+        return struct.pack("HH4s", socket.AF_INET, self.port, self.addr.packed)
+
+    @classmethod
+    def unpack(cls, buffer):
+        family, port, addr_packed = struct.unpack("HH4s", buffer[:8])
+        if family != socket.AF_INET:
+            raise ValueError(f"Must be AF_INET not {family}")
+        return cls(port, addr_packed)
+
+    def __bytes__(self):
+        return self.pack()
+
+
+@dataclass
+class sockaddr_in6:
+    """Like sockaddr_in6 but for python. Always contains scope_id"""
+
+    port: int
+    addr: IPv6Address
+    flowinfo: int
+    scope_id: int
+    sizeof = 28
+
+    def __init__(self, port=0, addr=None, flowinfo=0, scope_id=0):
+        self.port = port
+        if addr is None:
+            addr = IPv6Address(0)
+        self.addr = IPv6Address(addr)
+        self.flowinfo = flowinfo
+        self.scope_id = scope_id
+
+    def pack(self):
+        return struct.pack(
+            "HHI16sI",
+            socket.AF_INET6,
+            self.port,
+            self.flowinfo,
+            self.addr.packed,
+            self.scope_id,
+        )
+
+    @classmethod
+    def unpack(cls, buffer):
+        family, port, flowinfo, addr_packed, scope_id = struct.unpack(
+            "HHI16sI", buffer[:28]
+        )
+        if family != socket.AF_INET6:
+            raise ValueError(f"Must be AF_INET6 not {family}")
+        return cls(port, addr_packed, flowinfo=flowinfo, scope_id=scope_id)
+
+    def __bytes__(self):
+        return self.pack()
+
+
+@dataclass
+class sockaddr_storage:
+    family: int
+    data: bytes
+    sizeof = 128
+
+    def pack(self):
+        return struct.pack("H126s", self.family, self.data)
+
+    def __bytes__(self):
+        return self.pack()
+
+    @classmethod
+    def unpack(cls, buffer):
+        return cls(*struct.unpack("H126s", buffer))
+
+
+def sockaddr_unpack(buffer: bytes):
+    """Unpack based on family"""
+    family = struct.unpack("H", buffer[:2])[0]
+    if family == socket.AF_INET:
+        return sockaddr_in.unpack(buffer)
+    elif family == socket.AF_INET6:
+        return sockaddr_in6.unpack(buffer)
+    else:
+        return sockaddr_storage.unpack(buffer)
+
+
+def sockaddr_convert(val):
+    """Try to convert address into some sort of sockaddr"""
+    if isinstance(val, IPv4Address):
+        return sockaddr_in(addr=val)
+    if isinstance(val, IPv6Address):
+        return sockaddr_in6(addr=val)
+    if isinstance(val, str):
+        return sockaddr_convert(ip_address(val))
+    raise TypeError(f"Don't know how to convert {val!r} to sockaddr")
diff --git a/tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_sockopt.py b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_sockopt.py
new file mode 100644
index 000000000000..dd389ae6055e
--- /dev/null
+++ b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_sockopt.py
@@ -0,0 +1,185 @@ 
+# SPDX-License-Identifier: GPL-2.0
+"""Test TCP_AUTHOPT sockopt API"""
+import errno
+import socket
+import struct
+from ipaddress import IPv4Address, IPv6Address
+
+import pytest
+
+from .linux_tcp_authopt import (
+    TCP_AUTHOPT,
+    TCP_AUTHOPT_KEY,
+    TCP_AUTHOPT_ALG,
+    TCP_AUTHOPT_FLAG,
+    TCP_AUTHOPT_KEY_FLAG,
+    set_tcp_authopt,
+    get_tcp_authopt,
+    set_tcp_authopt_key,
+    del_tcp_authopt_key,
+    tcp_authopt,
+    tcp_authopt_key,
+)
+from .sockaddr import sockaddr_in, sockaddr_in6, sockaddr_unpack
+from .conftest import skipif_missing_tcp_authopt
+
+pytestmark = skipif_missing_tcp_authopt
+
+
+def test_authopt_key_pack_noaddr():
+    b = bytes(tcp_authopt_key(key=b"a\x00b"))
+    assert b[7] == 3
+    assert b[8:13] == b"a\x00b\x00\x00"
+
+
+def test_authopt_key_pack_addr():
+    b = bytes(tcp_authopt_key(key=b"a\x00b", addr="10.0.0.1"))
+    assert struct.unpack("H", b[88:90])[0] == socket.AF_INET
+    assert sockaddr_unpack(b[88:]).addr == IPv4Address("10.0.0.1")
+
+
+def test_authopt_key_pack_addr6():
+    b = bytes(tcp_authopt_key(key=b"abc", addr="fd00::1"))
+    assert struct.unpack("H", b[88:90])[0] == socket.AF_INET6
+    assert sockaddr_unpack(b[88:]).addr == IPv6Address("fd00::1")
+
+
+def test_tcp_authopt_key_del_without_active(exit_stack):
+    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+    exit_stack.push(sock)
+
+    # nothing happens:
+    key = tcp_authopt_key()
+    assert key.delete_flag is False
+    key.delete_flag = True
+    assert key.delete_flag is True
+    with pytest.raises(OSError) as e:
+        set_tcp_authopt_key(sock, key)
+    assert e.value.errno in [errno.EINVAL, errno.ENOENT]
+
+
+def test_tcp_authopt_key_setdel(exit_stack):
+    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+    exit_stack.push(sock)
+    set_tcp_authopt(sock, tcp_authopt())
+
+    # delete returns ENOENT
+    key = tcp_authopt_key()
+    key.delete_flag = True
+    with pytest.raises(OSError) as e:
+        set_tcp_authopt_key(sock, key)
+    assert e.value.errno == errno.ENOENT
+
+    key = tcp_authopt_key(send_id=1, recv_id=2)
+    set_tcp_authopt_key(sock, key)
+    # First delete works fine:
+    key.delete_flag = True
+    set_tcp_authopt_key(sock, key)
+    # Duplicate delete returns ENOENT
+    with pytest.raises(OSError) as e:
+        set_tcp_authopt_key(sock, key)
+    assert e.value.errno == errno.ENOENT
+
+
+def test_get_tcp_authopt():
+    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+        with pytest.raises(OSError) as e:
+            sock.getsockopt(socket.SOL_TCP, TCP_AUTHOPT, 4)
+        assert e.value.errno == errno.ENOENT
+
+
+def test_set_get_tcp_authopt_flags():
+    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+        # No flags by default
+        set_tcp_authopt(sock, tcp_authopt())
+        opt = get_tcp_authopt(sock)
+        assert opt.flags == 0
+
+        # simple flags are echoed
+        goodflag = TCP_AUTHOPT_FLAG.REJECT_UNEXPECTED
+        set_tcp_authopt(sock, tcp_authopt(flags=goodflag))
+        opt = get_tcp_authopt(sock)
+        assert opt.flags == goodflag
+
+        # attempting to set a badflag returns an error and has no effect
+        badflag = 1 << 27
+        with pytest.raises(OSError) as e:
+            set_tcp_authopt(sock, tcp_authopt(flags=badflag))
+        opt = get_tcp_authopt(sock)
+        assert opt.flags == goodflag
+
+
+def test_set_ipv6_key_on_ipv4():
+    """Binding a key to an ipv6 address on an ipv4 socket makes no sense"""
+    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+        key = tcp_authopt_key("abc")
+        key.flags = TCP_AUTHOPT_KEY_FLAG.BIND_ADDR
+        key.addr = IPv6Address("::1234")
+        with pytest.raises(OSError):
+            set_tcp_authopt_key(sock, key)
+
+
+def test_set_ipv4_key_on_ipv6():
+    """This could be implemented for ipv6-mapped-ipv4 but it is not
+
+    TCP_MD5SIG has a similar limitation
+    """
+    with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as sock:
+        key = tcp_authopt_key("abc")
+        key.flags = TCP_AUTHOPT_KEY_FLAG.BIND_ADDR
+        key.addr = IPv4Address("1.2.3.4")
+        with pytest.raises(OSError):
+            set_tcp_authopt_key(sock, key)
+
+
+def test_authopt_key_badflags():
+    """Don't pretend to handle unknown flags"""
+    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+        with pytest.raises(OSError):
+            set_tcp_authopt_key(sock, tcp_authopt_key(flags=0xabcdef))
+
+
+def test_authopt_key_longer_bad():
+    """Test that pass a longer sockopt with unknown data fails
+
+    Old kernels won't pretend to handle features they don't know about
+    """
+    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+        key = tcp_authopt_key(alg=TCP_AUTHOPT_ALG.HMAC_SHA_1_96, key="aaa")
+        optbuf = bytes(key)
+        optbuf = optbuf.ljust(len(optbuf) + 256, b"\x5a")
+        with pytest.raises(OSError):
+            sock.setsockopt(socket.SOL_TCP, TCP_AUTHOPT_KEY, optbuf)
+
+
+def test_authopt_key_longer_zeros():
+    """Test that passing a longer sockopt padded with zeros works
+
+    This ensures applications using a larger struct tcp_authopt_key won't have
+    to pass a shorter optlen on old kernels.
+    """
+    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+        key = tcp_authopt_key(alg=TCP_AUTHOPT_ALG.HMAC_SHA_1_96, key="aaa")
+        optbuf = bytes(key)
+        optbuf = optbuf.ljust(len(optbuf) + 256, b"\x00")
+        sock.setsockopt(socket.SOL_TCP, TCP_AUTHOPT_KEY, optbuf)
+        # the key was added and can be deleted normally
+        assert del_tcp_authopt_key(sock, key) == True
+        assert del_tcp_authopt_key(sock, key) == False
+
+
+def test_authopt_longer_baddata():
+    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+        opt = tcp_authopt()
+        optbuf = bytes(opt)
+        optbuf = optbuf.ljust(len(optbuf) + 256, b"\x5a")
+        with pytest.raises(OSError):
+            sock.setsockopt(socket.SOL_TCP, TCP_AUTHOPT, optbuf)
+
+
+def test_authopt_longer_zeros():
+    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+        opt = tcp_authopt()
+        optbuf = bytes(opt)
+        optbuf = optbuf.ljust(len(optbuf) + 256, b"\x00")
+        sock.setsockopt(socket.SOL_TCP, TCP_AUTHOPT, optbuf)