Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
SSLSession
  • Loading branch information
youknowone committed Oct 23, 2025
commit 0609a975b80bc204de4173145e36e2e26bd2accb
4 changes: 2 additions & 2 deletions Lib/ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
import _ssl # if we can't import it, let the error propagate

from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION
from _ssl import _SSLContext#, MemoryBIO, SSLSession
from _ssl import _SSLContext, SSLSession #, MemoryBIO
from _ssl import (
SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError,
SSLSyscallError, SSLEOFError, SSLCertVerificationError
Expand All @@ -114,7 +114,7 @@

from _ssl import (
HAS_SNI, HAS_ECDH, HAS_NPN, HAS_ALPN, HAS_SSLv2, HAS_SSLv3, HAS_TLSv1,
HAS_TLSv1_1, HAS_TLSv1_2, HAS_TLSv1_3
HAS_TLSv1_1, HAS_TLSv1_2, HAS_TLSv1_3, HAS_PSK
)
from _ssl import _DEFAULT_CIPHERS, _OPENSSL_API_VERSION

Expand Down
218 changes: 210 additions & 8 deletions stdlib/src/ssl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,16 @@ mod _ssl {
},
socket::{self, PySocket},
vm::{
PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
builtins::{PyBaseExceptionRef, PyStrRef, PyType, PyTypeRef, PyWeak},
class_or_notimplemented,
convert::{ToPyException, ToPyObject},
exceptions,
function::{
ArgBytesLike, ArgCallable, ArgMemoryBuffer, ArgStrOrBytesLike, Either, FsPath,
OptionalArg,
OptionalArg, PyComparisonValue,
},
types::Constructor,
types::{Comparable, Constructor, PyComparisonOp},
utils::ToCString,
},
};
Expand Down Expand Up @@ -162,6 +163,8 @@ mod _ssl {
const HAS_TLSv1_2: bool = true;
#[pyattr]
const HAS_TLSv1_3: bool = cfg!(ossl111);
#[pyattr]
const HAS_PSK: bool = true;

// the openssl version from the API headers

Expand Down Expand Up @@ -816,16 +819,22 @@ mod _ssl {
let stream = ssl::SslStream::new(ssl, SocketStream(args.sock.clone()))
.map_err(|e| convert_openssl_error(vm, e))?;

// TODO: use this
let _ = args.session;

Ok(PySslSocket {
let py_ssl_socket = PySslSocket {
ctx: zelf,
stream: PyRwLock::new(stream),
socket_type,
server_hostname: args.server_hostname,
owner: PyRwLock::new(args.owner.map(|o| o.downgrade(None, vm)).transpose()?),
})
};

// Set session if provided
if let Some(session) = args.session
&& !vm.is_none(&session)
{
py_ssl_socket.set_session(session, vm)?;
}

Ok(py_ssl_socket)
}
}

Expand Down Expand Up @@ -1103,6 +1112,73 @@ mod _ssl {
}
}

#[pygetset]
fn session(&self, _vm: &VirtualMachine) -> PyResult<Option<PySslSession>> {
let stream = self.stream.read();
unsafe {
let session_ptr = sys::SSL_get_session(stream.ssl().as_ptr());
if session_ptr.is_null() {
Ok(None)
} else {
// Increment reference count since SSL_get_session returns a borrowed reference
#[cfg(ossl110)]
let _session = sys::SSL_SESSION_up_ref(session_ptr);

Ok(Some(PySslSession {
session: session_ptr,
ctx: self.ctx.clone(),
}))
}
}
}

#[pygetset(setter)]
fn set_session(&self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
// Check if value is SSLSession type
let session = value
.downcast_ref::<PySslSession>()
.ok_or_else(|| vm.new_type_error("Value is not a SSLSession.".to_owned()))?;

// Check if session refers to the same SSLContext
if !std::ptr::eq(
self.ctx.ctx.read().as_ptr(),
session.ctx.ctx.read().as_ptr(),
) {
return Err(
vm.new_value_error("Session refers to a different SSLContext.".to_owned())
);
}

// Check if this is a client socket
if self.socket_type != SslServerOrClient::Client {
return Err(
vm.new_value_error("Cannot set session for server-side SSLSocket.".to_owned())
);
}

// Check if handshake is not finished
let stream = self.stream.read();
unsafe {
if sys::SSL_is_init_finished(stream.ssl().as_ptr()) != 0 {
return Err(
vm.new_value_error("Cannot set session after handshake.".to_owned())
);
}

if sys::SSL_set_session(stream.ssl().as_ptr(), session.session) == 0 {
return Err(convert_openssl_error(vm, ErrorStack::get()));
}
}

Ok(())
}

