From 9fce3dfc9c534116238326e7f628d6827dc03988 Mon Sep 17 00:00:00 2001 From: acheron <98934430+acheroncrypto@users.noreply.github.com> Date: Fri, 26 Jul 2024 21:21:20 +0200 Subject: [PATCH] ts: Get discriminator lengths dynamically (#3120) --- CHANGELOG.md | 2 + .../anchor/src/coder/borsh/accounts.ts | 41 ++++++++------- .../anchor/src/coder/borsh/discriminator.ts | 4 -- ts/packages/anchor/src/coder/borsh/event.ts | 50 +++++++++---------- ts/packages/anchor/src/coder/borsh/index.ts | 1 - .../anchor/src/coder/borsh/instruction.ts | 31 ++++-------- .../anchor/tests/coder-accounts.spec.ts | 2 - 7 files changed, 60 insertions(+), 71 deletions(-) delete mode 100644 ts/packages/anchor/src/coder/borsh/discriminator.ts diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b72e5032..972f47792 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ The minor version will be incremented upon a breaking change and the patch versi - client: Make `ThreadSafeSigner` trait public ([#3107](https://github.com/coral-xyz/anchor/pull/3107)). - lang: Update `dispatch` function to support dynamic discriminators ([#3104](https://github.com/coral-xyz/anchor/pull/3104)). - lang: Remove the fallback function shortcut in `try_entry` function ([#3109](https://github.com/coral-xyz/anchor/pull/3109)). +- ts: Get discriminator lengths dynamically ([#3120](https://github.com/coral-xyz/anchor/pull/3120)). ### Fixes @@ -44,6 +45,7 @@ The minor version will be incremented upon a breaking change and the patch versi - client: Remove `async_rpc` method ([#3053](https://github.com/coral-xyz/anchor/pull/3053)). - lang: Make discriminator type unsized ([#3098](https://github.com/coral-xyz/anchor/pull/3098)). - lang: Require `Discriminator` trait impl when using the `zero` constraint ([#3118](https://github.com/coral-xyz/anchor/pull/3118)). +- ts: Remove `DISCRIMINATOR_SIZE` constant ([#3120](https://github.com/coral-xyz/anchor/pull/3120)). ## [0.30.1] - 2024-06-20 diff --git a/ts/packages/anchor/src/coder/borsh/accounts.ts b/ts/packages/anchor/src/coder/borsh/accounts.ts index 999813380..6451962c8 100644 --- a/ts/packages/anchor/src/coder/borsh/accounts.ts +++ b/ts/packages/anchor/src/coder/borsh/accounts.ts @@ -1,10 +1,9 @@ import bs58 from "bs58"; import { Buffer } from "buffer"; import { Layout } from "buffer-layout"; -import { Idl } from "../../idl.js"; +import { Idl, IdlDiscriminator } from "../../idl.js"; import { IdlCoder } from "./idl.js"; import { AccountsCoder } from "../index.js"; -import { DISCRIMINATOR_SIZE } from "./discriminator.js"; /** * Encodes and decodes account objects. @@ -15,7 +14,10 @@ export class BorshAccountsCoder /** * Maps account type identifier to a layout. */ - private accountLayouts: Map; + private accountLayouts: Map< + A, + { discriminator: IdlDiscriminator; layout: Layout } + >; public constructor(private idl: Idl) { if (!idl.accounts) { @@ -28,12 +30,18 @@ export class BorshAccountsCoder throw new Error("Accounts require `idl.types`"); } - const layouts: [A, Layout][] = idl.accounts.map((acc) => { + const layouts = idl.accounts.map((acc) => { const typeDef = types.find((ty) => ty.name === acc.name); if (!typeDef) { throw new Error(`Account not found: ${acc.name}`); } - return [acc.name as A, IdlCoder.typeDefLayout({ typeDef, types })]; + return [ + acc.name as A, + { + discriminator: acc.discriminator, + layout: IdlCoder.typeDefLayout({ typeDef, types }), + }, + ] as const; }); this.accountLayouts = new Map(layouts); @@ -45,7 +53,7 @@ export class BorshAccountsCoder if (!layout) { throw new Error(`Unknown account: ${accountName}`); } - const len = layout.encode(account, buffer); + const len = layout.layout.encode(account, buffer); const accountData = buffer.slice(0, len); const discriminator = this.accountDiscriminator(accountName); return Buffer.concat([discriminator, accountData]); @@ -54,32 +62,31 @@ export class BorshAccountsCoder public decode(accountName: A, data: Buffer): T { // Assert the account discriminator is correct. const discriminator = this.accountDiscriminator(accountName); - if (discriminator.compare(data.slice(0, DISCRIMINATOR_SIZE))) { + if (discriminator.compare(data.slice(0, discriminator.length))) { throw new Error("Invalid account discriminator"); } return this.decodeUnchecked(accountName, data); } public decodeAny(data: Buffer): T { - const discriminator = data.slice(0, DISCRIMINATOR_SIZE); - const accountName = Array.from(this.accountLayouts.keys()).find((key) => - this.accountDiscriminator(key).equals(discriminator) - ); - if (!accountName) { - throw new Error("Account not found"); + for (const [name, layout] of this.accountLayouts) { + const givenDisc = data.subarray(0, layout.discriminator.length); + const matches = givenDisc.equals(Buffer.from(layout.discriminator)); + if (matches) return this.decodeUnchecked(name, data); } - return this.decodeUnchecked(accountName, data); + throw new Error("Account not found"); } public decodeUnchecked(accountName: A, acc: Buffer): T { // Chop off the discriminator before decoding. - const data = acc.subarray(DISCRIMINATOR_SIZE); + const discriminator = this.accountDiscriminator(accountName); + const data = acc.subarray(discriminator.length); const layout = this.accountLayouts.get(accountName); if (!layout) { throw new Error(`Unknown account: ${accountName}`); } - return layout.decode(data); + return layout.layout.decode(data); } public memcmp(accountName: A, appendData?: Buffer): any { @@ -94,7 +101,7 @@ export class BorshAccountsCoder public size(accountName: A): number { return ( - DISCRIMINATOR_SIZE + + this.accountDiscriminator(accountName).length + IdlCoder.typeSize({ defined: { name: accountName } }, this.idl) ); } diff --git a/ts/packages/anchor/src/coder/borsh/discriminator.ts b/ts/packages/anchor/src/coder/borsh/discriminator.ts deleted file mode 100644 index effc15a91..000000000 --- a/ts/packages/anchor/src/coder/borsh/discriminator.ts +++ /dev/null @@ -1,4 +0,0 @@ -/** - * Number of bytes in anchor discriminators - */ -export const DISCRIMINATOR_SIZE = 8; diff --git a/ts/packages/anchor/src/coder/borsh/event.ts b/ts/packages/anchor/src/coder/borsh/event.ts index 0ce5c0971..b3686213f 100644 --- a/ts/packages/anchor/src/coder/borsh/event.ts +++ b/ts/packages/anchor/src/coder/borsh/event.ts @@ -1,7 +1,7 @@ import { Buffer } from "buffer"; import { Layout } from "buffer-layout"; import * as base64 from "../../utils/bytes/base64.js"; -import { Idl } from "../../idl.js"; +import { Idl, IdlDiscriminator } from "../../idl.js"; import { IdlCoder } from "./idl.js"; import { EventCoder } from "../index.js"; @@ -9,12 +9,10 @@ export class BorshEventCoder implements EventCoder { /** * Maps account type identifier to a layout. */ - private layouts: Map; - - /** - * Maps base64 encoded event discriminator to event name. - */ - private discriminators: Map; + private layouts: Map< + string, + { discriminator: IdlDiscriminator; layout: Layout } + >; public constructor(idl: Idl) { if (!idl.events) { @@ -27,21 +25,20 @@ export class BorshEventCoder implements EventCoder { throw new Error("Events require `idl.types`"); } - const layouts: [string, Layout][] = idl.events.map((ev) => { + const layouts = idl.events.map((ev) => { const typeDef = types.find((ty) => ty.name === ev.name); if (!typeDef) { throw new Error(`Event not found: ${ev.name}`); } - return [ev.name, IdlCoder.typeDefLayout({ typeDef, types })]; + return [ + ev.name, + { + discriminator: ev.discriminator, + layout: IdlCoder.typeDefLayout({ typeDef, types }), + }, + ] as const; }); this.layouts = new Map(layouts); - - this.discriminators = new Map( - (idl.events ?? []).map((ev) => [ - base64.encode(Buffer.from(ev.discriminator)), - ev.name, - ]) - ); } public decode(log: string): { @@ -55,19 +52,18 @@ export class BorshEventCoder implements EventCoder { } catch (e) { return null; } - const disc = base64.encode(logArr.slice(0, 8)); - // Only deserialize if the discriminator implies a proper event. - const eventName = this.discriminators.get(disc); - if (!eventName) { - return null; + for (const [name, layout] of this.layouts) { + const givenDisc = logArr.subarray(0, layout.discriminator.length); + const matches = givenDisc.equals(Buffer.from(layout.discriminator)); + if (matches) { + return { + name, + data: layout.layout.decode(logArr.subarray(givenDisc.length)), + }; + } } - const layout = this.layouts.get(eventName); - if (!layout) { - throw new Error(`Unknown event: ${eventName}`); - } - const data = layout.decode(logArr.slice(8)); - return { data, name: eventName }; + return null; } } diff --git a/ts/packages/anchor/src/coder/borsh/index.ts b/ts/packages/anchor/src/coder/borsh/index.ts index e3e943531..d64a65287 100644 --- a/ts/packages/anchor/src/coder/borsh/index.ts +++ b/ts/packages/anchor/src/coder/borsh/index.ts @@ -7,7 +7,6 @@ import { Coder } from "../index.js"; export { BorshInstructionCoder } from "./instruction.js"; export { BorshAccountsCoder } from "./accounts.js"; -export { DISCRIMINATOR_SIZE } from "./discriminator.js"; export { BorshEventCoder } from "./event.js"; /** diff --git a/ts/packages/anchor/src/coder/borsh/instruction.ts b/ts/packages/anchor/src/coder/borsh/instruction.ts index 5d45e3319..8961eb802 100644 --- a/ts/packages/anchor/src/coder/borsh/instruction.ts +++ b/ts/packages/anchor/src/coder/borsh/instruction.ts @@ -16,7 +16,7 @@ import { IdlDiscriminator, } from "../../idl.js"; import { IdlCoder } from "./idl.js"; -import { DISCRIMINATOR_SIZE, InstructionCoder } from "../index.js"; +import { InstructionCoder } from "../index.js"; /** * Encodes and decodes program instructions. @@ -28,9 +28,6 @@ export class BorshInstructionCoder implements InstructionCoder { { discriminator: IdlDiscriminator; layout: Layout } >; - // Base58 encoded sighash to instruction layout. - private sighashLayouts: Map; - public constructor(private idl: Idl) { const ixLayouts = idl.instructions.map((ix) => { const name = ix.name; @@ -41,13 +38,6 @@ export class BorshInstructionCoder implements InstructionCoder { return [name, { discriminator: ix.discriminator, layout }] as const; }); this.ixLayouts = new Map(ixLayouts); - - const sighashLayouts = ixLayouts.map( - ([name, { discriminator, layout }]) => { - return [bs58.encode(discriminator), { name, layout }] as const; - } - ); - this.sighashLayouts = new Map(sighashLayouts); } /** @@ -77,17 +67,18 @@ export class BorshInstructionCoder implements InstructionCoder { ix = encoding === "hex" ? Buffer.from(ix, "hex") : bs58.decode(ix); } - const disc = ix.slice(0, DISCRIMINATOR_SIZE); - const data = ix.slice(DISCRIMINATOR_SIZE); - const decoder = this.sighashLayouts.get(bs58.encode(disc)); - if (!decoder) { - return null; + for (const [name, layout] of this.ixLayouts) { + const givenDisc = ix.subarray(0, layout.discriminator.length); + const matches = givenDisc.equals(Buffer.from(layout.discriminator)); + if (matches) { + return { + name, + data: layout.layout.decode(ix.subarray(givenDisc.length)), + }; + } } - return { - name: decoder.name, - data: decoder.layout.decode(data), - }; + return null; } /** diff --git a/ts/packages/anchor/tests/coder-accounts.spec.ts b/ts/packages/anchor/tests/coder-accounts.spec.ts index 7432515b2..6a36dcc9f 100644 --- a/ts/packages/anchor/tests/coder-accounts.spec.ts +++ b/ts/packages/anchor/tests/coder-accounts.spec.ts @@ -1,7 +1,5 @@ import * as assert from "assert"; import { BorshCoder, Idl } from "../src"; -import { DISCRIMINATOR_SIZE } from "../src/coder/borsh/discriminator"; -import { sha256 } from "@noble/hashes/sha256"; describe("coder.accounts", () => { test("Can encode and decode user-defined accounts, including those with consecutive capital letters", () => {