Package com.salesforce.phoenix.parse

Source Code of com.salesforce.phoenix.parse.ParseNodeRewriter$CompoundNodeFactory

/*******************************************************************************
* Copyright (c) 2013, Salesforce.com, Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
*     Redistributions of source code must retain the above copyright notice,
*     this list of conditions and the following disclaimer.
*     Redistributions in binary form must reproduce the above copyright notice,
*     this list of conditions and the following disclaimer in the documentation
*     and/or other materials provided with the distribution.
*     Neither the name of Salesforce.com nor the names of its contributors may
*     be used to endorse or promote products derived from this software without
*     specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
******************************************************************************/
package com.salesforce.phoenix.parse;

import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import org.apache.hadoop.hbase.filter.CompareFilter.CompareOp;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.salesforce.phoenix.compile.ColumnResolver;
import com.salesforce.phoenix.schema.AmbiguousColumnException;
import com.salesforce.phoenix.schema.ColumnNotFoundException;

/**
*
* Base class for visitors that rewrite the expression node hierarchy
*
* @author jtaylor
* @since 0.1
*/
public class ParseNodeRewriter extends TraverseAllParseNodeVisitor<ParseNode> {
   
    protected static final ParseNodeFactory NODE_FACTORY = new ParseNodeFactory();

    public static ParseNode rewrite(ParseNode where, ParseNodeRewriter rewriter) throws SQLException {
        if (where == null) {
            return null;
        }
        rewriter.reset();
        return where.accept(rewriter);
    }
   
