1
14
15 package com.liferay.portal.dao.shard;
16
17 import com.liferay.counter.service.persistence.CounterPersistence;
18 import com.liferay.portal.NoSuchCompanyException;
19 import com.liferay.portal.kernel.exception.PortalException;
20 import com.liferay.portal.kernel.exception.SystemException;
21 import com.liferay.portal.kernel.log.Log;
22 import com.liferay.portal.kernel.log.LogFactoryUtil;
23 import com.liferay.portal.kernel.util.InfrastructureUtil;
24 import com.liferay.portal.kernel.util.InitialThreadLocal;
25 import com.liferay.portal.kernel.util.StringPool;
26 import com.liferay.portal.kernel.util.StringUtil;
27 import com.liferay.portal.model.Company;
28 import com.liferay.portal.model.Shard;
29 import com.liferay.portal.security.auth.CompanyThreadLocal;
30 import com.liferay.portal.service.CompanyLocalServiceUtil;
31 import com.liferay.portal.service.ShardLocalServiceUtil;
32 import com.liferay.portal.service.persistence.ClassNamePersistence;
33 import com.liferay.portal.service.persistence.CompanyPersistence;
34 import com.liferay.portal.service.persistence.ReleasePersistence;
35 import com.liferay.portal.service.persistence.ShardPersistence;
36 import com.liferay.portal.util.PropsValues;
37
38 import java.util.HashMap;
39 import java.util.Map;
40 import java.util.Stack;
41
42 import javax.sql.DataSource;
43
44 import org.aspectj.lang.ProceedingJoinPoint;
45
46
52 public class ShardAdvice {
53
54 public void afterPropertiesSet() {
55 if (_shardDataSourceTargetSource == null) {
56 _shardDataSourceTargetSource =
57 (ShardDataSourceTargetSource)InfrastructureUtil.
58 getShardDataSourceTargetSource();
59 }
60
61 if (_shardSessionFactoryTargetSource == null) {
62 _shardSessionFactoryTargetSource =
63 (ShardSessionFactoryTargetSource)InfrastructureUtil.
64 getShardSessionFactoryTargetSource();
65 }
66 }
67
68 public Object invokeByParameter(ProceedingJoinPoint proceedingJoinPoint)
69 throws Throwable {
70
71 Object[] arguments = proceedingJoinPoint.getArgs();
72
73 long companyId = (Long)arguments[0];
74
75 Shard shard = ShardLocalServiceUtil.getShard(
76 Company.class.getName(), companyId);
77
78 String shardName = shard.getName();
79
80 if (_log.isInfoEnabled()) {
81 _log.info(
82 "Service being set to shard " + shardName + " for " +
83 _getSignature(proceedingJoinPoint));
84 }
85
86 Object returnValue = null;
87
88 pushCompanyService(shardName);
89
90 try {
91 returnValue = proceedingJoinPoint.proceed();
92 }
93 finally {
94 popCompanyService();
95 }
96
97 return returnValue;
98 }
99
100 public Object invokeCompanyService(ProceedingJoinPoint proceedingJoinPoint)
101 throws Throwable {
102
103 String methodName = proceedingJoinPoint.getSignature().getName();
104 Object[] arguments = proceedingJoinPoint.getArgs();
105
106 String shardName = PropsValues.SHARD_DEFAULT_NAME;
107
108 if (methodName.equals("addCompany")) {
109 String webId = (String)arguments[0];
110 String virtualHost = (String)arguments[1];
111 String mx = (String)arguments[2];
112 shardName = (String)arguments[3];
113
114 shardName = _getCompanyShardName(webId, virtualHost, mx, shardName);
115
116 arguments[3] = shardName;
117 }
118 else if (methodName.equals("checkCompany")) {
119 String webId = (String)arguments[0];
120
121 if (!webId.equals(PropsValues.COMPANY_DEFAULT_WEB_ID)) {
122 if (arguments.length == 3) {
123 String mx = (String)arguments[1];
124 shardName = (String)arguments[2];
125
126 shardName = _getCompanyShardName(
127 webId, null, mx, shardName);
128
129 arguments[2] = shardName;
130 }
131
132 try {
133 Company company = CompanyLocalServiceUtil.getCompanyByWebId(
134 webId);
135
136 shardName = company.getShardName();
137 }
138 catch (NoSuchCompanyException nsce) {
139 }
140 }
141 }
142 else if (methodName.startsWith("update")) {
143 long companyId = (Long)arguments[0];
144
145 Shard shard = ShardLocalServiceUtil.getShard(
146 Company.class.getName(), companyId);
147
148 shardName = shard.getName();
149 }
150 else {
151 return proceedingJoinPoint.proceed();
152 }
153
154 if (_log.isInfoEnabled()) {
155 _log.info(
156 "Company service being set to shard " + shardName + " for " +
157 _getSignature(proceedingJoinPoint));
158 }
159
160 Object returnValue = null;
161
162 pushCompanyService(shardName);
163
164 try {
165 returnValue = proceedingJoinPoint.proceed(arguments);
166 }
167 finally {
168 popCompanyService();
169 }
170
171 return returnValue;
172 }
173
174 public Object invokeGlobally(ProceedingJoinPoint proceedingJoinPoint)
175 throws Throwable {
176
177 _globalCall.set(new Object());
178
179 try {
180 if (_log.isInfoEnabled()) {
181 _log.info(
182 "All shards invoked for " +
183 _getSignature(proceedingJoinPoint));
184 }
185
186 for (String shardName : PropsValues.SHARD_AVAILABLE_NAMES) {
187 _shardDataSourceTargetSource.setDataSource(shardName);
188 _shardSessionFactoryTargetSource.setSessionFactory(shardName);
189
190 proceedingJoinPoint.proceed();
191 }
192 }
193 finally {
194 _globalCall.set(null);
195 }
196
197 return null;
198 }
199
200 public Object invokePersistence(ProceedingJoinPoint proceedingJoinPoint)
201 throws Throwable {
202
203 if ((_shardDataSourceTargetSource == null) ||
204 (_shardSessionFactoryTargetSource == null)) {
205
206 return proceedingJoinPoint.proceed();
207 }
208
209 Object target = proceedingJoinPoint.getTarget();
210
211 if (target instanceof ClassNamePersistence ||
212 target instanceof CompanyPersistence ||
213 target instanceof CounterPersistence ||
214 target instanceof ReleasePersistence ||
215 target instanceof ShardPersistence) {
216
217 _shardDataSourceTargetSource.setDataSource(
218 PropsValues.SHARD_DEFAULT_NAME);
219 _shardSessionFactoryTargetSource.setSessionFactory(
220 PropsValues.SHARD_DEFAULT_NAME);
221
222 if (_log.isDebugEnabled()) {
223 _log.debug(
224 "Using default shard for " +
225 _getSignature(proceedingJoinPoint));
226 }
227
228 return proceedingJoinPoint.proceed();
229 }
230
231 if (_globalCall.get() == null) {
232 _setShardNameByCompany();
233
234 String shardName = _getShardName();
235
236 _shardDataSourceTargetSource.setDataSource(shardName);
237 _shardSessionFactoryTargetSource.setSessionFactory(shardName);
238
239 if (_log.isInfoEnabled()) {
240 _log.info(
241 "Using shard name " + shardName + " for " +
242 _getSignature(proceedingJoinPoint));
243 }
244
245 return proceedingJoinPoint.proceed();
246 }
247 else {
248 return proceedingJoinPoint.proceed();
249 }
250 }
251
252 public void setShardDataSourceTargetSource(
253 ShardDataSourceTargetSource shardDataSourceTargetSource) {
254
255 _shardDataSourceTargetSource = shardDataSourceTargetSource;
256 }
257
258 public void setShardSessionFactoryTargetSource(
259 ShardSessionFactoryTargetSource shardSessionFactoryTargetSource) {
260
261 _shardSessionFactoryTargetSource = shardSessionFactoryTargetSource;
262 }
263
264 protected DataSource getDataSource() {
265 return _shardDataSourceTargetSource.getDataSource();
266 }
267
268 protected String popCompanyService() {
269 return _getCompanyServiceStack().pop();
270 }
271
272 protected void pushCompanyService(long companyId) {
273 try {
274 Shard shard = ShardLocalServiceUtil.getShard(
275 Company.class.getName(), companyId);
276
277 String shardName = shard.getName();
278
279 pushCompanyService(shardName);
280 }
281 catch (Exception e) {
282 _log.error(e, e);
283 }
284 }
285
286 protected void pushCompanyService(String shardName) {
287 _getCompanyServiceStack().push(shardName);
288 }
289
290 private Stack<String> _getCompanyServiceStack() {
291 Stack<String> companyServiceStack = _companyServiceStack.get();
292
293 if (companyServiceStack == null) {
294 companyServiceStack = new Stack<String>();
295
296 _companyServiceStack.set(companyServiceStack);
297 }
298
299 return companyServiceStack;
300 }
301
302 private String _getCompanyShardName(
303 String webId, String virtualHost, String mx, String shardName) {
304
305 Map<String, String> shardParams = new HashMap<String, String>();
306
307 shardParams.put("webId", webId);
308 shardParams.put("mx", mx);
309
310 if (virtualHost != null) {
311 shardParams.put("virtualHost", virtualHost);
312 }
313
314 shardName = ShardUtil.getShardSelector().getShardName(
315 ShardUtil.COMPANY_SCOPE, shardName, shardParams);
316
317 return shardName;
318 }
319
320 private String _getShardName() {
321 return _shardName.get();
322 }
323
324 private String _getSignature(ProceedingJoinPoint proceedingJoinPoint) {
325 String methodName = StringUtil.extractLast(
326 proceedingJoinPoint.getTarget().getClass().getName(),
327 StringPool.PERIOD);
328
329 methodName +=
330 StringPool.PERIOD + proceedingJoinPoint.getSignature().getName() +
331 "()";
332
333 return methodName;
334 }
335
336 private void _setShardName(String shardName) {
337 _shardName.set(shardName);
338 }
339
340 private void _setShardNameByCompany() throws Throwable {
341 Stack<String> companyServiceStack = _getCompanyServiceStack();
342
343 if (companyServiceStack.isEmpty()) {
344 long companyId = CompanyThreadLocal.getCompanyId();
345
346 _setShardNameByCompanyId(companyId);
347 }
348 else {
349 String shardName = companyServiceStack.peek();
350
351 _setShardName(shardName);
352 }
353 }
354
355 private void _setShardNameByCompanyId(long companyId)
356 throws PortalException, SystemException {
357
358 if (companyId == 0) {
359 _setShardName(PropsValues.SHARD_DEFAULT_NAME);
360 }
361 else {
362 Shard shard = ShardLocalServiceUtil.getShard(
363 Company.class.getName(), companyId);
364
365 String shardName = shard.getName();
366
367 _setShardName(shardName);
368 }
369 }
370
371 private static Log _log = LogFactoryUtil.getLog(ShardAdvice.class);
372
373 private static ThreadLocal<Stack<String>> _companyServiceStack =
374 new ThreadLocal<Stack<String>>();
375 private static ThreadLocal<Object> _globalCall = new ThreadLocal<Object>();
376 private static ThreadLocal<String> _shardName =
377 new InitialThreadLocal<String>(PropsValues.SHARD_DEFAULT_NAME);
378
379 private ShardDataSourceTargetSource _shardDataSourceTargetSource;
380 private ShardSessionFactoryTargetSource _shardSessionFactoryTargetSource;
381
382 }