/**
 * Copyright (C) 2015 ZeroTurnaround OU
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *         http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.zeroturnaround.javarebel.integration.util;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;

import org.zeroturnaround.bundled.javassist.CannotCompileException;
import org.zeroturnaround.bundled.javassist.ClassPool;
import org.zeroturnaround.bundled.javassist.CtBehavior;
import org.zeroturnaround.bundled.javassist.CtClass;
import org.zeroturnaround.bundled.javassist.CtConstructor;
import org.zeroturnaround.bundled.javassist.CtField;
import org.zeroturnaround.bundled.javassist.CtMethod;
import org.zeroturnaround.bundled.javassist.CtNewMethod;
import org.zeroturnaround.bundled.javassist.Modifier;
import org.zeroturnaround.bundled.javassist.NotFoundException;
import org.zeroturnaround.bundled.javassist.bytecode.MethodInfo;
import org.zeroturnaround.bundled.javassist.expr.ExprEditor;
import org.zeroturnaround.bundled.javassist.expr.MethodCall;
import org.zeroturnaround.javarebel.LoggerFactory;

/**
 * This util is for the happy path of CtClass methods if you don'l like it then don't use this method and catch your own exceptions
 *
 * @author Andres Luuk
 */
public final class JavassistUtil {

  public static void init() {
  }

  public static boolean hasClass(final ClassPool cp, final String className) {
    try {
      return cp.getOrNull(className) != null;
    }
    catch (RuntimeException e) {
      return false;
    }
  }

  public static boolean hasDeclaredMethod(final CtClass ctClass, final String methodName) {
    try {
      return ctClass != null && ctClass.getDeclaredMethod(methodName) != null;
    }
    catch (NotFoundException e) {
      return false;
    }
    catch (RuntimeException e) {
      return false;
    }
  }

  public static boolean hasDeclaredMethod(final ClassPool cp, final CtClass ctClass, final String methodName, final String[] param) {
    try {
      return ctClass.getDeclaredMethod(methodName, cp.get(param)) != null;
    }
    catch (NotFoundException e) {
      return false;
    }
    catch (RuntimeException e) {
      return false;
    }
  }

  public static boolean hasDeclaredMethod(final ClassPool cp, final String klazz, final String methodName) {
    try {
      CtClass ctClass = cp.get(klazz);
      return ctClass != null && ctClass.getDeclaredMethod(methodName) != null;
    }
    catch (NotFoundException e) {
      return false;
    }
    catch (RuntimeException e) {
      return false;
    }
  }

  public static boolean hasDeclaredMethod(final ClassPool cp, final String klazz, final String methodName, final String[] param) {
    try {
      CtClass ctClass = cp.get(klazz);
      return ctClass != null && ctClass.getDeclaredMethod(methodName, cp.get(param)) != null;
    }
    catch (NotFoundException e) {
      return false;
    }
    catch (RuntimeException e) {
      return false;
    }
  }

  public static boolean hasDeclaredConstructor(final ClassPool cp, final String klazz, final String... param) {
    try {
      CtClass ctClass = cp.get(klazz);
      return ctClass != null && ctClass.getDeclaredConstructor(cp.get(param)) != null;
    }
    catch (NotFoundException e) {
      return false;
    }
    catch (RuntimeException e) {
      return false;
    }
  }

  public static boolean hasDeclaredConstructor(final CtClass ctClass, final String... param) {
    try {
      return ctClass != null && ctClass.getDeclaredConstructor(ctClass.getClassPool().get(param)) != null;
    }
    catch (NotFoundException e) {
      return false;
    }
    catch (RuntimeException e) {
      return false;
    }
  }

  /**
   *
   * @return CtConstructor or null if not found
   */
  public static CtConstructor getDeclaredConstructor(final ClassPool cp, final String klazz, final String... param) {
    try {
      return getDeclaredConstructor(cp.get(klazz), param);
    }
    catch (NotFoundException e) {
      return null;
    }
    catch (RuntimeException e) {
      return null;
    }
  }