    /**
     * Rewrite the select statement by switching any constants to the right hand side
     * of the expression.
     * @param statement the select statement
     * @return new select statement
     * @throws SQLException
     */
    public static SelectStatement rewrite(SelectStatement statement, ParseNodeRewriter rewriter) throws SQLException {
        Map<String,ParseNode> aliasMap = rewriter.getAliasMap();
        List<TableNode> from = statement.getFrom();
        List<TableNode> normFrom = from;
        if (from.size() > 1) {
          for (int i = 1; i < from.size(); i++) {
            TableNode tableNode = from.get(i);
            if (tableNode instanceof JoinTableNode) {
              JoinTableNode joinTableNode = (JoinTableNode) tableNode;
              ParseNode onNode = joinTableNode.getOnNode();
              rewriter.reset();
              ParseNode normOnNode = onNode.accept(rewriter);
              if (onNode == normOnNode) {
                if (from != normFrom) {
                  normFrom.add(tableNode);
                }
                continue;
              }
              if (from == normFrom) {
                normFrom = Lists.newArrayList(from.subList(0, i));
              }
              TableNode normTableNode = NODE_FACTORY.join(joinTableNode.getType(), normOnNode, joinTableNode.getTable());
              normFrom.add(normTableNode);
            } else if (from != normFrom) {
          normFrom.add(tableNode);
        }
          }
        }
        ParseNode where = statement.getWhere();
        ParseNode normWhere = where;
        if (where != null) {
            rewriter.reset();
            normWhere = where.accept(rewriter);
        }
        ParseNode having = statement.getHaving();
        ParseNode normHaving= having;
        if (having != null) {
            rewriter.reset();
            normHaving = having.accept(rewriter);
        }
        List<AliasedNode> selectNodes = statement.getSelect();
        List<AliasedNode> normSelectNodes = selectNodes;
        for (int i = 0; i < selectNodes.size(); i++) {
            AliasedNode aliasedNode = selectNodes.get(i);
            ParseNode selectNode = aliasedNode.getNode();
            rewriter.reset();
            ParseNode normSelectNode = selectNode.accept(rewriter);
            if (selectNode == normSelectNode) {
                if (selectNodes != normSelectNodes) {
                    normSelectNodes.add(aliasedNode);
                }
                continue;
            }
            if (selectNodes == normSelectNodes) {
                normSelectNodes = Lists.newArrayList(selectNodes.subList(0, i));
            }
            AliasedNode normAliasedNode = NODE_FACTORY.aliasedNode(aliasedNode.isCaseSensitve() ? '"' + aliasedNode.getAlias() + '"' : aliasedNode.getAlias(), normSelectNode);
            normSelectNodes.add(normAliasedNode);
        }
        // Add to map in separate pass so that we don't try to use aliases
        // while processing the select expressions
        if (aliasMap != null) {
            for (int i = 0; i < normSelectNodes.size(); i++) {
                AliasedNode aliasedNode = normSelectNodes.get(i);
                ParseNode selectNode = aliasedNode.getNode();
                String alias = aliasedNode.getAlias();
                if (alias != null) {
                    aliasMap.put(alias, selectNode);
                }
            }
        }
       
        List<ParseNode> groupByNodes = statement.getGroupBy();
        List<ParseNode> normGroupByNodes = groupByNodes;
        for (int i = 0; i < groupByNodes.size(); i++) {
            ParseNode groupByNode = groupByNodes.get(i);
            rewriter.reset();
            ParseNode normGroupByNode = groupByNode.accept(rewriter);
            if (groupByNode == normGroupByNode) {
                if (groupByNodes != normGroupByNodes) {
                    normGroupByNodes.add(groupByNode);
                }
                continue;
            }
            if (groupByNodes == normGroupByNodes) {
                normGroupByNodes = Lists.newArrayList(groupByNodes.subList(0, i));
            }
            normGroupByNodes.add(normGroupByNode);
        }
        List<OrderByNode> orderByNodes = statement.getOrderBy();
        List<OrderByNode> normOrderByNodes = orderByNodes;
        for (int i = 0; i < orderByNodes.size(); i++) {
            OrderByNode orderByNode = orderByNodes.get(i);
            ParseNode node = orderByNode.getNode();
            rewriter.reset();
            ParseNode normNode = node.accept(rewriter);
            if (node == normNode) {
                if (orderByNodes != normOrderByNodes) {
                    normOrderByNodes.add(orderByNode);
                }
                continue;
            }
            if (orderByNodes == normOrderByNodes) {
                normOrderByNodes = Lists.newArrayList(orderByNodes.subList(0, i));
            }
            normOrderByNodes.add(NODE_FACTORY.orderBy(normNode, orderByNode.isNullsLast(), orderByNode.isAscending()));
        }
       
        // Return new SELECT statement with updated WHERE clause
        if (normFrom == from &&
            normWhere == where &&
                normHaving == having &&
                selectNodes == normSelectNodes &&
                groupByNodes == normGroupByNodes &&
                orderByNodes == normOrderByNodes) {
            return statement;
        }
        return NODE_FACTORY.select(normFrom, statement.getHint(), statement.isDistinct(),
                normSelectNodes, normWhere, normGroupByNodes, normHaving, normOrderByNodes,
                statement.getLimit(), statement.getBindCount(), statement.isAggregate());
    }
   
    private Map<String, ParseNode> getAliasMap() {
        return aliasMap;
    }

    private final ColumnResolver resolver;
    private final Map<String, ParseNode> aliasMap;
    private int nodeCount;
   
    public boolean isTopLevel() {
        return nodeCount == 0;
    }
   
    protected ParseNodeRewriter() {
        this.resolver = null;
        this.aliasMap = null;
    }
   
    protected ParseNodeRewriter(ColumnResolver resolver) {
        this.resolver = resolver;
        this.aliasMap = null;
    }
   
    protected ParseNodeRewriter(ColumnResolver resolver, int maxAliasCount) {
        this.resolver = resolver;
        this.aliasMap = Maps.newHashMapWithExpectedSize(maxAliasCount);
    }
   
    protected ColumnResolver getResolver() {
        return resolver;
    }
   
    protected void reset() {
        this.nodeCount = 0;
    }
   
    private static interface CompoundNodeFactory {
        ParseNode createNode(List<ParseNode> children);
    }
   
    private ParseNode leaveCompoundNode(CompoundParseNode node, List<ParseNode> children, CompoundNodeFactory factory) {
        if (children.equals(node.getChildren())) {
            return node;
        } else { // Child nodes have been inverted (because a literal was found on LHS)
            return factory.createNode(children);
        }
    }
   
    @Override
    public ParseNode visitLeave(AndParseNode node, List<ParseNode> nodes) throws SQLException {
        return leaveCompoundNode(node, nodes, new CompoundNodeFactory() {
            @Override
            public ParseNode createNode(List<ParseNode> children) {
                return NODE_FACTORY.and(children);
            }
        });
    }

