Skip to content

Commit 6879923

Browse files
authored
Merge pull request #6040 from neo4j/fix-distinct
Enable distinct for new aggregations
2 parents 98bf3bc + ea74a80 commit 6879923

File tree

7 files changed

+207
-30
lines changed

7 files changed

+207
-30
lines changed

packages/graphql/src/translate/queryAST/ast/fields/aggregation-fields/AggregationAttributeField.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,13 @@ export class AggregationAttributeField extends AggregationField {
6262
const projection = new Cypher.Return([this.createAggregationExpr(listVar), returnVar]);
6363

6464
return new Cypher.With(target)
65+
.distinct()
6566
.orderBy([Cypher.size(aggrProp), "DESC"])
6667
.with([Cypher.collect(aggrProp), listVar])
6768
.return(projection);
6869
}
6970

70-
return new Cypher.With(target).return([this.getAggregationExpr(target), returnVar]);
71+
return new Cypher.With(target).distinct().return([this.getAggregationExpr(target), returnVar]);
7172
}
7273

7374
private createAggregationExpr(target: Cypher.Variable | Cypher.Property): Cypher.Expr {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Licensed under the Apache License, Version 2.0 (the "License");
8+
* you may not use this file except in compliance with the License.
9+
* You may obtain a copy of the License at
10+
*
11+
* http://www.apache.org/licenses/LICENSE-2.0
12+
*
13+
* Unless required by applicable law or agreed to in writing, software
14+
* distributed under the License is distributed on an "AS IS" BASIS,
15+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
* See the License for the specific language governing permissions and
17+
* limitations under the License.
18+
*/
19+
20+
import Cypher from "@neo4j/cypher-builder";
21+
import type { AttributeAdapter } from "../../../../../schema-model/attribute/model-adapters/AttributeAdapter";
22+
import { filterFields, renameFields } from "../../../../../utils/utils";
23+
import type { QueryASTNode } from "../../QueryASTNode";
24+
import { AggregationField } from "./AggregationField";
25+
26+
export class DeprecatedAggregationAttributeField extends AggregationField {
27+
private attribute: AttributeAdapter;
28+
private aggregationProjection: Record<string, string>;
29+
30+
constructor({
31+
alias,
32+
attribute,
33+
aggregationProjection,
34+
}: {
35+
alias: string;
36+
attribute: AttributeAdapter;
37+
aggregationProjection: Record<string, string>;
38+
}) {
39+
super(alias);
40+
this.attribute = attribute;
41+
this.aggregationProjection = aggregationProjection;
42+
}
43+
44+
public getChildren(): QueryASTNode[] {
45+
return [];
46+
}
47+
48+
public getProjectionField(variable: Cypher.Variable): Record<string, Cypher.Expr> {
49+
return { [this.alias]: variable };
50+
}
51+
52+
public getAggregationExpr(target: Cypher.Variable): Cypher.Expr {
53+
const variable = target.property(this.attribute.databaseName);
54+
return this.createAggregationExpr(variable);
55+
}
56+
57+
public getAggregationProjection(target: Cypher.Variable, returnVar: Cypher.Variable): Cypher.Clause {
58+
if (this.attribute.typeHelper.isString()) {
59+
const aggrProp = target.property(this.attribute.databaseName);
60+
const listVar = new Cypher.NamedVariable("list");
61+
62+
const projection = new Cypher.Return([this.createAggregationExpr(listVar), returnVar]);
63+
64+
return new Cypher.With(target)
65+
.orderBy([Cypher.size(aggrProp), "DESC"])
66+
.with([Cypher.collect(aggrProp), listVar])
67+
.return(projection);
68+
}
69+
70+
return new Cypher.With(target).return([this.getAggregationExpr(target), returnVar]);
71+
}
72+
73+
private createAggregationExpr(target: Cypher.Variable | Cypher.Property): Cypher.Expr {
74+
if (this.attribute.typeHelper.isString()) {
75+
const listVar = new Cypher.NamedVariable("list");
76+
return new Cypher.Map(
77+
this.filterProjection({
78+
longest: Cypher.head(listVar),
79+
shortest: Cypher.last(listVar),
80+
})
81+
);
82+
}
83+
84+
// NOTE: These are types that are treated as numeric by aggregation
85+
if (this.attribute.typeHelper.isNumeric()) {
86+
return new Cypher.Map(
87+
this.filterProjection({
88+
min: Cypher.min(target),
89+
max: Cypher.max(target),
90+
average: Cypher.avg(target),
91+
sum: Cypher.sum(target),
92+
})
93+
);
94+
}
95+
96+
if (this.attribute.typeHelper.isDateTime()) {
97+
return new Cypher.Map(
98+
this.filterProjection({
99+
min: this.createDatetimeProjection(Cypher.min(target)),
100+
max: this.createDatetimeProjection(Cypher.max(target)),
101+
})
102+
);
103+
}
104+
if (this.attribute.typeHelper.isTemporal()) {
105+
return new Cypher.Map(
106+
this.filterProjection({
107+
min: Cypher.min(target),
108+
max: Cypher.max(target),
109+
})
110+
);
111+
}
112+
113+
if (this.attribute.typeHelper.isID()) {
114+
return new Cypher.Map(
115+
this.filterProjection({
116+
shortest: Cypher.min(target),
117+
longest: Cypher.max(target),
118+
})
119+
);
120+
}
121+
throw new Error(`Invalid aggregation type ${this.attribute.type.name}`);
122+
}
123+
124+
// Filters and apply aliases in the projection
125+
private filterProjection(projectionFields: Record<string, Cypher.Expr>): Record<string, Cypher.Expr> {
126+
const filteredFields = filterFields(projectionFields, Object.keys(this.aggregationProjection));
127+
return renameFields(filteredFields, this.aggregationProjection);
128+
}
129+
130+
private createDatetimeProjection(expr: Cypher.Expr) {
131+
return Cypher.apoc.date.convertFormat(expr, "iso_zoned_date_time", "iso_offset_date_time");
132+
}
133+
}

packages/graphql/src/translate/queryAST/factory/FieldFactory.ts

+16-6
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import { OperationField } from "../ast/fields/OperationField";
3535
import { AggregationAttributeField } from "../ast/fields/aggregation-fields/AggregationAttributeField";
3636
import type { AggregationField } from "../ast/fields/aggregation-fields/AggregationField";
3737
import { CountField } from "../ast/fields/aggregation-fields/CountField";
38+
import { DeprecatedAggregationAttributeField } from "../ast/fields/aggregation-fields/DeprecatedAggregationAttributeField";
3839
import { DeprecatedCountField } from "../ast/fields/aggregation-fields/DeprecatedCountField";
3940
import { AttributeField } from "../ast/fields/attribute-fields/AttributeField";
4041
import { DateTimeField } from "../ast/fields/attribute-fields/DateTimeField";
@@ -134,7 +135,8 @@ export class FieldFactory {
134135

135136
public createAggregationFields(
136137
entity: ConcreteEntityAdapter | RelationshipAdapter | InterfaceEntityAdapter,
137-
rawFields: Record<string, ResolveTree>
138+
rawFields: Record<string, ResolveTree>,
139+
useDeprecatedAttribute = false
138140
): AggregationField[] {
139141
return filterTruthy(
140142
Object.values(rawFields).map((field) => {
@@ -169,11 +171,19 @@ export class FieldFactory {
169171
return acc;
170172
}, {});
171173

172-
return new AggregationAttributeField({
173-
attribute,
174-
alias: field.alias,
175-
aggregationProjection,
176-
});
174+
if (useDeprecatedAttribute) {
175+
return new DeprecatedAggregationAttributeField({
176+
attribute,
177+
alias: field.alias,
178+
aggregationProjection,
179+
});
180+
} else {
181+
return new AggregationAttributeField({
182+
attribute,
183+
alias: field.alias,
184+
aggregationProjection,
185+
});
186+
}
177187
}
178188
})
179189
);

packages/graphql/src/translate/queryAST/factory/Operations/AggregateFactory.ts

+37-8
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,13 @@ export class AggregateFactory {
4848
resolveTree,
4949
context,
5050
extraWhereArgs = {},
51+
isInConnection = false,
5152
}: {
5253
entityOrRel: ConcreteEntityAdapter | RelationshipAdapter | InterfaceEntityAdapter;
5354
resolveTree: ResolveTree;
5455
context: Neo4jGraphQLTranslationContext;
5556
extraWhereArgs?: Record<string, any>;
57+
isInConnection?: boolean;
5658
}): AggregationOperation | CompositeAggregationOperation {
5759
let entity: ConcreteEntityAdapter | InterfaceEntityAdapter;
5860
if (entityOrRel instanceof RelationshipAdapter) {
@@ -92,6 +94,7 @@ export class AggregateFactory {
9294
resolveTree,
9395
context,
9496
whereArgs: resolveTreeWhere,
97+
deprecatedAttributes: !isInConnection,
9598
});
9699
} else {
97100
// RELATIONSHIP WITH INTERFACE TARGET
@@ -139,6 +142,7 @@ export class AggregateFactory {
139142
context,
140143
operation: compositeAggregationOp,
141144
whereArgs: resolveTreeWhere,
145+
deprecatedAttributes: !isInConnection,
142146
});
143147

