001    /**
002     * Copyright (c) 2000-2010 Liferay, Inc. All rights reserved.
003     *
004     * The contents of this file are subject to the terms of the Liferay Enterprise
005     * Subscription License ("License"). You may not use this file except in
006     * compliance with the License. You can obtain a copy of the License by
007     * contacting Liferay, Inc. See the License for the specific language governing
008     * permissions and limitations under the License, including but not limited to
009     * distribution rights of the Software.
010     *
011     *
012     *
013     */
014    
015    package com.liferay.portal.kernel.util;
016    
017    import com.liferay.portal.kernel.log.Log;
018    import com.liferay.portal.kernel.log.LogFactoryUtil;
019    
020    import java.lang.ref.WeakReference;
021    import java.lang.reflect.InvocationTargetException;
022    import java.lang.reflect.Method;
023    
024    import java.util.ArrayList;
025    import java.util.Collection;
026    import java.util.Iterator;
027    import java.util.List;
028    
029    /**
030     * @author Brian Wing Shun Chan
031     * @author Michael C. Han
032     * @author Shuyang Zhou
033     */
034    public class AggregateClassLoader extends ClassLoader {
035    
036            public static ClassLoader getAggregateClassLoader(
037                    ClassLoader parentClassLoader, ClassLoader[] classLoaders) {
038    
039                    if ((classLoaders == null) || (classLoaders.length == 0)) {
040                            return null;
041                    }
042    
043                    if (classLoaders.length == 1) {
044                            return classLoaders[0];
045                    }
046    
047                    AggregateClassLoader aggregateClassLoader = new AggregateClassLoader(
048                            parentClassLoader);
049    
050                    for (ClassLoader classLoader : classLoaders) {
051                            aggregateClassLoader.addClassLoader(classLoader);
052                    }
053    
054                    return aggregateClassLoader;
055            }
056    
057            public static ClassLoader getAggregateClassLoader(
058                    ClassLoader[] classLoaders) {
059    
060                    if ((classLoaders == null) || (classLoaders.length == 0)) {
061                            return null;
062                    }
063    
064                    return getAggregateClassLoader(classLoaders[0], classLoaders);
065            }
066    
067            public AggregateClassLoader(ClassLoader classLoader) {
068                    _parentClassLoaderReference = new WeakReference<ClassLoader>(
069                            classLoader);
070            }
071    
072            public void addClassLoader(ClassLoader classLoader) {
073                    if (getClassLoaders().contains(classLoader)) {
074                            return;
075                    }
076    
077                    if ((classLoader instanceof AggregateClassLoader) &&
078                            (classLoader.getParent().equals(getParent()))){
079    
080                            AggregateClassLoader aggregateClassLoader =
081                                    (AggregateClassLoader)classLoader;
082    
083                            for (ClassLoader curClassLoader :
084                                            aggregateClassLoader.getClassLoaders()) {
085    
086                                    addClassLoader(curClassLoader);
087                            }
088                    }
089                    else {
090                            _classLoaderReferences.add(
091                                    new WeakReference<ClassLoader>(classLoader));
092                    }
093            }
094    
095            public void addClassLoader(ClassLoader... classLoaders) {
096                    for (ClassLoader classLoader : classLoaders) {
097                            addClassLoader(classLoader);
098                    }
099            }
100    
101            public void addClassLoader(Collection<ClassLoader> classLoaders) {
102                    for (ClassLoader classLoader : classLoaders) {
103                            addClassLoader(classLoader);
104                    }
105            }
106    
107            public boolean equals(Object obj) {
108                    if (this == obj) {
109                            return true;
110                    }
111    
112                    if (!(obj instanceof AggregateClassLoader)) {
113                            return false;
114                    }
115    
116                    AggregateClassLoader aggregateClassLoader = (AggregateClassLoader)obj;
117    
118                    if (_classLoaderReferences.equals(
119                                    aggregateClassLoader._classLoaderReferences) &&
120                            (((getParent() == null) &&
121                              (aggregateClassLoader.getParent() == null)) ||
122                             ((getParent() != null) &&
123                              (getParent().equals(aggregateClassLoader.getParent()))))) {
124    
125                            return true;
126                    }
127    
128                    return false;
129            }
130    
131            public List<ClassLoader> getClassLoaders() {
132                    List<ClassLoader> classLoaders = new ArrayList<ClassLoader>(
133                            _classLoaderReferences.size());
134    
135                    Iterator<WeakReference<ClassLoader>> itr =
136                            _classLoaderReferences.iterator();
137    
138                    while (itr.hasNext()) {
139                            WeakReference<ClassLoader> weakReference = itr.next();
140    
141                            ClassLoader classLoader = weakReference.get();
142    
143                            if (classLoader == null) {
144                                    itr.remove();
145                            }
146                            else {
147                                    classLoaders.add(classLoader);
148                            }
149                    }
150    
151                    return classLoaders;
152            }
153    
154            public int hashCode() {
155                    if (_classLoaderReferences != null) {
156                            return _classLoaderReferences.hashCode();
157                    }
158                    else {
159                            return 0;
160                    }
161            }
162    
163            protected Class<?> findClass(String name) throws ClassNotFoundException {
164                    for (ClassLoader classLoader : getClassLoaders()) {
165                            try {
166                                    return _findClass(classLoader, name);
167                            }
168                            catch (ClassNotFoundException cnfe) {
169                            }
170                    }
171    
172                    throw new ClassNotFoundException("Unable to find class " + name);
173            }
174    
175            protected Class<?> loadClass(String name, boolean resolve)
176                    throws ClassNotFoundException {
177    
178                    Class<?> loadedClass = null;
179    
180                    for (ClassLoader classLoader : getClassLoaders()) {
181                            try {
182                                    loadedClass = _loadClass(classLoader, name, resolve);
183    
184                                    break;
185                            }
186                            catch (ClassNotFoundException cnfe) {
187                            }
188                    }
189    
190                    if (loadedClass == null) {
191                            ClassLoader parentClassLoader = _parentClassLoaderReference.get();
192    
193                            if (parentClassLoader == null) {
194                                    throw new ClassNotFoundException(
195                                            "Parent class loader has been garbage collected");
196                            }
197    
198                            loadedClass = _loadClass(parentClassLoader, name, resolve);
199                    }
200                    else if (resolve) {
201                            resolveClass(loadedClass);
202                    }
203    
204                    return loadedClass;
205            }
206    
207            private static Log _log = LogFactoryUtil.getLog(AggregateClassLoader.class);
208    
209            private List<WeakReference<ClassLoader>> _classLoaderReferences =
210                    new ArrayList<WeakReference<ClassLoader>>();
211    
212            private static Class<?> _findClass(ClassLoader classLoader, String name)
213                    throws ClassNotFoundException {
214    
215                    try {
216                            return (Class<?>) _findClassMethod.invoke(classLoader, name);
217                    }
218                    catch (InvocationTargetException ite) {
219                            throw new ClassNotFoundException(
220                                    "Unable to find class " + name, ite.getTargetException());
221                    }
222                    catch (Exception e) {
223                            throw new ClassNotFoundException("Unable to find class " + name, e);
224                    }
225            }
226    
227            private static Class<?> _loadClass(
228                            ClassLoader classLoader, String name, boolean resolve)
229                    throws ClassNotFoundException {
230    
231                    try {
232                            return (Class<?>) _loadClassMethod.invoke(
233                                    classLoader, name, resolve);
234                    }
235                    catch (InvocationTargetException ite) {
236                            throw new ClassNotFoundException(
237                                    "Unable to load class " + name, ite.getTargetException());
238                    }
239                    catch (Exception e) {
240                            throw new ClassNotFoundException(
241                                    "Unable to load class " + name, e);
242                    }
243            }
244    
245            private static Method _findClassMethod;
246            private static Method _loadClassMethod;
247    
248            private WeakReference<ClassLoader> _parentClassLoaderReference;
249    
250            static {
251                    try {
252                            _findClassMethod = ReflectionUtil.getDeclaredMethod(
253                                    ClassLoader.class, "findClass", String.class);
254                            _loadClassMethod = ReflectionUtil.getDeclaredMethod(
255                                    ClassLoader.class, "loadClass", String.class, boolean.class);
256                    }
257                    catch (Exception e) {
258                            if (_log.isErrorEnabled()) {
259                                    _log.error("Unable to locate required methods", e);
260                            }
261                    }
262            }
263    
264    }