Package org.apache.tez.runtime.library.shuffle.common

Source Code of org.apache.tez.runtime.library.shuffle.common.HttpConnection

/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements.  See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership.  The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.tez.runtime.library.shuffle.common;

import java.io.BufferedInputStream;
import java.io.DataInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.HttpURLConnection;
import java.net.URL;
import java.security.GeneralSecurityException;
import java.util.concurrent.TimeUnit;

import javax.crypto.SecretKey;
import javax.net.ssl.HttpsURLConnection;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.io.IOUtils;
import org.apache.hadoop.security.ssl.SSLFactory;
import org.apache.tez.runtime.library.common.security.SecureShuffleUtils;
import org.apache.tez.runtime.library.common.shuffle.impl.ShuffleHeader;

import com.google.common.base.Stopwatch;

/**
* HttpConnection which can be used for Unordered / Ordered shuffle.
*/
public class HttpConnection {

  private static final Log LOG = LogFactory.getLog(HttpConnection.class);

  /** Basic/unit connection timeout (in milliseconds) */
  private final static int UNIT_CONNECT_TIMEOUT = 60 * 1000;

  private URL url;
  private final String logIdentifier;

  //Shared by many threads
  private static SSLFactory sslFactory;

  private HttpURLConnection connection;
  private DataInputStream input;

  private boolean connectionSucceeed;
  private volatile boolean cleanup;

  private final SecretKey jobTokenSecret;
  private String encHash;
  private String msgToEncode;

  private final HttpConnectionParams httpConnParams;
  private final Stopwatch stopWatch;

  /**
   * HttpConnection
   *
   * @param url
   * @param connParams
   * @param logIdentifier
   * @param jobTokenSecret
   * @throws IOException
   */
  public HttpConnection(URL url, HttpConnectionParams connParams,
      String logIdentifier, SecretKey jobTokenSecret) throws IOException {
    this.logIdentifier = logIdentifier;
    this.jobTokenSecret = jobTokenSecret;
    this.httpConnParams = connParams;
    this.url = url;
    this.stopWatch = new Stopwatch();
    if (LOG.isDebugEnabled()) {
      LOG.debug("MapOutput URL :" + url.toString());
    }
  }

  private void setupConnection() throws IOException {
    connection = (HttpURLConnection) url.openConnection();
    if (sslFactory != null) {
      try {
        ((HttpsURLConnection) connection).setSSLSocketFactory(sslFactory
          .createSSLSocketFactory());
        ((HttpsURLConnection) connection).setHostnameVerifier(sslFactory
          .getHostnameVerifier());
      } catch (GeneralSecurityException ex) {
        throw new IOException(ex);
      }
    }
    // generate hash of the url
    msgToEncode = SecureShuffleUtils.buildMsgFrom(url);
    encHash = SecureShuffleUtils.hashFromString(msgToEncode, jobTokenSecret);

    // put url hash into http header
    connection.addRequestProperty(SecureShuffleUtils.HTTP_HEADER_URL_HASH,
      encHash);
    // set the read timeout
    connection.setReadTimeout(httpConnParams.readTimeout);
    // put shuffle version into http header
    connection.addRequestProperty(ShuffleHeader.HTTP_HEADER_NAME,
      ShuffleHeader.DEFAULT_HTTP_HEADER_NAME);
    connection.addRequestProperty(ShuffleHeader.HTTP_HEADER_VERSION,
      ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION);
  }

  /**
   * Connect to source
   *
   * @return true if connection was successful
   *         false if connection was previously cleaned up
   * @throws IOException upon connection failure
   */
  public boolean connect() throws IOException {
    return connect(httpConnParams.connectionTimeout);
  }

