Skip to content

Commit e5d1d11

Browse files
committed
Add UDP to socket
1 parent 9271115 commit e5d1d11

2 files changed

Lines changed: 76 additions & 24 deletions

File tree

tests/snippets/stdlib_socket.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import socket
22
from testutils import assertRaises
33

4+
MESSAGE_A = b'aaaa'
5+
MESSAGE_B= b'bbbbb'
6+
7+
# TCP
48

59
listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
610
listener.bind(("127.0.0.1", 0))
@@ -10,16 +14,12 @@
1014
connector.connect(("127.0.0.1", listener.getsockname()[1]))
1115
connection = listener.accept()[0]
1216

13-
message_a = b'aaaa'
14-
message_b = b'bbbbb'
15-
16-
connector.send(message_a)
17-
connection.send(message_b)
18-
recv_a = connection.recv(len(message_a))
19-
recv_b = connector.recv(len(message_b))
20-
assert recv_a == message_a
21-
assert recv_b == message_b
22-
17+
connector.send(MESSAGE_A)
18+
connection.send(MESSAGE_B)
19+
recv_a = connection.recv(len(MESSAGE_A))
20+
recv_b = connector.recv(len(MESSAGE_B))
21+
assert recv_a == MESSAGE_A
22+
assert recv_b == MESSAGE_B
2323
connection.close()
2424
connector.close()
2525
listener.close()
@@ -35,3 +35,22 @@
3535
s.bind((888, 8888))
3636

3737
s.close()
38+
39+
# UDP
40+
sock1 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
41+
sock1.bind(("127.0.0.1", 0))
42+
43+
sock2 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
44+
sock2.bind(("127.0.0.1", 0))
45+
46+
sock1.connect(("127.0.0.1", sock2.getsockname()[1]))
47+
sock2.connect(("127.0.0.1", sock1.getsockname()[1]))
48+
49+
sock1.send(MESSAGE_A)
50+
sock2.send(MESSAGE_B)
51+
recv_a = sock2.recv(len(MESSAGE_A))
52+
recv_b = sock1.recv(len(MESSAGE_B))
53+
assert recv_a == MESSAGE_A
54+
assert recv_b == MESSAGE_B
55+
sock1.close()
56+
sock2.close()

vm/src/stdlib/socket.rs

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::cell::RefCell;
22
use std::io;
33
use std::io::Read;
44
use std::io::Write;
5-
use std::net::{SocketAddr, TcpListener, TcpStream};
5+
use std::net::{SocketAddr, TcpListener, TcpStream, UdpSocket};
66
use std::ops::DerefMut;
77

88
use crate::obj::objbytes;
@@ -53,7 +53,7 @@ impl SocketKind {
5353
enum Connection {
5454
TcpListener(TcpListener),
5555
TcpStream(TcpStream),
56-
// UdpSocket(UdpSocket),
56+
UdpSocket(UdpSocket),
5757
}
5858

5959
impl Connection {
@@ -67,6 +67,7 @@ impl Connection {
6767
fn local_addr(&self) -> io::Result<SocketAddr> {
6868
match self {
6969
Connection::TcpListener(con) => con.local_addr(),
70+
Connection::UdpSocket(con) => con.local_addr(),
7071
_ => Err(io::Error::new(io::ErrorKind::Other, "oh no!")),
7172
}
7273
}
@@ -76,6 +77,7 @@ impl Read for Connection {
7677
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
7778
match self {
7879
Connection::TcpStream(con) => con.read(buf),
80+
Connection::UdpSocket(con) => con.recv(buf),
7981
_ => Err(io::Error::new(io::ErrorKind::Other, "oh no!")),
8082
}
8183
}
@@ -85,6 +87,7 @@ impl Write for Connection {
8587
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
8688
match self {
8789
Connection::TcpStream(con) => con.write(buf),
90+
Connection::UdpSocket(con) => con.send(buf),
8891
_ => Err(io::Error::new(io::ErrorKind::Other, "oh no!")),
8992
}
9093
}
@@ -153,12 +156,27 @@ fn socket_connect(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
153156

154157
let mut socket = get_socket(zelf);
155158

156-
if let Ok(stream) = TcpStream::connect(address_string) {
157-
socket.con = Some(Connection::TcpStream(stream));
158-
Ok(vm.get_none())
159-
} else {
160-
// TODO: Socket error
161-
Err(vm.new_type_error("socket failed".to_string()))
159+
match socket.socket_kind {
160+
SocketKind::Stream => {
161+
if let Ok(stream) = TcpStream::connect(address_string) {
162+
socket.con = Some(Connection::TcpStream(stream));
163+
Ok(vm.get_none())
164+
} else {
165+
// TODO: Socket error
166+
Err(vm.new_type_error("socket failed".to_string()))
167+
}
168+
}
169+
SocketKind::Dgram => {
170+
if let Some(Connection::UdpSocket(con)) = &socket.con {
171+
match con.connect(address_string) {
172+
Ok(_) => Ok(vm.get_none()),
173+
// TODO: Socket error
174+
Err(_) => Err(vm.new_type_error("socket failed".to_string())),
175+
}
176+
} else {
177+
Err(vm.new_type_error("".to_string()))
178+
}
179+
}
162180
}
163181
}
164182

@@ -173,12 +191,25 @@ fn socket_bind(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
173191

174192
let mut socket = get_socket(zelf);
175193

176-
if let Ok(stream) = TcpListener::bind(address_string) {
177-
socket.con = Some(Connection::TcpListener(stream));
178-
Ok(vm.get_none())
179-
} else {
180-
// TODO: Socket error
181-
Err(vm.new_type_error("socket failed".to_string()))
194+
match socket.socket_kind {
195+
SocketKind::Stream => {
196+
if let Ok(stream) = TcpListener::bind(address_string) {
197+
socket.con = Some(Connection::TcpListener(stream));
198+
Ok(vm.get_none())
199+
} else {
200+
// TODO: Socket error
201+
Err(vm.new_type_error("socket failed".to_string()))
202+
}
203+
}
204+
SocketKind::Dgram => {
205+
if let Ok(dgram) = UdpSocket::bind(address_string) {
206+
socket.con = Some(Connection::UdpSocket(dgram));
207+
Ok(vm.get_none())
208+
} else {
209+
// TODO: Socket error
210+
Err(vm.new_type_error("socket failed".to_string()))
211+
}
212+
}
182213
}
183214
}
184215

@@ -325,6 +356,8 @@ pub fn mk_module(ctx: &PyContext) -> PyObjectRef {
325356
ctx.new_int(SocketKind::Stream as i32),
326357
);
327358

359+
ctx.set_attr(&py_mod, "SOCK_DGRAM", ctx.new_int(SocketKind::Dgram as i32));
360+
328361
let socket = {
329362
let socket = ctx.new_class("socket", ctx.object());
330363
ctx.set_attr(&socket, "__new__", ctx.new_rustfunc(socket_new));

0 commit comments

Comments
 (0)