/**
 * Copyright (C) 2010 ZeroTurnaround OU
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *         http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.zeroturnaround.javarebel.integration.support;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Method;
import java.net.JarURLConnection;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLConnection;
import java.net.URLStreamHandler;
import java.util.Map;
import java.util.jar.Attributes;
import java.util.jar.Manifest;

import org.zeroturnaround.bundled.javassist.ByteArrayClassPath;
import org.zeroturnaround.bundled.javassist.CannotCompileException;
import org.zeroturnaround.bundled.javassist.ClassPath;
import org.zeroturnaround.bundled.javassist.ClassPool;
import org.zeroturnaround.bundled.javassist.CtClass;
import org.zeroturnaround.bundled.javassist.CtField;
import org.zeroturnaround.bundled.javassist.NotFoundException;
import org.zeroturnaround.bundled.javassist.bytecode.Descriptor;
import org.zeroturnaround.javarebel.ClassBytecodeProcessor;
import org.zeroturnaround.javarebel.IntegrationFactory;
import org.zeroturnaround.javarebel.Logger;
import org.zeroturnaround.javarebel.LoggerFactory;
import org.zeroturnaround.javarebel.StopWatch;
import org.zeroturnaround.javarebel.integration.util.MiscUtil;
import org.zeroturnaround.javarebel.integration.util.ReflectionUtil;
import org.zeroturnaround.javarebel.integration.util.SecurityController;
import org.zeroturnaround.javarebel.integration.util.SecurityController.PrivilegedAction;

/**
 * Javassist-based bytecode processor callback.
 * 
 * <p>
 * Any sub class is required to implement the
 * {@link #process(ClassPool, ClassLoader, CtClass)} method.
 */
public abstract class JavassistClassBytecodeProcessor implements ClassBytecodeProcessor {

  private static final Logger log = LoggerFactory.getLogger("SDK-CBP");

  private static final String FIELD = "_jr$$jrAlreadyPatched";

  private Map<String, String> implicitClassNames = null;
  private boolean avoidDuplicatePatching = false;

  static {
    RebelClassPool.init();
  }

  @Override
  public byte[] process(final ClassLoader cl, String name, byte[] bytecode) {
    if (name == null || name.length() == 0) {
      log.warn("skipping CBP because classname is empty; classloader: " + MiscUtil.identityToString(cl));
      return bytecode;
    }

    final String classname = name.replace('/', '.');
    RebelClassPool cp = buildClassPool(cl, classname, bytecode);

    StopWatch sw = log.createStopWatch("JavassistClassBytecodeProcessor");
    try {
      return process(cp, cl, classname, bytecode);
    }
    catch (Throwable e) {
      if ((e instanceof LinkageError && e.getClass().getName().startsWith("com.zeroturnaround.")) || e instanceof ClassCircularityError)
        throw (Error) e;
      log.errorEcho("Class '" + classname + "' could not be processed by " + MiscUtil.identityToString(getClass()), e);
      // Hackish solution do dump the bytecode of the failed class into the rebel.log for our debugging
      log.error(name, "DUMP", bytecode);
      String myName = getClass().getName();
      if (myName.startsWith("org.zeroturnaround.") || myName.startsWith("com.zeroturnaround."))
        reportCBPFailure(e, cl, classname);
      return bytecode;
    }
    finally {
      cp.clearCache();
      sw.stop();
    }
  }

  @Override
  public int priority() {
    return 0;
  }

  private RebelClassPool buildClassPool(ClassLoader cl, String classname, byte[] bytecode) {
    RebelClassPool cp = new RebelClassPool(systemClassPool, this, classname);
    if (log.isTraceEnabled())
      cp.appendClassPath(new RebelClassPath());
    cp.appendClassPath(new MyByteArrayClassPath(classname, bytecode));
    if (cl != null) {
      cp.appendClassPath(ignoreJava(new RestrictedLoaderClassPath(cl)));
    }
    cp.appendClassPath(ignoreJava(new RestrictedClassClassPath()));
    return cp;
  }