  /**
   * Connect to source with specific timeout
   *
   * @param connectionTimeout
   * @return true if connection was successful
   *         false if connection was previously cleaned up
   * @throws IOException upon connection failure
   */
  public boolean connect(int connectionTimeout) throws IOException {
    stopWatch.reset().start();
    if (connection == null) {
      setupConnection();
    }
    int unit = 0;
    if (connectionTimeout < 0) {
      throw new IOException("Invalid timeout " + "[timeout = "
          + connectionTimeout + " ms]");
    } else if (connectionTimeout > 0) {
      unit = Math.min(UNIT_CONNECT_TIMEOUT, connectionTimeout);
    }
    // set the connect timeout to the unit-connect-timeout
    connection.setConnectTimeout(unit);
    int connectionFailures = 0;
    while (true) {
      try {
        connection.connect();
        connectionSucceeed = true;
        break;
      } catch (IOException ioe) {
        // Don't attempt another connect if already cleanedup.
        if (!cleanup) {
          LOG.info("Cleanup is set to true. Not attempting to"
              + " connect again. Last exception was: ["
              + ioe.getClass().getName() + ", " + ioe.getMessage() + "]");
          return false;
        }
        // update the total remaining connect-timeout
        connectionTimeout -= unit;
        // throw an exception if we have waited for timeout amount of time
        // note that the updated value if timeout is used here
        if (connectionTimeout == 0) {
          throw ioe;
        }
        // reset the connect timeout for the last try
        if (connectionTimeout < unit) {
          unit = connectionTimeout;
          // reset the connect time out for the final connect
          connection.setConnectTimeout(unit);
        }
        connectionFailures++;
      }
    }
    if (LOG.isDebugEnabled()) {
      LOG.debug("Time taken to connect to " + url.toString() +
        " " + stopWatch.elapsedTime(TimeUnit.MILLISECONDS) + " ms; connectionFailures="+ connectionFailures);
    }
    return true;
  }

  public void validate() throws IOException {
    stopWatch.reset().start();
    int rc = connection.getResponseCode();
    if (rc != HttpURLConnection.HTTP_OK) {
      throw new IOException("Got invalid response code " + rc + " from " + url
          + ": " + connection.getResponseMessage());
    }
    // get the shuffle version
    if (!ShuffleHeader.DEFAULT_HTTP_HEADER_NAME.equals(connection
      .getHeaderField(ShuffleHeader.HTTP_HEADER_NAME))
        || !ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION.equals(connection
          .getHeaderField(ShuffleHeader.HTTP_HEADER_VERSION))) {
      throw new IOException("Incompatible shuffle response version");
    }
    // get the replyHash which is HMac of the encHash we sent to the server
    String replyHash =
        connection
          .getHeaderField(SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH);
    if (replyHash == null) {
      throw new IOException("security validation of TT Map output failed");
    }
    LOG.debug("url=" + msgToEncode + ";encHash=" + encHash + ";replyHash="
        + replyHash);
    // verify that replyHash is HMac of encHash
    SecureShuffleUtils.verifyReply(replyHash, encHash, jobTokenSecret);
    LOG.info("for url=" + msgToEncode +
      " sent hash and receievd reply " + stopWatch.elapsedTime(TimeUnit.MILLISECONDS) + " ms");
  }

  /**
   * Get the inputstream from the connection
   *
   * @return DataInputStream
   * @throws IOException
   */
  public DataInputStream getInputStream() throws IOException {
    stopWatch.reset().start();
    DataInputStream input = null;
    if (connectionSucceeed) {
      input =
          new DataInputStream(new BufferedInputStream(
            connection.getInputStream(), httpConnParams.bufferSize));
    }
    if (LOG.isDebugEnabled()) {
      LOG.debug("Time taken to getInputStream (connect) " + url +
        " " + stopWatch.elapsedTime(TimeUnit.MILLISECONDS) + " ms");
    }
    return input;
  }

  /**
   * Cleanup ssl factory. Should be called after all threads are shutdown.
   */
  public synchronized static void cleanupSSLFactory() {
    if (sslFactory != null) {
      sslFactory.destroy();
      sslFactory = null;
    }
  }

  /**
   * Cleanup the connection.
   *
   * @param disconnect
   *          Close the connection if this is true; otherwise respect keepalive
   * @throws IOException
   */
  public void cleanup(boolean disconnect) throws IOException {
    cleanup = true;
    stopWatch.reset().start();
    try {
      if (input != null) {
        LOG.info("Closing input on " + logIdentifier);
        input.close();
      }
      if (httpConnParams.keepAlive && connectionSucceeed) {
        // Refer:
        // http://docs.oracle.com/javase/6/docs/technotes/guides/net/http-keepalive.html
        readErrorStream(connection.getErrorStream());
      }
      if (connection != null && (disconnect || !httpConnParams.keepAlive)) {
        LOG.info("Closing connection on " + logIdentifier);
        connection.disconnect();
      }
    } catch (IOException e) {
      if (LOG.isDebugEnabled()) {
        LOG.debug("Exception while shutting down fetcher " + logIdentifier, e);
      } else {
        LOG.info("Exception while shutting down fetcher " + logIdentifier
            + ": " + e.getMessage());
      }
    }
    if (LOG.isDebugEnabled()) {
      LOG.debug("Time taken to cleanup connection to " + url +
        " " + stopWatch.elapsedTime(TimeUnit.MILLISECONDS) + " ms");
    }
  }

