Package org.springframework.integration.x.ip.websocket

Source Code of org.springframework.integration.x.ip.websocket.WebSocketTcpConnectionInterceptorFactory$WebSocketTcpConnectionInterceptor

/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*      http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.integration.x.ip.websocket;

import java.util.Date;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.ConcurrentHashMap;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.serializer.Deserializer;
import org.springframework.integration.Message;
import org.springframework.integration.MessageHandlingException;
import org.springframework.integration.MessageHeaders;
import org.springframework.integration.MessagingException;
import org.springframework.integration.aggregator.ResequencingMessageGroupProcessor;
import org.springframework.integration.aggregator.ResequencingMessageHandler;
import org.springframework.integration.channel.DirectChannel;
import org.springframework.integration.context.IntegrationObjectSupport;
import org.springframework.integration.core.MessageHandler;
import org.springframework.integration.endpoint.EventDrivenConsumer;
import org.springframework.integration.ip.tcp.connection.TcpConnection;
import org.springframework.integration.ip.tcp.connection.TcpConnectionInterceptorFactory;
import org.springframework.integration.ip.tcp.connection.TcpConnectionInterceptorSupport;
import org.springframework.integration.ip.tcp.connection.TcpConnectionSupport;
import org.springframework.integration.ip.tcp.connection.TcpNioConnection;
import org.springframework.integration.support.MessageBuilder;
import org.springframework.integration.x.ip.websocket.WebSocketEvent.WebSocketEventType;
import org.springframework.integration.x.ip.websocket.WebSocketSerializer.WebSocketState;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.Assert;

