From 4e8172d1a1b86bcd2427206623f07ff84c07284c Mon Sep 17 00:00:00 2001 From: Aaron Craelius Date: Tue, 1 Feb 2022 12:14:46 -0500 Subject: [PATCH] feat(orm): return newly generated ID with auto-increment tables (#11040) ## Description Adds a new interface `AutoIncrementTable` which extends `Table` and has a method `InsertWithID` which returns the newly generated ID. The new ID is also set on the message itself, but it feels like a nice improvement to have this method in real usage. --- ### Author Checklist *All items are required. Please add a note to the item if the item is not applicable and please add links to any relevant follow up issues.* I have... - [ ] included the correct [type prefix](https://github.com/commitizen/conventional-commit-types/blob/v3.0.0/index.json) in the PR title - [ ] added `!` to the type prefix if API or client breaking change - [ ] targeted the correct branch (see [PR Targeting](https://github.com/cosmos/cosmos-sdk/blob/master/CONTRIBUTING.md#pr-targeting)) - [ ] provided a link to the relevant issue or specification - [ ] followed the guidelines for [building modules](https://github.com/cosmos/cosmos-sdk/blob/master/docs/building-modules) - [ ] included the necessary unit and integration [tests](https://github.com/cosmos/cosmos-sdk/blob/master/CONTRIBUTING.md#testing) - [ ] added a changelog entry to `CHANGELOG.md` - [ ] included comments for [documenting Go code](https://blog.golang.org/godoc) - [ ] updated the relevant documentation or specification - [ ] reviewed "Files changed" and left comments if necessary - [ ] confirmed all CI checks have passed ### Reviewers Checklist *All items are required. Please add a note if the item is not applicable and please add your handle next to the items reviewed if you only reviewed selected items.* I have... - [ ] confirmed the correct [type prefix](https://github.com/commitizen/conventional-commit-types/blob/v3.0.0/index.json) in the PR title - [ ] confirmed `!` in the type prefix if API or client breaking change - [ ] confirmed all author checklist items have been addressed - [ ] reviewed state machine logic - [ ] reviewed API design and naming - [ ] reviewed documentation is accurate - [ ] reviewed tests and test coverage - [ ] manually tested (if applicable) --- orm/internal/codegen/codegen.go | 12 +++--- orm/internal/codegen/file.go | 4 -- orm/internal/codegen/index.go | 2 +- orm/internal/codegen/singleton.go | 2 +- orm/internal/codegen/table.go | 29 +++++++++++++-- orm/internal/testpb/test_schema.cosmos_orm.go | 9 ++++- orm/model/ormtable/auto_increment.go | 37 +++++++++++++------ orm/model/ormtable/auto_increment_test.go | 23 ++++++++---- orm/model/ormtable/table.go | 8 ++++ .../ormtable/testdata/test_auto_inc.golden | 23 ++++++++++++ 10 files changed, 113 insertions(+), 36 deletions(-) diff --git a/orm/internal/codegen/codegen.go b/orm/internal/codegen/codegen.go index 20ca63c33..1d52cec2e 100644 --- a/orm/internal/codegen/codegen.go +++ b/orm/internal/codegen/codegen.go @@ -13,12 +13,12 @@ import ( ) const ( - contextPkg = protogen.GoImportPath("context") - protoreflectPackage = protogen.GoImportPath("google.golang.org/protobuf/reflect/protoreflect") - ormListPkg = protogen.GoImportPath("github.com/cosmos/cosmos-sdk/orm/model/ormlist") - ormdbPkg = protogen.GoImportPath("github.com/cosmos/cosmos-sdk/orm/model/ormdb") - ormErrPkg = protogen.GoImportPath("github.com/cosmos/cosmos-sdk/orm/types/ormerrors") - fmtPkg = protogen.GoImportPath("fmt") + contextPkg = protogen.GoImportPath("context") + + ormListPkg = protogen.GoImportPath("github.com/cosmos/cosmos-sdk/orm/model/ormlist") + ormdbPkg = protogen.GoImportPath("github.com/cosmos/cosmos-sdk/orm/model/ormdb") + ormErrPkg = protogen.GoImportPath("github.com/cosmos/cosmos-sdk/orm/types/ormerrors") + ormTablePkg = protogen.GoImportPath("github.com/cosmos/cosmos-sdk/orm/model/ormtable") ) func PluginRunner(p *protogen.Plugin) error { diff --git a/orm/internal/codegen/file.go b/orm/internal/codegen/file.go index 2f2d0772f..84e482c0a 100644 --- a/orm/internal/codegen/file.go +++ b/orm/internal/codegen/file.go @@ -12,10 +12,6 @@ import ( v1alpha1 "github.com/cosmos/cosmos-sdk/api/cosmos/orm/v1alpha1" ) -var ( - tablePkg = protogen.GoImportPath("github.com/cosmos/cosmos-sdk/orm/model/ormtable") -) - type fileGen struct { *generator.GeneratedFile file *protogen.File diff --git a/orm/internal/codegen/index.go b/orm/internal/codegen/index.go index a7024c34f..61c1281e3 100644 --- a/orm/internal/codegen/index.go +++ b/orm/internal/codegen/index.go @@ -28,7 +28,7 @@ func (t tableGen) genIndexKeys() { func (t tableGen) genIterator() { t.P("type ", t.iteratorName(), " struct {") - t.P(tablePkg.Ident("Iterator")) + t.P(ormTablePkg.Ident("Iterator")) t.P("}") t.P() t.genValueFunc() diff --git a/orm/internal/codegen/singleton.go b/orm/internal/codegen/singleton.go index ddf739ed6..ce679131a 100644 --- a/orm/internal/codegen/singleton.go +++ b/orm/internal/codegen/singleton.go @@ -46,7 +46,7 @@ func (s singletonGen) genInterface() { func (s singletonGen) genStruct() { s.P("type ", s.messageStoreReceiverName(s.msg), " struct {") - s.P("table ", tablePkg.Ident("Table")) + s.P("table ", ormTablePkg.Ident("Table")) s.P("}") s.P() } diff --git a/orm/internal/codegen/table.go b/orm/internal/codegen/table.go index cd6e96a8b..93ffd20e6 100644 --- a/orm/internal/codegen/table.go +++ b/orm/internal/codegen/table.go @@ -57,6 +57,9 @@ func (t tableGen) gen() { func (t tableGen) genStoreInterface() { t.P("type ", t.messageStoreInterfaceName(t.msg), " interface {") t.P("Insert(ctx ", contextPkg.Ident("Context"), ", ", t.param(t.msg.GoIdent.GoName), " *", t.QualifiedGoIdent(t.msg.GoIdent), ") error") + if t.table.PrimaryKey.AutoIncrement { + t.P("InsertReturningID(ctx ", contextPkg.Ident("Context"), ", ", t.param(t.msg.GoIdent.GoName), " *", t.QualifiedGoIdent(t.msg.GoIdent), ") (uint64, error)") + } t.P("Update(ctx ", contextPkg.Ident("Context"), ", ", t.param(t.msg.GoIdent.GoName), " *", t.QualifiedGoIdent(t.msg.GoIdent), ") error") t.P("Save(ctx ", contextPkg.Ident("Context"), ", ", t.param(t.msg.GoIdent.GoName), " *", t.QualifiedGoIdent(t.msg.GoIdent), ") error") t.P("Delete(ctx ", contextPkg.Ident("Context"), ", ", t.param(t.msg.GoIdent.GoName), " *", t.QualifiedGoIdent(t.msg.GoIdent), ") error") @@ -144,7 +147,11 @@ func (t tableGen) fieldArg(name protoreflect.Name) string { func (t tableGen) genStruct() { t.P("type ", t.messageStoreReceiverName(t.msg), " struct {") - t.P("table ", tablePkg.Ident("Table")) + if t.table.PrimaryKey.AutoIncrement { + t.P("table ", ormTablePkg.Ident("AutoIncrementTable")) + } else { + t.P("table ", ormTablePkg.Ident("Table")) + } t.P("}") t.storeStructName() } @@ -164,6 +171,13 @@ func (t tableGen) genStoreImpl() { t.P() } + if t.table.PrimaryKey.AutoIncrement { + t.P(receiver, "InsertReturningID(ctx ", contextPkg.Ident("Context"), ", ", varName, " *", varTypeName, ") (uint64, error) {") + t.P("return ", receiverVar, ".table.InsertReturningID(ctx, ", varName, ")") + t.P("}") + t.P() + } + // Has t.P(receiver, "Has(ctx ", contextPkg.Ident("Context"), ", ", t.fieldsArgs(t.primaryKeyFields.Names()), ") (found bool, err error) {") t.P("return ", receiverVar, ".table.PrimaryKey().Has(ctx, ", t.primaryKeyFields.String(), ")") @@ -188,7 +202,7 @@ func (t tableGen) genStoreImpl() { // has t.P("func (", receiverVar, " ", t.messageStoreReceiverName(t.msg), ") ", hasName, "{") t.P("return ", receiverVar, ".table.GetIndexByID(", idx.Id, ").(", - tablePkg.Ident("UniqueIndex"), ").Has(ctx,") + ormTablePkg.Ident("UniqueIndex"), ").Has(ctx,") for _, field := range fields { t.P(field, ",") } @@ -202,7 +216,7 @@ func (t tableGen) genStoreImpl() { t.P("func (", receiverVar, " ", t.messageStoreReceiverName(t.msg), ") ", getName, "{") t.P("var ", varName, " ", varTypeName) t.P("found, err := ", receiverVar, ".table.GetIndexByID(", idx.Id, ").(", - tablePkg.Ident("UniqueIndex"), ").Get(ctx, &", varName, ",") + ormTablePkg.Ident("UniqueIndex"), ").Get(ctx, &", varName, ",") for _, field := range fields { t.P(field, ",") } @@ -246,6 +260,13 @@ func (t tableGen) genConstructor() { t.P("if table == nil {") t.P("return nil,", ormErrPkg.Ident("TableNotFound.Wrap"), "(string((&", t.msg.GoIdent.GoName, "{}).ProtoReflect().Descriptor().FullName()))") t.P("}") - t.P("return ", t.messageStoreReceiverName(t.msg), "{table}, nil") + if t.table.PrimaryKey.AutoIncrement { + t.P( + "return ", t.messageStoreReceiverName(t.msg), "{table.(", + ormTablePkg.Ident("AutoIncrementTable"), ")}, nil", + ) + } else { + t.P("return ", t.messageStoreReceiverName(t.msg), "{table}, nil") + } t.P("}") } diff --git a/orm/internal/testpb/test_schema.cosmos_orm.go b/orm/internal/testpb/test_schema.cosmos_orm.go index 573103371..14eb4f09a 100644 --- a/orm/internal/testpb/test_schema.cosmos_orm.go +++ b/orm/internal/testpb/test_schema.cosmos_orm.go @@ -200,6 +200,7 @@ func NewExampleTableStore(db ormdb.ModuleDB) (ExampleTableStore, error) { type ExampleAutoIncrementTableStore interface { Insert(ctx context.Context, exampleAutoIncrementTable *ExampleAutoIncrementTable) error + InsertReturningID(ctx context.Context, exampleAutoIncrementTable *ExampleAutoIncrementTable) (uint64, error) Update(ctx context.Context, exampleAutoIncrementTable *ExampleAutoIncrementTable) error Save(ctx context.Context, exampleAutoIncrementTable *ExampleAutoIncrementTable) error Delete(ctx context.Context, exampleAutoIncrementTable *ExampleAutoIncrementTable) error @@ -259,7 +260,7 @@ func (this ExampleAutoIncrementTableXIndexKey) WithX(x string) ExampleAutoIncrem } type exampleAutoIncrementTableStore struct { - table ormtable.Table + table ormtable.AutoIncrementTable } func (this exampleAutoIncrementTableStore) Insert(ctx context.Context, exampleAutoIncrementTable *ExampleAutoIncrementTable) error { @@ -278,6 +279,10 @@ func (this exampleAutoIncrementTableStore) Delete(ctx context.Context, exampleAu return this.table.Delete(ctx, exampleAutoIncrementTable) } +func (this exampleAutoIncrementTableStore) InsertReturningID(ctx context.Context, exampleAutoIncrementTable *ExampleAutoIncrementTable) (uint64, error) { + return this.table.InsertReturningID(ctx, exampleAutoIncrementTable) +} + func (this exampleAutoIncrementTableStore) Has(ctx context.Context, id uint64) (found bool, err error) { return this.table.PrimaryKey().Has(ctx, id) } @@ -329,7 +334,7 @@ func NewExampleAutoIncrementTableStore(db ormdb.ModuleDB) (ExampleAutoIncrementT if table == nil { return nil, ormerrors.TableNotFound.Wrap(string((&ExampleAutoIncrementTable{}).ProtoReflect().Descriptor().FullName())) } - return exampleAutoIncrementTableStore{table}, nil + return exampleAutoIncrementTableStore{table.(ormtable.AutoIncrementTable)}, nil } // singleton store diff --git a/orm/model/ormtable/auto_increment.go b/orm/model/ormtable/auto_increment.go index 5e1b6ab5b..24132d696 100644 --- a/orm/model/ormtable/auto_increment.go +++ b/orm/model/ormtable/auto_increment.go @@ -22,13 +22,23 @@ type autoIncrementTable struct { seqCodec *ormkv.SeqCodec } +func (t autoIncrementTable) InsertReturningID(ctx context.Context, message proto.Message) (newId uint64, err error) { + backend, err := t.getBackend(ctx) + if err != nil { + return 0, err + } + + return t.save(backend, message, saveModeInsert) +} + func (t autoIncrementTable) Save(ctx context.Context, message proto.Message) error { backend, err := t.getBackend(ctx) if err != nil { return err } - return t.save(backend, message, saveModeDefault) + _, err = t.save(backend, message, saveModeDefault) + return err } func (t autoIncrementTable) Insert(ctx context.Context, message proto.Message) error { @@ -37,7 +47,8 @@ func (t autoIncrementTable) Insert(ctx context.Context, message proto.Message) e return err } - return t.save(backend, message, saveModeInsert) + _, err = t.save(backend, message, saveModeInsert) + return err } func (t autoIncrementTable) Update(ctx context.Context, message proto.Message) error { @@ -46,10 +57,11 @@ func (t autoIncrementTable) Update(ctx context.Context, message proto.Message) e return err } - return t.save(backend, message, saveModeUpdate) + _, err = t.save(backend, message, saveModeUpdate) + return err } -func (t *autoIncrementTable) save(backend Backend, message proto.Message, mode saveMode) error { +func (t *autoIncrementTable) save(backend Backend, message proto.Message, mode saveMode) (newId uint64, err error) { messageRef := message.ProtoReflect() val := messageRef.Get(t.autoIncField).Uint() writer := newBatchIndexCommitmentWriter(backend) @@ -57,25 +69,25 @@ func (t *autoIncrementTable) save(backend Backend, message proto.Message, mode s if val == 0 { if mode == saveModeUpdate { - return ormerrors.PrimaryKeyInvalidOnUpdate + return 0, ormerrors.PrimaryKeyInvalidOnUpdate } mode = saveModeInsert - key, err := t.nextSeqValue(writer.IndexStore()) + newId, err = t.nextSeqValue(writer.IndexStore()) if err != nil { - return err + return 0, err } - messageRef.Set(t.autoIncField, protoreflect.ValueOfUint64(key)) + messageRef.Set(t.autoIncField, protoreflect.ValueOfUint64(newId)) } else { if mode == saveModeInsert { - return ormerrors.AutoIncrementKeyAlreadySet + return 0, ormerrors.AutoIncrementKeyAlreadySet } mode = saveModeUpdate } - return t.tableImpl.doSave(writer, message, mode) + return newId, t.tableImpl.doSave(writer, message, mode) } func (t *autoIncrementTable) curSeqValue(kv kv.ReadonlyStore) (uint64, error) { @@ -136,7 +148,8 @@ func (t autoIncrementTable) ImportJSON(ctx context.Context, reader io.Reader) er if id == 0 { // we don't have an ID in the JSON, so we call Save to insert and // generate one - return t.save(backend, message, saveModeInsert) + _, err = t.save(backend, message, saveModeInsert) + return err } else { if id > maxID { return fmt.Errorf("invalid ID %d, expected a value <= %d", id, maxID) @@ -223,3 +236,5 @@ func (t *autoIncrementTable) GetTable(message proto.Message) Table { } return nil } + +var _ AutoIncrementTable = &autoIncrementTable{} diff --git a/orm/model/ormtable/auto_increment_test.go b/orm/model/ormtable/auto_increment_test.go index c3404b338..efa54470a 100644 --- a/orm/model/ormtable/auto_increment_test.go +++ b/orm/model/ormtable/auto_increment_test.go @@ -21,8 +21,11 @@ func TestAutoIncrementScenario(t *testing.T) { }) assert.NilError(t, err) + autoTable, ok := table.(ormtable.AutoIncrementTable) + assert.Assert(t, ok) + // first run tests with a split index-commitment store - runAutoIncrementScenario(t, table, ormtable.WrapContextDefault(testkv.NewSplitMemBackend())) + runAutoIncrementScenario(t, autoTable, ormtable.WrapContextDefault(testkv.NewSplitMemBackend())) // now run with shared store and debugging debugBuf := &strings.Builder{} @@ -33,29 +36,35 @@ func TestAutoIncrementScenario(t *testing.T) { Print: func(s string) { debugBuf.WriteString(s + "\n") }, }, ) - runAutoIncrementScenario(t, table, ormtable.WrapContextDefault(store)) + runAutoIncrementScenario(t, autoTable, ormtable.WrapContextDefault(store)) golden.Assert(t, debugBuf.String(), "test_auto_inc.golden") checkEncodeDecodeEntries(t, table, store.IndexStoreReader()) } -func runAutoIncrementScenario(t *testing.T, table ormtable.Table, context context.Context) { +func runAutoIncrementScenario(t *testing.T, table ormtable.AutoIncrementTable, ctx context.Context) { store, err := testpb.NewExampleAutoIncrementTableStore(table) assert.NilError(t, err) - err = store.Save(context, &testpb.ExampleAutoIncrementTable{Id: 5}) + err = store.Save(ctx, &testpb.ExampleAutoIncrementTable{Id: 5}) assert.ErrorContains(t, err, "update") ex1 := &testpb.ExampleAutoIncrementTable{X: "foo", Y: 5} - assert.NilError(t, store.Save(context, ex1)) + assert.NilError(t, store.Save(ctx, ex1)) assert.Equal(t, uint64(1), ex1.Id) + ex2 := &testpb.ExampleAutoIncrementTable{X: "bar", Y: 10} + newId, err := table.InsertReturningID(ctx, ex2) + assert.NilError(t, err) + assert.Equal(t, uint64(2), ex2.Id) + assert.Equal(t, newId, ex2.Id) + buf := &bytes.Buffer{} - assert.NilError(t, table.ExportJSON(context, buf)) + assert.NilError(t, table.ExportJSON(ctx, buf)) assert.NilError(t, table.ValidateJSON(bytes.NewReader(buf.Bytes()))) store2 := ormtable.WrapContextDefault(testkv.NewSplitMemBackend()) assert.NilError(t, table.ImportJSON(store2, bytes.NewReader(buf.Bytes()))) - assertTablesEqual(t, table, context, store2) + assertTablesEqual(t, table, ctx, store2) } func TestBadJSON(t *testing.T) { diff --git a/orm/model/ormtable/table.go b/orm/model/ormtable/table.go index 88625077d..2fb3c07ee 100644 --- a/orm/model/ormtable/table.go +++ b/orm/model/ormtable/table.go @@ -140,3 +140,11 @@ type Schema interface { // GetTable returns the table for the provided message type or nil. GetTable(message proto.Message) Table } + +type AutoIncrementTable interface { + Table + + // InsertReturningID inserts the provided entry in the store and returns the newly + // generated ID for the message or an error. + InsertReturningID(ctx context.Context, message proto.Message) (newId uint64, err error) +} diff --git a/orm/model/ormtable/testdata/test_auto_inc.golden b/orm/model/ormtable/testdata/test_auto_inc.golden index 3f59cb478..48b05fe30 100644 --- a/orm/model/ormtable/testdata/test_auto_inc.golden +++ b/orm/model/ormtable/testdata/test_auto_inc.golden @@ -15,11 +15,28 @@ SET 0301666f6f 0000000000000001 UNIQ testpb.ExampleAutoIncrementTable x : foo -> 1 GET 03808002 01 SEQ testpb.ExampleAutoIncrementTable 1 +GET 03000000000000000002 + PK testpb.ExampleAutoIncrementTable 2 -> {"id":2} +ORM INSERT testpb.ExampleAutoIncrementTable {"id":2,"x":"bar","y":10} +HAS 0301626172 + ERR:EOF +SET 03000000000000000002 1203626172180a + PK testpb.ExampleAutoIncrementTable 2 -> {"id":2,"x":"bar","y":10} +SET 03808002 02 + SEQ testpb.ExampleAutoIncrementTable 2 +SET 0301626172 0000000000000002 + UNIQ testpb.ExampleAutoIncrementTable x : bar -> 2 +GET 03808002 02 + SEQ testpb.ExampleAutoIncrementTable 2 ITERATOR 0300 -> 0301 VALID true KEY 03000000000000000001 1203666f6f1805 PK testpb.ExampleAutoIncrementTable 1 -> {"id":1,"x":"foo","y":5} NEXT + VALID true + KEY 03000000000000000002 1203626172180a + PK testpb.ExampleAutoIncrementTable 2 -> {"id":2,"x":"bar","y":10} + NEXT VALID false ITERATOR 0300 -> 0301 VALID true @@ -28,4 +45,10 @@ ITERATOR 0300 -> 0301 KEY 03000000000000000001 1203666f6f1805 PK testpb.ExampleAutoIncrementTable 1 -> {"id":1,"x":"foo","y":5} NEXT + VALID true + KEY 03000000000000000002 1203626172180a + PK testpb.ExampleAutoIncrementTable 2 -> {"id":2,"x":"bar","y":10} + KEY 03000000000000000002 1203626172180a + PK testpb.ExampleAutoIncrementTable 2 -> {"id":2,"x":"bar","y":10} + NEXT VALID false