|
3 | 3 | use rand::distributions::Distribution; |
4 | 4 | use rand_distr::Normal; |
5 | 5 |
|
6 | | -use crate::function::PyFuncArgs; |
7 | | -use crate::obj::objfloat; |
8 | 6 | use crate::pyobject::{PyObjectRef, PyResult}; |
9 | 7 | use crate::vm::VirtualMachine; |
10 | 8 |
|
11 | 9 | pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { |
12 | 10 | let ctx = &vm.ctx; |
13 | 11 |
|
14 | 12 | py_module!(vm, "random", { |
15 | | - "guass" => ctx.new_rustfunc(random_gauss), |
| 13 | + "gauss" => ctx.new_rustfunc(random_normalvariate), // TODO: is this the same? |
16 | 14 | "normalvariate" => ctx.new_rustfunc(random_normalvariate), |
17 | 15 | "random" => ctx.new_rustfunc(random_random), |
18 | 16 | // "weibull", ctx.new_rustfunc(random_weibullvariate), |
19 | 17 | }) |
20 | 18 | } |
21 | 19 |
|
22 | | -fn random_gauss(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { |
23 | | - // TODO: is this the same? |
24 | | - random_normalvariate(vm, args) |
25 | | -} |
26 | | - |
27 | | -fn random_normalvariate(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { |
28 | | - arg_check!( |
29 | | - vm, |
30 | | - args, |
31 | | - required = [ |
32 | | - (mu, Some(vm.ctx.float_type())), |
33 | | - (sigma, Some(vm.ctx.float_type())) |
34 | | - ] |
35 | | - ); |
36 | | - let mu = objfloat::get_value(mu); |
37 | | - let sigma = objfloat::get_value(sigma); |
| 20 | +fn random_normalvariate(mu: f64, sigma: f64, vm: &VirtualMachine) -> PyResult<f64> { |
38 | 21 | let normal = Normal::new(mu, sigma).map_err(|rand_err| { |
39 | 22 | vm.new_exception( |
40 | 23 | vm.ctx.exceptions.arithmetic_error.clone(), |
41 | 24 | format!("invalid normal distribution: {:?}", rand_err), |
42 | 25 | ) |
43 | 26 | })?; |
44 | 27 | let value = normal.sample(&mut rand::thread_rng()); |
45 | | - let py_value = vm.ctx.new_float(value); |
46 | | - Ok(py_value) |
| 28 | + Ok(value) |
47 | 29 | } |
48 | 30 |
|
49 | | -fn random_random(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { |
50 | | - arg_check!(vm, args); |
51 | | - let value = rand::random::<f64>(); |
52 | | - let py_value = vm.ctx.new_float(value); |
53 | | - Ok(py_value) |
| 31 | +fn random_random(_vm: &VirtualMachine) -> f64 { |
| 32 | + rand::random() |
54 | 33 | } |
55 | 34 |
|
56 | 35 | /* |
|
0 commit comments