Package org.glassfish.grizzly.ssl

Source Code of org.glassfish.grizzly.ssl.SSLBaseFilter$OnWriteCopyCloner

/*
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS HEADER.
*
* Copyright (c) 2012-2013 Oracle and/or its affiliates. All rights reserved.
*
* The contents of this file are subject to the terms of either the GNU
* General Public License Version 2 only ("GPL") or the Common Development
* and Distribution License("CDDL") (collectively, the "License").  You
* may not use this file except in compliance with the License.  You can
* obtain a copy of the License at
* https://glassfish.dev.java.net/public/CDDL+GPL_1_1.html
* or packager/legal/LICENSE.txt.  See the License for the specific
* language governing permissions and limitations under the License.
*
* When distributing the software, include this License Header Notice in each
* file and include the License file at packager/legal/LICENSE.txt.
*
* GPL Classpath Exception:
* Oracle designates this particular file as subject to the "Classpath"
* exception as provided by Oracle in the GPL Version 2 section of the License
* file that accompanied this code.
*
* Modifications:
* If applicable, add the following below the License Header, with the fields
* enclosed by brackets [] replaced by your own identifying information:
* "Portions Copyright [year] [name of copyright owner]"
*
* Contributor(s):
* If you wish your version of this file to be governed by only the CDDL or
* only the GPL Version 2, indicate your decision by adding "[Contributor]
* elects to include this software in this distribution under the [CDDL or GPL
* Version 2] license."  If you don't indicate a single choice of license, a
* recipient has the option to distribute your version of this file under
* either the CDDL, the GPL Version 2 or to extend the choice of license to
* its licensees as provided above.  However, if you add GPL Version 2 code
* and therefore, elected the GPL Version 2 license, then the option applies
* only if the new code is made subject to such option by the copyright
* holder.
*/

package org.glassfish.grizzly.ssl;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.security.cert.Certificate;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.util.Collections;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.logging.Filter;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLEngineResult.Status;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLHandshakeException;
import org.glassfish.grizzly.Buffer;
import org.glassfish.grizzly.Connection;
import org.glassfish.grizzly.Context;
import org.glassfish.grizzly.FileTransfer;
import org.glassfish.grizzly.Grizzly;
import org.glassfish.grizzly.IOEvent;
import org.glassfish.grizzly.IOEventLifeCycleListener;
import org.glassfish.grizzly.ProcessorExecutor;
import org.glassfish.grizzly.ReadResult;
import org.glassfish.grizzly.asyncqueue.MessageCloner;
import org.glassfish.grizzly.filterchain.BaseFilter;
import org.glassfish.grizzly.filterchain.FilterChain;
import org.glassfish.grizzly.filterchain.FilterChainContext;
import org.glassfish.grizzly.filterchain.FilterChainContext.Operation;
import org.glassfish.grizzly.filterchain.FilterChainEvent;
import org.glassfish.grizzly.filterchain.NextAction;
import org.glassfish.grizzly.filterchain.TransportFilter;
import org.glassfish.grizzly.memory.Buffers;
import org.glassfish.grizzly.memory.CompositeBuffer;
import org.glassfish.grizzly.memory.MemoryManager;
import org.glassfish.grizzly.ssl.SSLConnectionContext.Allocator;
import org.glassfish.grizzly.utils.DataStructures;

import static org.glassfish.grizzly.ssl.SSLUtils.*;

/**
* SSL {@link Filter} to operate with SSL encrypted data.
*
* @author Alexey Stashok
*/
public class SSLBaseFilter extends BaseFilter {
    private static final Logger LOGGER = Grizzly.logger(SSLBaseFilter.class);
    protected static final MessageCloner<Buffer> COPY_CLONER = new OnWriteCopyCloner();

    private static final Allocator MM_ALLOCATOR = new Allocator() {
        @Override
        @SuppressWarnings("unchecked")
        public Buffer grow(final SSLConnectionContext sslCtx,
            final Buffer oldBuffer, final int newSize) {
            final MemoryManager mm = sslCtx.getConnection()
                    .getTransport().getMemoryManager();
           
            if (oldBuffer == null) {
                return mm.allocate(newSize);
            } else {
                return mm.reallocate(oldBuffer,
                        newSize);
            }
        }
    };
   
