Skip to content

Context refactoring #181

@neil-tan

Description

@neil-tan

Draft for context changes

  • Operators on stack
  • Tensor name and pointer mapping for easier operator porting
  • Easier breakpoint setting
  • Extensibility for tensor object-pooling, persistent operators, offline-lifecycle planner for tensor object

Things to consider:

  • Would the template usages here increase the binary size significantly.
  • Consider making our own hash-tables, or, a form of tensor-index planning ability

Code snippet below:

typedef utensor::string OpName;
typedef std::unordered_map<TName, Tensor*> TensorNamePtrMap;

struct TensorRecord {
    public:
        unsigned char ref_count = 0;
        Tensor* t = nullptr;
        //bool keep_alive = false;
        TensorRecord(Tensor* _t, unsigned char _ref_count) :
            t(_t), ref_count(_ref_count) {}
};

class Context {
  private:
      std::unordered_map<TName, TensorRecord> tTable;
      //TODO: - Op data manager
      //      - Profiler

  public:
    Context();
    template<class T, typename... Args>
    Tensor* add(TName _name, unsigned char _ref_count, Args&&... args) {
        //pooling can be implemented here
        Tensor* t = new T(std::forward<Args>(args)...);
        t->setName(_name);
        tTable[_name] = TensorRecord(t, _ref_count);
        return t;
    }

    //Tensor lookup interface
    //non-existing tensor: returns Tensor*& but Tensor* is null
    Tensor*& operator[](TName name) {
      //TODO: define behavior for tensor-not-found
      Tensor*& t = tTable[name].t;
      return t;
    };
    
    void invoke(operator *op, OpName _name = 0);  //persistent op exists in heap
    
    //intermediate ops exists on stack
    void invoke(operator &op, OpName _name = 0) {
      //trigger registered actions based on _name here
      op.compute();
    }

    //This tensor removal function is meant to be called by code-gen directly
    void rm(TName t_name) {
      //TODO: check for t_name's existance
      delete tTable[t_name].t;
      tTable.erase(t_name);
    }
    
    //decrease ref count of used tensors and perform deletion
    //NT: Template based function worries me, as one copy of the function may be generated per op use
    template <class T>
    void gc(T &t_struct) {
      for (unsigned char i = 0; i < (sizeof(T) / sizeof(Tensor*)); i++) {
        TName t_name = ((Tensor*) &t_struct)[i]->name;
        unsigned char c = tTable[t_name].ref_count - 1;
        if(c <= 0) {
          rm(t_name);
        } else {
          tTable[t_name].ref_count = c;
        }
      }
    }

};

//Code for operators 

template <class T1, class T2, class TOut>
class MatMulOp : public Operator {
  public:
    struct {
      Tensor* input;
      Tensor* dim;
    } inputs;

    struct {
      Tensor* output;
    } outputs;

  MatMulOp() {
    //similar to TFLM's prepare function
  }
  virtual void compute() override {
    MatMul2<T1, T2, TOut>(inputs.input, inputs.dim,
     outputs.output);
  }
};

class Operator : public uTensor {
public:
  virtual void compute() = 0;
};


//// Example for Generated Code

//Old
{
    RamTensor<float>* out_tensor;
    out_tensor = new RamTensor<float>({ 1 });
    ctx.add(out_tensor, "MatMul_eightbit/x__port__0/min:0", 1);
    ctx.push(new MinOp(), 
             { "MatMul_eightbit/x__port__0/reshape:0", "MatMul_eightbit/x__port__0/reduction_dims:0" },
             { "MatMul_eightbit/x__port__0/min:0" });
    ctx.eval();
}

//New

{
    //Keeping one-off operators on the stack
    //Stateful operators can be allocated on the heap or a more global scope
    MinOp op();
    op.inputs.input = ctx["MatMul_eightbit/x__port__0/reshape:0"];
    op.inputs.dim = ctx["MatMul_eightbit/x__port__0/reduction_dims:0"];
    //ctx.add() registers Tensor* with the context and returns the same Tensor*
    //Tensor instanization all happen at the same place now praving way for instance pooling
    op.outputs.output = ctx.add<RamTensor<float>>("MatMul_eightbit/x__port__0/min:0", 1, { 1 });

    //setting a breakpoint here should take you (almost) straight to the kernel
    ctx.invoke(op, "op_name_from_code_gen"); //with or without supplying the name

    //logic for reference-counting clean up
    ctx.gc(op.inputs);  //or, use code-gen to delete all the input tensors by calling rm()
}

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions