package codegen import ( "fmt" "strings" "google.golang.org/protobuf/compiler/protogen" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/types/dynamicpb" ormv1alpha1 "github.com/cosmos/cosmos-sdk/api/cosmos/orm/v1alpha1" "github.com/cosmos/cosmos-sdk/orm/internal/fieldnames" "github.com/cosmos/cosmos-sdk/orm/model/ormtable" ) type tableGen struct { fileGen msg *protogen.Message table *ormv1alpha1.TableDescriptor primaryKeyFields fieldnames.FieldNames fields map[protoreflect.Name]*protogen.Field uniqueIndexes []*ormv1alpha1.SecondaryIndexDescriptor ormTable ormtable.Table } func newTableGen(fileGen fileGen, msg *protogen.Message, table *ormv1alpha1.TableDescriptor) (*tableGen, error) { t := &tableGen{fileGen: fileGen, msg: msg, table: table, fields: map[protoreflect.Name]*protogen.Field{}} t.primaryKeyFields = fieldnames.CommaSeparatedFieldNames(table.PrimaryKey.Fields) for _, field := range msg.Fields { t.fields[field.Desc.Name()] = field } uniqIndexes := make([]*ormv1alpha1.SecondaryIndexDescriptor, 0) for _, idx := range t.table.Index { if idx.Unique { uniqIndexes = append(uniqIndexes, idx) } } t.uniqueIndexes = uniqIndexes var err error t.ormTable, err = ormtable.Build(ormtable.Options{ MessageType: dynamicpb.NewMessageType(msg.Desc), TableDescriptor: table, }) return t, err } func (t tableGen) gen() { t.genStoreInterface() t.genIterator() t.genIndexKeys() t.genStruct() t.genStoreImpl() t.genStoreImplGuard() t.genConstructor() } func (t tableGen) genStoreInterface() { t.P("type ", t.messageStoreInterfaceName(t.msg), " interface {") t.P("Insert(ctx ", contextPkg.Ident("Context"), ", ", t.param(t.msg.GoIdent.GoName), " *", t.QualifiedGoIdent(t.msg.GoIdent), ") error") if t.table.PrimaryKey.AutoIncrement { t.P("InsertReturningID(ctx ", contextPkg.Ident("Context"), ", ", t.param(t.msg.GoIdent.GoName), " *", t.QualifiedGoIdent(t.msg.GoIdent), ") (uint64, error)") } t.P("Update(ctx ", contextPkg.Ident("Context"), ", ", t.param(t.msg.GoIdent.GoName), " *", t.QualifiedGoIdent(t.msg.GoIdent), ") error") t.P("Save(ctx ", contextPkg.Ident("Context"), ", ", t.param(t.msg.GoIdent.GoName), " *", t.QualifiedGoIdent(t.msg.GoIdent), ") error") t.P("Delete(ctx ", contextPkg.Ident("Context"), ", ", t.param(t.msg.GoIdent.GoName), " *", t.QualifiedGoIdent(t.msg.GoIdent), ") error") t.P("Has(ctx ", contextPkg.Ident("Context"), ", ", t.fieldsArgs(t.primaryKeyFields.Names()), ") (found bool, err error)") t.P("Get(ctx ", contextPkg.Ident("Context"), ", ", t.fieldsArgs(t.primaryKeyFields.Names()), ") (*", t.QualifiedGoIdent(t.msg.GoIdent), ", error)") for _, idx := range t.uniqueIndexes { t.genUniqueIndexSig(idx) } t.P("List(ctx ", contextPkg.Ident("Context"), ", prefixKey ", t.indexKeyInterfaceName(), ", opts ...", ormListPkg.Ident("Option"), ") ", "(", t.iteratorName(), ", error)") t.P("ListRange(ctx ", contextPkg.Ident("Context"), ", from, to ", t.indexKeyInterfaceName(), ", opts ...", ormListPkg.Ident("Option"), ") ", "(", t.iteratorName(), ", error)") t.P() t.P("doNotImplement()") t.P("}") t.P() } // returns the has and get (in that order) function signature for unique indexes. func (t tableGen) uniqueIndexSig(idx *ormv1alpha1.SecondaryIndexDescriptor) (string, string) { fieldsSlc := strings.Split(idx.Fields, ",") camelFields := t.fieldsToCamelCase(idx.Fields) hasFuncName := "HasBy" + camelFields getFuncName := "GetBy" + camelFields args := t.fieldArgsFromStringSlice(fieldsSlc) hasFuncSig := fmt.Sprintf("%s (ctx context.Context, %s) (found bool, err error)", hasFuncName, args) getFuncSig := fmt.Sprintf("%s (ctx context.Context, %s) (*%s, error)", getFuncName, args, t.msg.GoIdent.GoName) return hasFuncSig, getFuncSig } func (t tableGen) genUniqueIndexSig(idx *ormv1alpha1.SecondaryIndexDescriptor) { hasSig, getSig := t.uniqueIndexSig(idx) t.P(hasSig) t.P(getSig) } func (t tableGen) iteratorName() string { return t.msg.GoIdent.GoName + "Iterator" } func (t tableGen) getSig() string { res := "Get" + t.msg.GoIdent.GoName + "(" res += t.fieldsArgs(t.primaryKeyFields.Names()) res += ") (*" + t.QualifiedGoIdent(t.msg.GoIdent) + ", error)" return res } func (t tableGen) hasSig() string { t.P("Has(ctx ", contextPkg.Ident("Context"), ", ", t.fieldsArgs(t.primaryKeyFields.Names()), ") (found bool, err error)") return "" } func (t tableGen) listSig() string { res := "List" + t.msg.GoIdent.GoName + "(" res += t.indexKeyInterfaceName() res += ") (" res += t.iteratorName() res += ", error)" return res } func (t tableGen) fieldArgsFromStringSlice(names []string) string { args := make([]string, len(names)) for i, name := range names { args[i] = t.fieldArg(protoreflect.Name(name)) } return strings.Join(args, ",") } func (t tableGen) fieldsArgs(names []protoreflect.Name) string { var params []string for _, name := range names { params = append(params, t.fieldArg(name)) } return strings.Join(params, ",") } func (t tableGen) fieldArg(name protoreflect.Name) string { typ, pointer := t.GeneratedFile.FieldGoType(t.fields[name]) if pointer { typ = "*" + typ } return string(name) + " " + typ } func (t tableGen) genStruct() { t.P("type ", t.messageStoreReceiverName(t.msg), " struct {") if t.table.PrimaryKey.AutoIncrement { t.P("table ", ormTablePkg.Ident("AutoIncrementTable")) } else { t.P("table ", ormTablePkg.Ident("Table")) } t.P("}") t.storeStructName() } func (t tableGen) genStoreImpl() { receiverVar := "this" receiver := fmt.Sprintf("func (%s %s) ", receiverVar, t.messageStoreReceiverName(t.msg)) varName := t.param(t.msg.GoIdent.GoName) varTypeName := t.QualifiedGoIdent(t.msg.GoIdent) // these methods all have the same impl sans their names. so we can just loop and replace. methods := []string{"Insert", "Update", "Save", "Delete"} for _, method := range methods { t.P(receiver, method, "(ctx ", contextPkg.Ident("Context"), ", ", varName, " *", varTypeName, ") error {") t.P("return ", receiverVar, ".table.", method, "(ctx, ", varName, ")") t.P("}") t.P() } if t.table.PrimaryKey.AutoIncrement { t.P(receiver, "InsertReturningID(ctx ", contextPkg.Ident("Context"), ", ", varName, " *", varTypeName, ") (uint64, error) {") t.P("return ", receiverVar, ".table.InsertReturningID(ctx, ", varName, ")") t.P("}") t.P() } // Has t.P(receiver, "Has(ctx ", contextPkg.Ident("Context"), ", ", t.fieldsArgs(t.primaryKeyFields.Names()), ") (found bool, err error) {") t.P("return ", receiverVar, ".table.PrimaryKey().Has(ctx, ", t.primaryKeyFields.String(), ")") t.P("}") t.P() // Get t.P(receiver, "Get(ctx ", contextPkg.Ident("Context"), ", ", t.fieldsArgs(t.primaryKeyFields.Names()), ") (*", varTypeName, ", error) {") t.P("var ", varName, " ", varTypeName) t.P("found, err := ", receiverVar, ".table.PrimaryKey().Get(ctx, &", varName, ", ", t.primaryKeyFields.String(), ")") t.P("if !found {") t.P("return nil, err") t.P("}") t.P("return &", varName, ", err") t.P("}") t.P() for _, idx := range t.uniqueIndexes { fields := strings.Split(idx.Fields, ",") hasName, getName := t.uniqueIndexSig(idx) // has t.P("func (", receiverVar, " ", t.messageStoreReceiverName(t.msg), ") ", hasName, "{") t.P("return ", receiverVar, ".table.GetIndexByID(", idx.Id, ").(", ormTablePkg.Ident("UniqueIndex"), ").Has(ctx,") for _, field := range fields { t.P(field, ",") } t.P(")") t.P("}") t.P() // get varName := t.param(t.msg.GoIdent.GoName) varTypeName := t.msg.GoIdent.GoName t.P("func (", receiverVar, " ", t.messageStoreReceiverName(t.msg), ") ", getName, "{") t.P("var ", varName, " ", varTypeName) t.P("found, err := ", receiverVar, ".table.GetIndexByID(", idx.Id, ").(", ormTablePkg.Ident("UniqueIndex"), ").Get(ctx, &", varName, ",") for _, field := range fields { t.P(field, ",") } t.P(")") t.P("if !found {") t.P("return nil, err") t.P("}") t.P("return &", varName, ", nil") t.P("}") t.P() } // List t.P(receiver, "List(ctx ", contextPkg.Ident("Context"), ", prefixKey ", t.indexKeyInterfaceName(), ", opts ...", ormListPkg.Ident("Option"), ") (", t.iteratorName(), ", error) {") t.P("opts = append(opts, ", ormListPkg.Ident("Prefix"), "(prefixKey.values()...))") t.P("it, err := ", receiverVar, ".table.GetIndexByID(prefixKey.id()).Iterator(ctx, opts...)") t.P("return ", t.iteratorName(), "{it}, err") t.P("}") t.P() // ListRange t.P(receiver, "ListRange(ctx ", contextPkg.Ident("Context"), ", from, to ", t.indexKeyInterfaceName(), ", opts ...", ormListPkg.Ident("Option"), ") (", t.iteratorName(), ", error) {") t.P("opts = append(opts, ", ormListPkg.Ident("Start"), "(from.values()...), ", ormListPkg.Ident("End"), "(to.values()...))") t.P("it, err := ", receiverVar, ".table.GetIndexByID(from.id()).Iterator(ctx, opts...)") t.P("return ", t.iteratorName(), "{it}, err") t.P("}") t.P() t.P(receiver, "doNotImplement() {}") t.P() } func (t tableGen) genStoreImplGuard() { t.P("var _ ", t.messageStoreInterfaceName(t.msg), " = ", t.messageStoreReceiverName(t.msg), "{}") } func (t tableGen) genConstructor() { iface := t.messageStoreInterfaceName(t.msg) t.P("func New", iface, "(db ", ormdbPkg.Ident("ModuleDB"), ") (", iface, ", error) {") t.P("table := db.GetTable(&", t.msg.GoIdent.GoName, "{})") t.P("if table == nil {") t.P("return nil,", ormErrPkg.Ident("TableNotFound.Wrap"), "(string((&", t.msg.GoIdent.GoName, "{}).ProtoReflect().Descriptor().FullName()))") t.P("}") if t.table.PrimaryKey.AutoIncrement { t.P( "return ", t.messageStoreReceiverName(t.msg), "{table.(", ormTablePkg.Ident("AutoIncrementTable"), ")}, nil", ) } else { t.P("return ", t.messageStoreReceiverName(t.msg), "{table}, nil") } t.P("}") }