Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 103 additions & 103 deletions src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs
Original file line number Diff line number Diff line change
Expand Up @@ -105,54 +105,52 @@ static void Write(NpgsqlWriteBuffer writeBuffer, int maxRows)
}
}

internal async Task WriteParse(string sql, byte[] asciiName, List<NpgsqlParameter> inputParameters, bool async, CancellationToken cancellationToken = default)
internal async Task WriteParse(string sql, byte[] asciiName, List<NpgsqlParameter> parameters, bool async, CancellationToken cancellationToken = default)
{
NpgsqlWriteBuffer.AssertASCIIOnly(asciiName);

int queryByteLen;
try
{
queryByteLen = TextEncoding.GetByteCount(sql);
}
catch (Exception e)
{
Break(e);
throw;
}
// See https://www.postgresql.org/docs/current/limits.html
if (parameters.Count > ushort.MaxValue)
ThrowHelper.ThrowArgumentException("Too many parameters in statement (max: 65535).", nameof(parameters));
var parameterCount = (ushort)parameters.Count;

var writeBuffer = WriteBuffer;
var messageLength =
sizeof(byte) + // Message code
sizeof(int) + // Length
asciiName.Length + // Statement name
sizeof(byte) + // Null terminator for the statement name
queryByteLen + sizeof(byte) + // SQL query length plus null terminator
sizeof(ushort) + // Number of parameters
inputParameters.Count * sizeof(int); // Parameter OIDs
var queryByteLen = TextEncoding.GetByteCount(sql);

var headerLength =
sizeof(byte) + // Message code
sizeof(int) + // Message length
asciiName.Length + sizeof(byte); // Statement name plus null terminator

WriteBuffer.StartMessage(messageLength);
if (WriteBuffer.WriteSpaceLeft < 1 + 4 + asciiName.Length + 1)
var messageLength =
headerLength +
queryByteLen + sizeof(byte) + // SQL query length plus null terminator
sizeof(ushort) + // Number of parameters
parameterCount * sizeof(int); // Parameter OIDs

var writeBuffer = WriteBuffer;
writeBuffer.StartMessage(messageLength);
if (writeBuffer.WriteSpaceLeft < headerLength)
{
Debug.Assert(writeBuffer.Size >= headerLength, "Write buffer too small for Parse header");
await Flush(async, cancellationToken).ConfigureAwait(false);
}

WriteBuffer.WriteByte(FrontendMessageCode.Parse);
WriteBuffer.WriteInt32(messageLength - 1);
WriteBuffer.WriteNullTerminatedString(asciiName);
writeBuffer.WriteByte(FrontendMessageCode.Parse);
writeBuffer.WriteInt32(messageLength - 1);
writeBuffer.WriteNullTerminatedString(asciiName);

await writeBuffer.WriteString(sql, queryByteLen, async, cancellationToken).ConfigureAwait(false);

if (writeBuffer.WriteSpaceLeft < 1 + 2)
if (writeBuffer.WriteSpaceLeft < sizeof(byte) + sizeof(short))
await Flush(async, cancellationToken).ConfigureAwait(false);
writeBuffer.WriteByte(0); // Null terminator for the query
writeBuffer.WriteUInt16((ushort)inputParameters.Count);

writeBuffer.WriteUInt16(parameterCount);
var databaseInfo = DatabaseInfo;
foreach (var p in inputParameters)
for (var i = 0; i < parameters.Count; i++)
{
if (writeBuffer.WriteSpaceLeft < 4)
if (writeBuffer.WriteSpaceLeft < sizeof(uint))
await Flush(async, cancellationToken).ConfigureAwait(false);

writeBuffer.WriteUInt32(databaseInfo.GetOid(p.PgTypeId).Value);
writeBuffer.WriteUInt32(databaseInfo.GetOid(parameters[i].PgTypeId).Value);
}
}

