/**
 * Copyright (C) 2010 ZeroTurnaround OU
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *         http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.zeroturnaround.javarebel.integration.util;

import java.util.Map;
import java.util.WeakHashMap;

import org.zeroturnaround.javarebel.Logger;
import org.zeroturnaround.javarebel.LoggerFactory;
import org.zeroturnaround.javarebel.RebelServletContext;
import org.zeroturnaround.javarebel.RequestIntegration;
import org.zeroturnaround.javarebel.RequestIntegrationFactory;
import org.zeroturnaround.javarebel.RequestListener;

/**
 * Helper methods for counting currently running HTTP requests.
 *
 * @author Rein Raudjärv
 *
 * @see #waitUntilNoRunning(RebelServletContext, long)
 */
public class RequestCountingUtil {
  private static final Logger log = LoggerFactory.getLogger("RequestCountUtil");

  private static final RequestIntegration integration = RequestIntegrationFactory.getInstance();

  private static final Map<RebelServletContext, CountingRequestListener> listeners = new WeakHashMap<RebelServletContext, CountingRequestListener>();

  public static synchronized void register(RebelServletContext context) {
    CountingRequestListener listener = new CountingRequestListener();
    listeners.put(context, listener);
    integration.addRequestListener(context, listener);
  }

  public static synchronized void unregister(RebelServletContext context) {
    CountingRequestListener listener = listeners.remove(context);
    if (listener != null)
      integration.removeRequestListener(context, listener);
  }

  public static int countRunning(RebelServletContext context) {
    return get(context).counter();
  }

  /**
   * Waits until there are no requests currently running.
   *
   * @param context Servlet Context of the application.
   * @param timeout maximum time to wait for.
   * @return <code>true</code> if there are no requests running, <code>false</code> if timeout was reached.
   */
  public static boolean waitUntilNoRunning(RebelServletContext context, long timeout) throws InterruptedException {
    return get(context).waitForZero(timeout);
  }

  private static synchronized CountingRequestListener get(RebelServletContext context) {
    CountingRequestListener result = listeners.get(context);
    if (result == null)
      throw new IllegalArgumentException("Unknown servlet context " + MiscUtil.identityToString(context));
    return result;
  }

  private static class CountingRequestListener implements RequestListener {
    private volatile int counter;
    public boolean rawRequest(Object context, Object request, Object response) {
      return false; // do nothing
    }
    public synchronized void beforeRequest() {
      counter++;
    }
    public synchronized void requestFinally() {
      counter--;
      if (counter == 0)
        notifyAll();
    }
    public int priority() {
      return 0;
    }
    public synchronized int counter() {
      return counter;
    }
    public synchronized boolean waitForZero(long timeout) throws InterruptedException {
      if (counter == 0) {
        log.log("No running requests (did not wait).");
        return true;
      }

      log.log("Waiting until no running requests (current count: " + counter + ", timeout: " + timeout + " ms)...");
      long start = System.currentTimeMillis();
      long end = start + timeout;
      do {
        wait(end - System.currentTimeMillis());
        if (counter == 0) {
          long time = System.currentTimeMillis() - start;
          log.log("No running requests (waited for " + time + " ms).");
          return true;
        }
      }
      while (System.currentTimeMillis() < end);
      log.log(counter + " requests still running (timeout reached).");
      return false;
    }
  }

}
