diff --git a/src/MongoDB.Bson/Serialization/BsonClassMap.cs b/src/MongoDB.Bson/Serialization/BsonClassMap.cs index 87239a33bf5..0e66161e684 100644 --- a/src/MongoDB.Bson/Serialization/BsonClassMap.cs +++ b/src/MongoDB.Bson/Serialization/BsonClassMap.cs @@ -14,6 +14,7 @@ */ using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Collections.ObjectModel; using System.Linq; @@ -61,6 +62,9 @@ public class BsonClassMap private int _extraElementsMemberIndex = -1; private List _knownTypes = new List(); + private ConcurrentDictionary _discriminatorToTypeMap = new(); + private ConcurrentDictionary _typeToDiscriminatorMap = new(); + // constructors /// /// Initializes a new instance of the BsonClassMap class. @@ -228,6 +232,10 @@ public bool IsRootClass get { return _isRootClass; } } + internal ConcurrentDictionary DiscriminatorToTypeMap => _discriminatorToTypeMap; + + internal ConcurrentDictionary TypeToDiscriminatorMap => _typeToDiscriminatorMap; + /// /// Gets the known types of this class. /// @@ -695,6 +703,11 @@ public BsonClassMap Freeze() } _frozen = true; + if (_discriminator != null) + { + AddKnownDiscriminator(_discriminator, _classType); + } + // use a queue to postpone processing of known types until we get back to the first level call to Freeze // this avoids infinite recursion when going back down the inheritance tree while processing known types foreach (var knownType in _knownTypes) @@ -1140,6 +1153,36 @@ public void SetExtraElementsMember(BsonMemberMap memberMap) _extraElementsMemberMap = memberMap; } + internal void AddKnownDiscriminator(BsonValue discriminator, Type type) + { + if (!_classType.IsAssignableFrom(type)) + { + throw new ArgumentException($"Type \"{type}\" is not assignable to \"{_classType}\".", nameof(type)); + } + + if (_classType == typeof(object) || _classType == typeof(ValueType) || _classType.IsInterface) + { + return; + } + + if (_baseClassMap != null) + { + _baseClassMap.AddKnownDiscriminator(discriminator, type); + } + + var knownType = _discriminatorToTypeMap.GetOrAdd(discriminator, type); + if (knownType != type) + { + throw new ArgumentException($"Duplicate discriminator value \"{discriminator}\".", nameof(discriminator)); + } + + var knownDiscriminator = _typeToDiscriminatorMap.GetOrAdd(type, discriminator); + if (!knownDiscriminator.Equals(discriminator)) + { + throw new ArgumentException($"Duplicate derived type \"{type}\".", nameof(type)); + } + } + /// /// Adds a known type to the class map. /// @@ -1331,31 +1374,37 @@ internal IDiscriminatorConvention GetDiscriminatorConvention() IDiscriminatorConvention LookupDiscriminatorConvention() { - var classMap = this; - while (classMap != null) + if (_discriminatorConvention != null) { - if (classMap._discriminatorConvention != null) - { - return classMap._discriminatorConvention; - } + return _discriminatorConvention; + } - if (BsonSerializer.IsDiscriminatorConventionRegisteredAtThisLevel(classMap._classType)) + if (BsonSerializer.IsDiscriminatorConventionRegisteredAtThisLevel(_classType)) + { + return BsonSerializer.LookupDiscriminatorConvention(_classType); + } + + if (_isRootClass) + { + return StandardDiscriminatorConvention.Hierarchical; + } + + if (_baseClassMap.ClassType == typeof(object)) + { + return new BsonClassMapScalarDiscriminatorConvention("_t", this); + } + else + { + var discriminatorConvention = _baseClassMap.GetDiscriminatorConvention(); + if (discriminatorConvention is BsonClassMapScalarDiscriminatorConvention) { - // in this case LookupDiscriminatorConvention below will find it - break; + return new BsonClassMapScalarDiscriminatorConvention(discriminatorConvention.ElementName, this); } - - if (classMap._isRootClass) + else { - // in this case auto-register a hierarchical convention for the root class and look it up as usual below - BsonSerializer.GetOrRegisterDiscriminatorConvention(classMap._classType, StandardDiscriminatorConvention.Hierarchical); - break; + return discriminatorConvention; } - - classMap = classMap._baseClassMap; } - - return BsonSerializer.LookupDiscriminatorConvention(_classType); } } diff --git a/src/MongoDB.Bson/Serialization/BsonClassMapScalarDiscriminatorConvention.cs b/src/MongoDB.Bson/Serialization/BsonClassMapScalarDiscriminatorConvention.cs new file mode 100644 index 00000000000..792b5fc6d33 --- /dev/null +++ b/src/MongoDB.Bson/Serialization/BsonClassMapScalarDiscriminatorConvention.cs @@ -0,0 +1,122 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using MongoDB.Bson.Serialization.Conventions; + +namespace MongoDB.Bson.Serialization +{ + /// + /// A scalar discriminator convention for class mapped types. + /// + public class BsonClassMapScalarDiscriminatorConvention : StandardDiscriminatorConvention, IScalarDiscriminatorConvention + { + private readonly BsonClassMap _classMap; + + // cached map + private readonly ConcurrentDictionary _typeToDiscriminatorsForTypeAndSubTypesMap = new(); + + /// + /// Gets the class map. + /// + public BsonClassMap ClassMap => _classMap; + + /// + /// Initializes a new instance of BsonClassMapScalarDiscriminatorConvention. + /// + /// The discriminator element name. + /// The class map. + public BsonClassMapScalarDiscriminatorConvention(string elementName, BsonClassMap classMap) + : base(elementName) + { + _classMap = classMap ?? throw new ArgumentNullException(nameof(classMap)); + } + + /// + /// Gets the actual type. + /// + /// The nominal type. + /// The discriminator value. + /// The actual type. + protected override Type GetActualType(Type nominalType, BsonValue discriminator) + { + if (_classMap.DiscriminatorToTypeMap.TryGetValue(discriminator, out Type actualType)) + { + return actualType; + } + + if (nominalType == typeof(object) && discriminator.IsString) + { + actualType = TypeNameDiscriminator.GetActualType(discriminator.AsString); + if (actualType != null) + { + return actualType; + } + } + + throw new BsonSerializationException($"No type found for discriminator value: {discriminator}."); + } + + /// + /// Gets the discriminator value. + /// + /// The nominal type. + /// The actual type. + /// The discriminator value. + public override BsonValue GetDiscriminator(Type nominalType, Type actualType) + { + if (actualType == nominalType && !_classMap.DiscriminatorIsRequired) + { + return null; + } + + if (_classMap.TypeToDiscriminatorMap.TryGetValue(actualType, out BsonValue discriminator)) + { + return discriminator; + } + + throw new BsonSerializationException($"No discriminator value found for type: \"{actualType}\"."); + } + + /// + /// Gets the discriminator values for a type and all of its sub types. + /// + /// The type. + /// The discriminator values for a type and all of its sub types. + public BsonValue[] GetDiscriminatorsForTypeAndSubTypes(Type type) + { + return _typeToDiscriminatorsForTypeAndSubTypesMap.GetOrAdd(type, MapTypeToDiscriminatorsForTypeAndSubTypes); + } + + private BsonValue[] MapTypeToDiscriminatorsForTypeAndSubTypes(Type type) + { + var discriminators = new List(); + foreach (var entry in _classMap.TypeToDiscriminatorMap) + { + var discriminatedType = entry.Key; + if (type.IsAssignableFrom(discriminatedType)) + { + var discriminator = entry.Value; + discriminators.Add(discriminator); + } + } + + return discriminators.OrderBy(x => x).ToArray(); + } + } +} diff --git a/src/MongoDB.Bson/Serialization/Conventions/StandardDiscriminatorConvention.cs b/src/MongoDB.Bson/Serialization/Conventions/StandardDiscriminatorConvention.cs index ca4d2cbc117..356af416bfe 100644 --- a/src/MongoDB.Bson/Serialization/Conventions/StandardDiscriminatorConvention.cs +++ b/src/MongoDB.Bson/Serialization/Conventions/StandardDiscriminatorConvention.cs @@ -96,35 +96,27 @@ obj is StandardDiscriminatorConvention other && /// The actual type. public Type GetActualType(IBsonReader bsonReader, Type nominalType) { - // the BsonReader is sitting at the value whose actual type needs to be found - var bsonType = bsonReader.GetCurrentBsonType(); - if (bsonType == BsonType.Document) + // ensure KnownTypes of nominalType are registered (so IsTypeDiscriminated returns correct answer) + BsonSerializer.EnsureKnownTypesAreRegistered(nominalType); + + // we can skip looking for a discriminator if nominalType has no discriminated sub types + if (!BsonSerializer.IsTypeDiscriminated(nominalType)) { - // ensure KnownTypes of nominalType are registered (so IsTypeDiscriminated returns correct answer) - BsonSerializer.EnsureKnownTypesAreRegistered(nominalType); + return nominalType; + } - // we can skip looking for a discriminator if nominalType has no discriminated sub types - if (BsonSerializer.IsTypeDiscriminated(nominalType)) - { - var bookmark = bsonReader.GetBookmark(); - bsonReader.ReadStartDocument(); - var actualType = nominalType; - if (bsonReader.FindElement(_elementName)) - { - var context = BsonDeserializationContext.CreateRoot(bsonReader); - var discriminator = BsonValueSerializer.Instance.Deserialize(context); - if (discriminator.IsBsonArray) - { - discriminator = discriminator.AsBsonArray.Last(); // last item is leaf class discriminator - } - actualType = BsonSerializer.LookupActualType(nominalType, discriminator); - } - bsonReader.ReturnToBookmark(bookmark); - return actualType; - } + var discriminator = ReadDiscriminator(bsonReader); + if (discriminator == null) + { + return nominalType; + } + + if (discriminator.IsBsonArray) + { + discriminator = discriminator.AsBsonArray.Last(); // last item is leaf class discriminator } - return nominalType; + return GetActualType(nominalType, discriminator); } /// @@ -137,5 +129,47 @@ public Type GetActualType(IBsonReader bsonReader, Type nominalType) /// public override int GetHashCode() => 0; + + // protected methods + /// + /// Gets the actual type. + /// + /// The nominal type. + /// The discriminator. + /// The actual type. + protected virtual Type GetActualType(Type nominalType, BsonValue discriminator) + { + return BsonSerializer.LookupActualType(nominalType, discriminator); + } + + /// + /// Reads the discriminator. + /// + /// The bsonReader. + /// The discriminator, or null if no discriminator was found. + protected BsonValue ReadDiscriminator(IBsonReader bsonReader) + { + // the BsonReader is sitting at the value whose actual type needs to be found + var bsonType = bsonReader.GetCurrentBsonType(); + if (bsonType == BsonType.Document) + { + var bookmark = bsonReader.GetBookmark(); + try + { + bsonReader.ReadStartDocument(); + if (bsonReader.FindElement(_elementName)) + { + var context = BsonDeserializationContext.CreateRoot(bsonReader); + return BsonValueSerializer.Instance.Deserialize(context); + } + } + finally + { + bsonReader.ReturnToBookmark(bookmark); + } + } + + return null; + } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstSimplifier.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstSimplifier.cs index 51e25375d76..8b21af17157 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstSimplifier.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstSimplifier.cs @@ -13,10 +13,13 @@ * limitations under the License. */ +using System.Linq; using MongoDB.Bson; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Filters; +using MongoDB.Driver.Linq.Linq3Implementation.Ast.Stages; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Visitors; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; namespace MongoDB.Driver.Linq.Linq3Implementation.Ast.Optimizers { @@ -277,6 +280,30 @@ elemMatchOperation.Filter is AstFieldOperationFilter elemFilter && } } + public override AstNode VisitFilterExpression(AstFilterExpression node) + { + var inputExpression = VisitAndConvert(node.Input); + var condExpression = VisitAndConvert(node.Cond); + var limitExpression = VisitAndConvert(node.Limit); + + if (condExpression is AstConstantExpression condConstantExpression && + condConstantExpression.Value is BsonBoolean condBsonBoolean) + { + if (condBsonBoolean.Value && limitExpression == null) + { + // { $cond : { input : , as : "x", cond : true } } => + return inputExpression; + } + else + { + // { $cond : { input : , as : "x", cond : false } } => [] + return AstExpression.Constant(new BsonArray()); + } + } + + return node.Update(inputExpression, condExpression, limitExpression); + } + public override AstNode VisitGetFieldExpression(AstGetFieldExpression node) { if (TrySimplifyAsFieldPath(node, out var simplified)) @@ -374,17 +401,56 @@ public override AstNode VisitNotFilterOperation(AstNotFilterOperation node) return base.VisitNotFilterOperation(node); } + public override AstNode VisitPipeline(AstPipeline node) + { + var stages = VisitAndConvert(node.Stages); + + // { $match : { } } => remove redundant stage + if (stages.Any(stage => IsMatchEverythingStage(stage))) + { + stages = stages.Where(stage => !IsMatchEverythingStage(stage)).AsReadOnlyList(); + } + + return node.Update(stages); + + static bool IsMatchEverythingStage(AstStage stage) + { + return + stage is AstMatchStage matchStage && + matchStage.Filter is AstMatchesEverythingFilter; + } + } + public override AstNode VisitUnaryExpression(AstUnaryExpression node) { + var arg = VisitAndConvert(node.Arg); + // { $first : } => { $arrayElemAt : [, 0] } (or -1 for $last) if (node.Operator == AstUnaryOperator.First || node.Operator == AstUnaryOperator.Last) { - var simplifiedArg = VisitAndConvert(node.Arg); var index = node.Operator == AstUnaryOperator.First ? 0 : -1; - return AstExpression.ArrayElemAt(simplifiedArg, index); + return AstExpression.ArrayElemAt(arg, index); + } + + // { $not : booleanConstant } => !booleanConstant + if (node.Operator is AstUnaryOperator.Not && + arg is AstConstantExpression argConstantExpression && + argConstantExpression.Value is BsonBoolean argBsonBoolean) + { + return AstExpression.Constant(!argBsonBoolean.Value); + } + + // { $not : { $eq : [expr1, expr2] } } => { $ne : [expr1, expr2] } + // { $not : { $ne : [expr1, expr2] } } => { $eq : [expr1, expr2] } + if (node.Operator is AstUnaryOperator.Not && + arg is AstBinaryExpression argBinaryExpression && + argBinaryExpression.Operator is AstBinaryOperator.Eq or AstBinaryOperator.Ne) + { + var oppositeComparisonOperator = argBinaryExpression.Operator == AstBinaryOperator.Eq ? AstBinaryOperator.Ne : AstBinaryOperator.Eq; + return AstExpression.Binary(oppositeComparisonOperator, argBinaryExpression.Arg1, argBinaryExpression.Arg2); } - return base.VisitUnaryExpression(node); + return node.Update(arg); } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodCallExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodCallExpressionToAggregationExpressionTranslator.cs index 57e592c43bc..479f56cb3ca 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodCallExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodCallExpressionToAggregationExpressionTranslator.cs @@ -62,6 +62,7 @@ public static AggregationExpression Translate(TranslationContext context, Method case "IsNullOrWhiteSpace": return IsNullOrWhiteSpaceMethodToAggregationExpressionTranslator.Translate(context, expression); case "IsSubsetOf": return IsSubsetOfMethodToAggregationExpressionTranslator.Translate(context, expression); case "Locf": return LocfMethodToAggregationExpressionTranslator.Translate(context, expression); + case "OfType": return OfTypeMethodToAggregationExpressionTranslator.Translate(context, expression); case "Parse": return ParseMethodToAggregationExpressionTranslator.Translate(context, expression); case "Pow": return PowMethodToAggregationExpressionTranslator.Translate(context, expression); case "Push": return PushMethodToAggregationExpressionTranslator.Translate(context, expression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/OfTypeMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/OfTypeMethodToAggregationExpressionTranslator.cs new file mode 100644 index 00000000000..561e055a5d8 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/OfTypeMethodToAggregationExpressionTranslator.cs @@ -0,0 +1,92 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Conventions; +using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators +{ + internal static class OfTypeMethodToAggregationExpressionTranslator + { + private static MethodInfo[] __ofTypeMethods = + { + EnumerableMethod.OfType, + QueryableMethod.OfType + }; + + public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression) + { + var method = expression.Method; + var arguments = expression.Arguments; + + if (method.IsOneOf(__ofTypeMethods)) + { + var sourceExpression = arguments[0]; + var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); + NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation); + + var sourceAst = sourceTranslation.Ast; + var sourceSerializer = sourceTranslation.Serializer; + if (sourceSerializer is IWrappedValueSerializer wrappedValueSerializer) + { + sourceAst = AstExpression.GetField(sourceAst, wrappedValueSerializer.FieldName); + sourceSerializer = wrappedValueSerializer.ValueSerializer; + } + var itemSerializer = ArraySerializerHelper.GetItemSerializer(sourceSerializer); + + var nominalType = itemSerializer.ValueType; + var nominalTypeSerializer = itemSerializer; + var actualType = method.GetGenericArguments().Single(); + var actualTypeSerializer = BsonSerializer.LookupSerializer(actualType); + + AstExpression ast; + if (nominalType == actualType) + { + ast = sourceAst; + } + else + { + var discriminatorConvention = nominalTypeSerializer.GetDiscriminatorConvention(); + var itemVar = AstExpression.Var("item"); + var discriminatorField = AstExpression.GetField(itemVar, discriminatorConvention.ElementName); + + var ofTypeExpression = discriminatorConvention switch + { + IHierarchicalDiscriminatorConvention hierarchicalDiscriminatorConvention => DiscriminatorAstExpression.TypeIs(discriminatorField, hierarchicalDiscriminatorConvention, nominalType, actualType), + IScalarDiscriminatorConvention scalarDiscriminatorConvention => DiscriminatorAstExpression.TypeIs(discriminatorField, scalarDiscriminatorConvention, nominalType, actualType), + _ => throw new ExpressionNotSupportedException(expression, because: "OfType is not supported with the configured discriminator convention") + }; + + ast = AstExpression.Filter( + input: sourceAst, + cond: ofTypeExpression, + @as: "item"); + } + + var resultSerializer = NestedAsQueryableSerializer.CreateIEnumerableOrNestedAsQueryableSerializer(expression.Type, actualTypeSerializer); + return new AggregationExpression(expression, ast, resultSerializer); + } + + throw new ExpressionNotSupportedException(expression); + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WhereMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WhereMethodToAggregationExpressionTranslator.cs index 1eff6481de5..b37f0edfb4d 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WhereMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WhereMethodToAggregationExpressionTranslator.cs @@ -40,9 +40,17 @@ public static AggregationExpression Translate(TranslationContext context, Method { var sourceExpression = arguments[0]; var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); - var itemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer); NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation); + var sourceAst = sourceTranslation.Ast; + var sourceSerializer = sourceTranslation.Serializer; + if (sourceSerializer is IWrappedValueSerializer wrappedValueSerializer) + { + sourceAst = AstExpression.GetField(sourceAst, wrappedValueSerializer.FieldName); + sourceSerializer = wrappedValueSerializer.ValueSerializer; + } + var itemSerializer = ArraySerializerHelper.GetItemSerializer(sourceSerializer); + var predicateLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); var predicateParameter = predicateLambda.Parameters[0]; var predicateSymbol = context.CreateSymbol(predicateParameter, itemSerializer); @@ -57,7 +65,7 @@ public static AggregationExpression Translate(TranslationContext context, Method } var ast = AstExpression.Filter( - sourceTranslation.Ast, + sourceAst, predicateTranslation.Ast, @as: predicateSymbol.Var.Name, limitTranslation?.Ast); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/TypeIsExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/TypeIsExpressionToAggregationExpressionTranslator.cs index ef429a863db..f7a127a640c 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/TypeIsExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/TypeIsExpressionToAggregationExpressionTranslator.cs @@ -13,9 +13,7 @@ * limitations under the License. */ -using System.Linq; using System.Linq.Expressions; -using MongoDB.Bson; using MongoDB.Bson.Serialization; using MongoDB.Bson.Serialization.Conventions; using MongoDB.Bson.Serialization.Serializers; diff --git a/tests/MongoDB.Bson.Tests/Serialization/Serializers/DiscriminatorTests.cs b/tests/MongoDB.Bson.Tests/Serialization/Serializers/DiscriminatorTests.cs index 826b0969da6..78fe458a3aa 100644 --- a/tests/MongoDB.Bson.Tests/Serialization/Serializers/DiscriminatorTests.cs +++ b/tests/MongoDB.Bson.Tests/Serialization/Serializers/DiscriminatorTests.cs @@ -40,7 +40,7 @@ private class C : A { } - [BsonDiscriminator("D~", RootClass = true)] + [BsonDiscriminator("D~")] private class D : A { } @@ -212,7 +212,7 @@ public void TestSerializeDAsD() { D d = new D { P = "x" }; var json = d.ToJson(); - var expected = ("{ '_t' : 'D~', 'P' : 'x' }").Replace("'", "\""); + var expected = ("{ 'P' : 'x' }").Replace("'", "\""); Assert.Equal(expected, json); var bson = d.ToBson(); @@ -329,7 +329,7 @@ public void TestSerializeGAsObject() { G g = new G { P = "x" }; var json = g.ToJson(); - var expected = ("{ '_t' : ['D~', 'G~'], 'P' : 'x' }").Replace("'", "\""); + var expected = ("{ '_t' : 'G~', 'P' : 'x' }").Replace("'", "\""); Assert.Equal(expected, json); var bson = g.ToBson(); @@ -342,7 +342,7 @@ public void TestSerializeGAsA() { G g = new G { P = "x" }; var json = g.ToJson(); - var expected = ("{ '_t' : ['D~', 'G~'], 'P' : 'x' }").Replace("'", "\""); + var expected = ("{ '_t' : 'G~', 'P' : 'x' }").Replace("'", "\""); Assert.Equal(expected, json); var bson = g.ToBson(); @@ -355,7 +355,7 @@ public void TestSerializeGAsD() { G g = new G { P = "x" }; var json = g.ToJson(); - var expected = ("{ '_t' : ['D~', 'G~'], 'P' : 'x' }").Replace("'", "\""); + var expected = ("{ '_t' : 'G~', 'P' : 'x' }").Replace("'", "\""); Assert.Equal(expected, json); var bson = g.ToBson(); @@ -368,7 +368,7 @@ public void TestSerializeGAsG() { G g = new G { P = "x" }; var json = g.ToJson(); - var expected = ("{ '_t' : ['D~', 'G~'], 'P' : 'x' }").Replace("'", "\""); + var expected = ("{ 'P' : 'x' }").Replace("'", "\""); Assert.Equal(expected, json); var bson = g.ToBson(); @@ -381,7 +381,7 @@ public void TestSerializeHAsObject() { H h = new H { P = "x" }; var json = h.ToJson(); - var expected = ("{ '_t' : ['D~', 'G~', 'H~'], 'P' : 'x' }").Replace("'", "\""); + var expected = ("{ '_t' : 'H~', 'P' : 'x' }").Replace("'", "\""); Assert.Equal(expected, json); var bson = h.ToBson(); @@ -394,7 +394,7 @@ public void TestSerializeHAsA() { H h = new H { P = "x" }; var json = h.ToJson(); - var expected = ("{ '_t' : ['D~', 'G~', 'H~'], 'P' : 'x' }").Replace("'", "\""); + var expected = ("{ '_t' : 'H~', 'P' : 'x' }").Replace("'", "\""); Assert.Equal(expected, json); var bson = h.ToBson(); @@ -407,7 +407,7 @@ public void TestSerializeHAsD() { H h = new H { P = "x" }; var json = h.ToJson(); - var expected = ("{ '_t' : ['D~', 'G~', 'H~'], 'P' : 'x' }").Replace("'", "\""); + var expected = ("{ '_t' : 'H~', 'P' : 'x' }").Replace("'", "\""); Assert.Equal(expected, json); var bson = h.ToBson(); @@ -420,7 +420,7 @@ public void TestSerializeHAsG() { H h = new H { P = "x" }; var json = h.ToJson(); - var expected = ("{ '_t' : ['D~', 'G~', 'H~'], 'P' : 'x' }").Replace("'", "\""); + var expected = ("{ '_t' : 'H~', 'P' : 'x' }").Replace("'", "\""); Assert.Equal(expected, json); var bson = h.ToBson(); @@ -433,7 +433,7 @@ public void TestSerializeHAsH() { H h = new H { P = "x" }; var json = h.ToJson(); - var expected = ("{ '_t' : ['D~', 'G~', 'H~'], 'P' : 'x' }").Replace("'", "\""); + var expected = ("{ 'P' : 'x' }").Replace("'", "\""); Assert.Equal(expected, json); var bson = h.ToBson(); diff --git a/tests/MongoDB.Bson.Tests/Serialization/Serializers/KnownTypesTests.cs b/tests/MongoDB.Bson.Tests/Serialization/Serializers/KnownTypesTests.cs index 1d0af120b66..618a25892c1 100644 --- a/tests/MongoDB.Bson.Tests/Serialization/Serializers/KnownTypesTests.cs +++ b/tests/MongoDB.Bson.Tests/Serialization/Serializers/KnownTypesTests.cs @@ -34,7 +34,6 @@ private class B : A { } - [BsonDiscriminator(RootClass = true)] [BsonKnownTypes(typeof(E))] private class C : A { @@ -77,7 +76,7 @@ public void TestDeserializeEAsA() { var document = new BsonDocument { - { "_t", new BsonArray { "C", "E" } }, + { "_t", "E" }, { "P", "x" } }; @@ -86,7 +85,7 @@ public void TestDeserializeEAsA() Assert.IsType(rehydrated); var json = rehydrated.ToJson(); - var expected = "{ '_t' : ['C', 'E'], 'P' : 'x' }".Replace("'", "\""); + var expected = "{ '_t' : 'E', 'P' : 'x' }".Replace("'", "\""); Assert.Equal(expected, json); Assert.True(bson.SequenceEqual(rehydrated.ToBson())); } diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp3140Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp3140Tests.cs index 468bca31027..22b790e0193 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp3140Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp3140Tests.cs @@ -135,7 +135,7 @@ public void OrElse_with_first_clause_that_evaluates_to_true_should_simplify_to_t .Where(x => currentUser.Factory == null || x.FactoryId == currentUser.Factory.Id); var stages = Translate(collection, queryable); - AssertStages(stages, "{ $match : { } }"); + AssertStages(stages); } [Fact] @@ -159,7 +159,7 @@ public void OrElse_with_second_clause_that_evaluates_to_true_should_simplify_to_ .Where(x => x.FactoryId != 0 || currentUser.Factory == null); var stages = Translate(collection, queryable); - AssertStages(stages, "{ $match : { } }"); + AssertStages(stages); } [Fact] diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4100FilterTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4100FilterTests.cs index 2e702574c4b..93dbbf8fee5 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4100FilterTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4100FilterTests.cs @@ -47,7 +47,7 @@ public void Contains_with_string_constant_and_char_constant_should_work() var stages = Translate(collection, queryable); - AssertStages(stages, "{ $match : { } }"); + AssertStages(stages); } [Fact] @@ -127,7 +127,7 @@ public void Contains_with_string_constant_and_char_constant_and_comparisonType_s var stages = Translate(collection, queryable); - AssertStages(stages, "{ $match : { } }"); + AssertStages(stages); } #endif @@ -230,7 +230,7 @@ public void Contains_with_string_constant_and_char_value_and_invalid_comparisonT var stages = Translate(collection, queryable); - AssertStages(stages, "{ $match : { } }"); + AssertStages(stages); } #endif @@ -301,7 +301,7 @@ public void Contains_with_string_field_and_string_constant_and_comparisonType_sh #if !NETFRAMEWORK [Theory] [InlineData(StringComparison.CurrentCulture, "{ $match : { _id : { $type : -1 } } }")] - [InlineData(StringComparison.CurrentCultureIgnoreCase, "{ $match : { } }")] + [InlineData(StringComparison.CurrentCultureIgnoreCase, null)] public void Contains_with_string_constant_and_string_constant_and_comparisonType_should_work(StringComparison comparisonType, string expectedStage) { var collection = GetCollection(); @@ -368,9 +368,9 @@ public void Contains_with_string_field_and_string_value_and_invalid_comparisonTy #if !NETFRAMEWORK [Theory] [InlineData(StringComparison.InvariantCulture, "{ $match : { _id : { $type : -1 } } }")] - [InlineData(StringComparison.InvariantCultureIgnoreCase, "{ $match : { } }")] + [InlineData(StringComparison.InvariantCultureIgnoreCase, null)] [InlineData(StringComparison.Ordinal, "{ $match : { _id : { $type : -1 } } }")] - [InlineData(StringComparison.OrdinalIgnoreCase, "{ $match : { } }")] + [InlineData(StringComparison.OrdinalIgnoreCase, null)] public void Contains_with_string_constant_and_string_value_and_invalid_comparisonType_should_(StringComparison comparisonType, string expectedStage) { var collection = GetCollection(); @@ -533,7 +533,7 @@ public void EndsWith_with_string_field_and_string_constant_and_ignoreCase_and_cu [Theory] [InlineData(false, "{ $match : { _id : { $type : -1 } } }")] - [InlineData(true, "{ $match : { } }")] + [InlineData(true, null)] public void EndsWith_with_string_constant_and_string_constant_and_ignoreCase_and_culture_should_work(bool ignoreCase, string expectedStage) { var collection = GetCollection(); @@ -591,7 +591,7 @@ public void EndsWith_with_string_field_and_string_value_and_ignoreCase_and_inval [Theory] [InlineData(false, "{ $match : { _id : { $type : -1 } } }")] - [InlineData(true, "{ $match : { } }")] + [InlineData(true, null)] public void EndsWith_with_string_constant_and_string_value_and_ignoreCase_and_invalid_culture_should_work(bool ignoreCase, string expectedStage) { var collection = GetCollection(); @@ -619,7 +619,7 @@ public void EndsWith_with_string_field_and_string_constant_and_comparisonType_sh [Theory] [InlineData(StringComparison.CurrentCulture, "{ $match : { _id : { $type : -1 } } }")] - [InlineData(StringComparison.CurrentCultureIgnoreCase, "{ $match : { } }")] + [InlineData(StringComparison.CurrentCultureIgnoreCase, null)] public void EndsWith_with_string_constant_and_string_constant_and_comparisonType_should_work(StringComparison comparisonType, string expectedStage) { var collection = GetCollection(); @@ -717,7 +717,7 @@ public void StartsWith_with_string_constant_and_char_constant_should_work() var stages = Translate(collection, queryable); - AssertStages(stages, "{ $match : { } }"); + AssertStages(stages); } #endif @@ -843,7 +843,7 @@ public void StartsWith_with_string_field_and_string_constant_and_ignoreCase_and_ [Theory] [InlineData(false, "{ $match : { _id : { $type : -1 } } }")] - [InlineData(true, "{ $match : { } }")] + [InlineData(true, null)] public void StartsWith_with_string_constant_and_string_constant_and_ignoreCase_and_culture_should_work(bool ignoreCase, string expectedStage) { var collection = GetCollection(); @@ -901,7 +901,7 @@ public void StartsWith_with_string_field_and_string_value_and_ignoreCase_and_inv [Theory] [InlineData(false, "{ $match : { _id : { $type : -1 } } }")] - [InlineData(true, "{ $match : { } }")] + [InlineData(true, null)] public void StartsWith_with_string_constant_and_string_value_and_ignoreCase_and_invalid_culture_should_work(bool ignoreCase, string expectedStage) { var collection = GetCollection(); @@ -929,7 +929,7 @@ public void StartsWith_with_string_field_and_string_constant_and_comparisonType_ [Theory] [InlineData(StringComparison.CurrentCulture, "{ $match : { _id : { $type : -1 } } }")] - [InlineData(StringComparison.CurrentCultureIgnoreCase, "{ $match : { } }")] + [InlineData(StringComparison.CurrentCultureIgnoreCase, null)] public void StartsWith_with_string_constant_and_string_constant_and_comparisonType_should_work(StringComparison comparisonType, string expectedStage) { var collection = GetCollection(); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4116Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4116Tests.cs index 68751e1622c..e5d552010a4 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4116Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4116Tests.cs @@ -23,7 +23,7 @@ public class CSharp4116Tests : Linq3IntegrationTest { [Theory] [InlineData(false, "{ $match : { _id : { $type : -1 } } }")] - [InlineData(true, "{ $match : { } }")] + [InlineData(true, null)] public void Optimize_match_with_expr(bool value, string expectedStage) { var collection = GetCollection(); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5356Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5356Tests.cs new file mode 100644 index 00000000000..ab3e2276ae4 --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5356Tests.cs @@ -0,0 +1,606 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq; +using FluentAssertions; +using MongoDB.Bson.Serialization.Attributes; +using MongoDB.Driver.Linq; +using MongoDB.Bson.Serialization.Serializers; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira +{ + public class CSharp5356Tests : Linq3IntegrationTest + { + [Fact] + public void Documents_should_be_serialized_as_expected() + { + var collection = GetCollection(); + + var documents = collection.AsQueryable().As(BsonDocumentSerializer.Instance).ToList(); + + documents.Count.Should().Be(3); + documents[0].Should().Be("{ _id : 1, _t : 'Cat' }"); + documents[1].Should().Be("{ _id : 2, _t : 'Dog' }"); + documents[2].Should().Be("{ _id : 3, _t : 'Snake' }"); + } + + [Fact] + public void OfType_Animal_should_work() + { + var collection = GetCollection(); + + var queryable = collection.AsQueryable() + .OfType(); + + var stages = Translate(collection, queryable); + stages.Count.Should().Be(0); + + var results = queryable.ToList(); + results.Select(x => x.Id).Should().Equal(1, 2, 3); + } + + [Fact] + public void OfType_Mammal_should_work() + { + var collection = GetCollection(); + + var queryable = collection.AsQueryable() + .OfType(); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $match : { _t : { $in : ['Cat', 'Dog', 'Mammal'] } } }"); + + var results = queryable.ToList(); + results.Select(x => x.Id).Should().Equal(1, 2); + } + + [Fact] + public void OfType_Cat_should_work() + { + var collection = GetCollection(); + + var queryable = collection.AsQueryable() + .OfType(); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $match : { _t : 'Cat' } }"); + + var results = queryable.ToList(); + results.Select(x => x.Id).Should().Equal(1); + } + + [Fact] + public void OfType_Reptile_should_work() + { + var collection = GetCollection(); + + var queryable = collection.AsQueryable() + .OfType(); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $match : { _t : { $in : ['Reptile', 'Snake'] } } }"); + + var results = queryable.ToList(); + results.Select(x => x.Id).Should().Equal(3); + } + + [Fact] + public void OfType_Snake_should_work() + { + var collection = GetCollection(); + + var queryable = collection.AsQueryable() + .OfType(); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $match : { _t : 'Snake' } }"); + + var results = queryable.ToList(); + results.Select(x => x.Id).Should().Equal(3); + } + + [Fact] + public void Where_is_Animal_should_work() + { + var collection = GetCollection(); + + var queryable = collection.AsQueryable() + .Where(x => x is Animal); + + var stages = Translate(collection, queryable); + AssertStages(stages); + + var results = queryable.ToList(); + results.Select(x => x.Id).Should().Equal(1, 2, 3); + } + + [Fact] + public void Where_is_Mammal_should_work() + { + var collection = GetCollection(); + + var queryable = collection.AsQueryable() + .Where(x => x is Mammal); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $match : { _t : { $in : ['Cat', 'Dog', 'Mammal'] } } }"); + + var results = queryable.ToList(); + results.Select(x => x.Id).Should().Equal(1, 2); + } + + [Fact] + public void Where_is_Cat_should_work() + { + var collection = GetCollection(); + + var queryable = collection.AsQueryable() + .Where(x => x is Cat); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $match : { _t : 'Cat' } }"); + + var results = queryable.ToList(); + results.Select(x => x.Id).Should().Equal(1); + } + + [Fact] + public void Where_is_Reptile_should_work() + { + var collection = GetCollection(); + + var queryable = collection.AsQueryable() + .Where(x => x is Reptile); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $match : { _t : { $in : ['Reptile', 'Snake'] } } }"); + + var results = queryable.ToList(); + results.Select(x => x.Id).Should().Equal(3); + } + + [Fact] + public void Where_is_Snake_should_work() + { + var collection = GetCollection(); + + var queryable = collection.AsQueryable() + .Where(x => x is Snake); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $match : { _t : 'Snake' } }"); + + var results = queryable.ToList(); + results.Select(x => x.Id).Should().Equal(3); + } + + [Fact] + public void Where_not_is_Animal_should_work() + { + var collection = GetCollection(); + + var queryable = collection.AsQueryable() + .Where(x => !(x is Animal)); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $match : { _id : { $type : -1 } } }"); + + var results = queryable.ToList(); + results.Count.Should().Be(0); + } + + [Fact] + public void Where_not_is_Mammal_should_work() + { + var collection = GetCollection(); + + var queryable = collection.AsQueryable() + .Where(x => !(x is Mammal)); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $match : { _t : { $nin : ['Cat', 'Dog', 'Mammal'] } } }"); + + var results = queryable.ToList(); + results.Select(x => x.Id).Should().Equal(3); + } + + [Fact] + public void Where_not_is_Cat_should_work() + { + var collection = GetCollection(); + + var queryable = collection.AsQueryable() + .Where(x => !(x is Cat)); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $match : { _t : { $ne : 'Cat' } } }"); + + var results = queryable.ToList(); + results.Select(x => x.Id).Should().Equal(2, 3); + } + + [Fact] + public void Where_not_is_Reptile_should_work() + { + var collection = GetCollection(); + + var queryable = collection.AsQueryable() + .Where(x => !(x is Reptile)); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $match : { _t : { $nin : ['Reptile', 'Snake'] } } }"); + + var results = queryable.ToList(); + results.Select(x => x.Id).Should().Equal(1, 2); + } + + [Fact] + public void Where_not_is_Snake_should_work() + { + var collection = GetCollection(); + + var queryable = collection.AsQueryable() + .Where(x => !(x is Snake)); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $match : { _t : { $ne : 'Snake' } } }"); + + var results = queryable.ToList(); + results.Select(x => x.Id).Should().Equal(1, 2); + } + + [Fact] + public void Array_OfType_Animal_should_work() + { + var collection = GetCollection(); + + var queryable = collection.AsQueryable() + .GroupBy(x => 1, (key, grouping) => grouping.ToArray()) + .Select(x => x.OfType().ToArray()); + + var stages = Translate(collection, queryable); + AssertStages( + stages, + "{ $group : { _id : 1, _elements : { $push : '$$ROOT' } } }", + "{ $project : { _v : '$_elements', _id : 0 } }", + "{ $project : { _v : '$_v', _id : 0 } }"); + + var result = queryable.Single(); + result.Select(x => x.Id).Should().Equal(1, 2, 3); + } + + [Fact] + public void Array_OfType_Mammal_should_work() + { + var collection = GetCollection(); + + var queryable = collection.AsQueryable() + .GroupBy(x => 1, (key, grouping) => grouping.ToArray()) + .Select(x => x.OfType().ToArray()); + + var stages = Translate(collection, queryable); + AssertStages( + stages, + "{ $group : { _id : 1, _elements : { $push : '$$ROOT' } } }", + "{ $project : { _v : '$_elements', _id : 0 } }", + "{ $project : { _v : { $filter : { input : '$_v', as : 'item', cond : { $in : ['$$item._t', ['Cat', 'Dog', 'Mammal']] } } }, _id : 0 } }"); + + var result = queryable.Single(); + result.Select(x => x.Id).Should().Equal(1, 2); + } + + [Fact] + public void Array_OfType_Cat_should_work() + { + var collection = GetCollection(); + + var queryable = collection.AsQueryable() + .GroupBy(x => 1, (key, grouping) => grouping.ToArray()) + .Select(x => x.OfType().ToArray()); + + var stages = Translate(collection, queryable); + AssertStages( + stages, + "{ $group : { _id : 1, _elements : { $push : '$$ROOT' } } }", + "{ $project : { _v : '$_elements', _id : 0 } }", + "{ $project : { _v : { $filter : { input : '$_v', as : 'item', cond : { $eq : ['$$item._t', 'Cat'] } } }, _id : 0 } }"); + + var result = queryable.Single(); + result.Select(x => x.Id).Should().Equal(1); + } + + [Fact] + public void Array_OfType_Reptile_should_work() + { + var collection = GetCollection(); + + var queryable = collection.AsQueryable() + .GroupBy(x => 1, (key, grouping) => grouping.ToArray()) + .Select(x => x.OfType().ToArray()); + + var stages = Translate(collection, queryable); + AssertStages( + stages, + "{ $group : { _id : 1, _elements : { $push : '$$ROOT' } } }", + "{ $project : { _v : '$_elements', _id : 0 } }", + "{ $project : { _v : { $filter : { input : '$_v', as : 'item', cond : { $in : ['$$item._t', ['Reptile', 'Snake']] } } }, _id : 0 } }"); + + var result = queryable.Single(); + result.Select(x => x.Id).Should().Equal(3); + } + + [Fact] + public void Array_OfType_Snake_should_work() + { + var collection = GetCollection(); + + var queryable = collection.AsQueryable() + .GroupBy(x => 1, (key, grouping) => grouping.ToArray()) + .Select(x => x.OfType().ToArray()); + + var stages = Translate(collection, queryable); + AssertStages( + stages, + "{ $group : { _id : 1, _elements : { $push : '$$ROOT' } } }", + "{ $project : { _v : '$_elements', _id : 0 } }", + "{ $project : { _v : { $filter : { input : '$_v', as : 'item', cond : { $eq : ['$$item._t', 'Snake'] } } }, _id : 0 } }"); + + var result = queryable.Single(); + result.Select(x => x.Id).Should().Equal(3); + } + + [Fact] + public void Array_Where_is_Animal_should_work() + { + var collection = GetCollection(); + + var queryable = collection.AsQueryable() + .GroupBy(x => 1, (key, grouping) => grouping.ToArray()) + .Select(x => x.Where(x => x is Animal).ToArray()); + + var stages = Translate(collection, queryable); + AssertStages( + stages, + "{ $group : { _id : 1, _elements : { $push : '$$ROOT' } } }", + "{ $project : { _v : '$_elements', _id : 0 } }", + "{ $project : { _v : '$_v', _id : 0 } }"); + + var result = queryable.Single(); + result.Select(x => x.Id).Should().Equal(1, 2, 3); + } + + [Fact] + public void Array_Where_is_Mammal_should_work() + { + var collection = GetCollection(); + + var queryable = collection.AsQueryable() + .GroupBy(x => 1, (key, grouping) => grouping.ToArray()) + .Select(x => x.Where(x => x is Mammal).ToArray()); + + var stages = Translate(collection, queryable); + AssertStages( + stages, + "{ $group : { _id : 1, _elements : { $push : '$$ROOT' } } }", + "{ $project : { _v : '$_elements', _id : 0 } }", + "{ $project : { _v: { $filter : { input : '$_v', as : 'x', cond : { $in : ['$$x._t', ['Cat', 'Dog', 'Mammal']] } } }, _id : 0 } }"); + + var result = queryable.Single(); + result.Select(x => x.Id).Should().Equal(1, 2); + } + + [Fact] + public void Array_Where_is_Cat_should_work() + { + var collection = GetCollection(); + + var queryable = collection.AsQueryable() + .GroupBy(x => 1, (key, grouping) => grouping.ToArray()) + .Select(x => x.Where(x => x is Cat).ToArray()); + + var stages = Translate(collection, queryable); + AssertStages( + stages, + "{ $group : { _id : 1, _elements : { $push : '$$ROOT' } } }", + "{ $project : { _v : '$_elements', _id : 0 } }", + "{ $project : { _v : { $filter : { input : '$_v', as : 'x', cond : { $eq : ['$$x._t', 'Cat'] } } }, _id : 0 } }"); + + var result = queryable.Single(); + result.Select(x => x.Id).Should().Equal(1); + } + + [Fact] + public void Array_Where_is_Reptile_should_work() + { + var collection = GetCollection(); + + var queryable = collection.AsQueryable() + .GroupBy(x => 1, (key, grouping) => grouping.ToArray()) + .Select(x => x.Where(x => x is Reptile).ToArray()); + + var stages = Translate(collection, queryable); + AssertStages( + stages, + "{ $group : { _id : 1, _elements : { $push : '$$ROOT' } } }", + "{ $project : { _v : '$_elements', _id : 0 } }", + "{ $project : { _v : { $filter : { input : '$_v', as : 'x', cond : { $in : ['$$x._t', ['Reptile', 'Snake']] } } }, _id : 0 } }"); + + var result = queryable.Single(); + result.Select(x => x.Id).Should().Equal(3); + } + + [Fact] + public void Array_Where_is_Snake_should_work() + { + var collection = GetCollection(); + + var queryable = collection.AsQueryable() + .GroupBy(x => 1, (key, grouping) => grouping.ToArray()) + .Select(x => x.Where(x => x is Snake).ToArray()); + + var stages = Translate(collection, queryable); + AssertStages( + stages, + "{ $group : { _id : 1, _elements : { $push : '$$ROOT' } } }", + "{ $project : { _v : '$_elements', _id : 0 } }", + "{ $project : { _v : { $filter : { input : '$_v', as : 'x', cond : { $eq : ['$$x._t', 'Snake'] } } }, _id : 0 } }"); + + var result = queryable.Single(); + result.Select(x => x.Id).Should().Equal(3); + } + + [Fact] + public void Array_Where_not_is_Animal_should_work() + { + var collection = GetCollection(); + + var queryable = collection.AsQueryable() + .GroupBy(x => 1, (key, grouping) => grouping.ToArray()) + .Select(x => x.Where(x => !(x is Animal)).ToArray()); + + var stages = Translate(collection, queryable); + AssertStages( + stages, + "{ $group : { _id : 1, _elements : { $push : '$$ROOT' } } }", + "{ $project : { _v : '$_elements', _id : 0 } }", + "{ $project : { _v : [], _id : 0 } }"); + + var result = queryable.Single(); + result.Select(x => x.Id).Should().Equal(); + } + + [Fact] + public void Array_Where_not_is_Mammal_should_work() + { + var collection = GetCollection(); + + var queryable = collection.AsQueryable() + .GroupBy(x => 1, (key, grouping) => grouping.ToArray()) + .Select(x => x.Where(x => !(x is Mammal)).ToArray()); + + var stages = Translate(collection, queryable); + AssertStages( + stages, + "{ $group : { _id : 1, _elements : { $push : '$$ROOT' } } }", + "{ $project : { _v : '$_elements', _id : 0 } }", + "{ $project : { _v : { $filter : { input : '$_v', as : 'x', cond : { $not : { $in : ['$$x._t', ['Cat', 'Dog', 'Mammal']] } } } }, _id : 0 } }"); + + var result = queryable.Single(); + result.Select(x => x.Id).Should().Equal(3); + } + + [Fact] + public void Array_Where_not_is_Cat_should_work() + { + var collection = GetCollection(); + + var queryable = collection.AsQueryable() + .GroupBy(x => 1, (key, grouping) => grouping.ToArray()) + .Select(x => x.Where(x => !(x is Cat)).ToArray()); + + var stages = Translate(collection, queryable); + AssertStages( + stages, + "{ $group : { _id : 1, _elements : { $push : '$$ROOT' } } }", + "{ $project : { _v : '$_elements', _id : 0 } }", + "{ $project : { _v : { $filter : { input : '$_v', as : 'x', cond : { $ne : ['$$x._t', 'Cat'] } } } , _id : 0 } }"); + + var result = queryable.Single(); + result.Select(x => x.Id).Should().Equal(2, 3); + } + + [Fact] + public void Array_Where_not_is_Reptile_should_work() + { + var collection = GetCollection(); + + var queryable = collection.AsQueryable() + .GroupBy(x => 1, (key, grouping) => grouping.ToArray()) + .Select(x => x.Where(x => !(x is Reptile)).ToArray()); + + var stages = Translate(collection, queryable); + AssertStages( + stages, + "{ $group : { _id : 1, _elements : { $push : '$$ROOT' } } }", + "{ $project : { _v : '$_elements', _id : 0 } }", + "{ $project : { _v : { $filter : { input : '$_v', as : 'x', cond : { $not : { $in : ['$$x._t', ['Reptile', 'Snake']] } } } }, _id : 0 } }"); + + var result = queryable.Single(); + result.Select(x => x.Id).Should().Equal(1, 2); + } + + [Fact] + public void Array_Where_not_is_Snake_should_work() + { + var collection = GetCollection(); + + var queryable = collection.AsQueryable() + .GroupBy(x => 1, (key, grouping) => grouping.ToArray()) + .Select(x => x.Where(x => !(x is Snake)).ToArray()); + + var stages = Translate(collection, queryable); + AssertStages( + stages, + "{ $group : { _id : 1, _elements : { $push : '$$ROOT' } } }", + "{ $project : { _v : '$_elements', _id : 0 } }", + "{ $project : { _v : { $filter : { input : '$_v', as : 'x', cond : { $ne : ['$$x._t', 'Snake'] } } } , _id : 0 } }"); + + var result = queryable.Single(); + result.Select(x => x.Id).Should().Equal(1, 2); + } + + private IMongoCollection GetCollection() + { + var collection = GetCollection("test"); + CreateCollection( + collection, + new Cat { Id = 1 }, + new Dog { Id = 2 }, + new Snake { Id = 3 }); + return collection; + } + + [BsonKnownTypes(typeof(Mammal))] + [BsonKnownTypes(typeof(Cat))] + [BsonKnownTypes(typeof(Dog))] + [BsonKnownTypes(typeof(Reptile))] + [BsonKnownTypes(typeof(Snake))] + private abstract class Animal + { + public int Id { get; set; } + } + + private abstract class Mammal : Animal + { + } + + private class Cat : Mammal + { + } + + private class Dog : Mammal + { + } + + private class Reptile : Animal + { + } + + private class Snake : Reptile + { + } + } +} diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Linq3IntegrationTest.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Linq3IntegrationTest.cs index 8a9dbe5e579..40d156c757d 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Linq3IntegrationTest.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Linq3IntegrationTest.cs @@ -33,7 +33,7 @@ protected void AssertStages(IEnumerable stages, params string[] ex protected void AssertStages(IEnumerable stages, IEnumerable expectedStages) { - stages.Should().Equal(expectedStages.Select(json => BsonDocument.Parse(json))); + stages.Should().Equal(expectedStages.Where(x => x != null).Select(json => BsonDocument.Parse(json))); } protected void CreateCollection(IMongoCollection collection, IEnumerable documents = null)