cosmos-sdk/orm/internal/codegen/query.go

320 lines
9.1 KiB
Go

package codegen
import (
"bytes"
"fmt"
"os"
ormv1 "cosmossdk.io/api/cosmos/orm/v1"
"github.com/iancoleman/strcase"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"github.com/cosmos/cosmos-sdk/orm/internal/fieldnames"
)
type queryProtoGen struct {
*protogen.File
imports map[string]bool
svc *writer
msgs *writer
outFile *os.File
}
func (g queryProtoGen) gen() error {
g.imports[g.Desc.Path()] = true
g.svc.F("// %s queries the state of the tables specified by %s.", g.queryServiceName(), g.Desc.Path())
g.svc.F("service %s {", g.queryServiceName())
g.svc.Indent()
for _, msg := range g.Messages {
tableDesc := proto.GetExtension(msg.Desc.Options(), ormv1.E_Table).(*ormv1.TableDescriptor)
if tableDesc != nil {
err := g.genTableRPCMethods(msg, tableDesc)
if err != nil {
return err
}
}
singletonDesc := proto.GetExtension(msg.Desc.Options(), ormv1.E_Singleton).(*ormv1.SingletonDescriptor)
if singletonDesc != nil {
err := g.genSingletonRPCMethods(msg)
if err != nil {
return err
}
}
}
g.svc.Dedent()
g.svc.F("}")
g.svc.F("")
outBuf := newWriter()
outBuf.F("// Code generated by protoc-gen-go-cosmos-orm-proto. DO NOT EDIT.")
outBuf.F(`syntax = "proto3";`)
outBuf.F("package %s;", g.Desc.Package())
outBuf.F("")
imports := maps.Keys(g.imports)
slices.Sort(imports)
for _, i := range imports {
outBuf.F(`import "%s";`, i)
}
outBuf.F("")
_, err := outBuf.Write(g.svc.Bytes())
if err != nil {
return err
}
_, err = outBuf.Write(g.msgs.Bytes())
if err != nil {
return err
}
_, err = g.outFile.Write(outBuf.Bytes())
if err != nil {
return err
}
return g.outFile.Close()
}
func (g queryProtoGen) genTableRPCMethods(msg *protogen.Message, desc *ormv1.TableDescriptor) error {
name := msg.Desc.Name()
g.svc.F("// Get queries the %s table by its primary key.", name)
g.svc.F("rpc Get%s (Get%sRequest) returns (Get%sResponse) {}", name, name, name) // TODO grpc gateway
g.startRequestType("Get%sRequest", name)
g.msgs.Indent()
primaryKeyFields := fieldnames.CommaSeparatedFieldNames(desc.PrimaryKey.Fields)
fields := msg.Desc.Fields()
for i, fieldName := range primaryKeyFields.Names() {
field := fields.ByName(fieldName)
if field == nil {
return fmt.Errorf("can't find primary key field %s", fieldName)
}
g.msgs.F("// %s specifies the value of the %s field in the primary key.", fieldName, fieldName)
g.msgs.F("%s %s = %d;", g.fieldType(field), fieldName, i+1)
}
g.msgs.Dedent()
g.msgs.F("}")
g.msgs.F("")
g.startResponseType("Get%sResponse", name)
g.msgs.Indent()
g.msgs.F("// value is the response value.")
g.msgs.F("%s value = 1;", name)
g.msgs.Dedent()
g.msgs.F("}")
g.msgs.F("")
for _, idx := range desc.Index {
if !idx.Unique {
continue
}
fieldsCamel := fieldsToCamelCase(idx.Fields)
methodName := fmt.Sprintf("Get%sBy%s", name, fieldsCamel)
g.svc.F("// %s queries the %s table by its %s index", methodName, name, fieldsCamel)
g.svc.F("rpc %s (%sRequest) returns (%sResponse) {}", methodName, methodName, methodName) // TODO grpc gateway
g.startRequestType("%sRequest", methodName)
g.msgs.Indent()
fieldNames := fieldnames.CommaSeparatedFieldNames(idx.Fields)
for i, fieldName := range fieldNames.Names() {
field := fields.ByName(fieldName)
if field == nil {
return fmt.Errorf("can't find unique index field %s", fieldName)
}
g.msgs.F("%s %s = %d;", g.fieldType(field), fieldName, i+1)
}
g.msgs.Dedent()
g.msgs.F("}")
g.msgs.F("")
g.startResponseType("%sResponse", methodName)
g.msgs.Indent()
g.msgs.F("%s value = 1;", name)
g.msgs.Dedent()
g.msgs.F("}")
g.msgs.F("")
}
g.imports["cosmos/base/query/v1beta1/pagination.proto"] = true
g.svc.F("// List%s queries the %s table using prefix and range queries against defined indexes.", name, name)
g.svc.F("rpc List%s (List%sRequest) returns (List%sResponse) {}", name, name, name) // TODO grpc gateway
g.startRequestType("List%sRequest", name)
g.msgs.Indent()
g.msgs.F("// IndexKey specifies the value of an index key to use in prefix and range queries.")
g.msgs.F("message IndexKey {")
g.msgs.Indent()
indexFields := []string{desc.PrimaryKey.Fields}
// the primary key has field number 1
fieldNums := []uint32{1}
for _, index := range desc.Index {
indexFields = append(indexFields, index.Fields)
// index field numbers are their id + 1
fieldNums = append(fieldNums, index.Id+1)
}
g.msgs.F("// key specifies the index key value.")
g.msgs.F("oneof key {")
g.msgs.Indent()
for i, fields := range indexFields {
fieldName := fieldsToSnakeCase(fields)
typeName := fieldsToCamelCase(fields)
g.msgs.F("// %s specifies the value of the %s index key to use in the query.", fieldName, typeName)
g.msgs.F("%s %s = %d;", typeName, fieldName, fieldNums[i])
}
g.msgs.Dedent()
g.msgs.F("}")
for _, fieldNames := range indexFields {
g.msgs.F("")
g.msgs.F("message %s {", fieldsToCamelCase(fieldNames))
g.msgs.Indent()
for i, fieldName := range fieldnames.CommaSeparatedFieldNames(fieldNames).Names() {
g.msgs.F("// %s is the value of the %s field in the index.", fieldName, fieldName)
g.msgs.F("// It can be omitted to query for all valid values of that field in this segment of the index.")
g.msgs.F("optional %s %s = %d;", g.fieldType(fields.ByName(fieldName)), fieldName, i+1)
}
g.msgs.Dedent()
g.msgs.F("}")
}
g.msgs.Dedent()
g.msgs.F("}")
g.msgs.F("")
g.msgs.F("// query specifies the type of query - either a prefix or range query.")
g.msgs.F("oneof query {")
g.msgs.Indent()
g.msgs.F("// prefix_query specifies the index key value to use for the prefix query.")
g.msgs.F("IndexKey prefix_query = 1;")
g.msgs.F("// range_query specifies the index key from/to values to use for the range query.")
g.msgs.F("RangeQuery range_query = 2;")
g.msgs.Dedent()
g.msgs.F("}")
g.msgs.F("// pagination specifies optional pagination parameters.")
g.msgs.F("cosmos.base.query.v1beta1.PageRequest pagination = 3;")
g.msgs.F("")
g.msgs.F("// RangeQuery specifies the from/to index keys for a range query.")
g.msgs.F("message RangeQuery {")
g.msgs.Indent()
g.msgs.F("// from is the index key to use for the start of the range query.")
g.msgs.F("// To query from the start of an index, specify an index key for that index with empty values.")
g.msgs.F("IndexKey from = 1;")
g.msgs.F("// to is the index key to use for the end of the range query.")
g.msgs.F("// The index key type MUST be the same as the index key type used for from.")
g.msgs.F("// To query from to the end of an index it can be omitted.")
g.msgs.F("IndexKey to = 2;")
g.msgs.Dedent()
g.msgs.F("}")
g.msgs.Dedent()
g.msgs.F("}")
g.msgs.F("")
g.startResponseType("List%sResponse", name)
g.msgs.Indent()
g.msgs.F("// values are the results of the query.")
g.msgs.F("repeated %s values = 1;", name)
g.msgs.F("// pagination is the pagination response.")
g.msgs.F("cosmos.base.query.v1beta1.PageResponse pagination = 2;")
g.msgs.Dedent()
g.msgs.F("}")
g.msgs.F("")
return nil
}
func (g queryProtoGen) genSingletonRPCMethods(msg *protogen.Message) error {
name := msg.Desc.Name()
g.svc.F("// Get%s queries the %s singleton.", name, name)
g.svc.F("rpc Get%s (Get%sRequest) returns (Get%sResponse) {}", name, name, name) // TODO grpc gateway
g.startRequestType("Get%sRequest", name)
g.msgs.F("}")
g.msgs.F("")
g.startRequestType("Get%sResponse", name)
g.msgs.Indent()
g.msgs.F("%s value = 1;", name)
g.msgs.Dedent()
g.msgs.F("}")
g.msgs.F("")
return nil
}
func (g queryProtoGen) startRequestType(format string, args ...any) {
g.startRequestResponseType("request", format, args...)
}
func (g queryProtoGen) startResponseType(format string, args ...any) {
g.startRequestResponseType("response", format, args...)
}
func (g queryProtoGen) startRequestResponseType(typ string, format string, args ...any) {
msgTypeName := fmt.Sprintf(format, args...)
g.msgs.F("// %s is the %s/%s %s type.", msgTypeName, g.queryServiceName(), msgTypeName, typ)
g.msgs.F("message %s {", msgTypeName)
}
func (g queryProtoGen) queryServiceName() string {
return fmt.Sprintf("%sQuery", strcase.ToCamel(fileShortName(g.File)))
}
func (g queryProtoGen) fieldType(descriptor protoreflect.FieldDescriptor) string {
if descriptor.Kind() == protoreflect.MessageKind {
message := descriptor.Message()
g.imports[message.ParentFile().Path()] = true
return string(message.FullName())
}
return descriptor.Kind().String()
}
type writer struct {
*bytes.Buffer
indent int
indentStr string
}
func newWriter() *writer {
return &writer{
Buffer: &bytes.Buffer{},
}
}
func (w *writer) F(format string, args ...interface{}) {
_, err := w.Write([]byte(w.indentStr))
if err != nil {
panic(err)
}
_, err = fmt.Fprintf(w, format, args...)
if err != nil {
panic(err)
}
_, err = fmt.Fprintln(w)
if err != nil {
panic(err)
}
}
func (w *writer) Indent() {
w.indent += 1
w.updateIndent()
}
func (w *writer) updateIndent() {
w.indentStr = ""
for i := 0; i < w.indent; i++ {
w.indentStr += " "
}
}
func (w *writer) Dedent() {
w.indent -= 1
w.updateIndent()
}