Skip to content
Prev Previous commit
Next Next commit
Rework row reading
  • Loading branch information
NinoFloris committed Mar 26, 2024
commit fe448b4688bae9df37f557e371e395921e5d6cd3
24 changes: 5 additions & 19 deletions src/Npgsql/Internal/NpgsqlConnector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1399,15 +1399,7 @@ internal ValueTask<IBackendMessage> ReadMessage(
}
}

internal IBackendMessage? ParseResultSetMessage(NpgsqlReadBuffer buf, BackendMessageCode code, int len, bool handleCallbacks = false)
=> code switch
{
BackendMessageCode.DataRow => _dataRowMessage.Load(len),
BackendMessageCode.CommandComplete => _commandCompleteMessage.Load(buf, len),
_ => ParseServerMessage(buf, code, len, false, handleCallbacks)
};

internal IBackendMessage? ParseServerMessage(NpgsqlReadBuffer buf, BackendMessageCode code, int len, bool isPrependedMessage, bool handleCallbacks = true)
internal IBackendMessage? ParseServerMessage(NpgsqlReadBuffer buf, BackendMessageCode code, int len, bool isPrependedMessage = false)
{
switch (code)
{
Expand Down Expand Up @@ -1443,18 +1435,12 @@ internal ValueTask<IBackendMessage> ReadMessage(
ReadParameterStatus(buf.GetNullTerminatedBytes(), buf.GetNullTerminatedBytes());
return null;
case BackendMessageCode.NoticeResponse:
if (handleCallbacks)
{
var notice = PostgresNotice.Load(buf, Settings.IncludeErrorDetail, LoggingConfiguration.ExceptionLogger);
LogMessages.ReceivedNotice(ConnectionLogger, notice.MessageText, Id);
Connection?.OnNotice(notice);
}
var notice = PostgresNotice.Load(buf, Settings.IncludeErrorDetail, LoggingConfiguration.ExceptionLogger);
LogMessages.ReceivedNotice(ConnectionLogger, notice.MessageText, Id);
Connection?.OnNotice(notice);
return null;
case BackendMessageCode.NotificationResponse:
if (handleCallbacks)
{
Connection?.OnNotification(new NpgsqlNotificationEventArgs(buf));
}
Connection?.OnNotification(new NpgsqlNotificationEventArgs(buf));
return null;

case BackendMessageCode.AuthenticationRequest:
Expand Down
32 changes: 32 additions & 0 deletions src/Npgsql/Internal/NpgsqlReadBuffer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.IO;
using System.Net.Sockets;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -13,6 +14,37 @@

namespace Npgsql.Internal;

readonly struct MessageHeader(BackendMessageCode code, int length)
{
public const int ByteCount = sizeof(byte) + sizeof(int);

public BackendMessageCode Code { get; } = code;
public int Length { get; } = length;

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool TryParse(ref ReadOnlySpan<byte> source, out MessageHeader header)
{
const int headerSize = sizeof(byte) + sizeof(int);

if (source.Length < headerSize)
{
header = default;
return false;
}

ref var first = ref MemoryMarshal.GetReference(source);
var code = (BackendMessageCode)first;
var length =
BitConverter.IsLittleEndian
? BinaryPrimitives.ReverseEndianness(Unsafe.ReadUnaligned<int>(ref Unsafe.Add(ref first, 1)))
: Unsafe.ReadUnaligned<int>(ref Unsafe.Add(ref first, 1));

header = new MessageHeader(code, length - sizeof(int)); // Transmitted length includes itself
source = source.Slice(headerSize);
return true;
}
}

/// <summary>
/// A buffer used by Npgsql to read data from the socket efficiently.
/// Provides methods which decode different values types and tracks the current position.
Expand Down
170 changes: 82 additions & 88 deletions src/Npgsql/NpgsqlDataReader.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Buffers;
using System.Buffers.Binary;
using System.Collections;
using System.Collections.Generic;
using System.Collections.ObjectModel;
Expand Down Expand Up @@ -63,6 +64,8 @@ public sealed class NpgsqlDataReader : DbDataReader, IDbColumnSchemaGenerator
/// </summary>
internal int StatementIndex { get; private set; }

bool _expectErrorBarrier;

/// <summary>
/// Records, for each column, its starting offset and length in the current row.
/// Used only in non-sequential mode.
Expand Down Expand Up @@ -147,6 +150,7 @@ internal void Init(
State = ReaderState.BetweenResults;
_recordsAffected = null;
_startTimestamp = startTimestamp;
_expectErrorBarrier = false;
}

#region Read
Expand All @@ -161,7 +165,7 @@ internal void Init(
public override bool Read()
{
ThrowIfClosedOrDisposed();
return TryRead()?.Result ?? Read(false).GetAwaiter().GetResult();
return Read(async: false).GetAwaiter().GetResult();
}

/// <summary>
Expand All @@ -174,125 +178,115 @@ public override bool Read()
public override Task<bool> ReadAsync(CancellationToken cancellationToken)
{
ThrowIfClosedOrDisposed();
return TryRead() ?? Read(async: true, cancellationToken);
return Read(async: true, cancellationToken);
}

// This is an optimized execution path that avoids calling any async methods for the (usual)
// case where the next row (or CommandComplete) is already in memory.
Task<bool>? TryRead()
[MethodImpl(MethodImplOptions.AggressiveInlining)]
Task<bool> Read(bool async, CancellationToken cancellationToken = default)
{
switch (State)
{
case ReaderState.InResult:
{
if (!_isRowBuffered || _behavior.HasFlag(CommandBehavior.SingleRow))
return InResultSlow(async, cancellationToken);

// Consume current row.
var buffer = Buffer;
buffer.PgReader.Commit();
buffer.ReadPosition = _dataMsgEnd;

var span = buffer.Span;
if (!MessageHeader.TryParse(ref span, out var header))
return InResultSlow(async, cancellationToken);

switch (header.Code)
{
// sizeof(short) is for the number of columns.
case BackendMessageCode.DataRow when span.Length >= (_isSequential ? sizeof(short) : header.Length):
Debug.Assert(BinaryPrimitives.ReadInt16BigEndian(span) == ColumnCount);

var msgEnd = _dataMsgEnd = buffer.ReadPosition + MessageHeader.ByteCount + header.Length;

_isRowBuffered = msgEnd <= buffer.FilledBytes;
_column = -1;

if (_columns.Count != 0)
_columns.Clear();

_columnsStartPos = buffer.ReadPosition += MessageHeader.ByteCount + sizeof(short);
return TrueTask;
case BackendMessageCode.CommandComplete or BackendMessageCode.EmptyQueryResponse when !_expectErrorBarrier && span.Length >= header.Length:
buffer.ReadPosition += MessageHeader.ByteCount;
ProcessMessage(Connector.ParseServerMessage(Buffer, BackendMessageCode.CommandComplete, header.Length)!);
return FalseTask;
default:
return InResultSlow(async, cancellationToken);
}
}
case ReaderState.BeforeResult:
// First Read() after NextResult. Data row has already been processed.
State = ReaderState.InResult;
return TrueTask;
case ReaderState.InResult:
break;
default:
Debug.Assert(Enum.IsDefined(State));
return FalseTask;
}

// We have a special case path for SingleRow.
if (_behavior.HasFlag(CommandBehavior.SingleRow) || !_isRowBuffered)
return null;

ConsumeBufferedRow();

const int headerSize = sizeof(byte) + sizeof(int);
var buffer = Buffer;
var readPosition = buffer.ReadPosition;
var bytesLeft = buffer.FilledBytes - readPosition;
if (bytesLeft < headerSize)
return null;
var messageCode = (BackendMessageCode)buffer.ReadByte();
var len = buffer.ReadInt32() - sizeof(int); // Transmitted length includes itself
var isDataRow = messageCode is BackendMessageCode.DataRow;
// sizeof(short) is for the number of columns
var sufficientBytes = isDataRow && _isSequential ? headerSize + sizeof(short) : headerSize + len;
if (bytesLeft < sufficientBytes
|| !isDataRow && (_statements[StatementIndex].AppendErrorBarrier ?? Command.EnableErrorBarriers)
// Could be an error, let main read handle it.
|| Connector.ParseResultSetMessage(buffer, messageCode, len) is not { } msg)
async Task<bool> InResultSlow(bool async, CancellationToken cancellationToken)
{
buffer.ReadPosition = readPosition;
return null;
}
ProcessMessage(msg);
return isDataRow ? TrueTask : FalseTask;
}

async Task<bool> Read(bool async, CancellationToken cancellationToken = default)
{
using var registration = Connector.StartNestedCancellableOperation(cancellationToken);
try
{
switch (State)
Debug.Assert(State is ReaderState.InResult);
using var registration = Connector.StartNestedCancellableOperation(cancellationToken);
try
{
case ReaderState.BeforeResult:
// First Read() after NextResult. Data row has already been processed.
State = ReaderState.InResult;
return true;

case ReaderState.InResult:
await ConsumeRow(async).ConfigureAwait(false);
// No more rows for single row.
if (_behavior.HasFlag(CommandBehavior.SingleRow))
{
// TODO: See optimization proposal in #410
await Consume(async).ConfigureAwait(false);
return false;
}
break;

case ReaderState.BetweenResults:
case ReaderState.Consumed:
case ReaderState.Closed:
case ReaderState.Disposed:
return false;
default:
ThrowHelper.ThrowArgumentOutOfRangeException();
return false;
}

var msg = await ReadMessage(async).ConfigureAwait(false);
await ConsumeRow(async).ConfigureAwait(false);

switch (msg.Code)
{
case BackendMessageCode.DataRow:
ProcessMessage(msg);
return true;
var msg = await ReadMessage(async).ConfigureAwait(false);

case BackendMessageCode.CommandComplete:
case BackendMessageCode.EmptyQueryResponse:
ProcessMessage(msg);
if (_statements[StatementIndex].AppendErrorBarrier ?? Command.EnableErrorBarriers)
Expect<ReadyForQueryMessage>(await Connector.ReadMessage(async).ConfigureAwait(false), Connector);
return false;
switch (msg.Code)
{
case BackendMessageCode.DataRow:
ProcessMessage(msg);
return true;

default:
throw Connector.UnexpectedMessageReceived(msg.Code);
case BackendMessageCode.CommandComplete:
case BackendMessageCode.EmptyQueryResponse:
ProcessMessage(msg);
if (_expectErrorBarrier)
Expect<ReadyForQueryMessage>(await Connector.ReadMessage(async).ConfigureAwait(false), Connector);
return false;
default:
throw Connector.UnexpectedMessageReceived(msg.Code);
}
}
catch
{
// Break may have progressed the reader already.
if (State is not ReaderState.Closed)
State = ReaderState.Consumed;
throw;
}
}
catch
{
// Break may have progressed the reader already.
if (State is not ReaderState.Closed)
State = ReaderState.Consumed;
throw;
}
}

ValueTask<IBackendMessage> ReadMessage(bool async)
{
return _isSequential ? ReadMessageSequential(Connector, async) : Connector.ReadMessage(async);
return _isSequential ? ReadMessageSequential(async, Connector) : Connector.ReadMessage(async);

static async ValueTask<IBackendMessage> ReadMessageSequential(NpgsqlConnector connector, bool async)
static async ValueTask<IBackendMessage> ReadMessageSequential(bool async, NpgsqlConnector connector)
{
var msg = await connector.ReadMessage(async, DataRowLoadingMode.Sequential).ConfigureAwait(false);
if (msg.Code == BackendMessageCode.DataRow)
if (msg.Code is BackendMessageCode.DataRow)
{
// Make sure that the datarow's column count is already buffered
await connector.ReadBuffer.Ensure(2, async).ConfigureAwait(false);
await connector.ReadBuffer.Ensure(sizeof(short), async).ConfigureAwait(false);
return msg;
}
return msg;
Expand All @@ -310,8 +304,7 @@ static async ValueTask<IBackendMessage> ReadMessageSequential(NpgsqlConnector co
public override bool NextResult()
{
ThrowIfClosedOrDisposed();
return (_isSchemaOnly ? NextResultSchemaOnly(false) : NextResult(false))
.GetAwaiter().GetResult();
return (_isSchemaOnly ? NextResultSchemaOnly(false) : NextResult(false)).GetAwaiter().GetResult();
}

/// <summary>
Expand Down Expand Up @@ -418,6 +411,7 @@ async Task<bool> NextResult(bool async, bool isConsuming = false, CancellationTo

if (RowDescription is not null)
{
_expectErrorBarrier = statement.AppendErrorBarrier ?? Command.EnableErrorBarriers;
if (ColumnInfoCache?.Length >= ColumnCount)
Array.Clear(ColumnInfoCache, 0, ColumnCount);
else
Expand Down Expand Up @@ -776,7 +770,7 @@ async Task<bool> NextResultSchemaOnly(bool async, bool isConsuming = false, Canc

#region ProcessMessage

internal void ProcessMessage(IBackendMessage msg)
void ProcessMessage(IBackendMessage msg)
{
if (msg.Code is not BackendMessageCode.DataRow)
{
Expand Down