Fixed TokenAccount to be able to find or create the associated token account, or use an old-style account if it exists and the associated token account doesn't.

This commit is contained in:
Geoff Taylor 2021-08-17 17:52:17 +01:00
parent 6a61b55e93
commit b408874a2a
3 changed files with 65 additions and 55 deletions

View File

@ -11,7 +11,6 @@ from decimal import Decimal
from spl.token.client import Token as SplToken from spl.token.client import Token as SplToken
from spl.token.constants import TOKEN_PROGRAM_ID from spl.token.constants import TOKEN_PROGRAM_ID
from solana.publickey import PublicKey from solana.publickey import PublicKey
from spl.token.instructions import get_associated_token_address
sys.path.insert(0, os.path.abspath( sys.path.insert(0, os.path.abspath(
os.path.join(os.path.dirname(__file__), '..'))) os.path.join(os.path.dirname(__file__), '..')))
@ -45,14 +44,8 @@ if account_info is None:
if account_info.owner == mango.SYSTEM_PROGRAM_ADDRESS: if account_info.owner == mango.SYSTEM_PROGRAM_ADDRESS:
# This is a root wallet account - get the associated token account # This is a root wallet account - get the associated token account
associated_token_address = get_associated_token_address(args.address, token.mint) destination: PublicKey = mango.TokenAccount.find_or_create_token_address_to_use(
token_account: typing.Optional[mango.TokenAccount] = mango.TokenAccount.load(context, associated_token_address) context, wallet, args.address, token)
if token_account is None:
# There is no associated token account, so create it
destination = spl_token.create_associated_token_account(args.address)
else:
# The associated token account exists so use it as the destination
destination = associated_token_address
quantity = token.shift_to_native(args.quantity) quantity = token.shift_to_native(args.quantity)

View File

