Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Bugfix: RecursionError when reverse/righthand operations invoked. e.g…
…. __rsub__, __rmul__
  • Loading branch information
gertdreyer committed Feb 23, 2024
commit e7a3aba0733ab3b76ac7c87815df7b06ad057ab8
2 changes: 1 addition & 1 deletion src/runtime/ClassManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ private static ClassInfo GetClassInfo(Type type, ClassBase impl)
ci.members[pyName] = new MethodObject(type, name, forwardMethods).AllocObject();
// Only methods where only the right operand is the declaring type.
if (reverseMethods.Length > 0)
ci.members[pyNameReverse] = new MethodObject(type, name, reverseMethods).AllocObject();
ci.members[pyNameReverse] = new MethodObject(type, name, reverseMethods, reverse_args: true).AllocObject();
}
}

Expand Down
40 changes: 24 additions & 16 deletions src/runtime/MethodBinder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,22 @@ internal class MethodBinder

[NonSerialized]
public bool init = false;

public const bool DefaultAllowThreads = true;
public bool allow_threads = DefaultAllowThreads;

internal MethodBinder()
public bool args_reversed = false;

internal MethodBinder(bool reverse_args = false)
Comment thread
gertdreyer marked this conversation as resolved.
Outdated
{
list = new List<MaybeMethodBase>();
args_reversed = reverse_args;
}

internal MethodBinder(MethodInfo mi)
internal MethodBinder(MethodInfo mi, bool reverse_args = false)
{
list = new List<MaybeMethodBase> { new MaybeMethodBase(mi) };
args_reversed = reverse_args;
}

public int Count
Expand Down Expand Up @@ -271,10 +276,11 @@ internal static int ArgPrecedence(Type t)
/// <param name="inst">The Python target of the method invocation.</param>
/// <param name="args">The Python arguments.</param>
/// <param name="kw">The Python keyword arguments.</param>
/// <param name="reverse_args">Reverse arguments of methods. Used for methods such as __radd__, __rsub__, __rmod__ etc</param>
/// <returns>A Binding if successful. Otherwise null.</returns>
internal Binding? Bind(BorrowedReference inst, BorrowedReference args, BorrowedReference kw)
internal Binding? Bind(BorrowedReference inst, BorrowedReference args, BorrowedReference kw, bool reverse_args = false)
{
return Bind(inst, args, kw, null, null);
return Bind(inst, args, kw, null, null, reverse_args);
}

/// <summary>
Expand All @@ -287,10 +293,11 @@ internal static int ArgPrecedence(Type t)
/// <param name="args">The Python arguments.</param>
/// <param name="kw">The Python keyword arguments.</param>
/// <param name="info">If not null, only bind to that method.</param>
/// <param name="reverse_args">Reverse arguments of methods. Used for methods such as __radd__, __rsub__, __rmod__ etc</param>
/// <returns>A Binding if successful. Otherwise null.</returns>
internal Binding? Bind(BorrowedReference inst, BorrowedReference args, BorrowedReference kw, MethodBase? info)
internal Binding? Bind(BorrowedReference inst, BorrowedReference args, BorrowedReference kw, MethodBase? info, bool reverse_args = false)
{
return Bind(inst, args, kw, info, null);
return Bind(inst, args, kw, info, null, reverse_args);
}

private readonly struct MatchedMethod
Expand Down Expand Up @@ -334,8 +341,9 @@ public MismatchedMethod(Exception exception, MethodBase mb)
/// <param name="kw">The Python keyword arguments.</param>
/// <param name="info">If not null, only bind to that method.</param>
/// <param name="methodinfo">If not null, additionally attempt to bind to the generic methods in this array by inferring generic type parameters.</param>
/// <param name="reverse_args">Reverse arguments of methods. Used for methods such as __radd__, __rsub__, __rmod__ etc</param>
/// <returns>A Binding if successful. Otherwise null.</returns>
internal Binding? Bind(BorrowedReference inst, BorrowedReference args, BorrowedReference kw, MethodBase? info, MethodBase[]? methodinfo)
internal Binding? Bind(BorrowedReference inst, BorrowedReference args, BorrowedReference kw, MethodBase? info, MethodBase[]? methodinfo, bool reverse_args = false)
{
// loop to find match, return invoker w/ or w/o error
var kwargDict = new Dictionary<string, PyObject>();
Expand Down Expand Up @@ -363,10 +371,10 @@ public MismatchedMethod(Exception exception, MethodBase mb)
_methods = GetMethods();
}

return Bind(inst, args, kwargDict, _methods, matchGenerics: true);
return Bind(inst, args, kwargDict, _methods, matchGenerics: true, reverse_args);
}

static Binding? Bind(BorrowedReference inst, BorrowedReference args, Dictionary<string, PyObject> kwargDict, MethodBase[] methods, bool matchGenerics)
private static Binding? Bind(BorrowedReference inst, BorrowedReference args, Dictionary<string, PyObject> kwargDict, MethodBase[] methods, bool matchGenerics, bool reversed = false)
Comment thread
gertdreyer marked this conversation as resolved.
Outdated
{
var pynargs = (int)Runtime.PyTuple_Size(args);
var isGeneric = false;
Expand All @@ -386,7 +394,7 @@ public MismatchedMethod(Exception exception, MethodBase mb)
// Binary operator methods will have 2 CLR args but only one Python arg
// (unary operators will have 1 less each), since Python operator methods are bound.
isOperator = isOperator && pynargs == pi.Length - 1;
bool isReverse = isOperator && OperatorMethod.IsReverse((MethodInfo)mi); // Only cast if isOperator.
bool isReverse = isOperator && reversed; // Only cast if isOperator.
if (isReverse && OperatorMethod.IsComparisonOp((MethodInfo)mi))
continue; // Comparison operators in Python have no reverse mode.
if (!MatchesArgumentCount(pynargs, pi, kwargDict, out bool paramsArray, out ArrayList? defaultArgList, out int kwargsMatched, out int defaultsNeeded) && !isOperator)
Expand Down Expand Up @@ -809,14 +817,14 @@ static bool MatchesArgumentCount(int positionalArgumentCount, ParameterInfo[] pa
return match;
}

internal virtual NewReference Invoke(BorrowedReference inst, BorrowedReference args, BorrowedReference kw)
internal virtual NewReference Invoke(BorrowedReference inst, BorrowedReference args, BorrowedReference kw, bool reverse_args = false)
{
return Invoke(inst, args, kw, null, null);
return Invoke(inst, args, kw, null, null, reverse_args);
}

internal virtual NewReference Invoke(BorrowedReference inst, BorrowedReference args, BorrowedReference kw, MethodBase? info)
internal virtual NewReference Invoke(BorrowedReference inst, BorrowedReference args, BorrowedReference kw, MethodBase? info, bool reverse_args = false)
{
return Invoke(inst, args, kw, info, null);
return Invoke(inst, args, kw, info, null, reverse_args = false);
}

protected static void AppendArgumentTypes(StringBuilder to, BorrowedReference args)
Expand Down Expand Up @@ -852,7 +860,7 @@ protected static void AppendArgumentTypes(StringBuilder to, BorrowedReference ar
to.Append(')');
}

internal virtual NewReference Invoke(BorrowedReference inst, BorrowedReference args, BorrowedReference kw, MethodBase? info, MethodBase[]? methodinfo)
internal virtual NewReference Invoke(BorrowedReference inst, BorrowedReference args, BorrowedReference kw, MethodBase? info, MethodBase[]? methodinfo, bool reverse_args = false)
{
// No valid methods, nothing to bind.
if (GetMethods().Length == 0)
Expand All @@ -865,7 +873,7 @@ internal virtual NewReference Invoke(BorrowedReference inst, BorrowedReference a
return Exceptions.RaiseTypeError(msg.ToString());
}

Binding? binding = Bind(inst, args, kw, info, methodinfo);
Binding? binding = Bind(inst, args, kw, info, methodinfo, reverse_args);
object result;
IntPtr ts = IntPtr.Zero;

Expand Down
3 changes: 2 additions & 1 deletion src/runtime/Types/MethodBinding.cs
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ public static NewReference tp_call(BorrowedReference ob, BorrowedReference args,
}
}
}
return self.m.Invoke(target is null ? BorrowedReference.Null : target, args, kw, self.info.UnsafeValue);

