Skip to content

Commit

Permalink
Openfga: use reference branches in intersection (#365)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ulimo authored Feb 12, 2024
1 parent f446fad commit 341fda2
Show file tree
Hide file tree
Showing 7 changed files with 254 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,18 @@ private sealed class Result
private readonly HashSet<string> _stopTypes;
private readonly HashSet<TypeReference> visitedTypes;
private readonly HashSet<TypeReference> loopFoundTypes;
private readonly List<ZanzibarRelation> _relations;

public FlowtideZanzibarConverter(AuthorizationModel authorizationModel, HashSet<string> stopTypes)
{
this.authorizationModel = authorizationModel;
_stopTypes = stopTypes;
visitedTypes = new HashSet<TypeReference>();
loopFoundTypes = new HashSet<TypeReference>();
_relations = new List<ZanzibarRelation>();
}

public ZanzibarRelation Parse(string type, string relation)
public List<ZanzibarRelation> Parse(string type, string relation)
{
var typeDefinition = authorizationModel.TypeDefinitions.Find(x => x.Type.Equals(type, StringComparison.OrdinalIgnoreCase));

Expand All @@ -104,7 +106,8 @@ public ZanzibarRelation Parse(string type, string relation)
}

var result = VisitRelationDefinition(relationDefinition, relation, typeDefinition);
return result.Relation;
_relations.Add(result.Relation);
return _relations;
}

