Skip to content

Commit

Permalink
make function calls cached session-wide (dagger#9621)
Browse files Browse the repository at this point in the history
* make function calls cached session-wide

Recently, function calls were "upgraded" from being tainted (i.e. never
cached) to cached-per-client in terms of dagql caching. This helped
fixed a problem with duplicated telemetry.

However, we currently have two layers of caching between dagql and
buildkit and this cache-per-client logic inadvertently ended up also
applying to the buildkit cache key, which increased the amount of cache
invalidation there.

This surprisingly didn't have a super noticeable effect on performance
at least as far as our tests that always run go.

However, the benchmark tests Connor added did end up hitting this,
specifically in BenchmarkLotsOfDeps, which creates a complicated DAG of
function calls that, if not cached correctly, will have O(n^2)
performance (in memory and time).

The fix here changes function calls to be cached in dagql session-wide
and restores the previous buildkit cache key so that we get consistent
session-wide cache hits there again too.

There is some complication from secrets, sockets and other
client-specific resources though. They need to transfer from return
values of function calls to the calling client whether or not the client
hit the cache.

To handle that problem, this change adds support for a "post call"
callback that, if set, dagql will always execute on results whether or
not they were retrieved from cache or an actual execution.

The implementation is admittedly pretty kludgy and ugly right now, but I
think we should go with it for now to fix this problem. We are about to
embark on quite a bit of work around dagql caching so hopefully we can
find a cleaner way of implementing post calls (or some equivalent
replacement) at that time.

Signed-off-by: Erik Sipsma <erik@sipsma.dev>

* unwrap more in telemetry

Signed-off-by: Erik Sipsma <erik@sipsma.dev>

---------

Signed-off-by: Erik Sipsma <erik@sipsma.dev>
  • Loading branch information
sipsma authored Feb 19, 2025
1 parent 7523178 commit 4bb955c
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 22 deletions.
8 changes: 7 additions & 1 deletion core/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ func (iface *InterfaceType) Install(ctx context.Context, dag *dagql.Server) erro
})
}

res, err := callable.Call(ctx, &CallOpts{
postCallRes, err := callable.Call(ctx, &CallOpts{
Inputs: callInputs,
ParentTyped: runtimeVal,
ParentFields: runtimeVal.Fields,
Expand All @@ -287,6 +287,12 @@ func (iface *InterfaceType) Install(ctx context.Context, dag *dagql.Server) erro
if err != nil {
return nil, fmt.Errorf("failed to call interface function %s.%s: %w", ifaceName, fieldDef.Name, err)
}
res := postCallRes.Typed
if postCallRes.PostCall != nil {
if err := postCallRes.PostCall(ctx); err != nil {
return nil, fmt.Errorf("failed to run post-call for %s.%s: %w", ifaceName, fieldDef.Name, err)
}
}

if fnTypeDef.ReturnType.Underlying().Kind != TypeDefKindInterface {
return res, nil
Expand Down
2 changes: 1 addition & 1 deletion core/moddeps.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ func (d *ModDeps) lazilyLoadSchema(ctx context.Context) (
IfaceType: ifaceType,
}, nil
},
CachePerClientObject,
nil,
)
}
}
Expand Down
34 changes: 26 additions & 8 deletions core/modfunc.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"path/filepath"
"strings"
"sync"

bkgw "github.com/moby/buildkit/frontend/gateway/client"
"github.com/moby/buildkit/identity"
Expand All @@ -22,6 +23,7 @@ import (

"github.com/dagger/dagger/analytics"
"github.com/dagger/dagger/dagql"
"github.com/dagger/dagger/engine"
"github.com/dagger/dagger/engine/buildkit"
"github.com/dagger/dagger/engine/server/resource"
"github.com/dagger/dagger/engine/slog"
Expand Down Expand Up @@ -213,7 +215,7 @@ func (fn *ModuleFunction) setCallInputs(ctx context.Context, opts *CallOpts) ([]
return callInputs, nil
}

func (fn *ModuleFunction) Call(ctx context.Context, opts *CallOpts) (t dagql.Typed, rerr error) { //nolint: gocyclo
func (fn *ModuleFunction) Call(ctx context.Context, opts *CallOpts) (t *dagql.PostCallTyped, rerr error) { //nolint: gocyclo
mod := fn.mod

lg := bklog.G(ctx).WithField("module", mod.Name()).WithField("function", fn.metadata.Name)
Expand Down Expand Up @@ -389,18 +391,34 @@ func (fn *ModuleFunction) Call(ctx context.Context, opts *CallOpts) (t dagql.Typ
return nil, fmt.Errorf("failed to collect IDs: %w", err)
}

for _, id := range returnedIDs {
if err := fn.root.AddClientResourcesFromID(ctx, id, clientID, false); err != nil {
return nil, fmt.Errorf("failed to add client resources from ID: %w", err)
}
}

// NOTE: once generalized function caching is enabled we need to ensure that any non-reproducible
// cache entries are linked to the result of this call.
// See the previous implementation of this for a reference:
// https://github.com/dagger/dagger/blob/7c31db76e07c9a17fcdb3f3c4513c915344c1da8/core/modfunc.go#L483

return returnValueTyped, nil
// Function calls are cached per-session, but every client caller needs to add
// secret/socket/etc. resources from the result to their store.
callerClientMemo := sync.Map{}
return &dagql.PostCallTyped{
Typed: returnValueTyped,
PostCall: func(ctx context.Context) error {
// only run this once per calling client, no need to re-add resources
clientMetadata, err := engine.ClientMetadataFromContext(ctx)
if err != nil {
return fmt.Errorf("failed to get client metadata: %w", err)
}
if _, alreadyRan := callerClientMemo.LoadOrStore(clientMetadata.ClientID, struct{}{}); alreadyRan {
return nil
}

for _, id := range returnedIDs {
if err := fn.root.AddClientResourcesFromID(ctx, id, clientID, false); err != nil {
return fmt.Errorf("failed to add client resources from ID: %w", err)
}
}
return nil
},
}, nil
}

func extractError(ctx context.Context, client *buildkit.Client, baseErr error) (dagql.ID[*Error], bool, error) {
Expand Down
18 changes: 8 additions & 10 deletions core/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func (t *ModuleObjectType) TypeDef() *TypeDef {
}

type Callable interface {
Call(context.Context, *CallOpts) (dagql.Typed, error)
Call(context.Context, *CallOpts) (*dagql.PostCallTyped, error)
ReturnType() (ModType, error)
ArgType(argName string) (ModType, error)
}
Expand Down Expand Up @@ -366,8 +366,7 @@ func (obj *ModuleObject) installConstructor(ctx context.Context, dag *dagql.Serv
Server: dag,
})
},
// cache constructor calls per client; a given client will hit cache when making the same call repeatedly
CachePerClientObject,
nil,
)

return nil
Expand Down Expand Up @@ -476,11 +475,6 @@ func objFun(ctx context.Context, mod *Module, objDef *ObjectTypeDef, fun *Functi
})
return modFun.Call(ctx, opts)
},
// Cache calls per client; a given client will hit cache when making the same call repeatedly.
// We can't *quite* mark them as fully cached across clients in a session, since Call has special
// logic for transferring secrets between cached calls (covered by TestModule/TestSecretNested
// integ tests).
CacheKeyFunc: CachePerClient[*ModuleObject, map[string]dagql.Input],
}, nil
}

