ts: Get discriminator lengths dynamically (#3120)

This commit is contained in:
acheron 2024-07-26 21:21:20 +02:00 committed by GitHub
parent 293ee9142b
commit 9fce3dfc9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 60 additions and 71 deletions

View File

@ -23,6 +23,7 @@ The minor version will be incremented upon a breaking change and the patch versi
- client: Make `ThreadSafeSigner` trait public ([#3107](https://github.com/coral-xyz/anchor/pull/3107)).
- lang: Update `dispatch` function to support dynamic discriminators ([#3104](https://github.com/coral-xyz/anchor/pull/3104)).
- lang: Remove the fallback function shortcut in `try_entry` function ([#3109](https://github.com/coral-xyz/anchor/pull/3109)).
- ts: Get discriminator lengths dynamically ([#3120](https://github.com/coral-xyz/anchor/pull/3120)).
### Fixes
@ -44,6 +45,7 @@ The minor version will be incremented upon a breaking change and the patch versi
- client: Remove `async_rpc` method ([#3053](https://github.com/coral-xyz/anchor/pull/3053)).
- lang: Make discriminator type unsized ([#3098](https://github.com/coral-xyz/anchor/pull/3098)).
- lang: Require `Discriminator` trait impl when using the `zero` constraint ([#3118](https://github.com/coral-xyz/anchor/pull/3118)).
- ts: Remove `DISCRIMINATOR_SIZE` constant ([#3120](https://github.com/coral-xyz/anchor/pull/3120)).
## [0.30.1] - 2024-06-20

View File

@ -1,10 +1,9 @@
import bs58 from "bs58";
import { Buffer } from "buffer";
import { Layout } from "buffer-layout";
import { Idl } from "../../idl.js";
import { Idl, IdlDiscriminator } from "../../idl.js";
import { IdlCoder } from "./idl.js";
import { AccountsCoder } from "../index.js";
import { DISCRIMINATOR_SIZE } from "./discriminator.js";
/**
* Encodes and decodes account objects.
@ -15,7 +14,10 @@ export class BorshAccountsCoder<A extends string = string>
/**
* Maps account type identifier to a layout.
*/
private accountLayouts: Map<A, Layout>;
private accountLayouts: Map<
A,
{ discriminator: IdlDiscriminator; layout: Layout }
>;
public constructor(private idl: Idl) {
if (!idl.accounts) {
@ -28,12 +30,18 @@ export class BorshAccountsCoder<A extends string = string>
throw new Error("Accounts require `idl.types`");
}
const layouts: [A, Layout][] = idl.accounts.map((acc) => {
const layouts = idl.accounts.map((acc) => {
const typeDef = types.find((ty) => ty.name === acc.name);
if (!typeDef) {
throw new Error(`Account not found: ${acc.name}`);
}
return [acc.name as A, IdlCoder.typeDefLayout({ typeDef, types })];
return [
acc.name as A,
{
discriminator: acc.discriminator,
layout: IdlCoder.typeDefLayout({ typeDef, types }),
},
] as const;
});
this.accountLayouts = new Map(layouts);
@ -45,7 +53,7 @@ export class BorshAccountsCoder<A extends string = string>
if (!layout) {
throw new Error(`Unknown account: ${accountName}`);
}
const len = layout.encode(account, buffer);
const len = layout.layout.encode(account, buffer);
const accountData = buffer.slice(0, len);
const discriminator = this.accountDiscriminator(accountName);
return Buffer.concat([discriminator, accountData]);
@ -54,32 +62,31 @@ export class BorshAccountsCoder<A extends string = string>
public decode<T = any>(accountName: A, data: Buffer): T {
// Assert the account discriminator is correct.
const discriminator = this.accountDiscriminator(accountName);
if (discriminator.compare(data.slice(0, DISCRIMINATOR_SIZE))) {
if (discriminator.compare(data.slice(0, discriminator.length))) {
throw new Error("Invalid account discriminator");
}
return this.decodeUnchecked(accountName, data);
}
public decodeAny<T = any>(data: Buffer): T {
const discriminator = data.slice(0, DISCRIMINATOR_SIZE);
const accountName = Array.from(this.accountLayouts.keys()).find((key) =>
this.accountDiscriminator(key).equals(discriminator)
);
if (!accountName) {
throw new Error("Account not found");
for (const [name, layout] of this.accountLayouts) {
const givenDisc = data.subarray(0, layout.discriminator.length);
const matches = givenDisc.equals(Buffer.from(layout.discriminator));
if (matches) return this.decodeUnchecked(name, data);
}
return this.decodeUnchecked<T>(accountName, data);
throw new Error("Account not found");
}
public decodeUnchecked<T = any>(accountName: A, acc: Buffer): T {
// Chop off the discriminator before decoding.
const data = acc.subarray(DISCRIMINATOR_SIZE);
const discriminator = this.accountDiscriminator(accountName);
const data = acc.subarray(discriminator.length);
const layout = this.accountLayouts.get(accountName);
if (!layout) {
throw new Error(`Unknown account: ${accountName}`);
}
return layout.decode(data);
return layout.layout.decode(data);
}
public memcmp(accountName: A, appendData?: Buffer): any {
@ -94,7 +101,7 @@ export class BorshAccountsCoder<A extends string = string>
public size(accountName: A): number {
return (
DISCRIMINATOR_SIZE +
this.accountDiscriminator(accountName).length +
IdlCoder.typeSize({ defined: { name: accountName } }, this.idl)
);
}

View File

@ -1,4 +0,0 @@
/**
* Number of bytes in anchor discriminators
*/
export const DISCRIMINATOR_SIZE = 8;

View File

@ -1,7 +1,7 @@
import { Buffer } from "buffer";
import { Layout } from "buffer-layout";
import * as base64 from "../../utils/bytes/base64.js";
import { Idl } from "../../idl.js";
import { Idl, IdlDiscriminator } from "../../idl.js";
import { IdlCoder } from "./idl.js";
import { EventCoder } from "../index.js";
@ -9,12 +9,10 @@ export class BorshEventCoder implements EventCoder {
/**
* Maps account type identifier to a layout.
*/
private layouts: Map<string, Layout>;
/**
* Maps base64 encoded event discriminator to event name.
*/
private discriminators: Map<string, string>;
private layouts: Map<
string,
{ discriminator: IdlDiscriminator; layout: Layout }
>;
public constructor(idl: Idl) {
if (!idl.events) {
@ -27,21 +25,20 @@ export class BorshEventCoder implements EventCoder {
throw new Error("Events require `idl.types`");
}
const layouts: [string, Layout<any>][] = idl.events.map((ev) => {
const layouts = idl.events.map((ev) => {
const typeDef = types.find((ty) => ty.name === ev.name);
if (!typeDef) {
throw new Error(`Event not found: ${ev.name}`);
}
return [ev.name, IdlCoder.typeDefLayout({ typeDef, types })];
return [
ev.name,
{
discriminator: ev.discriminator,
layout: IdlCoder.typeDefLayout({ typeDef, types }),
},
] as const;
});
this.layouts = new Map(layouts);
this.discriminators = new Map<string, string>(
(idl.events ?? []).map((ev) => [
base64.encode(Buffer.from(ev.discriminator)),
ev.name,
])
);
}
public decode(log: string): {
@ -55,19 +52,18 @@ export class BorshEventCoder implements EventCoder {
} catch (e) {
return null;
}
const disc = base64.encode(logArr.slice(0, 8));
// Only deserialize if the discriminator implies a proper event.
const eventName = this.discriminators.get(disc);
if (!eventName) {
return null;
for (const [name, layout] of this.layouts) {
const givenDisc = logArr.subarray(0, layout.discriminator.length);
const matches = givenDisc.equals(Buffer.from(layout.discriminator));
if (matches) {
return {
name,
data: layout.layout.decode(logArr.subarray(givenDisc.length)),
};
}
}
const layout = this.layouts.get(eventName);
if (!layout) {
throw new Error(`Unknown event: ${eventName}`);
}
const data = layout.decode(logArr.slice(8));
return { data, name: eventName };
return null;
}
}

View File

@ -7,7 +7,6 @@ import { Coder } from "../index.js";
export { BorshInstructionCoder } from "./instruction.js";
export { BorshAccountsCoder } from "./accounts.js";
export { DISCRIMINATOR_SIZE } from "./discriminator.js";
export { BorshEventCoder } from "./event.js";
/**

View File

@ -16,7 +16,7 @@ import {
IdlDiscriminator,
} from "../../idl.js";
import { IdlCoder } from "./idl.js";
import { DISCRIMINATOR_SIZE, InstructionCoder } from "../index.js";
import { InstructionCoder } from "../index.js";
/**
* Encodes and decodes program instructions.
@ -28,9 +28,6 @@ export class BorshInstructionCoder implements InstructionCoder {
{ discriminator: IdlDiscriminator; layout: Layout }
>;
// Base58 encoded sighash to instruction layout.
private sighashLayouts: Map<string, { name: string; layout: Layout }>;
public constructor(private idl: Idl) {
const ixLayouts = idl.instructions.map((ix) => {
const name = ix.name;
@ -41,13 +38,6 @@ export class BorshInstructionCoder implements InstructionCoder {
return [name, { discriminator: ix.discriminator, layout }] as const;
});
this.ixLayouts = new Map(ixLayouts);
const sighashLayouts = ixLayouts.map(
([name, { discriminator, layout }]) => {
return [bs58.encode(discriminator), { name, layout }] as const;
}
);
this.sighashLayouts = new Map(sighashLayouts);
}
/**
@ -77,17 +67,18 @@ export class BorshInstructionCoder implements InstructionCoder {
ix = encoding === "hex" ? Buffer.from(ix, "hex") : bs58.decode(ix);
}
const disc = ix.slice(0, DISCRIMINATOR_SIZE);
const data = ix.slice(DISCRIMINATOR_SIZE);
const decoder = this.sighashLayouts.get(bs58.encode(disc));
if (!decoder) {
return null;
for (const [name, layout] of this.ixLayouts) {
const givenDisc = ix.subarray(0, layout.discriminator.length);
const matches = givenDisc.equals(Buffer.from(layout.discriminator));
if (matches) {
return {
name,
data: layout.layout.decode(ix.subarray(givenDisc.length)),
};
}
}
return {
name: decoder.name,
data: decoder.layout.decode(data),
};
return null;
}
/**

View File

@ -1,7 +1,5 @@
import * as assert from "assert";
import { BorshCoder, Idl } from "../src";
import { DISCRIMINATOR_SIZE } from "../src/coder/borsh/discriminator";
import { sha256 } from "@noble/hashes/sha256";
describe("coder.accounts", () => {
test("Can encode and decode user-defined accounts, including those with consecutive capital letters", () => {