From 059a13cc513e137ed899804d30bb8b7563ebba99 Mon Sep 17 00:00:00 2001
From: Colin Taylor <cjntaylor@gmail.com>
Date: Fri, 22 Dec 2023 14:20:43 -0500
Subject: [PATCH] Modify unix socket tests to use stdlib tempdirs

---
 tests/test_sockets.py | 99 ++++++++++++++++++++++++++-----------------
 1 file changed, 60 insertions(+), 39 deletions(-)

diff --git a/tests/test_sockets.py b/tests/test_sockets.py
index f34a0381..91652e67 100644
--- a/tests/test_sockets.py
+++ b/tests/test_sockets.py
@@ -7,6 +7,7 @@
 import platform
 import socket
 import sys
+import tempfile
 import threading
 import time
 from contextlib import suppress
@@ -14,7 +15,7 @@
 from socket import AddressFamily
 from ssl import SSLContext, SSLError
 from threading import Thread
-from typing import Any, Iterable, Iterator, NoReturn, TypeVar, cast
+from typing import Any, Generator, Iterable, Iterator, NoReturn, TypeVar, cast
 
 import psutil
 import pytest
@@ -707,8 +708,11 @@ async def test_bind_link_local(self) -> None:
 )
 class TestUNIXStream:
     @pytest.fixture
-    def socket_path(self, tmp_path_factory: TempPathFactory) -> Path:
-        return tmp_path_factory.mktemp("unix").joinpath("socket")
+    def socket_path(self) -> Generator[Path, None, None]:
+        # Use stdlib tempdir generation
+        # Fixes `OSError: AF_UNIX path too long` from pytest generated temp_path
+        with tempfile.TemporaryDirectory() as path:
+            yield Path(path) / "socket"
 
     @pytest.fixture(params=[False, True], ids=["str", "path"])
     def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Path | str:
@@ -1026,8 +1030,11 @@ async def test_connecting_with_non_utf8(self, socket_path: Path) -> None:
 )
 class TestUNIXListener:
     @pytest.fixture
-    def socket_path(self, tmp_path_factory: TempPathFactory) -> Path:
-        return tmp_path_factory.mktemp("unix").joinpath("socket")
+    def socket_path(self) -> Generator[Path, None, None]:
+        # Use stdlib tempdir generation
+        # Fixes `OSError: AF_UNIX path too long` from pytest generated temp_path
+        with tempfile.TemporaryDirectory() as path:
+            yield Path(path) / "socket"
 
     @pytest.fixture(params=[False, True], ids=["str", "path"])
     def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Path | str:
@@ -1140,35 +1147,37 @@ async def handle(stream: SocketStream) -> None:
 
     client_addresses: list[str | IPSockAddrType] = []
     listeners: list[Listener] = [await create_tcp_listener(local_host="localhost")]
-    if sys.platform != "win32":
-        socket_path = tmp_path_factory.mktemp("unix").joinpath("socket")
-        listeners.append(await create_unix_listener(socket_path))
-
-    expected_addresses: list[str | IPSockAddrType] = []
-    async with MultiListener(listeners) as multi_listener:
-        async with create_task_group() as tg:
-            tg.start_soon(multi_listener.serve, handle)
-            for listener in multi_listener.listeners:
-                event = Event()
-                local_address = listener.extra(SocketAttribute.local_address)
-                if (
-                    sys.platform != "win32"
-                    and listener.extra(SocketAttribute.family)
-                    == socket.AddressFamily.AF_UNIX
-                ):
-                    assert isinstance(local_address, str)
-                    stream: SocketStream = await connect_unix(local_address)
-                else:
-                    assert isinstance(local_address, tuple)
-                    stream = await connect_tcp(*local_address)
+    with tempfile.TemporaryDirectory() as path:
+        if sys.platform != "win32":
+            listeners.append(await create_unix_listener(Path(path) / "socket"))
 
