1
19
20 package com.liferay.portal.dao.shard;
21
22 import com.liferay.counter.service.persistence.CounterPersistence;
23 import com.liferay.portal.NoSuchCompanyException;
24 import com.liferay.portal.PortalException;
25 import com.liferay.portal.SystemException;
26 import com.liferay.portal.kernel.log.Log;
27 import com.liferay.portal.kernel.log.LogFactoryUtil;
28 import com.liferay.portal.kernel.util.InitialThreadLocal;
29 import com.liferay.portal.kernel.util.StringPool;
30 import com.liferay.portal.kernel.util.StringUtil;
31 import com.liferay.portal.model.Company;
32 import com.liferay.portal.model.Shard;
33 import com.liferay.portal.security.auth.CompanyThreadLocal;
34 import com.liferay.portal.service.CompanyLocalServiceUtil;
35 import com.liferay.portal.service.ShardLocalServiceUtil;
36 import com.liferay.portal.service.persistence.ClassNamePersistence;
37 import com.liferay.portal.service.persistence.CompanyPersistence;
38 import com.liferay.portal.service.persistence.ReleasePersistence;
39 import com.liferay.portal.service.persistence.ShardPersistence;
40 import com.liferay.portal.util.PropsValues;
41
42 import java.util.HashMap;
43 import java.util.Map;
44 import java.util.Stack;
45
46 import javax.sql.DataSource;
47
48 import org.aspectj.lang.ProceedingJoinPoint;
49
50
57 public class ShardAdvice {
58
59 public Object invokeAccountService(ProceedingJoinPoint proceedingJoinPoint)
60 throws Throwable {
61
62 String methodName = proceedingJoinPoint.getSignature().getName();
63 Object[] arguments = proceedingJoinPoint.getArgs();
64
65 String shardName = PropsValues.SHARD_DEFAULT_NAME;
66
67 if (methodName.equals("getAccount") && (arguments.length == 2)) {
68 long companyId = (Long)arguments[0];
69
70 Shard shard = ShardLocalServiceUtil.getShard(
71 Company.class.getName(), companyId);
72
73 shardName = shard.getName();
74 }
75 else {
76 return proceedingJoinPoint.proceed();
77 }
78
79 if (_log.isInfoEnabled()) {
80 _log.info(
81 "Company service being set to shard " + shardName + " for " +
82 _getSignature(proceedingJoinPoint));
83 }
84
85 Object returnValue = null;
86
87 pushCompanyService(shardName);
88
89 try {
90 returnValue = proceedingJoinPoint.proceed();
91 }
92 finally {
93 popCompanyService();
94 }
95
96 return returnValue;
97 }
98
99 public Object invokeCompanyService(ProceedingJoinPoint proceedingJoinPoint)
100 throws Throwable {
101
102 String methodName = proceedingJoinPoint.getSignature().getName();
103 Object[] arguments = proceedingJoinPoint.getArgs();
104
105 String shardName = PropsValues.SHARD_DEFAULT_NAME;
106
107 if (methodName.equals("addCompany")) {
108 String webId = (String)arguments[0];
109 String virtualHost = (String)arguments[1];
110 String mx = (String)arguments[2];
111 shardName = (String)arguments[3];
112
113 shardName = _getCompanyShardName(webId, virtualHost, mx, shardName);
114
115 arguments[3] = shardName;
116 }
117 else if (methodName.equals("checkCompany")) {
118 String webId = (String)arguments[0];
119
120 if (!webId.equals(PropsValues.COMPANY_DEFAULT_WEB_ID)) {
121 if (arguments.length == 3) {
122 String mx = (String)arguments[1];
123 shardName = (String)arguments[2];
124
125 shardName = _getCompanyShardName(
126 webId, null, mx, shardName);
127
128 arguments[2] = shardName;
129 }
130
131 try {
132 Company company = CompanyLocalServiceUtil.getCompanyByWebId(
133 webId);
134
135 shardName = company.getShardName();
136 }
137 catch (NoSuchCompanyException nsce) {
138 }
139 }
140 }
141 else if (methodName.startsWith("update")) {
142 long companyId = (Long)arguments[0];
143
144 Shard shard = ShardLocalServiceUtil.getShard(
145 Company.class.getName(), companyId);
146
147 shardName = shard.getName();
148 }
149 else {
150 return proceedingJoinPoint.proceed();
151 }
152
153 if (_log.isInfoEnabled()) {
154 _log.info(
155 "Company service being set to shard " + shardName + " for " +
156 _getSignature(proceedingJoinPoint));
157 }
158
159 Object returnValue = null;
160
161 pushCompanyService(shardName);
162
163 try {
164 returnValue = proceedingJoinPoint.proceed(arguments);
165 }
166 finally {
167 popCompanyService();
168 }
169
170 return returnValue;
171 }
172
173 public Object invokeGlobally(ProceedingJoinPoint proceedingJoinPoint)
174 throws Throwable {
175
176 _globalCallThreadLocal.set(new Object());
177
178 try {
179 if (_log.isInfoEnabled()) {
180 _log.info(
181 "All shards invoked for " +
182 _getSignature(proceedingJoinPoint));
183 }
184
185 for (String shardName : PropsValues.SHARD_AVAILABLE_NAMES) {
186 _shardDataSourceTargetSource.setDataSource(shardName);
187 _shardSessionFactoryTargetSource.setSessionFactory(shardName);
188
189 proceedingJoinPoint.proceed();
190 }
191 }
192 finally {
193 _globalCallThreadLocal.set(null);
194 }
195
196 return null;
197 }
198
199 public Object invokePersistence(ProceedingJoinPoint proceedingJoinPoint)
200 throws Throwable {
201
202 Object target = proceedingJoinPoint.getTarget();
203
204 if (target instanceof ClassNamePersistence ||
205 target instanceof CompanyPersistence ||
206 target instanceof CounterPersistence ||
207 target instanceof ReleasePersistence ||
208 target instanceof ShardPersistence) {
209
210 _shardDataSourceTargetSource.setDataSource(
211 PropsValues.SHARD_DEFAULT_NAME);
212 _shardSessionFactoryTargetSource.setSessionFactory(
213 PropsValues.SHARD_DEFAULT_NAME);
214
215 if (_log.isDebugEnabled()) {
216 _log.debug(
217 "Using default shard for " +
218 _getSignature(proceedingJoinPoint));
219 }
220
221 return proceedingJoinPoint.proceed();
222 }
223
224 if (_globalCallThreadLocal.get() == null) {
225 _setShardNameByCompany();
226
227 String shardName = _getShardName();
228
229 _shardDataSourceTargetSource.setDataSource(shardName);
230 _shardSessionFactoryTargetSource.setSessionFactory(shardName);
231
232 if (_log.isInfoEnabled()) {
233 _log.info(
234 "Using shard name " + shardName + " for " +
235 _getSignature(proceedingJoinPoint));
236 }
237
238 return proceedingJoinPoint.proceed();
239 }
240 else {
241 return proceedingJoinPoint.proceed();
242 }
243 }
244
245 public Object invokeUserService(ProceedingJoinPoint proceedingJoinPoint)
246 throws Throwable {
247
248 String methodName = proceedingJoinPoint.getSignature().getName();
249 Object[] arguments = proceedingJoinPoint.getArgs();
250
251 String shardName = PropsValues.SHARD_DEFAULT_NAME;
252
253 if (methodName.equals("searchCount")) {
254 long companyId = (Long)arguments[0];
255
256 Shard shard = ShardLocalServiceUtil.getShard(
257 Company.class.getName(), companyId);
258
259 shardName = shard.getName();
260 }
261 else {
262 return proceedingJoinPoint.proceed();
263 }
264
265 if (_log.isInfoEnabled()) {
266 _log.info(
267 "Company service being set to shard " + shardName + " for " +
268 _getSignature(proceedingJoinPoint));
269 }
270
271 Object returnValue = null;
272
273 pushCompanyService(shardName);
274
275 try {
276 returnValue = proceedingJoinPoint.proceed();
277 }
278 finally {
279 popCompanyService();
280 }
281
282 return returnValue;
283 }
284
285 public void setShardDataSourceTargetSource(
286 ShardDataSourceTargetSource shardDataSourceTargetSource) {
287
288 _shardDataSourceTargetSource = shardDataSourceTargetSource;
289 }
290
291 public void setShardSessionFactoryTargetSource(
292 ShardSessionFactoryTargetSource shardSessionFactoryTargetSource) {
293
294 _shardSessionFactoryTargetSource = shardSessionFactoryTargetSource;
295 }
296
297 protected DataSource getDataSource() {
298 return _shardDataSourceTargetSource.getDataSource();
299 }
300
301 protected String popCompanyService() {
302 return _getCompanyServiceStack().pop();
303 }
304
305 protected void pushCompanyService(long companyId) {
306 try {
307 Shard shard = ShardLocalServiceUtil.getShard(
308 Company.class.getName(), companyId);
309
310 String shardName = shard.getName();
311
312 pushCompanyService(shardName);
313 }
314 catch (Exception e) {
315 _log.error(e, e);
316 }
317 }
318
319 protected void pushCompanyService(String shardName) {
320 _getCompanyServiceStack().push(shardName);
321 }
322
323 private Stack<String> _getCompanyServiceStack() {
324 Stack<String> companyServiceStack = _companyServiceStack.get();
325
326 if (companyServiceStack == null) {
327 companyServiceStack = new Stack<String>();
328
329 _companyServiceStack.set(companyServiceStack);
330 }
331
332 return companyServiceStack;
333 }
334
335 private String _getCompanyShardName(
336 String webId, String virtualHost, String mx, String shardName) {
337
338 Map<String, String> shardParams = new HashMap<String, String>();
339
340 shardParams.put("webId", webId);
341 shardParams.put("mx", mx);
342
343 if (virtualHost != null) {
344 shardParams.put("virtualHost", virtualHost);
345 }
346
347 shardName = ShardUtil.getShardSelector().getShardName(
348 ShardUtil.COMPANY_SCOPE, shardName, shardParams);
349
350 return shardName;
351 }
352
353 private String _getShardName() {
354 return _shardNameThreadLocal.get();
355 }
356
357 private String _getSignature(ProceedingJoinPoint proceedingJoinPoint) {
358 String methodName = StringUtil.extractLast(
359 proceedingJoinPoint.getTarget().getClass().getName(),
360 StringPool.PERIOD);
361
362 methodName +=
363 StringPool.PERIOD + proceedingJoinPoint.getSignature().getName() +
364 "()";
365
366 return methodName;
367 }
368
369 private void _setShardName(String shardName) {
370 _shardNameThreadLocal.set(shardName);
371 }
372
373 private void _setShardNameByCompany() throws Throwable {
374 Stack<String> companyServiceStack = _getCompanyServiceStack();
375
376 if (companyServiceStack.isEmpty()) {
377 long companyId = CompanyThreadLocal.getCompanyId();
378
379 _setShardNameByCompanyId(companyId);
380 }
381 else {
382 String shardName = companyServiceStack.peek();
383
384 _setShardName(shardName);
385 }
386 }
387
388 private void _setShardNameByCompanyId(long companyId)
389 throws PortalException, SystemException {
390
391 if (companyId == 0) {
392 _setShardName(PropsValues.SHARD_DEFAULT_NAME);
393 }
394 else {
395 Shard shard = ShardLocalServiceUtil.getShard(
396 Company.class.getName(), companyId);
397
398 String shardName = shard.getName();
399
400 _setShardName(shardName);
401 }
402 }
403
404 private static Log _log = LogFactoryUtil.getLog(ShardAdvice.class);
405
406 private static ThreadLocal<Stack<String>> _companyServiceStack =
407 new ThreadLocal<Stack<String>>();
408 private static ThreadLocal<Object> _globalCallThreadLocal =
409 new ThreadLocal<Object>();
410 private static ThreadLocal<String> _shardNameThreadLocal =
411 new InitialThreadLocal<String>(PropsValues.SHARD_DEFAULT_NAME);
412
413 private ShardDataSourceTargetSource _shardDataSourceTargetSource;
414 private ShardSessionFactoryTargetSource _shardSessionFactoryTargetSource;
415
416 }