package org.zeroturnaround.javarebel.integration.support;

import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.List;

import org.zeroturnaround.bundled.javassist.CannotCompileException;
import org.zeroturnaround.bundled.javassist.ClassPool;
import org.zeroturnaround.bundled.javassist.CtClass;
import org.zeroturnaround.bundled.javassist.CtConstructor;
import org.zeroturnaround.bundled.javassist.CtField;
import org.zeroturnaround.bundled.javassist.CtMethod;
import org.zeroturnaround.bundled.javassist.CtNewMethod;
import org.zeroturnaround.bundled.javassist.NotFoundException;
import org.zeroturnaround.bundled.javassist.expr.ExprEditor;
import org.zeroturnaround.bundled.javassist.expr.MethodCall;

/**
 * This class provides a fluent API for processing classes. It can be used to create two types of processors: direct and
 * delayed. A {@link DirectProcessor direct processor} takes a {@link ClassPool} and a {@link CtClass} and performs all
 * the proccessing directly on the given {@link CtClass} instance. A {@link DelayedProcessor delayed processor} gathers
 * a bunch of processing commands, and can later execute them all on a given {@link CtClass}. A delayed processor is
 * also a {@link JavassistClassBytecodeProcessor}.
 * 
 * <p>The {@link #process(ClassPool, CtClass) process} method is used to create a direct processor. Methods like {@link
 * #implement(Class)} and {@link #implement(String)} are used to create delayed processors.
 * 
 * <p>Here is how one would use a direct processor:
 * <pre>{@code
 * CBPs.process(cp, ctClass)
 *     .importPackage("org.zeroturnaround.javarebel")
 *     .importPackage("org.zeroturnaround.javarebel.integration.util")
 *     
 *     .addInterface(ClassEventListener.class)
 *     
 *     .addLoggerField("jrLog", "MyPlugin")
 *     
 *     .addMethod("" +
 *         "public void onClassEvent(int eventType, Class klass) {" +
 *         "  jrLog.info(\"Class reload event received for \" + klass);" +
 *         "}")
 *     
 *     .addMethod("" +
 *         "public int priority() {" +
 *         "  return PRIORITY_DEFAULT;" +
 *         "}")
 *     
 *     .insertAfterLeafCtors("{" +
 *         "  ReloaderFactory.getInstance().addClassReloadListener(WeakUtil.weakCEL(this));" +
 *         "}");
 * }</pre>
 * 
 * Here is how one could use a delayed processor in a plugin class to implement simple one method interfaces:
 * <pre>{@code
 * integration.addIntegrationProcessor(
 *   "com.ibm.ws.jaxws.webcontainer.LibertyJaxWsServlet",
 *   CBPs.implement(JrLibertyJaxWsServlet.class)
 *       .withMethod("" +
 *           "public void setWsEndpointInfo(Object epInfo) {" +
 *           "  endpoint.setEndpointInfo(epInfo);" +
 *           "}"));
 * }</pre>
 */
public abstract class CBPs {
  
  public static DelayedProcessor implement(String intfName) {
    return new DelayedProcessor().addInterface(intfName);
  }
  
  public static DelayedProcessor implement(Class<?> intf) {
    return new DelayedProcessor().addInterface(intf);
  }
  
  public static DirectProcessor process(ClassPool cp, CtClass ctClass) {
    return new DirectProcessorImpl(cp, ctClass);
  }

  /**
   * Creates an {@link ExprEditor} that replaces a method call with given source code. Example usage:
   * <pre>{@code
   * CBPs.process(cp, ctClass)
   *     .instrument("createPersistentAutomaticTimers",
   *         CBPs.replaceMethodCall("existsTaskGroup", "$_ = false;"))
   * }</pre>
   */
  public static ExprEditor replaceMethodCall(final String methodName, final String src) {
    return new ExprEditor() {
      @Override
      public void edit(MethodCall m) throws CannotCompileException {
        if (m.getMethodName().equals(methodName))
          m.replace(src);
      }
    };
  }

  public interface DirectProcessor {
    /**
     * If the current class does not have a declared field with the given name, then all following processing commands
     * will be turned into no-ops.
     * 
     * <p>The following code will execute successfully even if the current class does not have a field named {@code
     * fooCache}, the {@code getFoo} method will simply not be instrumented.
     * <pre>{@code
     * CBPs.process(cp, ctClass)
     *     .skipIfNoField("fooCache")
     *     .insertBefore("getFoo", "fooCache.clear();");
     * }</pre>
     */
    DirectProcessor skipIfNoField(String fieldName);
    
    DirectProcessor skipIfNoMethod(String methodName);
    
    DirectProcessor skipIfNoClass(String className);
    
    DirectProcessor importPackage(String pkg);

    DirectProcessor importClass(Class<?> klass);

    DirectProcessor addInterface(String intfName) throws NotFoundException;

    DirectProcessor addInterface(Class<?> intf) throws NotFoundException;

