cosmos-sdk/orm/model/ormtable/auto_increment.go

240 lines
5.9 KiB
Go

package ormtable
import (
"context"
"encoding/json"
"fmt"
"io"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"github.com/cosmos/cosmos-sdk/orm/encoding/ormkv"
"github.com/cosmos/cosmos-sdk/orm/types/kv"
"github.com/cosmos/cosmos-sdk/orm/types/ormerrors"
)
// autoIncrementTable is a Table implementation for tables with an
// auto-incrementing uint64 primary key.
type autoIncrementTable struct {
*tableImpl
autoIncField protoreflect.FieldDescriptor
seqCodec *ormkv.SeqCodec
}
func (t autoIncrementTable) InsertReturningID(ctx context.Context, message proto.Message) (newId uint64, err error) {
backend, err := t.getWriteBackend(ctx)
if err != nil {
return 0, err
}
return t.save(ctx, backend, message, saveModeInsert)
}
func (t autoIncrementTable) Save(ctx context.Context, message proto.Message) error {
backend, err := t.getWriteBackend(ctx)
if err != nil {
return err
}
_, err = t.save(ctx, backend, message, saveModeDefault)
return err
}
func (t autoIncrementTable) Insert(ctx context.Context, message proto.Message) error {
backend, err := t.getWriteBackend(ctx)
if err != nil {
return err
}
_, err = t.save(ctx, backend, message, saveModeInsert)
return err
}
func (t autoIncrementTable) Update(ctx context.Context, message proto.Message) error {
backend, err := t.getWriteBackend(ctx)
if err != nil {
return err
}
_, err = t.save(ctx, backend, message, saveModeUpdate)
return err
}
func (t *autoIncrementTable) save(ctx context.Context, backend Backend, message proto.Message, mode saveMode) (newId uint64, err error) {
messageRef := message.ProtoReflect()
val := messageRef.Get(t.autoIncField).Uint()
writer := newBatchIndexCommitmentWriter(backend)
defer writer.Close()
if val == 0 {
if mode == saveModeUpdate {
return 0, ormerrors.PrimaryKeyInvalidOnUpdate
}
mode = saveModeInsert
newId, err = t.nextSeqValue(writer.IndexStore())
if err != nil {
return 0, err
}
messageRef.Set(t.autoIncField, protoreflect.ValueOfUint64(newId))
} else {
if mode == saveModeInsert {
return 0, ormerrors.AutoIncrementKeyAlreadySet
}
mode = saveModeUpdate
}
return newId, t.tableImpl.doSave(ctx, writer, message, mode)
}
func (t *autoIncrementTable) curSeqValue(kv kv.ReadonlyStore) (uint64, error) {
bz, err := kv.Get(t.seqCodec.Prefix())
if err != nil {
return 0, err
}
return t.seqCodec.DecodeValue(bz)
}
func (t *autoIncrementTable) nextSeqValue(kv kv.Store) (uint64, error) {
seq, err := t.curSeqValue(kv)
if err != nil {
return 0, err
}
seq++
return seq, t.setSeqValue(kv, seq)
}
func (t *autoIncrementTable) setSeqValue(kv kv.Store, seq uint64) error {
return kv.Set(t.seqCodec.Prefix(), t.seqCodec.EncodeValue(seq))
}
func (t autoIncrementTable) EncodeEntry(entry ormkv.Entry) (k, v []byte, err error) {
if _, ok := entry.(*ormkv.SeqEntry); ok {
return t.seqCodec.EncodeEntry(entry)
}
return t.tableImpl.EncodeEntry(entry)
}
func (t autoIncrementTable) ValidateJSON(reader io.Reader) error {
return t.decodeAutoIncJson(nil, reader, func(message proto.Message, maxID uint64) error {
messageRef := message.ProtoReflect()
id := messageRef.Get(t.autoIncField).Uint()
if id > maxID {
return fmt.Errorf("invalid ID %d, expected a value <= %d, the highest sequence number", id, maxID)
}
if t.customJSONValidator != nil {
return t.customJSONValidator(message)
} else {
return DefaultJSONValidator(message)
}
})
}
func (t autoIncrementTable) ImportJSON(ctx context.Context, reader io.Reader) error {
backend, err := t.getWriteBackend(ctx)
if err != nil {
return err
}
return t.decodeAutoIncJson(backend, reader, func(message proto.Message, maxID uint64) error {
messageRef := message.ProtoReflect()
id := messageRef.Get(t.autoIncField).Uint()
if id == 0 {
// we don't have an ID in the JSON, so we call Save to insert and
// generate one
_, err = t.save(ctx, backend, message, saveModeInsert)
return err
} else {
if id > maxID {
return fmt.Errorf("invalid ID %d, expected a value <= %d, the highest sequence number", id, maxID)
}
// we do have an ID and calling Save will fail because it expects
// either no ID or SAVE_MODE_UPDATE. So instead we drop one level
// down and insert using tableImpl which doesn't know about
// auto-incrementing IDs
return t.tableImpl.save(ctx, backend, message, saveModeInsert)
}
})
}
func (t autoIncrementTable) decodeAutoIncJson(backend Backend, reader io.Reader, onMsg func(message proto.Message, maxID uint64) error) error {
decoder, err := t.startDecodeJson(reader)
if err != nil {
return err
}
var seq uint64
return t.doDecodeJson(decoder,
func(message json.RawMessage) bool {
err = json.Unmarshal(message, &seq)
if err == nil {
// writer is nil during validation
if backend != nil {
writer := newBatchIndexCommitmentWriter(backend)
defer writer.Close()
err = t.setSeqValue(writer.IndexStore(), seq)
if err != nil {
panic(err)
}
err = writer.Write()
if err != nil {
panic(err)
}
}
return true
}
return false
},
func(message proto.Message) error {
return onMsg(message, seq)
})
}
func (t autoIncrementTable) ExportJSON(ctx context.Context, writer io.Writer) error {
backend, err := t.getBackend(ctx)
if err != nil {
return err
}
_, err = writer.Write([]byte("["))
if err != nil {
return err
}
seq, err := t.curSeqValue(backend.IndexStoreReader())
if err != nil {
return err
}
start := true
if seq != 0 {
start = false
bz, err := json.Marshal(seq)
if err != nil {
return err
}
_, err = writer.Write(bz)
if err != nil {
return err
}
}
return t.doExportJSON(ctx, writer, start)
}
func (t *autoIncrementTable) GetTable(message proto.Message) Table {
if message.ProtoReflect().Descriptor().FullName() == t.MessageType().Descriptor().FullName() {
return t
}
return nil
}
var _ AutoIncrementTable = &autoIncrementTable{}