  protected byte[] process(ClassPool cp, ClassLoader cl, String classname, byte[] bytecode) throws Exception {
    CtClass cc = cp.get(classname);
    cc.defrost();

    if (checkAndMarkAvoidDuplicatePatching(this, cc)) {
      log.info("Class: {} already patched, skipping: {}", cc.getName(), getClass().getName());
      return bytecode;
    }

    CtClass[] interfaces = null;
    String genericSignature;
    try {
      genericSignature = cc.getGenericSignature();
    }
    catch (RuntimeException e) {
      log.warn("Got issues in signature: {}, skipping: {}", e.getMessage(), getClass().getName());
      return bytecode;
    }
    if (genericSignature != null) {
      try {
        // dependencies for processed class could be missing
        interfaces = cc.getInterfaces();
      }
      catch (NotFoundException ignore) {
      }
    }

    process(cp, cl, cc);

    // JR-10847
    // CtClass.addInterface does not update class signature. This causes problems
    // when code assumes that Class.getInterfaces and Class.getGenericInterfaces
    // return the same amount of elements. Here we add descriptors of new interfaces
    // to existing generic signature.
    if (genericSignature != null && interfaces != null) {
      CtClass[] newInterfaces = cc.getInterfaces();
      if (interfaces.length < newInterfaces.length) {
        // start from the first new interface
        for (int i = interfaces.length; i < newInterfaces.length; i++) {
          genericSignature += Descriptor.of(newInterfaces[i]);
        }
        cc.setGenericSignature(genericSignature);
      }
    }

    return cc.toBytecode();
  }

  /**
   * Can modify the class name request from javassist.
   * Used to improve performance to skip pointless name checks.
   */
  public void setImplicitClassNames(Map<String, String> implicitClassNames) {
    this.implicitClassNames = implicitClassNames;
  }

  Map<String, String> getImplicitClassNames() {
    return this.implicitClassNames;
  }

  /**
   * Says to javassist if the path can even be a class name.
   * Used to improve performance to skip pointless name checks.
   */
  public boolean acceptPathAsClass(String classname) {
    int dot = classname.lastIndexOf('.');
    char c = classname.charAt(dot + 1);
    return Character.isUpperCase(c) || c == '$' 
        || dot > 20 && classname.startsWith("com.zeroturnaround.");// our obfuscated classes
  }

  public boolean acceptPathAsPrimitive(String classname) {
    int dot = classname.lastIndexOf('.');
    return dot == -1 && ("void".equals(classname) || "byte".equals(classname) || "short".equals(classname) || "int".equals(classname) || "long".equals(classname) || "float".equals(classname) || "double".equals(classname) || "char".equals(classname) || "boolean".equals(classname));
  }

  /**
   * Modifies the class to be loaded.
   *
   * @param cp the container of <code>CtClass</code> objects.
   * @param cl the class loader loading the given class.
   * @param ctClass the class representation.
   */
  public abstract void process(ClassPool cp, ClassLoader cl, CtClass ctClass) throws Exception;

  /**
   * Avoids duplicate patching a given class when it has already been patched.
   * The mechanism for checking this is to add a field to the class with a comma-separated
   * list of CBP names that have been applied to the class and to check this list for a
   * match before running the cbp.
   */
  public JavassistClassBytecodeProcessor withDuplicatePatchingProtection() {
    avoidDuplicatePatching = true;
    return this;
  }

  /**
   * @return true - if the class been marked processed by the processor
   *         false - if the has been not been processed by the processor, also marks the class as processed by the processor
   */
  private boolean checkAndMarkAvoidDuplicatePatching(JavassistClassBytecodeProcessor processor, CtClass ctClass) throws CannotCompileException {
    if (!avoidDuplicatePatching) {
      return false;
    }
    String cbpNames;
    try {
      CtField cbpsField = ctClass.getDeclaredField(FIELD);
      cbpNames = cbpsField.getConstantValue() + ",";
      if (cbpNames.contains(processor.getClass().getName() + ",")) {
        return true;
      }
      ctClass.removeField(cbpsField);
    }
    catch (NotFoundException e) {
      cbpNames = "";
    }
    ctClass.addField(CtField.make((ctClass.isInterface() ? "public" : "private") + " static final String " + FIELD + " = \"" + cbpNames + processor.getClass().getName() + "\";", ctClass));
    return false;
  }

  /**
   * Emit log messages about sought classes.
   */
  private static class RebelClassPath implements ClassPath {

