ts: Get discriminator lengths dynamically (#3120)

This commit is contained in:
acheron 2024-07-26 21:21:20 +02:00 committed by GitHub
parent 293ee9142b
commit 9fce3dfc9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 60 additions and 71 deletions

View File

@ -23,6 +23,7 @@ The minor version will be incremented upon a breaking change and the patch versi
- client: Make `ThreadSafeSigner` trait public ([#3107](https://github.com/coral-xyz/anchor/pull/3107)). - client: Make `ThreadSafeSigner` trait public ([#3107](https://github.com/coral-xyz/anchor/pull/3107)).
- lang: Update `dispatch` function to support dynamic discriminators ([#3104](https://github.com/coral-xyz/anchor/pull/3104)). - lang: Update `dispatch` function to support dynamic discriminators ([#3104](https://github.com/coral-xyz/anchor/pull/3104)).
- lang: Remove the fallback function shortcut in `try_entry` function ([#3109](https://github.com/coral-xyz/anchor/pull/3109)). - lang: Remove the fallback function shortcut in `try_entry` function ([#3109](https://github.com/coral-xyz/anchor/pull/3109)).
- ts: Get discriminator lengths dynamically ([#3120](https://github.com/coral-xyz/anchor/pull/3120)).
### Fixes ### Fixes
@ -44,6 +45,7 @@ The minor version will be incremented upon a breaking change and the patch versi
- client: Remove `async_rpc` method ([#3053](https://github.com/coral-xyz/anchor/pull/3053)). - client: Remove `async_rpc` method ([#3053](https://github.com/coral-xyz/anchor/pull/3053)).
- lang: Make discriminator type unsized ([#3098](https://github.com/coral-xyz/anchor/pull/3098)). - lang: Make discriminator type unsized ([#3098](https://github.com/coral-xyz/anchor/pull/3098)).
- lang: Require `Discriminator` trait impl when using the `zero` constraint ([#3118](https://github.com/coral-xyz/anchor/pull/3118)). - lang: Require `Discriminator` trait impl when using the `zero` constraint ([#3118](https://github.com/coral-xyz/anchor/pull/3118)).
- ts: Remove `DISCRIMINATOR_SIZE` constant ([#3120](https://github.com/coral-xyz/anchor/pull/3120)).
## [0.30.1] - 2024-06-20 ## [0.30.1] - 2024-06-20

View File

@ -1,10 +1,9 @@
import bs58 from "bs58"; import bs58 from "bs58";
import { Buffer } from "buffer"; import { Buffer } from "buffer";
import { Layout } from "buffer-layout"; import { Layout } from "buffer-layout";
import { Idl } from "../../idl.js"; import { Idl, IdlDiscriminator } from "../../idl.js";
import { IdlCoder } from "./idl.js"; import { IdlCoder } from "./idl.js";
import { AccountsCoder } from "../index.js"; import { AccountsCoder } from "../index.js";
import { DISCRIMINATOR_SIZE } from "./discriminator.js";
/** /**
* Encodes and decodes account objects. * Encodes and decodes account objects.
@ -15,7 +14,10 @@ export class BorshAccountsCoder<A extends string = string>
/** /**
* Maps account type identifier to a layout. * Maps account type identifier to a layout.
*/ */
private accountLayouts: Map<A, Layout>; private accountLayouts: Map<
A,
{ discriminator: IdlDiscriminator; layout: Layout }
>;
public constructor(private idl: Idl) { public constructor(private idl: Idl) {
if (!idl.accounts) { if (!idl.accounts) {
@ -28,12 +30,18 @@ export class BorshAccountsCoder<A extends string = string>
throw new Error("Accounts require `idl.types`"); throw new Error("Accounts require `idl.types`");
} }
const layouts: [A, Layout][] = idl.accounts.map((acc) => { const layouts = idl.accounts.map((acc) => {
const typeDef = types.find((ty) => ty.name === acc.name); const typeDef = types.find((ty) => ty.name === acc.name);
if (!typeDef) { if (!typeDef) {
throw new Error(`Account not found: ${acc.name}`); throw new Error(`Account not found: ${acc.name}`);
} }
return [acc.name as A, IdlCoder.typeDefLayout({ typeDef, types })]; return [
acc.name as A,
{
discriminator: acc.discriminator,
layout: IdlCoder.typeDefLayout({ typeDef, types }),
},
] as const;
}); });
this.accountLayouts = new Map(layouts); this.accountLayouts = new Map(layouts);
@ -45,7 +53,7 @@ export class BorshAccountsCoder<A extends string = string>
if (!layout) { if (!layout) {
throw new Error(`Unknown account: ${accountName}`); throw new Error(`Unknown account: ${accountName}`);
} }
const len = layout.encode(account, buffer); const len = layout.layout.encode(account, buffer);
const accountData = buffer.slice(0, len); const accountData = buffer.slice(0, len);
const discriminator = this.accountDiscriminator(accountName); const discriminator = this.accountDiscriminator(accountName);
return Buffer.concat([discriminator, accountData]); return Buffer.concat([discriminator, accountData]);
@ -54,32 +62,31 @@ export class BorshAccountsCoder<A extends string = string>
public decode<T = any>(accountName: A, data: Buffer): T { public decode<T = any>(accountName: A, data: Buffer): T {
// Assert the account discriminator is correct. // Assert the account discriminator is correct.
const discriminator = this.accountDiscriminator(accountName); const discriminator = this.accountDiscriminator(accountName);
if (discriminator.compare(data.slice(0, DISCRIMINATOR_SIZE))) { if (discriminator.compare(data.slice(0, discriminator.length))) {
throw new Error("Invalid account discriminator"); throw new Error("Invalid account discriminator");
} }
return this.decodeUnchecked(accountName, data); return this.decodeUnchecked(accountName, data);
} }
public decodeAny<T = any>(data: Buffer): T { public decodeAny<T = any>(data: Buffer): T {
const discriminator = data.slice(0, DISCRIMINATOR_SIZE); for (const [name, layout] of this.accountLayouts) {
const accountName = Array.from(this.accountLayouts.keys()).find((key) => const givenDisc = data.subarray(0, layout.discriminator.length);
this.accountDiscriminator(key).equals(discriminator) const matches = givenDisc.equals(Buffer.from(layout.discriminator));
); if (matches) return this.decodeUnchecked(name, data);
if (!accountName) {
throw new Error("Account not found");
} }
return this.decodeUnchecked<T>(accountName, data); throw new Error("Account not found");
} }
public decodeUnchecked<T = any>(accountName: A, acc: Buffer): T { public decodeUnchecked<T = any>(accountName: A, acc: Buffer): T {
// Chop off the discriminator before decoding. // Chop off the discriminator before decoding.
const data = acc.subarray(DISCRIMINATOR_SIZE); const discriminator = this.accountDiscriminator(accountName);
const data = acc.subarray(discriminator.length);
const layout = this.accountLayouts.get(accountName); const layout = this.accountLayouts.get(accountName);
if (!layout) { if (!layout) {
throw new Error(`Unknown account: ${accountName}`); throw new Error(`Unknown account: ${accountName}`);
} }
return layout.decode(data); return layout.layout.decode(data);
} }
public memcmp(accountName: A, appendData?: Buffer): any { public memcmp(accountName: A, appendData?: Buffer): any {
@ -94,7 +101,7 @@ export class BorshAccountsCoder<A extends string = string>
public size(accountName: A): number { public size(accountName: A): number {
return ( return (
DISCRIMINATOR_SIZE + this.accountDiscriminator(accountName).length +
IdlCoder.typeSize({ defined: { name: accountName } }, this.idl) IdlCoder.typeSize({ defined: { name: accountName } }, this.idl)
); );
} }

View File

@ -1,4 +0,0 @@
/**
* Number of bytes in anchor discriminators
*/
export const DISCRIMINATOR_SIZE = 8;

View File

@ -1,7 +1,7 @@
import { Buffer } from "buffer"; import { Buffer } from "buffer";
import { Layout } from "buffer-layout"; import { Layout } from "buffer-layout";
import * as base64 from "../../utils/bytes/base64.js"; import * as base64 from "../../utils/bytes/base64.js";
import { Idl } from "../../idl.js"; import { Idl, IdlDiscriminator } from "../../idl.js";
import { IdlCoder } from "./idl.js"; import { IdlCoder } from "./idl.js";
import { EventCoder } from "../index.js"; import { EventCoder } from "../index.js";
@ -9,12 +9,10 @@ export class BorshEventCoder implements EventCoder {
/** /**
* Maps account type identifier to a layout. * Maps account type identifier to a layout.
*/ */
private layouts: Map<string, Layout>; private layouts: Map<
string,
/** { discriminator: IdlDiscriminator; layout: Layout }
* Maps base64 encoded event discriminator to event name. >;
*/
private discriminators: Map<string, string>;
public constructor(idl: Idl) { public constructor(idl: Idl) {
if (!idl.events) { if (!idl.events) {
@ -27,21 +25,20 @@ export class BorshEventCoder implements EventCoder {
throw new Error("Events require `idl.types`"); throw new Error("Events require `idl.types`");
} }
const layouts: [string, Layout<any>][] = idl.events.map((ev) => { const layouts = idl.events.map((ev) => {
const typeDef = types.find((ty) => ty.name === ev.name); const typeDef = types.find((ty) => ty.name === ev.name);
if (!typeDef) { if (!typeDef) {
throw new Error(`Event not found: ${ev.name}`); throw new Error(`Event not found: ${ev.name}`);
} }
return [ev.name, IdlCoder.typeDefLayout({ typeDef, types })]; return [
ev.name,
{
discriminator: ev.discriminator,
layout: IdlCoder.typeDefLayout({ typeDef, types }),
},
] as const;
}); });
this.layouts = new Map(layouts); this.layouts = new Map(layouts);
this.discriminators = new Map<string, string>(
(idl.events ?? []).map((ev) => [
base64.encode(Buffer.from(ev.discriminator)),
ev.name,
])
);
} }
public decode(log: string): { public decode(log: string): {
@ -55,19 +52,18 @@ export class BorshEventCoder implements EventCoder {
} catch (e) { } catch (e) {
return null; return null;
} }
const disc = base64.encode(logArr.slice(0, 8));
// Only deserialize if the discriminator implies a proper event. for (const [name, layout] of this.layouts) {
const eventName = this.discriminators.get(disc); const givenDisc = logArr.subarray(0, layout.discriminator.length);
if (!eventName) { const matches = givenDisc.equals(Buffer.from(layout.discriminator));
return null; if (matches) {
return {
name,
data: layout.layout.decode(logArr.subarray(givenDisc.length)),
};
}
} }
const layout = this.layouts.get(eventName); return null;
if (!layout) {
throw new Error(`Unknown event: ${eventName}`);
}
const data = layout.decode(logArr.slice(8));
return { data, name: eventName };
} }
} }

View File

@ -7,7 +7,6 @@ import { Coder } from "../index.js";
export { BorshInstructionCoder } from "./instruction.js"; export { BorshInstructionCoder } from "./instruction.js";
export { BorshAccountsCoder } from "./accounts.js"; export { BorshAccountsCoder } from "./accounts.js";
export { DISCRIMINATOR_SIZE } from "./discriminator.js";
export { BorshEventCoder } from "./event.js"; export { BorshEventCoder } from "./event.js";
/** /**

View File

@ -16,7 +16,7 @@ import {
IdlDiscriminator, IdlDiscriminator,
} from "../../idl.js"; } from "../../idl.js";
import { IdlCoder } from "./idl.js"; import { IdlCoder } from "./idl.js";
import { DISCRIMINATOR_SIZE, InstructionCoder } from "../index.js"; import { InstructionCoder } from "../index.js";
/** /**
* Encodes and decodes program instructions. * Encodes and decodes program instructions.
@ -28,9 +28,6 @@ export class BorshInstructionCoder implements InstructionCoder {
{ discriminator: IdlDiscriminator; layout: Layout } { discriminator: IdlDiscriminator; layout: Layout }
>; >;
// Base58 encoded sighash to instruction layout.
private sighashLayouts: Map<string, { name: string; layout: Layout }>;
public constructor(private idl: Idl) { public constructor(private idl: Idl) {
const ixLayouts = idl.instructions.map((ix) => { const ixLayouts = idl.instructions.map((ix) => {
const name = ix.name; const name = ix.name;
@ -41,13 +38,6 @@ export class BorshInstructionCoder implements InstructionCoder {
return [name, { discriminator: ix.discriminator, layout }] as const; return [name, { discriminator: ix.discriminator, layout }] as const;
}); });
this.ixLayouts = new Map(ixLayouts); this.ixLayouts = new Map(ixLayouts);
const sighashLayouts = ixLayouts.map(
([name, { discriminator, layout }]) => {
return [bs58.encode(discriminator), { name, layout }] as const;
}
);
this.sighashLayouts = new Map(sighashLayouts);
} }
/** /**
@ -77,17 +67,18 @@ export class BorshInstructionCoder implements InstructionCoder {
ix = encoding === "hex" ? Buffer.from(ix, "hex") : bs58.decode(ix); ix = encoding === "hex" ? Buffer.from(ix, "hex") : bs58.decode(ix);
} }
const disc = ix.slice(0, DISCRIMINATOR_SIZE); for (const [name, layout] of this.ixLayouts) {
const data = ix.slice(DISCRIMINATOR_SIZE); const givenDisc = ix.subarray(0, layout.discriminator.length);
const decoder = this.sighashLayouts.get(bs58.encode(disc)); const matches = givenDisc.equals(Buffer.from(layout.discriminator));
if (!decoder) { if (matches) {
return null; return {
name,
data: layout.layout.decode(ix.subarray(givenDisc.length)),
};
}
} }
return { return null;
name: decoder.name,
data: decoder.layout.decode(data),
};
} }
/** /**

View File

@ -1,7 +1,5 @@
import * as assert from "assert"; import * as assert from "assert";
import { BorshCoder, Idl } from "../src"; import { BorshCoder, Idl } from "../src";
import { DISCRIMINATOR_SIZE } from "../src/coder/borsh/discriminator";
import { sha256 } from "@noble/hashes/sha256";
describe("coder.accounts", () => { describe("coder.accounts", () => {
test("Can encode and decode user-defined accounts, including those with consecutive capital letters", () => { test("Can encode and decode user-defined accounts, including those with consecutive capital letters", () => {