  /**
   *
   * @return CtConstructor or null if not found
   */
  public static CtConstructor getDeclaredConstructor(final CtClass ctClass, final String... param) {
    try {
      if (ctClass != null) {
        return ctClass.getDeclaredConstructor(ctClass.getClassPool().get(param));
      }
      return null;
    }
    catch (NotFoundException e) {
      return null;
    }
    catch (RuntimeException e) {
      return null;
    }
  }

  public static boolean hasMethod(CtClass ctClass, String methodName) {
    try {
      for (CtMethod ctMethod : ctClass.getMethods()) {
        if (methodName.equals(ctMethod.getName()))
          return true;
      }
      return false;
    }
    catch (RuntimeException e) {
      return false;
    }
  }

  public static boolean hasMethod(final CtClass ctClass, final String methodName, final String methodDesc) {
    try {
      return ctClass.getMethod(methodName, methodDesc) != null;
    }
    catch (NotFoundException e) {
      return false;
    }
    catch (RuntimeException e) {
      return false;
    }
  }

  /**
   * Returns the first class out of the given ones that is found in the classpool or {@code null} if none is found.
   */
  public static CtClass getFirstExistingClass(ClassPool cp, String... classNames) {
    for (String className : classNames) {
      try {
        return cp.get(className);
      }
      catch (NotFoundException ignored) {
      }
    }
    return null;
  }

  public static CtField getFirstExistingField(CtClass ctClass, String... fieldNames) {
    if (ctClass == null)
      return null;
    for (String name : fieldNames) {
      try {
        return ctClass.getDeclaredField(name);
      }
      catch (NotFoundException ignored) {
      }
    }
    return null;
  }

  public static CtMethod getFirstExistingMethod(CtClass ctClass, String... methodNames) {
    if (ctClass == null)
      return null;
    for (String name : methodNames) {
      try {
        return ctClass.getDeclaredMethod(name);
      }
      catch (NotFoundException ignored) {
      }
    }
    return null;
  }

  public static CtMethod getFirstExistingMethod(final CtMethod first, final ClassPool cp, final CtClass ctClass, final String methodName, Class<?>... param) {
    if (first == null) {
      try {
        String[] classNames = new String[param.length];
        for (int i = 0; i < classNames.length; i++) {
          classNames[i] = param[i].getName();
        }
        return getFirstExistingMethod(null, cp, ctClass, methodName, classNames);
      }
      catch (Exception e) {
      }
    }
    return first;
  }

  public static CtMethod getFirstExistingMethod(final CtMethod first, final ClassPool cp, final CtClass ctClass, final String methodName, String... param) {
    if (first == null) {
      try {
        return ctClass.getDeclaredMethod(methodName, cp.get(param));
      }
      catch (Exception e) {
      }
    }
    return first;
  }

  public static boolean hasDeclaredField(CtClass ctClass, String fieldName) {
    try {
      ctClass.getDeclaredField(fieldName);
      return true;
    }
    catch (NotFoundException e) {
      return false;
    }
    catch (RuntimeException e) {
      return false;
    }
  }

  public static boolean hasDeclaredField(final ClassPool cp, final String klazz, String fieldName) {
    try {
      CtClass ctClass = cp.get(klazz);
      return ctClass != null && ctClass.getDeclaredField(fieldName) != null;
    }
    catch (NotFoundException e) {
      return false;
    }
    catch (RuntimeException e) {
      return false;
    }
  }

  public static CtField getDeclaredField(CtClass ctClass, String fieldName) {
    try {
      return ctClass.getDeclaredField(fieldName);
    }
    catch (NotFoundException e) {
      return null;
    }
  }

  public static boolean hasMethodOnField(CtClass ctClass, String fieldName, String methodName, String... paramClasses) {
    try {
      CtField field = ctClass.getField(fieldName);
      return field.getType().getDeclaredMethod(methodName, ctClass.getClassPool().get(paramClasses)) != null;
    }
    catch (NotFoundException e) {
      return false;
    }
    catch (RuntimeException e) {
      return false;
    }
  }

