Skip to content

Commit

Permalink
fix: handle projections after join chain (ibis-project#1119)
Browse files Browse the repository at this point in the history
Ibis IR has a bit of a quirky behavior around projections, namely one can define projections both in ops.Project object and also as part of ops.JoinChain, not sure why that's the case. Current compiler handles only projections defined in ops.Project, therefore any selects written after a join chain are ignored.

The PR fixes the issue by injecting a projection node after every join chain in the plan. Technically this is not always necessary, because if the values contained in the ops.JoinChain do only column renames and reorders (no new expressions) then substrait's Emit message inside JoinRel can be used as well, but that would require inspecting all values first to make sure that nothing other than field references are present in there.


---------

Co-authored-by: tokoko <togurgenidze@gmail.com>
  • Loading branch information
tokoko and tokoko authored Aug 29, 2024
1 parent a3b3f03 commit 7a8b320
Show file tree
Hide file tree
Showing 21 changed files with 10,303 additions and 5,100 deletions.
54 changes: 40 additions & 14 deletions ibis_substrait/compiler/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,25 +815,22 @@ def filter(
)


@translate.register(ops.Project)
def project(
op: ops.Project,
*,
def apply_projection(
schema_len: int,
relation: stalg.Rel,
values: Mapping[str, ops.Value],
compiler: SubstraitCompiler,
child_rel_field_offsets: Mapping[ops.TableNode, int] | None = None,
**kwargs: Any,
) -> stalg.Rel:
relation = translate(
op.parent, compiler=compiler, child_rel_field_offsets=child_rel_field_offsets
)
mapping_counter = itertools.count(len(op.parent.schema))
child_rel_field_offsets: Mapping[ops.TableNode, int] | None,
kwargs: Mapping,
) -> stalg.ReadRel:
mapping_counter = itertools.count(schema_len)

return stalg.Rel(
project=stalg.ProjectRel(
input=relation,
common=stalg.RelCommon(
emit=stalg.RelCommon.Emit(
output_mapping=[next(mapping_counter) for _ in op.values]
output_mapping=[next(mapping_counter) for _ in values]
)
),
expressions=[
Expand All @@ -843,12 +840,34 @@ def project(
child_rel_field_offsets=child_rel_field_offsets,
**kwargs,
)
for k, v in op.values.items()
for k, v in values.items()
],
)
)


@translate.register(ops.Project)
def project(
op: ops.Project,
*,
compiler: SubstraitCompiler,
child_rel_field_offsets: Mapping[ops.TableNode, int] | None = None,
**kwargs: Any,
) -> stalg.Rel:
relation = translate(
op.parent, compiler=compiler, child_rel_field_offsets=child_rel_field_offsets
)

return apply_projection(
schema_len=len(op.parent.schema),
relation=relation,
values=op.values,
compiler=compiler,
child_rel_field_offsets=child_rel_field_offsets,
kwargs=kwargs,
)


@translate.register(ops.Sort)
def sort(
op: ops.Sort,
Expand Down Expand Up @@ -944,7 +963,14 @@ def join(

relation = stalg.Rel(join=rel)

return relation
return apply_projection(
schema_len=offset,
relation=relation,
values=op.values,
compiler=compiler,
child_rel_field_offsets=child_rel_field_offsets,
kwargs=kwargs,
)


@translate.register(ops.Limit)
Expand Down
3 changes: 2 additions & 1 deletion ibis_substrait/tests/compiler/parity_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,5 @@ def execute(self, plan) -> pa.Table:
df = df.with_column_renamed(
column_name, plan.relations[0].root.names[column_number]
)
return df.to_arrow_table()
record_batch = df.collect()
return pa.Table.from_batches(record_batch)

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Loading

0 comments on commit 7a8b320

Please sign in to comment.