@@ -19,15 +19,17 @@ use proc_macro2::{Span, TokenStream as TokenStream2};
1919use quote:: quote;
2020use rustpython_bytecode:: bytecode:: CodeObject ;
2121use rustpython_compiler:: compile;
22+ use std:: collections:: HashMap ;
2223use std:: env;
2324use std:: fs;
24- use std:: path:: PathBuf ;
25+ use std:: path:: { Path , PathBuf } ;
2526use syn:: parse:: { Parse , ParseStream , Result as ParseResult } ;
26- use syn:: { self , parse2, Lit , LitByteStr , Meta , Token } ;
27+ use syn:: { self , parse2, Lit , LitByteStr , LitStr , Meta , Token } ;
2728
2829enum CompilationSourceKind {
2930 File ( PathBuf ) ,
3031 SourceCode ( String ) ,
32+ Dir ( PathBuf ) ,
3133}
3234
3335struct CompilationSource {
@@ -36,14 +38,22 @@ struct CompilationSource {
3638}
3739
3840impl CompilationSource {
39- fn compile ( self , mode : & compile:: Mode , module_name : String ) -> Result < CodeObject , Diagnostic > {
40- let compile = |source| {
41- compile:: compile ( source, mode, module_name, 0 ) . map_err ( |err| {
42- Diagnostic :: spans_error ( self . span , format ! ( "Compile error: {}" , err) )
43- } )
44- } ;
45-
46- match & self . kind {
41+ fn compile_string (
42+ & self ,
43+ source : & str ,
44+ mode : & compile:: Mode ,
45+ module_name : String ,
46+ ) -> Result < CodeObject , Diagnostic > {
47+ compile:: compile ( source, mode, module_name, 0 )
48+ . map_err ( |err| Diagnostic :: spans_error ( self . span , format ! ( "Compile error: {}" , err) ) )
49+ }
50+
51+ fn compile (
52+ & self ,
53+ mode : & compile:: Mode ,
54+ module_name : String ,
55+ ) -> Result < HashMap < String , CodeObject > , Diagnostic > {
56+ Ok ( match & self . kind {
4757 CompilationSourceKind :: File ( rel_path) => {
4858 let mut path = PathBuf :: from (
4959 env:: var_os ( "CARGO_MANIFEST_DIR" ) . expect ( "CARGO_MANIFEST_DIR is not present" ) ,
@@ -55,10 +65,59 @@ impl CompilationSource {
5565 format ! ( "Error reading file {:?}: {}" , path, err) ,
5666 )
5767 } ) ?;
58- compile ( & source)
68+ hashmap ! { module_name. clone( ) => self . compile_string( & source, mode, module_name. clone( ) ) ?}
69+ }
70+ CompilationSourceKind :: SourceCode ( code) => {
71+ hashmap ! { module_name. clone( ) => self . compile_string( code, mode, module_name. clone( ) ) ?}
72+ }
73+ CompilationSourceKind :: Dir ( rel_path) => {
74+ let mut path = PathBuf :: from (
75+ env:: var_os ( "CARGO_MANIFEST_DIR" ) . expect ( "CARGO_MANIFEST_DIR is not present" ) ,
76+ ) ;
77+ path. push ( rel_path) ;
78+ self . compile_dir ( & path, String :: new ( ) , mode) ?
79+ }
80+ } )
81+ }
82+
83+ fn compile_dir (
84+ & self ,
85+ path : & Path ,
86+ parent : String ,
87+ mode : & compile:: Mode ,
88+ ) -> Result < HashMap < String , CodeObject > , Diagnostic > {
89+ let mut code_map = HashMap :: new ( ) ;
90+ let paths = fs:: read_dir ( & path) . map_err ( |err| {
91+ Diagnostic :: spans_error ( self . span , format ! ( "Error listing dir {:?}: {}" , path, err) )
92+ } ) ?;
93+ for path in paths {
94+ let path = path. map_err ( |err| {
95+ Diagnostic :: spans_error ( self . span , format ! ( "Failed to list file: {}" , err) )
96+ } ) ?;
97+ let path = path. path ( ) ;
98+ let file_name = path. file_name ( ) . unwrap ( ) . to_str ( ) . unwrap ( ) ;
99+ if path. is_dir ( ) {
100+ code_map. extend ( self . compile_dir (
101+ & path,
102+ format ! ( "{}{}." , parent, file_name) ,
103+ mode,
104+ ) ?) ;
105+ } else if file_name. ends_with ( ".py" ) {
106+ let source = fs:: read_to_string ( & path) . map_err ( |err| {
107+ Diagnostic :: spans_error (
108+ self . span ,
109+ format ! ( "Error reading file {:?}: {}" , path, err) ,
110+ )
111+ } ) ?;
112+ let file_name_splitte: Vec < & str > = file_name. splitn ( 2 , '.' ) . collect ( ) ;
113+ let module_name = format ! ( "{}{}" , parent, file_name_splitte[ 0 ] ) ;
114+ code_map. insert (
115+ module_name. clone ( ) ,
116+ self . compile_string ( & source, mode, module_name) ?,
117+ ) ;
59118 }
60- CompilationSourceKind :: SourceCode ( code) => compile ( code) ,
61119 }
120+ Ok ( code_map)
62121 }
63122}
64123
@@ -69,7 +128,7 @@ struct PyCompileInput {
69128}
70129
71130impl PyCompileInput {
72- fn compile ( & self ) -> Result < CodeObject , Diagnostic > {
131+ fn compile ( & self ) -> Result < HashMap < String , CodeObject > , Diagnostic > {
73132 let mut module_name = None ;
74133 let mut mode = None ;
75134 let mut source: Option < CompilationSource > = None ;
@@ -122,6 +181,16 @@ impl PyCompileInput {
122181 kind : CompilationSourceKind :: File ( path) ,
123182 span : extract_spans ( & name_value) . unwrap ( ) ,
124183 } ) ;
184+ } else if name_value. ident == "dir" {
185+ assert_source_empty ( & source) ?;
186+ let path = match & name_value. lit {
187+ Lit :: Str ( s) => PathBuf :: from ( s. value ( ) ) ,
188+ _ => bail_span ! ( name_value. lit, "source must be a string" ) ,
189+ } ;
190+ source = Some ( CompilationSource {
191+ kind : CompilationSourceKind :: Dir ( path) ,
192+ span : extract_spans ( & name_value) . unwrap ( ) ,
193+ } ) ;
125194 }
126195 }
127196 }
@@ -154,16 +223,23 @@ impl Parse for PyCompileInput {
154223pub fn impl_py_compile_bytecode ( input : TokenStream2 ) -> Result < TokenStream2 , Diagnostic > {
155224 let input: PyCompileInput = parse2 ( input) ?;
156225
157- let code_obj = input. compile ( ) ?;
226+ let code_map = input. compile ( ) ?;
158227
159- let bytes = bincode:: serialize ( & code_obj) . expect ( "Failed to serialize" ) ;
160- let bytes = LitByteStr :: new ( & bytes, Span :: call_site ( ) ) ;
228+ let modules = code_map. iter ( ) . map ( |( module_name, code_obj) | {
229+ let module_name = LitStr :: new ( & module_name, Span :: call_site ( ) ) ;
230+ let bytes = bincode:: serialize ( & code_obj) . expect ( "Failed to serialize" ) ;
231+ let bytes = LitByteStr :: new ( & bytes, Span :: call_site ( ) ) ;
232+ quote ! { #module_name. into( ) => bincode:: deserialize:: <:: rustpython_vm:: bytecode:: CodeObject >( #bytes)
233+ . expect( "Deserializing CodeObject failed" ) }
234+ } ) ;
161235
162236 let output = quote ! {
163237 ( {
164238 use :: rustpython_vm:: __exports:: bincode;
165- bincode:: deserialize:: <:: rustpython_vm:: bytecode:: CodeObject >( #bytes)
166- . expect( "Deserializing CodeObject failed" )
239+ use :: rustpython_vm:: __exports:: hashmap;
240+ hashmap! {
241+ #( #modules) , *
242+ }
167243 } )
168244 } ;
169245
0 commit comments