Skip to content

Commit ac8facf

Browse files
committed
1. Implement RamTensor::resize
2. ReshapeOperator allocate memory for output tensor
1 parent 2351fbd commit ac8facf

7 files changed

Lines changed: 36 additions & 13 deletions

File tree

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#ifndef _RESHAPE_TEST_H
22
#define _RESHAPE_TEST_H
33

4-
static const float random_input_arr[15] = { 1.8550491333007812, 4.670377254486084, 1.5111958980560303, 4.3641228675842285, 3.122225284576416, 1.1933523416519165, 3.7784199714660645, 4.052943706512451, 1.0156375169754028, 1.4321529865264893, 3.4896273612976074, 1.3438916206359863, 4.172100067138672, 0.589367151260376, 1.9008618593215942 };
5-
static const float ref_output_arr[15] = { 1.8550491333007812, 4.670377254486084, 1.5111958980560303, 4.3641228675842285, 3.122225284576416, 1.1933523416519165, 3.7784199714660645, 4.052943706512451, 1.0156375169754028, 1.4321529865264893, 3.4896273612976074, 1.3438916206359863, 4.172100067138672, 0.589367151260376, 1.9008618593215942 };
4+
static const float random_input_arr[15] = { 3.484638214111328, 2.033799886703491, 3.2437448501586914, 4.783249855041504, 3.497023582458496, 3.511240005493164, 1.558927297592163, 3.7084484100341797, 2.570117712020874, 0.2405869960784912, 1.8713605403900146, 4.19132661819458, 0.6596618890762329, 0.9029078483581543, 0.2223271131515503 };
5+
static const float ref_output_arr[15] = { 3.484638214111328, 2.033799886703491, 3.2437448501586914, 4.783249855041504, 3.497023582458496, 3.511240005493164, 1.558927297592163, 3.7084484100341797, 2.570117712020874, 0.2405869960784912, 1.8713605403900146, 4.19132661819458, 0.6596618890762329, 0.9029078483581543, 0.2223271131515503 };
66

77
#endif // _RESHAPE_TEST_H

TESTS/operators/test_reshape.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
#include "arenaAllocator.hpp"
55
#include "context.hpp"
6-
#include "Reshape.hpp"
76
#include "RamTensor.hpp"
7+
#include "Reshape.hpp"
88
#include "RomTensor.hpp"
99

1010
#include "gtest/gtest.h"
@@ -19,16 +19,17 @@ TEST(Reshape, reshape_test) {
1919
Context::get_default_context()->set_metadata_allocator(&meta_allocator);
2020
Context::get_default_context()->set_ram_data_allocator(&ram_allocator);
2121
Tensor input_tensor = new RomTensor({ 3,5 }, flt, random_input_arr);
22-
Tensor output_tensor = new RamTensor({ 5,3,1 }, flt);
23-
ReshapeOperator<float> op;
22+
Tensor output_tensor = new RamTensor(flt);
23+
24+
ReshapeOperator<float> op({ 5,3,1 });
2425
op
2526
.set_inputs({ { ReshapeOperator<float>::input, input_tensor } })
2627
.set_outputs({ { ReshapeOperator<float>::output, output_tensor } })
2728
.eval();
2829
for (int i = 0; i < 15; ++i) {
2930
EXPECT_NEAR((float) output_tensor(i), ref_output_arr[i], 0.0001);
3031
}
31-
TensorShape target_shape(5, 3, 1);
32+
TensorShape target_shape(5,3,1);
3233
TensorShape output_shape = output_tensor->get_shape();
3334
EXPECT_TRUE(target_shape == output_shape);
3435

python/test_scripts/gen_reshape.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def main(cpp_fname, const_fname):
3030
tensor_input = tf.random.uniform((3, 5), maxval=5, dtype=tf.float32)
3131
np_input = tensor_input.numpy()
3232
new_shape = [5, 3, 1]
33+
op_construct_params.append("{{ {} }}".format(",".join(map(str, new_shape))))
3334
tensor_output = tf.reshape(tensor_input, new_shape)
3435
np_output = tensor_output.numpy()
3536

@@ -44,9 +45,7 @@ def main(cpp_fname, const_fname):
4445
const_var_name="random_input_arr",
4546
),
4647
env.get_template("declare_ram_tensor.cpp").render(
47-
tensor_name="output_tensor",
48-
shape=np_output.shape,
49-
tensor_type_str="float",
48+
tensor_name="output_tensor", tensor_type_str="float",
5049
),
5150
]
5251
)
@@ -55,7 +54,7 @@ def main(cpp_fname, const_fname):
5554
output_names.append("output_tensor")
5655
ref_output_names.append("ref_output_arr")
5756
other_tests_str.append(
58-
f"TensorShape target_shape({ ', '.join(map(str, new_shape)) });\n"
57+
f"TensorShape target_shape({ ','.join(map(str, new_shape)) });\n"
5958
" TensorShape output_shape = output_tensor->get_shape();\n"
6059
" EXPECT_TRUE(target_shape == output_shape);\n"
6160
)
@@ -74,6 +73,7 @@ def main(cpp_fname, const_fname):
7473
declare_tensor_strs=declare_tensor_strs,
7574
op_cls=op_cls,
7675
op_type_signature=op_type_signature,
76+
op_construct_params=op_construct_params,
7777
op_name=op_name,
7878
inputs_str=inputs_str,
7979
outputs_str=outputs_str,
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
1-
Tensor {{tensor_name}} = new RamTensor({ {%for s in shape%}{{ s }}{{"," if not loop.last}}{%endfor%} }, {{TENSOR_TYPE_MAP[tensor_type_str]}});
1+
{%if shape %}
2+
Tensor {{tensor_name}} = new RamTensor({ {%for s in shape%}{{ s }}{{"," if not loop.last}}{%endfor%} }, {{TENSOR_TYPE_MAP[tensor_type_str]}});
3+
{% else %}
4+
Tensor {{tensor_name}} = new RamTensor({{TENSOR_TYPE_MAP[tensor_type_str]}});
5+
{%endif%}

