cosmos-sdk/orm/encoding/ormkv/unique_key.go

205 lines
5.1 KiB
Go

package ormkv
import (
"bytes"
"io"
"github.com/cosmos/cosmos-sdk/orm/types/ormerrors"
"google.golang.org/protobuf/reflect/protoreflect"
)
// UniqueKeyCodec is the codec for unique indexes.
type UniqueKeyCodec struct {
pkFieldOrder []struct {
inKey bool
i int
}
keyCodec *KeyCodec
valueCodec *KeyCodec
}
var _ IndexCodec = &UniqueKeyCodec{}
// NewUniqueKeyCodec creates a new UniqueKeyCodec with an optional prefix for the
// provided message descriptor, index and primary key fields.
func NewUniqueKeyCodec(prefix []byte, messageType protoreflect.MessageType, indexFields, primaryKeyFields []protoreflect.Name) (*UniqueKeyCodec, error) {
if len(indexFields) == 0 {
return nil, ormerrors.InvalidTableDefinition.Wrapf("index fields are empty")
}
if len(primaryKeyFields) == 0 {
return nil, ormerrors.InvalidTableDefinition.Wrapf("primary key fields are empty")
}
keyCodec, err := NewKeyCodec(prefix, messageType, indexFields)
if err != nil {
return nil, err
}
haveFields := map[protoreflect.Name]int{}
for i, descriptor := range keyCodec.fieldDescriptors {
haveFields[descriptor.Name()] = i
}
var valueFields []protoreflect.Name
var pkFieldOrder []struct {
inKey bool
i int
}
k := 0
for _, field := range primaryKeyFields {
if j, ok := haveFields[field]; ok {
pkFieldOrder = append(pkFieldOrder, struct {
inKey bool
i int
}{inKey: true, i: j})
} else {
valueFields = append(valueFields, field)
pkFieldOrder = append(pkFieldOrder, struct {
inKey bool
i int
}{inKey: false, i: k})
k++
}
}
valueCodec, err := NewKeyCodec(nil, messageType, valueFields)
if err != nil {
return nil, err
}
return &UniqueKeyCodec{
pkFieldOrder: pkFieldOrder,
keyCodec: keyCodec,
valueCodec: valueCodec,
}, nil
}
func (u UniqueKeyCodec) DecodeIndexKey(k, v []byte) (indexFields, primaryKey []protoreflect.Value, err error) {
ks, err := u.keyCodec.DecodeKey(bytes.NewReader(k))
// got prefix key
if err == io.EOF {
return ks, nil, err
} else if err != nil {
return nil, nil, err
}
// got prefix key
if len(ks) < len(u.keyCodec.fieldCodecs) {
return ks, nil, err
}
vs, err := u.valueCodec.DecodeKey(bytes.NewReader(v))
if err != nil {
return nil, nil, err
}
pk := u.extractPrimaryKey(ks, vs)
return ks, pk, nil
}
func (cdc UniqueKeyCodec) extractPrimaryKey(keyValues, valueValues []protoreflect.Value) []protoreflect.Value {
numPkFields := len(cdc.pkFieldOrder)
pkValues := make([]protoreflect.Value, numPkFields)
for i := 0; i < numPkFields; i++ {
fo := cdc.pkFieldOrder[i]
if fo.inKey {
pkValues[i] = keyValues[fo.i]
} else {
pkValues[i] = valueValues[fo.i]
}
}
return pkValues
}
func (u UniqueKeyCodec) DecodeEntry(k, v []byte) (Entry, error) {
idxVals, pk, err := u.DecodeIndexKey(k, v)
if err != nil {
return nil, err
}
return &IndexKeyEntry{
TableName: u.MessageType().Descriptor().FullName(),
Fields: u.keyCodec.fieldNames,
IsUnique: true,
IndexValues: idxVals,
PrimaryKey: pk,
}, err
}
func (u UniqueKeyCodec) EncodeEntry(entry Entry) (k, v []byte, err error) {
indexEntry, ok := entry.(*IndexKeyEntry)
if !ok {
return nil, nil, ormerrors.BadDecodeEntry
}
k, err = u.keyCodec.EncodeKey(indexEntry.IndexValues)
if err != nil {
return nil, nil, err
}
n := len(indexEntry.PrimaryKey)
if n != len(u.pkFieldOrder) {
return nil, nil, ormerrors.BadDecodeEntry.Wrapf("wrong primary key length")
}
var values []protoreflect.Value
for i := 0; i < n; i++ {
value := indexEntry.PrimaryKey[i]
fieldOrder := u.pkFieldOrder[i]
if !fieldOrder.inKey {
// goes in values because it is not present in the index key otherwise
values = append(values, value)
} else {
// does not go in values, but we need to verify that the value in index values matches the primary key value
if u.keyCodec.fieldCodecs[fieldOrder.i].Compare(value, indexEntry.IndexValues[fieldOrder.i]) != 0 {
return nil, nil, ormerrors.BadDecodeEntry.Wrapf("value in primary key does not match corresponding value in index key")
}
}
}
v, err = u.valueCodec.EncodeKey(values)
return k, v, err
}
func (u UniqueKeyCodec) EncodeKVFromMessage(message protoreflect.Message) (k, v []byte, err error) {
_, k, err = u.keyCodec.EncodeKeyFromMessage(message)
if err != nil {
return nil, nil, err
}
_, v, err = u.valueCodec.EncodeKeyFromMessage(message)
return k, v, err
}
func (u UniqueKeyCodec) GetFieldNames() []protoreflect.Name {
return u.keyCodec.GetFieldNames()
}
func (u UniqueKeyCodec) GetKeyCodec() *KeyCodec {
return u.keyCodec
}
func (u UniqueKeyCodec) GetValueCodec() *KeyCodec {
return u.valueCodec
}
func (u UniqueKeyCodec) CompareKeys(key1, key2 []protoreflect.Value) int {
return u.keyCodec.CompareKeys(key1, key2)
}
func (u UniqueKeyCodec) EncodeKeyFromMessage(message protoreflect.Message) (keyValues []protoreflect.Value, key []byte, err error) {
return u.keyCodec.EncodeKeyFromMessage(message)
}
func (u UniqueKeyCodec) IsFullyOrdered() bool {
return u.keyCodec.IsFullyOrdered()
}
func (u UniqueKeyCodec) MessageType() protoreflect.MessageType {
return u.keyCodec.messageType
}