diff --git a/src/Microsoft.OData.Client/DataServiceContext.cs b/src/Microsoft.OData.Client/DataServiceContext.cs index 2c4d9d26ef..a9d6ae7308 100644 --- a/src/Microsoft.OData.Client/DataServiceContext.cs +++ b/src/Microsoft.OData.Client/DataServiceContext.cs @@ -3279,7 +3279,12 @@ protected internal Type DefaultResolveType(string typeName, string fullNamespace if (!this.resolveTypesCache.TryGetValue(typeName, out matchedType)) { - if (ClientTypeUtil.TryResolveType(typeName, fullNamespace, languageDependentNamespace, out matchedType)) + if (ClientTypeUtil.TryResolveType( + this.GetType().GetAssembly(), + typeName, + fullNamespace, + languageDependentNamespace, + out matchedType)) { this.resolveTypesCache.TryAdd(typeName, matchedType); diff --git a/src/Microsoft.OData.Client/Metadata/ClientTypeUtil.cs b/src/Microsoft.OData.Client/Metadata/ClientTypeUtil.cs index 192f31597c..de8231be54 100644 --- a/src/Microsoft.OData.Client/Metadata/ClientTypeUtil.cs +++ b/src/Microsoft.OData.Client/Metadata/ClientTypeUtil.cs @@ -454,7 +454,7 @@ internal static string GetServerDefinedName(MemberInfo memberInfo) /// The server defined type name. internal static string GetServerDefinedTypeName(Type type) { - ODataTypeInfo typeInfo = GetODataTypeInfo(type); + ODataTypeInfo typeInfo = GetODataTypeInfo(type); return typeInfo.ServerDefinedTypeName; } @@ -478,7 +478,7 @@ internal static string GetServerDefinedTypeFullName(Type type) internal static string GetClientFieldName(Type t, string serverDefinedName) { ODataTypeInfo typeInfo = GetODataTypeInfo(t); - + List serverDefinedNames = serverDefinedName.Split(',').Select(name => name.Trim()).ToList(); List clientMemberNames = new List(); foreach (var serverSideName in serverDefinedNames) @@ -618,7 +618,7 @@ internal static bool TryGetContainerProperty(object instance, out IDictionary , SortedDictionary<,> , ConcurrentDictionary<,> , etc - must also have parameterless constructor if (!propertyType.IsInterface() && !propertyType.IsAbstract() && propertyType.GetInstanceConstructor(true, new Type[0]) != null) { @@ -631,7 +631,7 @@ internal static bool TryGetContainerProperty(object instance, out IDictionary)Util.ActivatorCreateInstance(dictionaryType); } else - { + { // Not easy to figure out the implementing type return false; } @@ -688,59 +688,152 @@ private static bool IsOverride(Type type, PropertyInfo propertyInfo) } /// - /// Tries to resolve a type with specified name from the loaded assemblies. + /// Tries to resolve a type with specified name, first from the specified assembly and then from other loaded assemblies. /// + /// Assembly expected to contain the proxy classes. /// Name of the type to resolve. /// Namespace of the type. /// Namespace that the resolved type is expected to be. /// Usually same as but can be different /// where namespace for client types does not match namespace in service types. /// The resolved type. - /// true if type was successfully resolved; otherwise false. - internal static bool TryResolveType(string typeName, string fullNamespace, string languageDependentNamespace, out Type matchedType) - { + /// true if a type with the specified name is successfully resolved; otherwise false. + internal static bool TryResolveType( + Assembly targetAssembly, + string typeName, + string fullNamespace, + string languageDependentNamespace, + out Type matchedType) + { + Debug.Assert(targetAssembly != null, "targetAssembly != null"); Debug.Assert(typeName != null, "typeName != null"); - matchedType = null; - int namespaceLength = fullNamespace?.Length ?? 0; - string serverDefinedName = typeName.Substring(namespaceLength + 1); + int fullNamespaceLength = fullNamespace?.Length ?? 0; + string qualifiedTypeName = string.Concat(languageDependentNamespace, typeName.Substring(fullNamespaceLength)); + string serverDefinedName = fullNamespaceLength > 0 ? typeName.Substring(fullNamespaceLength + 1) : typeName; - // Searching only loaded assemblies, not referenced assemblies - foreach (Assembly assembly in AppDomain.CurrentDomain.GetAssemblies()) + // We first try to look for the type from the assembly expected to contain the proxy classes + matchedType = FindType( + targetAssembly, + qualifiedTypeName, + serverDefinedName, + languageDependentNamespace); + + if (matchedType != null) + { + return true; + } + + var entryAssembly = Assembly.GetEntryAssembly(); + if (entryAssembly != null && !entryAssembly.Equals(targetAssembly)) { - matchedType = assembly.GetType(string.Concat(languageDependentNamespace, typeName.Substring(namespaceLength)), false); + // Next, we try to look for the type from the entry assembly + matchedType = FindType( + entryAssembly, + qualifiedTypeName, + serverDefinedName, + languageDependentNamespace); + if (matchedType != null) { return true; } + } - IEnumerable types = null; - - try + // Searching only loaded assemblies, not referenced assemblies + foreach (Assembly assembly in AppDomain.CurrentDomain.GetAssemblies()) + { + if (assembly.Equals(targetAssembly) + || assembly.Equals(entryAssembly) + || SkipAssembly(assembly)) { - types = assembly.GetTypes(); + continue; } - catch (ReflectionTypeLoadException) + + matchedType = FindType( + assembly, + qualifiedTypeName, + serverDefinedName, + languageDependentNamespace); + + if (matchedType != null) { - // Ignore + return true; } + } - if (types != null) + return false; + } + + /// + /// Searches for a type that matches the specified type name from the specified assembly. + /// + /// Assembly that the specified type is expected to be. + /// The namespace-qualified name of the type. + /// The namespace-qualified name of corresponding server type. + /// Namespace that the specified type is expected to be. + /// The type if found in the specified assembly; otherwise null. + private static Type FindType( + Assembly assembly, + string qualifiedTypeName, + string serverDefinedName, + string languageDependentNamespace) + { + Type matchedType = assembly.GetType(qualifiedTypeName, throwOnError: false); + + if (matchedType != null) + { + return matchedType; + } + + Type[] types = null; + + try + { + types = assembly.GetTypes(); + } + catch (ReflectionTypeLoadException) + { + // Ignore + } + + if (types != null) + { + for (int i = 0; i < types.Length; i++) { - foreach (Type type in types) + Type type = types[i]; + + object[] originalNameAttributes = type.GetCustomAttributes(typeof(OriginalNameAttribute), inherit: true); + if (originalNameAttributes.Length == 0) { - OriginalNameAttribute originalNameAttribute = (OriginalNameAttribute)type.GetCustomAttributes(typeof(OriginalNameAttribute), true).SingleOrDefault(); - if (string.Equals(originalNameAttribute?.OriginalName, serverDefinedName, StringComparison.Ordinal) - && type.Namespace.Equals(languageDependentNamespace, StringComparison.Ordinal)) - { - matchedType = type; - return true; - } + continue; + } + + OriginalNameAttribute originalNameAttribute = (OriginalNameAttribute)originalNameAttributes[0]; + if (string.Equals(originalNameAttribute.OriginalName, serverDefinedName, StringComparison.Ordinal) + && type.Namespace.Equals(languageDependentNamespace, StringComparison.Ordinal)) + { + matchedType = type; } } } - return false; + return matchedType; + } + + /// + /// Checks whether to skip the assembly when trying to find a type to be used for materialization. + /// + /// The assembly to check. + /// true to skip the assembly; otherwise false. + private static bool SkipAssembly(Assembly assembly) + { + return assembly.Equals(typeof(string).Assembly) // mscorlib assembly + || assembly.Equals(typeof(Uri).Assembly) // Common types assembly + || assembly.Equals(typeof(ClientEdmModel).Assembly) // OData client assembly + || assembly.Equals(typeof(ODataItem).Assembly) // OData core assembly + || assembly.Equals(typeof(EdmModel).Assembly) // OData Edm assembly + || assembly.Equals(typeof(Spatial.Geography).Assembly); // Spatial assembly } } } diff --git a/test/FunctionalTests/Tests/DataServices/UnitTests/Client.TDD.Tests/Tests/Materialization/CamelCasedTypeMaterializationTests.cs b/test/FunctionalTests/Tests/DataServices/UnitTests/Client.TDD.Tests/Tests/Materialization/CamelCasedTypeMaterializationTests.cs index 9faaa3ef5e..a8c2e8a63e 100644 --- a/test/FunctionalTests/Tests/DataServices/UnitTests/Client.TDD.Tests/Tests/Materialization/CamelCasedTypeMaterializationTests.cs +++ b/test/FunctionalTests/Tests/DataServices/UnitTests/Client.TDD.Tests/Tests/Materialization/CamelCasedTypeMaterializationTests.cs @@ -46,7 +46,7 @@ public partial class CamelCasedTypeMaterializationTests private const string ServiceUri = "http://tempuri.org/"; private ClientEdmModel clientModel; - private DataServiceContext dataServiceContext; + private Container dataServiceContext; public CamelCasedTypeMaterializationTests() { @@ -125,6 +125,46 @@ public void MaterializationForEntitySetBoundToBaseEntityTypeCollection() Assert.Empty(rectangles[0].Attributes); } + [Fact] + public void MaterializationTypeShouldBeResolvedFromTargetAssembly() + { + var typeName = typeof(Rectangle).FullName; + var typeNamespace = typeName.Substring(0, typeName.LastIndexOf('.')); + + var resolvedType = dataServiceContext.DefaultResolveType(typeName, typeNamespace, typeNamespace); + + Assert.Equal(typeof(Rectangle), resolvedType); + } + + [Fact] + public void MaterializationTypeShouldBeResolvedFromLoadedAssembly() + { + var localClientModel = new ClientEdmModel(ODataProtocolVersion.V4); + // Use of DataServiceContext class directly will cause the materialization type not to be in the target assembly + // since this.GetType().GetAssembly() in DefaultResolveType method will return Microsoft.OData.Client assembly. + // Materialization type will be resolved from a loaded assembly + var localDataServiceContext = new DataServiceContext( + new Uri(ServiceUri), + ODataProtocolVersion.V4, + localClientModel); + localDataServiceContext.UndeclaredPropertyBehavior = UndeclaredPropertyBehavior.Support; + + using (var reader = XmlReader.Create(new StringReader(CamelCasedEdmx))) + { + if (CsdlReader.TryParse(reader, out IEdmModel localServiceModel, out _)) + { + localDataServiceContext.Format.UseJson(localServiceModel); + } + } + + var typeName = typeof(Rectangle).FullName; + var typeNamespace = typeName.Substring(0, typeName.LastIndexOf('.')); + + var resolvedType = localDataServiceContext.DefaultResolveType(typeName, typeNamespace, typeNamespace); + + Assert.Equal(typeof(Rectangle), resolvedType); + } + private void ConfigureOnMessageCreating(string payload) { dataServiceContext.Configurations.RequestPipeline.OnMessageCreating = (args) => @@ -145,7 +185,7 @@ private void ConfigureOnMessageCreating(string payload) private void InitializeEdmModel() { this.clientModel = new ClientEdmModel(ODataProtocolVersion.V4); - this.dataServiceContext = new DataServiceContext(new Uri(ServiceUri), ODataProtocolVersion.V4, this.clientModel); + this.dataServiceContext = new Container(new Uri(ServiceUri), this.clientModel); this.dataServiceContext.UndeclaredPropertyBehavior = UndeclaredPropertyBehavior.Support; this.dataServiceContext.ResolveType = (typeName) => { @@ -186,6 +226,7 @@ private void InitializeEdmModel() namespace NS.Models { + using System; using System.Collections.ObjectModel; using Microsoft.OData.Client; @@ -216,4 +257,13 @@ public Rectangle() [OriginalName("attributes")] public ObservableCollection Attributes { get; set; } } + + internal class Container : DataServiceContext + { + internal Container(Uri serviceRoot, ClientEdmModel model) + : base(serviceRoot, ODataProtocolVersion.V4, model) + { + + } + } }