How may Entity Framework be wrapped to intercept LINQ expressions right before they are executed?

c# entity-framework expression-trees linq

Question

Just before execution, I want to change some of the LINQ statement. Additionally, I'm having trouble injecting my rewriter in the proper location (at all actually).

It ultimately boils down to the Entity Framework source (in Reflector)IQueryProvider.Execute in EF, which is connected to the expression via theObjectContext providing theinternal IQueryProvider Provider { get; } property.

As a result, I developed a wrapper class (implementingIQueryProvider ) to do the Expression rewriting when Execute is called and deliver the result back to the original Provider.

The issue is the field behind.Provider is private ObjectQueryProvider _queryProvider; . ThisObjectQueryProvider is a class internal sealed, hence a subclass that offers the additional rewriting cannot be created.

As a result of the very closely connected ObjectContext, this strategy lead me to a dead end.

How can this issue be resolved? Am I focusing on the wrong thing? Is there any way I can be involved with this?ObjectQueryProvider ?

Update: While the solutions offered all function when you "wrap" the ObjectContext using the Repository pattern, it would be better to have a solution that would allow for direct use of the produced subclass from ObjectContext. Maintaining compatibility with the Dynamic Data scaffolding in this manner

1
24
12/3/2009 7:12:34 PM

Accepted Answer

I've developed a functional wrapper in response to Arthur's response.

To wrap each LINQ query with your own QueryProvider and IQueryable root, use the given snippets. This would imply that you must have some degree of control over the first query (as you would most often have when using any kind of pattern).

This solution has the drawback of being opaque; a better scenario would be to inject anything into the entities container at the constructor level.

I've developed a compilable version of the implementation, made entity framework operate with it, and added support for ObjectQuery. Add a method. You may copy the term visitor class from MSDN.

public class QueryTranslator<T> : IOrderedQueryable<T>
{
    private Expression expression = null;
    private QueryTranslatorProvider<T> provider = null;

    public QueryTranslator(IQueryable source)
    {
        expression = Expression.Constant(this);
        provider = new QueryTranslatorProvider<T>(source);
    }

    public QueryTranslator(IQueryable source, Expression e)
    {
        if (e == null) throw new ArgumentNullException("e");
        expression = e;
        provider = new QueryTranslatorProvider<T>(source);
    }

    public IEnumerator<T> GetEnumerator()
    {
        return ((IEnumerable<T>)provider.ExecuteEnumerable(this.expression)).GetEnumerator();
    }

    System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
    {
        return provider.ExecuteEnumerable(this.expression).GetEnumerator();
    }

    public QueryTranslator<T> Include(String path)
    {
        ObjectQuery<T> possibleObjectQuery = provider.source as ObjectQuery<T>;
        if (possibleObjectQuery != null)
        {
            return new QueryTranslator<T>(possibleObjectQuery.Include(path));
        }
        else
        {
            throw new InvalidOperationException("The Include should only happen at the beginning of a LINQ expression");
        }
    }

    public Type ElementType
    {
        get { return typeof(T); }
    }

    public Expression Expression
    {
        get { return expression; }
    }

    public IQueryProvider Provider
    {
        get { return provider; }
    }
}

public class QueryTranslatorProvider<T> : ExpressionVisitor, IQueryProvider
{
    internal IQueryable source;

    public QueryTranslatorProvider(IQueryable source)
    {
        if (source == null) throw new ArgumentNullException("source");
        this.source = source;
    }

    public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");

        return new QueryTranslator<TElement>(source, expression) as IQueryable<TElement>;
    }

    public IQueryable CreateQuery(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");
        Type elementType = expression.Type.GetGenericArguments().First();
        IQueryable result = (IQueryable)Activator.CreateInstance(typeof(QueryTranslator<>).MakeGenericType(elementType),
            new object[] { source, expression });
        return result;
    }

    public TResult Execute<TResult>(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");
        object result = (this as IQueryProvider).Execute(expression);
        return (TResult)result;
    }

    public object Execute(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");

        Expression translated = this.Visit(expression);
        return source.Provider.Execute(translated);
    }

    internal IEnumerable ExecuteEnumerable(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");

        Expression translated = this.Visit(expression);
        return source.Provider.CreateQuery(translated);
    }

    #region Visitors
    protected override Expression VisitConstant(ConstantExpression c)
    {
        // fix up the Expression tree to work with EF again
        if (c.Type == typeof(QueryTranslator<T>))
        {
            return source.Expression;
        }
        else
        {
            return base.VisitConstant(c);
        }
    }
    #endregion
}

an example of how to use your repository:

public IQueryable<User> List()
{
    return new QueryTranslator<User>(entities.Users).Include("Department");
}
15
12/4/2009 10:09:23 AM

Popular Answer

I own the precise source code you want, but I have no clue how to attach a file.

I had to alter the code, so the following excerpts may not compile:

IQueryable:

