1   /**
2    * Copyright (c) 2000-2010 Liferay, Inc. All rights reserved.
3    *
4    * This library is free software; you can redistribute it and/or modify it under
5    * the terms of the GNU Lesser General Public License as published by the Free
6    * Software Foundation; either version 2.1 of the License, or (at your option)
7    * any later version.
8    *
9    * This library is distributed in the hope that it will be useful, but WITHOUT
10   * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
11   * FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
12   * details.
13   */
14  
15  package com.liferay.portal.dao.shard;
16  
17  import com.liferay.counter.service.persistence.CounterPersistence;
18  import com.liferay.portal.NoSuchCompanyException;
19  import com.liferay.portal.kernel.exception.PortalException;
20  import com.liferay.portal.kernel.exception.SystemException;
21  import com.liferay.portal.kernel.log.Log;
22  import com.liferay.portal.kernel.log.LogFactoryUtil;
23  import com.liferay.portal.kernel.util.InfrastructureUtil;
24  import com.liferay.portal.kernel.util.InitialThreadLocal;
25  import com.liferay.portal.kernel.util.StringPool;
26  import com.liferay.portal.kernel.util.StringUtil;
27  import com.liferay.portal.model.Company;
28  import com.liferay.portal.model.Shard;
29  import com.liferay.portal.security.auth.CompanyThreadLocal;
30  import com.liferay.portal.service.CompanyLocalServiceUtil;
31  import com.liferay.portal.service.ShardLocalServiceUtil;
32  import com.liferay.portal.service.persistence.ClassNamePersistence;
33  import com.liferay.portal.service.persistence.CompanyPersistence;
34  import com.liferay.portal.service.persistence.ReleasePersistence;
35  import com.liferay.portal.service.persistence.ShardPersistence;
36  import com.liferay.portal.util.PropsValues;
37  
38  import java.util.HashMap;
39  import java.util.Map;
40  import java.util.Stack;
41  
42  import javax.sql.DataSource;
43  
44  import org.aspectj.lang.ProceedingJoinPoint;
45  
46  /**
47   * <a href="ShardAdvice.java.html"><b><i>View Source</i></b></a>
48   *
49   * @author Michael Young
50   * @author Alexander Chow
51   */
52  public class ShardAdvice {
53  
54      public void afterPropertiesSet() {
55          if (_shardDataSourceTargetSource == null) {
56              _shardDataSourceTargetSource =
57                  (ShardDataSourceTargetSource)InfrastructureUtil.
58                      getShardDataSourceTargetSource();
59          }
60  
61          if (_shardSessionFactoryTargetSource == null) {
62              _shardSessionFactoryTargetSource =
63                  (ShardSessionFactoryTargetSource)InfrastructureUtil.
64                      getShardSessionFactoryTargetSource();
65          }
66      }
67  
68      public Object invokeByParameter(ProceedingJoinPoint proceedingJoinPoint)
69          throws Throwable {
70  
71          Object[] arguments = proceedingJoinPoint.getArgs();
72  
73          long companyId = (Long)arguments[0];
74  
75          Shard shard = ShardLocalServiceUtil.getShard(
76              Company.class.getName(), companyId);
77  
78          String shardName = shard.getName();
79  
80          if (_log.isInfoEnabled()) {
81              _log.info(
82                  "Service being set to shard " + shardName + " for " +
83                      _getSignature(proceedingJoinPoint));
84          }
85  
86          Object returnValue = null;
87  
88          pushCompanyService(shardName);
89  
90          try {
91              returnValue = proceedingJoinPoint.proceed();
92          }
93          finally {
94              popCompanyService();
95          }
96  
97          return returnValue;
98      }
99  
100     public Object invokeCompanyService(ProceedingJoinPoint proceedingJoinPoint)
101         throws Throwable {
102 
103         String methodName = proceedingJoinPoint.getSignature().getName();
104         Object[] arguments = proceedingJoinPoint.getArgs();
105 
106         String shardName = PropsValues.SHARD_DEFAULT_NAME;
107 
108         if (methodName.equals("addCompany")) {
109             String webId = (String)arguments[0];
110             String virtualHost = (String)arguments[1];
111             String mx = (String)arguments[2];
112             shardName = (String)arguments[3];
113 
114             shardName = _getCompanyShardName(webId, virtualHost, mx, shardName);
115 
116             arguments[3] = shardName;
117         }
118         else if (methodName.equals("checkCompany")) {
119             String webId = (String)arguments[0];
120 
121             if (!webId.equals(PropsValues.COMPANY_DEFAULT_WEB_ID)) {
122                 if (arguments.length == 3) {
123                     String mx = (String)arguments[1];
124                     shardName = (String)arguments[2];
125 
126                     shardName = _getCompanyShardName(
127                         webId, null, mx, shardName);
128 
129                     arguments[2] = shardName;
130                 }
131 
132                 try {
133                     Company company = CompanyLocalServiceUtil.getCompanyByWebId(
134                         webId);
135 
136                     shardName = company.getShardName();
137                 }
138                 catch (NoSuchCompanyException nsce) {
139                 }
140             }
141         }
142         else if (methodName.startsWith("update")) {
143             long companyId = (Long)arguments[0];
144 
145             Shard shard = ShardLocalServiceUtil.getShard(
146                 Company.class.getName(), companyId);
147 
148             shardName = shard.getName();
149         }
150         else {
151             return proceedingJoinPoint.proceed();
152         }
153 
154         if (_log.isInfoEnabled()) {
155             _log.info(
156                 "Company service being set to shard " + shardName + " for " +
157                     _getSignature(proceedingJoinPoint));
158         }
159 
160         Object returnValue = null;
161 
162         pushCompanyService(shardName);
163 
164         try {
165             returnValue = proceedingJoinPoint.proceed(arguments);
166         }
167         finally {
168             popCompanyService();
169         }
170 
171         return returnValue;
172     }
173 
174     public Object invokeGlobally(ProceedingJoinPoint proceedingJoinPoint)
175         throws Throwable {
176 
177         _globalCall.set(new Object());
178 
179         try {
180             if (_log.isInfoEnabled()) {
181                 _log.info(
182                     "All shards invoked for " +
183                         _getSignature(proceedingJoinPoint));
184             }
185 
186             for (String shardName : PropsValues.SHARD_AVAILABLE_NAMES) {
187                 _shardDataSourceTargetSource.setDataSource(shardName);
188                 _shardSessionFactoryTargetSource.setSessionFactory(shardName);
189 
190                 proceedingJoinPoint.proceed();
191             }
192         }
193         finally {
194             _globalCall.set(null);
195         }
196 
197         return null;
198     }
199 
200     public Object invokePersistence(ProceedingJoinPoint proceedingJoinPoint)
201         throws Throwable {
202 
203         if ((_shardDataSourceTargetSource == null) ||
204             (_shardSessionFactoryTargetSource == null)) {
205 
206             return proceedingJoinPoint.proceed();
207         }
208 
209         Object target = proceedingJoinPoint.getTarget();
210 
211         if (target instanceof ClassNamePersistence ||
212             target instanceof CompanyPersistence ||
213             target instanceof CounterPersistence ||
214             target instanceof ReleasePersistence ||
215             target instanceof ShardPersistence) {
216 
217             _shardDataSourceTargetSource.setDataSource(
218                 PropsValues.SHARD_DEFAULT_NAME);
219             _shardSessionFactoryTargetSource.setSessionFactory(
220                 PropsValues.SHARD_DEFAULT_NAME);
221 
222             if (_log.isDebugEnabled()) {
223                 _log.debug(
224                     "Using default shard for " +
225                         _getSignature(proceedingJoinPoint));
226             }
227 
228             return proceedingJoinPoint.proceed();
229         }
230 
231         if (_globalCall.get() == null) {
232             _setShardNameByCompany();
233 
234             String shardName = _getShardName();
235 
236             _shardDataSourceTargetSource.setDataSource(shardName);
237             _shardSessionFactoryTargetSource.setSessionFactory(shardName);
238 
239             if (_log.isInfoEnabled()) {
240                 _log.info(
241                     "Using shard name " + shardName + " for " +
242                         _getSignature(proceedingJoinPoint));
243             }
244 
245             return proceedingJoinPoint.proceed();
246         }
247         else {
248             return proceedingJoinPoint.proceed();
249         }
250     }
251 
252     public void setShardDataSourceTargetSource(
253         ShardDataSourceTargetSource shardDataSourceTargetSource) {
254 
255         _shardDataSourceTargetSource = shardDataSourceTargetSource;
256     }
257 
258     public void setShardSessionFactoryTargetSource(
259         ShardSessionFactoryTargetSource shardSessionFactoryTargetSource) {
260 
261         _shardSessionFactoryTargetSource = shardSessionFactoryTargetSource;
262     }
263 
264     protected DataSource getDataSource() {
265         return _shardDataSourceTargetSource.getDataSource();
266     }
267 
268     protected String popCompanyService() {
269         return _getCompanyServiceStack().pop();
270     }
271 
272     protected void pushCompanyService(long companyId) {
273         try {
274             Shard shard = ShardLocalServiceUtil.getShard(
275                 Company.class.getName(), companyId);
276 
277             String shardName = shard.getName();
278 
279             pushCompanyService(shardName);
280         }
281         catch (Exception e) {
282             _log.error(e, e);
283         }
284     }
285 
286     protected void pushCompanyService(String shardName) {
287         _getCompanyServiceStack().push(shardName);
288     }
289 
290     private Stack<String> _getCompanyServiceStack() {
291         Stack<String> companyServiceStack = _companyServiceStack.get();
292 
293         if (companyServiceStack == null) {
294             companyServiceStack = new Stack<String>();
295 
296             _companyServiceStack.set(companyServiceStack);
297         }
298 
299         return companyServiceStack;
300     }
301 
302     private String _getCompanyShardName(
303         String webId, String virtualHost, String mx, String shardName) {
304 
305         Map<String, String> shardParams = new HashMap<String, String>();
306 
307         shardParams.put("webId", webId);
308         shardParams.put("mx", mx);
309 
310         if (virtualHost != null) {
311             shardParams.put("virtualHost", virtualHost);
312         }
313 
314         shardName = ShardUtil.getShardSelector().getShardName(
315             ShardUtil.COMPANY_SCOPE, shardName, shardParams);
316 
317         return shardName;
318     }
319 
320     private String _getShardName() {
321         return _shardName.get();
322     }
323 
324     private String _getSignature(ProceedingJoinPoint proceedingJoinPoint) {
325         String methodName = StringUtil.extractLast(
326             proceedingJoinPoint.getTarget().getClass().getName(),
327             StringPool.PERIOD);
328 
329         methodName +=
330             StringPool.PERIOD + proceedingJoinPoint.getSignature().getName() +
331                 "()";
332 
333         return methodName;
334     }
335 
336     private void _setShardName(String shardName) {
337         _shardName.set(shardName);
338     }
339 
340     private void _setShardNameByCompany() throws Throwable {
341         Stack<String> companyServiceStack = _getCompanyServiceStack();
342 
343         if (companyServiceStack.isEmpty()) {
344             long companyId = CompanyThreadLocal.getCompanyId();
345 
346             _setShardNameByCompanyId(companyId);
347         }
348         else {
349             String shardName = companyServiceStack.peek();
350 
351             _setShardName(shardName);
352         }
353     }
354 
355     private void _setShardNameByCompanyId(long companyId)
356         throws PortalException, SystemException {
357 
358         if (companyId == 0) {
359             _setShardName(PropsValues.SHARD_DEFAULT_NAME);
360         }
361         else {
362             Shard shard = ShardLocalServiceUtil.getShard(
363                 Company.class.getName(), companyId);
364 
365             String shardName = shard.getName();
366 
367             _setShardName(shardName);
368         }
369     }
370 
371     private static Log _log = LogFactoryUtil.getLog(ShardAdvice.class);
372 
373     private static ThreadLocal<Stack<String>> _companyServiceStack =
374         new ThreadLocal<Stack<String>>();
375     private static ThreadLocal<Object> _globalCall = new ThreadLocal<Object>();
376     private static ThreadLocal<String> _shardName =
377         new InitialThreadLocal<String>(PropsValues.SHARD_DEFAULT_NAME);
378 
379     private ShardDataSourceTargetSource _shardDataSourceTargetSource;
380     private ShardSessionFactoryTargetSource _shardSessionFactoryTargetSource;
381 
382 }