Skip to content

Commit

Permalink
Support signed queries.
Browse files Browse the repository at this point in the history
  • Loading branch information
q-uint committed Jun 3, 2024
1 parent 1eeaf4b commit 3d11816
Show file tree
Hide file tree
Showing 16 changed files with 3,051 additions and 2,556 deletions.
333 changes: 61 additions & 272 deletions agent.go

Large diffs are not rendered by default.

8 changes: 5 additions & 3 deletions agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,14 @@ func Example_json() {
_ = json.Unmarshal([]byte(raw), &balance)
fmt.Println(balance.E8S)

a, _ := agent.New(agent.Config{})
_ = a.Query(ic.LEDGER_PRINCIPAL, "account_balance_dfx", []any{struct {
a, _ := agent.New(agent.DefaultConfig)
if err := a.Query(ic.LEDGER_PRINCIPAL, "account_balance_dfx", []any{struct {
Account string `json:"account"`
}{
Account: "9523dc824aa062dcd9c91b98f4594ff9c6af661ac96747daef2090b7fe87037d",
}}, []any{&balance}) // Repurposing the balance struct.
}}, []any{&balance}); err != nil {
fmt.Println(err)
}
rawJSON, _ := json.Marshal(balance)
fmt.Println(string(rawJSON))
// Output:
Expand Down
117 changes: 117 additions & 0 deletions call.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package agent

import (
"github.com/aviate-labs/agent-go/candid/idl"
"github.com/aviate-labs/agent-go/principal"
"google.golang.org/protobuf/proto"
)

// Call calls a method on a canister and unmarshals the result into the given values.
func (a Agent) Call(canisterID principal.Principal, methodName string, args []any, values []any) error {
call, err := a.CreateCall(canisterID, methodName, args...)
if err != nil {
return err
}
return call.CallAndWait(values...)
}

// CallProto calls a method on a canister and unmarshals the result into the given proto message.
func (a Agent) CallProto(canisterID principal.Principal, methodName string, in, out proto.Message) error {
payload, err := proto.Marshal(in)
if err != nil {
return err
}
requestID, data, err := a.sign(Request{
Type: RequestTypeCall,
Sender: a.Sender(),
IngressExpiry: a.expiryDate(),
CanisterID: canisterID,
MethodName: methodName,
Arguments: payload,
})
if err != nil {
return err
}
if _, err := a.call(canisterID, data); err != nil {
return err
}
raw, err := a.poll(canisterID, *requestID)
if err != nil {
return err
}
return proto.Unmarshal(raw, out)
}

// CreateCall creates a new Call to the given canister and method.
func (a *Agent) CreateCall(canisterID principal.Principal, methodName string, args ...any) (*Call, error) {
rawArgs, err := idl.Marshal(args)
if err != nil {
return nil, err
}
if len(args) == 0 {
// Default to the empty Candid argument list.
rawArgs = []byte{'D', 'I', 'D', 'L', 0, 0}
}
nonce, err := newNonce()
if err != nil {
return nil, err
}
requestID, data, err := a.sign(Request{
Type: RequestTypeCall,
Sender: a.Sender(),
CanisterID: canisterID,
MethodName: methodName,
Arguments: rawArgs,
IngressExpiry: a.expiryDate(),
Nonce: nonce,
})
if err != nil {
return nil, err
}
return &Call{
a: a,
methodName: methodName,
effectiveCanisterID: effectiveCanisterID(canisterID, args),
requestID: *requestID,
data: data,
}, nil
}

// Call is an intermediate representation of a Call to a canister.
type Call struct {
a *Agent
methodName string
effectiveCanisterID principal.Principal
requestID RequestID
data []byte
}

// Call calls a method on a canister, it does not wait for the result.
func (c Call) Call() error {
c.a.logger.Printf("[AGENT] CALL %s %s (%x)", c.effectiveCanisterID, c.methodName, c.requestID)
_, err := c.a.call(c.effectiveCanisterID, c.data)
return err
}

// CallAndWait calls a method on a canister and waits for the result.
func (c Call) CallAndWait(values ...any) error {
if err := c.Call(); err != nil {
return err
}
return c.Wait(values...)
}

// Wait waits for the result of the Call and unmarshals it into the given values.
func (c Call) Wait(values ...any) error {
raw, err := c.a.poll(c.effectiveCanisterID, c.requestID)
if err != nil {
return err
}
return idl.Unmarshal(raw, values)
}

// WithEffectiveCanisterID sets the effective canister ID for the Call.
func (c *Call) WithEffectiveCanisterID(canisterID principal.Principal) *Call {
c.effectiveCanisterID = canisterID
return c
}
79 changes: 63 additions & 16 deletions certification/certificate.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package certification

import (
"bytes"
"crypto/ed25519"
"encoding/asn1"
"fmt"
"slices"
Expand All @@ -13,7 +14,7 @@ import (
"github.com/fxamacker/cbor/v2"
)

func PublicKeyFromDER(der []byte) (*bls.PublicKey, error) {
func PublicBLSKeyFromDER(der []byte) (*bls.PublicKey, error) {
var seq asn1.RawValue
if _, err := asn1.Unmarshal(der, &seq); err != nil {
return nil, err
Expand Down Expand Up @@ -51,7 +52,7 @@ func PublicKeyFromDER(der []byte) (*bls.PublicKey, error) {
return bls.PublicKeyFromBytes(bs.Bytes)
}

func PublicKeyToDER(publicKey []byte) ([]byte, error) {
func PublicBLSKeyToDER(publicKey []byte) ([]byte, error) {
if len(publicKey) != 96 {
return nil, fmt.Errorf("invalid public key length: %d", len(publicKey))
}
Expand All @@ -67,12 +68,40 @@ func PublicKeyToDER(publicKey []byte) ([]byte, error) {
})
}

func PublicED25519KeyFromDER(der []byte) (*ed25519.PublicKey, error) {
var seq asn1.RawValue
if _, err := asn1.Unmarshal(der, &seq); err != nil {
return nil, err
}
if seq.Tag != asn1.TagSequence {
return nil, fmt.Errorf("invalid tag: %d", seq.Tag)
}
var idSeq asn1.RawValue
rest, err := asn1.Unmarshal(seq.Bytes, &idSeq)
if err != nil {
return nil, err
}
var bs asn1.BitString
if _, err := asn1.Unmarshal(rest, &bs); err != nil {
return nil, err
}
var algoId asn1.ObjectIdentifier
if _, err := asn1.Unmarshal(idSeq.Bytes, &algoId); err != nil {
return nil, err
}
if !algoId.Equal(asn1.ObjectIdentifier{1, 3, 101, 112}) {
return nil, fmt.Errorf("invalid algorithm identifier: %v", algoId)
}
publicKey := ed25519.PublicKey(bs.Bytes)
return &publicKey, nil
}

func VerifyCertificate(
certificate Certificate,
canisterID principal.Principal,
rootPublicKey []byte,
) error {
publicKey, err := PublicKeyFromDER(rootPublicKey)
publicKey, err := PublicBLSKeyFromDER(rootPublicKey)
if err != nil {
return err
}
Expand Down Expand Up @@ -150,7 +179,7 @@ func verifyDelegationCertificate(
if err != nil {
return nil, err
}
var canisterRanges canisterRanges
var canisterRanges CanisterRanges
if err := cbor.Unmarshal(rawRanges, &canisterRanges); err != nil {
return nil, err
}
Expand All @@ -166,7 +195,36 @@ func verifyDelegationCertificate(
if err != nil {
return nil, err
}
return PublicKeyFromDER(rawPublicKey)
return PublicBLSKeyFromDER(rawPublicKey)
}

type CanisterRange struct {
From principal.Principal
To principal.Principal
}

func (c *CanisterRange) UnmarshalCBOR(bytes []byte) error {
var raw [][]byte
if err := cbor.Unmarshal(bytes, &raw); err != nil {
return err
}
if len(raw) != 2 {
return fmt.Errorf("unexpected length: %d", len(raw))
}
c.From = principal.Principal{Raw: raw[0]}
c.To = principal.Principal{Raw: raw[1]}
return nil
}

type CanisterRanges []CanisterRange

func (c CanisterRanges) InRange(canisterID principal.Principal) bool {
for _, r := range c {
if slices.Compare(r.From.Raw, canisterID.Raw) <= 0 && slices.Compare(canisterID.Raw, r.To.Raw) <= 0 {
return true
}
}
return false
}

// Certificate is a certificate gets returned by the IC.
Expand Down Expand Up @@ -211,14 +269,3 @@ func (d *Delegation) UnmarshalCBOR(bytes []byte) error {
}
return nil
}

type canisterRanges [][][]byte

func (c canisterRanges) InRange(canisterID principal.Principal) bool {
for _, pair := range c {
if slices.Compare(pair[0], canisterID.Raw) <= 0 && slices.Compare(canisterID.Raw, pair[1]) <= 0 {
return true
}
}
return false
}
35 changes: 20 additions & 15 deletions certification/hashtree/hashtree.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ package hashtree

import "fmt"

// Lookup looks up a path in the node.
func Lookup(n Node, path ...Label) ([]byte, error) {
return lookupPath(n, path, 0)
}

// HashTree is a hash tree.
type HashTree struct {
Root Node
Expand All @@ -19,12 +24,12 @@ func (t HashTree) Digest() [32]byte {

// Lookup looks up a path in the hash tree.
func (t HashTree) Lookup(path ...Label) ([]byte, error) {
return lookupPath(t.Root, path, 0)
return Lookup(t.Root, path...)
}

// LookupSubTree looks up a path in the hash tree and returns the sub-tree.
func (t HashTree) LookupSubTree(path ...Label) (Node, error) {
return lookupSubTree(t.Root, path, 0)
return LookupSubTree(t.Root, path...)
}

// MarshalCBOR marshals a hash tree.
Expand All @@ -42,31 +47,31 @@ func (t *HashTree) UnmarshalCBOR(bytes []byte) error {
return nil
}

type PathValuePair struct {
// LookupSubTree looks up a path in the node and returns the sub-tree.
func LookupSubTree(n Node, path ...Label) (Node, error) {
return lookupSubTree(n, path, 0)
}

type PathValuePair[V any] struct {
Path []Label
Value []byte
Value V
}

func AllChildren(n Node) ([]PathValuePair, error) {
func AllChildren(n Node) ([]PathValuePair[Node], error) {
return allChildren(n)
}

// AllPaths returns all non-empty labeled paths in the hash tree, does not include pruned nodes.
func AllPaths(n Node) ([]PathValuePair, error) {
func AllPaths(n Node) ([]PathValuePair[[]byte], error) {
return allLabeled(n, nil)
}

func allChildren(n Node) ([]PathValuePair, error) {
func allChildren(n Node) ([]PathValuePair[Node], error) {
switch n := n.(type) {
case Empty, Pruned, Leaf:
return nil, nil
case Labeled:
switch c := n.Tree.(type) {
case Leaf:
return []PathValuePair{{Path: []Label{n.Label}, Value: c}}, nil
default:
return nil, nil
}
return []PathValuePair[Node]{{Path: []Label{n.Label}, Value: n.Tree}}, nil
case Fork:
left, err := allChildren(n.LeftTree)
if err != nil {
Expand All @@ -82,12 +87,12 @@ func allChildren(n Node) ([]PathValuePair, error) {
}
}

func allLabeled(n Node, path []Label) ([]PathValuePair, error) {
func allLabeled(n Node, path []Label) ([]PathValuePair[[]byte], error) {
switch n := n.(type) {
case Empty, Pruned:
return nil, nil
case Leaf:
return []PathValuePair{{Path: path, Value: n}}, nil
return []PathValuePair[[]byte]{{Path: path, Value: n}}, nil
case Labeled:
return allLabeled(n.Tree, append(path, n.Label))
case Fork:
Expand Down
8 changes: 6 additions & 2 deletions certification/rootKey.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
package certification

// RootKey is the root key of IC main net.
const RootKey = "308182301d060d2b0601040182dc7c0503010201060c2b0601040182dc7c05030201036100814c0e6ec71fab583b08bd81373c255c3c371b2e84863c98a4f1e08b74235d14fb5d9c0cd546d9685f913a0c0b2cc5341583bf4b4392e467db96d65b9bb4cb717112f8472e0d5a4d14505ffd7484b01291091c5f87b98883463f98091a0baaae"
const (
// RootKey is the root key of IC main net.
RootKey = "308182301d060d2b0601040182dc7c0503010201060c2b0601040182dc7c05030201036100814c0e6ec71fab583b08bd81373c255c3c371b2e84863c98a4f1e08b74235d14fb5d9c0cd546d9685f913a0c0b2cc5341583bf4b4392e467db96d65b9bb4cb717112f8472e0d5a4d14505ffd7484b01291091c5f87b98883463f98091a0baaae"
// RootSubnetID is the subnet ID of the (NNS) root subnet.
RootSubnetID = "tdb26-jop6k-aogll-7ltgs-eruif-6kk7m-qpktf-gdiqx-mxtrf-vb5e6-eqe"
)
20 changes: 20 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ func (c Client) ReadState(canisterID principal.Principal, data []byte) ([]byte,
return c.post("read_state", canisterID, data)
}

func (c Client) ReadSubnetState(subnetID principal.Principal, data []byte) ([]byte, error) {
return c.postSubnet("read_state", subnetID, data)
}

// Status returns the status of the IC.
func (c Client) Status() (*Status, error) {
raw, err := c.get("/api/v2/status")
Expand Down Expand Up @@ -106,6 +110,22 @@ func (c Client) post(path string, canisterID principal.Principal, data []byte) (
}
}

func (c Client) postSubnet(path string, subnetID principal.Principal, data []byte) ([]byte, error) {
u := c.url(fmt.Sprintf("/api/v2/subnet/%s/%s", subnetID.Encode(), path))
c.logger.Printf("[CLIENT] POST %s", u)
resp, err := c.client.Post(u, "application/cbor", bytes.NewBuffer(data))
if err != nil {
return nil, err
}
switch resp.StatusCode {
case http.StatusOK:
return io.ReadAll(resp.Body)
default:
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("(%d) %s: %s", resp.StatusCode, resp.Status, body)
}
}

func (c Client) url(p string) string {
u := *c.config.Host
u.Path = path.Join(u.Path, p)
Expand Down
Loading

0 comments on commit 3d11816

Please sign in to comment.