/**
 * Copyright (C) 2010 ZeroTurnaround OU
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *         http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.zeroturnaround.javarebel.integration.util;

import java.lang.reflect.Method;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.WeakHashMap;

import org.zeroturnaround.javarebel.Logger;
import org.zeroturnaround.javarebel.LoggerFactory;
import org.zeroturnaround.javarebel.RebelServletContext;
import org.zeroturnaround.javarebel.RequestIntegration;
import org.zeroturnaround.javarebel.RequestIntegrationFactory;
import org.zeroturnaround.javarebel.RequestListener;

/**
 * Helper methods for blocking HTTP requests.
 *
 * @author Rein Raudjärv
 */
public class RequestBlockingUtil {

  private static final Logger log = LoggerFactory.getLogger("RequestBlockUtil");

  private static final RequestIntegration integration = RequestIntegrationFactory.getInstance();

  private static final Map<RebelServletContext, BlockingRequestListener> listeners = new WeakHashMap<RebelServletContext, BlockingRequestListener>();

  public static synchronized void block(RebelServletContext context) {
    BlockingRequestListener listener = new BlockingRequestListener();
    addListener(context, listener);
    log.log("Blocked new requests of " + MiscUtil.identityToString(context));
  }

  public static synchronized void blockExcludingLocalhost(RebelServletContext context) {
    BlockingRequestListener localhostExcludingListener = new LocalhostExcludingBlockingRequestListener();
    addListener(context, localhostExcludingListener);
    log.log("Blocked new non-localhost requests of " + MiscUtil.identityToString(context));
  }

  private static void addListener(RebelServletContext context, BlockingRequestListener listener) {
    listeners.put(context, listener);
    if (context == null) {
      integration.addRequestListener(listener);
    }
    else {
      integration.addRequestListener(context, listener);
    }
  }

  public static synchronized boolean unblock(RebelServletContext context) {
    BlockingRequestListener listener = (BlockingRequestListener) listeners.remove(context);
    if (listener == null) {
      log.log("Could not unblock requests of " + MiscUtil.identityToString(context));
      return false;
    }

    if (context == null) {
      integration.removeRequestListener(listener);
    }
    else {
      integration.removeRequestListener(context, listener);
    }
    listener.unblock();
    log.log("Unblocked requests of " + MiscUtil.identityToString(context));
    return true;
  }

  private static class BlockingRequestListener implements RequestListener {

    private volatile boolean block = true;

    public boolean rawRequest(Object context, Object request, Object response) {
      String name = Thread.currentThread().getName();
      if (block) {
        log.log("Blocking request in thread [" + name + "]...");
        synchronized (this) {
          try {
            while (block) {
              wait();
            }
          }
          catch (InterruptedException e) {
            Thread.currentThread().interrupt();
          }
        }
        log.log("Resumed request in thread [" + name + "].");
      }
      return false;
    }
    public void beforeRequest() {
      // do nothing
    }
    public void requestFinally() {
      // do nothing
    }
    public int priority() {
      return 0;
    }
    public void unblock() {
      if (block) {
        synchronized (this) {
          block = false;
          notifyAll();
        }
      }
    }
  }

  private static class LocalhostExcludingBlockingRequestListener extends BlockingRequestListener {

    private final Set<String> localAddresses = new TreeSet<String>();

    LocalhostExcludingBlockingRequestListener() {
      try {
        InetAddress[] addresses = InetAddress.getAllByName("localhost");

        for (int i = 0; i < addresses.length; i++) {
          InetAddress address = addresses[i];
          String host = address.getHostAddress();
          localAddresses.add(host);
        }

        log.log("Local addresses: " + localAddresses);
      }
      catch (UnknownHostException e) {
        log.error(e);
      }
    }

    public boolean rawRequest(Object context, Object request, Object response) {
      Class<?> requestClass = request.getClass();

      // The HttpServletRequest interface is not available on the class path, so we use reflection
      if (classImplements(requestClass, "javax.servlet.http.HttpServletRequest", "jakarta.servlet.http.HttpServletRequest")) {
        try {
          Method method = requestClass.getMethod("getRemoteHost");
          String host = (String) method.invoke(request, new Object[] {});

          // IPv6 addresses may have a zone index at the end that we need to remove
          if (host.indexOf('%') != -1) {
            host = host.substring(0, host.indexOf('%'));
          }

          if (localAddresses.contains(host)) {
            log.log("Got request from local host '" + host + "', excluding from request blocker.");
            return false;
          }
          else {
            log.log("Got request from non-local host '" + host + "', blocking.");
          }
        }
        catch (Exception e) {
          log.error(e);
        }
      }

      return super.rawRequest(context, request, response);
    }

    private boolean classImplements(Class<?> classObj, String javaxInterfaceName, String jakartaInterfaceName) {
      Class<?>[] interfaces = classObj.getInterfaces();

      for (int i = 0; i < interfaces.length; i++) {
        String interfaceName = interfaces[i].getName();
        if (interfaceName.equals(javaxInterfaceName) || interfaceName.equals(jakartaInterfaceName)) {
          return true;
        }
      }

      return false;
    }

  }

}