Expand All @@ -165,105 +163,112 @@ internal async Task WriteBind(
bool async,
CancellationToken cancellationToken = default)
{
const short textFormatCode = 0;
const short binaryFormatCode = 1;

Debug.Assert(portal == string.Empty);

NpgsqlWriteBuffer.AssertASCIIOnly(asciiName);
NpgsqlWriteBuffer.AssertASCIIOnly(portal);

var headerLength =
sizeof(byte) + // Message code
sizeof(int) + // Message length
sizeof(byte) + // Portal is always empty (only a null terminator)
asciiName.Length + sizeof(byte) + // Statement name plus null terminator
sizeof(ushort); // Number of parameter format codes that follow
if (parameters.Count > ushort.MaxValue)
ThrowHelper.ThrowArgumentException("Too many parameters in statement (max: 65535).", nameof(parameters));
var parameterCount = (ushort)parameters.Count;

// PG limit is 1664 (see https://www.postgresql.org/docs/current/limits.html) however we purely guard datatype limits here.
if (unknownResultTypeList?.Length > short.MaxValue)
ThrowHelper.ThrowArgumentException("Too many result types in statement (max: 32768).", nameof(unknownResultTypeList));
var unknownResultTypeCount = (short)(unknownResultTypeList?.Length ?? 1);

var writeBuffer = WriteBuffer;
var formatCodesSum = 0;
var paramsLength = 0;
for (var paramIndex = 0; paramIndex < parameters.Count; paramIndex++)
var parametersLength = 0;
for (var i = 0; i < parameters.Count; i++)
{
var param = parameters[paramIndex];
var param = parameters[i];
param.Bind(out var format, out var size);
paramsLength += size.Value > 0 ? size.Value : 0;
parametersLength += size.Value > 0 ? size.Value : 0;
formatCodesSum += format.ToFormatCode();
}
var formatCodeListLength = formatCodesSum is 0 ? (ushort)0 : formatCodesSum == parameterCount ? (ushort)1 : parameterCount;

var formatCodeListLength = formatCodesSum == 0 ? 0 : formatCodesSum == parameters.Count ? 1 : parameters.Count;
var headerLength =
sizeof(byte) + // Message code
sizeof(int) + // Message length
sizeof(byte) + // Portal (only a null terminator)
asciiName.Length + sizeof(byte) + // Statement name plus null terminator
sizeof(ushort); // Number of parameter format codes that follow

var messageLength = headerLength +
sizeof(short) * formatCodeListLength + // List of format codes
sizeof(short) + // Number of parameters
sizeof(int) * parameters.Count + // Parameter lengths
paramsLength + // Parameter values
sizeof(short) + // Number of result format codes
sizeof(short) * (unknownResultTypeList?.Length ?? 1); // Result format codes
var messageLength =
headerLength +
sizeof(short) * formatCodeListLength + // List of format codes
sizeof(short) + // Number of parameters
sizeof(int) * parameterCount + // Parameter lengths
parametersLength + // Parameter values
sizeof(short) + // Number of result format codes
sizeof(short) * unknownResultTypeCount; // Result format codes

WriteBuffer.StartMessage(messageLength);
if (WriteBuffer.WriteSpaceLeft < headerLength)
var writeBuffer = WriteBuffer;
writeBuffer.StartMessage(messageLength);
if (writeBuffer.WriteSpaceLeft < headerLength)
{
Debug.Assert(WriteBuffer.Size >= headerLength, "Write buffer too small for Bind header");
Debug.Assert(writeBuffer.Size >= headerLength, "Write buffer too small for Bind header");
await Flush(async, cancellationToken).ConfigureAwait(false);
}

WriteBuffer.WriteByte(FrontendMessageCode.Bind);
WriteBuffer.WriteInt32(messageLength - 1);
Debug.Assert(portal == string.Empty);
writeBuffer.WriteByte(0); // Portal is always empty

writeBuffer.WriteByte(FrontendMessageCode.Bind);
writeBuffer.WriteInt32(messageLength - 1);
writeBuffer.WriteByte(0); // Portal is always empty
writeBuffer.WriteNullTerminatedString(asciiName);
writeBuffer.WriteInt16((short)formatCodeListLength);

// 0 length implicitly means all-text, 1 means all-binary, >1 means mix-and-match
if (formatCodeListLength == 1)
// 0 length implicitly means all-text, 1 is a uniform format, >1 is per parameter formats
writeBuffer.WriteUInt16(formatCodeListLength);
if (formatCodeListLength is 1)
{
if (writeBuffer.WriteSpaceLeft < sizeof(short))
await Flush(async, cancellationToken).ConfigureAwait(false);
writeBuffer.WriteInt16(DataFormat.Binary.ToFormatCode());
writeBuffer.WriteInt16(binaryFormatCode);
}
else if (formatCodeListLength > 1)
{
for (var paramIndex = 0; paramIndex < parameters.Count; paramIndex++)
for (var i = 0; i < parameters.Count; i++)
{
if (writeBuffer.WriteSpaceLeft < sizeof(short))
await Flush(async, cancellationToken).ConfigureAwait(false);
writeBuffer.WriteInt16(parameters[paramIndex].Format.ToFormatCode());
writeBuffer.WriteInt16(parameters[i].Format.ToFormatCode());
}
}

if (writeBuffer.WriteSpaceLeft < sizeof(ushort))
await Flush(async, cancellationToken).ConfigureAwait(false);

writeBuffer.WriteUInt16((ushort)parameters.Count);
if (parameters.Count > 0)
writeBuffer.WriteUInt16(parameterCount);
if (parameterCount > 0)
{
var writer = writeBuffer.GetWriter(DatabaseInfo, async ? FlushMode.NonBlocking : FlushMode.Blocking);
try
{
for (var paramIndex = 0; paramIndex < parameters.Count; paramIndex++)
{
var param = parameters[paramIndex];
await param.Write(async, writer, cancellationToken).ConfigureAwait(false);
}
}
catch(Exception ex)
{
Break(ex);
throw;
}
for (var i = 0; i < parameters.Count; i++)
await parameters[i].Write(async, writer, cancellationToken).ConfigureAwait(false);
}

if (unknownResultTypeList != null)
{
if (writeBuffer.WriteSpaceLeft < 2 + unknownResultTypeList.Length * 2)
if (writeBuffer.WriteSpaceLeft < sizeof(short) + unknownResultTypeCount * sizeof(short))
await Flush(async, cancellationToken).ConfigureAwait(false);
writeBuffer.WriteInt16((short)unknownResultTypeList.Length);
foreach (var t in unknownResultTypeList)
writeBuffer.WriteInt16((short)(t ? 0 : 1));

writeBuffer.WriteInt16(unknownResultTypeCount);
for (var i = 0; i < unknownResultTypeList.Length; i++)
{
var unknownResultType = unknownResultTypeList[i];
writeBuffer.WriteInt16(unknownResultType ? textFormatCode : binaryFormatCode);
}
}
else
{
if (writeBuffer.WriteSpaceLeft < 4)
if (writeBuffer.WriteSpaceLeft < sizeof(short) + sizeof(short))
await Flush(async, cancellationToken).ConfigureAwait(false);
writeBuffer.WriteInt16(1);
writeBuffer.WriteInt16((short)(allResultTypesAreUnknown ? 0 : 1));

Debug.Assert(unknownResultTypeCount is 1);
writeBuffer.WriteInt16(unknownResultTypeCount);
writeBuffer.WriteInt16(allResultTypesAreUnknown ? textFormatCode : binaryFormatCode);
}
}