  public static boolean implementsInterface(CtClass ctClass, String ifaceName) throws NotFoundException {
    try {
      CtClass[] ifaces = ctClass.getInterfaces();
      for (CtClass iface : ifaces) {
        if (iface.getName().equals(ifaceName) || implementsInterface(iface, ifaceName))
          return true;
      }
      CtClass superclass = ctClass.getSuperclass();
      if (superclass != null && !superclass.getName().equals(Object.class.getName()))
        return implementsInterface(superclass, ifaceName);
    }
    catch (RuntimeException e) {
      // It might be that CtClass is made, i.e. ClassPool finds the class resource, but opening a stream from the resource fails.
      // In that case get a wrapped NotFoundException inside RuntimeException
      // anyway, lets continue checking the other deps, maybe those are not broken
      LoggerFactory.getLogger("Javassit").warn("Checking {} failed with: {}", ctClass.getName(), e.getMessage());
    }
    return false;
  }

  /**
   * Finds a declared method by name and params.
   *
   * @return CtMethod or null if method or parameter class not found
   */
  public static CtMethod getDeclaredMethod(CtClass ctClass, String methodName, String... paramClasses) {
    try {
      return ctClass.getDeclaredMethod(methodName, ctClass.getClassPool().get(paramClasses));
    }
    catch (NotFoundException e) {
      return null;
    }
  }

  /**
   * Finds a declared method with the given name and descriptor prefix.
   */
  public static CtMethod getDeclaredMethodByDescriptor(CtClass ctClass, String methodName, String descriptorPrefix) throws NotFoundException {
    for (CtMethod method : ctClass.getDeclaredMethods(methodName)) {
      if (method.getSignature().startsWith(descriptorPrefix))
        return method;
    }
    throw new NotFoundException(methodName);
  }

  /**
   * Finds a declared method by name .
   *
   * @return CtMethod or null if method was not found
   */
  public static CtMethod getDeclaredMethod(CtClass ctClass, String methodName) {
    try {
      return ctClass.getDeclaredMethod(methodName);
    }
    catch (NotFoundException e) {
      return null;
    }
  }

  /**
   * Finds a declared method by name and params.
   *
   * @return CtMethod or null if method or parameter class not found
   */
  public static CtMethod getDeclaredMethod(CtClass ctClass, String methodName, Class<?>... paramClasses) {
    String[] params = new String[paramClasses.length];
    for (int i = 0; i < paramClasses.length; i++) {
      params[i] = paramClasses[i].getName();
    }
    return getDeclaredMethod(ctClass, methodName, params);
  }

  public static class MethodCallReplaceEditor extends ExprEditor {
    private final String methodName;
    private final String replaceStatement;

    public MethodCallReplaceEditor(String methodName, String replaceStatement) {
      this.methodName = methodName;
      this.replaceStatement = replaceStatement;
    }

    @Override
    public void edit(MethodCall m) throws CannotCompileException {
      if (m.getMethodName().equals(methodName)) {
        m.replace(replaceStatement);
      }
    }
  }

  private static class ValidatingMethodCallReplaceEditor extends ExprEditor {
    private final String methodName;
    private final String replaceStatement;
    private final int expectedMatchCount;
    private int matchCount = 0;

    public ValidatingMethodCallReplaceEditor(String methodName, String replaceStatement, int matchCount) {
      this.methodName = methodName;
      this.replaceStatement = replaceStatement;
      this.expectedMatchCount = matchCount;
    }

    @Override
    public void edit(MethodCall m) throws CannotCompileException {
      if (m.getMethodName().equals(methodName)) {
        matchCount++;
        m.replace(replaceStatement);
      }
    }

    @Override
    public boolean doit(CtClass clazz, MethodInfo minfo) throws CannotCompileException {
      boolean result = super.doit(clazz, minfo);
      if (matchCount != expectedMatchCount) {
        throw new CannotCompileException("Expected " + expectedMatchCount + " method calls to " + methodName +
            " in " + clazz.getName() + "#" + minfo + ", found " + matchCount);
      }
      return result;
    }
  }

