diff --git a/src/Npgsql/NpgsqlBinaryExporter.cs b/src/Npgsql/NpgsqlBinaryExporter.cs index c1bd27325d..96f690c66b 100644 --- a/src/Npgsql/NpgsqlBinaryExporter.cs +++ b/src/Npgsql/NpgsqlBinaryExporter.cs @@ -36,6 +36,19 @@ public sealed class NpgsqlBinaryExporter : ICancelable, IAsyncDisposable readonly NpgsqlTypeHandler?[] _typeHandlerCache; static readonly NpgsqlLogger Log = NpgsqlLogManager.CreateLogger(nameof(NpgsqlBinaryExporter)); + /// + /// Current timeout + /// + public TimeSpan Timeout + { + set + { + _buf.Timeout = value; + // While calling Complete(), we're using the connector, which overwrites the buffer's timeout with it's own + _connector.UserTimeout = (int)value.TotalMilliseconds; + } + } + #endregion #region Construction / Initialization diff --git a/src/Npgsql/NpgsqlBinaryImporter.cs b/src/Npgsql/NpgsqlBinaryImporter.cs index 6c39942a98..0a3b753938 100644 --- a/src/Npgsql/NpgsqlBinaryImporter.cs +++ b/src/Npgsql/NpgsqlBinaryImporter.cs @@ -42,6 +42,19 @@ public sealed class NpgsqlBinaryImporter : ICancelable, IAsyncDisposable static readonly NpgsqlLogger Log = NpgsqlLogManager.CreateLogger(nameof(NpgsqlBinaryImporter)); + /// + /// Current timeout + /// + public TimeSpan Timeout + { + set + { + _buf.Timeout = value; + // While calling Complete(), we're using the connector, which overwrites the buffer's timeout with it's own + _connector.UserTimeout = (int)value.TotalMilliseconds; + } + } + #endregion #region Construction / Initialization @@ -117,11 +130,20 @@ async Task StartRow(bool async, CancellationToken cancellationToken = default) if (_column != -1 && _column != NumColumns) ThrowHelper.ThrowInvalidOperationException_BinaryImportParametersMismatch(NumColumns, _column); - if (_buf.WriteSpaceLeft < 2) - await _buf.Flush(async, cancellationToken); - _buf.WriteInt16(NumColumns); + try + { + if (_buf.WriteSpaceLeft < 2) + await _buf.Flush(async, cancellationToken); + _buf.WriteInt16(NumColumns); - _column = 0; + _column = 0; + } + catch + { + // An exception here will have already broken the connection etc. + Cleanup(); + throw; + } } /// @@ -285,25 +307,34 @@ async Task Write([AllowNull] T value, NpgsqlParameter param, bool async, Canc return; } - if (typeof(T) == typeof(object)) - { - param.Value = value; - } - else + try { - if (!(param is NpgsqlParameter typedParam)) + if (typeof(T) == typeof(object)) { - _params[_column] = typedParam = new NpgsqlParameter(); - typedParam.NpgsqlDbType = param.NpgsqlDbType; + param.Value = value; } - typedParam.TypedValue = value; + else + { + if (!(param is NpgsqlParameter typedParam)) + { + _params[_column] = typedParam = new NpgsqlParameter(); + typedParam.NpgsqlDbType = param.NpgsqlDbType; + } + typedParam.TypedValue = value; + } + param.ResolveHandler(_connector.TypeMapper); + param.ValidateAndGetLength(); + param.LengthCache?.Rewind(); + await param.WriteWithLength(_buf, async, cancellationToken); + param.LengthCache?.Clear(); + _column++; + } + catch + { + // An exception here will have already broken the connection etc. + Cleanup(); + throw; } - param.ResolveHandler(_connector.TypeMapper); - param.ValidateAndGetLength(); - param.LengthCache?.Rewind(); - await param.WriteWithLength(_buf, async, cancellationToken); - param.LengthCache?.Clear(); - _column++; } /// @@ -328,11 +359,20 @@ async Task WriteNull(bool async, CancellationToken cancellationToken = default) if (_column == -1) throw new InvalidOperationException("A row hasn't been started"); - if (_buf.WriteSpaceLeft < 4) - await _buf.Flush(async, cancellationToken); + try + { + if (_buf.WriteSpaceLeft < 4) + await _buf.Flush(async, cancellationToken); - _buf.WriteInt32(-1); - _column++; + _buf.WriteInt32(-1); + _column++; + } + catch + { + // An exception here will have already broken the connection etc. + Cleanup(); + throw; + } } /// diff --git a/src/Npgsql/NpgsqlRawCopyStream.cs b/src/Npgsql/NpgsqlRawCopyStream.cs index c28dc2c320..0bb9606b07 100644 --- a/src/Npgsql/NpgsqlRawCopyStream.cs +++ b/src/Npgsql/NpgsqlRawCopyStream.cs @@ -37,6 +37,23 @@ public sealed class NpgsqlRawCopyStream : Stream, ICancelable public override bool CanWrite => _canWrite; public override bool CanRead => _canRead; + public override bool CanTimeout => true; + public override int WriteTimeout + { + get => (int) _writeBuf.Timeout.TotalMilliseconds; + set => _writeBuf.Timeout = TimeSpan.FromMilliseconds(value); + } + public override int ReadTimeout + { + get => (int) _readBuf.Timeout.TotalMilliseconds; + set + { + _readBuf.Timeout = TimeSpan.FromMilliseconds(value); + // While calling the connector it will overwrite our read buffer timeout + _connector.UserTimeout = value; + } + } + /// /// The copy binary format header signature /// diff --git a/test/Npgsql.Tests/CopyTests.cs b/test/Npgsql.Tests/CopyTests.cs index 512776a72e..cbd324c9d9 100644 --- a/test/Npgsql.Tests/CopyTests.cs +++ b/test/Npgsql.Tests/CopyTests.cs @@ -19,42 +19,34 @@ public class CopyTests : MultiplexingTestBase [Test, Description("Reproduce #2257")] public async Task Issue2257() { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: fails"); + await using var conn = await OpenConnectionAsync(); + await using var _ = await GetTempTableName(conn, out var table1); + await using var __ = await GetTempTableName(conn, out var table2); - using (var conn = OpenConnection(new NpgsqlConnectionStringBuilder(ConnectionString) { CommandTimeout = 3 })) + const int rowCount = 1000000; + using (var cmd = conn.CreateCommand()) { - await using var _ = await GetTempTableName(conn, out var table1); - await using var __ = await GetTempTableName(conn, out var table2); + cmd.CommandText = $"CREATE TABLE {table1} AS SELECT * FROM generate_series(1, {rowCount}) id"; + await cmd.ExecuteNonQueryAsync(); + cmd.CommandText = $"ALTER TABLE {table1} ADD CONSTRAINT {table1}_pk PRIMARY KEY (id)"; + await cmd.ExecuteNonQueryAsync(); + cmd.CommandText = $"CREATE TABLE {table2} (master_id integer NOT NULL REFERENCES {table1} (id))"; + await cmd.ExecuteNonQueryAsync(); + } - const int rowCount = 1000000; - using (var cmd = conn.CreateCommand()) + await using var writer = conn.BeginBinaryImport($"COPY {table2} FROM STDIN BINARY"); + writer.Timeout = TimeSpan.FromMilliseconds(3); + var e = Assert.Throws(() => + { + for (var i = 1; i <= rowCount; ++i) { - cmd.CommandText = $"CREATE TABLE {table1} AS SELECT * FROM generate_series(1, {rowCount}) id"; - // Creating table can take some time, so we set quite large timeout - cmd.CommandTimeout = 30; - await cmd.ExecuteNonQueryAsync(); - cmd.CommandText = $"ALTER TABLE {table1} ADD CONSTRAINT {table1}_pk PRIMARY KEY (id)"; - await cmd.ExecuteNonQueryAsync(); - cmd.CommandText = $"CREATE TABLE {table2} (master_id integer NOT NULL REFERENCES {table1} (id))"; - // We need to fail with timeout while calling writer.Complete() and conn.BeginBinaryImport reuses timeout from previous command - // so we set default timeout here - cmd.CommandTimeout = 3; - await cmd.ExecuteNonQueryAsync(); + writer.StartRow(); + writer.Write(i); } - using (var writer = conn.BeginBinaryImport($"COPY {table2} FROM STDIN BINARY")) - { - for (var i = 1; i <= rowCount; ++i) - { - writer.StartRow(); - writer.Write(i); - } - - var e = Assert.Throws(() => writer.Complete()); - Assert.That(e.InnerException, Is.TypeOf()); - } - } + writer.Complete(); + }); + Assert.That(e.InnerException, Is.TypeOf()); } #endregion