Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions Lib/test/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -7529,6 +7529,62 @@ def detach():
pass


class ReentrantMutationTests(unittest.TestCase):
"""Regression tests for re-entrant mutation in sendmsg/recvmsg_into.

These tests verify that mutating sequences during argument parsing
via __buffer__ protocol does not cause crashes.

See: https://github.com/python/cpython/issues/143988
"""

@unittest.skipUnless(hasattr(socket.socket, "sendmsg"),
"sendmsg not supported")
def test_sendmsg_reentrant_data_mutation(self):
Comment thread
vstinner marked this conversation as resolved.
seq = []

class MutBuffer:
def __init__(self):
self.tripped = False

def __buffer__(self, flags):
if not self.tripped:
self.tripped = True
seq.clear()
return memoryview(b'Hello')

seq = [MutBuffer(), b'World', b'Test']

left, right = socket.socketpair()
with left, right:
left.sendmsg(seq)
self.assertEqual(right.recv(1024), b'HelloWorldTest')

@unittest.skipUnless(hasattr(socket.socket, "recvmsg_into"),
"recvmsg_into not supported")
def test_recvmsg_into_reentrant_buffer_mutation(self):
seq = []
buf1 = bytearray(100)

class MutBuffer:
def __init__(self):
self.tripped = False

def __buffer__(self, flags):
if not self.tripped:
self.tripped = True
seq.clear()
return memoryview(buf1)

seq = [MutBuffer(), bytearray(100), bytearray(100)]

left, right = socket.socketpair()
with left, right:
left.send(b'Hello World!')
right.recvmsg_into(seq)
self.assertEqual(buf1, b'Hello World!'.ljust(100, b'\x00'))


def setUpModule():
thread_info = threading_helper.threading_setup()
unittest.addModuleCleanup(threading_helper.threading_cleanup, *thread_info)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fixed crashes in :meth:`socket.socket.sendmsg` and :meth:`socket.socket.recvmsg_into`
Comment thread
vstinner marked this conversation as resolved.
that could occur if buffer sequences are concurrently mutated.
26 changes: 14 additions & 12 deletions Modules/socketmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -4526,17 +4526,19 @@ sock_recvmsg_into(PyObject *self, PyObject *args)
struct iovec *iovs = NULL;
Py_ssize_t i, nitems, nbufs = 0;
Py_buffer *bufs = NULL;
PyObject *buffers_arg, *fast, *retval = NULL;
PyObject *buffers_arg, *buffers_tuple, *retval = NULL;

if (!PyArg_ParseTuple(args, "O|ni:recvmsg_into",
&buffers_arg, &ancbufsize, &flags))
return NULL;

if ((fast = PySequence_Fast(buffers_arg,
"recvmsg_into() argument 1 must be an "
"iterable")) == NULL)
buffers_tuple = PySequence_Tuple(buffers_arg);
if (buffers_tuple == NULL) {
PyErr_SetString(PyExc_TypeError,
"recvmsg_into() argument 1 must be an iterable");
return NULL;
nitems = PySequence_Fast_GET_SIZE(fast);
}
nitems = PyTuple_GET_SIZE(buffers_tuple);
if (nitems > INT_MAX) {
PyErr_SetString(PyExc_OSError, "recvmsg_into() argument 1 is too long");
goto finally;
Expand All @@ -4550,7 +4552,7 @@ sock_recvmsg_into(PyObject *self, PyObject *args)
goto finally;
}
for (; nbufs < nitems; nbufs++) {
if (!PyArg_Parse(PySequence_Fast_GET_ITEM(fast, nbufs),
if (!PyArg_Parse(PyTuple_GET_ITEM(buffers_tuple, nbufs),
"w*;recvmsg_into() argument 1 must be an iterable "
"of single-segment read-write buffers",
&bufs[nbufs]))
Expand All @@ -4566,7 +4568,7 @@ sock_recvmsg_into(PyObject *self, PyObject *args)
PyBuffer_Release(&bufs[i]);
PyMem_Free(bufs);
PyMem_Free(iovs);
Py_DECREF(fast);
Py_DECREF(buffers_tuple);
return retval;
}

Expand Down Expand Up @@ -4861,14 +4863,14 @@ sock_sendmsg_iovec(PySocketSockObject *s, PyObject *data_arg,

/* Fill in an iovec for each message part, and save the Py_buffer
structs to release afterwards. */
data_fast = PySequence_Fast(data_arg,
"sendmsg() argument 1 must be an "
"iterable");
data_fast = PySequence_Tuple(data_arg);
Comment thread
vstinner marked this conversation as resolved.
if (data_fast == NULL) {
PyErr_SetString(PyExc_TypeError,
"sendmsg() argument 1 must be an iterable");
goto finally;
}

ndataparts = PySequence_Fast_GET_SIZE(data_fast);
ndataparts = PyTuple_GET_SIZE(data_fast);
if (ndataparts > INT_MAX) {
PyErr_SetString(PyExc_OSError, "sendmsg() argument 1 is too long");
goto finally;
Expand All @@ -4890,7 +4892,7 @@ sock_sendmsg_iovec(PySocketSockObject *s, PyObject *data_arg,
}
}
for (; ndatabufs < ndataparts; ndatabufs++) {
if (PyObject_GetBuffer(PySequence_Fast_GET_ITEM(data_fast, ndatabufs),
if (PyObject_GetBuffer(PyTuple_GET_ITEM(data_fast, ndatabufs),
&databufs[ndatabufs], PyBUF_SIMPLE) < 0)
goto finally;
iovs[ndatabufs].iov_base = databufs[ndatabufs].buf;
Expand Down
Loading