diff --git a/orm/model/ormdb/file.go b/orm/model/ormdb/file.go index ea0bb073e..3dcf1e481 100644 --- a/orm/model/ormdb/file.go +++ b/orm/model/ormdb/file.go @@ -2,7 +2,6 @@ package ormdb import ( "bytes" - "context" "encoding/binary" "math" @@ -20,12 +19,11 @@ import ( ) type fileDescriptorDBOptions struct { - Prefix []byte - ID uint32 - TypeResolver ormtable.TypeResolver - JSONValidator func(proto.Message) error - GetBackend func(context.Context) (ormtable.Backend, error) - GetReadBackend func(context.Context) (ormtable.ReadBackend, error) + Prefix []byte + ID uint32 + TypeResolver ormtable.TypeResolver + JSONValidator func(proto.Message) error + BackendResolver ormtable.BackendResolver } type fileDescriptorDB struct { @@ -63,12 +61,11 @@ func newFileDescriptorDB(fileDescriptor protoreflect.FileDescriptor, options fil } table, err := ormtable.Build(ormtable.Options{ - Prefix: prefix, - MessageType: messageType, - TypeResolver: resolver, - JSONValidator: options.JSONValidator, - GetReadBackend: options.GetReadBackend, - GetBackend: options.GetBackend, + Prefix: prefix, + MessageType: messageType, + TypeResolver: resolver, + JSONValidator: options.JSONValidator, + BackendResolver: options.BackendResolver, }) if err != nil { return nil, err diff --git a/orm/model/ormdb/module.go b/orm/model/ormdb/module.go index ee260ec23..dfdc63f72 100644 --- a/orm/model/ormdb/module.go +++ b/orm/model/ormdb/module.go @@ -6,6 +6,10 @@ import ( "encoding/binary" "math" + "google.golang.org/protobuf/reflect/protoregistry" + + ormv1alpha1 "github.com/cosmos/cosmos-sdk/api/cosmos/orm/v1alpha1" + "github.com/cosmos/cosmos-sdk/orm/types/ormjson" "google.golang.org/protobuf/reflect/protodesc" @@ -21,17 +25,6 @@ import ( "github.com/cosmos/cosmos-sdk/orm/types/ormerrors" ) -// ModuleSchema describes the ORM schema for a module. -type ModuleSchema struct { - // FileDescriptors are the file descriptors that contain ORM tables to use in this schema. - // Each file descriptor must have an unique non-zero uint32 ID associated with it. - FileDescriptors map[uint32]protoreflect.FileDescriptor - - // Prefix is an optional prefix to prepend to all keys. It is recommended - // to leave it empty. - Prefix []byte -} - // ModuleDB defines the ORM database type to be used by modules. type ModuleDB interface { ormtable.Schema @@ -74,17 +67,13 @@ type ModuleDBOptions struct { // will be used JSONValidator func(proto.Message) error - // GetBackend is the function used to retrieve the table backend. - // See ormtable.Options.GetBackend for more details. - GetBackend func(context.Context) (ormtable.Backend, error) - - // GetReadBackend is the function used to retrieve a table read backend. - // See ormtable.Options.GetReadBackend for more details. - GetReadBackend func(context.Context) (ormtable.ReadBackend, error) + // GetBackendResolver returns a backend resolver for the requested storage + // type or an error if this type of storage isn't supported. + GetBackendResolver func(ormv1alpha1.StorageType) (ormtable.BackendResolver, error) } // NewModuleDB constructs a ModuleDB instance from the provided schema and options. -func NewModuleDB(schema ModuleSchema, options ModuleDBOptions) (ModuleDB, error) { +func NewModuleDB(schema *ormv1alpha1.ModuleSchemaDescriptor, options ModuleDBOptions) (ModuleDB, error) { prefix := schema.Prefix db := &moduleDB{ prefix: prefix, @@ -92,29 +81,37 @@ func NewModuleDB(schema ModuleSchema, options ModuleDBOptions) (ModuleDB, error) tablesByName: map[protoreflect.FullName]ormtable.Table{}, } - for id, fileDescriptor := range schema.FileDescriptors { + fileResolver := options.FileResolver + if fileResolver == nil { + fileResolver = protoregistry.GlobalFiles + } + + for _, entry := range schema.SchemaFile { + var backendResolver ormtable.BackendResolver + var err error + if options.GetBackendResolver != nil { + backendResolver, err = options.GetBackendResolver(entry.StorageType) + if err != nil { + return nil, err + } + } + + id := entry.Id + fileDescriptor, err := fileResolver.FindFileByPath(entry.ProtoFileName) + if err != nil { + return nil, err + } + if id == 0 { return nil, ormerrors.InvalidFileDescriptorID.Wrapf("for %s", fileDescriptor.Path()) } opts := fileDescriptorDBOptions{ - ID: id, - Prefix: prefix, - TypeResolver: options.TypeResolver, - JSONValidator: options.JSONValidator, - GetBackend: options.GetBackend, - GetReadBackend: options.GetReadBackend, - } - - if options.FileResolver != nil { - // if a FileResolver is provided, we use that to resolve the file - // and not the one provided as a different pinned file descriptor - // may have been provided - var err error - fileDescriptor, err = options.FileResolver.FindFileByPath(fileDescriptor.Path()) - if err != nil { - return nil, err - } + ID: id, + Prefix: prefix, + TypeResolver: options.TypeResolver, + JSONValidator: options.JSONValidator, + BackendResolver: backendResolver, } fdSchema, err := newFileDescriptorDB(fileDescriptor, opts) diff --git a/orm/model/ormdb/module_test.go b/orm/model/ormdb/module_test.go index e70f4e244..d71504551 100644 --- a/orm/model/ormdb/module_test.go +++ b/orm/model/ormdb/module_test.go @@ -8,11 +8,12 @@ import ( "strings" "testing" + ormv1alpha1 "github.com/cosmos/cosmos-sdk/api/cosmos/orm/v1alpha1" + "github.com/golang/mock/gomock" "github.com/cosmos/cosmos-sdk/orm/testing/ormmocks" - "google.golang.org/protobuf/reflect/protoreflect" "gotest.tools/v3/assert" "gotest.tools/v3/golden" @@ -28,9 +29,12 @@ import ( // These tests use a simulated bank keeper. Addresses and balances use // string and uint64 types respectively for simplicity. -var TestBankSchema = ormdb.ModuleSchema{ - FileDescriptors: map[uint32]protoreflect.FileDescriptor{ - 1: testpb.File_testpb_bank_proto, +var TestBankSchema = &ormv1alpha1.ModuleSchemaDescriptor{ + SchemaFile: []*ormv1alpha1.ModuleSchemaDescriptor_FileEntry{ + { + Id: 1, + ProtoFileName: testpb.File_testpb_bank_proto.Path(), + }, }, } @@ -333,3 +337,33 @@ func TestHooks(t *testing.T) { ) assert.NilError(t, k.Burn(ctx, acct1, denom, 5)) } + +func TestGetBackendResolver(t *testing.T) { + backend := ormtest.NewMemoryBackend() + getResolver := func(storageType ormv1alpha1.StorageType) (ormtable.BackendResolver, error) { + switch storageType { + case ormv1alpha1.StorageType_STORAGE_TYPE_MEMORY: + return func(ctx context.Context) (ormtable.ReadBackend, error) { + return backend, nil + }, nil + default: + return nil, fmt.Errorf("storage type %s unsupported", storageType) + } + } + _, err := ormdb.NewModuleDB(TestBankSchema, ormdb.ModuleDBOptions{ + GetBackendResolver: getResolver, + }) + assert.ErrorContains(t, err, "unsupported") + + _, err = ormdb.NewModuleDB(&ormv1alpha1.ModuleSchemaDescriptor{SchemaFile: []*ormv1alpha1.ModuleSchemaDescriptor_FileEntry{ + { + Id: 1, + ProtoFileName: testpb.File_testpb_bank_proto.Path(), + StorageType: ormv1alpha1.StorageType_STORAGE_TYPE_MEMORY, + }, + }, + }, ormdb.ModuleDBOptions{ + GetBackendResolver: getResolver, + }) + assert.NilError(t, err) +} diff --git a/orm/model/ormtable/auto_increment.go b/orm/model/ormtable/auto_increment.go index 7eeb32f2d..40f625b73 100644 --- a/orm/model/ormtable/auto_increment.go +++ b/orm/model/ormtable/auto_increment.go @@ -23,7 +23,7 @@ type autoIncrementTable struct { } func (t autoIncrementTable) InsertReturningID(ctx context.Context, message proto.Message) (newId uint64, err error) { - backend, err := t.getBackend(ctx) + backend, err := t.getWriteBackend(ctx) if err != nil { return 0, err } @@ -32,7 +32,7 @@ func (t autoIncrementTable) InsertReturningID(ctx context.Context, message proto } func (t autoIncrementTable) Save(ctx context.Context, message proto.Message) error { - backend, err := t.getBackend(ctx) + backend, err := t.getWriteBackend(ctx) if err != nil { return err } @@ -42,7 +42,7 @@ func (t autoIncrementTable) Save(ctx context.Context, message proto.Message) err } func (t autoIncrementTable) Insert(ctx context.Context, message proto.Message) error { - backend, err := t.getBackend(ctx) + backend, err := t.getWriteBackend(ctx) if err != nil { return err } @@ -52,7 +52,7 @@ func (t autoIncrementTable) Insert(ctx context.Context, message proto.Message) e } func (t autoIncrementTable) Update(ctx context.Context, message proto.Message) error { - backend, err := t.getBackend(ctx) + backend, err := t.getWriteBackend(ctx) if err != nil { return err } @@ -137,7 +137,7 @@ func (t autoIncrementTable) ValidateJSON(reader io.Reader) error { } func (t autoIncrementTable) ImportJSON(ctx context.Context, reader io.Reader) error { - backend, err := t.getBackend(ctx) + backend, err := t.getWriteBackend(ctx) if err != nil { return err } diff --git a/orm/model/ormtable/backend.go b/orm/model/ormtable/backend.go index 7d6967a74..8013309d0 100644 --- a/orm/model/ormtable/backend.go +++ b/orm/model/ormtable/backend.go @@ -2,6 +2,7 @@ package ormtable import ( "context" + "fmt" "github.com/cosmos/cosmos-sdk/orm/types/kv" ) @@ -33,11 +34,13 @@ type Backend interface { // ValidateHooks returns a ValidateHooks instance or nil. ValidateHooks() ValidateHooks - // WithValidateHooks returns a copy of this backend with the provided hooks. + // WithValidateHooks returns a copy of this backend with the provided validate hooks. WithValidateHooks(ValidateHooks) Backend + // WriteHooks returns a WriteHooks instance of nil. WriteHooks() WriteHooks + // WithWriteHooks returns a copy of this backend with the provided write hooks. WithWriteHooks(WriteHooks) Backend } @@ -160,6 +163,11 @@ func NewBackend(options BackendOptions) Backend { } } +// BackendResolver resolves a backend from the context or returns an error. +// Callers should type cast the returned ReadBackend to Backend to test whether +// the backend is writable. +type BackendResolver func(context.Context) (ReadBackend, error) + // WrapContextDefault performs the default wrapping of a backend in a context. // This should be used primarily for testing purposes and production code // should use some other framework specific wrapping (for instance using @@ -172,10 +180,16 @@ type contextKeyType string var defaultContextKey = contextKeyType("backend") -func getBackendDefault(ctx context.Context) (Backend, error) { - return ctx.Value(defaultContextKey).(Backend), nil -} +func getBackendDefault(ctx context.Context) (ReadBackend, error) { + value := ctx.Value(defaultContextKey) + if value == nil { + return nil, fmt.Errorf("can't resolve backend") + } -func getReadBackendDefault(ctx context.Context) (ReadBackend, error) { - return ctx.Value(defaultContextKey).(ReadBackend), nil + backend, ok := value.(ReadBackend) + if !ok { + return nil, fmt.Errorf("expected value of type %T, instead got %T", backend, value) + } + + return backend, nil } diff --git a/orm/model/ormtable/build.go b/orm/model/ormtable/build.go index 9e3d7e30d..fcf198888 100644 --- a/orm/model/ormtable/build.go +++ b/orm/model/ormtable/build.go @@ -1,7 +1,6 @@ package ormtable import ( - "context" "fmt" "github.com/cosmos/cosmos-sdk/orm/internal/fieldnames" @@ -53,18 +52,15 @@ type Options struct { // will be used JSONValidator func(proto.Message) error - // GetBackend is an optional function which retrieves a Backend from the context. + // BackendResolver is an optional function which retrieves a Backend from the context. // If it is nil, the default behavior will be to attempt to retrieve a // backend using the method that WrapContextDefault uses. This method - // can be used to imlement things like "store keys" which would allow a + // can be used to implement things like "store keys" which would allow a // table to only be used with a specific backend and to hide direct // access to the backend other than through the table interface. - GetBackend func(context.Context) (Backend, error) - - // GetReadBackend is an optional function which retrieves a ReadBackend from the context. - // If it is nil, the default behavior will be to attempt to retrieve a - // backend using the method that WrapContextDefault uses. - GetReadBackend func(context.Context) (ReadBackend, error) + // Mutating operations will attempt to cast ReadBackend to Backend and + // will return an error if that fails. + BackendResolver BackendResolver } // TypeResolver is an interface that can be used for the protoreflect.UnmarshalOptions.Resolver option. @@ -77,20 +73,15 @@ type TypeResolver interface { func Build(options Options) (Table, error) { messageDescriptor := options.MessageType.Descriptor() - getReadBackend := options.GetReadBackend - if getReadBackend == nil { - getReadBackend = getReadBackendDefault - } - getBackend := options.GetBackend - if getBackend == nil { - getBackend = getBackendDefault + backendResolver := options.BackendResolver + if backendResolver == nil { + backendResolver = getBackendDefault } table := &tableImpl{ primaryKeyIndex: &primaryKeyIndex{ - indexers: []indexer{}, - getBackend: getBackend, - getReadBackend: getReadBackend, + indexers: []indexer{}, + getBackend: backendResolver, }, indexes: []Index{}, indexesByFields: map[fieldnames.FieldNames]concreteIndex{}, @@ -213,7 +204,7 @@ func Build(options Options) (Table, error) { UniqueKeyCodec: uniqCdc, fields: idxFields, primaryKey: pkIndex, - getReadBackend: getReadBackend, + getReadBackend: backendResolver, } table.uniqueIndexesByFields[idxFields] = uniqIdx index = uniqIdx @@ -231,7 +222,7 @@ func Build(options Options) (Table, error) { IndexKeyCodec: idxCdc, fields: idxFields, primaryKey: pkIndex, - getReadBackend: getReadBackend, + getReadBackend: backendResolver, } // non-unique indexes can sometimes be named by several sub-lists of diff --git a/orm/model/ormtable/primary_key.go b/orm/model/ormtable/primary_key.go index 8c8b8b8a3..0ce628528 100644 --- a/orm/model/ormtable/primary_key.go +++ b/orm/model/ormtable/primary_key.go @@ -3,6 +3,8 @@ package ormtable import ( "context" + "github.com/cosmos/cosmos-sdk/orm/types/ormerrors" + "github.com/cosmos/cosmos-sdk/orm/internal/fieldnames" "github.com/cosmos/cosmos-sdk/orm/model/ormlist" @@ -18,14 +20,13 @@ import ( // primaryKeyIndex defines an UniqueIndex for the primary key. type primaryKeyIndex struct { *ormkv.PrimaryKeyCodec - fields fieldnames.FieldNames - indexers []indexer - getBackend func(context.Context) (Backend, error) - getReadBackend func(context.Context) (ReadBackend, error) + fields fieldnames.FieldNames + indexers []indexer + getBackend func(context.Context) (ReadBackend, error) } func (p primaryKeyIndex) List(ctx context.Context, prefixKey []interface{}, options ...ormlist.Option) (Iterator, error) { - backend, err := p.getReadBackend(ctx) + backend, err := p.getBackend(ctx) if err != nil { return nil, err } @@ -34,7 +35,7 @@ func (p primaryKeyIndex) List(ctx context.Context, prefixKey []interface{}, opti } func (p primaryKeyIndex) ListRange(ctx context.Context, from, to []interface{}, options ...ormlist.Option) (Iterator, error) { - backend, err := p.getReadBackend(ctx) + backend, err := p.getBackend(ctx) if err != nil { return nil, err } @@ -45,7 +46,7 @@ func (p primaryKeyIndex) ListRange(ctx context.Context, from, to []interface{}, func (p primaryKeyIndex) doNotImplement() {} func (p primaryKeyIndex) Has(ctx context.Context, key ...interface{}) (found bool, err error) { - backend, err := p.getReadBackend(ctx) + backend, err := p.getBackend(ctx) if err != nil { return false, err } @@ -63,7 +64,7 @@ func (p primaryKeyIndex) has(backend ReadBackend, values []protoreflect.Value) ( } func (p primaryKeyIndex) Get(ctx context.Context, message proto.Message, values ...interface{}) (found bool, err error) { - backend, err := p.getReadBackend(ctx) + backend, err := p.getBackend(ctx) if err != nil { return false, err } @@ -102,8 +103,21 @@ func (p primaryKeyIndex) DeleteRange(ctx context.Context, from, to []interface{} return p.deleteByIterator(ctx, it) } -func (p primaryKeyIndex) doDelete(ctx context.Context, primaryKeyValues []protoreflect.Value) error { +func (p primaryKeyIndex) getWriteBackend(ctx context.Context) (Backend, error) { backend, err := p.getBackend(ctx) + if err != nil { + return nil, err + } + + if writeBackend, ok := backend.(Backend); ok { + return writeBackend, nil + } + + return nil, ormerrors.ReadOnly +} + +func (p primaryKeyIndex) doDelete(ctx context.Context, primaryKeyValues []protoreflect.Value) error { + backend, err := p.getWriteBackend(ctx) if err != nil { return err } @@ -190,7 +204,7 @@ func (p primaryKeyIndex) Fields() string { } func (p primaryKeyIndex) deleteByIterator(ctx context.Context, it Iterator) error { - backend, err := p.getBackend(ctx) + backend, err := p.getWriteBackend(ctx) if err != nil { return err } diff --git a/orm/model/ormtable/singleton.go b/orm/model/ormtable/singleton.go index 3c7418492..ec3dc3cab 100644 --- a/orm/model/ormtable/singleton.go +++ b/orm/model/ormtable/singleton.go @@ -44,7 +44,7 @@ func (t singleton) ValidateJSON(reader io.Reader) error { } func (t singleton) ImportJSON(ctx context.Context, reader io.Reader) error { - backend, err := t.getBackend(ctx) + backend, err := t.getWriteBackend(ctx) if err != nil { return err } diff --git a/orm/model/ormtable/table_impl.go b/orm/model/ormtable/table_impl.go index b0a5d6f49..34d75aa74 100644 --- a/orm/model/ormtable/table_impl.go +++ b/orm/model/ormtable/table_impl.go @@ -48,7 +48,7 @@ func (t tableImpl) GetIndexByID(id uint32) Index { } func (t tableImpl) Save(ctx context.Context, message proto.Message) error { - backend, err := t.getBackend(ctx) + backend, err := t.getWriteBackend(ctx) if err != nil { return err } @@ -57,7 +57,7 @@ func (t tableImpl) Save(ctx context.Context, message proto.Message) error { } func (t tableImpl) Insert(ctx context.Context, message proto.Message) error { - backend, err := t.getBackend(ctx) + backend, err := t.getWriteBackend(ctx) if err != nil { return err } @@ -66,7 +66,7 @@ func (t tableImpl) Insert(ctx context.Context, message proto.Message) error { } func (t tableImpl) Update(ctx context.Context, message proto.Message) error { - backend, err := t.getBackend(ctx) + backend, err := t.getWriteBackend(ctx) if err != nil { return err } @@ -287,7 +287,7 @@ func (t tableImpl) ValidateJSON(reader io.Reader) error { } func (t tableImpl) ImportJSON(ctx context.Context, reader io.Reader) error { - backend, err := t.getBackend(ctx) + backend, err := t.getWriteBackend(ctx) if err != nil { return err } @@ -392,7 +392,7 @@ func (t tableImpl) ID() uint32 { } func (t tableImpl) Has(ctx context.Context, message proto.Message) (found bool, err error) { - backend, err := t.getReadBackend(ctx) + backend, err := t.getBackend(ctx) if err != nil { return false, err } @@ -405,7 +405,7 @@ func (t tableImpl) Has(ctx context.Context, message proto.Message) (found bool, // set on the message. Other fields besides the primary key fields will not // be used for retrieval. func (t tableImpl) Get(ctx context.Context, message proto.Message) (found bool, err error) { - backend, err := t.getReadBackend(ctx) + backend, err := t.getBackend(ctx) if err != nil { return false, err } diff --git a/orm/model/ormtable/table_test.go b/orm/model/ormtable/table_test.go index 78e81a85d..864fa02ec 100644 --- a/orm/model/ormtable/table_test.go +++ b/orm/model/ormtable/table_test.go @@ -8,6 +8,8 @@ import ( "strings" "testing" + dbm "github.com/tendermint/tm-db" + "github.com/cosmos/cosmos-sdk/orm/types/kv" "google.golang.org/protobuf/proto" @@ -695,3 +697,16 @@ func protoValuesToInterfaces(ks []protoreflect.Value) []interface{} { return values } + +func TestReadonly(t *testing.T) { + table, err := ormtable.Build(ormtable.Options{ + MessageType: (&testpb.ExampleTable{}).ProtoReflect().Type(), + }) + assert.NilError(t, err) + readBackend := ormtable.NewReadBackend(ormtable.ReadBackendOptions{ + CommitmentStoreReader: dbm.NewMemDB(), + IndexStoreReader: dbm.NewMemDB(), + }) + ctx := ormtable.WrapContextDefault(readBackend) + assert.ErrorIs(t, ormerrors.ReadOnly, table.Insert(ctx, &testpb.ExampleTable{})) +} diff --git a/orm/types/ormerrors/errors.go b/orm/types/ormerrors/errors.go index 2c38a6412..fb2d597d3 100644 --- a/orm/types/ormerrors/errors.go +++ b/orm/types/ormerrors/errors.go @@ -39,4 +39,5 @@ var ( TableNotFound = errors.New(codespace, 27, "table not found") JSONValidationError = errors.New(codespace, 28, "invalid JSON") NotFound = errors.New(codespace, 29, "not found") + ReadOnly = errors.New(codespace, 30, "database is read-only") )