private Result VisitRelationDefinition(Userset relationDefinition, string relationName, TypeDefinition typeDefinition)
Expand Down Expand Up @@ -168,12 +171,23 @@ private Result VisitIntersection(Usersets userset, string relationName, TypeDefi
}
else
{
var rootRelReference = new ZanzibarRelationReference()
{
ReferenceId = _relations.Count
};
_relations.Add(rootRel);
var subRelReference = new ZanzibarRelationReference()
{
ReferenceId = _relations.Count
};
_relations.Add(subRel.Relation);
List<ZanzibarRelation> relations = new List<ZanzibarRelation>();
var joinEqual = new ZanzibarJoinOnUserTypeId()
{
Left = rootRel,
Right = subRel.Relation
Left = rootRelReference,
Right = subRelReference
};

relations.Add(joinEqual);
var leftHasWildcard = resultTypes.Any(x => x.Wildcard);
var rightHasWildcard = subRel.ResultTypes.Any(x => x.Wildcard);
Expand All @@ -182,8 +196,8 @@ private Result VisitIntersection(Usersets userset, string relationName, TypeDefi
{
relations.Add(new ZanzibarJoinIntersectWildcard()
{
Left = rootRel,
Right = subRel.Relation,
Left = rootRelReference,
Right = subRelReference,
LeftWildcard = true
});
}
Expand All @@ -192,8 +206,8 @@ private Result VisitIntersection(Usersets userset, string relationName, TypeDefi
{
relations.Add(new ZanzibarJoinIntersectWildcard()
{
Left = rootRel,
Right = subRel.Relation,
Left = rootRelReference,
Right = subRelReference,
LeftWildcard = false
});
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Licensed under the Apache License, Version 2.0 (the "License")
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace FlowtideDotNet.Connector.OpenFGA.Internal.Models
{
internal class ZanzibarRelationReference : ZanzibarRelation
{
public int ReferenceId { get; set; }
public override T Accept<T, TState>(ZanzibarVisitor<T, TState> visitor, TState state)
{
return visitor.VisitZanzibarRelationReference(this, state);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,10 @@ public virtual T VisitZanzibarUnion(ZanzibarUnion union, TState state)
{
throw new NotImplementedException();
}

public virtual T VisitZanzibarRelationReference(ZanzibarRelationReference relationReference, TState state)
{
throw new NotImplementedException();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -677,5 +677,14 @@ public override Relation VisitZanzibarUnion(ZanzibarUnion union, object? state)

return setRelation;
}

public override Relation VisitZanzibarRelationReference(ZanzibarRelationReference relationReference, object? state)
{
return new ReferenceRelation()
{
ReferenceOutputLength = 6,
RelationId = relationReference.ReferenceId
};
}
}
}
26 changes: 16 additions & 10 deletions src/FlowtideDotNet.Connector.OpenFGA/OpenFgaToFlowtide.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
// limitations under the License.

using FlowtideDotNet.Connector.OpenFGA.Internal;
using FlowtideDotNet.Connector.OpenFGA.Internal.Models;
using FlowtideDotNet.Substrait;
using FlowtideDotNet.Substrait.Relations;
using OpenFga.Sdk.Model;
Expand All @@ -35,9 +36,20 @@ public static Plan Convert(AuthorizationModel authorizationModel, string type, s
}
}

var zanzibarRelation = new FlowtideZanzibarConverter(authorizationModel, stopTypes.ToHashSet()).Parse(type, relation);
var outputPlan = new Plan()
{
Relations = new List<Relation>()
};

var zanzibarRelations = new FlowtideZanzibarConverter(authorizationModel, stopTypes.ToHashSet()).Parse(type, relation);
var visitor = new ZanzibarToFlowtideVisitor(inputTypeName);
var flowtideRelation = visitor.Visit(zanzibarRelation, default);

for (int i = 0; i < zanzibarRelations.Count - 1; i++)
{
var flowtideReferenceRelation = visitor.Visit(zanzibarRelations[i], default);
outputPlan.Relations.Add(flowtideReferenceRelation);
}
var flowtideRelation = visitor.Visit(zanzibarRelations[zanzibarRelations.Count - 1], default);

var rootRelation = new RootRelation()
{
Expand All @@ -52,14 +64,8 @@ public static Plan Convert(AuthorizationModel authorizationModel, string type, s
"object_id"
}
};

var outputPlan = new Plan()
{
Relations = new List<Substrait.Relations.Relation>()
{
rootRelation
}
};
outputPlan.Relations.Add(rootRelation);


return outputPlan;
}
Expand Down
160 changes: 160 additions & 0 deletions src/FlowtideDotNet.Substrait/Modifier/ReferenceRemapVisitor.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
// Licensed under the Apache License, Version 2.0 (the "License")
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

using FlowtideDotNet.Substrait.Relations;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace FlowtideDotNet.Substrait.Modifier
{
internal class ReferenceRemapVisitor : RelationVisitor<Relation, object?>
{
private readonly Dictionary<int, int> idMap;

public ReferenceRemapVisitor(Dictionary<int, int> idMap)
{
this.idMap = idMap;
}

public override Relation VisitReadRelation(ReadRelation readRelation, object? state)
{
return readRelation;
}

public override Relation VisitFilterRelation(FilterRelation filterRelation, object? state)
{
filterRelation.Input = Visit(filterRelation.Input, state);
return filterRelation;
}

public override Relation VisitJoinRelation(JoinRelation joinRelation, object? state)
{
joinRelation.Left = Visit(joinRelation.Left, state);
joinRelation.Right = Visit(joinRelation.Right, state);
return joinRelation;
}

public override Relation VisitNormalizationRelation(NormalizationRelation normalizationRelation, object? state)
{
normalizationRelation.Input = Visit(normalizationRelation.Input, state);
return normalizationRelation;
}

public override Relation VisitProjectRelation(ProjectRelation projectRelation, object? state)
{
projectRelation.Input = Visit(projectRelation.Input, state);
return projectRelation;
}

public override Relation VisitWriteRelation(WriteRelation writeRelation, object? state)
{
writeRelation.Input = Visit(writeRelation.Input, state);
return writeRelation;
}

public override Relation VisitPlanRelation(PlanRelation planRelation, object? state)
{
return planRelation;
}

public override Relation VisitReferenceRelation(ReferenceRelation referenceRelation, object? state)
{
if (idMap.TryGetValue(referenceRelation.RelationId, out int newId))
{
referenceRelation.RelationId = newId;
}
return referenceRelation;
}

public override Relation VisitRootRelation(RootRelation rootRelation, object? state)
{
rootRelation.Input = Visit(rootRelation.Input, state);
return rootRelation;
}

public override Relation VisitSetRelation(SetRelation setRelation, object? state)
{
for (int i = 0; i < setRelation.Inputs.Count; i++)
{
setRelation.Inputs[i] = Visit(setRelation.Inputs[i], state);
}
return setRelation;
}

public override Relation VisitMergeJoinRelation(MergeJoinRelation mergeJoinRelation, object? state)
{
mergeJoinRelation.Left = Visit(mergeJoinRelation.Left, state);
mergeJoinRelation.Right = Visit(mergeJoinRelation.Right, state);
return mergeJoinRelation;
}

public override Relation VisitAggregateRelation(AggregateRelation aggregateRelation, object? state)
{
aggregateRelation.Input = Visit(aggregateRelation.Input, state);
return aggregateRelation;
}

public override Relation VisitIterationRelation(IterationRelation iterationRelation, object? state)
{
if (iterationRelation.Input != null)
{
iterationRelation.Input = Visit(iterationRelation.Input, state);
}
iterationRelation.LoopPlan = Visit(iterationRelation.LoopPlan, state);

return iterationRelation;
}

public override Relation VisitIterationReferenceReadRelation(IterationReferenceReadRelation iterationReferenceReadRelation, object? state)
{
return iterationReferenceReadRelation;
}

public override Relation VisitUnwrapRelation(UnwrapRelation unwrapRelation, object? state)
{
unwrapRelation.Input = Visit(unwrapRelation.Input, state);
return unwrapRelation;
}

public override Relation VisitVirtualTableReadRelation(VirtualTableReadRelation virtualTableReadRelation, object? state)
{
return virtualTableReadRelation;
}

public override Relation VisitBufferRelation(BufferRelation bufferRelation, object? state)
{
bufferRelation.Input = Visit(bufferRelation.Input, state);
return bufferRelation;
}

public override Relation VisitFetchRelation(FetchRelation fetchRelation, object? state)
{
fetchRelation.Input = Visit(fetchRelation.Input, state);
return fetchRelation;
}

public override Relation VisitSortRelation(SortRelation sortRelation, object? state)
{
sortRelation.Input = Visit(sortRelation.Input, state);
return sortRelation;
}

public override Relation VisitTopNRelation(TopNRelation topNRelation, object? state)
{
topNRelation.Input = Visit(topNRelation.Input, state);
return topNRelation;
}
}
}
15 changes: 13 additions & 2 deletions src/FlowtideDotNet.Substrait/PlanModifier.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,12 @@ public Plan Modify()
Relations = new List<FlowtideDotNet.Substrait.Relations.Relation>()
};
Dictionary<string, ReferenceInfo> subPlanNameToId = new Dictionary<string, ReferenceInfo>(StringComparer.OrdinalIgnoreCase);
foreach(var subplan in _subplans)
foreach (var subplan in _subplans)
{
Dictionary<int, int> oldRelationToNewMap = new Dictionary<int, int>();
var referenceRemapVisitor = new ReferenceRemapVisitor(oldRelationToNewMap);

bool containsRootRelation = subplan.Value.Relations.Any(x => x is RootRelation);
// TODO: Must remap reference relations from sub plans to their new id.
for (int i = 0; i < subplan.Value.Relations.Count; i++)
{
Expand All @@ -95,14 +99,21 @@ public Plan Modify()
else
{
var relationId = newPlan.Relations.Count;
oldRelationToNewMap.Add(i, relationId);
subPlanNameToId.Add(subplan.Key, new ReferenceInfo(relationId, rootRelation.Input.OutputLength));
referenceRemapVisitor.Visit(rootRelation.Input, default);
newPlan.Relations.Add(rootRelation.Input);
}
}
else
{
var relationId = newPlan.Relations.Count;
subPlanNameToId.Add(subplan.Key, new ReferenceInfo(relationId, relation.OutputLength));
oldRelationToNewMap.Add(i, relationId);
if (!containsRootRelation)
{
subPlanNameToId.Add(subplan.Key, new ReferenceInfo(relationId, relation.OutputLength));
}
referenceRemapVisitor.Visit(relation, default);
newPlan.Relations.Add(relation);
}
}
Expand Down

0 comments on commit 341fda2

Please sign in to comment.