/**
* @author Gary Russell
* @since 3.0
*
*/
public class WebSocketTcpConnectionInterceptorFactory extends IntegrationObjectSupport
    implements TcpConnectionInterceptorFactory {

  private static final Message<WebSocketFrame> PING = MessageBuilder.withPayload(
      new WebSocketFrame(WebSocketFrame.TYPE_PING, "Ping from SI")).build();

  private static final Log logger = LogFactory.getLog(WebSocketTcpConnectionInterceptor.class);

  private final Map<TcpConnection, WebSocketTcpConnectionInterceptor> connections =
      new ConcurrentHashMap<TcpConnection, WebSocketTcpConnectionInterceptor>();

  private volatile TaskScheduler taskScheduler;

  private volatile long pingInterval = 25000;

  private final Runnable pinger = new Runnable() {

    @Override
    public void run() {
      // Add 100ms to allow for heuristics
      long pingFilter = System.currentTimeMillis() - pingInterval + 100;
      for (Entry<TcpConnection, WebSocketTcpConnectionInterceptor> entry : connections.entrySet()) {
        TcpConnection connection = entry.getKey();
        String connectionId = connection.getConnectionId();
        if (entry.getValue().getLastReceiveTime() <= pingFilter) {
          try {
            if (logger.isDebugEnabled()) {
              logger.debug("Sending Ping to " + connectionId);
            }
            connection.send(PING);
          }
          catch (Exception e) {
            logger.error("Failed to send Ping to " + connectionId, e);
            connection.close();
          }
        }
        else {
          if (logger.isTraceEnabled()) {
            logger.trace("Skipping PING for " + connectionId + "  due to recent send");
          }
        }
      }
      if (pingInterval > 0) {
        taskScheduler.schedule(pinger, new Date(System.currentTimeMillis() + pingInterval));
      }
    }
  };

  @Override
  public void setTaskScheduler(TaskScheduler taskScheduler) {
    this.taskScheduler = taskScheduler;
  }

  /**
   * The time between PINGs sent for idle connections. Must be less than half the
   * socket timeout (if any).
   * @param pingInterval
   */
  public void setPingInterval(long pingInterval) {
    this.pingInterval = pingInterval;
  }

  @Override
  protected void onInit() throws Exception {
    super.onInit();
    if (this.pingInterval > 0) {
      if (this.taskScheduler == null) {
        this.taskScheduler = this.getTaskScheduler();
      }
      this.taskScheduler.schedule(this.pinger, new Date(System.currentTimeMillis() + this.pingInterval));
    }
  }

  @Override
  public TcpConnectionInterceptorSupport getInterceptor() {
    return new WebSocketTcpConnectionInterceptor();
  }

  public WebSocketTcpConnectionInterceptor locateInterceptor(TcpConnection connection) {
    return this.connections.get(connection);
  }


  public class WebSocketTcpConnectionInterceptor extends TcpConnectionInterceptorSupport {

    private volatile boolean shook;

    private final DirectChannel resequenceChannel = new DirectChannel();

    private final EventDrivenConsumer resequencer;

    private long lastReceiveTime;

    public WebSocketTcpConnectionInterceptor() {
      super();
      ResequencingMessageHandler handler = new ResequencingMessageHandler(new ResequencingMessageGroupProcessor());
      handler.setReleasePartialSequences(true);
      DirectChannel resequenced = new DirectChannel();
      resequenced.setBeanName("resequencedWSFrames");
      handler.setOutputChannel(resequenced);
      this.resequencer = new EventDrivenConsumer(this.resequenceChannel, handler);
      resequenced.subscribe(new MessageHandler() {

        @Override
        public void handleMessage(Message<?> message) throws MessagingException {
          doOnMessage(message);
        }
      });
      this.resequencer.afterPropertiesSet();
      this.resequencer.start();
    }

    public long getLastReceiveTime() {
      return lastReceiveTime;
    }

    /**
     * When using NIO, we have to resequence the messages because frames may
     * arrive out of order. This is particularly an issue for some of the
     * Autobahn tests where, for example, many pings are sent and the test
     * expects the pongs to come back in the same order.
     */
    @Override
    public boolean onMessage(Message<?> message) {
      this.lastReceiveTime = System.currentTimeMillis();
      if (this.getTheConnection() instanceof TcpNioConnection && message.getHeaders().getCorrelationId() != null) {
        resequenceChannel.send(message);
        return true;
      }
      else {
        return this.doOnMessage(message);
      }
    }

    public boolean doOnMessage(Message<?> message) {
      Assert.isInstanceOf(WebSocketFrame.class, message.getPayload());
      WebSocketFrame payload = (WebSocketFrame) message.getPayload();
      WebSocketState state = getState(message);
      if (logger.isTraceEnabled()) {
        logger.trace(state);
      }
      if (payload.getRsv() > 0) {
        if (logger.isDebugEnabled()) {
          logger.debug("Reserved bits:" + payload.getRsv());
        }
        this.protocolViolation(message);
      }
      else if (payload.getType() == WebSocketFrame.TYPE_CLOSE) {
        try {
          if (logger.isDebugEnabled()) {
            logger.debug("Close, status:" + payload.getStatus());
          }
          // If we initiated the close, just close.
          if (!state.isCloseInitiated()) {
            if (payload.getStatus() < 0) {
              payload.setStatus((short) 1000);
            }
            this.send(message);
          }
          WebSocketEvent event = new WebSocketEvent(this.getTheConnection(),
              WebSocketEventType.WEBSOCKET_CLOSED, state.getPath(), state.getQueryString());
          this.getTheConnection().publishEvent(event);
          this.close();
        }
        catch (Exception e) {
          throw new MessageHandlingException(message, "Send failed", e);
        }
      }
      else if (state == null || state.isCloseInitiated()) {
        if (logger.isWarnEnabled()) {
          logger.warn("Message dropped - close initiated:" + message);
        }
      }
      else if ((payload.getType() & 0xff) == WebSocketFrame.TYPE_INVALID) {
        if (logger.isDebugEnabled()) {
          logger.debug("Invalid:" + payload.getPayload());
        }
        this.protocolViolation(message);
      }
      else if (payload.getType() == WebSocketFrame.TYPE_FRAGMENTED_CONTROL) {
        if (logger.isDebugEnabled()) {
          logger.debug("Fragmented Control Op");
        }
        this.protocolViolation(message);
      }
      else if (payload.getType() == WebSocketFrame.TYPE_PING) {
        try {
          if (logger.isDebugEnabled()) {
            logger.debug("Ping received on " + this.getConnectionId() + ":"
                + new String(payload.getBinary(), "UTF-8"));
          }
          if (payload.getBinary().length > 125) {
            this.protocolViolation(message);
          }
          else {
            WebSocketFrame pong = new WebSocketFrame(WebSocketFrame.TYPE_PONG, payload.getBinary());
            this.send(MessageBuilder.withPayload(pong)
                .copyHeaders(message.getHeaders())
                .build());
          }
        }
        catch (Exception e) {
          throw new MessageHandlingException(message, "Send failed", e);
        }
      }
      else if (payload.getType() == WebSocketFrame.TYPE_PONG) {
        if (logger.isDebugEnabled()) {
          logger.debug("Pong received on " + this.getConnectionId());
        }
      }
      else if (this.shook) {
        return super.onMessage(message);
      }
      else {
        try {
          doHandshake(payload, message.getHeaders());
          this.shook = true;
          WebSocketEvent event = new WebSocketEvent(this.getTheConnection(),
              WebSocketEventType.HANDSHAKE_COMPLETE, state.getPath(), state.getQueryString());
          this.getTheConnection().publishEvent(event);
        }
        catch (Exception e) {
          throw new MessageHandlingException(message, "Handshake failed", e);
        }
      }
      return true;
    }

    private WebSocketState getState(Object object) {
      Object stateKey = null;
      stateKey = this.getTheConnection().getDeserializerStateKey();
      Assert.notNull(stateKey, "StateKey must not be null:" + object);
      WebSocketState state = (WebSocketState) this.getRequiredDeserializer().getState(stateKey);
      Assert.notNull(state, "State must not be null:" + object);
      return state;
    }

    private void protocolViolation(Message<?> message) {
      if (logger.isDebugEnabled()) {
        logger.debug("Protocol violation - closing; " + message);
      }
      WebSocketFrame frame = (WebSocketFrame) message.getPayload();
      String error = "Protocol Error" + frame.getPayload() == null ? "" : (":" + frame.getPayload());
      WebSocketFrame close = new WebSocketFrame(WebSocketFrame.TYPE_CLOSE, error);
      close.setStatus(frame.getType() == WebSocketFrame.TYPE_INVALID_UTF8 ? (short) 1007 : (short) 1002);
      try {
        Object stateKey = this.getTheConnection().getDeserializerStateKey();
        if (stateKey != null) {
          WebSocketState webSocketState = (WebSocketState) this.getRequiredDeserializer().getState(stateKey);
          if (webSocketState != null) {
            webSocketState.setCloseInitiated(true);
          }
          this.send(MessageBuilder.withPayload(close)
              .copyHeaders(message.getHeaders())
              .build());
        }
      }
      catch (Exception e) {
        throw new MessageHandlingException(message, "Send failed", e);
      }
    }

    @Override
    public void close() {
      connections.remove(this.getTheConnection());
      Object stateKey = this.getTheConnection().getDeserializerStateKey();
      if (stateKey != null) {
        this.getRequiredDeserializer().removeState(stateKey);
      }
      super.close();
    }

    private void doHandshake(WebSocketFrame frame, MessageHeaders messageHeaders) throws Exception {
      try {
        WebSocketFrame handshake = this.getRequiredDeserializer().generateHandshake(frame);
        this.send(MessageBuilder.withPayload(handshake)
            .copyHeaders(messageHeaders)
            .build());
      }
      catch (WebSocketUpgradeException e) {
        this.send(MessageBuilder
            .withPayload(
                new WebSocketFrame(WebSocketFrame.TYPE_DATA, "HTTP/1.1 " +
                    e.getMessage() + e.getHeaders()))
            .copyHeaders(messageHeaders)
            .build());
        this.close();
      }
    }

    private WebSocketSerializer getRequiredDeserializer() {
      Deserializer<?> deserializer = this.getDeserializer();
      Assert.state(deserializer instanceof WebSocketSerializer,
          "Deserializer must be a WebSocketSerializer");
      return (WebSocketSerializer) deserializer;
    }

    public Map<String, String> getAdditionalHeaders() {
      Map<String, String> headers = new HashMap<String, String>();
      WebSocketState state = this.getState(this.getConnectionId());
      if (state.getPath() != null) {
        headers.put(WebSocketHeaders.PATH, state.getPath());
      }
      if (state.getQueryString() != null) {
        headers.put(WebSocketHeaders.QUERY_STRING, state.getQueryString());
      }
      return headers;
    }

    @Override
    public void setTheConnection(TcpConnectionSupport theConnection) {
      connections.put(theConnection, this);
      super.setTheConnection(theConnection);
    }
  }

}
TOP

Related Classes of org.springframework.integration.x.ip.websocket.WebSocketTcpConnectionInterceptorFactory$WebSocketTcpConnectionInterceptor

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.