    @Override
    public URL find(String classname) {
      if (log.isTraceEnabled() && classname != null && !classname.startsWith("java.")) {
        log.trace("Searching for Javassist resource " + classname);
      }
      return null;
    }

    @Override
    public InputStream openClassfile(String classname) throws NotFoundException {
      if (log.isTraceEnabled()) {
        log.trace("Searching for Javassist class " + classname);
      }
      throw new NotFoundException(classname);
    }
  }

  private static final ClassFilter pickPackagesForNames = new ClassFilter() {
    @Override
    public final boolean accept(String classname) {
      return
      // Not need to scan the java.lang for everything
      (!classname.startsWith("java.lang") || !classname.startsWith("java.lang.Set")
          && !classname.startsWith("java.lang.HashSet")
          && !classname.startsWith("java.lang.Map")
          && !classname.startsWith("java.lang.HashMap")
          && !classname.startsWith("java.lang.WeakHashMap")
          && !classname.startsWith("java.lang.IdentityHashMap")
          && !classname.startsWith("java.lang.Vector")
          && !classname.startsWith("java.lang.List")
          && !classname.startsWith("java.lang.ArrayList")
          && !classname.startsWith("java.lang.LinkedList")
          && !classname.startsWith("java.lang.Collections")
          && !classname.startsWith("java.lang.URL")
          && !classname.startsWith("java.lang.File")
          && !classname.startsWith("java.lang.Logger")
          && !classname.startsWith("java.lang.Method")
          && !classname.startsWith("java.lang.InputStream")
          && !classname.startsWith("java.lang.Iterator")) &&

          // JRebel classes, no need to search from them in the java package
          !(classname.startsWith("java.")
          && (classname.endsWith(".WeakUtil") ||
              classname.endsWith(".StopWatch") ||
              classname.endsWith(".MonitorUtil") ||
              classname.endsWith(".MiscUtil") ||
              classname.endsWith(".LoggerFactory") ||
              classname.endsWith(".ReloaderUtil") ||
              classname.endsWith(".ResourceUtil") ||
              classname.endsWith(".ResourceUtils") ||
              classname.endsWith(".RebelSource") ||
              classname.endsWith(".ClassResourceSource") ||
              classname.endsWith(".ClassEventListener") ||
              classname.endsWith(".FileMonitorAdapter") ||
              classname.endsWith(".Integration") || classname.endsWith(".IntegrationFactory") ||
              classname.endsWith(".Reloader") || classname.endsWith(".ReloaderFactory") ||
              classname.endsWith(".RequestIntegration") || classname.endsWith(".RequestIntegrationFactory") ||
              classname.endsWith(".RebelXmlIntegration") || classname.endsWith(".RebelXmlIntegrationFactory")));
    }
  };

  private static final ClassFilter ignoreJava = new ClassFilter() {
    @Override
    public final boolean accept(String classname) {
      return !classname.startsWith("java.");
    }
  };

  private static final ClassPool systemClassPool = getSystemClassPool();

  private static ClassPool getSystemClassPool() {
    ClassPool cp = new ClassPool();
    cp.appendClassPath(new FilterClassPathWrapper(new RestrictedClassClassPath(), pickPackagesForNames));
    return cp;
  }

  /**
   * Class path wrapper that ignores all classes from java. package.
   */
  private static ClassPath ignoreJava(ClassPath cp) {
    return new FilterClassPathWrapper(cp, ignoreJava);
  }

  private static class FilterClassPathWrapper implements ClassPath {
    final ClassPath delegate;
    final ClassFilter filter;

    public FilterClassPathWrapper(ClassPath delegate, ClassFilter filter) {
      this.delegate = delegate;
      this.filter = filter;
    }

    @Override
    public URL find(String classname) {
      if (!filter.accept(classname))
        return null;

      return delegate.find(classname);
    }

    @Override
    public InputStream openClassfile(String classname) throws NotFoundException {
      if (!filter.accept(classname))
        throw new NotFoundException(classname);

      return delegate.openClassfile(classname);
    }
  }

  private interface ClassFilter {
    boolean accept(String classname);
  }

  private static void reportCBPFailure(final Throwable exception, final ClassLoader cl, final String classname) {
    SecurityController.doWithoutSecurityManager(new PrivilegedAction<Void>() {
      public Void run() {
        doReportCBPFailure(exception, cl, classname);
        return null;
      }
    });
  }

