Skip to content

Commit bd54d28

Browse files
committed
Fix for Python3, and add Oid cmp tests
1 parent 406c317 commit bd54d28

File tree

2 files changed

+35
-5
lines changed

2 files changed

+35
-5
lines changed

src/oid.c

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,31 @@ Oid_init(Oid *self, PyObject *args, PyObject *kw)
179179
}
180180

181181

182-
int
183-
Oid_compare(PyObject *o1, PyObject *o2)
182+
PyObject *
183+
Oid_richcompare(PyObject *o1, PyObject *o2, int op)
184184
{
185-
return git_oid_cmp(&((Oid*)o1)->oid, &((Oid*)o2)->oid);
185+
PyObject *res;
186+
187+
/* Support only equual (and not-equal). */
188+
if (op != Py_EQ && op != Py_NE) {
189+
PyErr_SetNone(PyExc_TypeError);
190+
return NULL;
191+
}
192+
193+
/* Comparing to something else than an Oid is not supported. */
194+
if (!PyObject_TypeCheck(o2, &OidType)) {
195+
Py_INCREF(Py_NotImplemented);
196+
return Py_NotImplemented;
197+
}
198+
199+
/* Ok go. */
200+
if (git_oid_cmp(&((Oid*)o1)->oid, &((Oid*)o2)->oid) == 0)
201+
res = (op == Py_EQ) ? Py_True : Py_False;
202+
else
203+
res = (op == Py_EQ) ? Py_False : Py_True;
204+
205+
Py_INCREF(res);
206+
return res;
186207
}
187208

188209

@@ -220,7 +241,7 @@ PyTypeObject OidType = {
220241
0, /* tp_print */
221242
0, /* tp_getattr */
222243
0, /* tp_setattr */
223-
(cmpfunc)Oid_compare, /* tp_compare */
244+
0, /* tp_compare */
224245
0, /* tp_repr */
225246
0, /* tp_as_number */
226247
0, /* tp_as_sequence */
@@ -235,7 +256,7 @@ PyTypeObject OidType = {
235256
Oid__doc__, /* tp_doc */
236257
0, /* tp_traverse */
237258
0, /* tp_clear */
238-
0, /* tp_richcompare */
259+
(richcmpfunc)Oid_richcompare, /* tp_richcompare */
239260
0, /* tp_weaklistoffset */
240261
0, /* tp_iter */
241262
0, /* tp_iternext */

test/test_oid.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,21 @@ def test_long(self):
6767

6868
def test_cmp(self):
6969
oid1 = Oid(raw=RAW)
70+
71+
# Equal
7072
oid2 = Oid(hex=HEX)
7173
self.assertEqual(oid1, oid2)
7274

75+
# Not equal
7376
oid2 = Oid(hex="15b648aec6ed045b5ca6f57f8b7831a8b4757299")
7477
self.assertNotEqual(oid1, oid2)
7578

79+
# Other
80+
with self.assertRaises(TypeError): oid1 < oid2
81+
with self.assertRaises(TypeError): oid1 <= oid2
82+
with self.assertRaises(TypeError): oid1 > oid2
83+
with self.assertRaises(TypeError): oid1 >= oid2
84+
7685

7786
if __name__ == '__main__':
7887
unittest.main()

0 commit comments

Comments
 (0)