    @Override
    public ParseNode visitLeave(OrParseNode node, List<ParseNode> nodes) throws SQLException {
        return leaveCompoundNode(node, nodes, new CompoundNodeFactory() {
            @Override
            public ParseNode createNode(List<ParseNode> children) {
                return NODE_FACTORY.or(children);
            }
        });
    }
   
    @Override
    public ParseNode visitLeave(SubtractParseNode node, List<ParseNode> nodes) throws SQLException {
        return leaveCompoundNode(node, nodes, new CompoundNodeFactory() {
            @Override
            public ParseNode createNode(List<ParseNode> children) {
                return NODE_FACTORY.subtract(children);
            }
        });
    }
   
    @Override
    public ParseNode visitLeave(AddParseNode node, List<ParseNode> nodes) throws SQLException {
        return leaveCompoundNode(node, nodes, new CompoundNodeFactory() {
            @Override
            public ParseNode createNode(List<ParseNode> children) {
                return NODE_FACTORY.add(children);
            }
        });
    }
   
    @Override
    public ParseNode visitLeave(MultiplyParseNode node, List<ParseNode> nodes) throws SQLException {
        return leaveCompoundNode(node, nodes, new CompoundNodeFactory() {
            @Override
            public ParseNode createNode(List<ParseNode> children) {
                return NODE_FACTORY.multiply(children);
            }
        });
    }
   
    @Override
    public ParseNode visitLeave(DivideParseNode node, List<ParseNode> nodes) throws SQLException {
        return leaveCompoundNode(node, nodes, new CompoundNodeFactory() {
            @Override
            public ParseNode createNode(List<ParseNode> children) {
                return NODE_FACTORY.divide(children);
            }
        });
    }
   
    @Override
    public ParseNode visitLeave(final FunctionParseNode node, List<ParseNode> nodes) throws SQLException {
        return leaveCompoundNode(node, nodes, new CompoundNodeFactory() {
            @Override
            public ParseNode createNode(List<ParseNode> children) {
                return NODE_FACTORY.function(node.getName(),children);
            }
        });
    }
   
    @Override
    public ParseNode visitLeave(CaseParseNode node, List<ParseNode> nodes) throws SQLException {
        return leaveCompoundNode(node, nodes, new CompoundNodeFactory() {
            @Override
            public ParseNode createNode(List<ParseNode> children) {
                return NODE_FACTORY.caseWhen(children);
            }
        });
    }

    @Override
    public ParseNode visitLeave(final LikeParseNode node, List<ParseNode> nodes) throws SQLException {
        return leaveCompoundNode(node, nodes, new CompoundNodeFactory() {
            @Override
            public ParseNode createNode(List<ParseNode> children) {
                return NODE_FACTORY.like(children.get(0),children.get(1),node.isNegate());
            }
        });
    }
   
    @Override
    public ParseNode visitLeave(NotParseNode node, List<ParseNode> nodes) throws SQLException {
        return leaveCompoundNode(node, nodes, new CompoundNodeFactory() {
            @Override
            public ParseNode createNode(List<ParseNode> children) {
                return NODE_FACTORY.not(children.get(0));
            }
        });
    }
   
    @Override
    public ParseNode visitLeave(final CastParseNode node, List<ParseNode> nodes) throws SQLException {
        return leaveCompoundNode(node, nodes, new CompoundNodeFactory() {
            @Override
            public ParseNode createNode(List<ParseNode> children) {
                return NODE_FACTORY.cast(children.get(0), node.getDataType());
            }
        });
    }
   
    @Override
    public ParseNode visitLeave(final InListParseNode node, List<ParseNode> nodes) throws SQLException {
        return leaveCompoundNode(node, nodes, new CompoundNodeFactory() {
            @Override
            public ParseNode createNode(List<ParseNode> children) {
                return NODE_FACTORY.inList(children, node.isNegate());
            }
        });
    }
   
