diff --git a/internal/batch/cap.go b/internal/batch/cap.go new file mode 100644 index 0000000000..c3b87f5796 --- /dev/null +++ b/internal/batch/cap.go @@ -0,0 +1,63 @@ +package batch + +type capped struct { + current, limit int64 +} + +func (c capped) reachedLimit() bool { return c.limit > 0 && c.current >= c.limit } +func (c capped) remaining() int64 { + if c.limit > 0 { + return c.limit - c.current + } + return -1 +} + +func (c capped) remainingPerN(n int64) int64 { + if c.limit > 0 { + return (c.limit - c.current) / n + } + return -1 +} + +func (c capped) cap() int64 { + if c.limit > 0 { + return c.limit + } + return -1 +} +func (c capped) capPerN(n int64) int64 { + if c.limit > 0 { + return c.limit / n + } + return -1 +} + +type Cap struct { + bytes, rows capped +} + +func (c *Cap) ReachedLimit() bool { return c.bytes.reachedLimit() || c.rows.reachedLimit() } +func (c *Cap) Rows() int64 { return c.rows.current } +func (c *Cap) AddRows(rows int64) { c.rows.current += rows } + +func (c *Cap) Reset() { + c.bytes.current = 0 + c.rows.current = 0 +} + +func (c *Cap) add(bytes, rows int64) { + c.bytes.current += bytes + c.rows.current += rows +} + +func (c *Cap) set(bytes, rows int64) { + c.bytes.current = bytes + c.rows.current = rows +} + +func CappedAt(bytes, rows int64) *Cap { + return &Cap{ + bytes: capped{limit: bytes}, + rows: capped{limit: rows}, + } +} diff --git a/internal/batch/slice.go b/internal/batch/slice.go new file mode 100644 index 0000000000..ad3d1c3d35 --- /dev/null +++ b/internal/batch/slice.go @@ -0,0 +1,143 @@ +package batch + +import ( + "github.com/apache/arrow/go/v16/arrow" + "github.com/apache/arrow/go/v16/arrow/util" +) + +type ( + SlicedRecord struct { + arrow.Record + Bytes int64 // we need this as the util.TotalRecordSize will report the full size even for the sliced record + bytesPerRow int64 + } +) + +func (s *SlicedRecord) split(limit *Cap) (add *SlicedRecord, toFlush []arrow.Record, rest *SlicedRecord) { + if s == nil { + return nil, nil, nil + } + + add = s.getAdd(limit) + if add != nil { + limit.add(add.Bytes, add.NumRows()) + } + + if s.Record == nil { + // all processed + return add, nil, nil + } + + toFlush = s.getToFlush(limit) + if s.Record == nil { + // all processed + return add, toFlush, nil + } + + // set bytes & rows new values + limit.set(s.Bytes, s.NumRows()) + return add, toFlush, s +} + +func (s *SlicedRecord) getAdd(limit *Cap) *SlicedRecord { + rowsByBytes := limit.bytes.remainingPerN(s.bytesPerRow) + rows := limit.rows.remaining() + switch { + case rows < 0: + rows = rowsByBytes + case rows > rowsByBytes && rowsByBytes >= 0: + rows = rowsByBytes + } + + switch { + case rows == 0: + return nil + case rows < 0, rows >= s.NumRows(): + // grab the whole record (either no limits or not overflowing) + res := *s + s.Bytes = 0 + s.Record = nil + return &res + } + + res := SlicedRecord{ + Record: s.NewSlice(0, rows), + Bytes: rows * s.bytesPerRow, + bytesPerRow: s.bytesPerRow, + } + s.Record = s.NewSlice(rows, s.NumRows()) + s.Bytes -= res.Bytes + return &res +} + +func (s *SlicedRecord) getToFlush(limit *Cap) []arrow.Record { + rowsByBytes := limit.bytes.capPerN(s.bytesPerRow) + rows := limit.rows.cap() + switch { + case rows < 0: + rows = rowsByBytes + case rows > rowsByBytes && rowsByBytes >= 0: + rows = rowsByBytes + } + + switch { + case rows == 0: + // not even a single row fits + // we still need to process this, so slice by single row + return s.slice() + case rows < 0: + // as s.Record != nil we know that the limits are there in place & the s.Record.NumRows() > 0 + panic("should never be here") + case rows > s.NumRows(): + // no need to flush anything, as the amount of rows isn't enough to grant this + return nil + } + + flush := make([]arrow.Record, 0, s.NumRows()/rows) + offset := int64(0) + for offset+rows <= s.NumRows() { + flush = append(flush, s.NewSlice(offset, offset+rows)) + offset += rows + } + if offset == s.NumRows() { + // we processed everything for flush + s.Record = nil + s.Bytes = 0 + return flush + } + + // set record to the remainder + s.Record = s.NewSlice(offset, s.NumRows()) + s.Bytes = s.NumRows() * s.bytesPerRow + + return flush +} + +func (s *SlicedRecord) slice() []arrow.Record { + res := make([]arrow.Record, s.NumRows()) + for i := int64(0); i < s.NumRows(); i++ { + res[i] = s.NewSlice(i, i+1) + } + return res +} + +func newSlicedRecord(r arrow.Record) *SlicedRecord { + if r.NumRows() == 0 { + return nil + } + res := SlicedRecord{ + Record: r, + Bytes: util.TotalRecordSize(r), + } + res.bytesPerRow = res.Bytes / r.NumRows() + return &res +} + +// SliceRecord will return the SlicedRecord you can add to the batch given the restrictions provided (if any). +// The meaning of the returned values: +// - `add` is good to be added to the current batch that the caller is assembling +// - `flush` represents sliced arrow.Record that needs own batch to be flushed +// - `remaining` represents the overflow of the batch after `add` & `flush` are processed +func SliceRecord(r arrow.Record, limit *Cap) (add *SlicedRecord, flush []arrow.Record, remaining *SlicedRecord) { + return newSlicedRecord(r).split(limit) +} diff --git a/internal/batch/slice_test.go b/internal/batch/slice_test.go new file mode 100644 index 0000000000..3bf9bae39d --- /dev/null +++ b/internal/batch/slice_test.go @@ -0,0 +1,122 @@ +package batch + +import ( + "fmt" + "math/rand" + "strconv" + "testing" + "time" + + "github.com/apache/arrow/go/v16/arrow/util" + "github.com/cloudquery/plugin-sdk/v4/schema" + "github.com/stretchr/testify/assert" +) + +func TestSliceRecord(t *testing.T) { + for run := 0; run < 5; run++ { + rows := rand.Intn(100) + 5 + t.Run(strconv.Itoa(rows), func(t *testing.T) { + t.Parallel() + table := schema.TestTable(fmt.Sprintf("test_%d_rows", rows), schema.TestSourceOptions{}) + tg := schema.NewTestDataGenerator(0) + record := tg.Generate(table, schema.GenTestDataOptions{ + MaxRows: rows, + SourceName: "test", + SyncTime: time.Now(), + }) + + recordRows, recordBytes := record.NumRows(), util.TotalRecordSize(record) + + t.Run("only add", func(t *testing.T) { + add, toFlush, rest := SliceRecord(record, CappedAt(0, 0)) + assert.NotNil(t, add) + assert.Equal(t, recordRows, add.NumRows()) + assert.Equal(t, recordBytes, add.Bytes) + assert.Empty(t, toFlush) + assert.Nil(t, rest) + }) + + t.Run("only single toFlush", func(t *testing.T) { + limit := &Cap{ + bytes: capped{current: recordBytes, limit: recordBytes}, + rows: capped{current: recordRows, limit: recordRows}, + } + add, toFlush, rest := SliceRecord(record, limit) + assert.Nil(t, add) + assert.NotEmpty(t, toFlush) + assert.Len(t, toFlush, 1) + r := toFlush[0] + assert.Equal(t, recordRows, r.NumRows()) + assert.Nil(t, rest) + }) + + t.Run("full - by rows", func(t *testing.T) { + limit := &Cap{rows: capped{current: recordRows / 10, limit: recordRows / 5}} + remaining := recordRows + + add, toFlush, rest := SliceRecord(record, limit) + // if we could add some rows + if (recordRows/5)-(recordRows/10) > 0 { + assert.NotNil(t, add) + assert.LessOrEqual(t, add.NumRows(), recordRows/5) + assert.LessOrEqual(t, add.Bytes, recordBytes/5) + remaining -= add.NumRows() + } else { + assert.Nil(t, add) + } + + assert.NotEmpty(t, toFlush) + assert.GreaterOrEqual(t, len(toFlush), 4) + for _, f := range toFlush { + assert.LessOrEqual(t, f.NumRows(), recordRows/5) + remaining -= f.NumRows() + } + + assert.GreaterOrEqual(t, remaining, int64(0)) + if remaining == 0 { + assert.Nil(t, rest) + return + } + + assert.NotNil(t, rest) + assert.Less(t, remaining, recordRows/5) + assert.Equal(t, remaining, rest.NumRows()) + assert.Less(t, rest.Bytes, recordBytes/5) + }) + + t.Run("full - by bytes", func(t *testing.T) { + limit := &Cap{bytes: capped{current: recordBytes / 10, limit: recordBytes / 5}} + remaining := recordRows + + add, toFlush, rest := SliceRecord(record, limit) + // if we could add some rows + if (recordBytes/5)-(recordBytes/10) >= util.TotalRecordSize(record)/record.NumRows() { + assert.NotNil(t, add) + assert.LessOrEqual(t, add.NumRows(), recordRows/5) + assert.LessOrEqual(t, add.Bytes, recordBytes/5) + remaining -= add.NumRows() + } else { + assert.Nil(t, add) + } + + assert.NotEmpty(t, toFlush) + assert.GreaterOrEqual(t, len(toFlush), 4) + for _, f := range toFlush { + assert.LessOrEqual(t, f.NumRows(), recordRows/5) + remaining -= f.NumRows() + } + + assert.GreaterOrEqual(t, remaining, int64(0)) + if remaining == 0 { + assert.Nil(t, rest) + return + } + + assert.NotNil(t, rest) + assert.Less(t, remaining, recordRows/5) + assert.Equal(t, remaining, rest.NumRows()) + assert.Less(t, rest.Bytes, recordBytes/5) + }) + }) + } +} diff --git a/scalar/binary.go b/scalar/binary.go index e6519e05fa..f834d7c67f 100644 --- a/scalar/binary.go +++ b/scalar/binary.go @@ -83,6 +83,8 @@ func (s *Binary) Set(val any) error { return nil } +func (s *Binary) ByteSize() int64 { return int64(len(s.Value)) } + func (*Binary) DataType() arrow.DataType { return arrow.BinaryTypes.Binary } @@ -94,3 +96,8 @@ type LargeBinary struct { func (*LargeBinary) DataType() arrow.DataType { return arrow.BinaryTypes.LargeBinary } + +var ( + _ Scalar = (*Binary)(nil) + _ Scalar = (*LargeBinary)(nil) +) diff --git a/scalar/bool.go b/scalar/bool.go index 862d96d5f8..a982eef857 100644 --- a/scalar/bool.go +++ b/scalar/bool.go @@ -85,3 +85,9 @@ func (s *Bool) Set(val any) error { s.Valid = true return nil } + +func (*Bool) ByteSize() int64 { return int64(1) } + +var ( + _ Scalar = (*Bool)(nil) +) diff --git a/scalar/date32.go b/scalar/date32.go index 8df72e0399..e7174b6d63 100644 --- a/scalar/date32.go +++ b/scalar/date32.go @@ -105,3 +105,9 @@ func (s *Date32) Set(val any) error { s.Valid = true return nil } + +func (*Date32) ByteSize() int64 { return int64(arrow.Date32SizeBytes) } + +var ( + _ Scalar = (*Date32)(nil) +) diff --git a/scalar/date64.go b/scalar/date64.go index d32c526ba9..d01ef02580 100644 --- a/scalar/date64.go +++ b/scalar/date64.go @@ -105,3 +105,9 @@ func (s *Date64) Set(val any) error { s.Valid = true return nil } + +func (*Date64) ByteSize() int64 { return int64(arrow.Date64SizeBytes) } + +var ( + _ Scalar = (*Date64)(nil) +) diff --git a/scalar/decimal.go b/scalar/decimal.go index 487675bd61..5c9458469b 100644 --- a/scalar/decimal.go +++ b/scalar/decimal.go @@ -175,6 +175,8 @@ func (s *Decimal256) Set(val any) error { return nil } +func (*Decimal256) ByteSize() int64 { return int64(arrow.Decimal256SizeBytes) } + type Decimal128 struct { Valid bool Value decimal128.Num @@ -335,3 +337,10 @@ func (s *Decimal128) Set(val any) error { s.Valid = true return nil } + +func (*Decimal128) ByteSize() int64 { return int64(arrow.Decimal128SizeBytes) } + +var ( + _ Scalar = (*Decimal256)(nil) + _ Scalar = (*Decimal128)(nil) +) diff --git a/scalar/duration.go b/scalar/duration.go index d19d0e34cb..254f8b4ab2 100644 --- a/scalar/duration.go +++ b/scalar/duration.go @@ -60,3 +60,9 @@ func (s *Duration) Set(value any) error { } return s.Int.Set(value) } + +func (*Duration) ByteSize() int64 { return int64(arrow.DurationSizeBytes) } + +var ( + _ Scalar = (*Duration)(nil) +) diff --git a/scalar/float.go b/scalar/float.go index edfc414f74..9dc6c82a42 100644 --- a/scalar/float.go +++ b/scalar/float.go @@ -229,3 +229,9 @@ func (s *Float) getBitWidth() uint8 { } return s.BitWidth } + +func (s *Float) ByteSize() int64 { return int64(s.getBitWidth() / 8) } + +var ( + _ Scalar = (*Float)(nil) +) diff --git a/scalar/inet.go b/scalar/inet.go index 9bf8461eee..995e347de5 100644 --- a/scalar/inet.go +++ b/scalar/inet.go @@ -129,6 +129,12 @@ func (s *Inet) Set(val any) error { return nil } +func (s *Inet) ByteSize() int64 { return int64(len(s.Value.IP) + len(s.Value.Mask)) } + +var ( + _ Scalar = (*Inet)(nil) +) + // Convert the net.IP to IPv4, if appropriate. // // When parsing a string to a net.IP using net.ParseIP() and the like, we get a diff --git a/scalar/int.go b/scalar/int.go index 46f868586e..ba7525d084 100644 --- a/scalar/int.go +++ b/scalar/int.go @@ -251,3 +251,9 @@ func (s *Int) getBitWidth() uint8 { } return s.BitWidth } + +func (s *Int) ByteSize() int64 { return int64(s.getBitWidth() / 8) } + +var ( + _ Scalar = (*Int)(nil) +) diff --git a/scalar/interval.go b/scalar/interval.go index 6101aa8f50..dac87f6fb2 100644 --- a/scalar/interval.go +++ b/scalar/interval.go @@ -77,6 +77,8 @@ func (s *MonthInterval) Set(value any) error { } } +func (*MonthInterval) ByteSize() int64 { return int64(arrow.MonthIntervalSizeBytes) } + type DayTimeInterval struct { Value arrow.DayTimeInterval Valid bool @@ -170,6 +172,8 @@ func (s *DayTimeInterval) Get() any { return s.Value } +func (*DayTimeInterval) ByteSize() int64 { return int64(arrow.DayTimeIntervalSizeBytes) } + type MonthDayNanoInterval struct { Value arrow.MonthDayNanoInterval Valid bool @@ -262,3 +266,11 @@ func (s *MonthDayNanoInterval) Get() any { return s.Value } + +func (*MonthDayNanoInterval) ByteSize() int64 { return int64(arrow.MonthDayNanoIntervalSizeBytes) } + +var ( + _ Scalar = (*MonthInterval)(nil) + _ Scalar = (*DayTimeInterval)(nil) + _ Scalar = (*MonthDayNanoInterval)(nil) +) diff --git a/scalar/json.go b/scalar/json.go index 41edc212f3..2a9a240b75 100644 --- a/scalar/json.go +++ b/scalar/json.go @@ -131,6 +131,12 @@ func (s *JSON) Set(val any) error { return nil } +func (s *JSON) ByteSize() int64 { return int64(len(s.Value)) } + +var ( + _ Scalar = (*JSON)(nil) +) + // isEmptyStringMap returns true if the value is a map from string to any (i.e. map[string]any). // We need to use reflection for this, because it impossible to type-assert a map[string]string into a // map[string]any. See https://go.dev/doc/faq#convert_slice_of_interface. diff --git a/scalar/list.go b/scalar/list.go index 004da5f49c..b9f4f386fa 100644 --- a/scalar/list.go +++ b/scalar/list.go @@ -153,6 +153,12 @@ func (s *List) Set(val any) error { return nil } +func (s *List) ByteSize() int64 { return s.Value.ByteSize() } + +var ( + _ Scalar = (*List)(nil) +) + func isReflectValueNil(v reflect.Value) bool { switch v.Kind() { case reflect.Pointer, diff --git a/scalar/mac.go b/scalar/mac.go index 29b8057014..1cd9643c92 100644 --- a/scalar/mac.go +++ b/scalar/mac.go @@ -89,3 +89,9 @@ func (s *Mac) Set(val any) error { s.Valid = true return nil } + +func (s *Mac) ByteSize() int64 { return int64(len(s.Value)) } + +var ( + _ Scalar = (*Mac)(nil) +) diff --git a/scalar/scalar.go b/scalar/scalar.go index 298382d28c..ca859a8ac5 100644 --- a/scalar/scalar.go +++ b/scalar/scalar.go @@ -31,6 +31,8 @@ type Scalar interface { Get() any Equal(other Scalar) bool + + ByteSize() int64 } type Vector []Scalar @@ -54,6 +56,14 @@ func (v Vector) Equal(r Vector) bool { return true } +func (v Vector) ByteSize() int64 { + var size int64 + for _, s := range v { + size += s.ByteSize() + } + return size +} + func NewScalar(dt arrow.DataType) Scalar { switch dt.ID() { case arrow.TIMESTAMP: diff --git a/scalar/string.go b/scalar/string.go index babdd83c5a..67ddc11050 100644 --- a/scalar/string.go +++ b/scalar/string.go @@ -82,22 +82,22 @@ func (s *String) Set(val any) error { return nil } +func (s *String) ByteSize() int64 { return int64(len(s.Value)) } + type LargeString struct { s String } -func (s *LargeString) IsValid() bool { - return s.s.Valid -} +func (s *LargeString) String() string { return s.s.Value } +func (s *LargeString) IsValid() bool { return s.s.IsValid() } +func (s *LargeString) Set(val any) error { return s.s.Set(val) } +func (s *LargeString) Get() any { return s.s.Get() } +func (s *LargeString) ByteSize() int64 { return s.s.ByteSize() } func (*LargeString) DataType() arrow.DataType { return arrow.BinaryTypes.LargeString } -func (s *LargeString) String() string { - return s.s.String() -} - func (s *LargeString) Equal(rhs Scalar) bool { if rhs == nil { return false @@ -109,10 +109,7 @@ func (s *LargeString) Equal(rhs Scalar) bool { return s.s.Valid == r.s.Valid && s.s.Value == r.s.Value } -func (s *LargeString) Get() any { - return s.s.Get() -} - -func (s *LargeString) Set(val any) error { - return s.s.Set(val) -} +var ( + _ Scalar = (*String)(nil) + _ Scalar = (*LargeString)(nil) +) diff --git a/scalar/struct.go b/scalar/struct.go index 088466da0f..b8f974b237 100644 --- a/scalar/struct.go +++ b/scalar/struct.go @@ -104,15 +104,29 @@ func (s *Struct) Set(val any) error { s.Value = val } - if rv := reflect.ValueOf(val); rv.Kind() == reflect.Pointer && !rv.Elem().IsValid() { // typed nil - s.Valid = false - return nil - } - - s.Valid = true + rv := reflect.ValueOf(val) + s.Valid = rv.Kind() != reflect.Pointer || rv.Elem().IsValid() // !typed nil return nil } func (s *Struct) DataType() arrow.DataType { return s.Type } + +func (s *Struct) ByteSize() int64 { + if !s.Valid { + return 1 // for nil bitmap + } + v := reflect.ValueOf(s.Value) + for v.Kind() == reflect.Pointer { + if v.IsNil() { + return 1 // for nil bitmap + } + v = v.Elem() + } + return int64(v.Type().Size()) +} + +var ( + _ Scalar = (*Struct)(nil) +) diff --git a/scalar/time.go b/scalar/time.go index f060a32be4..4b98a2f284 100644 --- a/scalar/time.go +++ b/scalar/time.go @@ -94,3 +94,7 @@ func (s *Time) Set(value any) error { return s.Int.Set(value) } } + +var ( + _ Scalar = (*Time)(nil) +) diff --git a/scalar/timestamp.go b/scalar/timestamp.go index fbc412f349..2db203fc41 100644 --- a/scalar/timestamp.go +++ b/scalar/timestamp.go @@ -149,3 +149,9 @@ func (s *Timestamp) DecodeText(src []byte) error { return &ValidationError{Type: s.DataType(), Msg: "cannot parse timestamp", Value: sbuf, Err: err} } } + +func (*Timestamp) ByteSize() int64 { return int64(arrow.TimestampSizeBytes) } + +var ( + _ Scalar = (*Timestamp)(nil) +) diff --git a/scalar/uint.go b/scalar/uint.go index e5a6bc7452..39bb7039aa 100644 --- a/scalar/uint.go +++ b/scalar/uint.go @@ -246,3 +246,9 @@ func (s *Uint) getBitWidth() uint8 { } return s.BitWidth } + +func (s *Uint) ByteSize() int64 { return int64(s.getBitWidth() / 8) } + +var ( + _ Scalar = (*Uint)(nil) +) diff --git a/scalar/uuid.go b/scalar/uuid.go index 715a13cccb..9e5e8cecf6 100644 --- a/scalar/uuid.go +++ b/scalar/uuid.go @@ -100,6 +100,12 @@ func (s *UUID) Set(src any) error { return nil } +func (s *UUID) ByteSize() int64 { return int64(len(s.Value)) } + +var ( + _ Scalar = (*UUID)(nil) +) + // parseUUID converts a string UUID in standard form to a byte array. func parseUUID(src string) (dst [16]byte, err error) { switch len(src) { diff --git a/scheduler/batch.go b/scheduler/batch.go new file mode 100644 index 0000000000..7845107f76 --- /dev/null +++ b/scheduler/batch.go @@ -0,0 +1,171 @@ +package scheduler + +import ( + "context" + "sync" + "time" + + "github.com/apache/arrow/go/v16/arrow/array" + "github.com/apache/arrow/go/v16/arrow/memory" + "github.com/cloudquery/plugin-sdk/v4/message" + "github.com/cloudquery/plugin-sdk/v4/scalar" + "github.com/cloudquery/plugin-sdk/v4/schema" + "github.com/cloudquery/plugin-sdk/v4/writers" +) + +type batcher struct { + ctx context.Context + ctxDone <-chan struct{} + + res chan<- message.SyncMessage + + maxRows int + maxBytes int64 + timeout time.Duration + + // using sync primitives by value here implies that batcher is to be used by pointer only + // workers is a sync.Map rather than a map + mutex pair + // because worker allocation & lookup falls into one of the sync.Map use-cases, + // namely, ever-growing cache (write once, read many times). + workers sync.Map // k = table name, v = *worker + wg sync.WaitGroup +} + +type worker struct { + ch chan *schema.Resource + flush chan chan struct{} + curRows, maxRows int + curBytes, maxBytes int64 + builder *array.RecordBuilder // we can reuse that + res chan<- message.SyncMessage +} + +// send must be called on len(rows) > 0 +func (w *worker) send() { + w.res <- &message.SyncInsert{Record: w.builder.NewRecord()} + // we need to reserve here as NewRecord (& underlying NewArray calls) reset the memory + w.builder.Reserve(w.maxRows) + w.curRows, w.curBytes = 0, 0 // reset +} + +func (w *worker) work(done <-chan struct{}, timeout time.Duration) { + ticker := writers.NewTicker(timeout) + defer ticker.Stop() + tickerCh := ticker.Chan() + + for { + select { + case r, ok := <-w.ch: + if !ok { + if w.curRows > 0 { + w.send() + } + return + } + + v := r.GetValues() + vBytes := v.ByteSize() + // check if append will cause overflow + if w.maxBytes > 0 && w.curBytes+vBytes > w.maxBytes { + w.send() + ticker.Reset(timeout) + } + + // append to builder + scalar.AppendToRecordBuilder(w.builder, r.GetValues()) + w.curRows++ + w.curBytes += vBytes + // check if we need to flush + if (w.maxRows > 0 && w.curRows == w.maxRows) || + (w.maxBytes > 0 && w.curBytes == w.maxBytes) { // > impossible due to the flush above + w.send() + ticker.Reset(timeout) + } + + case <-tickerCh: + if w.curRows > 0 { + w.send() + } + + case ch := <-w.flush: + if w.curRows > 0 { + w.send() + ticker.Reset(timeout) + } + close(ch) + + case <-done: + // this means the request was cancelled + return // after this NO other call will succeed + } + } +} + +func (b *batcher) process(res *schema.Resource) { + table := res.Table + // already running worker + v, loaded := b.workers.Load(table.Name) + if loaded { + v.(*worker).ch <- res + return + } + + // we alloc only ch here, as it may be needed right away + // for instance, if another goroutine will get the value allocated by us + wr := &worker{ch: make(chan *schema.Resource, 5)} // 5 is quite enough + v, loaded = b.workers.LoadOrStore(table.Name, wr) + if loaded { + // means that the worker was already in tne sync.Map, so we just discard the wr value + close(wr.ch) // for GC + v.(*worker).ch <- res // send res to the already allocated worker + return + } + + // fill in the required data + // start wr + b.wg.Add(1) + go func() { + defer b.wg.Done() + + // fill in the worker fields + wr.flush = make(chan chan struct{}) + wr.maxRows = b.maxRows + wr.maxBytes = b.maxBytes + wr.builder = array.NewRecordBuilder(memory.DefaultAllocator, table.ToArrowSchema()) + wr.res = b.res + wr.builder.Reserve(b.maxRows) + + // start processing + wr.work(b.ctxDone, b.timeout) + }() + + wr.ch <- res +} + +func (b *batcher) close() { + b.workers.Range(func(_, v any) bool { + close(v.(*worker).ch) + return true + }) + b.wg.Wait() +} + +func newBatcher(ctx context.Context, res chan<- message.SyncMessage, maxRows int, maxBytes int64, timeout time.Duration) *batcher { + return &batcher{ + ctx: ctx, + ctxDone: ctx.Done(), + res: res, + maxRows: maxRows, + maxBytes: maxBytes, + timeout: timeout, + } +} + +func newDefaultBatcher(ctx context.Context, res chan<- message.SyncMessage) *batcher { + const ( + rows = 50 + bytes = 50 * (1 << 20) // 50 MiB + timeout = 5 * time.Second + ) + return newBatcher(ctx, res, rows, bytes, timeout) +} diff --git a/scheduler/benchmark_test.go.backup b/scheduler/benchmark_test.go.backup index 5e16b26447..e9a8d627b6 100644 --- a/scheduler/benchmark_test.go.backup +++ b/scheduler/benchmark_test.go.backup @@ -9,7 +9,7 @@ import ( "testing" "time" - "github.com/apache/arrow/go/v15/arrow" + "github.com/apache/arrow/go/v16/arrow" "github.com/cloudquery/plugin-pb-go/specs" "github.com/cloudquery/plugin-sdk/v4/schema" "github.com/rs/zerolog" diff --git a/scheduler/scheduler.go b/scheduler/scheduler.go index e37fb99581..aa12e8a021 100644 --- a/scheduler/scheduler.go +++ b/scheduler/scheduler.go @@ -9,8 +9,6 @@ import ( "sync/atomic" "time" - "github.com/apache/arrow/go/v16/arrow" - "github.com/cloudquery/plugin-sdk/v4/caser" "github.com/cloudquery/plugin-sdk/v4/message" "github.com/cloudquery/plugin-sdk/v4/schema" @@ -209,22 +207,22 @@ func (s *Scheduler) Sync(ctx context.Context, client schema.ClientMeta, tables s panic(fmt.Errorf("unknown scheduler %s", s.strategy.String())) } }() + + b := newDefaultBatcher(ctx, res) + defer b.close() // wait for all resources to be processed + done := ctx.Done() // no need to do the lookups in loop for resource := range resources { select { - case res <- &message.SyncInsert{Record: resourceToRecord(resource)}: - case <-ctx.Done(): + case <-done: s.logger.Debug().Msg("sync context cancelled") return context.Cause(ctx) + default: + b.process(resource) } } return context.Cause(ctx) } -func resourceToRecord(resource *schema.Resource) arrow.Record { - vector := resource.GetValues() - return vector.ToArrowRecord(resource.Table.ToArrowSchema()) -} - func (s *syncClient) logTablesMetrics(tables schema.Tables, client Client) { clientName := client.ID() for _, table := range tables { diff --git a/scheduler/scheduler_test.go b/scheduler/scheduler_test.go index c46e4ba754..efd7efc03e 100644 --- a/scheduler/scheduler_test.go +++ b/scheduler/scheduler_test.go @@ -394,22 +394,22 @@ func TestScheduler_Cancellation(t *testing.T) { data := make([]any, 100) tests := []struct { - name string - data []any - cancel bool - messageCount int + name string + data []any + cancel bool + messagesOrRows int }{ { - name: "should consume all message", - data: data, - cancel: false, - messageCount: len(data) + 1, // 9 data + 1 migration message + name: "should consume all message", + data: data, + cancel: false, + messagesOrRows: len(data) + 1, // 9 data + 1 migration message }, { - name: "should not consume all message on cancel", - data: data, - cancel: true, - messageCount: len(data) + 1, // 9 data + 1 migration message + name: "should not consume all message on cancel", + data: data, + cancel: true, + messagesOrRows: len(data) + 1, // 9 data + 1 migration message }, } @@ -443,18 +443,22 @@ func TestScheduler_Cancellation(t *testing.T) { close(messages) }() - messageConsumed := 0 - for range messages { + messagesOrRows := 0 + for msg := range messages { if tc.cancel { cancel() } - messageConsumed++ + if r, ok := msg.(*message.SyncInsert); ok { + messagesOrRows += int(r.Record.NumRows()) + } else { + messagesOrRows++ + } } if tc.cancel { - assert.NotEqual(t, tc.messageCount, messageConsumed) + assert.NotEqual(t, tc.messagesOrRows, messagesOrRows) } else { - assert.Equal(t, tc.messageCount, messageConsumed) + assert.Equal(t, tc.messagesOrRows, messagesOrRows) } }) } diff --git a/writers/batchwriter/batchwriter.go b/writers/batchwriter/batchwriter.go index 13909fc3db..1082199388 100644 --- a/writers/batchwriter/batchwriter.go +++ b/writers/batchwriter/batchwriter.go @@ -6,7 +6,7 @@ import ( "sync" "time" - "github.com/apache/arrow/go/v16/arrow/util" + "github.com/cloudquery/plugin-sdk/v4/internal/batch" "github.com/cloudquery/plugin-sdk/v4/message" "github.com/cloudquery/plugin-sdk/v4/schema" "github.com/cloudquery/plugin-sdk/v4/writers" @@ -122,7 +122,7 @@ func (w *BatchWriter) Close(context.Context) error { } func (w *BatchWriter) worker(ctx context.Context, tableName string, ch <-chan *message.WriteInsert, flush <-chan chan bool) { - var bytes, rows int64 + limit := batch.CappedAt(w.batchSizeBytes, w.batchSize) resources := make([]*message.WriteInsert, 0, w.batchSize) // at least we have 1 row per record ticker := writers.NewTicker(w.batchTimeout) @@ -134,47 +134,50 @@ func (w *BatchWriter) worker(ctx context.Context, tableName string, ch <-chan *m w.flushTable(ctx, tableName, resources) clear(resources) resources = resources[:0] - bytes, rows = 0, 0 + limit.Reset() } + for { select { case r, ok := <-ch: if !ok { - if rows > 0 { + if limit.Rows() > 0 { w.flushTable(ctx, tableName, resources) } return } - recordRows, recordBytes := r.Record.NumRows(), util.TotalRecordSize(r.Record) - if (w.batchSize > 0 && rows+recordRows > w.batchSize) || - (w.batchSizeBytes > 0 && bytes+recordBytes > w.batchSizeBytes) { - if rows == 0 { - // New record overflows batch by itself. - // Flush right away. - // TODO: slice - resources = append(resources, r) - send() - ticker.Reset(w.batchTimeout) - continue - } - // rows > 0 + if r.Record.NumRows() == 0 { + // skip empty ones + continue + } + + add, toFlush, rest := batch.SliceRecord(r.Record, limit) + if add != nil { + resources = append(resources, &message.WriteInsert{Record: add.Record}) + } + if len(toFlush) > 0 || rest != nil || limit.ReachedLimit() { + // flush current batch send() ticker.Reset(w.batchTimeout) } - if recordRows > 0 { - // only save records with rows - resources = append(resources, r) - rows += recordRows - bytes += recordBytes + for _, sliceToFlush := range toFlush { + resources = append(resources, &message.WriteInsert{Record: sliceToFlush}) + send() + ticker.Reset(w.batchTimeout) + } + + // set the remainder + if rest != nil { + resources = append(resources, &message.WriteInsert{Record: rest.Record}) } case <-tickerCh: - if rows > 0 { + if limit.Rows() > 0 { send() } case done := <-flush: - if rows > 0 { + if limit.Rows() > 0 { send() ticker.Reset(w.batchTimeout) } diff --git a/writers/mixedbatchwriter/mixedbatchwriter.go b/writers/mixedbatchwriter/mixedbatchwriter.go index a53d3344c9..788a5657ac 100644 --- a/writers/mixedbatchwriter/mixedbatchwriter.go +++ b/writers/mixedbatchwriter/mixedbatchwriter.go @@ -4,7 +4,7 @@ import ( "context" "time" - "github.com/apache/arrow/go/v16/arrow/util" + "github.com/cloudquery/plugin-sdk/v4/internal/batch" "github.com/cloudquery/plugin-sdk/v4/message" "github.com/cloudquery/plugin-sdk/v4/writers" "github.com/rs/zerolog" @@ -92,8 +92,7 @@ func (w *MixedBatchWriter) Write(ctx context.Context, msgChan <-chan message.Wri insert := &insertBatchManager{ batch: make([]*message.WriteInsert, 0, w.batchSize), writeFunc: w.client.InsertBatch, - maxRows: w.batchSize, - maxBytes: w.batchSizeBytes, + limit: batch.CappedAt(w.batchSizeBytes, w.batchSize), logger: w.logger, } deleteStale := &batchManager[message.WriteDeleteStales, *message.WriteDeleteStale]{ @@ -199,47 +198,53 @@ func (m *batchManager[A, T]) flush(ctx context.Context) error { // special batch manager for insert messages that also keeps track of the total size of the batch type insertBatchManager struct { - batch []*message.WriteInsert - writeFunc func(ctx context.Context, messages message.WriteInserts) error - curRows, maxRows int64 - curBytes, maxBytes int64 - logger zerolog.Logger + batch []*message.WriteInsert + writeFunc func(ctx context.Context, messages message.WriteInserts) error + limit *batch.Cap + logger zerolog.Logger } func (m *insertBatchManager) append(ctx context.Context, msg *message.WriteInsert) error { - recordRows, recordBytes := msg.Record.NumRows(), util.TotalRecordSize(msg.Record) - if (m.maxRows > 0 && m.curRows+recordRows > m.maxRows) || - (m.maxBytes > 0 && m.curBytes+recordBytes > m.maxBytes) { + add, toFlush, rest := batch.SliceRecord(msg.Record, m.limit) + if add != nil { + m.batch = append(m.batch, &message.WriteInsert{Record: add.Record}) + } + if len(toFlush) > 0 || rest != nil || m.limit.ReachedLimit() { + // flush current batch + if err := m.flush(ctx); err != nil { + return err + } + } + for _, sliceToFlush := range toFlush { + m.batch = append(m.batch, &message.WriteInsert{Record: sliceToFlush}) if err := m.flush(ctx); err != nil { return err } } - if recordRows > 0 { - // only save records with rows - m.batch = append(m.batch, msg) - m.curRows += recordRows - m.curBytes += recordBytes + // set the remainder + if rest != nil { + m.batch = append(m.batch, &message.WriteInsert{Record: rest.Record}) } return nil } func (m *insertBatchManager) flush(ctx context.Context) error { - if m.curRows == 0 { + if m.limit.Rows() == 0 { // no rows to insert return nil } start := time.Now() err := m.writeFunc(ctx, m.batch) if err != nil { - m.logger.Err(err).Int64("len", m.curRows).Dur("duration", time.Since(start)).Msg("failed to write batch") + m.logger.Err(err).Int64("len", m.limit.Rows()).Dur("duration", time.Since(start)).Msg("failed to write batch") return err } - m.logger.Debug().Int64("len", m.curRows).Dur("duration", time.Since(start)).Msg("batch written successfully") + m.logger.Debug().Int64("len", m.limit.Rows()).Dur("duration", time.Since(start)).Msg("batch written successfully") clear(m.batch) // GC can work m.batch = m.batch[:0] - m.curRows, m.curBytes = 0, 0 + m.limit.Reset() return nil } diff --git a/writers/streamingbatchwriter/streamingbatchwriter.go b/writers/streamingbatchwriter/streamingbatchwriter.go index afa8b65577..91cca852ed 100644 --- a/writers/streamingbatchwriter/streamingbatchwriter.go +++ b/writers/streamingbatchwriter/streamingbatchwriter.go @@ -25,7 +25,7 @@ import ( "sync" "time" - "github.com/apache/arrow/go/v16/arrow/util" + "github.com/cloudquery/plugin-sdk/v4/internal/batch" "github.com/cloudquery/plugin-sdk/v4/message" "github.com/cloudquery/plugin-sdk/v4/schema" "github.com/cloudquery/plugin-sdk/v4/writers" @@ -233,9 +233,9 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err flush: make(chan chan bool), errCh: errCh, - batchSizeRows: w.batchSizeRows, - batchTimeout: w.batchTimeout, - tickerFn: w.tickerFn, + limit: batch.CappedAt(0, w.batchSizeRows), + batchTimeout: w.batchTimeout, + tickerFn: w.tickerFn, } w.workersWaitGroup.Add(1) @@ -257,9 +257,9 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err flush: make(chan chan bool), errCh: errCh, - batchSizeRows: w.batchSizeRows, - batchTimeout: w.batchTimeout, - tickerFn: w.tickerFn, + limit: batch.CappedAt(0, w.batchSizeRows), + batchTimeout: w.batchTimeout, + tickerFn: w.tickerFn, } w.workersWaitGroup.Add(1) @@ -283,10 +283,9 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err flush: make(chan chan bool), errCh: errCh, - batchSizeRows: w.batchSizeRows, - batchSizeBytes: w.batchSizeBytes, - batchTimeout: w.batchTimeout, - tickerFn: w.tickerFn, + limit: batch.CappedAt(w.batchSizeBytes, w.batchSizeRows), + batchTimeout: w.batchTimeout, + tickerFn: w.tickerFn, } w.workersLock.Lock() wrOld, ok := w.insertWorkers[tableName] @@ -320,9 +319,9 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err flush: make(chan chan bool), errCh: errCh, - batchSizeRows: w.batchSizeRows, - batchTimeout: w.batchTimeout, - tickerFn: w.tickerFn, + limit: batch.CappedAt(w.batchSizeBytes, w.batchSizeRows), + batchTimeout: w.batchTimeout, + tickerFn: w.tickerFn, } w.workersWaitGroup.Add(1) @@ -341,19 +340,17 @@ type streamingWorkerManager[T message.WriteMessage] struct { flush chan chan bool errCh chan<- error - batchSizeRows int64 - batchSizeBytes int64 - batchTimeout time.Duration - tickerFn writers.TickerFunc + limit *batch.Cap + batchTimeout time.Duration + tickerFn writers.TickerFunc } func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup, tableName string) { defer wg.Done() var ( - clientCh chan T - clientErrCh chan error - open bool - sizeBytes, sizeRows int64 + clientCh chan T + clientErrCh chan error + open bool ) ensureOpened := func() { @@ -382,7 +379,7 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup, } } open = false - sizeBytes, sizeRows = 0, 0 + s.limit.Reset() } defer closeFlush() @@ -398,44 +395,45 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup, return } - recordRows := int64(1) // at least 1 row for messages without records - var recordBytes int64 if ins, ok := any(r).(*message.WriteInsert); ok { - recordBytes = util.TotalRecordSize(ins.Record) - recordRows = ins.Record.NumRows() - } - - if (s.batchSizeRows > 0 && sizeRows+recordRows > s.batchSizeRows) || - (s.batchSizeBytes > 0 && sizeBytes+recordBytes > s.batchSizeBytes) { - if sizeRows == 0 { - // New record overflows batch by itself. - // Flush right away. - // TODO: slice + add, toFlush, rest := batch.SliceRecord(ins.Record, s.limit) + if add != nil { ensureOpened() - clientCh <- r + clientCh <- any(&message.WriteInsert{Record: add.Record}).(T) + } + if len(toFlush) > 0 || rest != nil || s.limit.ReachedLimit() { + // flush current batch + closeFlush() + ticker.Reset(s.batchTimeout) + } + for _, sliceToFlush := range toFlush { + ensureOpened() + clientCh <- any(&message.WriteInsert{Record: sliceToFlush}).(T) closeFlush() ticker.Reset(s.batchTimeout) - continue } - // sizeRows > 0 - closeFlush() - ticker.Reset(s.batchTimeout) - } - if recordRows > 0 { - // only save records with rows + // set the remainder + if rest != nil { + ensureOpened() + clientCh <- any(&message.WriteInsert{Record: rest.Record}).(T) + } + } else { ensureOpened() clientCh <- r - sizeRows += recordRows - sizeBytes += recordBytes + s.limit.AddRows(1) + if s.limit.ReachedLimit() { + closeFlush() + ticker.Reset(s.batchTimeout) + } } case <-tickerCh: - if sizeRows > 0 { + if s.limit.Rows() > 0 { closeFlush() } case done := <-s.flush: - if sizeRows > 0 { + if s.limit.Rows() > 0 { closeFlush() ticker.Reset(s.batchTimeout) }