Skip to content

Commit 5130aaa

Browse files
committed
Overload operators except for bit shifts and default args.
1 parent 1f40564 commit 5130aaa

8 files changed

Lines changed: 389 additions & 4 deletions

File tree

src/embed_tests/TestPyMethod.cs

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
using NUnit.Framework;
2+
3+
using Python.Runtime;
4+
5+
using System.Linq;
6+
using System.Reflection;
7+
8+
namespace Python.EmbeddingTest
9+
{
10+
public class TestPyMethod
11+
{
12+
[OneTimeSetUp]
13+
public void SetUp()
14+
{
15+
PythonEngine.Initialize();
16+
}
17+
18+
[OneTimeTearDown]
19+
public void Dispose()
20+
{
21+
PythonEngine.Shutdown();
22+
}
23+
24+
public class SampleClass
25+
{
26+
public int VoidCall() => 10;
27+
28+
public int Foo(int a, int b = 10) => a + b;
29+
30+
public int Foo2(int a = 10, params int[] args)
31+
{
32+
return a + args.Sum();
33+
}
34+
}
35+
36+
[Test]
37+
public void TestVoidCall()
38+
{
39+
string name = string.Format("{0}.{1}",
40+
typeof(SampleClass).DeclaringType.Name,
41+
typeof(SampleClass).Name);
42+
string module = MethodBase.GetCurrentMethod().DeclaringType.Namespace;
43+
PythonEngine.Exec($@"
44+
from {module} import *
45+
SampleClass = {name}
46+
obj = SampleClass()
47+
assert obj.VoidCall() == 10
48+
");
49+
}
50+
51+
[Test]
52+
public void TestDefaultParameter()
53+
{
54+
string name = string.Format("{0}.{1}",
55+
typeof(SampleClass).DeclaringType.Name,
56+
typeof(SampleClass).Name);
57+
string module = MethodBase.GetCurrentMethod().DeclaringType.Namespace;
58+
59+
PythonEngine.Exec($@"
60+
from {module} import *
61+
SampleClass = {name}
62+
obj = SampleClass()
63+
assert obj.Foo(10) == 20
64+
assert obj.Foo(10, 1) == 11
65+
66+
assert obj.Foo2() == 10
67+
assert obj.Foo2(20) == 20
68+
assert obj.Foo2(20, 30) == 50
69+
assert obj.Foo2(20, 30, 50) == 100
70+
");
71+
}
72+
73+
public class OperableObject
74+
{
75+
public int Num { get; set; }
76+
77+
public OperableObject(int num)
78+
{
79+
Num = num;
80+
}
81+
82+
public static OperableObject operator +(OperableObject a, OperableObject b)
83+
{
84+
return new OperableObject(a.Num + b.Num);
85+
}
86+
87+
public static OperableObject operator -(OperableObject a, OperableObject b)
88+
{
89+
return new OperableObject(a.Num - b.Num);
90+
}
91+
92+
public static OperableObject operator *(OperableObject a, OperableObject b)
93+
{
94+
return new OperableObject(a.Num * b.Num);
95+
}
96+
97+
public static OperableObject operator /(OperableObject a, OperableObject b)
98+
{
99+
return new OperableObject(a.Num / b.Num);
100+
}
101+
102+
public static OperableObject operator &(OperableObject a, OperableObject b)
103+
{
104+
return new OperableObject(a.Num & b.Num);
105+
}
106+
107+
public static OperableObject operator |(OperableObject a, OperableObject b)
108+
{
109+
return new OperableObject(a.Num | b.Num);
110+
}
111+
112+
public static OperableObject operator ^(OperableObject a, OperableObject b)
113+
{
114+
return new OperableObject(a.Num ^ b.Num);
115+
}
116+
117+
public static OperableObject operator <<(OperableObject a, int offset)
118+
{
119+
return new OperableObject(a.Num << offset);
120+
}
121+
122+
public static OperableObject operator >>(OperableObject a, int offset)
123+
{
124+
return new OperableObject(a.Num >> offset);
125+
}
126+
}
127+
128+
[Test]
129+
public void OperatorOverloads()
130+
{
131+
string name = string.Format("{0}.{1}",
132+
typeof(OperableObject).DeclaringType.Name,
133+
typeof(OperableObject).Name);
134+
string module = MethodBase.GetCurrentMethod().DeclaringType.Namespace;
135+
136+
PythonEngine.Exec($@"
137+
from {module} import *
138+
cls = {name}
139+
a = cls(2)
140+
b = cls(10)
141+
c = a + b
142+
assert c.Num == a.Num + b.Num
143+
144+
c = a - b
145+
assert c.Num == a.Num - b.Num
146+
147+
c = a * b
148+
assert c.Num == a.Num * b.Num
149+
150+
c = a / b
151+
assert c.Num == a.Num // b.Num
152+
153+
c = a & b
154+
assert c.Num == a.Num & b.Num
155+
156+
c = a | b
157+
assert c.Num == a.Num | b.Num
158+
159+
c = a ^ b
160+
assert c.Num == a.Num ^ b.Num
161+
");
162+
}
163+
[Test]
164+
public void BitOperatorOverloads()
165+
{
166+
string name = string.Format("{0}.{1}",
167+
typeof(OperableObject).DeclaringType.Name,
168+
typeof(OperableObject).Name);
169+
string module = MethodBase.GetCurrentMethod().DeclaringType.Namespace;
170+
171+
PythonEngine.Exec($@"
172+
from {module} import *
173+
cls = {name}
174+
a = cls(2)
175+
b = cls(10)
176+
177+
c = a << b.Num
178+
assert c.Num == a.Num << b.Num
179+
180+
c = a >> b.Num
181+
assert c.Num == a.Num >> b.Num
182+
");
183+
}
184+
}
185+
}

src/runtime/classmanager.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,10 @@ private static ClassInfo GetClassInfo(Type type)
470470

471471
ob = new MethodObject(type, name, mlist);
472472
ci.members[name] = ob;
473+
if (OperatorMethod.IsOperatorMethod(name))
474+
{
475+
ci.members[OperatorMethod.GetPyMethodName(name)] = ob;
476+
}
473477
}
474478

475479
if (ci.indexer == null && type.IsClass)

src/runtime/methodbinder.cs

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -342,15 +342,34 @@ internal Binding Bind(IntPtr inst, IntPtr args, IntPtr kw, MethodBase info, Meth
342342
bool paramsArray;
343343
int kwargsMatched;
344344
int defaultsNeeded;
345-
345+
bool isOperator = OperatorMethod.IsOperatorMethod(mi); // e.g. op_Addition is defined for OperableObject
346346
if (!MatchesArgumentCount(pynargs, pi, kwargDict, out paramsArray, out defaultArgList, out kwargsMatched, out defaultsNeeded))
347347
{
348-
continue;
348+
if (isOperator)
349+
{
350+
defaultArgList = null;
351+
}
352+
else { continue; }
349353
}
350354
var outs = 0;
355+
int clrnargs = pi.Length;
356+
isOperator = isOperator && pynargs == clrnargs - 1; // Handle mismatched arg numbers due to Python operator being bound.
351357
var margs = TryConvertArguments(pi, paramsArray, args, pynargs, kwargDict, defaultArgList,
352-
needsResolution: _methods.Length > 1,
358+
needsResolution: _methods.Length > 1, // If there's more than one possible match.
359+
isOperator: isOperator,
353360
outs: out outs);
361+
if (isOperator)
362+
{
363+
if (inst != IntPtr.Zero)
364+
{
365+
var co = ManagedType.GetManagedObject(inst) as CLRObject;
366+
if (co == null)
367+
{
368+
break;
369+
}
370+
margs[0] = co.inst;
371+
}
372+
}
354373

355374
if (margs == null)
356375
{
@@ -474,13 +493,15 @@ static IntPtr HandleParamsArray(IntPtr args, int arrayStart, int pyArgCount, out
474493
/// <param name="kwargDict">Dictionary of keyword argument name to python object pointer</param>
475494
/// <param name="defaultArgList">A list of default values for omitted parameters</param>
476495
/// <param name="needsResolution"><c>true</c>, if overloading resolution is required</param>
496+
/// <param name="isOperator"><c>true</c>, if is operator method</param>
477497
/// <param name="outs">Returns number of output parameters</param>
478498
/// <returns>An array of .NET arguments, that can be passed to a method.</returns>
479499
static object[] TryConvertArguments(ParameterInfo[] pi, bool paramsArray,
480500
IntPtr args, int pyArgCount,
481501
Dictionary<string, IntPtr> kwargDict,
482502
ArrayList defaultArgList,
483503
bool needsResolution,
504+
bool isOperator,
484505
out int outs)
485506
{
486507
outs = 0;
@@ -519,6 +540,12 @@ static object[] TryConvertArguments(ParameterInfo[] pi, bool paramsArray,
519540
op = Runtime.PyTuple_GetItem(args, paramIndex);
520541
}
521542
}
543+
if (isOperator && paramIndex == 0)
544+
{
545+
// After we've obtained the first argument from Python, we need to skip the first argument of the CLR
546+
// because operator method is a bound method in Python
547+
paramIndex++; // Leave the first .NET param as null (margs).
548+
}
522549

523550
bool isOut;
524551
if (!TryConvertArgument(op, parameter.ParameterType, needsResolution, out margs[paramIndex], out isOut))
@@ -543,6 +570,15 @@ static object[] TryConvertArguments(ParameterInfo[] pi, bool paramsArray,
543570
return margs;
544571
}
545572

573+
/// <summary>
574+
/// Try to convert a Python argument object to a managed CLR type.
575+
/// </summary>
576+
/// <param name="op">Pointer to the object at a particular parameter.</param>
577+
/// <param name="parameterType">That parameter's managed type.</param>
578+
/// <param name="needsResolution">There are multiple overloading methods that need resolution.</param>
579+
/// <param name="arg">Converted argument.</param>
580+
/// <param name="isOut">Whether the CLR type is passed by reference.</param>
581+
/// <returns></returns>
546582
static bool TryConvertArgument(IntPtr op, Type parameterType, bool needsResolution,
547583
out object arg, out bool isOut)
548584
{
@@ -633,7 +669,17 @@ static Type TryComputeClrArgumentType(Type parameterType, IntPtr argument, bool
633669

634670
return clrtype;
635671
}
636-
672+
/// <summary>
673+
/// Check whether the number of Python and .NET arguments match, and compute additional arg information.
674+
/// </summary>
675+
/// <param name="positionalArgumentCount">Number of positional args passed from Python.</param>
676+
/// <param name="parameters">Parameters of the specified .NET method.</param>
677+
/// <param name="kwargDict">Keyword args passed from Python.</param>
678+
/// <param name="paramsArray">True if the final param of the .NET method is an array (`params` keyword).</param>
679+
/// <param name="defaultArgList">List of default values for arguments.</param>
680+
/// <param name="kwargsMatched">Number of kwargs from Python that are also present in the .NET method.</param>
681+
/// <param name="defaultsNeeded">Number of non-null defaultsArgs.</param>
682+
/// <returns></returns>
637683
static bool MatchesArgumentCount(int positionalArgumentCount, ParameterInfo[] parameters,
638684
Dictionary<string, IntPtr> kwargDict,
639685
out bool paramsArray,

src/runtime/native/ITypeOffsets.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,16 @@ interface ITypeOffsets
1515
int mp_subscript { get; }
1616
int name { get; }
1717
int nb_add { get; }
18+
int nb_subtract { get; }
19+
int nb_multiply { get; }
20+
int nb_true_divide { get; }
21+
int nb_and { get; }
22+
int nb_or { get; }
23+
int nb_xor { get; }
24+
int nb_lshift { get; }
25+
int nb_rshift { get; }
26+
int nb_remainder { get; }
27+
int nb_invert { get; }
1828
int nb_inplace_add { get; }
1929
int nb_inplace_subtract { get; }
2030
int ob_size { get; }

src/runtime/native/TypeOffset.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,16 @@ static partial class TypeOffset
2222
internal static int mp_subscript { get; private set; }
2323
internal static int name { get; private set; }
2424
internal static int nb_add { get; private set; }
25+
internal static int nb_subtract { get; private set; }
26+
internal static int nb_multiply { get; private set; }
27+
internal static int nb_true_divide { get; private set; }
28+
internal static int nb_and { get; private set; }
29+
internal static int nb_or { get; private set; }
30+
internal static int nb_xor { get; private set; }
31+
internal static int nb_lshift { get; private set; }
32+
internal static int nb_rshift { get; private set; }
33+
internal static int nb_remainder { get; private set; }
34+
internal static int nb_invert { get; private set; }
2535
internal static int nb_inplace_add { get; private set; }
2636
internal static int nb_inplace_subtract { get; private set; }
2737
internal static int ob_size { get; private set; }

0 commit comments

Comments
 (0)