package org.xins.common.servlet.container;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.ArrayList;
import java.util.List;
import java.util.StringTokenizer;
import java.util.jar.JarEntry;
import java.util.jar.JarInputStream;
public class ServletClassLoader {
public static final int USE_CURRENT_CLASSPATH = 1;
public static final int USE_CLASSPATH_LIB = 2;
public static final int USE_XINS_LIB = 3;
public static final int USE_WAR_LIB = 4;
public static final int USE_WAR_EXTERNAL_LIB = 5;
public static ClassLoader getServletClassLoader(File warFile, int mode) throws IOException {
if (mode == USE_CURRENT_CLASSPATH) {
return ServletClassLoader.class.getClassLoader();
}
List<URL> urlList = new ArrayList<URL>();
urlList.add(warFile.toURI().toURL());
if (mode != USE_WAR_EXTERNAL_LIB) {
URL classesURL = new URL("jar:file:" + warFile.getAbsolutePath().replace(File.separatorChar, '/') + "!/WEB-INF/classes/");
urlList.add(classesURL);
}
List<String> standardLibs = new ArrayList<String>();
if (mode == USE_XINS_LIB) {
String classLocation = ServletClassLoader.class.getProtectionDomain().getCodeSource().getLocation().toString();
String commonJar = classLocation.substring(6).replace('/', File.separatorChar);
if (!commonJar.endsWith("xins-common.jar")) {
String xinsHome = System.getenv("XINS_HOME");
commonJar = xinsHome + File.separator + "build" + File.separator + "xins-common.jar";
}
File baseDir = new File(commonJar).getParentFile();
File[] xinsFiles = baseDir.listFiles();
for (int i = 0; i < xinsFiles.length; i++) {
if (xinsFiles[i].getName().endsWith(".jar")) {
urlList.add(xinsFiles[i].toURI().toURL());
}
}
File libDir = new File(baseDir, ".." + File.separator + "lib");
File[] libFiles = libDir.listFiles();
for (int i = 0; i < libFiles.length; i++) {
if (libFiles[i].getName().endsWith(".jar")) {
urlList.add(libFiles[i].toURI().toURL());
}
}
}
if (mode == USE_CLASSPATH_LIB || mode == USE_WAR_EXTERNAL_LIB) {
String classPath = System.getProperty("java.class.path");
StringTokenizer stClassPath = new StringTokenizer(classPath, File.pathSeparator);
while (stClassPath.hasMoreTokens()) {
String nextPath = stClassPath.nextToken();
if (nextPath.toLowerCase().endsWith(".jar")) {
standardLibs.add(nextPath.substring(nextPath.lastIndexOf(File.separatorChar) + 1));
}
urlList.add(new File(nextPath).toURI().toURL());
}
}
if (mode == USE_WAR_LIB || mode == USE_WAR_EXTERNAL_LIB) {
JarInputStream jarStream = new JarInputStream(new FileInputStream(warFile));
JarEntry entry = jarStream.getNextJarEntry();
while(entry != null) {
String entryName = entry.getName();
if (entryName.startsWith("WEB-INF/lib/") && entryName.endsWith(".jar") && !standardLibs.contains(entryName.substring(12))) {
File tempJarFile = unpack(jarStream, entryName);
urlList.add(tempJarFile.toURI().toURL());
}
entry = jarStream.getNextJarEntry();
}
jarStream.close();
}
URL[] urls = new URL[urlList.size()];
for (int i=0; i<urlList.size(); i++) {
urls[i] = (URL) urlList.get(i);
}
ClassLoader loader = new ChildFirstClassLoader(urls, ServletClassLoader.class.getClassLoader());
Thread.currentThread().setContextClassLoader(loader);
return loader;
}
private static File unpack(JarInputStream jarStream, String entryName) throws IOException {
String libName = entryName.substring(entryName.lastIndexOf('/') + 1, entryName.length() - 4);
File tempJarFile = File.createTempFile("tmp_" + libName, ".jar");
tempJarFile.deleteOnExit();
FileOutputStream out = new FileOutputStream(tempJarFile);
byte[] buf = new byte[8192];
int len;
while ((len = jarStream.read(buf)) > 0) {
out.write(buf, 0, len);
}
out.close();
return tempJarFile;
}
private static class ChildFirstClassLoader extends URLClassLoader {
public ChildFirstClassLoader(URL[] urls) {
super(urls);
}
public ChildFirstClassLoader(URL[] urls, ClassLoader parent) {
super(urls, parent);
}
public void addURL(URL url) {
super.addURL(url);
}
public Class loadClass(String name) throws ClassNotFoundException {
return loadClass(name, false);
}
protected Class loadClass(String name, boolean resolve)
throws ClassNotFoundException {
Class c = findLoadedClass(name);
if (c == null) {
try {
c = findClass(name);
} catch(ClassNotFoundException cnfe) {
}
}
if (c == null) {
if (getParent() != null) {
c = getParent().loadClass(name);
} else {
c = getSystemClassLoader().loadClass(name);
}
}
if (resolve) {
resolveClass(c);
}
return c;
}
public URL getResource(String name) {
URL url = findResource(name);
if (url == null) {
url = getParent().getResource(name);
}
return url;
}
}
}