001
014
015 package com.liferay.portal.dao.shard;
016
017 import com.liferay.counter.service.persistence.CounterFinder;
018 import com.liferay.counter.service.persistence.CounterPersistence;
019 import com.liferay.portal.NoSuchCompanyException;
020 import com.liferay.portal.kernel.exception.PortalException;
021 import com.liferay.portal.kernel.exception.SystemException;
022 import com.liferay.portal.kernel.log.Log;
023 import com.liferay.portal.kernel.log.LogFactoryUtil;
024 import com.liferay.portal.kernel.util.InfrastructureUtil;
025 import com.liferay.portal.kernel.util.InitialThreadLocal;
026 import com.liferay.portal.kernel.util.StringPool;
027 import com.liferay.portal.kernel.util.StringUtil;
028 import com.liferay.portal.model.Company;
029 import com.liferay.portal.model.Shard;
030 import com.liferay.portal.security.auth.CompanyThreadLocal;
031 import com.liferay.portal.service.CompanyLocalServiceUtil;
032 import com.liferay.portal.service.ShardLocalServiceUtil;
033 import com.liferay.portal.service.persistence.ClassNamePersistence;
034 import com.liferay.portal.service.persistence.CompanyPersistence;
035 import com.liferay.portal.service.persistence.ReleasePersistence;
036 import com.liferay.portal.service.persistence.ShardPersistence;
037 import com.liferay.portal.util.PropsValues;
038
039 import java.util.HashMap;
040 import java.util.Map;
041 import java.util.Stack;
042
043 import javax.sql.DataSource;
044
045 import org.aspectj.lang.ProceedingJoinPoint;
046
047
051 public class ShardAdvice {
052
053 public void afterPropertiesSet() {
054 if (_shardDataSourceTargetSource == null) {
055 _shardDataSourceTargetSource =
056 (ShardDataSourceTargetSource)InfrastructureUtil.
057 getShardDataSourceTargetSource();
058 }
059
060 if (_shardSessionFactoryTargetSource == null) {
061 _shardSessionFactoryTargetSource =
062 (ShardSessionFactoryTargetSource)InfrastructureUtil.
063 getShardSessionFactoryTargetSource();
064 }
065 }
066
067 public Object invokeByParameter(ProceedingJoinPoint proceedingJoinPoint)
068 throws Throwable {
069
070 Object[] arguments = proceedingJoinPoint.getArgs();
071
072 long companyId = (Long)arguments[0];
073
074 Shard shard = ShardLocalServiceUtil.getShard(
075 Company.class.getName(), companyId);
076
077 String shardName = shard.getName();
078
079 if (_log.isInfoEnabled()) {
080 _log.info(
081 "Service being set to shard " + shardName + " for " +
082 _getSignature(proceedingJoinPoint));
083 }
084
085 Object returnValue = null;
086
087 pushCompanyService(shardName);
088
089 try {
090 returnValue = proceedingJoinPoint.proceed();
091 }
092 finally {
093 popCompanyService();
094 }
095
096 return returnValue;
097 }
098
099 public Object invokeCompanyService(ProceedingJoinPoint proceedingJoinPoint)
100 throws Throwable {
101
102 String methodName = proceedingJoinPoint.getSignature().getName();
103 Object[] arguments = proceedingJoinPoint.getArgs();
104
105 String shardName = PropsValues.SHARD_DEFAULT_NAME;
106
107 if (methodName.equals("addCompany")) {
108 String webId = (String)arguments[0];
109 String virtualHost = (String)arguments[1];
110 String mx = (String)arguments[2];
111 shardName = (String)arguments[3];
112
113 shardName = _getCompanyShardName(webId, virtualHost, mx, shardName);
114
115 arguments[3] = shardName;
116 }
117 else if (methodName.equals("checkCompany")) {
118 String webId = (String)arguments[0];
119
120 if (!webId.equals(PropsValues.COMPANY_DEFAULT_WEB_ID)) {
121 if (arguments.length == 3) {
122 String mx = (String)arguments[1];
123 shardName = (String)arguments[2];
124
125 shardName = _getCompanyShardName(
126 webId, null, mx, shardName);
127
128 arguments[2] = shardName;
129 }
130
131 try {
132 Company company = CompanyLocalServiceUtil.getCompanyByWebId(
133 webId);
134
135 shardName = company.getShardName();
136 }
137 catch (NoSuchCompanyException nsce) {
138 }
139 }
140 }
141 else if (methodName.startsWith("update")) {
142 long companyId = (Long)arguments[0];
143
144 Shard shard = ShardLocalServiceUtil.getShard(
145 Company.class.getName(), companyId);
146
147 shardName = shard.getName();
148 }
149 else {
150 return proceedingJoinPoint.proceed();
151 }
152
153 if (_log.isInfoEnabled()) {
154 _log.info(
155 "Company service being set to shard " + shardName + " for " +
156 _getSignature(proceedingJoinPoint));
157 }
158
159 Object returnValue = null;
160
161 pushCompanyService(shardName);
162
163 try {
164 returnValue = proceedingJoinPoint.proceed(arguments);
165 }
166 finally {
167 popCompanyService();
168 }
169
170 return returnValue;
171 }
172
173 public Object invokeGlobally(ProceedingJoinPoint proceedingJoinPoint)
174 throws Throwable {
175
176 _globalCall.set(new Object());
177
178 try {
179 if (_log.isInfoEnabled()) {
180 _log.info(
181 "All shards invoked for " +
182 _getSignature(proceedingJoinPoint));
183 }
184
185 for (String shardName : PropsValues.SHARD_AVAILABLE_NAMES) {
186 _shardDataSourceTargetSource.setDataSource(shardName);
187 _shardSessionFactoryTargetSource.setSessionFactory(shardName);
188
189 proceedingJoinPoint.proceed();
190 }
191 }
192 finally {
193 _globalCall.set(null);
194 }
195
196 return null;
197 }
198
199 public Object invokePersistence(ProceedingJoinPoint proceedingJoinPoint)
200 throws Throwable {
201
202 if ((_shardDataSourceTargetSource == null) ||
203 (_shardSessionFactoryTargetSource == null)) {
204
205 return proceedingJoinPoint.proceed();
206 }
207
208 Object target = proceedingJoinPoint.getTarget();
209
210 if (target instanceof ClassNamePersistence ||
211 target instanceof CompanyPersistence ||
212 target instanceof CounterFinder ||
213 target instanceof CounterPersistence ||
214 target instanceof ReleasePersistence ||
215 target instanceof ShardPersistence) {
216
217 _shardDataSourceTargetSource.setDataSource(
218 PropsValues.SHARD_DEFAULT_NAME);
219 _shardSessionFactoryTargetSource.setSessionFactory(
220 PropsValues.SHARD_DEFAULT_NAME);
221
222 if (_log.isDebugEnabled()) {
223 _log.debug(
224 "Using default shard for " +
225 _getSignature(proceedingJoinPoint));
226 }
227
228 return proceedingJoinPoint.proceed();
229 }
230
231 if (_globalCall.get() == null) {
232 _setShardNameByCompany();
233
234 String shardName = _getShardName();
235
236 _shardDataSourceTargetSource.setDataSource(shardName);
237 _shardSessionFactoryTargetSource.setSessionFactory(shardName);
238
239 if (_log.isInfoEnabled()) {
240 _log.info(
241 "Using shard name " + shardName + " for " +
242 _getSignature(proceedingJoinPoint));
243 }
244
245 return proceedingJoinPoint.proceed();
246 }
247 else {
248 return proceedingJoinPoint.proceed();
249 }
250 }
251
252 public void setShardDataSourceTargetSource(
253 ShardDataSourceTargetSource shardDataSourceTargetSource) {
254
255 _shardDataSourceTargetSource = shardDataSourceTargetSource;
256 }
257
258 public void setShardSessionFactoryTargetSource(
259 ShardSessionFactoryTargetSource shardSessionFactoryTargetSource) {
260
261 _shardSessionFactoryTargetSource = shardSessionFactoryTargetSource;
262 }
263
264 protected DataSource getDataSource() {
265 return _shardDataSourceTargetSource.getDataSource();
266 }
267
268 protected String popCompanyService() {
269 return _getCompanyServiceStack().pop();
270 }
271
272 protected void pushCompanyService(long companyId) {
273 try {
274 Shard shard = ShardLocalServiceUtil.getShard(
275 Company.class.getName(), companyId);
276
277 String shardName = shard.getName();
278
279 pushCompanyService(shardName);
280 }
281 catch (Exception e) {
282 _log.error(e, e);
283 }
284 }
285
286 protected void pushCompanyService(String shardName) {
287 _getCompanyServiceStack().push(shardName);
288 }
289
290 private Stack<String> _getCompanyServiceStack() {
291 Stack<String> companyServiceStack = _companyServiceStack.get();
292
293 if (companyServiceStack == null) {
294 companyServiceStack = new Stack<String>();
295
296 _companyServiceStack.set(companyServiceStack);
297 }
298
299 return companyServiceStack;
300 }
301
302 private String _getCompanyShardName(
303 String webId, String virtualHost, String mx, String shardName) {
304
305 Map<String, String> shardParams = new HashMap<String, String>();
306
307 shardParams.put("webId", webId);
308 shardParams.put("mx", mx);
309
310 if (virtualHost != null) {
311 shardParams.put("virtualHost", virtualHost);
312 }
313
314 shardName = _shardSelector.getShardName(
315 ShardSelector.COMPANY_SCOPE, shardName, shardParams);
316
317 return shardName;
318 }
319
320 private String _getShardName() {
321 return _shardName.get();
322 }
323
324 private String _getSignature(ProceedingJoinPoint proceedingJoinPoint) {
325 String methodName = StringUtil.extractLast(
326 proceedingJoinPoint.getTarget().getClass().getName(),
327 StringPool.PERIOD);
328
329 methodName +=
330 StringPool.PERIOD + proceedingJoinPoint.getSignature().getName() +
331 "()";
332
333 return methodName;
334 }
335
336 private void _setShardName(String shardName) {
337 _shardName.set(shardName);
338 }
339
340 private void _setShardNameByCompany() throws Throwable {
341 Stack<String> companyServiceStack = _getCompanyServiceStack();
342
343 if (companyServiceStack.isEmpty()) {
344 long companyId = CompanyThreadLocal.getCompanyId();
345
346 _setShardNameByCompanyId(companyId);
347 }
348 else {
349 String shardName = companyServiceStack.peek();
350
351 _setShardName(shardName);
352 }
353 }
354
355 private void _setShardNameByCompanyId(long companyId)
356 throws PortalException, SystemException {
357
358 if (companyId == 0) {
359 _setShardName(PropsValues.SHARD_DEFAULT_NAME);
360 }
361 else {
362 Shard shard = ShardLocalServiceUtil.getShard(
363 Company.class.getName(), companyId);
364
365 String shardName = shard.getName();
366
367 _setShardName(shardName);
368 }
369 }
370
371 private static Log _log = LogFactoryUtil.getLog(ShardAdvice.class);
372
373 private static ThreadLocal<Stack<String>> _companyServiceStack =
374 new ThreadLocal<Stack<String>>();
375 private static ThreadLocal<Object> _globalCall = new ThreadLocal<Object>();
376 private static ThreadLocal<String> _shardName =
377 new InitialThreadLocal<String>(
378 ShardAdvice.class + "._shardName", PropsValues.SHARD_DEFAULT_NAME);
379 private static ShardSelector _shardSelector;
380
381 private ShardDataSourceTargetSource _shardDataSourceTargetSource;
382 private ShardSessionFactoryTargetSource _shardSessionFactoryTargetSource;
383
384 static {
385 try {
386 _shardSelector = (ShardSelector)Class.forName(
387 PropsValues.SHARD_SELECTOR).newInstance();
388 }
389 catch (Exception e) {
390 _log.error(e, e);
391 }
392 }
393
394 }