-                expected_addresses.append(stream.extra(SocketAttribute.local_address))
-                await event.wait()
-                await stream.aclose()
+        expected_addresses: list[str | IPSockAddrType] = []
+        async with MultiListener(listeners) as multi_listener:
+            async with create_task_group() as tg:
+                tg.start_soon(multi_listener.serve, handle)
+                for listener in multi_listener.listeners:
+                    event = Event()
+                    local_address = listener.extra(SocketAttribute.local_address)
+                    if (
+                        sys.platform != "win32"
+                        and listener.extra(SocketAttribute.family)
+                        == socket.AddressFamily.AF_UNIX
+                    ):
+                        assert isinstance(local_address, str)
+                        stream: SocketStream = await connect_unix(local_address)
+                    else:
+                        assert isinstance(local_address, tuple)
+                        stream = await connect_tcp(*local_address)
+
+                    expected_addresses.append(
+                        stream.extra(SocketAttribute.local_address)
+                    )
+                    await event.wait()
+                    await stream.aclose()
 
-            tg.cancel_scope.cancel()
+                tg.cancel_scope.cancel()
 
-    assert client_addresses == expected_addresses
+        assert client_addresses == expected_addresses
 
 
 @pytest.mark.usefixtures("check_asyncio_bug")
@@ -1423,16 +1432,22 @@ async def test_send_after_close(self, family: AnyIPAddressFamily) -> None:
 )
 class TestUNIXDatagramSocket:
     @pytest.fixture
-    def socket_path(self, tmp_path_factory: TempPathFactory) -> Path:
-        return tmp_path_factory.mktemp("unix").joinpath("socket")
+    def socket_path(self) -> Generator[Path, None, None]:
+        # Use stdlib tempdir generation
+        # Fixes `OSError: AF_UNIX path too long` from pytest generated temp_path
+        with tempfile.TemporaryDirectory() as path:
+            yield Path(path) / "socket"
 
     @pytest.fixture(params=[False, True], ids=["str", "path"])
     def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Path | str:
         return socket_path if request.param else str(socket_path)
 
     @pytest.fixture
-    def peer_socket_path(self, tmp_path_factory: TempPathFactory) -> Path:
-        return tmp_path_factory.mktemp("unix").joinpath("peer_socket")
+    def peer_socket_path(self) -> Generator[Path, None, None]:
+        # Use stdlib tempdir generation
+        # Fixes `OSError: AF_UNIX path too long` from pytest generated temp_path
+        with tempfile.TemporaryDirectory() as path:
+            yield Path(path) / "peer_socket"
 
     async def test_extra_attributes(self, socket_path: Path) -> None:
         async with await create_unix_datagram_socket(local_path=socket_path) as unix_dg:
@@ -1545,16 +1560,22 @@ async def test_local_path_invalid_ascii(self, socket_path: Path) -> None:
 )
 class TestConnectedUNIXDatagramSocket:
     @pytest.fixture
-    def socket_path(self, tmp_path_factory: TempPathFactory) -> Path:
-        return tmp_path_factory.mktemp("unix").joinpath("socket")
+    def socket_path(self) -> Generator[Path, None, None]:
+        # Use stdlib tempdir generation
+        # Fixes `OSError: AF_UNIX path too long` from pytest generated temp_path
+        with tempfile.TemporaryDirectory() as path:
+            yield Path(path) / "socket"
 
     @pytest.fixture(params=[False, True], ids=["str", "path"])
     def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Path | str:
         return socket_path if request.param else str(socket_path)
 
     @pytest.fixture
-    def peer_socket_path(self, tmp_path_factory: TempPathFactory) -> Path:
-        return tmp_path_factory.mktemp("unix").joinpath("peer_socket")
+    def peer_socket_path(self) -> Generator[Path, None, None]:
+        # Use stdlib tempdir generation
+        # Fixes `OSError: AF_UNIX path too long` from pytest generated temp_path
+        with tempfile.TemporaryDirectory() as path:
+            yield Path(path) / "peer_socket"
 
     @pytest.fixture(params=[False, True], ids=["peer_str", "peer_path"])
     def peer_socket_path_or_str(
