diff --git a/src/Npgsql/NpgsqlClosedState.cs b/src/Npgsql/NpgsqlClosedState.cs index 19d15358c0..9718b22632 100644 --- a/src/Npgsql/NpgsqlClosedState.cs +++ b/src/Npgsql/NpgsqlClosedState.cs @@ -43,8 +43,12 @@ internal class NpgsqlNetworkStream : NetworkStream { NpgsqlConnector mContext = null; - public NpgsqlNetworkStream(NpgsqlConnector context, Socket socket, Boolean owner) + public NpgsqlNetworkStream(Socket socket, Boolean owner) : base(socket, owner) + { + } + + public void AttachConnector(NpgsqlConnector context) { mContext = context; } @@ -53,8 +57,11 @@ protected override void Dispose(bool disposing) { if (!disposing) { - mContext.Close(); - mContext = null; + if (mContext != null) + { + mContext.Close(); + mContext = null; + } } base.Dispose(disposing); @@ -88,35 +95,6 @@ public override void Open(NpgsqlConnector context, Int32 timeout) { NpgsqlEventLog.LogMethodEnter(LogLevel.Debug, CLASSNAME, "Open"); - /*TcpClient tcpc = new TcpClient(); - tcpc.Connect(new IPEndPoint(ResolveIPHost(context.Host), context.Port)); - Stream stream = tcpc.GetStream();*/ - - /*socket.SetSocketOption (SocketOptionLevel.Socket, SocketOptionName.SendTimeout, context.ConnectionTimeout*1000);*/ - - //socket.Connect(new IPEndPoint(ResolveIPHost(context.Host), context.Port)); - - /*Socket socket = new Socket(AddressFamily.InterNetwork,SocketType.Stream,ProtocolType.Tcp); - - IAsyncResult result = socket.BeginConnect(new IPEndPoint(ResolveIPHost(context.Host), context.Port), null, null); - - if (!result.AsyncWaitHandle.WaitOne(context.ConnectionTimeout*1000, false)) - { - socket.Close(); - throw new Exception(resman.GetString("Exception_ConnectionTimeout")); - } - - try - { - socket.EndConnect(result); - } - catch (Exception) - { - socket.Close(); - throw; - } - */ - IAsyncResult result; // Keep track of time remaining; Even though there may be multiple timeout-able calls, // this allows us to still respect the caller's timeout expectation. @@ -181,51 +159,55 @@ public override void Open(NpgsqlConnector context, Int32 timeout) throw lastSocketException; } - //Stream stream = new NetworkStream(socket, true); - Stream stream = new NpgsqlNetworkStream(context, socket, true); + NpgsqlNetworkStream baseStream = new NpgsqlNetworkStream(socket, true); + Stream sslStream = null; // If the PostgreSQL server has SSL connectors enabled Open SslClientStream if (response == 'S') { if (context.SSL || (context.SslMode == SslMode.Require) || (context.SslMode == SslMode.Prefer)) { - stream + baseStream .WriteInt32(8) .WriteInt32(80877103); + // Receive response + Char response = (Char) baseStream.ReadByte(); - Char response = (Char) stream.ReadByte(); if (response == 'S') { - //create empty collection - X509CertificateCollection clientCertificates = new X509CertificateCollection(); - - //trigger the callback to fetch some certificates - context.DefaultProvideClientCertificatesCallback(clientCertificates); - - if (context.UseMonoSsl) - { - stream = new SslClientStream( - stream, - context.Host, - true, - SecurityProtocolType.Default, - clientCertificates); - - ((SslClientStream)stream).ClientCertSelectionDelegate = - new CertificateSelectionCallback(context.DefaultCertificateSelectionCallback); - ((SslClientStream)stream).ServerCertValidationDelegate = - new CertificateValidationCallback(context.DefaultCertificateValidationCallback); - ((SslClientStream)stream).PrivateKeyCertSelectionDelegate = - new PrivateKeySelectionCallback(context.DefaultPrivateKeySelectionCallback); - } - else - { - SslStream sstream = new SslStream(stream, true, delegate(object sender, X509Certificate cert, X509Chain chain, SslPolicyErrors errors) - { - return context.DefaultValidateRemoteCertificateCallback(cert, chain, errors); - }); - sstream.AuthenticateAsClient(context.Host, clientCertificates, System.Security.Authentication.SslProtocols.Default, false); - stream = sstream; - } + //create empty collection + X509CertificateCollection clientCertificates = new X509CertificateCollection(); + + //trigger the callback to fetch some certificates + context.DefaultProvideClientCertificatesCallback(clientCertificates); + + if (context.UseMonoSsl) + { + SslClientStream sslStreamPriv; + + sslStreamPriv = new SslClientStream( + baseStream, + context.Host, + true, + SecurityProtocolType.Default, + clientCertificates); + + sslStreamPriv.ClientCertSelectionDelegate = + new CertificateSelectionCallback(context.DefaultCertificateSelectionCallback); + sslStreamPriv.ServerCertValidationDelegate = + new CertificateValidationCallback(context.DefaultCertificateValidationCallback); + sslStreamPriv.PrivateKeyCertSelectionDelegate = + new PrivateKeySelectionCallback(context.DefaultPrivateKeySelectionCallback); + sslStream = sslStreamPriv; + } + else + { + SslStream sslStreamPriv; + + sslStreamPriv = new SslStream(baseStream, true, context.DefaultValidateRemoteCertificateCallback); + + sslStreamPriv.AuthenticateAsClient(context.Host, clientCertificates, System.Security.Authentication.SslProtocols.Default, false); + sslStream = sslStreamPriv; + } } else if (context.SslMode == SslMode.Require) { @@ -233,8 +215,9 @@ public override void Open(NpgsqlConnector context, Int32 timeout) } } - context.Stream = new BufferedStream(stream); context.Socket = socket; + context.BaseStream = baseStream; + context.Stream = new BufferedStream(sslStream == null ? baseStream : sslStream); NpgsqlEventLog.LogMsg(resman, "Log_ConnectedTo", LogLevel.Normal, context.Host, context.Port); ChangeState(context, NpgsqlConnectedState.Instance); diff --git a/src/Npgsql/NpgsqlConnector.cs b/src/Npgsql/NpgsqlConnector.cs index 264085dd22..13ab76bd5e 100644 --- a/src/Npgsql/NpgsqlConnector.cs +++ b/src/Npgsql/NpgsqlConnector.cs @@ -103,10 +103,15 @@ internal class NpgsqlConnector private ConnectionState _connection_state; - // The physical network connection to the backend. - private Stream _stream; - + // The physical network connection socket and stream to the backend. private Socket _socket; + private NpgsqlNetworkStream _baseStream; + + // The top level stream to the backend. + // This is a BufferedStream. + // With SSL, this stream sits on top of the SSL stream, which sits on top of _baseStream. + // Otherwise, this stream sits directly on top of _baseStream. + private BufferedStream _stream; // Mediator which will hold data generated from backend. private readonly NpgsqlMediator _mediator; @@ -196,6 +201,7 @@ public NpgsqlConnector(NpgsqlConnectionStringBuilder ConnectionString, bool Pool _notificationThreadStopCount = 1; } + //Finalizer should never be used, but if some incident has left to a connector being abandoned (most likely //case being a user not cleaning up a connection properly) then this way we can at least reduce the damage. @@ -534,7 +540,7 @@ internal void DefaultProvideClientCertificatesCallback(X509CertificateCollection /// /// Default SSL ValidateRemoteCertificateCallback implementation. /// - internal bool DefaultValidateRemoteCertificateCallback(X509Certificate cert, X509Chain chain, SslPolicyErrors errors) + internal bool DefaultValidateRemoteCertificateCallback(object sender, X509Certificate cert, X509Chain chain, SslPolicyErrors errors) { if (ValidateRemoteCertificateCallback != null) { @@ -564,22 +570,31 @@ internal ProtocolVersion BackendProtocolVersion set { _backendProtocolVersion = value; } } + /// + /// The physical connection socket to the backend. + /// + internal Socket Socket + { + get { return _socket; } + set { _socket = value; } + } + /// /// The physical connection stream to the backend. /// - internal Stream Stream + internal NpgsqlNetworkStream BaseStream { - get { return _stream; } - set { _stream = value; } + get { return _baseStream; } + set { _baseStream = value; } } /// - /// The physical connection socket to the backend. + /// The top level stream to the backend. /// - internal Socket Socket + internal BufferedStream Stream { - get { return _socket; } - set { _socket = value; } + get { return _stream; } + set { _stream = value; } } /// @@ -736,6 +751,17 @@ internal void Open() } catch (NpgsqlException ne) { + if (_stream != null) + { + try + { + _stream.Dispose(); + } + catch + { + } + } + connectTimeRemaining -= Convert.ToInt32((DateTime.Now - attemptStart).TotalMilliseconds); // Check for protocol not supported. If we have been told what protocol to use, @@ -773,6 +799,9 @@ internal void Open() _connection_state = ConnectionState.Open; CurrentState = NpgsqlReadyState.Instance; + // After attachment, the stream will close the connector (this) when the stream gets disposed. + _baseStream.AttachConnector(this); + // Fall back to the old way, SELECT VERSION(). // This should not happen for protocol version 3+. if (ServerVersion == null)