package restx.security;
import com.google.common.base.Optional;
import com.google.common.base.Predicate;
import com.google.common.base.Predicates;
import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableCollection;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import restx.RestxRequest;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import static com.google.common.base.Preconditions.checkNotNull;
/**
* User: xavierhanin
* Date: 2/8/13
* Time: 1:29 PM
*/
public class StdCORSAuthorizer implements CORSAuthorizer {
private static final Logger logger = LoggerFactory.getLogger(StdCORSAuthorizer.class);
public static class Builder {
private Predicate<CharSequence> originMatcher = Predicates.alwaysTrue();
private Predicate<CharSequence> pathMatcher = Predicates.alwaysTrue();
private ImmutableCollection<String> allowedMethods = ImmutableSet.of("GET");
private ImmutableCollection<String> allowedHeaders = ImmutableSet.of();
private Optional<Boolean> allowCredentials = Optional.absent();
private int maxAge = 1728000;
public Builder setOriginMatcher(Predicate<CharSequence> originMatcher) {
this.originMatcher = originMatcher;
return this;
}
public Builder setPathMatcher(Predicate<CharSequence> pathMatcher) {
this.pathMatcher = pathMatcher;
return this;
}
public Builder setAllowedMethods(ImmutableCollection<String> allowedMethods) {
this.allowedMethods = allowedMethods;
return this;
}
public Builder setAllowedHeaders(ImmutableCollection<String> allowedHeaders) {
this.allowedHeaders = allowedHeaders;
return this;
}
public Builder setAllowCredentials(Optional<Boolean> allowCredentials) {
this.allowCredentials = allowCredentials;
return this;
}
public Builder setMaxAge(final int maxAge) {
this.maxAge = maxAge;
return this;
}
public StdCORSAuthorizer build() {
return new StdCORSAuthorizer(originMatcher, pathMatcher, allowedMethods, allowedHeaders, allowCredentials, maxAge);
}
}
public static Builder builder() {
return new Builder();
}
private final Predicate<CharSequence> originMatcher;
private final Predicate<CharSequence> pathMatcher;
private final ImmutableCollection<String> allowedMethods;
private final ImmutableCollection<String> allowedHeaders;
private final Optional<Boolean> allowCredentials;
private final int maxAge;
public StdCORSAuthorizer(Predicate<CharSequence> originMatcher, Predicate<CharSequence> pathMatcher,
ImmutableCollection<String> allowedMethods, ImmutableCollection<String> allowedHeaders,
Optional<Boolean> allowCredentials, int maxAge) {
this.maxAge = maxAge;
this.originMatcher = checkNotNull(originMatcher);
this.pathMatcher = checkNotNull(pathMatcher);
this.allowedMethods = checkNotNull(allowedMethods);
this.allowedHeaders = checkNotNull(toLowerCase(allowedHeaders));
this.allowCredentials = checkNotNull(allowCredentials);
}
private ImmutableCollection<String> toLowerCase(ImmutableCollection<String> strings) {
ImmutableList.Builder<String> builder = ImmutableList.builder();
for (String string : strings) {
builder.add(string.toLowerCase(Locale.ENGLISH));
}
return builder.build();
}
@Override
public Optional<CORS> checkCORS(RestxRequest request, String origin, String method, String restxPath) {
if (originMatcher.apply(origin) && pathMatcher.apply(restxPath)) {
if (!Iterables.contains(allowedMethods, method)) {
logger.debug("CORS request not accepted by {}: method not allowed {}\n" +
"REQUEST => {}", this, method, request);
return Optional.of(CORS.reject());
}
if (!checkAllowed(request, "Access-Control-Request-Methods", allowedMethods)) {
return Optional.of(CORS.reject());
}
if (!checkAllowed(request, "Access-Control-Request-Headers", allowedHeaders)) {
return Optional.of(CORS.reject());
}
return Optional.of(CORS.accept(origin, allowedMethods, allowedHeaders, allowCredentials, maxAge));
}
return Optional.absent();
}
private boolean checkAllowed(RestxRequest request, String headerName, ImmutableCollection<String> allowed) {
Optional<String> requestProperty = request.getHeader(headerName);
if (requestProperty.isPresent()) {
for (String s : Splitter.on(',').trimResults().split(requestProperty.get())) {
if (!allowed.contains(s.toLowerCase(Locale.ENGLISH))) {
logger.debug("CORS request not accepted by {}: {} not allowed: {}\nREQUEST => {}", this, headerName, s, request);
return false;
}
}
}
return true;
}
@Override
public String toString() {
return "StdCORSAuthorizer{" +
"originMatcher=" + originMatcher +
", pathMatcher=" + pathMatcher +
", allowedMethods=" + allowedMethods +
", allowedHeaders=" + allowedHeaders +
", allowCredentials=" + allowCredentials +
'}';
}
}