/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kylin.query.util;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.core.Values;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexSlot;
import org.apache.calcite.rex.RexVisitor;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.DateString;
import org.apache.calcite.util.TimestampString;
import org.apache.commons.lang3.StringUtils;
import org.apache.kylin.guava30.shaded.common.base.Preconditions;
import org.apache.kylin.guava30.shaded.common.collect.Lists;
import org.apache.kylin.metadata.datatype.DataType;
import org.apache.kylin.metadata.model.TblColRef;
import org.apache.kylin.query.relnode.ContextUtil;
import org.apache.kylin.query.relnode.OlapAggregateRel;
import org.apache.kylin.query.relnode.OlapJoinRel;
import org.apache.kylin.query.relnode.OlapProjectRel;
import org.apache.kylin.query.relnode.OlapTableScan;

public class RexUtils {
    private RexUtils() {
    }

    public static boolean joinMoreThanOneTable(Join join) {
        HashSet<Integer> left = new HashSet<Integer>();
        HashSet<Integer> right = new HashSet<Integer>();
        Set<Integer> indexes = RexUtils.getAllInputRefs(join.getCondition()).stream().map(RexSlot::getIndex).collect(Collectors.toSet());
        RexUtils.splitJoinInputIndex(join, indexes, left, right);
        return !RexUtils.colsComeFromSameSideOfJoin(join.getLeft(), left) || !RexUtils.colsComeFromSameSideOfJoin(join.getRight(), right);
    }

    private static boolean colsComeFromSameSideOfJoin(RelNode rel, Set<Integer> indexes) {
        if (rel instanceof Join) {
            Join join = (Join)rel;
            HashSet<Integer> left = new HashSet<Integer>();
            HashSet<Integer> right = new HashSet<Integer>();
            RexUtils.splitJoinInputIndex(join, indexes, left, right);
            if (left.isEmpty()) {
                return RexUtils.colsComeFromSameSideOfJoin(join.getRight(), right);
            }
            if (right.isEmpty()) {
                return RexUtils.colsComeFromSameSideOfJoin(join.getLeft(), left);
            }
            return false;
        }
        if (rel instanceof Project) {
            Set<Integer> inputIndexes = indexes.stream().map(idx -> (RexNode)((Project)rel).getProjects().get((int)idx)).flatMap(rex -> RexUtils.getAllInputRefs(rex).stream()).map(RexSlot::getIndex).collect(Collectors.toSet());
            return RexUtils.colsComeFromSameSideOfJoin(((Project)rel).getInput(), inputIndexes);
        }
        if (rel instanceof TableScan || rel instanceof Values) {
            return true;
        }
        return RexUtils.colsComeFromSameSideOfJoin(rel.getInput(0), indexes);
    }

    public static void splitJoinInputIndex(Join joinRel, Collection<Integer> indexes, Set<Integer> leftInputIndexes, Set<Integer> rightInputIndexes) {
        indexes.forEach(idx -> {
            if (idx < joinRel.getLeft().getRowType().getFieldCount()) {
                leftInputIndexes.add((Integer)idx);
            } else {
                rightInputIndexes.add(idx - joinRel.getLeft().getRowType().getFieldCount());
            }
        });
    }

    public static int countOperatorCall(RexNode condition, final Class<? extends SqlOperator> sqlOperator) {
        final AtomicInteger likeCount = new AtomicInteger(0);
        RexVisitorImpl<Void> likeVisitor = new RexVisitorImpl<Void>(true){

            public Void visitCall(RexCall call) {
                if (call.getOperator().getClass().equals(sqlOperator)) {
                    likeCount.incrementAndGet();
                }
                return (Void)super.visitCall(call);
            }
        };
        condition.accept((RexVisitor)likeVisitor);
        return likeCount.get();
    }

    public static Set<RexInputRef> getAllInputRefs(RexNode rexNode) {
        if (rexNode instanceof RexInputRef) {
            return Collections.singleton((RexInputRef)rexNode);
        }
        if (rexNode instanceof RexCall) {
            return RexUtils.getAllInputRefsCall((RexCall)rexNode);
        }
        return Collections.emptySet();
    }

    private static Set<RexInputRef> getAllInputRefsCall(RexCall rexCall) {
        return rexCall.getOperands().stream().flatMap(rexNode -> RexUtils.getAllInputRefs(rexNode).stream()).collect(Collectors.toSet());
    }

    public static boolean isMerelyTableColumnReference(RelNode rel, Collection<Integer> columnIndexes) {
        if (rel instanceof OlapProjectRel) {
            return RexUtils.isProjectMerelyTableColumnReference((OlapProjectRel)rel, columnIndexes);
        }
        if (rel instanceof OlapAggregateRel) {
            return RexUtils.isAggMerelyTableColumnReference((OlapAggregateRel)rel, columnIndexes);
        }
        if (rel instanceof OlapJoinRel) {
            return RexUtils.isJoinMerelyTableColumnReference(rel, columnIndexes);
        }
        for (RelNode inputRel : rel.getInputs()) {
            if (RexUtils.isMerelyTableColumnReference(inputRel, columnIndexes)) continue;
            return false;
        }
        return true;
    }

