package solana import ( "errors" "fmt" "sort" "github.com/davecgh/go-spew/spew" bin "github.com/gagliardetto/binary" "github.com/gagliardetto/solana-go/text" "github.com/gagliardetto/treeout" "go.uber.org/zap" ) type Instruction interface { ProgramID() PublicKey // the programID the instruction acts on Accounts() []*AccountMeta // returns the list of accounts the instructions requires Data() ([]byte, error) // the binary encoded instructions } type TransactionOption interface { apply(opts *transactionOptions) } type transactionOptions struct { payer PublicKey } type transactionOptionFunc func(opts *transactionOptions) func (f transactionOptionFunc) apply(opts *transactionOptions) { f(opts) } func TransactionPayer(payer PublicKey) TransactionOption { return transactionOptionFunc(func(opts *transactionOptions) { opts.payer = payer }) } var debugNewTransaction = false type TransactionBuilder struct { instructions []Instruction recentBlockHash Hash opts []TransactionOption } // NewTransactionBuilder creates a new instruction builder. func NewTransactionBuilder() *TransactionBuilder { return &TransactionBuilder{} } // AddInstruction adds the provided instruction to the builder. func (builder *TransactionBuilder) AddInstruction(instruction Instruction) *TransactionBuilder { builder.instructions = append(builder.instructions, instruction) return builder } // SetRecentBlockHash sets the recent blockhash for the instruction builder. func (builder *TransactionBuilder) SetRecentBlockHash(recentBlockHash Hash) *TransactionBuilder { builder.recentBlockHash = recentBlockHash return builder } // WithOpt adds a TransactionOption. func (builder *TransactionBuilder) WithOpt(opt TransactionOption) *TransactionBuilder { builder.opts = append(builder.opts, opt) return builder } // Set transaction fee payer. // If not set, defaults to first signer account of the first instruction. func (builder *TransactionBuilder) SetFeePayer(feePayer PublicKey) *TransactionBuilder { builder.opts = append(builder.opts, TransactionPayer(feePayer)) return builder } // Build builds and returns a *Transaction. func (builder *TransactionBuilder) Build() (*Transaction, error) { return NewTransaction( builder.instructions, builder.recentBlockHash, builder.opts..., ) } func NewTransaction(instructions []Instruction, recentBlockHash Hash, opts ...TransactionOption) (*Transaction, error) { if len(instructions) == 0 { return nil, fmt.Errorf("requires at-least one instruction to create a transaction") } options := transactionOptions{} for _, opt := range opts { opt.apply(&options) } feePayer := options.payer if feePayer.IsZero() { found := false for _, act := range instructions[0].Accounts() { if act.IsSigner { feePayer = act.PublicKey found = true break } } if !found { return nil, fmt.Errorf("cannot determine fee payer. You can ether pass the fee payer via the 'TransactionWithInstructions' option parameter or it falls back to the first instruction's first signer") } } programIDs := make(PublicKeySlice, 0) accounts := []*AccountMeta{} for _, instruction := range instructions { for _, key := range instruction.Accounts() { accounts = append(accounts, key) } programIDs.UniqueAppend(instruction.ProgramID()) } // Add programID to the account list for _, programID := range programIDs { accounts = append(accounts, &AccountMeta{ PublicKey: programID, IsSigner: false, IsWritable: false, }) } // Sort. Prioritizing first by signer, then by writable sort.Slice(accounts, func(i, j int) bool { return accounts[i].less(accounts[j]) }) uniqAccountsMap := map[PublicKey]uint64{} uniqAccounts := []*AccountMeta{} for _, acc := range accounts { if index, found := uniqAccountsMap[acc.PublicKey]; found { uniqAccounts[index].IsWritable = uniqAccounts[index].IsWritable || acc.IsWritable continue } uniqAccounts = append(uniqAccounts, acc) uniqAccountsMap[acc.PublicKey] = uint64(len(uniqAccounts) - 1) } if debugNewTransaction { zlog.Debug("unique account sorted", zap.Int("account_count", len(uniqAccounts))) } // Move fee payer to the front feePayerIndex := -1 for idx, acc := range uniqAccounts { if acc.PublicKey.Equals(feePayer) { feePayerIndex = idx } } if debugNewTransaction { zlog.Debug("current fee payer index", zap.Int("fee_payer_index", feePayerIndex)) } accountCount := len(uniqAccounts) if feePayerIndex < 0 { // fee payer is not part of accounts we want to add it accountCount++ } finalAccounts := make([]*AccountMeta, accountCount) itr := 1 for idx, uniqAccount := range uniqAccounts { if idx == feePayerIndex { uniqAccount.IsSigner = true uniqAccount.IsWritable = true finalAccounts[0] = uniqAccount continue } finalAccounts[itr] = uniqAccount itr++ } if feePayerIndex < 0 { // fee payer is not part of accounts we want to add it feePayerAccount := &AccountMeta{ PublicKey: feePayer, IsSigner: true, IsWritable: true, } finalAccounts[0] = feePayerAccount } message := Message{ RecentBlockhash: recentBlockHash, } accountKeyIndex := map[string]uint16{} for idx, acc := range finalAccounts { if debugNewTransaction { zlog.Debug("transaction account", zap.Int("account_index", idx), zap.Stringer("account_pub_key", acc.PublicKey), ) } message.AccountKeys = append(message.AccountKeys, acc.PublicKey) accountKeyIndex[acc.PublicKey.String()] = uint16(idx) if acc.IsSigner { message.Header.NumRequiredSignatures++ if !acc.IsWritable { message.Header.NumReadonlySignedAccounts++ } continue } if !acc.IsWritable { message.Header.NumReadonlyUnsignedAccounts++ } } if debugNewTransaction { zlog.Debug("message header compiled", zap.Uint8("num_required_signatures", message.Header.NumRequiredSignatures), zap.Uint8("num_readonly_signed_accounts", message.Header.NumReadonlySignedAccounts), zap.Uint8("num_readonly_unsigned_accounts", message.Header.NumReadonlyUnsignedAccounts), ) } for txIdx, instruction := range instructions { accounts = instruction.Accounts() accountIndex := make([]uint16, len(accounts)) for idx, acc := range accounts { accountIndex[idx] = accountKeyIndex[acc.PublicKey.String()] } data, err := instruction.Data() if err != nil { return nil, fmt.Errorf("unable to encode instructions [%d]: %w", txIdx, err) } message.Instructions = append(message.Instructions, CompiledInstruction{ ProgramIDIndex: accountKeyIndex[instruction.ProgramID().String()], AccountCount: bin.Varuint16(uint16(len(accountIndex))), Accounts: accountIndex, DataLength: bin.Varuint16(uint16(len(data))), Data: data, }) } return &Transaction{ Message: message, }, nil } type privateKeyGetter func(key PublicKey) *PrivateKey func (tx *Transaction) MarshalBinary() ([]byte, error) { if len(tx.Signatures) == 0 || len(tx.Signatures) != int(tx.Message.Header.NumRequiredSignatures) { return nil, errors.New("signature verification failed") } messageContent, err := tx.Message.MarshalBinary() if err != nil { return nil, fmt.Errorf("failed to encode tx.Message to binary: %w", err) } var signatureCount []byte bin.EncodeCompactU16Length(&signatureCount, len(tx.Signatures)) output := make([]byte, 0, len(signatureCount)+len(signatureCount)*64+len(messageContent)) output = append(output, signatureCount...) for _, sig := range tx.Signatures { output = append(output, sig[:]...) } output = append(output, messageContent...) return output, nil } func (tx *Transaction) MarshalWithEncoder(encoder *bin.Encoder) error { out, err := tx.MarshalBinary() if err != nil { return err } return encoder.WriteBytes(out, false) } func (tx *Transaction) UnmarshalWithDecoder(decoder *bin.Decoder) (err error) { { numSignatures, err := bin.DecodeCompactU16LengthFromByteReader(decoder) if err != nil { return err } for i := 0; i < numSignatures; i++ { sigBytes, err := decoder.ReadNBytes(64) if err != nil { return err } var sig Signature copy(sig[:], sigBytes) tx.Signatures = append(tx.Signatures, sig) } } { err := tx.Message.UnmarshalWithDecoder(decoder) if err != nil { return err } } return nil } func (tx *Transaction) Sign(getter privateKeyGetter) (out []Signature, err error) { messageContent, err := tx.Message.MarshalBinary() if err != nil { return nil, fmt.Errorf("unable to encode message for signing: %w", err) } signerKeys := tx.Message.signerKeys() for _, key := range signerKeys { privateKey := getter(key) if privateKey == nil { return nil, fmt.Errorf("signer key %q not found. Ensure all the signer keys are in the vault", key.String()) } s, err := privateKey.Sign(messageContent) if err != nil { return nil, fmt.Errorf("failed to signed with key %q: %w", key.String(), err) } tx.Signatures = append(tx.Signatures, s) } return tx.Signatures, nil } func (tx *Transaction) EncodeTree(encoder *text.TreeEncoder) (int, error) { if len(encoder.Docs) == 0 { encoder.Docs = []string{"Transaction"} } tx.EncodeToTree(encoder) return encoder.WriteString(encoder.Tree.String()) } func (tx *Transaction) EncodeToTree(parent treeout.Branches) { parent.ParentFunc(func(txTree treeout.Branches) { txTree.Child("Signatures[]").ParentFunc(func(signaturesBranch treeout.Branches) { for _, sig := range tx.Signatures { signaturesBranch.Child(sig.String()) } }) txTree.Child("Message").ParentFunc(func(messageBranch treeout.Branches) { tx.Message.EncodeToTree(messageBranch) }) }) parent.Child("Instructions[]").ParentFunc(func(message treeout.Branches) { for _, inst := range tx.Message.Instructions { progKey, err := tx.ResolveProgramIDIndex(inst.ProgramIDIndex) if err == nil { decodedInstruction, err := DecodeInstruction(progKey, inst.ResolveInstructionAccounts(&tx.Message), inst.Data) if err == nil { if enToTree, ok := decodedInstruction.(text.EncodableToTree); ok { enToTree.EncodeToTree(message) } else { message.Child(spew.Sdump(decodedInstruction)) } } else { // TODO: log error? message.Child(fmt.Sprintf(text.RedBG("cannot decode instruction for %s program: %s"), progKey, err)) } } else { message.Child(fmt.Sprintf(text.RedBG("cannot ResolveProgramIDIndex: %s"), err)) } } }) }