/*
* Copyright 2014 Basho Technologies Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.basho.riak.client.core.netty;
import com.basho.riak.client.core.RiakMessage;
import com.basho.riak.client.core.util.Constants;
import com.basho.riak.protobuf.RiakMessageCodes;
import com.basho.riak.protobuf.RiakPB;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.ssl.SslHandler;
import io.netty.util.concurrent.DefaultPromise;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import javax.net.ssl.SSLEngine;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
*
* @author Brian Roach <roach at basho dot com>
*/
public class RiakSecurityDecoder extends ByteToMessageDecoder
{
private final CountDownLatch promiseLatch = new CountDownLatch(1);
private final SSLEngine sslEngine;
private final String username;
private final String password;
private final Logger logger = LoggerFactory.getLogger(RiakSecurityDecoder.class);
private volatile DefaultPromise<Void> promise;
private enum State { TLS_START, TLS_WAIT, SSL_WAIT, AUTH_WAIT }
private volatile State state = State.TLS_START;
public RiakSecurityDecoder(SSLEngine engine, String username, String password)
{
this.sslEngine = engine;
this.username = username;
this.password = password;
}
@Override
protected void decode(ChannelHandlerContext chc, ByteBuf in, List<Object> out) throws Exception
{
// Make sure we have 4 bytes
if (in.readableBytes() >= 4)
{
in.markReaderIndex();
int length = in.readInt();
// See if we have the full frame.
if (in.readableBytes() < length)
{
in.resetReaderIndex();
}
else
{
byte code = in.readByte();
byte[] protobuf = new byte[length - 1];
in.readBytes(protobuf);
switch(state)
{
case TLS_WAIT:
switch(code)
{
case RiakMessageCodes.MSG_StartTls:
logger.debug("Received MSG_RpbStartTls reply");
// change state
this.state = State.SSL_WAIT;
// insert SSLHandler
SslHandler sslHandler = new SslHandler(sslEngine);
// get promise
Future<Channel> hsFuture = sslHandler.handshakeFuture();
// register callback
hsFuture.addListener(new SslListener());
// Add handler
chc.channel().pipeline().addFirst(Constants.SSL_HANDLER, sslHandler);
break;
case RiakMessageCodes.MSG_ErrorResp:
logger.debug("Received MSG_ErrorResp reply to startTls");
promise.tryFailure((riakErrorToException(protobuf)));
break;
default:
promise.tryFailure(new RiakResponseException(0,
"Invalid return code during StartTLS; " + code));
}
break;
case AUTH_WAIT:
chc.channel().pipeline().remove(this);
switch(code)
{
case RiakMessageCodes.MSG_AuthResp:
logger.debug("Received MSG_RpbAuthResp reply");
promise.trySuccess(null);
break;
case RiakMessageCodes.MSG_ErrorResp:
logger.debug("Received MSG_ErrorResp reply to auth");
promise.tryFailure(riakErrorToException(protobuf));
break;
default:
promise.tryFailure(new RiakResponseException(0,
"Invalid return code during Auth; " + code));
}
break;
default:
// WTF?
logger.error("Received message while not in TLS_WAIT or AUTH_WAIT");
promise.tryFailure(new IllegalStateException("Received message while not in TLS_WAIT or AUTH_WAIT"));
}
}
}
}
private RiakResponseException riakErrorToException(byte[] protobuf)
{
try
{
RiakPB.RpbErrorResp error = RiakPB.RpbErrorResp.parseFrom(protobuf);
return new RiakResponseException(error.getErrcode(), error.getErrmsg().toStringUtf8());
}
catch (InvalidProtocolBufferException ex)
{
return null;
}
}
private void init(ChannelHandlerContext ctx)
{
state = State.TLS_WAIT;
promise = new DefaultPromise<Void>(ctx.executor());
promiseLatch.countDown();
ctx.channel().writeAndFlush(new RiakMessage(RiakMessageCodes.MSG_StartTls,
new byte[0]));
}
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception
{
logger.debug("MyStartTlsDecoder Handler Added");
if (ctx.channel().isActive())
{
init(ctx);
}
}
@Override
public void channelActive(final ChannelHandlerContext ctx) throws Exception {
logger.debug("MyStartTlsDecoder Channel Active");
init(ctx);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception
{
promise.tryFailure(new IOException("Channel closed during auth"));
ctx.fireChannelInactive();
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
throws Exception
{
if (cause.getCause() instanceof javax.net.ssl.SSLHandshakeException)
{
// consume
}
else
{
ctx.fireExceptionCaught(cause);
}
}
public DefaultPromise<Void> getPromise() throws InterruptedException {
promiseLatch.await();
return promise;
}
private class SslListener implements GenericFutureListener<Future<Channel>>
{
@Override
public void operationComplete(Future<Channel> future) throws Exception
{
if (future.isSuccess())
{
Channel c = future.getNow();
state = State.AUTH_WAIT;
RiakPB.RpbAuthReq authReq =
RiakPB.RpbAuthReq.newBuilder()
.setUser(ByteString.copyFromUtf8(username))
.setPassword(ByteString.copyFromUtf8(password))
.build();
c.writeAndFlush(new RiakMessage(RiakMessageCodes.MSG_AuthReq,
authReq.toByteArray()));
}
else
{
promise.tryFailure(future.cause());
}
}
}
}