1   /**
2    * Copyright (c) 2000-2009 Liferay, Inc. All rights reserved.
3    *
4    * The contents of this file are subject to the terms of the Liferay Enterprise
5    * Subscription License ("License"). You may not use this file except in
6    * compliance with the License. You can obtain a copy of the License by
7    * contacting Liferay, Inc. See the License for the specific language governing
8    * permissions and limitations under the License, including but not limited to
9    * distribution rights of the Software.
10   *
11   * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
12   * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
13   * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
14   * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
15   * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
16   * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17   * SOFTWARE.
18   */
19  
20  package com.liferay.portal.dao.shard;
21  
22  import com.liferay.counter.service.persistence.CounterPersistence;
23  import com.liferay.portal.NoSuchCompanyException;
24  import com.liferay.portal.PortalException;
25  import com.liferay.portal.SystemException;
26  import com.liferay.portal.kernel.log.Log;
27  import com.liferay.portal.kernel.log.LogFactoryUtil;
28  import com.liferay.portal.kernel.util.InitialThreadLocal;
29  import com.liferay.portal.kernel.util.StringPool;
30  import com.liferay.portal.kernel.util.StringUtil;
31  import com.liferay.portal.model.Company;
32  import com.liferay.portal.model.Shard;
33  import com.liferay.portal.security.auth.CompanyThreadLocal;
34  import com.liferay.portal.service.CompanyLocalServiceUtil;
35  import com.liferay.portal.service.ShardLocalServiceUtil;
36  import com.liferay.portal.service.persistence.ClassNamePersistence;
37  import com.liferay.portal.service.persistence.CompanyPersistence;
38  import com.liferay.portal.service.persistence.ReleasePersistence;
39  import com.liferay.portal.service.persistence.ShardPersistence;
40  import com.liferay.portal.util.PropsValues;
41  
42  import java.util.HashMap;
43  import java.util.Map;
44  import java.util.Stack;
45  
46  import javax.sql.DataSource;
47  
48  import org.aspectj.lang.ProceedingJoinPoint;
49  
50  /**
51   * <a href="ShardAdvice.java.html"><b><i>View Source</i></b></a>
52   *
53   * @author Michael Young
54   * @author Alexander Chow
55   *
56   */
57  public class ShardAdvice {
58  
59      public Object invokeAccountService(ProceedingJoinPoint proceedingJoinPoint)
60          throws Throwable {
61  
62          String methodName = proceedingJoinPoint.getSignature().getName();
63          Object[] arguments = proceedingJoinPoint.getArgs();
64  
65          String shardName = PropsValues.SHARD_DEFAULT_NAME;
66  
67          if (methodName.equals("getAccount") && (arguments.length == 2)) {
68              long companyId = (Long)arguments[0];
69  
70              Shard shard = ShardLocalServiceUtil.getShard(
71                  Company.class.getName(), companyId);
72  
73              shardName = shard.getName();
74          }
75          else {
76              return proceedingJoinPoint.proceed();
77          }
78  
79          if (_log.isInfoEnabled()) {
80              _log.info(
81                  "Company service being set to shard " + shardName + " for " +
82                      _getSignature(proceedingJoinPoint));
83          }
84  
85          Object returnValue = null;
86  
87          pushCompanyService(shardName);
88  
89          try {
90              returnValue = proceedingJoinPoint.proceed();
91          }
92          finally {
93              popCompanyService();
94          }
95  
96          return returnValue;
97      }
98  
99      public Object invokeCompanyService(ProceedingJoinPoint proceedingJoinPoint)
100         throws Throwable {
101 
102         String methodName = proceedingJoinPoint.getSignature().getName();
103         Object[] arguments = proceedingJoinPoint.getArgs();
104 
105         String shardName = PropsValues.SHARD_DEFAULT_NAME;
106 
107         if (methodName.equals("addCompany")) {
108             String webId = (String)arguments[0];
109             String virtualHost = (String)arguments[1];
110             String mx = (String)arguments[2];
111             shardName = (String)arguments[3];
112 
113             shardName = _getCompanyShardName(webId, virtualHost, mx, shardName);
114 
115             arguments[3] = shardName;
116         }
117         else if (methodName.equals("checkCompany")) {
118             String webId = (String)arguments[0];
119 
120             if (!webId.equals(PropsValues.COMPANY_DEFAULT_WEB_ID)) {
121                 if (arguments.length == 3) {
122                     String mx = (String)arguments[1];
123                     shardName = (String)arguments[2];
124 
125                     shardName = _getCompanyShardName(
126                         webId, null, mx, shardName);
127 
128                     arguments[2] = shardName;
129                 }
130 
131                 try {
132                     Company company = CompanyLocalServiceUtil.getCompanyByWebId(
133                         webId);
134 
135                     shardName = company.getShardName();
136                 }
137                 catch (NoSuchCompanyException nsce) {
138                 }
139             }
140         }
141         else if (methodName.startsWith("update")) {
142             long companyId = (Long)arguments[0];
143 
144             Shard shard = ShardLocalServiceUtil.getShard(
145                 Company.class.getName(), companyId);
146 
147             shardName = shard.getName();
148         }
149         else {
150             return proceedingJoinPoint.proceed();
151         }
152 
153         if (_log.isInfoEnabled()) {
154             _log.info(
155                 "Company service being set to shard " + shardName + " for " +
156                     _getSignature(proceedingJoinPoint));
157         }
158 
159         Object returnValue = null;
160 
161         pushCompanyService(shardName);
162 
163         try {
164             returnValue = proceedingJoinPoint.proceed(arguments);
165         }
166         finally {
167             popCompanyService();
168         }
169 
170         return returnValue;
171     }
172 
173     public Object invokeGlobally(ProceedingJoinPoint proceedingJoinPoint)
174         throws Throwable {
175 
176         _globalCallThreadLocal.set(new Object());
177 
178         try {
179             if (_log.isInfoEnabled()) {
180                 _log.info(
181                     "All shards invoked for " +
182                         _getSignature(proceedingJoinPoint));
183             }
184 
185             for (String shardName : PropsValues.SHARD_AVAILABLE_NAMES) {
186                 _shardDataSourceTargetSource.setDataSource(shardName);
187                 _shardSessionFactoryTargetSource.setSessionFactory(shardName);
188 
189                 proceedingJoinPoint.proceed();
190             }
191         }
192         finally {
193             _globalCallThreadLocal.set(null);
194         }
195 
196         return null;
197     }
198 
199     public Object invokePersistence(ProceedingJoinPoint proceedingJoinPoint)
200         throws Throwable {
201 
202         Object target = proceedingJoinPoint.getTarget();
203 
204         if (target instanceof ClassNamePersistence ||
205             target instanceof CompanyPersistence ||
206             target instanceof CounterPersistence ||
207             target instanceof ReleasePersistence ||
208             target instanceof ShardPersistence) {
209 
210             _shardDataSourceTargetSource.setDataSource(
211                 PropsValues.SHARD_DEFAULT_NAME);
212             _shardSessionFactoryTargetSource.setSessionFactory(
213                 PropsValues.SHARD_DEFAULT_NAME);
214 
215             if (_log.isDebugEnabled()) {
216                 _log.debug(
217                     "Using default shard for " +
218                         _getSignature(proceedingJoinPoint));
219             }
220 
221             return proceedingJoinPoint.proceed();
222         }
223 
224         if (_globalCallThreadLocal.get() == null) {
225             _setShardNameByCompany();
226 
227             String shardName = _getShardName();
228 
229             _shardDataSourceTargetSource.setDataSource(shardName);
230             _shardSessionFactoryTargetSource.setSessionFactory(shardName);
231 
232             if (_log.isInfoEnabled()) {
233                 _log.info(
234                     "Using shard name " + shardName + " for " +
235                         _getSignature(proceedingJoinPoint));
236             }
237 
238             return proceedingJoinPoint.proceed();
239         }
240         else {
241             return proceedingJoinPoint.proceed();
242         }
243     }
244 
245     public Object invokeUserService(ProceedingJoinPoint proceedingJoinPoint)
246         throws Throwable {
247 
248         String methodName = proceedingJoinPoint.getSignature().getName();
249         Object[] arguments = proceedingJoinPoint.getArgs();
250 
251         String shardName = PropsValues.SHARD_DEFAULT_NAME;
252 
253         if (methodName.equals("searchCount")) {
254             long companyId = (Long)arguments[0];
255 
256             Shard shard = ShardLocalServiceUtil.getShard(
257                 Company.class.getName(), companyId);
258 
259             shardName = shard.getName();
260         }
261         else {
262             return proceedingJoinPoint.proceed();
263         }
264 
265         if (_log.isInfoEnabled()) {
266             _log.info(
267                 "Company service being set to shard " + shardName + " for " +
268                     _getSignature(proceedingJoinPoint));
269         }
270 
271         Object returnValue = null;
272 
273         pushCompanyService(shardName);
274 
275         try {
276             returnValue = proceedingJoinPoint.proceed();
277         }
278         finally {
279             popCompanyService();
280         }
281 
282         return returnValue;
283     }
284 
285     public void setShardDataSourceTargetSource(
286         ShardDataSourceTargetSource shardDataSourceTargetSource) {
287 
288         _shardDataSourceTargetSource = shardDataSourceTargetSource;
289     }
290 
291     public void setShardSessionFactoryTargetSource(
292         ShardSessionFactoryTargetSource shardSessionFactoryTargetSource) {
293 
294         _shardSessionFactoryTargetSource = shardSessionFactoryTargetSource;
295     }
296 
297     protected DataSource getDataSource() {
298         return _shardDataSourceTargetSource.getDataSource();
299     }
300 
301     protected String popCompanyService() {
302         return _getCompanyServiceStack().pop();
303     }
304 
305     protected void pushCompanyService(long companyId) {
306         try {
307             Shard shard = ShardLocalServiceUtil.getShard(
308                 Company.class.getName(), companyId);
309 
310             String shardName = shard.getName();
311 
312             pushCompanyService(shardName);
313         }
314         catch (Exception e) {
315             _log.error(e, e);
316         }
317     }
318 
319     protected void pushCompanyService(String shardName) {
320         _getCompanyServiceStack().push(shardName);
321     }
322 
323     private Stack<String> _getCompanyServiceStack() {
324         Stack<String> companyServiceStack = _companyServiceStack.get();
325 
326         if (companyServiceStack == null) {
327             companyServiceStack = new Stack<String>();
328 
329             _companyServiceStack.set(companyServiceStack);
330         }
331 
332         return companyServiceStack;
333     }
334 
335     private String _getCompanyShardName(
336         String webId, String virtualHost, String mx, String shardName) {
337 
338         Map<String, String> shardParams = new HashMap<String, String>();
339 
340         shardParams.put("webId", webId);
341         shardParams.put("mx", mx);
342 
343         if (virtualHost != null) {
344             shardParams.put("virtualHost", virtualHost);
345         }
346 
347         shardName = ShardUtil.getShardSelector().getShardName(
348             ShardUtil.COMPANY_SCOPE, shardName, shardParams);
349 
350         return shardName;
351     }
352 
353     private String _getShardName() {
354         return _shardNameThreadLocal.get();
355     }
356 
357     private String _getSignature(ProceedingJoinPoint proceedingJoinPoint) {
358         String methodName = StringUtil.extractLast(
359             proceedingJoinPoint.getTarget().getClass().getName(),
360             StringPool.PERIOD);
361 
362         methodName +=
363             StringPool.PERIOD + proceedingJoinPoint.getSignature().getName() +
364                 "()";
365 
366         return methodName;
367     }
368 
369     private void _setShardName(String shardName) {
370         _shardNameThreadLocal.set(shardName);
371     }
372 
373     private void _setShardNameByCompany() throws Throwable {
374         Stack<String> companyServiceStack = _getCompanyServiceStack();
375 
376         if (companyServiceStack.isEmpty()) {
377             long companyId = CompanyThreadLocal.getCompanyId();
378 
379             _setShardNameByCompanyId(companyId);
380         }
381         else {
382             String shardName = companyServiceStack.peek();
383 
384             _setShardName(shardName);
385         }
386     }
387 
388     private void _setShardNameByCompanyId(long companyId)
389         throws PortalException, SystemException {
390 
391         if (companyId == 0) {
392             _setShardName(PropsValues.SHARD_DEFAULT_NAME);
393         }
394         else {
395             Shard shard = ShardLocalServiceUtil.getShard(
396                 Company.class.getName(), companyId);
397 
398             String shardName = shard.getName();
399 
400             _setShardName(shardName);
401         }
402     }
403 
404     private static Log _log = LogFactoryUtil.getLog(ShardAdvice.class);
405 
406     private static ThreadLocal<Stack<String>> _companyServiceStack =
407         new ThreadLocal<Stack<String>>();
408     private static ThreadLocal<Object> _globalCallThreadLocal =
409         new ThreadLocal<Object>();
410     private static ThreadLocal<String> _shardNameThreadLocal =
411         new InitialThreadLocal<String>(PropsValues.SHARD_DEFAULT_NAME);
412 
413     private ShardDataSourceTargetSource _shardDataSourceTargetSource;
414     private ShardSessionFactoryTargetSource _shardSessionFactoryTargetSource;
415 
416 }