    private static final SSLConnectionContext.Allocator OUTPUT_BUFFER_ALLOCATOR =
            new SSLConnectionContext.Allocator() {
        @Override
        public Buffer grow(final SSLConnectionContext sslCtx,
            final Buffer oldBuffer, final int newSize) {
           
            return allocateOutputBuffer(newSize);
        }
    };
   
    private final SSLEngineConfigurator serverSSLEngineConfigurator;
    private final boolean renegotiateOnClientAuthWant;

    protected final Set<HandshakeListener> handshakeListeners =
            Collections.newSetFromMap(
            DataStructures.<HandshakeListener, Boolean>getConcurrentMap(2));
   
    private long handshakeTimeoutMillis = -1;
   
    // ------------------------------------------------------------ Constructors


    public SSLBaseFilter() {
        this(null);
    }

    /**
     * Build <tt>SSLFilter</tt> with the given {@link SSLEngineConfigurator}.
     *
     * @param serverSSLEngineConfigurator SSLEngine configurator for server side connections
     */
    public SSLBaseFilter(SSLEngineConfigurator serverSSLEngineConfigurator) {
        this(serverSSLEngineConfigurator, true);
    }


    /**
     * Build <tt>SSLFilter</tt> with the given {@link SSLEngineConfigurator}.
     *
     * @param serverSSLEngineConfigurator SSLEngine configurator for server side connections
     */
    public SSLBaseFilter(SSLEngineConfigurator serverSSLEngineConfigurator,
                     boolean renegotiateOnClientAuthWant) {
       
        this.renegotiateOnClientAuthWant = renegotiateOnClientAuthWant;
        if (serverSSLEngineConfigurator == null) {
            this.serverSSLEngineConfigurator = new SSLEngineConfigurator(
                    SSLContextConfigurator.DEFAULT_CONFIG.createSSLContext(),
                    false, false, false);
        } else {
            this.serverSSLEngineConfigurator = serverSSLEngineConfigurator;
        }
    }

    public void addHandshakeListener(final HandshakeListener listener) {
        handshakeListeners.add(listener);
    }
   
    public void removeHandshakeListener(final HandshakeListener listener) {
        handshakeListeners.remove(listener);
    }

    /**
     * Returns the handshake timeout, <code>-1</code> if blocking handshake mode
     * is disabled (default).
     */
    public long getHandshakeTimeout(final TimeUnit timeUnit) {
        if (handshakeTimeoutMillis < 0) {
            return -1;
        }
       
        return timeUnit.convert(handshakeTimeoutMillis, TimeUnit.MILLISECONDS);
    }

    /**
     * Sets the handshake timeout.
     * @param handshakeTimeout timeout value, or <code>-1</code> means for
     * non-blocking handshake mode.
     */
    public void setHandshakeTimeout(final long handshakeTimeout,
            final TimeUnit timeUnit) {
        if (handshakeTimeout < 0) {
            handshakeTimeoutMillis = -1;
        } else {
            this.handshakeTimeoutMillis =
                    TimeUnit.MILLISECONDS.convert(handshakeTimeout, timeUnit);
        }
    }

   
    public TransportFilter createOptimizedTransportFilter(final TransportFilter childFilter) {
        return new SSLTransportFilterWrapper(childFilter);
    }

    @Override
    public void onFilterChainChanged(final FilterChain filterChain) {
        final boolean isSsl = filterChain.indexOfType(SSLBaseFilter.class) >= 0;
        final int sslTransportFilterIdx = filterChain.indexOfType(SSLTransportFilterWrapper.class);
       
        if (isSsl) {
            if (sslTransportFilterIdx == -1) {
                final int transportFilterIdx =
                        filterChain.indexOfType(TransportFilter.class);
                if (transportFilterIdx >= 0) {
                    filterChain.set(transportFilterIdx,
                            createOptimizedTransportFilter(
                            (TransportFilter) filterChain.get(transportFilterIdx)));
                }
            }
        } else { // SSLBaseFilter has been removed
            if (sslTransportFilterIdx >= 0) {
                SSLTransportFilterWrapper wrapper =
                        (SSLTransportFilterWrapper) filterChain.get(sslTransportFilterIdx);
                filterChain.set(sslTransportFilterIdx, wrapper.transportFilter);
            }
        }
    }

   
    // ----------------------------------------------------- Methods from Filter