#[pygetset]
fn session_reused(&self) -> bool {
let stream = self.stream.read();
unsafe { sys::SSL_session_reused(stream.ssl().as_ptr()) != 0 }
}

#[pymethod]
fn read(
&self,
Expand Down Expand Up @@ -1164,6 +1240,132 @@ mod _ssl {
}
}

#[pyattr]
#[pyclass(module = "ssl", name = "SSLSession")]
#[derive(PyPayload)]
struct PySslSession {
session: *mut sys::SSL_SESSION,
ctx: PyRef<PySslContext>,
}

impl fmt::Debug for PySslSession {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.pad("SSLSession")
}
}

impl Drop for PySslSession {
fn drop(&mut self) {
if !self.session.is_null() {
unsafe {
sys::SSL_SESSION_free(self.session);
}
}
}
}

unsafe impl Send for PySslSession {}
unsafe impl Sync for PySslSession {}

impl Comparable for PySslSession {
fn cmp(
zelf: &Py<Self>,
other: &crate::vm::PyObject,
op: PyComparisonOp,
_vm: &VirtualMachine,
) -> PyResult<PyComparisonValue> {
let other = class_or_notimplemented!(Self, other);

if !matches!(op, PyComparisonOp::Eq | PyComparisonOp::Ne) {
return Ok(PyComparisonValue::NotImplemented);
}
let mut eq = unsafe {
let mut self_len: libc::c_uint = 0;
let mut other_len: libc::c_uint = 0;
let self_id = sys::SSL_SESSION_get_id(zelf.session, &mut self_len);
let other_id = sys::SSL_SESSION_get_id(other.session, &mut other_len);

if self_len != other_len {
false
} else {
let self_slice = std::slice::from_raw_parts(self_id, self_len as usize);
let other_slice = std::slice::from_raw_parts(other_id, other_len as usize);
self_slice == other_slice
}
};
if matches!(op, PyComparisonOp::Ne) {
eq = !eq;
}
Ok(PyComparisonValue::Implemented(eq))
}
}

#[pyclass(with(Comparable))]
impl PySslSession {
#[pygetset]
fn time(&self) -> i64 {
unsafe {
#[cfg(ossl330)]
{
sys::SSL_SESSION_get_time(self.session) as i64
}
#[cfg(not(ossl330))]
{
sys::SSL_SESSION_get_time(self.session) as i64
}
}
}

#[pygetset]
fn timeout(&self) -> i64 {
unsafe { sys::SSL_SESSION_get_timeout(self.session) as i64 }
}

#[pygetset]
fn ticket_lifetime_hint(&self) -> u64 {
// SSL_SESSION_get_ticket_lifetime_hint may not be available in older OpenSSL
// Return 0 as default if not available
#[cfg(ossl110)]
{
// For now, return 0 as this function may not be in openssl-sys
let _ = self.session;
0
}
#[cfg(not(ossl110))]
{
let _ = self.session;
0
}
}

#[pygetset]
fn id(&self, vm: &VirtualMachine) -> PyObjectRef {
unsafe {
let mut len: libc::c_uint = 0;
let id_ptr = sys::SSL_SESSION_get_id(self.session, &mut len);
let id_slice = std::slice::from_raw_parts(id_ptr, len as usize);
vm.ctx.new_bytes(id_slice.to_vec()).into()
}
}

#[pygetset]
fn has_ticket(&self) -> bool {
// SSL_SESSION_has_ticket may not be available in older OpenSSL
// Return false as default
#[cfg(ossl110)]
{
// For now, return false as this function may not be in openssl-sys
let _ = self.session;
false
}
#[cfg(not(ossl110))]
{
let _ = self.session;
false
}
}
}

#[track_caller]
fn convert_openssl_error(vm: &VirtualMachine, err: ErrorStack) -> PyBaseExceptionRef {
let cls = ssl_error(vm);
Expand Down