    DirectProcessor addField(String field) throws CannotCompileException;

    DirectProcessor removeFinalFromField(String field) throws CannotCompileException;

    DirectProcessor addLoggerField(String fieldName, String loggerName) throws CannotCompileException;

    DirectProcessor addMethod(String method) throws CannotCompileException;

    DirectProcessor addLocalVariableTo(String methodName, Class<?> variableClass, String variableName, String initExpr) throws CannotCompileException, NotFoundException;
    
    DirectProcessor insertAfterClinit(String src) throws CannotCompileException;

    DirectProcessor insertBeforeLeafCtors(String src) throws CannotCompileException;

    DirectProcessor insertAfterLeafCtors(String src) throws CannotCompileException;

    DirectProcessor insertBefore(String methodName, String src) throws NotFoundException, CannotCompileException;

    DirectProcessor insertBefore(String methodName, CtClass[] paramTypes, String src) throws NotFoundException, CannotCompileException;

    DirectProcessor insertAfter(String methodName, String src) throws NotFoundException, CannotCompileException;

    /**
     * Instruments a declared method with the given name. If there are multiple methods with the same name,
     * only one of them will be instrumented.
     * 
     * @see #instrumentAll(String, ExprEditor)
     */
    DirectProcessor instrument(String methodName, ExprEditor exprEditor) throws NotFoundException, CannotCompileException;

    /**
     * Instruments all declared methods with the given name.
     */
    DirectProcessor instrumentAll(String methodName, ExprEditor exprEditor) throws NotFoundException, CannotCompileException;
  }

  private static class DirectProcessorImpl implements DirectProcessor {
    private ClassPool cp;
    private CtClass ctClass;

    DirectProcessorImpl(ClassPool cp, CtClass ctClass) {
      this.cp = cp;
      this.ctClass = ctClass;
    }

    public DirectProcessor skipIfNoField(String fieldName) {
      try {
        ctClass.getDeclaredField(fieldName);
      }
      catch (NotFoundException e) {
        return NopDirectProcessor.INSTANCE;
      }
      return this;
    }

    public DirectProcessor skipIfNoMethod(String methodName) {
      try {
        ctClass.getDeclaredMethod(methodName);
      }
      catch (NotFoundException e) {
        return NopDirectProcessor.INSTANCE;
      }
      return this;
    }

    public DirectProcessor skipIfNoClass(String className) {
      if (cp.getOrNull(className) == null)
        return NopDirectProcessor.INSTANCE;
      return this;
    }

    public DirectProcessor importPackage(String pkg) {
      cp.importPackage(pkg);
      return this;
    }

    public DirectProcessor addInterface(String intfName) throws NotFoundException {
      ctClass.addInterface(cp.get(intfName));
      return this;
    }

    public DirectProcessor addInterface(Class<?> intf) throws NotFoundException {
      ctClass.addInterface(cp.get(intf.getName()));
      return this;
    }

    public DirectProcessor addMethod(String method) throws CannotCompileException {
      ctClass.addMethod(CtNewMethod.make(method, ctClass));
      return this;
    }

    public DirectProcessor addLocalVariableTo(String methodName, Class<?> variableClass, String variableName, String initExpr) throws CannotCompileException, NotFoundException {
      CtMethod declaredMethod = ctClass.getDeclaredMethod(methodName);
      declaredMethod.addLocalVariable(variableName, cp.get(variableClass.getName()));
      declaredMethod.insertBefore(variableName + " = " + initExpr + ";");
      return this;
    }

    public DirectProcessor insertAfterClinit(String src) throws CannotCompileException {
      ctClass.makeClassInitializer().insertAfter(src);
      return this;
    }

    public DirectProcessor insertAfterLeafCtors(String src) throws CannotCompileException {
      for (CtConstructor ctor : ctClass.getDeclaredConstructors()) {
        if (ctor.callsSuper())
          ctor.insertAfter(src);
      }
      return this;
    }
    
    public DirectProcessor instrument(String methodName, ExprEditor exprEditor) throws NotFoundException, CannotCompileException {
      ctClass.getDeclaredMethod(methodName).instrument(exprEditor);
      return this;
    }

    public DirectProcessor instrumentAll(String methodName, ExprEditor exprEditor) throws NotFoundException, CannotCompileException {
      CtMethod[] methods = ctClass.getDeclaredMethods(methodName);
      for (CtMethod method : methods) {
        method.instrument(exprEditor);
      }
      return this;
    }

    public DirectProcessor importClass(Class<?> klass) {
      cp.importPackage(klass.getName());
      return this;
    }

    public DirectProcessor addField(String field) throws CannotCompileException {
      ctClass.addField(CtField.make(field, ctClass));
      return this;
    }

    public DirectProcessor removeFinalFromField(String field) {
      try {
        CtField declaredField = ctClass.getDeclaredField(field);
        declaredField.setModifiers(declaredField.getModifiers() & ~Modifier.FINAL);
      }
      catch (NotFoundException e) {
        return this;
      }
      return this;
    }

