@@ -2,7 +2,7 @@ use std::cell::RefCell;
22use std:: io;
33use std:: io:: Read ;
44use std:: io:: Write ;
5- use std:: net:: { SocketAddr , TcpListener , TcpStream } ;
5+ use std:: net:: { SocketAddr , TcpListener , TcpStream , UdpSocket } ;
66use std:: ops:: DerefMut ;
77
88use crate :: obj:: objbytes;
@@ -53,7 +53,7 @@ impl SocketKind {
5353enum Connection {
5454 TcpListener ( TcpListener ) ,
5555 TcpStream ( TcpStream ) ,
56- // UdpSocket(UdpSocket),
56+ UdpSocket ( UdpSocket ) ,
5757}
5858
5959impl Connection {
@@ -67,6 +67,7 @@ impl Connection {
6767 fn local_addr ( & self ) -> io:: Result < SocketAddr > {
6868 match self {
6969 Connection :: TcpListener ( con) => con. local_addr ( ) ,
70+ Connection :: UdpSocket ( con) => con. local_addr ( ) ,
7071 _ => Err ( io:: Error :: new ( io:: ErrorKind :: Other , "oh no!" ) ) ,
7172 }
7273 }
@@ -76,6 +77,7 @@ impl Read for Connection {
7677 fn read ( & mut self , buf : & mut [ u8 ] ) -> io:: Result < usize > {
7778 match self {
7879 Connection :: TcpStream ( con) => con. read ( buf) ,
80+ Connection :: UdpSocket ( con) => con. recv ( buf) ,
7981 _ => Err ( io:: Error :: new ( io:: ErrorKind :: Other , "oh no!" ) ) ,
8082 }
8183 }
@@ -85,6 +87,7 @@ impl Write for Connection {
8587 fn write ( & mut self , buf : & [ u8 ] ) -> io:: Result < usize > {
8688 match self {
8789 Connection :: TcpStream ( con) => con. write ( buf) ,
90+ Connection :: UdpSocket ( con) => con. send ( buf) ,
8891 _ => Err ( io:: Error :: new ( io:: ErrorKind :: Other , "oh no!" ) ) ,
8992 }
9093 }
@@ -153,12 +156,27 @@ fn socket_connect(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
153156
154157 let mut socket = get_socket ( zelf) ;
155158
156- if let Ok ( stream) = TcpStream :: connect ( address_string) {
157- socket. con = Some ( Connection :: TcpStream ( stream) ) ;
158- Ok ( vm. get_none ( ) )
159- } else {
160- // TODO: Socket error
161- Err ( vm. new_type_error ( "socket failed" . to_string ( ) ) )
159+ match socket. socket_kind {
160+ SocketKind :: Stream => {
161+ if let Ok ( stream) = TcpStream :: connect ( address_string) {
162+ socket. con = Some ( Connection :: TcpStream ( stream) ) ;
163+ Ok ( vm. get_none ( ) )
164+ } else {
165+ // TODO: Socket error
166+ Err ( vm. new_type_error ( "socket failed" . to_string ( ) ) )
167+ }
168+ }
169+ SocketKind :: Dgram => {
170+ if let Some ( Connection :: UdpSocket ( con) ) = & socket. con {
171+ match con. connect ( address_string) {
172+ Ok ( _) => Ok ( vm. get_none ( ) ) ,
173+ // TODO: Socket error
174+ Err ( _) => Err ( vm. new_type_error ( "socket failed" . to_string ( ) ) ) ,
175+ }
176+ } else {
177+ Err ( vm. new_type_error ( "" . to_string ( ) ) )
178+ }
179+ }
162180 }
163181}
164182
@@ -173,12 +191,25 @@ fn socket_bind(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
173191
174192 let mut socket = get_socket ( zelf) ;
175193
176- if let Ok ( stream) = TcpListener :: bind ( address_string) {
177- socket. con = Some ( Connection :: TcpListener ( stream) ) ;
178- Ok ( vm. get_none ( ) )
179- } else {
180- // TODO: Socket error
181- Err ( vm. new_type_error ( "socket failed" . to_string ( ) ) )
194+ match socket. socket_kind {
195+ SocketKind :: Stream => {
196+ if let Ok ( stream) = TcpListener :: bind ( address_string) {
197+ socket. con = Some ( Connection :: TcpListener ( stream) ) ;
198+ Ok ( vm. get_none ( ) )
199+ } else {
200+ // TODO: Socket error
201+ Err ( vm. new_type_error ( "socket failed" . to_string ( ) ) )
202+ }
203+ }
204+ SocketKind :: Dgram => {
205+ if let Ok ( dgram) = UdpSocket :: bind ( address_string) {
206+ socket. con = Some ( Connection :: UdpSocket ( dgram) ) ;
207+ Ok ( vm. get_none ( ) )
208+ } else {
209+ // TODO: Socket error
210+ Err ( vm. new_type_error ( "socket failed" . to_string ( ) ) )
211+ }
212+ }
182213 }
183214}
184215
@@ -325,6 +356,8 @@ pub fn mk_module(ctx: &PyContext) -> PyObjectRef {
325356 ctx. new_int ( SocketKind :: Stream as i32 ) ,
326357 ) ;
327358
359+ ctx. set_attr ( & py_mod, "SOCK_DGRAM" , ctx. new_int ( SocketKind :: Dgram as i32 ) ) ;
360+
328361 let socket = {
329362 let socket = ctx. new_class ( "socket" , ctx. object ( ) ) ;
330363 ctx. set_attr ( & socket, "__new__" , ctx. new_rustfunc ( socket_new) ) ;
0 commit comments