return self.m.Invoke(target is null ? BorrowedReference.Null : target, args, kw, self.info.UnsafeValue, self.m.binder.args_reversed);
}
finally
{
Expand Down
18 changes: 9 additions & 9 deletions src/runtime/Types/MethodObject.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,20 @@ internal class MethodObject : ExtensionType
{
[NonSerialized]
private MethodBase[]? _info = null;

private readonly List<MaybeMethodInfo> infoList;
internal string name;
internal readonly MethodBinder binder;
internal bool is_static = false;

internal PyString? doc;
internal MaybeType type;

public MethodObject(MaybeType type, string name, MethodBase[] info, bool allow_threads)
public MethodObject(MaybeType type, string name, MethodBase[] info, bool allow_threads, bool reverse_args = false)
{
this.type = type;
this.name = name;
this.infoList = new List<MaybeMethodInfo>();
binder = new MethodBinder();
binder = new MethodBinder(reverse_args);
foreach (MethodBase item in info)
{
this.infoList.Add(item);
Expand All @@ -45,8 +45,8 @@ public MethodObject(MaybeType type, string name, MethodBase[] info, bool allow_t
binder.allow_threads = allow_threads;
}

public MethodObject(MaybeType type, string name, MethodBase[] info)
: this(type, name, info, allow_threads: AllowThreads(info))
public MethodObject(MaybeType type, string name, MethodBase[] info, bool reverse_args = false)
: this(type, name, info, allow_threads: AllowThreads(info), reverse_args)
{
}

Expand All @@ -67,14 +67,14 @@ internal MethodBase[] info
}
}

public virtual NewReference Invoke(BorrowedReference inst, BorrowedReference args, BorrowedReference kw)
public virtual NewReference Invoke(BorrowedReference inst, BorrowedReference args, BorrowedReference kw, bool reverse_args = false)
{
return Invoke(inst, args, kw, null);
return Invoke(inst, args, kw, null, reverse_args);
}

public virtual NewReference Invoke(BorrowedReference target, BorrowedReference args, BorrowedReference kw, MethodBase? info)
public virtual NewReference Invoke(BorrowedReference target, BorrowedReference args, BorrowedReference kw, MethodBase? info, bool reverse_args = false)
{
return binder.Invoke(target, args, kw, info, this.info);
return binder.Invoke(target, args, kw, info, this.info, reverse_args);
}

/// <summary>
Expand Down
18 changes: 6 additions & 12 deletions src/runtime/Types/OperatorMethod.cs
Original file line number Diff line number Diff line change
Expand Up @@ -177,17 +177,14 @@ public static string ReversePyMethodName(string pyName)
}

/// <summary>
/// Check if the method is performing a reverse operation.
/// Check if the method should have a reversed operation.
/// </summary>
/// <param name="method">The operator method.</param>
/// <returns></returns>
public static bool IsReverse(MethodBase method)
public static bool HaveReverse(MethodBase method)
{
Type primaryType = method.IsOpsHelper()
? method.DeclaringType.GetGenericArguments()[0]
: method.DeclaringType;
Type leftOperandType = method.GetParameters()[0].ParameterType;
return leftOperandType != primaryType;
var pi = method.GetParameters();
return OpMethodMap.ContainsKey(method.Name) && pi.Length == 2;
}

public static void FilterMethods(MethodBase[] methods, out MethodBase[] forwardMethods, out MethodBase[] reverseMethods)
Expand All @@ -196,14 +193,11 @@ public static void FilterMethods(MethodBase[] methods, out MethodBase[] forwardM
var reverseMethodsList = new List<MethodBase>();
foreach (var method in methods)
{
if (IsReverse(method))
forwardMethodsList.Add(method);
if (HaveReverse(method))
{
reverseMethodsList.Add(method);
} else
{
forwardMethodsList.Add(method);
}

}
forwardMethods = forwardMethodsList.ToArray();
reverseMethods = reverseMethodsList.ToArray();
Expand Down