@ -5,7 +5,6 @@ import logging
import os import os
import os.path import os.path
import sys import sys
import traceback
import typing import typing
from decimal import Decimal from decimal import Decimal
@ -26,6 +25,8 @@ parser.add_argument("--symbol", type=str, required=True, help="token symbol to s
parser.add_argument("--address", type=PublicKey, parser.add_argument("--address", type=PublicKey,
help="Destination address for the SPL token - can be either the actual token address or the address of the owner of the token address") help="Destination address for the SPL token - can be either the actual token address or the address of the owner of the token address")
parser.add_argument("--quantity", type=Decimal, required=True, help="quantity of token to send") parser.add_argument("--quantity", type=Decimal, required=True, help="quantity of token to send")
parser.add_argument("--wait", action="store_true", default=False,
help="wait until the transaction is confirmed")
parser.add_argument("--dry-run", action="store_true", default=False, parser.add_argument("--dry-run", action="store_true", default=False,
help="runs as read-only and does not perform any transactions") help="runs as read-only and does not perform any transactions")
args = parser.parse_args() args = parser.parse_args()
@ -33,7 +34,6 @@ args = parser.parse_args()
logging.getLogger().setLevel(args.log_level) logging.getLogger().setLevel(args.log_level)
logging.warning(mango.WARNING_DISCLAIMER_TEXT) logging.warning(mango.WARNING_DISCLAIMER_TEXT)
try:
context = mango.ContextBuilder.from_command_line_parameters(args) context = mango.ContextBuilder.from_command_line_parameters(args)
wallet = mango.Wallet.from_command_line_parameters_or_raise(args) wallet = mango.Wallet.from_command_line_parameters_or_raise(args)
@ -50,17 +50,14 @@ try:
source = PublicKey(source_account["pubkey"]) source = PublicKey(source_account["pubkey"])
# Is the address an actual token account? Or is it the SOL address of the owner? # Is the address an actual token account? Or is it the SOL address of the owner?
possible_dest: typing.Optional[mango.TokenAccount] = mango.TokenAccount.load(context, args.address) account_info: typing.Optional[mango.AccountInfo] = mango.AccountInfo.load(context, args.address)
if (possible_dest is not None) and (possible_dest.value.token.mint == token.mint): if account_info is None:
# We successfully loaded the token account. raise Exception(f"Could not find account at address {args.address}.")
destination: PublicKey = args.address
else: if account_info.owner == mango.SYSTEM_PROGRAM_ADDRESS:
destination_accounts = spl_token.get_accounts(args.address) # This is a root wallet account - get the token account to use.
if len(destination_accounts["result"]["value"]) == 0: destination: PublicKey = mango.TokenAccount.find_or_create_token_address_to_use(
raise Exception( context, wallet, args.address, token)
f"Could not find destination account using {args.address} as either owner address or token address.")
destination_account = destination_accounts["result"]["value"][0]
destination = PublicKey(destination_account["pubkey"])
owner = wallet.account owner = wallet.account
amount = int(args.quantity * Decimal(10 ** token.decimals)) amount = int(args.quantity * Decimal(10 ** token.decimals))
@ -76,14 +73,13 @@ try:
print("Skipping actual transfer - dry run.") print("Skipping actual transfer - dry run.")
else: else:
transfer_response = spl_token.transfer(source, destination, owner, amount) transfer_response = spl_token.transfer(source, destination, owner, amount)
transaction_ids = transfer_response["result"] transaction_ids = [transfer_response["result"]]
print(f"Waiting on transaction ID: {transaction_ids}") print(f"Transaction IDs: {transaction_ids}")
if args.wait:
context.client.wait_for_confirmation(transaction_ids) context.client.wait_for_confirmation(transaction_ids)
updated_balance = spl_token.get_balance(source) updated_balance = spl_token.get_balance(source)
updated_balance_text = updated_balance["result"]["value"]["uiAmountString"] updated_balance_text = updated_balance["result"]["value"]["uiAmountString"]
print(f"{text_amount} sent. Balance now: {updated_balance_text} {token.name}") print(f"{text_amount} sent. Balance now: {updated_balance_text} {token.name}")
except Exception as exception: else:
logging.critical(f"send-token stopped because of exception: {exception} - {traceback.format_exc()}") print(f"{text_amount} sent.")
except:
logging.critical(f"send-token stopped because of uncatchable error: {traceback.format_exc()}")

View File

@ -21,6 +21,7 @@ from solana.publickey import PublicKey
from solana.rpc.types import TokenAccountOpts from solana.rpc.types import TokenAccountOpts
from spl.token.client import Token as SplToken from spl.token.client import Token as SplToken
from spl.token.constants import TOKEN_PROGRAM_ID from spl.token.constants import TOKEN_PROGRAM_ID
from spl.token.instructions import get_associated_token_address
from .accountinfo import AccountInfo from .accountinfo import AccountInfo
from .addressableaccount import AddressableAccount from .addressableaccount import AddressableAccount
@ -30,6 +31,7 @@ from .token import Token
from .tokenlookup import TokenLookup from .tokenlookup import TokenLookup
from .tokenvalue import TokenValue from .tokenvalue import TokenValue
from .version import Version from .version import Version
from .wallet import Wallet
# # 🥭 TokenAccount class # # 🥭 TokenAccount class
# #
@ -89,6 +91,25 @@ class TokenAccount(AddressableAccount):
return largest_account return largest_account
@staticmethod
def find_or_create_token_address_to_use(context: Context, wallet: Wallet, owner: PublicKey, token: Token) -> PublicKey:
# This is a root wallet account - get the token account to use.
associated_token_address = get_associated_token_address(owner, token.mint)
token_account: typing.Optional[TokenAccount] = TokenAccount.load(context, associated_token_address)
if token_account is None:
# There is no associated token account. See if they have an old-style non-associated token account.
largest = TokenAccount.fetch_largest_for_owner_and_token(context, owner, token)
if largest is not None:
# There is an old-style account so use that.
return largest.address
# There is no old-style token account either, so create the proper associated token account.
spl_token = SplToken(context.client.compatible_client, token.mint, TOKEN_PROGRAM_ID, wallet.account)
return spl_token.create_associated_token_account(owner)
else:
# The associated token account exists so use it
return associated_token_address
@staticmethod @staticmethod
def from_layout(layout: layouts.TOKEN_ACCOUNT, account_info: AccountInfo, token: Token) -> "TokenAccount": def from_layout(layout: layouts.TOKEN_ACCOUNT, account_info: AccountInfo, token: Token) -> "TokenAccount":
token_value = TokenValue(token, token.shift_to_decimals(layout.amount)) token_value = TokenValue(token, token.shift_to_decimals(layout.amount))