    public DirectProcessor addLoggerField(String fieldName, String loggerName) throws CannotCompileException {
      return addField("private final static org.zeroturnaround.javarebel.Logger " + fieldName
          + " = org.zeroturnaround.javarebel.LoggerFactory.getLogger(\"" + loggerName + "\");");
    }

    public DirectProcessor insertBefore(String methodName, String src) throws NotFoundException, CannotCompileException {
      ctClass.getDeclaredMethod(methodName).insertBefore(src);
      return this;
    }

    public DirectProcessor insertBefore(String methodName, CtClass[] paramTypes, String src) throws NotFoundException, CannotCompileException {
      ctClass.getDeclaredMethod(methodName, paramTypes).insertBefore(src);
      return this;
    }

    public DirectProcessor insertAfter(String methodName, String src) throws NotFoundException, CannotCompileException {
      ctClass.getDeclaredMethod(methodName).insertAfter(src);
      return this;
    }

    public DirectProcessor insertBeforeLeafCtors(String src) throws CannotCompileException {
      for (CtConstructor ctor : ctClass.getConstructors()) {
        if (ctor.callsSuper())
          ctor.insertBeforeBody(src);
      }
      return this;
    }
  }

  /**
   * A no-op implementation of a direct processor. Used by the {@code skip*} methods in {@link DirectProcessor}.
   */
  private static class NopDirectProcessor implements DirectProcessor {
    
    static final NopDirectProcessor INSTANCE = new NopDirectProcessor();

    public DirectProcessor skipIfNoField(String fieldName) {
      return this;
    }

    public DirectProcessor skipIfNoMethod(String methodName) {
      return this;
    }

    public DirectProcessor skipIfNoClass(String className) {
      return this;
    }

    public DirectProcessor importPackage(String pkg) {
      return this;
    }

    public DirectProcessor importClass(Class<?> klass) {
      return this;
    }

    public DirectProcessor addInterface(String intfName) throws NotFoundException {
      return this;
    }

    public DirectProcessor addInterface(Class<?> intf) throws NotFoundException {
      return this;
    }

    public DirectProcessor addField(String field) throws CannotCompileException {
      return this;
    }

    public DirectProcessor removeFinalFromField(String field) throws CannotCompileException {
      return this;
    }

    public DirectProcessor addLoggerField(String fieldName, String loggerName) throws CannotCompileException {
      return this;
    }

    public DirectProcessor addMethod(String method) throws CannotCompileException {
      return this;
    }

    public DirectProcessor addLocalVariableTo(String methodName, Class<?> variableClass, String variableName, String initExpr) throws CannotCompileException, NotFoundException {
      return this;
    }
    
    public DirectProcessor insertAfterClinit(String src) {
      return this;
    }

    public DirectProcessor insertBeforeLeafCtors(String src) throws CannotCompileException {
      return this;
    }

    public DirectProcessor insertAfterLeafCtors(String src) throws CannotCompileException {
      return this;
    }

    public DirectProcessor insertBefore(String methodName, String src) throws NotFoundException, CannotCompileException {
      return this;
    }

    public DirectProcessor insertBefore(String methodName, CtClass[] paramTypes, String src) throws NotFoundException, CannotCompileException {
      return this;
    }

    public DirectProcessor insertAfter(String methodName, String src) throws NotFoundException, CannotCompileException {
      return this;
    }

    public DirectProcessor instrument(String methodName, ExprEditor exprEditor) throws NotFoundException, CannotCompileException {
      return this;
    }

    public DirectProcessor instrumentAll(String methodName, ExprEditor exprEditor) throws NotFoundException, CannotCompileException {
      return this;
    }
  }
  
  public static class DelayedProcessor extends JavassistClassBytecodeProcessor {

    private List<String> intfs = new ArrayList<String>();
    private List<String> imports = new ArrayList<String>();
    private List<String> methods = new ArrayList<String>();

    public DelayedProcessor importPackage(String pkg) {
      imports.add(pkg);
      return this;
    }

    public DelayedProcessor addInterface(String intfName) {
      intfs.add(intfName);
      return this;
    }

    public DelayedProcessor addInterface(Class<?> intf) {
      intfs.add(intf.getName());
      return this;
    }

    public DelayedProcessor addMethod(String method) {
      methods.add(method);
      return this;
    }

    public DelayedProcessor withMethod(String method) {
      return addMethod(method);
    }

    @Override
    public void process(ClassPool cp, ClassLoader cl, CtClass ctClass) throws Exception {
      DirectProcessor processor = CBPs.process(cp, ctClass);
      
      for (String anImport : imports) {
        processor = processor.importPackage(anImport);
      }

      for (String intf : intfs) {
        processor = processor.addInterface(intf);
      }

      for (String method : methods) {
        processor = processor.addMethod(method);
      }
    }
  }
}
