Package org.jruby.rack

Source Code of org.jruby.rack.RubyFirstRackFilter$ResponseCapture

package org.jruby.rack;

import com.strobecorp.kirk.RewindableInputStream;
import org.jruby.rack.servlet.ServletRackEnvironment;
import org.jruby.rack.servlet.ServletRackResponseEnvironment;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
import java.io.*;
import java.net.URLDecoder;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

/**
* This filter implementation is nearly the same as
* org.jruby.rack.RackFilter, except the order of execution is flipped.
* Instead of invoking servlets first and falling back to Ruby, in this
* filter Ruby is given a chance to handle the request first, and only if an
* error is returned does the filter continue to the rest of the
* application. This allows us to incrementally implement the application in
* Ruby while still taking advantage of existing servlets and JSPs.
*
* A version of this filter will probably appear in a future version of
* JRuby-Rack.
*
* @author nicksieger
*/
public class RubyFirstRackFilter implements Filter {
    private RackContext context;
    private RackDispatcher dispatcher;

    public RubyFirstRackFilter() {
    }

    /** Construct a new dispatcher with the servlet context */
    public void init(FilterConfig config) throws ServletException {
        this.context = (RackContext) config.getServletContext().getAttribute(RackApplicationFactory.RACK_CONTEXT);
        this.dispatcher = new DefaultRackDispatcher(this.context);
    }

    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
        throws IOException, ServletException {
        RequestCapture          reqCapture   = new RequestCapture((HttpServletRequest) request);
        ResponseCapture         respCapture  = new ResponseCapture((HttpServletResponse) response);
        RackEnvironment         env          = new ServletRackEnvironment(reqCapture, context);
        RackResponseEnvironment responseEnv  = new ServletRackResponseEnvironment(respCapture);

        request.setAttribute(RackEnvironment.DYNAMIC_REQS_ONLY, Boolean.TRUE);
        dispatcher.process(env, responseEnv);

        if (respCapture.isError()) {
            reqCapture.reset();
            response.reset();
            chain.doFilter(reqCapture, response);
        }
    }

    public void destroy() {
    }

    private static class RequestCapture extends HttpServletRequestWrapper {
        private RewindableInputStream inputStream;
        private Map<String,String[]> requestParams;
        public RequestCapture(HttpServletRequest request) {
            super(request);
        }

        @Override public BufferedReader getReader() throws IOException {
            if (inputStream != null) {
                String enc = getCharacterEncoding();
                if (enc == null) {
                    enc = "UTF-8";
                }
                return new BufferedReader(new InputStreamReader(inputStream, enc));
            } else {
                return super.getReader();
            }
        }

        @Override public ServletInputStream getInputStream() throws IOException {
            if (inputStream == null) {
                inputStream = new RewindableInputStream(super.getInputStream());
            }
            return new ServletInputStream() {
                @Override
                public long skip(long l) throws IOException {
                    return inputStream.skip(l);
                }

                @Override
                public int available() throws IOException {
                    return inputStream.available();
                }

                @Override
                public void close() throws IOException {
                    inputStream.close();
                }

                @Override
                public void mark(int i) {
                    inputStream.mark(i);
                }

                @Override
                public void reset() throws IOException {
                    inputStream.reset();
                }

                @Override
                public boolean markSupported() {
                    return inputStream.markSupported();
                }

                @Override
                public int read(byte[] bytes) throws IOException {
                    return inputStream.read(bytes);
                }

                @Override
                public int read(byte[] bytes, int i, int i1) throws IOException {
                    return inputStream.read(bytes, i, i1);
                }

                @Override
                public int read() throws IOException {
                    return inputStream.read();
                }
            };
        }

        @Override
        public String getParameter(String name) {
            if (getReParsedParameterMap() != null) {
                String[] values = requestParams.get(name);
                if (values != null) {
                    return values[0];
                }
                return null;
            } else {
                return super.getParameter(name);
            }
        }