    @Override
    public ParseNode visitLeave(final IsNullParseNode node, List<ParseNode> nodes) throws SQLException {
        return leaveCompoundNode(node, nodes, new CompoundNodeFactory() {
            @Override
            public ParseNode createNode(List<ParseNode> children) {
                return NODE_FACTORY.isNull(children.get(0), node.isNegate());
            }
        });
    }
   
    /**
     * Rewrites expressions of the form (a, b, c) = (1, 2) as a = 1 and b = 2 and c is null
     * as this is equivalent and already optimized
     * @param lhs
     * @param rhs
     * @param andNodes
     * @throws SQLException
     */
    private void rewriteRowValueConstuctorEqualityComparison(ParseNode lhs, ParseNode rhs, List<ParseNode> andNodes) throws SQLException {
        if (lhs instanceof RowValueConstructorParseNode && rhs instanceof RowValueConstructorParseNode) {
            int i = 0;
            for (; i < Math.min(lhs.getChildren().size(),rhs.getChildren().size()); i++) {
                rewriteRowValueConstuctorEqualityComparison(lhs.getChildren().get(i), rhs.getChildren().get(i), andNodes);
            }
            for (; i < lhs.getChildren().size(); i++) {
                rewriteRowValueConstuctorEqualityComparison(lhs.getChildren().get(i), null, andNodes);
            }
            for (; i < rhs.getChildren().size(); i++) {
                rewriteRowValueConstuctorEqualityComparison(null, rhs.getChildren().get(i), andNodes);
            }
        } else if (lhs instanceof RowValueConstructorParseNode) {
            rewriteRowValueConstuctorEqualityComparison(lhs.getChildren().get(0), rhs, andNodes);
            for (int i = 1; i < lhs.getChildren().size(); i++) {
                rewriteRowValueConstuctorEqualityComparison(lhs.getChildren().get(i), null, andNodes);
            }
        } else if (rhs instanceof RowValueConstructorParseNode) {
            rewriteRowValueConstuctorEqualityComparison(lhs, rhs.getChildren().get(0), andNodes);
            for (int i = 1; i < rhs.getChildren().size(); i++) {
                rewriteRowValueConstuctorEqualityComparison(null, rhs.getChildren().get(i), andNodes);
            }
        } else if (lhs == null && rhs == null) { // null == null will end up making the query degenerate
            andNodes.add(NODE_FACTORY.comparison(CompareOp.EQUAL, null, null).accept(this));
        } else if (lhs == null) { // AND rhs IS NULL
            andNodes.add(NODE_FACTORY.isNull(rhs, false).accept(this));
        } else if (rhs == null) { // AND lhs IS NULL
            andNodes.add(NODE_FACTORY.isNull(lhs, false).accept(this));
        } else { // AND lhs = rhs
            andNodes.add(NODE_FACTORY.comparison(CompareOp.EQUAL, lhs, rhs).accept(this));
        }
    }
   
    @Override
    public ParseNode visitLeave(final ComparisonParseNode node, List<ParseNode> nodes) throws SQLException {
        ParseNode normNode = leaveCompoundNode(node, nodes, new CompoundNodeFactory() {
            @Override
            public ParseNode createNode(List<ParseNode> children) {
                return NODE_FACTORY.comparison(node.getFilterOp(), children.get(0), children.get(1));
            }
        });
       
        CompareOp op = node.getFilterOp();
        if (op == CompareOp.EQUAL || op == CompareOp.NOT_EQUAL) {
            // Rewrite row value constructor in = or != expression, as this is the same as if it was
            // used in an equality expression for each individual part.
            ParseNode lhs = normNode.getChildren().get(0);
            ParseNode rhs = normNode.getChildren().get(1);
            if (lhs instanceof RowValueConstructorParseNode || rhs instanceof RowValueConstructorParseNode) {
                List<ParseNode> andNodes = Lists.newArrayListWithExpectedSize(Math.max(lhs.getChildren().size(), rhs.getChildren().size()));
                rewriteRowValueConstuctorEqualityComparison(lhs,rhs,andNodes);
                normNode = NODE_FACTORY.and(andNodes);
                if (op == CompareOp.NOT_EQUAL) {
                    normNode = NODE_FACTORY.not(normNode);
                }
            }
        }
        return normNode;
    }
   