public class QueryTranslator<T> : IOrderedQueryable<T>
{
    private Expression _expression = null;
    private QueryTranslatorProvider<T> _provider = null;

    public QueryTranslator(IQueryable source)
    {
        _expression = Expression.Constant(this);
        _provider = new QueryTranslatorProvider<T>(source);
    }

    public QueryTranslator(IQueryable source, Expression e)
    {
        if (e == null) throw new ArgumentNullException("e");
        _expression = e;
        _provider = new QueryTranslatorProvider<T>(source);
    }

    public IEnumerator<T> GetEnumerator()
    {
        return ((IEnumerable<T>)_provider.ExecuteEnumerable(this._expression)).GetEnumerator();
    }

    IEnumerator System.Collections.IEnumerable.GetEnumerator()
    {
        return _provider.ExecuteEnumerable(this._expression).GetEnumerator();
    }

    public Type ElementType
    {
        get { return typeof(T); }
    }

    public Expression Expression
    {
        get { return _expression; }
    }

    public IQueryProvider Provider
    {
        get { return _provider; }
    }
}

IQueryProvider:

public class QueryTranslatorProvider<T> : ExpressionTreeTranslator, IQueryProvider
{
    IQueryable _source;

    public QueryTranslatorProvider(IQueryable source)
    {
        if (source == null) throw new ArgumentNullException("source");
        _source = source;
    }

    #region IQueryProvider Members

    public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");
        return new QueryTranslator<TElement>(_source, expression) as IQueryable<TElement>;
    }

    public IQueryable CreateQuery(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");

        Type elementType = expression.Type.FindElementTypes().First();
        IQueryable result = (IQueryable)Activator.CreateInstance(typeof(QueryTranslator<>).MakeGenericType(elementType),
            new object[] { _source, expression });
        return result;
    }

    public TResult Execute<TResult>(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");
        object result = (this as IQueryProvider).Execute(expression);
        return (TResult)result;
    }

    public object Execute(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");

        Expression translated = this.Visit(expression);

        return _source.Provider.Execute(translated);            
    }

    internal IEnumerable ExecuteEnumerable(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");

        Expression translated = this.Visit(expression);

        return _source.Provider.CreateQuery(translated);
    }

    #endregion        

    #region Visits
    protected override MethodCallExpression VisitMethodCall(MethodCallExpression m)
    {
        return m;
    }

    protected override Expression VisitUnary(UnaryExpression u)
    {
         return Expression.MakeUnary(u.NodeType, base.Visit(u.Operand), u.Type.ToImplementationType(), u.Method);
    }
    #endregion
}

Use with caution—adapted code! (May fail to compile)

private Dictionary<Type, object> _table = new Dictionary<Type, object>();
public override IQueryable<T> GetObjectQuery<T>()
{
    if (!_table.ContainsKey(type))
    {
        _table[type] = new QueryTranslator<T>(
            _ctx.CreateQuery<T>("[" + typeof(T).Name + "]"));
    }

    return (IQueryable<T>)_table[type];
}

Visitor/Translator Expression:

http://blogs.msdn.com/mattwar/archive/2007/07/31/linq-building-an-iqueryable-provider-part-ii.aspx

http://msdn.microsoft.com/en-us/library/bb882521.aspx

EDIT: FindElementTypes was added (). Hopefully now all Methods are available.

    /// <summary>
    /// Finds all implemented IEnumerables of the given Type
    /// </summary>
    public static IQueryable<Type> FindIEnumerables(this Type seqType)
    {
        if (seqType == null || seqType == typeof(object) || seqType == typeof(string))
            return new Type[] { }.AsQueryable();

        if (seqType.IsArray || seqType == typeof(IEnumerable))
            return new Type[] { typeof(IEnumerable) }.AsQueryable();

        if (seqType.IsGenericType && seqType.GetGenericArguments().Length == 1 && seqType.GetGenericTypeDefinition() == typeof(IEnumerable<>))
        {
            return new Type[] { seqType, typeof(IEnumerable) }.AsQueryable();
        }

        var result = new List<Type>();

        foreach (var iface in (seqType.GetInterfaces() ?? new Type[] { }))
        {
            result.AddRange(FindIEnumerables(iface));
        }

        return FindIEnumerables(seqType.BaseType).Union(result);
    }

    /// <summary>
    /// Finds all element types provided by a specified sequence type.
    /// "Element types" are T for IEnumerable&lt;T&gt; and object for IEnumerable.
    /// </summary>
    public static IQueryable<Type> FindElementTypes(this Type seqType)
    {
        return seqType.FindIEnumerables().Select(t => t.IsGenericType ? t.GetGenericArguments().Single() : typeof(object));
    }


Related Questions





Related

Licensed under: CC-BY-SA with attribution
Not affiliated with Stack Overflow
Licensed under: CC-BY-SA with attribution
Not affiliated with Stack Overflow