1   /**
2    * Copyright (c) 2000-2009 Liferay, Inc. All rights reserved.
3    *
4    *
5    *
6    *
7    * The contents of this file are subject to the terms of the Liferay Enterprise
8    * Subscription License ("License"). You may not use this file except in
9    * compliance with the License. You can obtain a copy of the License by
10   * contacting Liferay, Inc. See the License for the specific language governing
11   * permissions and limitations under the License, including but not limited to
12   * distribution rights of the Software.
13   *
14   * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15   * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16   * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17   * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18   * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19   * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
20   * SOFTWARE.
21   */
22  
23  package com.liferay.portal.dao.shard;
24  
25  import com.liferay.counter.service.persistence.CounterPersistence;
26  import com.liferay.portal.NoSuchCompanyException;
27  import com.liferay.portal.PortalException;
28  import com.liferay.portal.SystemException;
29  import com.liferay.portal.kernel.log.Log;
30  import com.liferay.portal.kernel.log.LogFactoryUtil;
31  import com.liferay.portal.kernel.util.InitialThreadLocal;
32  import com.liferay.portal.kernel.util.StringPool;
33  import com.liferay.portal.kernel.util.StringUtil;
34  import com.liferay.portal.model.Company;
35  import com.liferay.portal.model.Shard;
36  import com.liferay.portal.security.auth.CompanyThreadLocal;
37  import com.liferay.portal.service.CompanyLocalServiceUtil;
38  import com.liferay.portal.service.ShardLocalServiceUtil;
39  import com.liferay.portal.service.persistence.ClassNamePersistence;
40  import com.liferay.portal.service.persistence.CompanyPersistence;
41  import com.liferay.portal.service.persistence.ReleasePersistence;
42  import com.liferay.portal.service.persistence.ShardPersistence;
43  import com.liferay.portal.util.PropsValues;
44  
45  import java.util.HashMap;
46  import java.util.Map;
47  import java.util.Stack;
48  
49  import javax.sql.DataSource;
50  
51  import org.aspectj.lang.ProceedingJoinPoint;
52  
53  /**
54   * <a href="ShardAdvice.java.html"><b><i>View Source</i></b></a>
55   *
56   * @author Michael Young
57   * @author Alexander Chow
58   */
59  public class ShardAdvice {
60  
61      public Object invokeByParameter(ProceedingJoinPoint proceedingJoinPoint)
62          throws Throwable {
63  
64          Object[] arguments = proceedingJoinPoint.getArgs();
65  
66          long companyId = (Long)arguments[0];
67  
68          Shard shard = ShardLocalServiceUtil.getShard(
69              Company.class.getName(), companyId);
70  
71          String shardName = shard.getName();
72  
73          if (_log.isInfoEnabled()) {
74              _log.info(
75                  "Service being set to shard " + shardName + " for " +
76                      _getSignature(proceedingJoinPoint));
77          }
78  
79          Object returnValue = null;
80  
81          pushCompanyService(shardName);
82  
83          try {
84              returnValue = proceedingJoinPoint.proceed();
85          }
86          finally {
87              popCompanyService();
88          }
89  
90          return returnValue;
91      }
92  
93      public Object invokeCompanyService(ProceedingJoinPoint proceedingJoinPoint)
94          throws Throwable {
95  
96          String methodName = proceedingJoinPoint.getSignature().getName();
97          Object[] arguments = proceedingJoinPoint.getArgs();
98  
99          String shardName = PropsValues.SHARD_DEFAULT_NAME;
100 
101         if (methodName.equals("addCompany") && (arguments.length > 3)) {
102             String webId = (String)arguments[0];
103             String virtualHost = (String)arguments[1];
104             String mx = (String)arguments[2];
105             shardName = (String)arguments[3];
106 
107             shardName = _getCompanyShardName(webId, virtualHost, mx, shardName);
108 
109             arguments[3] = shardName;
110         }
111         else if (methodName.equals("checkCompany")) {
112             String webId = (String)arguments[0];
113 
114             if (!webId.equals(PropsValues.COMPANY_DEFAULT_WEB_ID)) {
115                 if (arguments.length == 3) {
116                     String mx = (String)arguments[1];
117                     shardName = (String)arguments[2];
118 
119                     shardName = _getCompanyShardName(
120                         webId, null, mx, shardName);
121 
122                     arguments[2] = shardName;
123                 }
124 
125                 try {
126                     Company company = CompanyLocalServiceUtil.getCompanyByWebId(
127                         webId);
128 
129                     shardName = company.getShardName();
130                 }
131                 catch (NoSuchCompanyException nsce) {
132                 }
133             }
134         }
135         else if (methodName.startsWith("update")) {
136             long companyId = (Long)arguments[0];
137 
138             Shard shard = ShardLocalServiceUtil.getShard(
139                 Company.class.getName(), companyId);
140 
141             shardName = shard.getName();
142         }
143         else {
144             return proceedingJoinPoint.proceed();
145         }
146 
147         if (_log.isInfoEnabled()) {
148             _log.info(
149                 "Company service being set to shard " + shardName + " for " +
150                     _getSignature(proceedingJoinPoint));
151         }
152 
153         Object returnValue = null;
154 
155         pushCompanyService(shardName);
156 
157         try {
158             returnValue = proceedingJoinPoint.proceed(arguments);
159         }
160         finally {
161             popCompanyService();
162         }
163 
164         return returnValue;
165     }
166 
167     public Object invokeGlobally(ProceedingJoinPoint proceedingJoinPoint)
168         throws Throwable {
169 
170         _globalCallThreadLocal.set(new Object());
171 
172         try {
173             if (_log.isInfoEnabled()) {
174                 _log.info(
175                     "All shards invoked for " +
176                         _getSignature(proceedingJoinPoint));
177             }
178 
179             for (String shardName : PropsValues.SHARD_AVAILABLE_NAMES) {
180                 _shardDataSourceTargetSource.setDataSource(shardName);
181                 _shardSessionFactoryTargetSource.setSessionFactory(shardName);
182 
183                 proceedingJoinPoint.proceed();
184             }
185         }
186         finally {
187             _globalCallThreadLocal.set(null);
188         }
189 
190         return null;
191     }
192 
193     public Object invokePersistence(ProceedingJoinPoint proceedingJoinPoint)
194         throws Throwable {
195 
196         Object target = proceedingJoinPoint.getTarget();
197 
198         if (target instanceof ClassNamePersistence ||
199             target instanceof CompanyPersistence ||
200             target instanceof CounterPersistence ||
201             target instanceof ReleasePersistence ||
202             target instanceof ShardPersistence) {
203 
204             _shardDataSourceTargetSource.setDataSource(
205                 PropsValues.SHARD_DEFAULT_NAME);
206             _shardSessionFactoryTargetSource.setSessionFactory(
207                 PropsValues.SHARD_DEFAULT_NAME);
208 
209             if (_log.isDebugEnabled()) {
210                 _log.debug(
211                     "Using default shard for " +
212                         _getSignature(proceedingJoinPoint));
213             }
214 
215             return proceedingJoinPoint.proceed();
216         }
217 
218         if (_globalCallThreadLocal.get() == null) {
219             _setShardNameByCompany();
220 
221             String shardName = _getShardName();
222 
223             _shardDataSourceTargetSource.setDataSource(shardName);
224             _shardSessionFactoryTargetSource.setSessionFactory(shardName);
225 
226             if (_log.isInfoEnabled()) {
227                 _log.info(
228                     "Using shard name " + shardName + " for " +
229                         _getSignature(proceedingJoinPoint));
230             }
231 
232             return proceedingJoinPoint.proceed();
233         }
234         else {
235             return proceedingJoinPoint.proceed();
236         }
237     }
238 
239     public void setShardDataSourceTargetSource(
240         ShardDataSourceTargetSource shardDataSourceTargetSource) {
241 
242         _shardDataSourceTargetSource = shardDataSourceTargetSource;
243     }
244 
245     public void setShardSessionFactoryTargetSource(
246         ShardSessionFactoryTargetSource shardSessionFactoryTargetSource) {
247 
248         _shardSessionFactoryTargetSource = shardSessionFactoryTargetSource;
249     }
250 
251     protected DataSource getDataSource() {
252         return _shardDataSourceTargetSource.getDataSource();
253     }
254 
255     protected String popCompanyService() {
256         return _getCompanyServiceStack().pop();
257     }
258 
259     protected void pushCompanyService(long companyId) {
260         try {
261             Shard shard = ShardLocalServiceUtil.getShard(
262                 Company.class.getName(), companyId);
263 
264             String shardName = shard.getName();
265 
266             pushCompanyService(shardName);
267         }
268         catch (Exception e) {
269             _log.error(e, e);
270         }
271     }
272 
273     protected void pushCompanyService(String shardName) {
274         _getCompanyServiceStack().push(shardName);
275     }
276 
277     private Stack<String> _getCompanyServiceStack() {
278         Stack<String> companyServiceStack = _companyServiceStack.get();
279 
280         if (companyServiceStack == null) {
281             companyServiceStack = new Stack<String>();
282 
283             _companyServiceStack.set(companyServiceStack);
284         }
285 
286         return companyServiceStack;
287     }
288 
289     private String _getCompanyShardName(
290         String webId, String virtualHost, String mx, String shardName) {
291 
292         Map<String, String> shardParams = new HashMap<String, String>();
293 
294         shardParams.put("webId", webId);
295         shardParams.put("mx", mx);
296 
297         if (virtualHost != null) {
298             shardParams.put("virtualHost", virtualHost);
299         }
300 
301         shardName = ShardUtil.getShardSelector().getShardName(
302             ShardUtil.COMPANY_SCOPE, shardName, shardParams);
303 
304         return shardName;
305     }
306 
307     private String _getShardName() {
308         return _shardNameThreadLocal.get();
309     }
310 
311     private String _getSignature(ProceedingJoinPoint proceedingJoinPoint) {
312         String methodName = StringUtil.extractLast(
313             proceedingJoinPoint.getTarget().getClass().getName(),
314             StringPool.PERIOD);
315 
316         methodName +=
317             StringPool.PERIOD + proceedingJoinPoint.getSignature().getName() +
318                 "()";
319 
320         return methodName;
321     }
322 
323     private void _setShardName(String shardName) {
324         _shardNameThreadLocal.set(shardName);
325     }
326 
327     private void _setShardNameByCompany() throws Throwable {
328         Stack<String> companyServiceStack = _getCompanyServiceStack();
329 
330         if (companyServiceStack.isEmpty()) {
331             long companyId = CompanyThreadLocal.getCompanyId();
332 
333             _setShardNameByCompanyId(companyId);
334         }
335         else {
336             String shardName = companyServiceStack.peek();
337 
338             _setShardName(shardName);
339         }
340     }
341 
342     private void _setShardNameByCompanyId(long companyId)
343         throws PortalException, SystemException {
344 
345         if (companyId == 0) {
346             _setShardName(PropsValues.SHARD_DEFAULT_NAME);
347         }
348         else {
349             Shard shard = ShardLocalServiceUtil.getShard(
350                 Company.class.getName(), companyId);
351 
352             String shardName = shard.getName();
353 
354             _setShardName(shardName);
355         }
356     }
357 
358     private static Log _log = LogFactoryUtil.getLog(ShardAdvice.class);
359 
360     private static ThreadLocal<Stack<String>> _companyServiceStack =
361         new ThreadLocal<Stack<String>>();
362     private static ThreadLocal<Object> _globalCallThreadLocal =
363         new ThreadLocal<Object>();
364     private static ThreadLocal<String> _shardNameThreadLocal =
365         new InitialThreadLocal<String>(PropsValues.SHARD_DEFAULT_NAME);
366 
367     private ShardDataSourceTargetSource _shardDataSourceTargetSource;
368     private ShardSessionFactoryTargetSource _shardSessionFactoryTargetSource;
369 
370 }