Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prepend and quote savepoints
Closes #2892
Closes #3262
  • Loading branch information
roji committed Nov 1, 2020
commit 6d010535612f0e4917b5a2765bdb8390a81c5af4
14 changes: 6 additions & 8 deletions src/Npgsql/NpgsqlConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -575,14 +575,12 @@ async ValueTask<NpgsqlTransaction> BeginTransaction(IsolationLevel level, bool a

try
{
// Note that beginning a transaction doesn't actually send anything to the backend
// (only prepends), so strictly speaking we don't have to start a user action.
// However, we do this for consistency as if we did (for the checks and exceptions)
using (connector.StartUserAction())
{
connector.Transaction.Init(level);
return connector.Transaction;
}
// Note that beginning a transaction doesn't actually send anything to the backend (only prepends), so strictly speaking we
// don't have to start a user action. However, we do this for consistency as if we did (for the checks and exceptions)
using var _ = connector.StartUserAction();

connector.Transaction.Init(level);
return connector.Transaction;
}
catch
{
Expand Down
14 changes: 7 additions & 7 deletions src/Npgsql/NpgsqlConnector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ sealed partial class NpgsqlConnector : IDisposable
/// The number of messages that were prepended to the current message chain, but not yet sent.
/// Note that this only tracks messages which produce a ReadyForQuery message
/// </summary>
int _pendingPrependedResponses;
internal int PendingPrependedResponses { get; set; }

internal NpgsqlDataReader? CurrentReader;

Expand Down Expand Up @@ -1022,7 +1022,7 @@ async Task MultiplexingReadLoop()
/// </summary>
internal void PrependInternalMessage(byte[] rawMessage, int responseMessageCount)
{
_pendingPrependedResponses += responseMessageCount;
PendingPrependedResponses += responseMessageCount;

var t = WritePregenerated(rawMessage);
Debug.Assert(t.IsCompleted, "Could not fully write pregenerated message into the buffer");
Expand Down Expand Up @@ -1061,7 +1061,7 @@ internal ValueTask<IBackendMessage> ReadMessage(bool async, DataRowLoadingMode d
bool attemptPostgresCancellation = true,
CancellationToken cancellationToken = default)
{
if (_pendingPrependedResponses > 0 ||
if (PendingPrependedResponses > 0 ||
dataRowLoadingMode != DataRowLoadingMode.NonSequential ||
readingNotifications ||
ReadBuffer.ReadBytesLeft < 5)
Expand Down Expand Up @@ -1103,13 +1103,13 @@ internal ValueTask<IBackendMessage> ReadMessage(bool async, DataRowLoadingMode d
CancellationToken cancellationToken2 = default)
{
// First read the responses of any prepended messages.
if (_pendingPrependedResponses > 0 && !isReadingPrependedMessage)
if (PendingPrependedResponses > 0 && !isReadingPrependedMessage)
{
try
{
// TODO: There could be room for optimization here, rather than the async call(s)
ReadBuffer.Timeout = TimeSpan.FromMilliseconds(InternalCommandTimeout);
for (; _pendingPrependedResponses > 0; _pendingPrependedResponses--)
for (; PendingPrependedResponses > 0; PendingPrependedResponses--)
await ReadMessageLong(DataRowLoadingMode.Skip, readingNotifications2: false, isReadingPrependedMessage: true,
attemptPostgresCancellation2: false, cancellationToken2);
}
Expand Down Expand Up @@ -1795,7 +1795,7 @@ internal async Task Reset(bool async, CancellationToken cancellationToken = defa

// Our buffer may contain unsent prepended messages (such as BeginTransaction), clear it out completely
WriteBuffer.Clear();
_pendingPrependedResponses = 0;
PendingPrependedResponses = 0;

// We may have allocated an oversize read buffer, switch back to the original one
// TODO: Replace this with array pooling, #2326
Expand Down Expand Up @@ -2139,7 +2139,7 @@ internal async Task ExecuteInternalCommand(byte[] data, bool async, Cancellation
{
Debug.Assert(State != ConnectorState.Ready, "Forgot to start a user action...");

Log.Trace($"Executing internal pregenerated command", Id);
Log.Trace("Executing internal pregenerated command", Id);

await WritePregenerated(data, async, cancellationToken);
await Flush(async, cancellationToken);
Expand Down
91 changes: 67 additions & 24 deletions src/Npgsql/NpgsqlTransaction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Data;
using System.Data.Common;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Npgsql.Logging;
Expand Down Expand Up @@ -184,40 +185,63 @@ public override Task RollbackAsync(CancellationToken cancellationToken = default

#region Savepoints

async Task Save(string name, bool async, CancellationToken cancellationToken = default)
/// <summary>
/// Creates a transaction save point.
/// </summary>
/// <param name="name">The name of the savepoint.</param>
/// <remarks>
/// This method does not cause a database roundtrip to be made. The savepoint creation statement will instead be sent along with
/// the next command.
/// </remarks>
#if NET
public override void Save(string name)
#else
public void Save(string name)
#endif
{
if (name == null)
throw new ArgumentNullException(nameof(name));
if (string.IsNullOrWhiteSpace(name))
throw new ArgumentException("name can't be empty", nameof(name));
if (name.Contains(";"))
throw new ArgumentException("name can't contain a semicolon");

CheckReady();
if (!_connector.DatabaseInfo.SupportsTransactions)
return;
using (_connector.StartUserAction())
{
Log.Debug($"Creating savepoint {name}", _connector.Id);
await _connector.ExecuteInternalCommand($"SAVEPOINT {name}", async, cancellationToken);
}
}

/// <summary>
/// Creates a transaction save point.
/// </summary>
/// <param name="name">The name of the savepoint.</param>
#if NET
public override void Save(string name) => Save(name, false).GetAwaiter().GetResult();
#else
public void Save(string name) => Save(name, false).GetAwaiter().GetResult();
#endif
// Note that creating a savepoint doesn't actually send anything to the backend (only prepends), so strictly speaking we don't
// have to start a user action. However, we do this for consistency as if we did (for the checks and exceptions)
using var _ = _connector.StartUserAction();

Log.Debug($"Creating savepoint {name}", _connector.Id);

if (RequiresQuoting(name))
name = $"\"{name.Replace("\"", "\"\"")}\"";

// Note: savepoint names are PostgreSQL identifiers, and so limited by default to 63 characters.
// Since we are prepending, we assume below that the statement will always fit in the buffer.
_connector.WriteBuffer.WriteByte(FrontendMessageCode.Query);
_connector.WriteBuffer.WriteInt32(
sizeof(int) + // Message length (including self excluding code)
_connector.TextEncoding.GetByteCount("SAVEPOINT ") +
_connector.TextEncoding.GetByteCount(name) +
sizeof(byte)); // Null terminator

_connector.WriteBuffer.WriteString("SAVEPOINT ");
_connector.WriteBuffer.WriteString(name);
_connector.WriteBuffer.WriteByte(0);

_connector.PendingPrependedResponses += 2;
}

/// <summary>
/// Creates a transaction save point.
/// </summary>
/// <param name="name">The name of the savepoint.</param>
/// <param name="cancellationToken">The token to monitor for cancellation requests. The default value is <see cref="CancellationToken.None"/>.</param>
/// <remarks>
/// This method does not cause a database roundtrip to be made, and will therefore always complete synchronously.
/// The savepoint creation statement will instead be sent along with the next command.
/// </remarks>
#if NET
public override Task SaveAsync(string name, CancellationToken cancellationToken = default)
#else
Expand All @@ -226,8 +250,8 @@ public Task SaveAsync(string name, CancellationToken cancellationToken = default
{
if (cancellationToken.IsCancellationRequested)
return Task.FromCanceled(cancellationToken);
using (NoSynchronizationContextScope.Enter())
return Save(name, true, cancellationToken);
Save(name);
return Task.CompletedTask;
}

async Task Rollback(string name, bool async, CancellationToken cancellationToken = default)
Expand All @@ -236,15 +260,17 @@ async Task Rollback(string name, bool async, CancellationToken cancellationToken
throw new ArgumentNullException(nameof(name));
if (string.IsNullOrWhiteSpace(name))
throw new ArgumentException("name can't be empty", nameof(name));
if (name.Contains(";"))
throw new ArgumentException("name can't contain a semicolon");

CheckReady();
if (!_connector.DatabaseInfo.SupportsTransactions)
return;
using (_connector.StartUserAction())
{
Log.Debug($"Rolling back savepoint {name}", _connector.Id);

if (RequiresQuoting(name))
name = $"\"{name.Replace("\"", "\"\"")}\"";

await _connector.ExecuteInternalCommand($"ROLLBACK TO SAVEPOINT {name}", async, cancellationToken);
}
}
Expand Down Expand Up @@ -283,15 +309,17 @@ async Task Release(string name, bool async, CancellationToken cancellationToken
throw new ArgumentNullException(nameof(name));
if (string.IsNullOrWhiteSpace(name))
throw new ArgumentException("name can't be empty", nameof(name));
if (name.Contains(";"))
throw new ArgumentException("name can't contain a semicolon");

CheckReady();
if (!_connector.DatabaseInfo.SupportsTransactions)
return;
using (_connector.StartUserAction())
{
Log.Debug($"Releasing savepoint {name}", _connector.Id);

if (RequiresQuoting(name))
name = $"\"{name.Replace("\"", "\"\"")}\"";

await _connector.ExecuteInternalCommand($"RELEASE SAVEPOINT {name}", async, cancellationToken);
}
}
Expand Down Expand Up @@ -401,6 +429,21 @@ void CheckReady()
throw new InvalidOperationException("This NpgsqlTransaction has completed; it is no longer usable.");
}

static bool RequiresQuoting(string identifier)
{
Debug.Assert(identifier.Length > 0);

var first = identifier[0];
if (first != '_' && !char.IsLower(first))
return true;

foreach (var c in identifier.AsSpan(1))
if (c != '_' && c != '$' && !char.IsLower(c) && !char.IsDigit(c))
return true;

return false;
}

#endregion
}
}
4 changes: 2 additions & 2 deletions test/Npgsql.Tests/ReaderTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1582,7 +1582,7 @@ await pgMock
Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open));
}

