Added AccountInfo.load_by_program() to return AccountInfos since that's nearly always what we want.

This commit is contained in:
Geoff Taylor 2022-02-16 15:14:27 +00:00
parent 6eaca82696
commit d2fb7e50e1
3 changed files with 40 additions and 39 deletions

View File

@ -471,18 +471,15 @@ class Account(AddressableAccount):
# owner is just after mango_group in the layout, and it's a PublicKey which is 32 bytes.
filters = [MemcmpOpts(offset=group_offset, bytes=encode_key(group.address))]
results = context.client.get_program_accounts(
account_infos = AccountInfo.load_by_program(
context,
context.mango_program_address,
memcmp_opts=filters,
data_size=layouts.MANGO_ACCOUNT.sizeof(),
)
cache: Cache = group.fetch_cache(context)
accounts: typing.List[Account] = []
for account_data in results:
address = PublicKey(account_data["pubkey"])
account_info = AccountInfo._from_response_values(
account_data["account"], address
)
for account_info in account_infos:
account = Account.parse(account_info, group, cache)
accounts += [account]
return accounts
@ -500,18 +497,15 @@ class Account(AddressableAccount):
MemcmpOpts(offset=owner_offset, bytes=encode_key(owner)),
]
results = context.client.get_program_accounts(
account_infos = AccountInfo.load_by_program(
context,
context.mango_program_address,
memcmp_opts=filters,
data_size=layouts.MANGO_ACCOUNT.sizeof(),
)
cache: Cache = group.fetch_cache(context)
accounts: typing.List[Account] = []
for account_data in results:
address = PublicKey(account_data["pubkey"])
account_info = AccountInfo._from_response_values(
account_data["account"], address
)
for account_info in account_infos:
account = Account.parse(account_info, group, cache)
accounts += [account]
return accounts
@ -529,18 +523,15 @@ class Account(AddressableAccount):
MemcmpOpts(offset=delegate_offset, bytes=encode_key(delegate)),
]
results = context.client.get_program_accounts(
account_infos = AccountInfo.load_by_program(
context,
context.mango_program_address,
memcmp_opts=filters,
data_size=layouts.MANGO_ACCOUNT.sizeof(),
)
cache: Cache = group.fetch_cache(context)
accounts: typing.List[Account] = []
for account_data in results:
address = PublicKey(account_data["pubkey"])
account_info = AccountInfo._from_response_values(
account_data["account"], address
)
for account_info in account_infos:
account = Account.parse(account_info, group, cache)
accounts += [account]
return accounts

View File

@ -21,7 +21,7 @@ import typing
from decimal import Decimal
from solana.publickey import PublicKey
from solana.rpc.types import RPCResponse
from solana.rpc.types import MemcmpOpts, DataSliceOpts, RPCResponse
from .constants import SOL_DECIMAL_DIVISOR
from .context import Context
@ -102,7 +102,7 @@ class AccountInfo:
@staticmethod
def load_multiple(
context: Context, addresses: typing.Sequence[PublicKey]
) -> typing.List["AccountInfo"]:
) -> typing.Sequence["AccountInfo"]:
# This is a tricky one to get right.
# Some errors this can generate:
# 413 Client Error: Payload Too Large for url
@ -129,6 +129,27 @@ class AccountInfo:
return multiple
@staticmethod
def load_by_program(
context: Context,
pubkey: typing.Union[str, PublicKey],
data_slice: typing.Optional[DataSliceOpts] = None,
data_size: typing.Optional[int] = None,
memcmp_opts: typing.Optional[typing.List[MemcmpOpts]] = None,
) -> typing.Sequence["AccountInfo"]:
all_accounts = context.client.get_program_accounts(
pubkey, data_slice=data_slice, data_size=data_size, memcmp_opts=memcmp_opts
)
all_account_infos = map(
lambda result: AccountInfo._from_response_values(
result["account"], PublicKey(result["pubkey"])
),
all_accounts,
)
return list(all_account_infos)
@staticmethod
def _from_response_values(
response_values: typing.Dict[str, typing.Any], address: PublicKey

View File

@ -144,20 +144,12 @@ class OpenOrders(AddressableAccount):
)
]
results = context.client.get_program_accounts(
account_infos = AccountInfo.load_by_program(
context,
group.serum_program_address,
data_size=layouts.OPEN_ORDERS.sizeof(),
memcmp_opts=filters,
)
account_infos = list(
map(
lambda pair: AccountInfo._from_response_values(pair[0], pair[1]),
[
(result["account"], PublicKey(result["pubkey"]))
for result in results
],
)
)
account_infos_by_address = {
key: value
for key, value in [
@ -197,19 +189,16 @@ class OpenOrders(AddressableAccount):
),
]
results = context.client.get_program_accounts(
program_address, data_size=layouts.OPEN_ORDERS.sizeof(), memcmp_opts=filters
)
accounts = map(
lambda result: AccountInfo._from_response_values(
result["account"], PublicKey(result["pubkey"])
),
results,
account_infos = AccountInfo.load_by_program(
context,
program_address,
data_size=layouts.OPEN_ORDERS.sizeof(),
memcmp_opts=filters,
)
return list(
map(
lambda acc: OpenOrders.parse(acc, base_decimals, quote_decimals),
accounts,
account_infos,
)
)