  /**
   * Cleanup the error stream if any, for keepAlive connections
   *
   * @param errorStream
   */
  private void readErrorStream(InputStream errorStream) {
    if (errorStream == null) {
      return;
    }
    try {
      DataOutputBuffer errorBuffer = new DataOutputBuffer();
      IOUtils.copyBytes(errorStream, errorBuffer, 4096);
      IOUtils.closeStream(errorBuffer);
      IOUtils.closeStream(errorStream);
    } catch (IOException ioe) {
      // ignore
    }
  }

  public static class HttpConnectionParams {
    private boolean keepAlive;
    private int keepAliveMaxConnections;
    private int connectionTimeout;
    private int readTimeout;
    private int bufferSize;
    private boolean sslShuffle;

    public boolean getKeepAlive() {
      return keepAlive;
    }

    public int getKeepAliveMaxConnections() {
      return keepAliveMaxConnections;
    }

    public int getConnectionTimeout() {
      return connectionTimeout;
    }

    public int getReadTimeout() {
      return readTimeout;
    }

    public void setReadTimeout(int readTimeout) {
      this.readTimeout = readTimeout;
    }

    public int getBufferSize() {
      return bufferSize;
    }

    public boolean isSSLShuffleEnabled() {
      return sslShuffle;
    }

    public String toString() {
      StringBuilder sb = new StringBuilder();
      sb.append("keepAlive=").append(keepAlive).append(", ");
      sb.append("keepAliveMaxConnections=").append(keepAliveMaxConnections).append(", ");
      sb.append("connectionTimeout=").append(connectionTimeout).append(", ");
      sb.append("readTimeout=").append(readTimeout).append(", ");
      sb.append("bufferSize=").append(bufferSize).append(", ");
      sb.append("sslShuffle=").append(sslShuffle);
      return sb.toString();
    }
  }

  public static class HttpConnectionParamsBuilder {
    private HttpConnectionParams params;

    public HttpConnectionParamsBuilder() {
      params = new HttpConnectionParams();
    }

    public HttpConnectionParamsBuilder setKeepAlive(boolean keepAlive,
        int keepAliveMaxConnections) {
      params.keepAlive = keepAlive;
      params.keepAliveMaxConnections = keepAliveMaxConnections;
      return this;
    }

    public HttpConnectionParamsBuilder setTimeout(int connectionTimeout,
        int readTimeout) {
      params.connectionTimeout = connectionTimeout;
      params.readTimeout = readTimeout;
      return this;
    }

    public synchronized HttpConnectionParamsBuilder setSSL(boolean sslEnabled,
        Configuration conf) {
      params.sslShuffle = sslEnabled;
      if(sslEnabled) {
        //Create sslFactory if it is null or if it was destroyed earlier
        if (sslFactory == null || sslFactory.getKeystoresFactory()
            .getTrustManagers() == null) {
          LOG.info("Initializing SSL factory in HttpConnection");
          sslFactory = new SSLFactory(SSLFactory.Mode.CLIENT, conf);
          try {
            sslFactory.init();
          } catch (Exception ex) {
            sslFactory.destroy();
            sslFactory = null;
            throw new RuntimeException(ex);
          }
        }
      } else {
        //Defensive: In case SSL was initied earlier, clean it up.
        cleanupSSLFactory();
      }
      return this;
    }

    public HttpConnectionParamsBuilder setBufferSize(int bufferSize) {
      params.bufferSize = bufferSize;
      return this;
    }

    public HttpConnectionParams build() {
      return params;
    }
  }
}
TOP

Related Classes of org.apache.tez.runtime.library.shuffle.common.HttpConnection

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.