Expand Down Expand Up @@ -303,20 +308,24 @@ internal async Task WriteQuery(string sql, bool async, CancellationToken cancell
{
var queryByteLen = TextEncoding.GetByteCount(sql);

var len = sizeof(byte) +
sizeof(int) + // Message length (including self excluding code)
queryByteLen + // Query byte length
sizeof(byte);
var len =
sizeof(byte) + // Message code
sizeof(int) + // Message length
queryByteLen + // Query byte length
sizeof(byte); // Null terminator

WriteBuffer.StartMessage(len);
if (WriteBuffer.WriteSpaceLeft < 1 + 4)
if (WriteBuffer.WriteSpaceLeft < sizeof(byte) + sizeof(int))
{
Debug.Assert(WriteBuffer.Size >= sizeof(byte) + sizeof(int), "Write buffer too small for Parse header");
await Flush(async, cancellationToken).ConfigureAwait(false);
}

WriteBuffer.WriteByte(FrontendMessageCode.Query);
WriteBuffer.WriteInt32(len - 1);

await WriteBuffer.WriteString(sql, queryByteLen, async, cancellationToken).ConfigureAwait(false);
if (WriteBuffer.WriteSpaceLeft < 1)
if (WriteBuffer.WriteSpaceLeft < sizeof(byte))
await Flush(async, cancellationToken).ConfigureAwait(false);
WriteBuffer.WriteByte(0); // Null terminator
}
Expand Down Expand Up @@ -439,16 +448,7 @@ internal async Task WritePassword(byte[] payload, int offset, int count, bool as

WriteBuffer.WriteByte(FrontendMessageCode.Password);
WriteBuffer.WriteInt32(sizeof(int) + count);

if (count <= WriteBuffer.WriteSpaceLeft)
{
// The entire array fits in our WriteBuffer, copy it into the WriteBuffer as usual.
WriteBuffer.WriteBytes(payload, offset, count);
return;
}

await WriteBuffer.Flush(async, cancellationToken).ConfigureAwait(false);
await WriteBuffer.DirectWrite(new ReadOnlyMemory<byte>(payload, offset, count), async, cancellationToken).ConfigureAwait(false);
await WriteBuffer.WriteBytesRaw(new ReadOnlyMemory<byte>(payload, offset, count), async, cancellationToken).ConfigureAwait(false);
}

internal async Task WriteSASLInitialResponse(string mechanism, byte[] initialResponse, bool async, CancellationToken cancellationToken = default)
Expand Down
24 changes: 0 additions & 24 deletions src/Npgsql/Internal/NpgsqlWriteBuffer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -422,30 +422,6 @@ static async Task WriteBytesLong(NpgsqlWriteBuffer buffer, bool async, ReadOnlyM
}
}

public async Task WriteStreamRaw(Stream stream, int count, bool async, CancellationToken cancellationToken = default)
{
while (count > 0)
{
if (WriteSpaceLeft == 0)
await Flush(async, cancellationToken).ConfigureAwait(false);
try
{
var read = async
? await stream.ReadAsync(Buffer, WritePosition, Math.Min(WriteSpaceLeft, count), cancellationToken).ConfigureAwait(false)
: stream.Read(Buffer, WritePosition, Math.Min(WriteSpaceLeft, count));
if (read == 0)
throw new EndOfStreamException();
WritePosition += read;
count -= read;
}
catch (Exception e)
{
throw Connector.Break(new NpgsqlException("Exception while writing to stream", e));
}
}
Debug.Assert(count == 0);
}

public void WriteNullTerminatedString(string s)
{
AssertASCIIOnly(s);
Expand Down
Loading