await pgMock.WriteScalarResponse(1);
await pgMock.WriteScalarResponseAndFlush(1);
Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1));
}

Expand Down Expand Up @@ -1634,7 +1634,7 @@ await pgMock
Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open));
}

await pgMock.WriteScalarResponse(1);
await pgMock.WriteScalarResponseAndFlush(1);
Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1));
}

Expand Down
30 changes: 28 additions & 2 deletions test/Npgsql.Tests/Support/PgServerMock.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ internal async Task SkipMessage()
_readBuffer.Skip(len - 4);
}

internal async Task ReadMessageType(byte expectedCode)
internal async Task ExpectMessage(byte expectedCode)
{
CheckDisposed();

Expand All @@ -84,13 +84,39 @@ internal async Task ReadMessageType(byte expectedCode)
_readBuffer.Skip(len - 4);
}

internal Task ExpectExtendedQuery()
=> ExpectMessages(
FrontendMessageCode.Parse,
FrontendMessageCode.Bind,
FrontendMessageCode.Describe,
FrontendMessageCode.Execute,
FrontendMessageCode.Sync);

internal async Task ExpectMessages(params byte[] expectedCodes)
{
foreach (var expectedCode in expectedCodes)
await ExpectMessage(expectedCode);
}

internal async Task ExpectSimpleQuery(string expectedSql)
{
CheckDisposed();

await _readBuffer.EnsureAsync(5);
var actualCode = _readBuffer.ReadByte();
Assert.That(actualCode, Is.EqualTo(FrontendMessageCode.Query), $"Expected message of type Query but got '{(char)actualCode}'");
_ = _readBuffer.ReadInt32();
var actualSql = _readBuffer.ReadNullTerminatedString();
Assert.That(actualSql, Is.EqualTo(expectedSql));
}