  public static ExprEditor strictMethodCallReplacer(String methodName, String replaceStatement) {
    return new ValidatingMethodCallReplaceEditor(methodName, replaceStatement, 1);
  }

  /**
   * Checks if method with given name is visible from caller
   */
  public static boolean isMethodVisible(CtClass caller, String name) {
    CtClass declaring = caller;
    while (declaring != null) {
      try {
        int modifiers = declaring.getDeclaredMethod(name).getModifiers();
        if (declaring == caller || !Modifier.isPrivate(modifiers))
          return true;
      }
      catch (NotFoundException ignored) {
      }
      try {
        declaring = declaring.getSuperclass();
      }
      catch (NotFoundException e) {
        return false;
      }
    }
    return false;
  }

  /**
   * Checks if field with given name is visible from user
   */
  public static boolean isFieldVisible(CtClass user, String name) {
    CtClass declaring = user;
    while (declaring != null) {
      try {
        int modifiers = declaring.getDeclaredField(name).getModifiers();
        if (declaring == user || !Modifier.isPrivate(modifiers))
          return true;
      }
      catch (NotFoundException ignored) {
      }
      try {
        declaring = declaring.getSuperclass();
      }
      catch (NotFoundException e) {
        return false;
      }
    }
    return false;
  }

  public static void removeFinalModifer(CtClass ctClass, String field) {
    try {
      CtField f = ctClass.getDeclaredField(field);
      f.setModifiers(f.getModifiers() & ~Modifier.FINAL);
    }
    catch (NotFoundException ignored) {
    }
  }

  public static void makeFieldPublic(CtClass ctClass, String field) {
    try {
      CtField f = ctClass.getDeclaredField(field);
      f.setModifiers(Modifier.setPublic(f.getModifiers()));
    }
    catch (NotFoundException ignored) {
    }
  }

  public static void makeMethodSynthetic(CtMethod method) {
    final int ACC_SYNTHETIC = 0x1000; // From Opcodes.ACC_SYNTHETIC
    method.setModifiers(method.getModifiers() | ACC_SYNTHETIC);
  }

  /**
   * Initializes a named local var from a param. Type is copied.
   *
   * @param param is 1-based as in javassist $1, $2..
   */
  public static void addLocalForParameter(CtMethod ctMethod, int param, String name) throws NotFoundException, CannotCompileException {
    CtClass type = ctMethod.getParameterTypes()[param - 1];
    ctMethod.addLocalVariable(name, type);
    ctMethod.insertBefore(name + " = $" + param + ";");
  }

  /**
   * Initializes a named local var from a param. Type is copied.
   *
   * @param paramClass the first parameter with that class is used as value
   */
  public static void addLocalForParameter(CtMethod ctMethod, String paramClass, String name) throws NotFoundException, CannotCompileException {
    for (int i = 0; i < ctMethod.getParameterTypes().length; i++) {
      CtClass type = ctMethod.getParameterTypes()[i];
      if (paramClass.equals(type.getName())) {
        addLocalForParameter(ctMethod, i + 1, name);
        return;
      }
    }
    throw new NotFoundException("Param from class '" + paramClass + "'not found:");
  }