    @Override
    public NextAction handleEvent(FilterChainContext ctx, FilterChainEvent event) throws IOException {
        if (event.type() == CertificateEvent.TYPE) {
            final CertificateEvent ce = (CertificateEvent) event;
            ce.certs = getPeerCertificateChain(obtainSslConnectionContext(ctx.getConnection()),
                                               ctx,
                                               ce.needClientAuth);
            return ctx.getStopAction();
        }
        return ctx.getInvokeAction();
    }

    @Override
    public NextAction handleRead(final FilterChainContext ctx)
    throws IOException {
        final Connection connection = ctx.getConnection();
        final SSLConnectionContext sslCtx = obtainSslConnectionContext(connection);
        SSLEngine sslEngine = sslCtx.getSslEngine();
       
        if (sslEngine != null && !isHandshaking(sslEngine)) {
            return unwrapAll(ctx, sslCtx);
        } else {
            if (sslEngine == null) {
                sslEngine = serverSSLEngineConfigurator.createSSLEngine();
                sslEngine.beginHandshake();
                sslCtx.configure(sslEngine);
                notifyHandshakeStart(connection);
            }

            final Buffer buffer;
            if (handshakeTimeoutMillis >= 0) {
                buffer = doHandshakeSync(sslCtx, ctx, (Buffer) ctx.getMessage(),
                        handshakeTimeoutMillis);
            } else {
                buffer = makeInputRemainder(sslCtx, ctx,
                        doHandshakeStep(sslCtx, ctx, (Buffer) ctx.getMessage()));
            }
       
            final boolean hasRemaining = buffer != null && buffer.hasRemaining();
           
            final boolean isHandshaking = isHandshaking(sslEngine);
            if (!isHandshaking) {
                notifyHandshakeComplete(connection, sslEngine);
                final FilterChain connectionFilterChain = sslCtx.getNewConnectionFilterChain();
                sslCtx.setNewConnectionFilterChain(null);
                if (connectionFilterChain != null) {
                    if (LOGGER.isLoggable(Level.FINE)) {
                        LOGGER.log(Level.FINE, "Applying new FilterChain after"
                                + "SSLHandshake. Connection={0} filterchain={1}",
                                new Object[]{connection, connectionFilterChain});
                    }
                   
                    connection.setProcessor(connectionFilterChain);

                    if (hasRemaining) {
                        NextAction suspendAction = ctx.getSuspendAction();
                        ctx.setMessage(buffer);
                        ctx.suspend();
                        final FilterChainContext newContext =
                                obtainProtocolChainContext(ctx, connectionFilterChain);
                        ProcessorExecutor.execute(newContext.getInternalContext());
                        return suspendAction;
                    } else {
                        return ctx.getStopAction();
                    }
                }

                if (hasRemaining) {
                    ctx.setMessage(buffer);
                    return unwrapAll(ctx, sslCtx);
                }
            }

            return ctx.getStopAction(buffer);
        }
    }

    @SuppressWarnings("unchecked")
    @Override
    public NextAction handleWrite(final FilterChainContext ctx) throws IOException {
        if (ctx.getMessage() instanceof FileTransfer) {
            throw new IllegalStateException("TLS operations not supported with SendFile messages");
        }

        final Connection connection = ctx.getConnection();
       
        synchronized(connection) {
            final Buffer output =
                    wrapAll(ctx, obtainSslConnectionContext(connection));

            final FilterChainContext.TransportContext transportContext =
                    ctx.getTransportContext();

            ctx.write(null, output,
                    transportContext.getCompletionHandler(),
                    transportContext.getPushBackHandler(),
                    COPY_CLONER,
                    transportContext.isBlocking());

            return ctx.getStopAction();
        }
    }

    // ------------------------------------------------------- Protected Methods

    protected NextAction unwrapAll(final FilterChainContext ctx,
            final SSLConnectionContext sslCtx) throws SSLException {
        Buffer input = ctx.getMessage();
       
        Buffer output = null;
       
        boolean isClosed = false;
       
        _outter:
        do {
            final int len = getSSLPacketSize(input);
           
            if (len == -1 || input.remaining() < len) {
                break;
            }

            final SSLConnectionContext.SslResult result =
                    sslCtx.unwrap(input, output, MM_ALLOCATOR);
           
            if (isHandshaking(sslCtx.getSslEngine())) {
                input = rehandshake(ctx, sslCtx);
                if (input == null) {
                    break;
                }
            }
           
            output = result.getOutput();

            if (result.isError()) {
                output.dispose();
                throw result.getError();
            }
           
            switch (result.getSslEngineResult().getStatus()) {
                case OK:
                    if (input.hasRemaining()) {
                        break;
                    }

                    break _outter;
                case CLOSED:
                    isClosed = true;
                    break _outter;
                default:
                    // Should not reach this point
                    throw new IllegalStateException("Unexpected status: " +
                            result.getSslEngineResult().getStatus());
            }
        } while (true);
       
        if (output != null) {
            output.trim();

            if (output.hasRemaining() || isClosed) {
                ctx.setMessage(output);
                return ctx.getInvokeAction(makeInputRemainder(sslCtx, ctx, input));
            }
        }

        return ctx.getStopAction(makeInputRemainder(sslCtx, ctx, input));
    }

