1   /**
2    * Copyright (c) 2000-2010 Liferay, Inc. All rights reserved.
3    *
4    * The contents of this file are subject to the terms of the Liferay Enterprise
5    * Subscription License ("License"). You may not use this file except in
6    * compliance with the License. You can obtain a copy of the License by
7    * contacting Liferay, Inc. See the License for the specific language governing
8    * permissions and limitations under the License, including but not limited to
9    * distribution rights of the Software.
10   *
11   *
12   *
13   */
14  
15  package com.liferay.portal.dao.shard;
16  
17  import com.liferay.counter.service.persistence.CounterPersistence;
18  import com.liferay.portal.NoSuchCompanyException;
19  import com.liferay.portal.PortalException;
20  import com.liferay.portal.SystemException;
21  import com.liferay.portal.kernel.log.Log;
22  import com.liferay.portal.kernel.log.LogFactoryUtil;
23  import com.liferay.portal.kernel.util.InfrastructureUtil;
24  import com.liferay.portal.kernel.util.InitialThreadLocal;
25  import com.liferay.portal.kernel.util.StringPool;
26  import com.liferay.portal.kernel.util.StringUtil;
27  import com.liferay.portal.model.Company;
28  import com.liferay.portal.model.Portlet;
29  import com.liferay.portal.model.Shard;
30  import com.liferay.portal.security.auth.CompanyThreadLocal;
31  import com.liferay.portal.service.CompanyLocalServiceUtil;
32  import com.liferay.portal.service.ShardLocalServiceUtil;
33  import com.liferay.portal.service.persistence.ClassNamePersistence;
34  import com.liferay.portal.service.persistence.CompanyPersistence;
35  import com.liferay.portal.service.persistence.ReleasePersistence;
36  import com.liferay.portal.service.persistence.ShardPersistence;
37  import com.liferay.portal.util.PropsValues;
38  
39  import java.util.HashMap;
40  import java.util.Map;
41  import java.util.Stack;
42  
43  import javax.sql.DataSource;
44  
45  import org.aspectj.lang.ProceedingJoinPoint;
46  
47  /**
48   * <a href="ShardAdvice.java.html"><b><i>View Source</i></b></a>
49   *
50   * @author Michael Young
51   * @author Alexander Chow
52   */
53  public class ShardAdvice {
54  
55      public void afterPropertiesSet() {
56          if (_shardDataSourceTargetSource == null) {
57              _shardDataSourceTargetSource =
58                  (ShardDataSourceTargetSource)InfrastructureUtil.
59                      getShardDataSourceTargetSource();
60          }
61  
62          if (_shardSessionFactoryTargetSource == null) {
63              _shardSessionFactoryTargetSource =
64                  (ShardSessionFactoryTargetSource)InfrastructureUtil.
65                      getShardSessionFactoryTargetSource();
66          }
67      }
68  
69      public Object invokeByParameter(ProceedingJoinPoint proceedingJoinPoint)
70          throws Throwable {
71  
72          Object[] arguments = proceedingJoinPoint.getArgs();
73  
74          long companyId = (Long)arguments[0];
75  
76          Shard shard = ShardLocalServiceUtil.getShard(
77              Company.class.getName(), companyId);
78  
79          String shardName = shard.getName();
80  
81          if (_log.isInfoEnabled()) {
82              _log.info(
83                  "Service being set to shard " + shardName + " for " +
84                      _getSignature(proceedingJoinPoint));
85          }
86  
87          Object returnValue = null;
88  
89          pushCompanyService(shardName);
90  
91          try {
92              returnValue = proceedingJoinPoint.proceed();
93          }
94          finally {
95              popCompanyService();
96          }
97  
98          return returnValue;
99      }
100 
101     public Object invokeCompanyService(ProceedingJoinPoint proceedingJoinPoint)
102         throws Throwable {
103 
104         String methodName = proceedingJoinPoint.getSignature().getName();
105         Object[] arguments = proceedingJoinPoint.getArgs();
106 
107         String shardName = PropsValues.SHARD_DEFAULT_NAME;
108 
109         if (methodName.equals("addCompany")) {
110             String webId = (String)arguments[0];
111             String virtualHost = (String)arguments[1];
112             String mx = (String)arguments[2];
113             shardName = (String)arguments[3];
114 
115             shardName = _getCompanyShardName(webId, virtualHost, mx, shardName);
116 
117             arguments[3] = shardName;
118         }
119         else if (methodName.equals("checkCompany")) {
120             String webId = (String)arguments[0];
121 
122             if (!webId.equals(PropsValues.COMPANY_DEFAULT_WEB_ID)) {
123                 if (arguments.length == 3) {
124                     String mx = (String)arguments[1];
125                     shardName = (String)arguments[2];
126 
127                     shardName = _getCompanyShardName(
128                         webId, null, mx, shardName);
129 
130                     arguments[2] = shardName;
131                 }
132 
133                 try {
134                     Company company = CompanyLocalServiceUtil.getCompanyByWebId(
135                         webId);
136 
137                     shardName = company.getShardName();
138                 }
139                 catch (NoSuchCompanyException nsce) {
140                 }
141             }
142         }
143         else if (methodName.startsWith("update")) {
144             long companyId = (Long)arguments[0];
145 
146             Shard shard = ShardLocalServiceUtil.getShard(
147                 Company.class.getName(), companyId);
148 
149             shardName = shard.getName();
150         }
151         else {
152             return proceedingJoinPoint.proceed();
153         }
154 
155         if (_log.isInfoEnabled()) {
156             _log.info(
157                 "Company service being set to shard " + shardName + " for " +
158                     _getSignature(proceedingJoinPoint));
159         }
160 
161         Object returnValue = null;
162 
163         pushCompanyService(shardName);
164 
165         try {
166             returnValue = proceedingJoinPoint.proceed(arguments);
167         }
168         finally {
169             popCompanyService();
170         }
171 
172         return returnValue;
173     }
174 
175     /**
176      * Invoke a join point across all shards while ignoring the company service
177      * stack.
178      *
179      * @see #invokeIteratively
180      */
181     public Object invokeGlobally(ProceedingJoinPoint proceedingJoinPoint)
182         throws Throwable {
183 
184         _globalCall.set(new Object());
185 
186         try {
187             if (_log.isInfoEnabled()) {
188                 _log.info(
189                     "All shards invoked for " +
190                         _getSignature(proceedingJoinPoint));
191             }
192 
193             for (String shardName : PropsValues.SHARD_AVAILABLE_NAMES) {
194                 _shardDataSourceTargetSource.setDataSource(shardName);
195                 _shardSessionFactoryTargetSource.setSessionFactory(shardName);
196 
197                 proceedingJoinPoint.proceed();
198             }
199         }
200         finally {
201             _globalCall.set(null);
202         }
203 
204         return null;
205     }
206 
207     /**
208      * Invoke a join point across all shards while using the company service
209      * stack.
210      *
211      * @see #invokeGlobally
212      */
213     public Object invokeIteratively(ProceedingJoinPoint proceedingJoinPoint)
214         throws Throwable {
215 
216         if (_log.isInfoEnabled()) {
217             _log.info(
218                 "Iterating through all shards for " +
219                     _getSignature(proceedingJoinPoint));
220         }
221 
222         for (String shardName : PropsValues.SHARD_AVAILABLE_NAMES) {
223             pushCompanyService(shardName);
224 
225             try {
226                 proceedingJoinPoint.proceed();
227             }
228             finally {
229                 popCompanyService();
230             }
231         }
232 
233         return null;
234     }
235 
236     public Object invokePersistence(ProceedingJoinPoint proceedingJoinPoint)
237         throws Throwable {
238 
239         if ((_shardDataSourceTargetSource == null) ||
240             (_shardSessionFactoryTargetSource == null)) {
241 
242             return proceedingJoinPoint.proceed();
243         }
244 
245         Object target = proceedingJoinPoint.getTarget();
246 
247         if (target instanceof ClassNamePersistence ||
248             target instanceof CompanyPersistence ||
249             target instanceof CounterPersistence ||
250             target instanceof ReleasePersistence ||
251             target instanceof ShardPersistence) {
252 
253             _shardDataSourceTargetSource.setDataSource(
254                 PropsValues.SHARD_DEFAULT_NAME);
255             _shardSessionFactoryTargetSource.setSessionFactory(
256                 PropsValues.SHARD_DEFAULT_NAME);
257 
258             if (_log.isDebugEnabled()) {
259                 _log.debug(
260                     "Using default shard for " +
261                         _getSignature(proceedingJoinPoint));
262             }
263 
264             return proceedingJoinPoint.proceed();
265         }
266 
267         if (_globalCall.get() == null) {
268             _setShardNameByCompany();
269 
270             String shardName = _getShardName();
271 
272             _shardDataSourceTargetSource.setDataSource(shardName);
273             _shardSessionFactoryTargetSource.setSessionFactory(shardName);
274 
275             if (_log.isInfoEnabled()) {
276                 _log.info(
277                     "Using shard name " + shardName + " for " +
278                         _getSignature(proceedingJoinPoint));
279             }
280 
281             return proceedingJoinPoint.proceed();
282         }
283         else {
284             return proceedingJoinPoint.proceed();
285         }
286     }
287 
288     public Object invokePortletService(ProceedingJoinPoint proceedingJoinPoint)
289         throws Throwable {
290 
291         String methodName = proceedingJoinPoint.getSignature().getName();
292         Object[] arguments = proceedingJoinPoint.getArgs();
293 
294         if (arguments.length == 0) {
295             return proceedingJoinPoint.proceed();
296         }
297 
298         Object argument = arguments[0];
299 
300         long companyId = -1;
301 
302         if (argument instanceof Long) {
303             if (methodName.equals("checkPortlets") ||
304                 methodName.equals("clonePortlet") ||
305                 methodName.equals("getPortletById") ||
306                 methodName.equals("getPortletByStrutsPath") ||
307                 methodName.equals("getPortlets") ||
308                 methodName.equals("hasPortlet") ||
309                 methodName.equals("updatePortlet")) {
310 
311                 companyId = (Long)argument;
312             }
313         }
314         else if (argument instanceof Portlet) {
315             if (methodName.equals("checkPortlet") ||
316                 methodName.equals("deployRemotePortlet") ||
317                 methodName.equals("destroyPortlet") ||
318                 methodName.equals("destroyRemotePortlet")) {
319 
320                 Portlet portlet = (Portlet)argument;
321 
322                 companyId = portlet.getCompanyId();
323             }
324         }
325 
326         if (companyId <= 0) {
327             return proceedingJoinPoint.proceed();
328         }
329 
330         if (_log.isInfoEnabled()) {
331             _log.info(
332                 "Company service being set to shard of companyId " +
333                     companyId + " for " + _getSignature(proceedingJoinPoint));
334         }
335 
336         Object returnValue = null;
337 
338         pushCompanyService(companyId);
339 
340         try {
341             returnValue = proceedingJoinPoint.proceed(arguments);
342         }
343         finally {
344             popCompanyService();
345         }
346 
347         return returnValue;
348     }
349 
350     public void setShardDataSourceTargetSource(
351         ShardDataSourceTargetSource shardDataSourceTargetSource) {
352 
353         _shardDataSourceTargetSource = shardDataSourceTargetSource;
354     }
355 
356     public void setShardSessionFactoryTargetSource(
357         ShardSessionFactoryTargetSource shardSessionFactoryTargetSource) {
358 
359         _shardSessionFactoryTargetSource = shardSessionFactoryTargetSource;
360     }
361 
362     protected String getCurrentShardName() {
363         String shardName = _getCompanyServiceStack().peek();
364 
365         if (shardName == null) {
366             shardName = PropsValues.SHARD_DEFAULT_NAME;
367         }
368 
369         return shardName;
370     }
371 
372     protected DataSource getDataSource() {
373         return _shardDataSourceTargetSource.getDataSource();
374     }
375 
376     protected String popCompanyService() {
377         return _getCompanyServiceStack().pop();
378     }
379 
380     protected void pushCompanyService(long companyId) {
381         try {
382             Shard shard = ShardLocalServiceUtil.getShard(
383                 Company.class.getName(), companyId);
384 
385             String shardName = shard.getName();
386 
387             pushCompanyService(shardName);
388         }
389         catch (Exception e) {
390             _log.error(e, e);
391         }
392     }
393 
394     protected void pushCompanyService(String shardName) {
395         _getCompanyServiceStack().push(shardName);
396     }
397 
398     private Stack<String> _getCompanyServiceStack() {
399         Stack<String> companyServiceStack = _companyServiceStack.get();
400 
401         if (companyServiceStack == null) {
402             companyServiceStack = new Stack<String>();
403 
404             _companyServiceStack.set(companyServiceStack);
405         }
406 
407         return companyServiceStack;
408     }
409 
410     private String _getCompanyShardName(
411         String webId, String virtualHost, String mx, String shardName) {
412 
413         Map<String, String> shardParams = new HashMap<String, String>();
414 
415         shardParams.put("webId", webId);
416         shardParams.put("mx", mx);
417 
418         if (virtualHost != null) {
419             shardParams.put("virtualHost", virtualHost);
420         }
421 
422         shardName = _shardSelector.getShardName(
423             ShardSelector.COMPANY_SCOPE, shardName, shardParams);
424 
425         return shardName;
426     }
427 
428     private String _getShardName() {
429         return _shardName.get();
430     }
431 
432     private String _getSignature(ProceedingJoinPoint proceedingJoinPoint) {
433         String methodName = StringUtil.extractLast(
434             proceedingJoinPoint.getTarget().getClass().getName(),
435             StringPool.PERIOD);
436 
437         methodName +=
438             StringPool.PERIOD + proceedingJoinPoint.getSignature().getName() +
439                 "()";
440 
441         return methodName;
442     }
443 
444     private void _setShardName(String shardName) {
445         _shardName.set(shardName);
446     }
447 
448     private void _setShardNameByCompany() throws Throwable {
449         Stack<String> companyServiceStack = _getCompanyServiceStack();
450 
451         if (companyServiceStack.isEmpty()) {
452             long companyId = CompanyThreadLocal.getCompanyId();
453 
454             _setShardNameByCompanyId(companyId);
455         }
456         else {
457             String shardName = companyServiceStack.peek();
458 
459             _setShardName(shardName);
460         }
461     }
462 
463     private void _setShardNameByCompanyId(long companyId)
464         throws PortalException, SystemException {
465 
466         if (companyId == 0) {
467             _setShardName(PropsValues.SHARD_DEFAULT_NAME);
468         }
469         else {
470             Shard shard = ShardLocalServiceUtil.getShard(
471                 Company.class.getName(), companyId);
472 
473             String shardName = shard.getName();
474 
475             _setShardName(shardName);
476         }
477     }
478 
479     private static Log _log = LogFactoryUtil.getLog(ShardAdvice.class);
480 
481     private static ThreadLocal<Stack<String>> _companyServiceStack =
482         new ThreadLocal<Stack<String>>();
483     private static ThreadLocal<Object> _globalCall = new ThreadLocal<Object>();
484     private static ThreadLocal<String> _shardName =
485         new InitialThreadLocal<String>(PropsValues.SHARD_DEFAULT_NAME);
486     private static ShardSelector _shardSelector;
487 
488     private ShardDataSourceTargetSource _shardDataSourceTargetSource;
489     private ShardSessionFactoryTargetSource _shardSessionFactoryTargetSource;
490 
491     static {
492         try {
493             _shardSelector = (ShardSelector)Class.forName(
494                 PropsValues.SHARD_SELECTOR).newInstance();
495         }
496         catch (Exception e) {
497             _log.error(e, e);
498         }
499     }
500 
501 }