Skip to content

Commit 789c66a

Browse files
committed
downcastasble
1 parent cee18f5 commit 789c66a

File tree

13 files changed

+54
-45
lines changed

13 files changed

+54
-45
lines changed

.cspell.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
"dedentations",
6161
"dedents",
6262
"deduped",
63+
"downcastable",
6364
"downcasted",
6465
"dumpable",
6566
"emscripten",

stdlib/src/array.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -667,7 +667,7 @@ mod array {
667667
ArrayContentType::from_char(spec).map_err(|err| vm.new_value_error(err))?;
668668

669669
if let OptionalArg::Present(init) = init {
670-
if let Some(init) = init.payload::<Self>() {
670+
if let Some(init) = init.downcast_ref::<Self>() {
671671
match (spec, init.read().typecode()) {
672672
(spec, ch) if spec == ch => array.frombytes(&init.get_bytes()),
673673
(spec, 'u') => {
@@ -681,7 +681,7 @@ mod array {
681681
}
682682
}
683683
}
684-
} else if let Some(wtf8) = init.payload::<PyStr>() {
684+
} else if let Some(wtf8) = init.downcast_ref::<PyStr>() {
685685
if spec == 'u' {
686686
let bytes = Self::_unicode_to_wchar_bytes(wtf8.as_wtf8(), array.itemsize());
687687
array.frombytes_move(bytes);
@@ -690,7 +690,7 @@ mod array {
690690
"cannot use a str to initialize an array with typecode '{spec}'"
691691
)));
692692
}
693-
} else if init.payload_is::<PyBytes>() || init.payload_is::<PyByteArray>() {
693+
} else if init.downcastable::<PyBytes>() || init.downcastable::<PyByteArray>() {
694694
init.try_bytes_like(vm, |x| array.frombytes(x))?;
695695
} else if let Ok(iter) = ArgIterable::try_from_object(vm, init.clone()) {
696696
for obj in iter.iter(vm)? {
@@ -765,7 +765,7 @@ mod array {
765765
let mut w = zelf.try_resizable(vm)?;
766766
if zelf.is(&obj) {
767767
w.imul(2, vm)
768-
} else if let Some(array) = obj.payload::<Self>() {
768+
} else if let Some(array) = obj.downcast_ref::<Self>() {
769769
w.iadd(&array.read(), vm)
770770
} else {
771771
let iter = ArgIterable::try_from_object(vm, obj)?;
@@ -1013,7 +1013,7 @@ mod array {
10131013
cloned = zelf.read().clone();
10141014
&cloned
10151015
} else {
1016-
match value.payload::<Self>() {
1016+
match value.downcast_ref::<Self>() {
10171017
Some(array) => {
10181018
guard = array.read();
10191019
&*guard
@@ -1059,7 +1059,7 @@ mod array {
10591059

10601060
#[pymethod]
10611061
fn __add__(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
1062-
if let Some(other) = other.payload::<Self>() {
1062+
if let Some(other) = other.downcast_ref::<Self>() {
10631063
self.read()
10641064
.add(&other.read(), vm)
10651065
.map(|array| Self::from(array).into_ref(&vm.ctx))
@@ -1079,7 +1079,7 @@ mod array {
10791079
) -> PyResult<PyRef<Self>> {
10801080
if zelf.is(&other) {
10811081
zelf.try_resizable(vm)?.imul(2, vm)?;
1082-
} else if let Some(other) = other.payload::<Self>() {
1082+
} else if let Some(other) = other.downcast_ref::<Self>() {
10831083
zelf.try_resizable(vm)?.iadd(&other.read(), vm)?;
10841084
} else {
10851085
return Err(vm.new_type_error(format!(

stdlib/src/select.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ mod decl {
350350
fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
351351
let timeout = if vm.is_none(&obj) {
352352
None
353-
} else if let Some(float) = obj.payload::<PyFloat>() {
353+
} else if let Some(float) = obj.downcast_ref::<PyFloat>() {
354354
let float = float.to_f64();
355355
if float.is_nan() {
356356
return Err(vm.new_value_error("Invalid value NaN (not a number)"));

stdlib/src/sqlite.rs

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,7 @@ mod _sqlite {
535535
let access = ptr_to_str(access, vm)?;
536536

537537
let val = callable.call((action, arg1, arg2, db_name, access), vm)?;
538-
let Some(val) = val.payload::<PyInt>() else {
538+
let Some(val) = val.downcast_ref::<PyInt>() else {
539539
return Ok(SQLITE_DENY);
540540
};
541541
val.try_to_primitive::<c_int>(vm)
@@ -1897,18 +1897,18 @@ mod _sqlite {
18971897
Ok(self
18981898
.description
18991899
.iter()
1900-
.map(|x| x.payload::<PyTuple>().unwrap().as_slice()[0].clone())
1900+
.map(|x| x.downcast_ref::<PyTuple>().unwrap().as_slice()[0].clone())
19011901
.collect())
19021902
}
19031903

19041904
fn subscript(&self, needle: &PyObject, vm: &VirtualMachine) -> PyResult {
1905-
if let Some(i) = needle.payload::<PyInt>() {
1905+
if let Some(i) = needle.downcast_ref::ref::<PyInt>() {
19061906
let i = i.try_to_primitive::<isize>(vm)?;
19071907
self.data.getitem_by_index(vm, i)
1908-
} else if let Some(name) = needle.payload::<PyStr>() {
1908+
} else if let Some(name) = needle.downcast_ref::<PyStr>() {
19091909
for (obj, i) in self.description.iter().zip(0..) {
1910-
let obj = &obj.payload::<PyTuple>().unwrap().as_slice()[0];
1911-
let Some(obj) = obj.payload::<PyStr>() else {
1910+
let obj = &obj.downcast_ref::<PyTuple>().unwrap().as_slice()[0];
1911+
let Some(obj) = obj.downcast_ref::<PyStr>() else {
19121912
break;
19131913
};
19141914
let a_iter = name.as_str().chars().flat_map(|x| x.to_uppercase());
@@ -1919,7 +1919,7 @@ mod _sqlite {
19191919
}
19201920
}
19211921
Err(vm.new_index_error("No item with that key"))
1922-
} else if let Some(slice) = needle.payload::<PySlice>() {
1922+
} else if let Some(slice) = needle.downcast_ref::<PySlice>() {
19231923
let list = self.data.getitem_by_slice(vm, slice.to_saturated(vm)?)?;
19241924
Ok(vm.ctx.new_tuple(list).into())
19251925
} else {
@@ -1962,7 +1962,7 @@ mod _sqlite {
19621962
vm: &VirtualMachine,
19631963
) -> PyResult<PyComparisonValue> {
19641964
op.eq_only(|| {
1965-
if let Some(other) = other.payload::<Self>() {
1965+
if let Some(other) = other.downcast_ref::<Self>() {
19661966
let eq = vm
19671967
.bool_eq(zelf.description.as_object(), other.description.as_object())?
19681968
&& vm.bool_eq(zelf.data.as_object(), other.data.as_object())?;
@@ -2179,7 +2179,7 @@ mod _sqlite {
21792179
let mut byte: u8 = 0;
21802180
let ret = inner.blob.read_single(&mut byte, index);
21812181
self.check(ret, vm).map(|_| vm.ctx.new_int(byte).into())
2182-
} else if let Some(slice) = needle.payload::<PySlice>() {
2182+
} else if let Some(slice) = needle.downcast_ref::<PySlice>() {
21832183
let blob_len = inner.blob.bytes();
21842184
let slice = slice.to_saturated(vm)?;
21852185
let (range, step, length) = slice.adjust_indices(blob_len as usize);
@@ -2220,7 +2220,7 @@ mod _sqlite {
22202220
let inner = self.inner(vm)?;
22212221

22222222
if let Some(index) = needle.try_index_opt(vm) {
2223-
let Some(value) = value.payload::<PyInt>() else {
2223+
let Some(value) = value.downcast_ref::<PyInt>() else {
22242224
return Err(vm.new_type_error(format!(
22252225
"'{}' object cannot be interpreted as an integer",
22262226
value.class()
@@ -2232,7 +2232,7 @@ mod _sqlite {
22322232
Self::expect_write(blob_len, 1, index, vm)?;
22332233
let ret = inner.blob.write_single(value, index);
22342234
self.check(ret, vm)
2235-
} else if let Some(_slice) = needle.payload::<PySlice>() {
2235+
} else if let Some(_slice) = needle.downcast_ref::<PySlice>() {
22362236
Err(vm.new_not_implemented_error("Blob slice assignment is not implemented"))
22372237
// let blob_len = inner.blob.bytes();
22382238
// let slice = slice.to_saturated(vm)?;
@@ -2645,15 +2645,15 @@ mod _sqlite {
26452645

26462646
let ret = if vm.is_none(obj) {
26472647
unsafe { sqlite3_bind_null(self.st, pos) }
2648-
} else if let Some(val) = obj.payload::<PyInt>() {
2648+
} else if let Some(val) = obj.downcast_ref::<PyInt>() {
26492649
let val = val.try_to_primitive::<i64>(vm).map_err(|_| {
26502650
vm.new_overflow_error("Python int too large to convert to SQLite INTEGER")
26512651
})?;
26522652
unsafe { sqlite3_bind_int64(self.st, pos, val) }
2653-
} else if let Some(val) = obj.payload::<PyFloat>() {
2653+
} else if let Some(val) = obj.downcast_ref::<PyFloat>() {
26542654
let val = val.to_f64();
26552655
unsafe { sqlite3_bind_double(self.st, pos, val) }
2656-
} else if let Some(val) = obj.payload::<PyStr>() {
2656+
} else if let Some(val) = obj.downcast_ref::<PyStr>() {
26572657
let (ptr, len) = str_to_ptr_len(val, vm)?;
26582658
unsafe { sqlite3_bind_text(self.st, pos, ptr, len, SQLITE_TRANSIENT()) }
26592659
} else if let Ok(buffer) = PyBuffer::try_from_borrowed_object(vm, obj) {
@@ -2900,11 +2900,11 @@ mod _sqlite {
29002900
unsafe {
29012901
if vm.is_none(val) {
29022902
sqlite3_result_null(self.ctx)
2903-
} else if let Some(val) = val.payload::<PyInt>() {
2903+
} else if let Some(val) = val.downcast_ref::<PyInt>() {
29042904
sqlite3_result_int64(self.ctx, val.try_to_primitive(vm)?)
2905-
} else if let Some(val) = val.payload::<PyFloat>() {
2905+
} else if let Some(val) = val.downcast_ref::<PyFloat>() {
29062906
sqlite3_result_double(self.ctx, val.to_f64())
2907-
} else if let Some(val) = val.payload::<PyStr>() {
2907+
} else if let Some(val) = val.downcast_ref::<PyStr>() {
29082908
let (ptr, len) = str_to_ptr_len(val, vm)?;
29092909
sqlite3_result_text(self.ctx, ptr, len, SQLITE_TRANSIENT())
29102910
} else if let Ok(buffer) = PyBuffer::try_from_borrowed_object(vm, val) {

vm/src/builtins/asyncgenerator.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,9 @@ impl PyAsyncGenAThrow {
344344
let ret = self.ag.inner.send(self.ag.as_object(), val, vm);
345345
if self.aclose {
346346
match ret {
347-
Ok(PyIterReturn::Return(v)) if v.payload_is::<PyAsyncGenWrappedValue>() => {
347+
Ok(PyIterReturn::Return(v))
348+
if v.downcastable::<PyAsyncGenWrappedValue>() =>
349+
{
348350
Err(self.yield_close(vm))
349351
}
350352
other => other
@@ -392,7 +394,7 @@ impl PyAsyncGenAThrow {
392394

393395
fn ignored_close(&self, res: &PyResult<PyIterReturn>) -> bool {
394396
res.as_ref().is_ok_and(|v| match v {
395-
PyIterReturn::Return(obj) => obj.payload_is::<PyAsyncGenWrappedValue>(),
397+
PyIterReturn::Return(obj) => obj.downcastable::<PyAsyncGenWrappedValue>(),
396398
PyIterReturn::StopIteration(_) => false,
397399
})
398400
}

vm/src/builtins/memory.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -909,13 +909,13 @@ impl TryFromObject for SubscriptNeedle {
909909
// TODO: number protocol
910910
if let Some(i) = obj.payload::<PyInt>() {
911911
Ok(Self::Index(i.try_to_primitive(vm)?))
912-
} else if obj.payload_is::<PySlice>() {
912+
} else if obj.downcastable::<PySlice>() {
913913
Ok(Self::Slice(unsafe { obj.downcast_unchecked::<PySlice>() }))
914914
} else if let Ok(i) = obj.try_index(vm) {
915915
Ok(Self::Index(i.try_to_primitive(vm)?))
916916
} else {
917917
if let Some(tuple) = obj.payload::<PyTuple>() {
918-
if tuple.iter().all(|x| x.payload_is::<PyInt>()) {
918+
if tuple.iter().all(|x| x.downcastable::<PyInt>()) {
919919
let v = tuple
920920
.iter()
921921
.map(|x| {
@@ -924,7 +924,7 @@ impl TryFromObject for SubscriptNeedle {
924924
})
925925
.try_collect()?;
926926
return Ok(Self::MultiIndex(v));
927-
} else if tuple.iter().all(|x| x.payload_is::<PySlice>()) {
927+
} else if tuple.iter().all(|x| x.downcastable::<PySlice>()) {
928928
return Err(vm.new_not_implemented_error(
929929
"multi-dimensional slicing is not implemented",
930930
));

vm/src/convert/transmute_from.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ unsafe impl<T: PyPayload> TransmuteFromObject for PyRef<T> {
1717
fn check(vm: &VirtualMachine, obj: &PyObject) -> PyResult<()> {
1818
let class = T::class(&vm.ctx);
1919
if obj.fast_isinstance(class) {
20-
if obj.payload_is::<T>() {
20+
if obj.downcastable::<T>() {
2121
Ok(())
2222
} else {
2323
Err(vm.new_downcast_runtime_error(class, obj))

vm/src/frame.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,7 +1077,7 @@ impl ExecutingFrame<'_> {
10771077
}
10781078
bytecode::Instruction::GetAwaitable => {
10791079
let awaited_obj = self.pop_value();
1080-
let awaitable = if awaited_obj.payload_is::<PyCoroutine>() {
1080+
let awaitable = if awaited_obj.downcastable::<PyCoroutine>() {
10811081
awaited_obj
10821082
} else {
10831083
let await_method = vm.get_method_or_type_error(
@@ -2337,7 +2337,7 @@ impl fmt::Debug for Frame {
23372337
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23382338
let state = self.state.lock();
23392339
let stack_str = state.stack.iter().fold(String::new(), |mut s, elem| {
2340-
if elem.payload_is::<Self>() {
2340+
if elem.downcastable::<Self>() {
23412341
s.push_str("\n > {frame}");
23422342
} else {
23432343
std::fmt::write(&mut s, format_args!("\n > {elem:?}")).unwrap();

vm/src/macros.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ macro_rules! match_class {
160160
// An arm taken when the object is an instance of the specified built-in
161161
// class.
162162
(match ($obj:expr) { $class:ty => $expr:expr, $($rest:tt)* }) => {
163-
if $obj.payload_is::<$class>() {
163+
if $obj.downcastable::<$class>() {
164164
$expr
165165
} else {
166166
$crate::match_class!(match ($obj) { $($rest)* })

vm/src/object/core.rs

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -528,22 +528,28 @@ impl PyObjectRef {
528528
Self { ptr }
529529
}
530530

531+
#[inline(always)]
532+
pub fn downcastable<T: PyObjectPayload>(&self) -> bool {
533+
self.payload_is::<T>()
534+
}
535+
531536
/// Attempt to downcast this reference to a subclass.
532537
///
533538
/// If the downcast fails, the original ref is returned in as `Err` so
534539
/// another downcast can be attempted without unnecessary cloning.
535540
#[inline(always)]
536541
pub fn downcast<T: PyObjectPayload>(self) -> Result<PyRef<T>, Self> {
537-
if self.payload_is::<T>() {
542+
if self.downcastable::<T>() {
538543
Ok(unsafe { self.downcast_unchecked() })
539544
} else {
540545
Err(self)
541546
}
542547
}
543548

549+
/// Attempt to downcast this reference to a subclass.
544550
#[inline(always)]
545551
pub fn downcast_ref<T: PyObjectPayload>(&self) -> Option<&Py<T>> {
546-
if self.payload_is::<T>() {
552+
if self.downcastable::<T>() {
547553
// SAFETY: just checked that the payload is T, and PyRef is repr(transparent) over
548554
// PyObjectRef
549555
Some(unsafe { &*(self as *const Self as *const PyRef<T>) })
@@ -570,7 +576,7 @@ impl PyObjectRef {
570576
/// T must be the exact payload type
571577
#[inline(always)]
572578
pub unsafe fn downcast_unchecked_ref<T: PyObjectPayload>(&self) -> &Py<T> {
573-
debug_assert!(self.payload_is::<T>());
579+
debug_assert!(self.downcastable::<T>());
574580
// SAFETY: requirements forwarded from caller
575581
unsafe { &*(self as *const Self as *const PyRef<T>) }
576582
}
@@ -589,10 +595,10 @@ impl PyObjectRef {
589595
if self.class().is(T::class(&vm.ctx)) {
590596
// TODO: is this always true?
591597
assert!(
592-
self.payload_is::<T>(),
598+
self.downcastable::<T>(),
593599
"obj.__class__ is T::class() but payload is not T"
594600
);
595-
// SAFETY: just asserted that payload_is::<T>()
601+
// SAFETY: just asserted that downcastable::<T>()
596602
Ok(unsafe { PyRefExact::new_unchecked(PyRef::from_obj_unchecked(self)) })
597603
} else {
598604
Err(self)
@@ -733,7 +739,7 @@ impl PyObject {
733739

734740
#[inline(always)]
735741
pub fn downcast_ref<T: PyObjectPayload>(&self) -> Option<&Py<T>> {
736-
if self.payload_is::<T>() {
742+
if self.downcastable::<T>() {
737743
// SAFETY: just checked that the payload is T, and PyRef is repr(transparent) over
738744
// PyObjectRef
739745
Some(unsafe { self.downcast_unchecked_ref::<T>() })
@@ -756,7 +762,7 @@ impl PyObject {
756762
/// T must be the exact payload type
757763
#[inline(always)]
758764
pub unsafe fn downcast_unchecked_ref<T: PyObjectPayload>(&self) -> &Py<T> {
759-
debug_assert!(self.payload_is::<T>());
765+
debug_assert!(self.downcastable::<T>());
760766
// SAFETY: requirements forwarded from caller
761767
unsafe { &*(self as *const Self as *const Py<T>) }
762768
}
@@ -1045,7 +1051,7 @@ impl<T: PyObjectPayload> PyRef<T> {
10451051
/// Safety: payload type of `obj` must be `T`
10461052
#[inline(always)]
10471053
unsafe fn from_obj_unchecked(obj: PyObjectRef) -> Self {
1048-
debug_assert!(obj.payload_is::<T>());
1054+
debug_assert!(obj.downcast_ref::<T>().is_some());
10491055
let obj = ManuallyDrop::new(obj);
10501056
Self {
10511057
ptr: obj.ptr.cast(),

0 commit comments

Comments
 (0)