    protected Buffer wrapAll(final FilterChainContext ctx,
            final SSLConnectionContext sslCtx) throws SSLException {
       
        final Buffer input = ctx.getMessage();
       
        final Buffer output = sslCtx.wrapAll(input, OUTPUT_BUFFER_ALLOCATOR);

        input.tryDispose();

        return output;
    }

//    protected Buffer doHandshakeStep1(final SSLConnectionContext sslCtx,
//                                     final FilterChainContext ctx,
//                                     Buffer inputBuffer)
//    throws IOException {
//
//        final SSLEngine sslEngine = sslCtx.getSslEngine();
//        final Connection connection = ctx.getConnection();
//
//        final boolean isLoggingFinest = LOGGER.isLoggable(Level.FINEST);
//        try {
//            HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus();
//
//            while (true) {
//
//                if (isLoggingFinest) {
//                    LOGGER.log(Level.FINEST, "Loop Engine: {0} handshakeStatus={1}",
//                            new Object[]{sslEngine, sslEngine.getHandshakeStatus()});
//                }
//
//                switch (handshakeStatus) {
//                    case NEED_UNWRAP: {
//
//                        if (isLoggingFinest) {
//                            LOGGER.log(Level.FINEST, "NEED_UNWRAP Engine: {0}", sslEngine);
//                        }
//
//                        if (inputBuffer == null || !inputBuffer.hasRemaining()) {
//                            return inputBuffer;
//                        }
//                       
//                        final int expectedLength = getSSLPacketSize(inputBuffer);
//                        if (expectedLength == -1
//                                || inputBuffer.remaining() < expectedLength) {
//                            return inputBuffer;
//                        }
//
//                        final SSLEngineResult sslEngineResult =
//                                handshakeUnwrap(connection, sslCtx, inputBuffer, null);
//
//                        inputBuffer.shrink();
//
//                        final SSLEngineResult.Status status = sslEngineResult.getStatus();
//
//                        if (status == SSLEngineResult.Status.BUFFER_UNDERFLOW ||
//                                status == SSLEngineResult.Status.BUFFER_OVERFLOW) {
//                            throw new SSLException("SSL unwrap error: " + status);
//                        }
//
//                        handshakeStatus = sslEngine.getHandshakeStatus();
//                        break;
//                    }
//
//                    case NEED_WRAP: {
//                        if (isLoggingFinest) {
//                            LOGGER.log(Level.FINEST, "NEED_WRAP Engine: {0}", sslEngine);
//                        }
//
//                        inputBuffer = makeInputRemainder(sslCtx, ctx, inputBuffer);
//                        final Buffer buffer = handshakeWrap(connection, sslCtx, null);
//
//                        try {
//                            ctx.write(buffer);
//
//                            handshakeStatus = sslEngine.getHandshakeStatus();
//                        } catch (Exception e) {
//                            buffer.dispose();
//                            throw new IOException("Unexpected exception", e);
//                        }
//
//                        break;
//                    }
//
//                    case NEED_TASK: {
//                        if (isLoggingFinest) {
//                            LOGGER.log(Level.FINEST, "NEED_TASK Engine: {0}", sslEngine);
//                        }
//                        executeDelegatedTask(sslEngine);
//                        handshakeStatus = sslEngine.getHandshakeStatus();
//                        break;
//                    }
//
//                    case FINISHED:
//                    case NOT_HANDSHAKING: {
//                        return inputBuffer;
//                    }
//                }
//
//                if (handshakeStatus == HandshakeStatus.FINISHED) {
//                    return inputBuffer;
//                }
//            }
//        } catch (IOException ioe) {
//            notifyHandshakeFailed(connection, ioe);
//            throw ioe;
//        }
//    }