  /**
   * Report CBP failure, determine version and name of the framework if possible.
   */
  private static void doReportCBPFailure(final Throwable exception, ClassLoader cl, String classname) {
    if (cl == null)
      return;

    for (Throwable rootCause = exception; rootCause != null; rootCause = rootCause.getCause()) {
      if (rootCause instanceof OutOfMemoryError) {
        return;
      }
      if (rootCause instanceof IOException) {
        // Avoid reporting spamy FileNotFoundException, NoSuchFileException, etc
        return;
      }
    }

    String version = null;
    String fileName = null;

    // If the resource was loaded from a JAR,
    // try to get the information directly from the JAR
    URL resource = cl.getResource(classname.replace(".", "/").concat(".class"));
    if (resource != null && "jar".equals(resource.getProtocol())) {
      String jarname = resource.getFile().substring(0, resource.getFile().lastIndexOf(".jar"));
      fileName = jarname.substring(jarname.lastIndexOf('/') + 1);
      try {
        URLConnection connection = resource.openConnection();
        Manifest manifest = ((JarURLConnection) connection).getManifest();
        if (manifest != null) {
          Attributes mainAttributes = manifest.getMainAttributes();
          version = String.valueOf(mainAttributes.get(mainAttributes.get(Attributes.Name.IMPLEMENTATION_VERSION) == null ? Attributes.Name.SPECIFICATION_VERSION : Attributes.Name.IMPLEMENTATION_VERSION));
        }
      }
      catch (IOException e) {
        log.warn("Got error while trying to determine package/filename via jar for '" + classname + "'", e);
      }
    }

    // If the file based approach fails,
    // then try to determine the package via classloader
    if (fileName == null || version == null) {
      try {
        Method packageMethod = ReflectionUtil.getDeclaredMethod(ClassLoader.class, "getDefinedPackage", String.class);
        if (packageMethod == null) {// For JDK8
          packageMethod = ReflectionUtil.getDeclaredMethod(ClassLoader.class, "getPackage", String.class);
        }
        Package pck = (Package) packageMethod.invoke(cl, getPackageName(classname));
        version = pck.getImplementationVersion() != null ? pck.getImplementationVersion() : pck.getSpecificationVersion();
        if (fileName == null) {
          fileName = pck.getImplementationTitle() != null ? pck.getImplementationTitle() : pck.getSpecificationTitle();
        }
      }
      catch (Exception e) {
        log.warn("Got error while trying to determine package/filename via classloader for '" + classname + "'", e);
      }
    }

    final String finalFileName = fileName;
    final String finalVersion = version;

    Thread reportThread = new Thread(new Runnable() {
      public void run() {
        IntegrationFactory.getInstance().reportError(exception, finalVersion, finalFileName);
      }
    });
    reportThread.setName("rebel-error-reporter");
    reportThread.setDaemon(true);
    reportThread.start();
  }

  private static String getPackageName(String className) {
    int i = className.lastIndexOf('.');
    return i == -1 ? "" : className.substring(0, i);
  }

  private static class MyByteArrayClassPath extends ByteArrayClassPath {

    MyByteArrayClassPath(String classname, byte[] bytecode) {
      super(classname, bytecode);
    }

    @Override
    public URL find(final String classname) {
      if (this.classname.equals(classname)) {
        return SecurityController.doWithoutSecurityManager(new PrivilegedAction<URL>() {
          public URL run() {
            String cname = classname.replace('.', '/') + ".class";
            try {
              return new URL(null, "file:/ByteArrayClassPath/" + cname, new BytecodeURLStreamHandler());
            }
            catch (MalformedURLException e) {
            }
            return null;
          }
        });
      }
      return null;
    }

    private class BytecodeURLStreamHandler extends URLStreamHandler {

      @Override
      protected URLConnection openConnection(final URL u) {
        return new BytecodeURLConnection(u);
      }
    }

    private class BytecodeURLConnection extends URLConnection {

      protected BytecodeURLConnection(URL url) {
        super(url);
      }

      @Override
      public void connect() throws IOException {
      }

      @Override
      public InputStream getInputStream() throws IOException {
        return new ByteArrayInputStream(classfile);
      }

      @Override
      public int getContentLength() {
        return classfile.length;
      }
    }
  }
}