    private static boolean isJoinMerelyTableColumnReference(RelNode rel, Collection<Integer> columnIndexes) {
        int offset = 0;
        for (RelNode inputRel : rel.getInputs()) {
            HashSet<Integer> nextInputRefKeys = new HashSet<Integer>();
            for (Integer columnIdx : columnIndexes) {
                if (columnIdx - offset < 0 || columnIdx - offset >= inputRel.getRowType().getFieldCount()) continue;
                nextInputRefKeys.add(columnIdx - offset);
            }
            if (!RexUtils.isMerelyTableColumnReference(inputRel, nextInputRefKeys)) {
                return false;
            }
            offset += inputRel.getRowType().getFieldCount();
        }
        return true;
    }

    private static boolean isAggMerelyTableColumnReference(OlapAggregateRel rel, Collection<Integer> columnIndexes) {
        HashSet<Integer> nextInputRefKeys = new HashSet<Integer>();
        OlapAggregateRel agg = rel;
        for (Integer columnIdx : columnIndexes) {
            if (columnIdx >= agg.getRewriteGroupKeys().size()) {
                return false;
            }
            nextInputRefKeys.add((Integer)agg.getRewriteGroupKeys().get(columnIdx.intValue()));
        }
        return RexUtils.isMerelyTableColumnReference(agg.getInput(), nextInputRefKeys);
    }

    private static boolean isProjectMerelyTableColumnReference(OlapProjectRel rel, Collection<Integer> columnIndexes) {
        HashSet<Integer> nextInputRefKeys = new HashSet<Integer>();
        OlapProjectRel project = rel;
        for (Integer columnIdx : columnIndexes) {
            RexNode projExp = project.getProjects().get(columnIdx);
            if (projExp.getKind() == SqlKind.CAST) {
                projExp = (RexNode)((RexCall)projExp).getOperands().get(0);
            }
            if (!(projExp instanceof RexInputRef)) {
                return false;
            }
            nextInputRefKeys.add(((RexInputRef)projExp).getIndex());
        }
        return RexUtils.isMerelyTableColumnReference(project.getInput(), nextInputRefKeys);
    }

    public static boolean isMerelyTableColumnReference(OlapJoinRel rel, RexNode condition) {
        return RexUtils.isMerelyTableColumnReference((RelNode)rel, RexUtils.getAllInputRefs(condition).stream().map(RexSlot::getIndex).collect(Collectors.toSet()));
    }

    public static RexNode stripOffCastInColumnEqualPredicate(RexNode predicateNode) {
        if (!(predicateNode instanceof RexCall)) {
            return predicateNode;
        }
        RexCall predicate = (RexCall)predicateNode;
        if (predicate.getKind() == SqlKind.EQUALS) {
            boolean colEqualPredWithCast = false;
            ArrayList predicateOperands = Lists.newArrayList((Iterable)predicate.getOperands());
            for (int predicateOpIdx = 0; predicateOpIdx < predicateOperands.size(); ++predicateOpIdx) {
                RexNode predicateChild = (RexNode)predicateOperands.get(predicateOpIdx);
                if (predicateChild instanceof RexInputRef || !(predicateChild instanceof RexCall) || predicateChild.getKind() != SqlKind.CAST || !(((RexCall)predicateChild).getOperands().get(0) instanceof RexInputRef)) continue;
                predicateOperands.set(predicateOpIdx, ((RexCall)predicateOperands.get(predicateOpIdx)).getOperands().get(0));
                colEqualPredWithCast = true;
            }
            if (colEqualPredWithCast) {
                return predicate.clone(predicate.getType(), (List)predicateOperands);
            }
        }
        return predicate;
    }