    protected Buffer doHandshakeSync(final SSLConnectionContext sslCtx,
            final FilterChainContext ctx,
            Buffer inputBuffer,
            final long timeoutMillis) throws IOException {
       
        final Connection connection = ctx.getConnection();
        final SSLEngine sslEngine = sslCtx.getSslEngine();
       
        final Buffer tmpAppBuffer = allocateOutputBuffer(sslCtx.getAppBufferSize());
       
        final long oldReadTimeout = connection.getReadTimeout(TimeUnit.MILLISECONDS);
       
        try {
            connection.setReadTimeout(timeoutMillis, TimeUnit.MILLISECONDS);
           
            inputBuffer = makeInputRemainder(sslCtx, ctx,
                    doHandshakeStep(sslCtx, ctx, inputBuffer, tmpAppBuffer));

            while (isHandshaking(sslEngine)) {
                final ReadResult rr = ctx.read();
                final Buffer newBuf = (Buffer) rr.getMessage();
                inputBuffer = Buffers.appendBuffers(ctx.getMemoryManager(),
                        inputBuffer, newBuf);
                inputBuffer = makeInputRemainder(sslCtx, ctx,
                        doHandshakeStep(sslCtx, ctx, inputBuffer, tmpAppBuffer));
            }
        } finally {
            tmpAppBuffer.dispose();
            connection.setReadTimeout(oldReadTimeout, TimeUnit.MILLISECONDS);
        }
       
        return inputBuffer;
    }

    protected Buffer doHandshakeStep(final SSLConnectionContext sslCtx,
                                     final FilterChainContext ctx,
                                     Buffer inputBuffer) throws IOException {
        return doHandshakeStep(sslCtx, ctx, inputBuffer, null);
    }
           
    protected Buffer doHandshakeStep(final SSLConnectionContext sslCtx,
                                     final FilterChainContext ctx,
                                     Buffer inputBuffer,
                                     final Buffer tmpAppBuffer0)
            throws IOException {

        final SSLEngine sslEngine = sslCtx.getSslEngine();
        final Connection connection = ctx.getConnection();
       
        final boolean isLoggingFinest = LOGGER.isLoggable(Level.FINEST);
        Buffer tmpInputToDispose = null;
        Buffer tmpNetBuffer = null;
       
        Buffer tmpAppBuffer = tmpAppBuffer0;
       
        try {
            HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus();

            _exitWhile:
           
            while (true) {

                if (isLoggingFinest) {
                    LOGGER.log(Level.FINEST, "Loop Engine: {0} handshakeStatus={1}",
                            new Object[]{sslEngine, sslEngine.getHandshakeStatus()});
                }

                switch (handshakeStatus) {
                    case NEED_UNWRAP: {

                        if (isLoggingFinest) {
                            LOGGER.log(Level.FINEST, "NEED_UNWRAP Engine: {0}", sslEngine);
                        }

                        if (inputBuffer == null || !inputBuffer.hasRemaining()) {
                            break _exitWhile;
                        }

                        final int expectedLength = getSSLPacketSize(inputBuffer);
                        if (expectedLength == -1
                                || inputBuffer.remaining() < expectedLength) {
                            break _exitWhile;
                        }
                       
                        if (tmpAppBuffer == null) {
                            tmpAppBuffer = allocateOutputBuffer(sslCtx.getAppBufferSize());
                        }
                       
                        final SSLEngineResult sslEngineResult =
                                handshakeUnwrap(sslCtx, inputBuffer, tmpAppBuffer);

                        if (!inputBuffer.hasRemaining()) {
                            tmpInputToDispose = inputBuffer;
                            inputBuffer = null;
                        }

                        final SSLEngineResult.Status status = sslEngineResult.getStatus();

                        if (status == Status.BUFFER_UNDERFLOW ||
                                status == Status.BUFFER_OVERFLOW) {
                            throw new SSLException("SSL unwrap error: " + status);
                        }

                        handshakeStatus = sslEngine.getHandshakeStatus();
                        break;
                    }

                    case NEED_WRAP: {
                        if (isLoggingFinest) {
                            LOGGER.log(Level.FINEST, "NEED_WRAP Engine: {0}", sslEngine);
                        }

                        tmpNetBuffer = handshakeWrap(
                                connection, sslCtx, tmpNetBuffer);
                        handshakeStatus = sslEngine.getHandshakeStatus();

                        break;
                    }

                    case NEED_TASK: {
                        if (isLoggingFinest) {
                            LOGGER.log(Level.FINEST, "NEED_TASK Engine: {0}", sslEngine);
                        }
                        executeDelegatedTask(sslEngine);
                        handshakeStatus = sslEngine.getHandshakeStatus();
                        break;
                    }

                    case FINISHED:
                    case NOT_HANDSHAKING: {
                        break _exitWhile;
                    }
                }

                if (handshakeStatus == HandshakeStatus.FINISHED) {
                    break _exitWhile;
                }
            }
        } catch (IOException ioe) {
            notifyHandshakeFailed(connection, ioe);
            throw ioe;
        } finally {
            if (tmpAppBuffer0 == null && tmpAppBuffer != null) {
                tmpAppBuffer.dispose();
            }
           
            if (tmpInputToDispose != null) {
                tmpInputToDispose.tryDispose();
                inputBuffer = null;
            } else if (inputBuffer != null) {
                inputBuffer.shrink();
            }
           
            if (tmpNetBuffer != null) {
                if (inputBuffer != null) {
                    inputBuffer = makeInputRemainder(sslCtx, ctx, inputBuffer);
                }
               
                ctx.write(tmpNetBuffer);
            }
        }
       
        return inputBuffer;
    }
   