        @Override
        public Map getParameterMap() {
            if (getReParsedParameterMap() != null) {
                return requestParams;
            } else {
                return super.getParameterMap();
            }
        }

        @Override
        public Enumeration getParameterNames() {
            if (getReParsedParameterMap() != null) {
                return new Enumeration() {
                    Iterator keys = requestParams.keySet().iterator();
                    public boolean hasMoreElements() {
                        return keys.hasNext();
                    }

                    public Object nextElement() {
                        return keys.next();
                    }
                };
            } else {
                return super.getParameterNames();
            }
        }

        @Override
        public String[] getParameterValues(String name) {
            if (getReParsedParameterMap() != null) {
                return requestParams.get(name);
            } else {
                return super.getParameterValues(name);
            }
        }

        private Map getReParsedParameterMap() {
            if (requestParams != null) {
                return requestParams;
            }
            if (inputStream == null || getContentType() == null ||
                    !getContentType().equals("application/x-www-form-urlencoded")) {
                return null;
            }
            // Need to re-parse form params from the request
            // All this because you can't mix use of request#getParameter
            // and request#getInputStream in the Servlet API.
            requestParams = new HashMap<String,String[]>();
            String line = "";
            try {
                line = getReader().readLine();
            } catch (IOException e) {
            }
            String[] pairs = line.split("\\&");
            for (int i = 0; i < pairs.length; i++) {
                try {
                    String[] fields = pairs[i].split("=", 2);
                    String key = URLDecoder.decode(fields[0], "UTF-8");
                    String value = null;
                    if (fields.length == 2) {
                        value = URLDecoder.decode(fields[1], "UTF-8");
                    }
                    if (value != null) {
                        String[] newValues;
                        if (requestParams.containsKey(key)) {
                            String[] values = requestParams.get(key);
                            newValues = new String[values.length + 1];
                            System.arraycopy(values, 0, newValues, 0, values.length);
                            newValues[values.length] = value;
                        } else {
                            newValues = new String[1];
                            newValues[0] = value;
                        }
                        requestParams.put(key, newValues);
                    }
                } catch (UnsupportedEncodingException e) {
                }
            }
            return requestParams;
        }

        public void reset() throws IOException {
            if (inputStream != null) {
                inputStream.rewind();
            }
        }
    }

    private static class ResponseCapture extends HttpServletResponseWrapper {
        private int status = 200;

        public ResponseCapture(HttpServletResponse response) {
            super(response);
        }

        @Override public void sendError(int status, String message) throws IOException {
            this.status = status;
        }

        @Override public void sendError(int status) throws IOException {
            this.status = status;
        }

        @Override public void sendRedirect(String path) throws IOException {
            this.status = 302;
            super.sendRedirect(path);
        }

        @Override public void setStatus(int status) {
            this.status = status;
            if (!isError()) {
                super.setStatus(status);
            }
        }

        @Override public void setStatus(int status, String message) {
            this.status = status;
            if (!isError()) {
                super.setStatus(status, message);
            }
        }

        @Override public void flushBuffer() throws IOException {
            if (!isError()) {
                super.flushBuffer();
            }
        }

        @Override public ServletOutputStream getOutputStream() throws IOException {
            if (isError()) {
                // swallow output, because we're going to discard it
                return new ServletOutputStream() {
                    @Override public void write(int b) throws IOException {
                    }
                };
            } else {
                return super.getOutputStream();
            }
        }

        @Override
        public PrintWriter getWriter() throws IOException {
            if (isError()) {
                // swallow output, because we're going to discard it
                return new PrintWriter(new OutputStream() {
                    @Override public void write(int i) throws IOException {
                    }
                });
            } else {
                return super.getWriter();
            }
        }

        private boolean isError() {
            return status >= 400;
        }
    }
}
TOP

Related Classes of org.jruby.rack.RubyFirstRackFilter$ResponseCapture

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.