Expand All @@ -490,12 +484,16 @@ type CallableField struct {
Return ModType
}

func (f *CallableField) Call(ctx context.Context, opts *CallOpts) (dagql.Typed, error) {
func (f *CallableField) Call(ctx context.Context, opts *CallOpts) (*dagql.PostCallTyped, error) {
val, ok := opts.ParentFields[f.Field.OriginalName]
if !ok {
return nil, fmt.Errorf("field %q not found on object %q", f.Field.Name, opts.ParentFields)
}
return f.Return.ConvertFromSDKResult(ctx, val)
typed, err := f.Return.ConvertFromSDKResult(ctx, val)
if err != nil {
return nil, fmt.Errorf("failed to convert field %q: %w", f.Field.Name, err)
}
return &dagql.PostCallTyped{Typed: typed}, nil
}

func (f *CallableField) ReturnType() (ModType, error) {
Expand Down
9 changes: 8 additions & 1 deletion core/schema/modulesource.go
Original file line number Diff line number Diff line change
Expand Up @@ -2028,7 +2028,7 @@ func (s *moduleSourceSchema) moduleSourceAsModule(
getModDefSpan.End()
return inst, fmt.Errorf("failed to create module definition function for module %q: %w", modName, err)
}
result, err := getModDefFn.Call(getModDefCtx, &core.CallOpts{
postCallRes, err := getModDefFn.Call(getModDefCtx, &core.CallOpts{
Cache: true,
SkipSelfSchema: true,
Server: s.dag,
Expand All @@ -2042,6 +2042,13 @@ func (s *moduleSourceSchema) moduleSourceAsModule(
getModDefSpan.End()
return inst, fmt.Errorf("failed to call module %q to get functions: %w", modName, err)
}
result := postCallRes.Typed
if postCallRes.PostCall != nil {
if err := postCallRes.PostCall(ctx); err != nil {
getModDefSpan.End()
return inst, fmt.Errorf("failed to run post-call for module %q: %w", modName, err)
}
}
resultInst, ok := result.(dagql.Instance[*core.Module])
if !ok {
getModDefSpan.End()
Expand Down
9 changes: 8 additions & 1 deletion core/telemetry.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,14 @@ func AroundFunc(ctx context.Context, self dagql.Object, id *call.ID) (context.Co
// This allows the UI to "simplify" the returned object's ID back to the
// current call's ID, so we can show the user myMod().unit().stdout()
// instead of container().from().[...].stdout().
if obj, ok := res.(dagql.Object); ok {
obj, isObj := res.(dagql.Object)
if !isObj {
// try again unwrapping it
if wrapper, isWrapper := res.(dagql.Wrapper); isWrapper {
obj, isObj = wrapper.Unwrap().(dagql.Object)
}
}
if isObj {
// Don't consider loadFooFromID to be a 'creator' as that would only
// obfuscate the real ID.
//
Expand Down
25 changes: 25 additions & 0 deletions dagql/objects.go
Original file line number Diff line number Diff line change
Expand Up @@ -521,9 +521,34 @@ func (r Instance[T]) call(
}
}

// field implementations can optionally return a wrapped Typed val that has
// a callback that should always run after the field is called
if postCallVal, ok := val.(*PostCallTyped); ok {
val = postCallVal.Typed
if postCallVal.PostCall != nil {
if err := postCallVal.PostCall(ctx); err != nil {
return nil, nil, fmt.Errorf("post-call error: %w", err)
}
}
}

return val, newID, nil
}

// PostCallTyped wraps a Typed value with an additional callback that
// needs to be called after any value is returned, whether the value was from
// cache or not
type PostCallTyped struct {
Typed
PostCall func(context.Context) error
}

var _ Wrapper = PostCallTyped{}

func (p PostCallTyped) Unwrap() Typed {
return p.Typed
}

type View interface {
Contains(string) bool
}
Expand Down

0 comments on commit 4bb955c

Please sign in to comment.