Skip to content

Commit

Permalink
fix: handle projections after join chain
Browse files Browse the repository at this point in the history
  • Loading branch information
tokoko committed Aug 28, 2024
1 parent a3b3f03 commit 3d12921
Show file tree
Hide file tree
Showing 21 changed files with 10,303 additions and 5,100 deletions.
54 changes: 40 additions & 14 deletions ibis_substrait/compiler/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,25 +815,22 @@ def filter(
)


@translate.register(ops.Project)
def project(
op: ops.Project,
*,
def apply_projection(
schema_len: int,
relation: stalg.Rel,
values,
compiler: SubstraitCompiler,
child_rel_field_offsets: Mapping[ops.TableNode, int] | None = None,
**kwargs: Any,
) -> stalg.Rel:
relation = translate(
op.parent, compiler=compiler, child_rel_field_offsets=child_rel_field_offsets
)
mapping_counter = itertools.count(len(op.parent.schema))
child_rel_field_offsets,
kwargs,
):
mapping_counter = itertools.count(schema_len)

return stalg.Rel(
project=stalg.ProjectRel(
input=relation,
common=stalg.RelCommon(
emit=stalg.RelCommon.Emit(
output_mapping=[next(mapping_counter) for _ in op.values]
output_mapping=[next(mapping_counter) for _ in values]
)
),
expressions=[
Expand All @@ -843,12 +840,34 @@ def project(
child_rel_field_offsets=child_rel_field_offsets,
**kwargs,
)
for k, v in op.values.items()
for k, v in values.items()
],
)
)


@translate.register(ops.Project)
def project(
op: ops.Project,
*,
compiler: SubstraitCompiler,
child_rel_field_offsets: Mapping[ops.TableNode, int] | None = None,
**kwargs: Any,
) -> stalg.Rel:
relation = translate(
op.parent, compiler=compiler, child_rel_field_offsets=child_rel_field_offsets
)

return apply_projection(
schema_len=len(op.parent.schema),
relation=relation,
values=op.values,
compiler=compiler,
child_rel_field_offsets=child_rel_field_offsets,
kwargs=kwargs,
)


@translate.register(ops.Sort)
def sort(
op: ops.Sort,
Expand Down Expand Up @@ -944,7 +963,14 @@ def join(

relation = stalg.Rel(join=rel)

return relation
return apply_projection(
schema_len=offset,
relation=relation,
values=op.values,
compiler=compiler,
child_rel_field_offsets=child_rel_field_offsets,
kwargs=kwargs,
)


@translate.register(ops.Limit)
Expand Down
3 changes: 2 additions & 1 deletion ibis_substrait/tests/compiler/parity_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,5 @@ def execute(self, plan) -> pa.Table:
df = df.with_column_renamed(
column_name, plan.relations[0].root.names[column_number]
)
return df.to_arrow_table()
record_batch = df.collect()
return pa.Table.from_batches(record_batch)

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Loading

0 comments on commit 3d12921

Please sign in to comment.