    @Override
    public ParseNode visitLeave(final BetweenParseNode node, List<ParseNode> nodes) throws SQLException {
        return leaveCompoundNode(node, nodes, new CompoundNodeFactory() {
            @Override
            public ParseNode createNode(List<ParseNode> children) {
                if(node.isNegate()) {
                    return NODE_FACTORY.not(NODE_FACTORY.and(children));
                } else {
                    return NODE_FACTORY.and(children);
                }
            }
        });
    }
   
    @Override
    public ParseNode visit(ColumnParseNode node) throws SQLException {
        // If we're resolving aliases and we have an unqualified ColumnParseNode,
        // check if we find the name in our alias map.
        if (aliasMap != null && node.getTableName() == null) {
            ParseNode aliasedNode = aliasMap.get(node.getName());
            // If we found something, then try to resolve it unless the two nodes are the same
            if (aliasedNode != null && !node.equals(aliasedNode)) {
                try {
                    // If we're able to resolve it, that means we have a conflict
                    resolver.resolveColumn(node.getSchemaName(), node.getTableName(), node.getName());
                    throw new AmbiguousColumnException(node.getName());
                } catch (ColumnNotFoundException e) {
                    // Not able to resolve alias as a column name as well, so we use the alias
                    return aliasedNode;
                }
            }
        }
        return node;
    }

    @Override
    public ParseNode visit(LiteralParseNode node) throws SQLException {
        return node;
    }
   
    @Override
    public ParseNode visit(BindParseNode node) throws SQLException {
        return node;
    }
   
    @Override
    public ParseNode visit(WildcardParseNode node) throws SQLException {
        return node;
    }
   
    @Override
    public ParseNode visit(FamilyWildcardParseNode node) throws SQLException {
        return node;
    }
   
    @Override
    public List<ParseNode> newElementList(int size) {
        nodeCount += size;
        return new ArrayList<ParseNode>(size);
    }
   
    @Override
    public ParseNode visitLeave(StringConcatParseNode node, List<ParseNode> l) throws SQLException {
        return node;
    }

    @Override
    public void addElement(List<ParseNode> l, ParseNode element) {
        nodeCount--;
        if (element != null) {
            l.add(element);
        }
    }

    @Override
    public ParseNode visitLeave(RowValueConstructorParseNode node, List<ParseNode> children) throws SQLException {
        // Strip trailing nulls from rvc as they have no meaning
        if (children.get(children.size()-1) == null) {
            children = Lists.newArrayList(children);
            do {
                children.remove(children.size()-1);
            } while (children.size() > 0 && children.get(children.size()-1) == null);
            // If we're down to a single child, it's not a rvc anymore
            if (children.size() == 0) {
                return null;
            }
            if (children.size() == 1) {
                return children.get(0);
            }
        }
        // Flatten nested row value constructors, as this makes little sense and adds no information
        List<ParseNode> flattenedChildren = children;
        for (int i = 0; i < children.size(); i++) {
            ParseNode child = children.get(i);
            if (child instanceof RowValueConstructorParseNode) {
                if (flattenedChildren == children) {
                    flattenedChildren = Lists.newArrayListWithExpectedSize(children.size() + child.getChildren().size());
                    flattenedChildren.addAll(children.subList(0, i));
                }
                flattenedChildren.addAll(child.getChildren());
            }
        }
       
        return leaveCompoundNode(node, flattenedChildren, new CompoundNodeFactory() {
            @Override
            public ParseNode createNode(List<ParseNode> children) {
                return NODE_FACTORY.rowValueConstructor(children);
            }
        });
    }

  @Override
  public ParseNode visit(SequenceValueParseNode node) throws SQLException {   
    return node;
  }

  @Override
  public ParseNode visitLeave(ArrayConstructorNode node, List<ParseNode> nodes) throws SQLException {
      return leaveCompoundNode(node, nodes, new CompoundNodeFactory() {
            @Override
            public ParseNode createNode(List<ParseNode> children) {
                return NODE_FACTORY.upsertStmtArrayNode(children);
            }
        });
  }
}
TOP

Related Classes of com.salesforce.phoenix.parse.ParseNodeRewriter$CompoundNodeFactory

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.