    /**
     * Performs an SSL renegotiation.
     *
     * @param sslCtx the {@link SSLConnectionContext} associated with this
     *  this renegotiation request.
     * @param context the {@link FilterChainContext} associated with this
     *  this renegotiation request.
     *
     * @throws IOException if an error occurs during SSL renegotiation.
     */
    protected void renegotiate(final SSLConnectionContext sslCtx,
                               final FilterChainContext context)
                               throws IOException {

        final SSLEngine sslEngine = sslCtx.getSslEngine();
        if (sslEngine.getWantClientAuth() && !renegotiateOnClientAuthWant) {
            return;
        }
        final boolean authConfigured =
                (sslEngine.getWantClientAuth()
                        || sslEngine.getNeedClientAuth());
        if (!authConfigured) {
            sslEngine.setNeedClientAuth(true);
        }

        sslEngine.getSession().invalidate();

        try {
            sslEngine.beginHandshake();
        } catch (SSLHandshakeException e) {
            // If we catch SSLHandshakeException at this point it may be due
            // to an older SSL peer that hasn't made its SSL/TLS renegotiation
            // secure.  This will be the case with Oracle's VM older than
            // 1.6.0_22 or native applications using OpenSSL libraries
            // older than 0.9.8m.
            //
            // What we're trying to accomplish here is an attempt to detect
            // this issue and log a useful message for the end user instead
            // of an obscure exception stack trace in the server's log.
            // Note that this probably will only work on Oracle's VM.
            if (e.toString().toLowerCase().contains("insecure renegotiation")) {
                if (LOGGER.isLoggable(Level.SEVERE)) {
                    LOGGER.severe("Secure SSL/TLS renegotiation is not "
                            + "supported by the peer.  This is most likely due"
                            + " to the peer using an older SSL/TLS "
                            + "implementation that does not implement RFC 5746.");
                }
                // we could return null here and let the caller
                // decided what to do, but since the SSLEngine will
                // close the channel making further actions useless,
                // we'll report the entire cause.
            }
            throw e;
        }
       
        try {
            rehandshake(context, sslCtx);
        } finally {
            if (!authConfigured) {
                sslEngine.setNeedClientAuth(false);
            }
        }
    }

    private Buffer rehandshake(final FilterChainContext context,
            final SSLConnectionContext sslCtx) throws SSLException {
        final Connection c = context.getConnection();
       
        notifyHandshakeStart(c);

        try {
            return doHandshakeSync(sslCtx, context, null, handshakeTimeoutMillis);

        } catch (Throwable t) {
            if (LOGGER.isLoggable(Level.FINE)) {
                LOGGER.log(Level.FINE, "Error during handshake", t);
            }
        }
       
        return null;
    }