    public static RexNode transformValue2RexLiteral(RexBuilder rexBuilder, String value, DataType colType) {
        Object parsedValue = colType.parseValue(value);
        switch (colType.getName()) {
            case "date": {
                String[] splits = StringUtils.split((String)value.trim(), (String)" ");
                Preconditions.checkArgument((splits.length >= 1 ? 1 : 0) != 0, (String)"split %s with error", (Object)value);
                return rexBuilder.makeDateLiteral(new DateString(splits[0]));
            }
            case "datetime": 
            case "timestamp": {
                String[] splits;
                RelDataType relDataType = rexBuilder.getTypeFactory().createSqlType(SqlTypeName.TIMESTAMP);
                int dotIndex = value.indexOf(".");
                if (dotIndex != -1 && Integer.parseInt(value.substring(dotIndex + 1)) == 0) {
                    value = value.substring(0, dotIndex);
                }
                String ts = (splits = StringUtils.split((String)value.trim(), (String)" ")).length == 1 ? value + " 00:00:00" : value;
                return rexBuilder.makeTimestampLiteral(new TimestampString(ts), relDataType.getPrecision());
            }
            case "string": 
            case "varchar": {
                RelDataType relDataType = rexBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR, colType.getPrecision());
                return rexBuilder.makeLiteral(parsedValue, relDataType, false);
            }
            case "numeric": 
            case "decimal": {
                RelDataType relDataType = rexBuilder.getTypeFactory().createSqlType(SqlTypeName.DECIMAL, colType.getPrecision());
                return rexBuilder.makeLiteral(parsedValue, relDataType, false);
            }
            case "byte": {
                RelDataType relDataType = rexBuilder.getTypeFactory().createSqlType(SqlTypeName.TINYINT, colType.getPrecision());
                return rexBuilder.makeLiteral(parsedValue, relDataType, false);
            }
            case "int": 
            case "int4": {
                RelDataType relDataType = rexBuilder.getTypeFactory().createSqlType(SqlTypeName.INTEGER, colType.getPrecision());
                return rexBuilder.makeLiteral(parsedValue, relDataType, false);
            }
            case "short": {
                RelDataType relDataType = rexBuilder.getTypeFactory().createSqlType(SqlTypeName.SMALLINT, colType.getPrecision());
                return rexBuilder.makeLiteral(parsedValue, relDataType, false);
            }
            case "long": 
            case "long8": {
                RelDataType relDataType = rexBuilder.getTypeFactory().createSqlType(SqlTypeName.BIGINT, colType.getPrecision());
                return rexBuilder.makeLiteral(parsedValue, relDataType, false);
            }
        }
        try {
            SqlTypeName sqlTypeName = SqlTypeName.get((String)colType.getName().toUpperCase(Locale.ROOT));
            int precision = colType.getPrecision();
            if (sqlTypeName == null) {
                throw new IllegalArgumentException(colType + " data type is not supported for filter column");
            }
            RelDataType relDataType = precision == -1 ? rexBuilder.getTypeFactory().createSqlType(sqlTypeName) : rexBuilder.getTypeFactory().createSqlType(sqlTypeName, precision);
            return rexBuilder.makeLiteral(parsedValue, relDataType, false);
        }
        catch (Exception e) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "%s data type is not supported for filter column", colType), e);
        }
    }

    public static RexInputRef transformColumn2RexInputRef(TblColRef tblColRef, Set<OlapTableScan> tableScans) {
        for (OlapTableScan tableScan : tableScans) {
            String tableIdentity = tableScan.getTableName();
            if (!tableIdentity.equals(tblColRef.getTable())) continue;
            int index = tableScan.getColumnRowType().getAllColumns().indexOf(tblColRef);
            if (index >= 0) {
                return ContextUtil.createUniqueInputRefAmongTables(tableScan, index, tableScans);
            }
            throw new IllegalStateException(String.format(Locale.ROOT, "Cannot find column %s in all tableScans", tblColRef.getIdentity()));
        }
        throw new IllegalStateException(String.format(Locale.ROOT, "Cannot find column %s in all tableScans", tblColRef.getIdentity()));
    }

    public static RexNode symmetricalExchange(RexBuilder rexBuilder, RexNode rexNode) {
        if (!(rexNode instanceof RexCall)) {
            return rexNode;
        }
        RexCall call = (RexCall)rexNode;
        SqlOperator operator = call.getOperator();
        List operands = call.getOperands();
        SqlKind kind = operator.getKind();
        SqlKind reversedKind = kind.reverse();
        int x = reversedKind.compareTo((Enum)kind);
        if (operands.size() == 2) {
            RexNode operand0 = (RexNode)operands.get(0);
            RexNode operand1 = (RexNode)operands.get(1);
            RexNode newOperand0 = RexUtils.symmetricalExchange(rexBuilder, operand0);
            RexNode newOperand1 = RexUtils.symmetricalExchange(rexBuilder, operand1);
            if (x < 0) {
                SqlOperator reverseOp = operator.reverse();
                if (reverseOp == null) {
                    return call.clone(call.getType(), Arrays.asList(newOperand0, newOperand1));
                }
                return rexBuilder.makeCall(call.getType(), reverseOp, Arrays.asList(newOperand1, newOperand0));
            }
            if (rexNode.isA((Collection)SqlKind.SYMMETRICAL_SAME_ARG_TYPE)) {
                if (RexUtils.reorderOperands(operand0, operand1) < 0) {
                    return call.clone(call.getType(), Arrays.asList(newOperand1, newOperand0));
                }
            } else {
                return call.clone(call.getType(), Arrays.asList(newOperand0, newOperand1));
            }
        }
        List newOperands = operands.stream().map(rex -> RexUtils.symmetricalExchange(rexBuilder, rex)).collect(Collectors.toList());
        return call.clone(call.getType(), newOperands);
    }

    private static int reorderOperands(RexNode operand0, RexNode operand1) {
        int x = operand0.getKind().compareTo((Enum)operand1.getKind());
        return x != 0 ? x : operand1.hashCode() - operand0.hashCode();
    }
}