src/uTensor/ops/Reshape.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,22 @@
77
#include "uTensor_util.hpp"
88
#include "operatorBase.hpp"
99

10+
using std::array;
11+
1012
namespace uTensor {
1113

1214
template <typename Tin>
1315
class ReshapeOperator : public OperatorInterface<1, 1> {
1416
/* reshape input as the shape of output*/
1517
public:
18+
ReshapeOperator(const TensorShape&& shape) : _shape(shape) {}
19+
ReshapeOperator(const TensorShape& shape) : _shape(shape) {}
1620
enum names_in : uint8_t { input };
1721
enum names_out : uint8_t { output };
1822
virtual void compute(){
1923
const Tensor& input_tensor = inputs[input].tensor();
2024
Tensor& output_tensor = outputs[output].tensor();
25+
output_tensor->resize(_shape);
2126
if (input_tensor->num_elems() != output_tensor->num_elems()){
2227
uTensor_printf("inconsistent input and output shape for reshape\n");
2328
Context::get_default_context()->throwError(new InvalidReshapeError);
@@ -39,6 +44,7 @@ class ReshapeOperator : public OperatorInterface<1, 1> {
3944
}
4045
}
4146
private:
47+
TensorShape _shape;
4248
bool _check_input_shape(){
4349
Tensor& input_tensor = inputs[input].tensor();
4450
TensorShape shape = input_tensor->get_shape();

src/uTensor/tensors/RamTensor.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,19 @@ RamTensor::~RamTensor() {
6565
}
6666

6767
void RamTensor::resize(TensorShape new_shape) {
68-
uTensor_printf("Warning, RAM Tensor resize not implemented\n");
68+
// uTensor_printf("Warning, RAM Tensor resize not implemented\n");
69+
AllocatorInterface* allocator = Context::get_default_context()->get_ram_data_allocator();
70+
void* ptr = allocator->allocate(new_shape.get_linear_size()*_type_size);
71+
if (!ptr) {
72+
uTensor_printf("OOM when resizing\n");
73+
Context::get_default_context()->throwError(new OutOfMemError);
74+
return;
75+
}
76+
if (is_bound(_ram_region, *allocator)) {
77+
allocator->unbind(*_ram_region, &_ram_region);
78+
}
79+
_shape = new_shape;
80+
allocator->bind(ptr, &_ram_region);
6981
}
7082

7183
FutureMaxSizeRamTensor::FutureMaxSizeRamTensor(ttype _type)

src/uTensor/tensors/RamTensor.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ class RamTensor : public TensorInterface {
1313
virtual void* read(uint32_t linear_index) const override;
1414
virtual void* write(uint32_t linear_index) override;
1515
RamTensor(); // May be useful in subclasses
16-
RamTensor(ttype _type);
1716

1817
public:
18+
RamTensor(ttype _type);
1919
RamTensor(TensorShape _shape, ttype _type);
2020
virtual ~RamTensor();
2121
virtual void resize(TensorShape new_shape) override;

0 commit comments

Comments
 (0)