    /**
     * <p>
     * Obtains the certificate chain for this SSL session.  If no certificates
     * are available, and <code>needClientAuth</code> is true, an SSL renegotiation
     * will be be triggered to request the certificates from the client.
     * </p>
     *
     * @param sslCtx the {@link SSLConnectionContext} associated with this
     *  certificate request.
     * @param context the {@link FilterChainContext} associated with this
     *  this certificate request.
     * @param needClientAuth determines whether or not SSL renegotiation will
     *  be attempted to obtain the certificate chain.
     *
     * @return the certificate chain as an <code>Object[]</code>.  If no
     *  certificate chain can be determined, this method will return
     *  <code>null</code>.
     *
     * @throws IOException if an error occurs during renegotiation.
     */
    protected Object[] getPeerCertificateChain(final SSLConnectionContext sslCtx,
                                               final FilterChainContext context,
                                               final boolean needClientAuth)
    throws IOException {

        Certificate[] certs = getPeerCertificates(sslCtx);
        if (certs != null) {
            return certs;
        }

        if (needClientAuth) {
            renegotiate(sslCtx, context);
        }

        certs = getPeerCertificates(sslCtx);

        if (certs == null) {
            return null;
        }

        X509Certificate[] x509Certs = extractX509Certs(certs);

        if (x509Certs == null || x509Certs.length < 1) {
            return null;
        }

        return x509Certs;
    }

    private FilterChainContext obtainProtocolChainContext(
            final FilterChainContext ctx,
            final FilterChain completeProtocolFilterChain) {

        final FilterChainContext newFilterChainContext =
                completeProtocolFilterChain.obtainFilterChainContext(
                        ctx.getConnection(),
                        ctx.getStartIdx(),
                        completeProtocolFilterChain.size(),
                        ctx.getFilterIdx());

        newFilterChainContext.setAddressHolder(ctx.getAddressHolder());
        newFilterChainContext.setMessage(ctx.getMessage());
        newFilterChainContext.getInternalContext().setIoEvent(IOEvent.READ);
        newFilterChainContext.getInternalContext().addLifeCycleListener(
                new InternalProcessingHandler(ctx));

        return newFilterChainContext;
    }


    // --------------------------------------------------------- Private Methods

    private X509Certificate[] extractX509Certs(final Certificate[] certs) {
        final X509Certificate[] x509Certs = new X509Certificate[certs.length];
        for(int i = 0, len = certs.length; i < len; i++) {
            if( certs[i] instanceof X509Certificate ) {
                x509Certs[i] = (X509Certificate)certs[i];
            } else {
                try {
                    final byte [] buffer = certs[i].getEncoded();
                    final CertificateFactory cf =
                            CertificateFactory.getInstance("X.509");
                    ByteArrayInputStream stream = new ByteArrayInputStream(buffer);
                    x509Certs[i] = (X509Certificate)
                    cf.generateCertificate(stream);
                } catch(Exception ex) {
                    LOGGER.log(Level.INFO,
                               "Error translating cert " + certs[i],
                               ex);
                    return null;
                }
            }

            if (LOGGER.isLoggable(Level.FINE)) {
                LOGGER.log(Level.FINE, "Cert #{0} = {1}", new Object[] {i, x509Certs[i]});
            }
        }
        return x509Certs;
    }

    private Certificate[] getPeerCertificates(final SSLConnectionContext sslCtx) {
        try {
            return sslCtx.getSslEngine().getSession().getPeerCertificates();
        } catch( Throwable t ) {
            if (LOGGER.isLoggable(Level.FINE)) {
                LOGGER.log(Level.FINE,"Error getting client certs", t);
            }
            return null;
        }
    }

    protected void notifyHandshakeStart(final Connection connection) {
        if (!handshakeListeners.isEmpty()) {
            for (final HandshakeListener listener : handshakeListeners) {
                listener.onStart(connection);
            }
        }
    }
   
    protected void notifyHandshakeComplete(final Connection<?> connection,
                                          final SSLEngine sslEngine) {

        if (!handshakeListeners.isEmpty()) {
            for (final HandshakeListener listener : handshakeListeners) {
                listener.onComplete(connection);
            }
        }
    }

    protected void notifyHandshakeFailed(final Connection connection,
            final Throwable t) {
    }
   
    // ----------------------------------------------------------- Inner Classes

    public static class CertificateEvent implements FilterChainEvent {

        private static final String TYPE = "CERT_EVENT";

        private Object[] certs;

        private final boolean needClientAuth;


        // -------------------------------------------------------- Constructors


        public CertificateEvent(final boolean needClientAuth) {
            this.needClientAuth = needClientAuth;
        }