  /**
   * Wraps method in stopwatch.
   */
  public static void createMethodStopWatch(CtMethod method, String logPrefix, String stopWatchId) throws CannotCompileException, NotFoundException {
    if (!LoggerFactory.getLogger(logPrefix).isEnabled() || method == null) {
      return;
    }
    CtClass ctClass = method.getDeclaringClass();
    String methodName = "__" + method.getName();
    CtMethod newMethod = CtNewMethod.copy(method, methodName, ctClass, null);
    newMethod.setModifiers((newMethod.getModifiers() & ~(Modifier.PUBLIC | Modifier.PROTECTED)) | Modifier.PRIVATE);
    ctClass.addMethod(newMethod);
    CtClass returnType = method.getReturnType();
    String returnLine = CtClass.voidType.equals(returnType) ? "" : "return";
    method.setBody("" +
        "{" +
        "  org.zeroturnaround.javarebel.StopWatch jrStopWatch = " +
        "      org.zeroturnaround.javarebel.LoggerFactory.getLogger(\"" + logPrefix + "\").createStopWatch(\"" + stopWatchId + "\");" +
        "  try {" +
        "   " + returnLine + " " + methodName + "($$);" +
        "  } finally {" +
        "    jrStopWatch.stop();" +
        "  }" +
        "}");
  }

  public static boolean callsMethod(CtMethod ctMethod, final String calledMethodName) throws CannotCompileException {
    final boolean[] found = new boolean[] { false };
    ctMethod.instrument(new ExprEditor() {
      public void edit(MethodCall m) throws CannotCompileException {
        if (m.getMethodName().equals(calledMethodName)) {
          found[0] = true;
        }
      }
    });

    return found[0];
  }

  public static void importUsedPackage(final ClassPool cp, final CtClass ctClass) {
    final Set<String> allPackages = new HashSet<String>();
    for (Iterator<String> it = cp.getImportedPackages(); it.hasNext();) {
      allPackages.add(it.next());
    }
    allPackages.add(null);
    allPackages.add("");
    for (String referencedClass : ctClass.getRefClasses()) {
      if (referencedClass.startsWith("java.")) {
        continue;
      }

      int lastDot = referencedClass.lastIndexOf('.');
      if (lastDot == -1) {
        continue;
      }

      String p = referencedClass.substring(0, lastDot);
      if (allPackages.add(p)) {
        cp.importPackage(p);
      }
    }
  }

  public static void importSignatureUsedPackage(final ClassPool cp, final CtBehavior ctTarget) throws NotFoundException {
    final Set<String> allPackages = new HashSet<String>();
    for (Iterator<String> it = cp.getImportedPackages(); it.hasNext();) {
      allPackages.add(it.next());
    }
    allPackages.add(null);
    allPackages.add("");
    for (CtClass referencedClass : ctTarget.getParameterTypes()) {
      if (allPackages.add(referencedClass.getPackageName())) {
        cp.importPackage(referencedClass.getPackageName());
      }
    }
    if (ctTarget instanceof CtMethod && allPackages.add(((CtMethod) ctTarget).getReturnType().getPackageName())) {
      cp.importPackage(((CtMethod) ctTarget).getReturnType().getPackageName());
    }
    for (CtClass exceptionType : ctTarget.getExceptionTypes()) {
      final String exceptionName = exceptionType.getPackageName();
      if (allPackages.add(exceptionName)) {
        cp.importPackage(exceptionName);
      }
    }
  }

  public static String getReturnTypeFromSignature(final String signature) {
    int l = signature.lastIndexOf(")L");
    return signature.substring(l + 2, signature.length() - 1).replace('/', '.');
  }

  // Based on ASM's Type.getArgumentTypes
  public static List<String> getArgumentTypes(final String methodDescriptor) {
    List<String> argumentTypes = new ArrayList<String>();
    // Skip the first character, which is always a '('.
    int currentOffset = 1;
    // Parse and create the argument types, one at each loop iteration.
    while (methodDescriptor.charAt(currentOffset) != ')') {
      final int currentArgumentTypeOffset = currentOffset;
      while (methodDescriptor.charAt(currentOffset) == '[') {
        currentOffset++;
      }
      if (methodDescriptor.charAt(currentOffset++) == 'L') {
        // Skip the argument descriptor content.
        int semiColumnOffset = methodDescriptor.indexOf(';', currentOffset);
        currentOffset = Math.max(currentOffset, semiColumnOffset + 1);
      }
      argumentTypes.add(methodDescriptor.substring(currentArgumentTypeOffset, currentOffset));
    }
    return argumentTypes;
  }
}
