Skip to content

Commit b83443c

Browse files
committed
Partially implement async generators
1 parent e97937a commit b83443c

File tree

8 files changed

+378
-17
lines changed

8 files changed

+378
-17
lines changed

compiler/src/compile.rs

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ struct CompileContext {
3939
func: FunctionContext,
4040
}
4141

42-
#[derive(Clone, Copy)]
42+
#[derive(Clone, Copy, PartialEq)]
4343
enum FunctionContext {
4444
NoFunction,
4545
Function,
@@ -562,6 +562,17 @@ impl<O: OutputStream> Compiler<O> {
562562
}
563563
match value {
564564
Some(v) => {
565+
<<<<<<< Updated upstream
566+
=======
567+
if self.ctx.func == FunctionContext::AsyncFunction
568+
&& self.current_output().is_generator()
569+
{
570+
return Err(self.error_loc(
571+
CompileErrorType::AsyncReturnValue,
572+
statement.location.clone(),
573+
));
574+
}
575+
>>>>>>> Stashed changes
565576
self.compile_expression(v)?;
566577
}
567578
None => {
@@ -1543,6 +1554,8 @@ impl<O: OutputStream> Compiler<O> {
15431554
self.set_source_location(expression.location);
15441555

15451556
use ast::ExpressionType::*;
1557+
#[allow(unused_imports)] // not unused, overrides ast::ExpressionType::None
1558+
use Option::None;
15461559
match &expression.node {
15471560
Call {
15481561
function,
@@ -1637,13 +1650,25 @@ impl<O: OutputStream> Compiler<O> {
16371650
self.mark_generator();
16381651
match value {
16391652
Some(expression) => self.compile_expression(expression)?,
1640-
Option::None => self.emit(Instruction::LoadConst {
1653+
None => self.emit(Instruction::LoadConst {
16411654
value: bytecode::Constant::None,
16421655
}),
16431656
};
16441657
self.emit(Instruction::YieldValue);
16451658
}
16461659
Await { value } => {
1660+
if self.ctx.func != FunctionContext::AsyncFunction {
1661+
<<<<<<< Updated upstream
1662+
return Err(CompileError {
1663+
statement: None,
1664+
error: CompileErrorType::InvalidAwait,
1665+
location: self.current_source_location.clone(),
1666+
source_path: None,
1667+
});
1668+
=======
1669+
return Err(self.error(CompileErrorType::InvalidAwait));
1670+
>>>>>>> Stashed changes
1671+
}
16471672
self.compile_expression(value)?;
16481673
self.emit(Instruction::GetAwaitable);
16491674
self.emit(Instruction::LoadConst {
@@ -1652,6 +1677,32 @@ impl<O: OutputStream> Compiler<O> {
16521677
self.emit(Instruction::YieldFrom);
16531678
}
16541679
YieldFrom { value } => {
1680+
match self.ctx.func {
1681+
FunctionContext::NoFunction => {
1682+
<<<<<<< Updated upstream
1683+
return Err(CompileError {
1684+
statement: None,
1685+
error: CompileErrorType::InvalidYieldFrom,
1686+
location: self.current_source_location.clone(),
1687+
source_path: None,
1688+
})
1689+
}
1690+
FunctionContext::AsyncFunction => {
1691+
return Err(CompileError {
1692+
statement: None,
1693+
error: CompileErrorType::AsyncYieldFrom,
1694+
location: self.current_source_location.clone(),
1695+
source_path: None,
1696+
})
1697+
=======
1698+
return Err(self.error(CompileErrorType::InvalidYieldFrom))
1699+
}
1700+
FunctionContext::AsyncFunction => {
1701+
return Err(self.error(CompileErrorType::AsyncYieldFrom))
1702+
>>>>>>> Stashed changes
1703+
}
1704+
FunctionContext::Function => {}
1705+
}
16551706
self.mark_generator();
16561707
self.compile_expression(value)?;
16571708
self.emit(Instruction::GetIter);
@@ -1670,7 +1721,7 @@ impl<O: OutputStream> Compiler<O> {
16701721
value: bytecode::Constant::Boolean { value: false },
16711722
});
16721723
}
1673-
None => {
1724+
ast::ExpressionType::None => {
16741725
self.emit(Instruction::LoadConst {
16751726
value: bytecode::Constant::None,
16761727
});

compiler/src/error.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ pub enum CompileErrorType {
5454
InvalidContinue,
5555
InvalidReturn,
5656
InvalidYield,
57+
InvalidYieldFrom,
58+
InvalidAwait,
59+
AsyncYieldFrom,
5760
}
5861

5962
impl CompileError {
@@ -96,6 +99,9 @@ impl fmt::Display for CompileError {
9699
CompileErrorType::InvalidContinue => "'continue' outside loop".to_owned(),
97100
CompileErrorType::InvalidReturn => "'return' outside function".to_owned(),
98101
CompileErrorType::InvalidYield => "'yield' outside function".to_owned(),
102+
CompileErrorType::InvalidYieldFrom => "'yield from' outside function".to_owned(),
103+
CompileErrorType::InvalidAwait => "'await' outside async function".to_owned(),
104+
CompileErrorType::AsyncYieldFrom => "'yield from' inside async function".to_owned(),
99105
};
100106

101107
if let Some(statement) = &self.statement {

vm/src/frame.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use itertools::Itertools;
77
use crate::bytecode;
88
use crate::exceptions::{self, ExceptionCtor, PyBaseExceptionRef};
99
use crate::function::{single_or_tuple_any, PyFuncArgs};
10+
use crate::obj::objasyncgenerator::PyAsyncGenWrappedValue;
1011
use crate::obj::objbool;
1112
use crate::obj::objcode::PyCodeRef;
1213
use crate::obj::objcoroinner::Coro;
@@ -118,11 +119,15 @@ impl ExecutionResult {
118119
}
119120

120121
/// Turn an ExecutionResult into a PyResult that would be returned from a generator or coroutine
121-
pub fn into_result(self, vm: &VirtualMachine) -> PyResult {
122+
pub fn into_result(self, async_stopiter: bool, vm: &VirtualMachine) -> PyResult {
122123
match self {
123124
ExecutionResult::Yield(value) => Ok(value),
124125
ExecutionResult::Return(value) => {
125-
let stop_iteration = vm.ctx.exceptions.stop_iteration.clone();
126+
let stop_iteration = if async_stopiter {
127+
vm.ctx.exceptions.stop_async_iteration.clone()
128+
} else {
129+
vm.ctx.exceptions.stop_iteration.clone()
130+
};
126131
let args = if vm.is_none(&value) {
127132
vec![]
128133
} else {
@@ -379,6 +384,11 @@ impl Frame {
379384
}
380385
bytecode::Instruction::YieldValue => {
381386
let value = self.pop_value();
387+
let value = if self.code.flags.contains(bytecode::CodeFlags::IS_COROUTINE) {
388+
PyAsyncGenWrappedValue(value).into_ref(vm).into_object()
389+
} else {
390+
value
391+
};
382392
Ok(Some(ExecutionResult::Yield(value)))
383393
}
384394
bytecode::Instruction::YieldFrom => self.execute_yield_from(vm),

vm/src/obj/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
//! This package contains the python basic/builtin types
22
3+
pub mod objasyncgenerator;
34
pub mod objbool;
45
pub mod objbuiltinfunc;
56
pub mod objbytearray;

0 commit comments

Comments
 (0)