diff --git a/types/lib/stdlib.go b/types/lib/stdlib.go index caf426585..f075ce3ff 100644 --- a/types/lib/stdlib.go +++ b/types/lib/stdlib.go @@ -2,6 +2,8 @@ package lib import ( "fmt" + "strconv" + "strings" sdk "github.com/cosmos/cosmos-sdk/types" wire "github.com/cosmos/cosmos-sdk/wire" @@ -11,38 +13,45 @@ import ( // It panics when the element type cannot be (un/)marshalled by the codec type ListMapper interface { - // ListMapper dosen't checks index out of range + // ListMapper dosen't check if an index is in bounds // The user should check Len() before doing any actions - Len(sdk.Context) int64 - Get(sdk.Context, int64, interface{}) - // Setting element out of range is harmful; use Push() when adding new elements - Set(sdk.Context, int64, interface{}) - Delete(sdk.Context, int64) + Len(sdk.Context) uint64 + + Get(sdk.Context, uint64, interface{}) error + + // Setting element out of range is harmful + // Use Push() instead of Set() to append a new element + Set(sdk.Context, uint64, interface{}) + + Delete(sdk.Context, uint64) + Push(sdk.Context, interface{}) - Iterate(sdk.Context, interface{}, func(sdk.Context, int64) bool) + + // Iterate*() is used to iterate over all existing elements in the list + // Return true in the continuation to break + + // CONTRACT: No writes may happen within a domain while an iterator exists over it. + IterateRead(sdk.Context, interface{}, func(sdk.Context, uint64) bool) + + // IterateWrite() is safe to write over the domain + IterateWrite(sdk.Context, interface{}, func(sdk.Context, uint64) bool) } type listMapper struct { key sdk.StoreKey cdc *wire.Codec prefix string - lk []byte } func NewListMapper(cdc *wire.Codec, key sdk.StoreKey, prefix string) ListMapper { - lk, err := cdc.MarshalBinary(int64(-1)) - if err != nil { - panic(err) - } return listMapper{ key: key, cdc: cdc, prefix: prefix, - lk: lk, } } -func (lm listMapper) Len(ctx sdk.Context) int64 { +func (lm listMapper) Len(ctx sdk.Context) uint64 { store := ctx.KVStore(lm.key) bz := store.Get(lm.LengthKey()) if bz == nil { @@ -53,28 +62,20 @@ func (lm listMapper) Len(ctx sdk.Context) int64 { store.Set(lm.LengthKey(), zero) return 0 } - var res int64 + var res uint64 if err := lm.cdc.UnmarshalBinary(bz, &res); err != nil { panic(err) } return res } -func (lm listMapper) Get(ctx sdk.Context, index int64, ptr interface{}) { - if index < 0 { - panic(fmt.Errorf("Invalid index in ListMapper.Get(ctx, %d, ptr)", index)) - } +func (lm listMapper) Get(ctx sdk.Context, index uint64, ptr interface{}) error { store := ctx.KVStore(lm.key) bz := store.Get(lm.ElemKey(index)) - if err := lm.cdc.UnmarshalBinary(bz, ptr); err != nil { - panic(err) - } + return lm.cdc.UnmarshalBinary(bz, ptr) } -func (lm listMapper) Set(ctx sdk.Context, index int64, value interface{}) { - if index < 0 { - panic(fmt.Errorf("Invalid index in ListMapper.Set(ctx, %d, value)", index)) - } +func (lm listMapper) Set(ctx sdk.Context, index uint64, value interface{}) { store := ctx.KVStore(lm.key) bz, err := lm.cdc.MarshalBinary(value) if err != nil { @@ -83,10 +84,7 @@ func (lm listMapper) Set(ctx sdk.Context, index int64, value interface{}) { store.Set(lm.ElemKey(index), bz) } -func (lm listMapper) Delete(ctx sdk.Context, index int64) { - if index < 0 { - panic(fmt.Errorf("Invalid index in ListMapper.Delete(ctx, %d)", index)) - } +func (lm listMapper) Delete(ctx sdk.Context, index uint64) { store := ctx.KVStore(lm.key) store.Delete(lm.ElemKey(index)) } @@ -96,13 +94,38 @@ func (lm listMapper) Push(ctx sdk.Context, value interface{}) { lm.Set(ctx, length, value) store := ctx.KVStore(lm.key) - store.Set(lm.LengthKey(), marshalInt64(lm.cdc, length+1)) + store.Set(lm.LengthKey(), marshalUint64(lm.cdc, length+1)) } -func (lm listMapper) Iterate(ctx sdk.Context, ptr interface{}, fn func(sdk.Context, int64) bool) { +func (lm listMapper) IterateRead(ctx sdk.Context, ptr interface{}, fn func(sdk.Context, uint64) bool) { + store := ctx.KVStore(lm.key) + start, end := subspace([]byte(fmt.Sprintf("%s/elem/", lm.prefix))) + iter := store.Iterator(start, end) + for ; iter.Valid(); iter.Next() { + v := iter.Value() + if err := lm.cdc.UnmarshalBinary(v, ptr); err != nil { + panic(err) + } + s := strings.Split(string(iter.Key()), "/") + index, err := strconv.ParseUint(s[len(s)-1], 10, 64) + if err != nil { + panic(err) + } + if fn(ctx, index) { + break + } + } + + iter.Close() +} + +func (lm listMapper) IterateWrite(ctx sdk.Context, ptr interface{}, fn func(sdk.Context, uint64) bool) { length := lm.Len(ctx) - for i := int64(0); i < length; i++ { - lm.Get(ctx, i, ptr) + + for i := uint64(0); i < length; i++ { + if err := lm.Get(ctx, i, ptr); err != nil { + continue + } if fn(ctx, i) { break } @@ -110,11 +133,11 @@ func (lm listMapper) Iterate(ctx sdk.Context, ptr interface{}, fn func(sdk.Conte } func (lm listMapper) LengthKey() []byte { - return []byte(fmt.Sprintf("%s/%d", lm.prefix, lm.lk)) + return []byte(fmt.Sprintf("%s/length", lm.prefix)) } -func (lm listMapper) ElemKey(i int64) []byte { - return []byte(fmt.Sprintf("%s/%d", lm.prefix, i)) +func (lm listMapper) ElemKey(i uint64) []byte { + return []byte(fmt.Sprintf("%s/elem/%020d", lm.prefix, i)) } // QueueMapper is a Mapper interface that provides queue-like functions @@ -124,12 +147,14 @@ type QueueMapper interface { Push(sdk.Context, interface{}) // Popping/Peeking on an empty queue will cause panic // The user should check IsEmpty() before doing any actions - Peek(sdk.Context, interface{}) + Peek(sdk.Context, interface{}) error Pop(sdk.Context) IsEmpty(sdk.Context) bool - // Iterate() removes elements it processed; return true in the continuation to break - Iterate(sdk.Context, interface{}, func(sdk.Context) bool) - Info(sdk.Context) QueueInfo + // Iterate() removes elements it processed + // Return true in the continuation to break + // The interface{} is unmarshalled before the continuation is called + // Starts from the top(head) of the queue + Flush(sdk.Context, interface{}, func(sdk.Context) bool) } type queueMapper struct { @@ -137,100 +162,67 @@ type queueMapper struct { cdc *wire.Codec prefix string lm ListMapper - lk []byte - ik []byte } func NewQueueMapper(cdc *wire.Codec, key sdk.StoreKey, prefix string) QueueMapper { - lk := []byte("list") - ik := []byte("info") return queueMapper{ key: key, cdc: cdc, prefix: prefix, - lm: NewListMapper(cdc, key, prefix+string(lk)), - lk: lk, - ik: ik, + lm: NewListMapper(cdc, key, prefix+"list"), } } -type QueueInfo struct { - // begin <= elems < end - Begin int64 - End int64 -} - -func (info QueueInfo) validateBasic() error { - if info.End < info.Begin || info.Begin < 0 || info.End < 0 { - return fmt.Errorf("Invalid queue information: {Begin: %d, End: %d}", info.Begin, info.End) - } - return nil -} - -func (info QueueInfo) isEmpty() bool { - return info.Begin == info.End -} - -func (qm queueMapper) getQueueInfo(store sdk.KVStore) QueueInfo { - bz := store.Get(qm.InfoKey()) +func (qm queueMapper) getTop(store sdk.KVStore) (res uint64) { + bz := store.Get(qm.TopKey()) if bz == nil { - store.Set(qm.InfoKey(), marshalQueueInfo(qm.cdc, QueueInfo{0, 0})) - return QueueInfo{0, 0} + store.Set(qm.TopKey(), marshalUint64(qm.cdc, 0)) + return 0 } - var info QueueInfo - if err := qm.cdc.UnmarshalBinary(bz, &info); err != nil { + + if err := qm.cdc.UnmarshalBinary(bz, &res); err != nil { panic(err) } - if err := info.validateBasic(); err != nil { - panic(err) - } - return info + + return } -func (qm queueMapper) setQueueInfo(store sdk.KVStore, info QueueInfo) { - bz, err := qm.cdc.MarshalBinary(info) - if err != nil { - panic(err) - } - store.Set(qm.InfoKey(), bz) +func (qm queueMapper) setTop(store sdk.KVStore, top uint64) { + bz := marshalUint64(qm.cdc, top) + store.Set(qm.TopKey(), bz) } func (qm queueMapper) Push(ctx sdk.Context, value interface{}) { - store := ctx.KVStore(qm.key) - info := qm.getQueueInfo(store) - - qm.lm.Set(ctx, info.End, value) - - info.End++ - qm.setQueueInfo(store, info) + qm.lm.Push(ctx, value) } -func (qm queueMapper) Peek(ctx sdk.Context, ptr interface{}) { +func (qm queueMapper) Peek(ctx sdk.Context, ptr interface{}) error { store := ctx.KVStore(qm.key) - info := qm.getQueueInfo(store) - qm.lm.Get(ctx, info.Begin, ptr) + top := qm.getTop(store) + return qm.lm.Get(ctx, top, ptr) } func (qm queueMapper) Pop(ctx sdk.Context) { store := ctx.KVStore(qm.key) - info := qm.getQueueInfo(store) - qm.lm.Delete(ctx, info.Begin) - info.Begin++ - qm.setQueueInfo(store, info) + top := qm.getTop(store) + qm.lm.Delete(ctx, top) + qm.setTop(store, top+1) } func (qm queueMapper) IsEmpty(ctx sdk.Context) bool { store := ctx.KVStore(qm.key) - info := qm.getQueueInfo(store) - return info.isEmpty() + top := qm.getTop(store) + length := qm.lm.Len(ctx) + return top >= length } -func (qm queueMapper) Iterate(ctx sdk.Context, ptr interface{}, fn func(sdk.Context) bool) { +func (qm queueMapper) Flush(ctx sdk.Context, ptr interface{}, fn func(sdk.Context) bool) { store := ctx.KVStore(qm.key) - info := qm.getQueueInfo(store) + top := qm.getTop(store) + length := qm.lm.Len(ctx) - var i int64 - for i = info.Begin; i < info.End; i++ { + var i uint64 + for i = top; i < length; i++ { qm.lm.Get(ctx, i, ptr) qm.lm.Delete(ctx, i) if fn(ctx) { @@ -238,32 +230,24 @@ func (qm queueMapper) Iterate(ctx sdk.Context, ptr interface{}, fn func(sdk.Cont } } - info.Begin = i - qm.setQueueInfo(store, info) + qm.setTop(store, i) } -func (qm queueMapper) Info(ctx sdk.Context) QueueInfo { - store := ctx.KVStore(qm.key) - - return qm.getQueueInfo(store) +func (qm queueMapper) TopKey() []byte { + return []byte(fmt.Sprintf("%s/top", qm.prefix)) } -func (qm queueMapper) InfoKey() []byte { - return []byte(fmt.Sprintf("%s/%s", qm.prefix, qm.ik)) -} - -func marshalQueueInfo(cdc *wire.Codec, info QueueInfo) []byte { - bz, err := cdc.MarshalBinary(info) - if err != nil { - panic(err) - } - return bz -} - -func marshalInt64(cdc *wire.Codec, i int64) []byte { +func marshalUint64(cdc *wire.Codec, i uint64) []byte { bz, err := cdc.MarshalBinary(i) if err != nil { panic(err) } return bz } + +func subspace(prefix []byte) (start, end []byte) { + end = make([]byte, len(prefix)) + copy(end, prefix) + end[len(end)-1]++ + return prefix, end +} diff --git a/types/lib/stdlib_test.go b/types/lib/stdlib_test.go index 197e6cbe4..1058fed6e 100644 --- a/types/lib/stdlib_test.go +++ b/types/lib/stdlib_test.go @@ -9,13 +9,13 @@ import ( abci "github.com/tendermint/abci/types" - store "github.com/cosmos/cosmos-sdk/mock" + "github.com/cosmos/cosmos-sdk/store" sdk "github.com/cosmos/cosmos-sdk/types" wire "github.com/cosmos/cosmos-sdk/wire" ) type S struct { - I int64 + I uint64 B bool } @@ -38,20 +38,39 @@ func TestListMapper(t *testing.T) { var res S lm.Push(ctx, val) - assert.Equal(t, int64(1), lm.Len(ctx)) - lm.Get(ctx, int64(0), &res) + assert.Equal(t, uint64(1), lm.Len(ctx)) + lm.Get(ctx, uint64(0), &res) assert.Equal(t, val, res) val = S{2, false} - lm.Set(ctx, int64(0), val) - lm.Get(ctx, int64(0), &res) + lm.Set(ctx, uint64(0), val) + lm.Get(ctx, uint64(0), &res) assert.Equal(t, val, res) - lm.Iterate(ctx, &res, func(ctx sdk.Context, index int64) (brk bool) { + val = S{100, false} + lm.Push(ctx, val) + assert.Equal(t, uint64(2), lm.Len(ctx)) + lm.Get(ctx, uint64(1), &res) + assert.Equal(t, val, res) + + lm.Delete(ctx, uint64(1)) + assert.Equal(t, uint64(2), lm.Len(ctx)) + + lm.IterateRead(ctx, &res, func(ctx sdk.Context, index uint64) (brk bool) { + var temp S + lm.Get(ctx, index, &temp) + assert.Equal(t, temp, res) + + assert.True(t, index != 1) + return + }) + + lm.IterateWrite(ctx, &res, func(ctx sdk.Context, index uint64) (brk bool) { lm.Set(ctx, index, S{res.I + 1, !res.B}) return }) - lm.Get(ctx, int64(0), &res) + + lm.Get(ctx, uint64(0), &res) assert.Equal(t, S{3, true}, res) } @@ -71,12 +90,12 @@ func TestQueueMapper(t *testing.T) { empty := qm.IsEmpty(ctx) assert.True(t, empty) - assert.Panics(t, func() { qm.Peek(ctx, &res) }) + assert.NotNil(t, qm.Peek(ctx, &res)) qm.Push(ctx, S{1, true}) qm.Push(ctx, S{2, true}) qm.Push(ctx, S{3, true}) - qm.Iterate(ctx, &res, func(ctx sdk.Context) (brk bool) { + qm.Flush(ctx, &res, func(ctx sdk.Context) (brk bool) { if res.I == 3 { brk = true } @@ -84,7 +103,6 @@ func TestQueueMapper(t *testing.T) { }) assert.False(t, qm.IsEmpty(ctx)) - assert.Equal(t, QueueInfo{3, 4}, qm.Info(ctx)) qm.Pop(ctx) assert.True(t, qm.IsEmpty(ctx))