Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ This document follows the conventions laid out in [Keep a CHANGELOG][].
### Added

### Changed
- Added a `FormatterFactory` member in RuntimeData to create formatters with parameters. For compatibility, the `FormatterType` member is still present and has precedence when defining both `FormatterFactory` and `FormatterType`
- Added a post-serialization and a pre-deserialization step callbacks to extend (de)serialization process
- Added an API to stash serialized data on Python capsules

### Fixed

Expand Down
14 changes: 14 additions & 0 deletions src/runtime/StateSerialization/NoopFormatter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using System;
using System.IO;
using System.Runtime.Serialization;

namespace Python.Runtime;

public class NoopFormatter : IFormatter {
public object Deserialize(Stream s) => throw new NotImplementedException();
public void Serialize(Stream s, object o) {}

public SerializationBinder? Binder { get; set; }
public StreamingContext Context { get; set; }
public ISurrogateSelector? SurrogateSelector { get; set; }
}
131 changes: 125 additions & 6 deletions src/runtime/StateSerialization/RuntimeData.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Diagnostics;
using System.IO;
using System.Linq;
Expand All @@ -17,7 +15,34 @@ namespace Python.Runtime
{
public static class RuntimeData
{
private static Type? _formatterType;

public readonly static Func<IFormatter> DefaultFormatterFactory = () =>
{
try
{
return new BinaryFormatter();
}
catch
{
return new NoopFormatter();
}
};

private static Func<IFormatter> _formatterFactory { get; set; } = DefaultFormatterFactory;

public static Func<IFormatter> FormatterFactory
{
get => _formatterFactory;
set
{
if (value == null)
throw new ArgumentNullException(nameof(value));

_formatterFactory = value;
}
}

private static Type? _formatterType = null;
public static Type? FormatterType
{
get => _formatterType;
Expand All @@ -31,6 +56,14 @@ public static Type? FormatterType
}
}

/// <summary>
/// Callback called as a last step in the serialization process
/// </summary>
public static Action? PostStashHook { get; set; } = null;
/// <summary>
/// Callback called as the first step in the deserialization process
/// </summary>
public static Action? PreRestoreHook { get; set; } = null;
public static ICLRObjectStorer? WrappersStorer { get; set; }

/// <summary>
Expand Down Expand Up @@ -74,6 +107,7 @@ internal static void Stash()
using NewReference capsule = PyCapsule_New(mem, IntPtr.Zero, IntPtr.Zero);
int res = PySys_SetObject("clr_data", capsule.BorrowOrThrow());
PythonException.ThrowIfIsNotZero(res);
PostStashHook?.Invoke();
}

internal static void RestoreRuntimeData()
Expand All @@ -90,6 +124,7 @@ internal static void RestoreRuntimeData()

private static void RestoreRuntimeDataImpl()
{
PreRestoreHook?.Invoke();
BorrowedReference capsule = PySys_GetObject("clr_data");
if (capsule.IsNull)
{
Expand Down Expand Up @@ -250,11 +285,95 @@ private static void RestoreRuntimeDataObjects(SharedObjectsState storage)
}
}

static readonly string serialization_key_namepsace = "pythonnet_serialization_";
/// <summary>
/// Removes the serialization capsule from the `sys` module object.
/// </summary>
/// <remarks>
/// The serialization data must have been set with <code>StashSerializationData</code>
/// </remarks>
/// <param name="key">The name given to the capsule on the `sys` module object</param>
public static void FreeSerializationData(string key)
{
key = serialization_key_namepsace + key;
BorrowedReference oldCapsule = PySys_GetObject(key);
if (!oldCapsule.IsNull)
{
IntPtr oldData = PyCapsule_GetPointer(oldCapsule, IntPtr.Zero);
Marshal.FreeHGlobal(oldData);
PyCapsule_SetPointer(oldCapsule, IntPtr.Zero);
PySys_SetObject(key, null);
}
}

/// <summary>
/// Stores the data in the <paramref name="stream"/> argument in a Python capsule and stores
/// the capsule on the `sys` module object with the name <paramref name="key"/>.
/// </summary>
/// <remarks>
/// No checks on pre-existing names on the `sys` module object are made.
/// </remarks>
/// <param name="key">The name given to the capsule on the `sys` module object</param>
/// <param name="stream">A MemoryStream that contains the data to be placed in the capsule</param>
public static void StashSerializationData(string key, MemoryStream stream)
{
var data = stream.GetBuffer();
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might have extra data in the end

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by extra data? The stream container is longer than the data in it? The code is using Length not Capacity, so it should not be an issue, no?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that stream.GetBuffer() is really just the internal buffer. The data represented may

  • start at a different index than 0 (if a non-zero index is passed in the constructor)
  • be less than the full buffer

The primitive solution is to use ToArray, but that creates another copy. Since .NET 4.6, there is also a TryGetBuffer API that outputs an ArraySegment containing all necessary information. I'll push a commit with the respective adjustment.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 6e37568

IntPtr mem = Marshal.AllocHGlobal(IntPtr.Size + data.Length);
Comment thread
filmor marked this conversation as resolved.
Outdated
// store the length of the buffer first
Marshal.WriteIntPtr(mem, (IntPtr)data.Length);
Marshal.Copy(data, 0, mem + IntPtr.Size, data.Length);

try
{
using NewReference capsule = PyCapsule_New(mem, IntPtr.Zero, IntPtr.Zero);
int res = PySys_SetObject(key, capsule.BorrowOrThrow());
PythonException.ThrowIfIsNotZero(res);
}
catch
{
Marshal.FreeHGlobal(mem);
}

}

static byte[] emptyBuffer = new byte[0];
/// <summary>
/// Retreives the previously stored data on a Python capsule.
/// Throws if the object corresponding to the <paramref name="key"/> parameter
/// on the `sys` module object is not a capsule.
/// </summary>
/// <param name="key">The name given to the capsule on the `sys` module object</param>
/// <returns>A MemoryStream containing the previously saved serialization data.
/// The stream is empty if no name matches the key. </returns>
public static MemoryStream GetSerializationData(string key)
{
BorrowedReference capsule = PySys_GetObject(key);
if (capsule.IsNull)
{
// nothing to do.
return new MemoryStream(emptyBuffer, writable:false);
}
var ptr = PyCapsule_GetPointer(capsule, IntPtr.Zero);
if (ptr == IntPtr.Zero)
{
// The PyCapsule API returns NULL on error; NULL cannot be stored
// as a capsule's value
PythonException.ThrowIfIsNull(null);
}
var len = (int)Marshal.ReadIntPtr(ptr);
byte[] buffer = new byte[len];
Marshal.Copy(ptr+IntPtr.Size, buffer, 0, len);
return new MemoryStream(buffer, writable:false);
}

internal static IFormatter CreateFormatter()
{
return FormatterType != null ?
(IFormatter)Activator.CreateInstance(FormatterType)
: new BinaryFormatter();

if (FormatterType != null)
{
return (IFormatter)Activator.CreateInstance(FormatterType);
}
return FormatterFactory();
}
}
}
117 changes: 117 additions & 0 deletions tests/domain_tests/TestRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,66 @@ import System

",
},
new TestCase
{
Name = "test_serialize_unserializable_object",
DotNetBefore = @"
namespace TestNamespace
{
public class NotSerializableTextWriter : System.IO.TextWriter
{
override public System.Text.Encoding Encoding { get { return System.Text.Encoding.ASCII;} }
}
[System.Serializable]
public static class SerializableWriter
{
private static System.IO.TextWriter _writer = null;
public static System.IO.TextWriter Writer {get { return _writer; }}
public static void CreateInternalWriter()
{
_writer = System.IO.TextWriter.Synchronized(new NotSerializableTextWriter());
}
}
}
",
DotNetAfter = @"
namespace TestNamespace
{
public class NotSerializableTextWriter : System.IO.TextWriter
{
override public System.Text.Encoding Encoding { get { return System.Text.Encoding.ASCII;} }
}
[System.Serializable]
public static class SerializableWriter
{
private static System.IO.TextWriter _writer = null;
public static System.IO.TextWriter Writer {get { return _writer; }}
public static void CreateInternalWriter()
{
_writer = System.IO.TextWriter.Synchronized(new NotSerializableTextWriter());
}
}
}
",
PythonCode = @"
import sys

def before_reload():
import clr
import System
clr.AddReference('DomainTests')
import TestNamespace
TestNamespace.SerializableWriter.CreateInternalWriter();
sys.__obj = TestNamespace.SerializableWriter.Writer
sys.__obj.WriteLine('test')

def after_reload():
import clr
import System
sys.__obj.WriteLine('test')

",
}
};

/// <summary>
Expand All @@ -1142,7 +1202,59 @@ import System
const string CaseRunnerTemplate = @"
using System;
using System.IO;
using System.Runtime.Serialization;
using System.Runtime.Serialization.Formatters.Binary;
using Python.Runtime;

namespace Serialization
{{
// Classes in this namespace is mostly useful for test_serialize_unserializable_object
class NotSerializableSerializer : ISerializationSurrogate
{{
public NotSerializableSerializer()
{{
}}
public void GetObjectData(object obj, SerializationInfo info, StreamingContext context)
{{
info.AddValue(""notSerialized_tp"", obj.GetType());
}}
public object SetObjectData(object obj, SerializationInfo info, StreamingContext context, ISurrogateSelector selector)
{{
if (info == null)
{{
return null;
}}
Type typeObj = info.GetValue(""notSerialized_tp"", typeof(Type)) as Type;
if (typeObj == null)
{{
return null;
}}

obj = Activator.CreateInstance(typeObj);
return obj;
}}
}}
class NonSerializableSelector : SurrogateSelector
{{
public override ISerializationSurrogate GetSurrogate(Type type, StreamingContext context, out ISurrogateSelector selector)
{{
if (type == null)
{{
throw new ArgumentNullException();
}}
selector = (ISurrogateSelector)this;
if (type.IsSerializable)
{{
return null; // use whichever default
}}
else
{{
return (ISerializationSurrogate)(new NotSerializableSerializer());
}}
}}
}}
}}

namespace CaseRunner
{{
class CaseRunner
Expand All @@ -1151,6 +1263,11 @@ public static int Main()
{{
try
{{
RuntimeData.FormatterFactory = () =>
Comment thread
filmor marked this conversation as resolved.
{{
return new BinaryFormatter(){{SurrogateSelector = new Serialization.NonSerializableSelector()}};
}};

PythonEngine.Initialize();
using (Py.GIL())
{{
Expand Down
3 changes: 3 additions & 0 deletions tests/domain_tests/test_domain_reload.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,6 @@ def test_nested_type():

def test_import_after_reload():
_run_test("import_after_reload")

def test_import_after_reload():
_run_test("test_serialize_unserializable_object")