Skip to content

Commit 716eb4a

Browse files
committed
Add validation to connect and bind address
1 parent dff6b0b commit 716eb4a

File tree

2 files changed

+38
-10
lines changed

2 files changed

+38
-10
lines changed

tests/snippets/stdlib_socket.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import socket
2+
from testutils import assertRaises
23

34
listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
45
listener.bind(("127.0.0.1", 8080))
@@ -22,3 +23,14 @@
2223
connector.close()
2324
listener.close()
2425

26+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
27+
with assertRaises(TypeError):
28+
s.connect(("127.0.0.1", 8888, 8888))
29+
30+
with assertRaises(TypeError):
31+
s.bind(("127.0.0.1", 8888, 8888))
32+
33+
with assertRaises(TypeError):
34+
s.bind((888, 8888))
35+
36+
s.close()

vm/src/stdlib/socket.rs

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,7 @@ fn socket_connect(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
130130
required = [(zelf, None), (address, Some(vm.ctx.tuple_type()))]
131131
);
132132

133-
let elements = get_elements(address);
134-
let host = objstr::get_value(&elements[0]);
135-
let port = objint::get_value(&elements[1]);
136-
137-
let address_string = format!("{}:{}", host, port.to_string());
133+
let address_string = get_address_string(vm, address)?;
138134

139135
match zelf.payload {
140136
PyObjectPayload::Socket { ref socket } => {
@@ -157,11 +153,7 @@ fn socket_bind(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
157153
required = [(zelf, None), (address, Some(vm.ctx.tuple_type()))]
158154
);
159155

160-
let elements = get_elements(address);
161-
let host = objstr::get_value(&elements[0]);
162-
let port = objint::get_value(&elements[1]);
163-
164-
let address_string = format!("{}:{}", host, port.to_string());
156+
let address_string = get_address_string(vm, address)?;
165157

166158
match zelf.payload {
167159
PyObjectPayload::Socket { ref socket } => {
@@ -177,6 +169,30 @@ fn socket_bind(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
177169
}
178170
}
179171

172+
fn get_address_string(
173+
vm: &mut VirtualMachine,
174+
address: &PyObjectRef,
175+
) -> Result<String, PyObjectRef> {
176+
let args = PyFuncArgs {
177+
args: get_elements(address).to_vec(),
178+
kwargs: vec![],
179+
};
180+
arg_check!(
181+
vm,
182+
args,
183+
required = [
184+
(host, Some(vm.ctx.str_type())),
185+
(port, Some(vm.ctx.int_type()))
186+
]
187+
);
188+
189+
Ok(format!(
190+
"{}:{}",
191+
objstr::get_value(host),
192+
objint::get_value(port).to_string()
193+
))
194+
}
195+
180196
fn socket_listen(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
181197
Ok(vm.get_none())
182198
}

0 commit comments

Comments
 (0)