144148
return compositeAggregationOp;
@@ -169,6 +173,7 @@ export class AggregateFactory {
169173
resolveTree,
170174
context,
171175
whereArgs: resolveTreeWhere,
176+
deprecatedAttributes: !isInConnection,
172177
});
173178
} else {
174179
// TOP level interface/union
@@ -208,6 +213,7 @@ export class AggregateFactory {
208213
context,
209214
operation: compositeAggregationOp,
210215
whereArgs: resolveTreeWhere,
216+
deprecatedAttributes: !isInConnection,
211217
});
212218
}
213219
}
@@ -248,13 +254,15 @@ export class AggregateFactory {
248254
resolveTree,
249255
context,
250256
whereArgs,
257+
deprecatedAttributes,
251258
}: {
252259
relationship?: RelationshipAdapter;
253260
entity: ConcreteEntityAdapter | InterfaceEntityAdapter;
254261
operation: T;
255262
resolveTree: ResolveTree;
256263
context: Neo4jGraphQLTranslationContext;
257264
whereArgs: Record<string, any>;
265+
deprecatedAttributes: boolean;
258266
}): T {
259267
if (relationship) {
260268
const parsedProjectionFields = this.getAggregationParsedProjectionFields(relationship, resolveTree);
@@ -273,10 +281,19 @@ export class AggregateFactory {
273281

274282
const fields = this.queryASTFactory.fieldFactory.createAggregationFields(
275283
entity,
276-
parsedProjectionFields.fields
284+
parsedProjectionFields.fields,
285+
deprecatedAttributes
286+
);
287+
const nodeFields = this.queryASTFactory.fieldFactory.createAggregationFields(
288+
entity,
289+
nodeRawFields,
290+
deprecatedAttributes
291+
);
292+
const edgeFields = this.queryASTFactory.fieldFactory.createAggregationFields(
293+
relationship,
294+
edgeRawFields,
295+
deprecatedAttributes
277296
);
278-
const nodeFields = this.queryASTFactory.fieldFactory.createAggregationFields(entity, nodeRawFields);
279-
const edgeFields = this.queryASTFactory.fieldFactory.createAggregationFields(relationship, edgeRawFields);
280297
if (isInterfaceEntity(entity)) {
281298
const filters = this.queryASTFactory.filterFactory.createInterfaceNodeFilters({
282299
entity,
@@ -308,7 +325,11 @@ export class AggregateFactory {
308325
...resolveTree.fieldsByTypeName[entity.operations.aggregateTypeNames.node], // Handles both, deprecated and new aggregation parsing
309326
};
310327

311-
const fields = this.queryASTFactory.fieldFactory.createAggregationFields(entity, rawProjectionFields);
328+
const fields = this.queryASTFactory.fieldFactory.createAggregationFields(
329+
entity,
330+
rawProjectionFields,
331+
deprecatedAttributes
332+
);
312333
// TOP Level aggregate in connection
313334
const connectionFields = {
314335
// TOP level connection fields
@@ -320,14 +341,22 @@ export class AggregateFactory {
320341
...nodeResolveTree?.fieldsByTypeName[entity.operations.aggregateTypeNames.node],
321342
};
322343

323-
const nodeFields = this.queryASTFactory.fieldFactory.createAggregationFields(entity, nodeRawFields);
344+
const nodeFields = this.queryASTFactory.fieldFactory.createAggregationFields(
345+
entity,
346+
nodeRawFields,
347+
deprecatedAttributes
348+
);
324349
operation.setNodeFields(nodeFields);
325350
const countResolveTree = findFieldsByNameInFieldsByTypeNameField(connectionFields, "count")[0];
326351

327352
if (countResolveTree) {
328-
const connetionTopFields = this.queryASTFactory.fieldFactory.createAggregationFields(entity, {
329-
count: countResolveTree,
330-
});
353+
const connetionTopFields = this.queryASTFactory.fieldFactory.createAggregationFields(
354+
entity,
355+
{
356+
count: countResolveTree,
357+
},
358+
deprecatedAttributes
359+
);
331360
fields.push(...connetionTopFields);
332361
}
333362

packages/graphql/src/translate/queryAST/factory/Operations/ConnectionFactory.ts

+2
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ export class ConnectionFactory {
267267
resolveTree: resolveTreeAggregate,
268268
context,
269269
extraWhereArgs: whereArgs,
270+
isInConnection: true,
270271
});
271272
// NOTE: This will always be true on 7.x and this attribute should be removed
272273
aggregationOperation.isInConnectionField = true;
@@ -288,6 +289,7 @@ export class ConnectionFactory {
288289
resolveTree: resolveTreeAggregate,
289290
context,
290291
extraWhereArgs: whereArgs,
292+
isInConnection: true,
291293
});
292294
// NOTE: This will always be true on 7.x and this attribute should be removed
293295
aggregationOperation.isInConnectionField = true;

packages/graphql/tests/integration/aggregations/field-level/field-level-aggregations.int.test.ts

+9-8
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ describe("Field Level Aggregations", () => {
5757
CREATE (m)<-[:ACTED_IN { screentime: 60, character: "Terminator" }]-(a1:${typeActor.name} { name: "Arnold", age: 54, born: datetime('1980-07-02')})
5858
CREATE (m)<-[:ACTED_IN { screentime: 50, character: "someone" }]-(a1)
5959
CREATE (m)<-[:ACTED_IN { screentime: 120, character: "Sarah" }]-(:${typeActor.name} {name: "Linda", age:37, born: datetime('2000-02-02')})
60+
CREATE (m)<-[:ACTED_IN { screentime: 120, character: "Sarah" }]-(:${typeActor.name} {name: "John", age:37, born: datetime('2000-02-02')})
6061
`);
6162
});
6263

@@ -89,7 +90,7 @@ describe("Field Level Aggregations", () => {
8990
actorsConnection: {
9091
aggregate: {
9192
count: {
92-
nodes: 2,
93+
nodes: 3,
9394
},
9495
},
9596
},
@@ -124,8 +125,8 @@ describe("Field Level Aggregations", () => {
124125
actorsConnection: {
125126
aggregate: {
126127
count: {
127-
nodes: 2,
128-
edges: 3,
128+
nodes: 3,
129+
edges: 4,
129130
},
130131
},
131132
},
@@ -163,7 +164,7 @@ describe("Field Level Aggregations", () => {
163164
node: {
164165
name: {
165166
longest: "Arnold",
166-
shortest: "Linda",
167+
shortest: "John",
167168
},
168169
},
169170
},
@@ -205,8 +206,8 @@ describe("Field Level Aggregations", () => {
205206
age: {
206207
max: 54,
207208
min: 37,
208-
average: expect.closeTo(48.33),
209-
sum: 145,
209+
average: expect.closeTo(42.67),
210+
sum: 128,
210211
},
211212
},
212213
},
@@ -289,8 +290,8 @@ describe("Field Level Aggregations", () => {
289290
screentime: {
290291
max: 120,
291292
min: 50,
292-
average: expect.closeTo(76.67),
293-
sum: 230,
293+
average: expect.closeTo(87.5),
294+
sum: 350,
294295
},
295296
},
296297
},

0 commit comments

Comments
 (0)