internal Task FlushAsync()
{
CheckDisposed();
return _writeBuffer.Flush(async: true);
}

internal Task WriteScalarResponse(int value)
internal Task WriteScalarResponseAndFlush(int value)
=> WriteParseComplete()
.WriteBindComplete()
.WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Int4))
Expand Down
36 changes: 34 additions & 2 deletions test/Npgsql.Tests/TransactionTests.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
using System;
using System.Buffers.Binary;
using System.Data;
using System.Threading.Tasks;
using Npgsql.BackendMessages;
using Npgsql.Tests.Support;
using Npgsql.Util;
using NUnit.Framework;
using static Npgsql.Tests.TestUtil;
Expand Down Expand Up @@ -368,11 +371,40 @@ public async Task SavepointAsync()
}

[Test]
public async Task SavepointWithSemicolon()
public async Task SavepointQuoted()
{
await using var conn = await OpenConnectionAsync();
await using var tx = conn.BeginTransaction();
Assert.That(() => tx.Save("a;b"), Throws.Exception.TypeOf<ArgumentException>());
tx.Save("a;b");
tx.Rollback("a;b");
}

[Test(Description = "Makes sure that creating a savepoint doesn't perform an additional roundtrip, but prepends to the next command")]
public async Task SavepointPrepends()
{
await using var postmasterMock = PgPostmasterMock.Start(ConnectionString);
using var _ = CreateTempPool(postmasterMock.ConnectionString, out var connectionString);
await using var conn = await OpenConnectionAsync(connectionString);
var pgMock = await postmasterMock.WaitForServerConnection();

using var tx = conn.BeginTransaction();
var saveTask = tx.SaveAsync("foo");
Assert.That(saveTask.Status, Is.EqualTo(TaskStatus.RanToCompletion));

// If we're here, SaveAsync above didn't wait for any response, which is the right behavior

await pgMock
.WriteCommandComplete()
.WriteReadyForQuery() // BEGIN response
.WriteCommandComplete()
.WriteReadyForQuery() // SAVEPOINT response
.WriteScalarResponseAndFlush(1);

await conn.ExecuteScalarAsync("SELECT 1");

await pgMock.ExpectSimpleQuery("BEGIN TRANSACTION ISOLATION LEVEL READ COMMITTED");
await pgMock.ExpectSimpleQuery("SAVEPOINT foo");
await pgMock.ExpectExtendedQuery();
}

[Test, Description("Check IsCompleted before, during and after a normal committed transaction")]
Expand Down