/*
* Copyright 2004-2011 H2 Group. Multiple-Licensed under the H2 License,
* Version 1.0, and under the Eclipse Public License, Version 1.0
* (http://h2database.com/html/license.html).
* Initial Developer: H2 Group
*/
package org.h2.test.utils;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.HashMap;
import java.util.TreeMap;
import java.util.TreeSet;
import org.h2.util.New;
import org.h2.util.SourceCompiler;
/**
* A code generator for class proxies.
*/
public class ProxyCodeGenerator {
private static SourceCompiler compiler = new SourceCompiler();
private static HashMap<Class<?>, Class<?>> proxyMap = New.hashMap();
private TreeSet<String> imports = new TreeSet<String>();
private TreeMap<String, Method> methods = new TreeMap<String, Method>();
private String packageName;
private String className;
private Class<?> extendsClass;
private Constructor<?> constructor;
/**
* Check whether there is already a proxy class generated.
*
* @param c the class
* @return true if yes
*/
public static boolean isGenerated(Class<?> c) {
return proxyMap.containsKey(c);
}
/**
* Generate a proxy class. The returned class extends the given class.
*
* @param c the class to extend
* @return the proxy class
*/
public static Class<?> getClassProxy(Class<?> c) throws ClassNotFoundException {
Class<?> p = proxyMap.get(c);
if (p != null) {
return p;
}
// TODO how to extend a class with private constructor
// TODO call right constructor
// TODO use the right package
ProxyCodeGenerator cg = new ProxyCodeGenerator();
cg.setPackageName("bytecode");
cg.generateClassProxy(c);
StringWriter sw = new StringWriter();
cg.write(new PrintWriter(sw));
String code = sw.toString();
String proxy = "bytecode."+ c.getSimpleName() + "Proxy";
compiler.setSource(proxy, code);
// System.out.println(code);
Class<?> px = compiler.getClass(proxy);
proxyMap.put(c, px);
return px;
}
private void setPackageName(String packageName) {
this.packageName = packageName;
}
/**
* Generate a class that implements all static methods of the given class,
* but as non-static.
*
* @param c the class to extend
*/
void generateStaticProxy(Class<?> clazz) {
imports.clear();
addImport(InvocationHandler.class);
addImport(Method.class);
addImport(clazz);
className = getClassName(clazz) + "Proxy";
for (Method m : clazz.getDeclaredMethods()) {
if (Modifier.isStatic(m.getModifiers())) {
if (!Modifier.isPrivate(m.getModifiers())) {
addMethod(m);
}
}
}
}
private void generateClassProxy(Class<?> clazz) {
imports.clear();
addImport(InvocationHandler.class);
addImport(Method.class);
addImport(clazz);
className = getClassName(clazz) + "Proxy";
extendsClass = clazz;
int doNotOverride = Modifier.FINAL | Modifier.STATIC |
Modifier.PRIVATE | Modifier.ABSTRACT | Modifier.VOLATILE;
Class<?> dc = clazz;
while (dc != null) {
addImport(dc);
for (Method m : dc.getDeclaredMethods()) {
if ((m.getModifiers() & doNotOverride) == 0) {
addMethod(m);
}
}
dc = dc.getSuperclass();
}
for (Constructor<?> c : clazz.getDeclaredConstructors()) {
if (Modifier.isPrivate(c.getModifiers())) {
continue;
}
if (constructor == null) {
constructor = c;
} else if (c.getParameterTypes().length < constructor.getParameterTypes().length) {
constructor = c;
}
}
}
private void addMethod(Method m) {
if (methods.containsKey(getMethodName(m))) {
// already declared in a subclass
return;
}
addImport(m.getReturnType());
for (Class<?> c : m.getParameterTypes()) {
addImport(c);
}
for (Class<?> c : m.getExceptionTypes()) {
addImport(c);
}
methods.put(getMethodName(m), m);
}
private String getMethodName(Method m) {
StringBuilder buff = new StringBuilder();
buff.append(m.getReturnType()).append(' ');
buff.append(m.getName());
for (Class<?> p : m.getParameterTypes()) {
buff.append(' ');
buff.append(p.getName());
}
return buff.toString();
}
private void addImport(Class<?> c) {
while (c.isArray()) {
c = c.getComponentType();
}
if (!c.isPrimitive()) {
if (!"java.lang".equals(c.getPackage().getName())) {
imports.add(c.getName());
}
}
}
private static String getClassName(Class<?> c) {
return getClassName(c, false);
}
private static String getClassName(Class<?> c, boolean varArg) {
if (varArg) {
c = c.getComponentType();
}
String s = c.getSimpleName();
while (true) {
c = c.getEnclosingClass();
if (c == null) {
break;
}
s = c.getSimpleName() + "." + s;
}
if (varArg) {
return s + "...";
}
return s;
}
private void write(PrintWriter writer) {
if (packageName != null) {
writer.println("package " + packageName + ";");
}
for (String imp : imports) {
writer.println("import " + imp + ";");
}
writer.print("public class " + className);
if (extendsClass != null) {
writer.print(" extends " + getClassName(extendsClass));
}
writer.println(" {");
writer.println(" private final InvocationHandler ih;");
writer.println(" public " + className + "() {");
writer.println(" this(new InvocationHandler() {");
writer.println(" public Object invoke(Object proxy,");
writer.println(" Method method, Object[] args) throws Throwable {");
writer.println(" return method.invoke(proxy, args);");
writer.println(" }});");
writer.println(" }");
writer.println(" public " + className + "(InvocationHandler ih) {");
if (constructor != null) {
writer.print(" super(");
int i = 0;
for (Class<?> p : constructor.getParameterTypes()) {
if (i > 0) {
writer.print(", ");
}
if (p.isPrimitive()) {
if (p == boolean.class) {
writer.print("false");
} else if (p == byte.class) {
writer.print("(byte) 0");
} else if (p == char.class) {
writer.print("(char) 0");
} else if (p == short.class) {
writer.print("(short) 0");
} else if (p == int.class) {
writer.print("0");
} else if (p == long.class) {
writer.print("0L");
} else if (p == float.class) {
writer.print("0F");
} else if (p == double.class) {
writer.print("0D");
}
} else {
writer.print("null");
}
i++;
}
writer.println(");");
}
writer.println(" this.ih = ih;");
writer.println(" }");
writer.println(" @SuppressWarnings(\"unchecked\")");
writer.println(" private static <T extends RuntimeException> T convertException(Throwable e) {");
writer.println(" if (e instanceof Error) {");
writer.println(" throw (Error) e;");
writer.println(" }");
writer.println(" return (T) e;");
writer.println(" }");
for (Method m : methods.values()) {
Class<?> retClass = m.getReturnType();
writer.print(" ");
if (Modifier.isProtected(m.getModifiers())) {
// 'public' would also work
writer.print("protected ");
} else {
writer.print("public ");
}
writer.print(getClassName(retClass) +
" " + m.getName() + "(");
Class<?>[] pc = m.getParameterTypes();
for (int i = 0; i < pc.length; i++) {
Class<?> p = pc[i];
if (i > 0) {
writer.print(", ");
}
boolean varArg = i == pc.length - 1 && m.isVarArgs();
writer.print(getClassName(p, varArg) + " p" + i);
}
writer.print(")");
Class<?>[] ec = m.getExceptionTypes();
writer.print(" throws RuntimeException");
if (ec.length > 0) {
for (Class<?> e : ec) {
writer.print(", ");
writer.print(getClassName(e));
}
}
writer.println(" {");
writer.println(" try {");
writer.print(" ");
if (retClass != void.class) {
writer.print("return (");
if (retClass == boolean.class) {
writer.print("Boolean");
} else if (retClass == byte.class) {
writer.print("Byte");
} else if (retClass == char.class) {
writer.print("Character");
} else if (retClass == short.class) {
writer.print("Short");
} else if (retClass == int.class) {
writer.print("Integer");
} else if (retClass == long.class) {
writer.print("Long");
} else if (retClass == float.class) {
writer.print("Float");
} else if (retClass == double.class) {
writer.print("Double");
} else {
writer.print(getClassName(retClass));
}
writer.print(") ");
}
writer.print("ih.invoke(this, ");
writer.println(getClassName(m.getDeclaringClass()) +
".class.getDeclaredMethod(\"" + m.getName() +
"\",");
writer.print(" new Class[] {");
int i = 0;
for (Class<?> p : m.getParameterTypes()) {
if (i > 0) {
writer.print(", ");
}
writer.print(getClassName(p) + ".class");
i++;
}
writer.println("}),");
writer.print(" new Object[] {");
for (i = 0; i < m.getParameterTypes().length; i++) {
if (i > 0) {
writer.print(", ");
}
writer.print("p" + i);
}
writer.println("});");
writer.println(" } catch (Throwable e) {");
writer.println(" throw convertException(e);");
writer.println(" }");
writer.println(" }");
}
writer.println("}");
writer.flush();
}
public static String methodCallFormatter(Method m, Object... args) {
StringBuilder buff = new StringBuilder();
buff.append(m.getName()).append('(');
for (int i = 0; i < args.length; i++) {
Object a = args[i];
if (i > 0) {
buff.append(", ");
}
buff.append(a == null ? "null" : a.toString());
}
buff.append(")");
return buff.toString();
}
}