        // --------------------------------------- Methods from FilterChainEvent


        @Override
        public final Object type() {
            return TYPE;
        }


        // ------------------------------------------------------ Public Methods


        public Object[] getCertificates() {
            return certs;
        }

    } // END CertificateEvent

    private class InternalProcessingHandler extends IOEventLifeCycleListener.Adapter {
        private final FilterChainContext parentContext;

        private InternalProcessingHandler(final FilterChainContext parentContext) {
            this.parentContext = parentContext;
        }

        @Override
        public void onComplete(final Context context, Object data) throws IOException {
            parentContext.resume(parentContext.getStopAction());
        }

    } // END InternalProcessingHandler
   
    public static interface HandshakeListener {
        public void onStart(Connection connection);
        public void onComplete(Connection connection);
    }
   
    private final class SSLTransportFilterWrapper extends TransportFilter {
        private final TransportFilter transportFilter;

        public SSLTransportFilterWrapper(final TransportFilter transportFilter) {
            this.transportFilter = transportFilter;
        }
       
        @Override
        public NextAction handleAccept(FilterChainContext ctx) throws IOException {
            return transportFilter.handleAccept(ctx);
        }

        @Override
        public NextAction handleConnect(FilterChainContext ctx) throws IOException {
            return transportFilter.handleConnect(ctx);
        }

        @Override
        public NextAction handleRead(final FilterChainContext ctx) throws IOException {
            final Connection connection = ctx.getConnection();
            final SSLConnectionContext sslCtx =
                    obtainSslConnectionContext(connection);
           
            if (sslCtx.getSslEngine() == null) {
                final SSLEngine sslEngine = serverSSLEngineConfigurator.createSSLEngine();
                sslEngine.beginHandshake();
                sslCtx.configure(sslEngine);
                notifyHandshakeStart(connection);
            }
           
            ctx.setMessage(allowDispose(allocateInputBuffer(sslCtx)));
           
            return transportFilter.handleRead(ctx);
        }

        @Override
        public NextAction handleWrite(FilterChainContext ctx) throws IOException {
            return transportFilter.handleWrite(ctx);
        }

        @Override
        public NextAction handleEvent(FilterChainContext ctx, FilterChainEvent event) throws IOException {
            return transportFilter.handleEvent(ctx, event);
        }

        @Override
        public NextAction handleClose(FilterChainContext ctx) throws IOException {
            return transportFilter.handleClose(ctx);
        }

        @Override
        public void onAdded(FilterChain filterChain) {
            transportFilter.onAdded(filterChain);
        }

        @Override
        public void onFilterChainChanged(FilterChain filterChain) {
            transportFilter.onFilterChainChanged(filterChain);
        }

        @Override
        public void onRemoved(FilterChain filterChain) {
            transportFilter.onRemoved(filterChain);
        }

        @Override
        public void exceptionOccurred(FilterChainContext ctx, Throwable error) {
            transportFilter.exceptionOccurred(ctx, error);
        }

        @Override
        public FilterChainContext createContext(Connection connection, Operation operation) {
            return transportFilter.createContext(connection, operation);
        }
    }
   
    private static final class OnWriteCopyCloner implements MessageCloner<Buffer> {
        @Override
        public Buffer clone(final Connection connection,
                final Buffer originalMessage) {
            final SSLConnectionContext sslCtx = obtainSslConnectionContext(connection);

            final int copyThreshold = sslCtx.getNetBufferSize() / 2;

            final Buffer lastOutputBuffer = sslCtx.resetLastOutputBuffer();

            final int totalRemaining = originalMessage.remaining();

            if (totalRemaining < copyThreshold) {
                return move(connection.getTransport().getMemoryManager(),
                        originalMessage);
            }
            if (lastOutputBuffer.remaining() < copyThreshold) {
                final Buffer tmpBuf =
                        copy(connection.getTransport().getMemoryManager(),
                        originalMessage);

                if (originalMessage.isComposite()) {
                    ((CompositeBuffer) originalMessage).replace(
                            lastOutputBuffer, tmpBuf);
                } else {
                    assert originalMessage == lastOutputBuffer;
                }

                lastOutputBuffer.tryDispose();

                return tmpBuf;
            }


            return originalMessage;
        }
    }
}
TOP

Related Classes of org.glassfish.grizzly.ssl.SSLBaseFilter$OnWriteCopyCloner

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.