使用spring-authorization-server时是否可以将令牌存储在Redis中。在spring-security-oauth中我们可以定义TokenStore:
@Bean
public TokenStore redisTokenStore() {
RedisTokenStore redisTokenStore = new RedisTokenStore(redisConnectionFactory);
redisTokenStore.setPrefix(redisTokenPrefix);
return redisTokenStore;
}
理论上,我可以实现接口 OAuth2AuthorizationService 但也许有一个更简单、更漂亮的解决方案
我已执行以下操作将令牌信息存储到Redis集群中(我使用的是spring授权服务器1.2.3):
MyTokenAuthService:该类实现 OAuth2AuthorizationService 并调用缓存方法来存储 OAuth2Authorization 对象。
JedisRefreshAccessToken:缓存使用 java ObjectOutputStream 创建字节数组以存储在 Redis 中,并使用 java ObjectInputStream 从 Redis 缓存中存储的字节中获取字节数组。
MyJedis:此类使用 JedisCluster 类向 Redis 缓存写入和读取字节数组。
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
import org.springframework.stereotype.Component;
import in.org.cris.superapp.authserver.cache.CRISJedisRefreshAccessToken;
@Component
public class MyTokenAuthService implements OAuth2AuthorizationService {
@Autowired JedisRefreshAccessToken oauth2TokenCache;
private static final Logger logger = LoggerFactory.getLogger(MyTokenAuthService.class);
@Override
public void save(OAuth2Authorization authorization) {
logger.info(" save");
oauth2TokenCache.saveToken(authorization);
}
@Override
public void remove(OAuth2Authorization authorization) {
logger.info(" remove " + authorization.getId());
oauth2TokenCache.removeToken(authorization.getId());
}
@Override
public OAuth2Authorization findById(String id) {
logger.info(" findById " + id);
return oauth2TokenCache.findByIdFromCache(id);
}
@Override
public OAuth2Authorization findByToken(String token, OAuth2TokenType tokenType) {
logger.info(" findByToken " + token + " " + tokenType.getValue());
return oauth2TokenCache.findByToken(token, tokenType);
}
}
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2DeviceCode;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.OAuth2UserCode;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization.Token;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationCode;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
import org.springframework.stereotype.Component;
@Component
public class JedisRefreshAccessToken {
private final transient static Logger logger = LoggerFactory.getLogger(JedisRefreshAccessToken.class);
@Autowired MyJedis cache; // It has JedisCluster like methods (get, set with expire seconds).
static final String keyPrefix = "Oauth2Token_";
private static String cacheKey(String id) {
return keyPrefix + id;
}
private static String cacheKey_init(String id) {
return keyPrefix + "init_" + id;
}
private static String cacheTokenKey(String token) {
return keyPrefix + token;
}
public void removeToken(final String id) {
OAuth2Authorization auth = findByIdFromCache(id);
if(auth == null) {
logger.info("No token found to remove. Id " + id);
return;
}
cache.unlink(cacheTokenKey(auth.getAccessToken().getToken().getTokenValue()));
cache.unlink(cacheTokenKey(auth.getRefreshToken().getToken().getTokenValue()));
cache.unlink(cacheKey(auth.getId()).getBytes());
cache.unlink(cacheKey_init(auth.getId()).getBytes());
}
public OAuth2Authorization findByIdFromCache(final String id) {
final String key = cacheKey(id);
byte[] v = cache.get(key.getBytes());
if(v == null || v.length == 0) {
v = cache.get(cacheKey_init(id).getBytes());
}
if(v == null || v.length == 0) {
return null;
}
return readFromBytes(v);
}
public void saveToken(final OAuth2Authorization auth) {
final boolean isComplete = auth.getAccessToken() != null;
final String key = isComplete ? cacheKey(auth.getId()) : cacheKey_init(auth.getId());
byte[] v = getBytes(auth);
logger.info("saveToken Id " + auth.getId());
cache.set(key.getBytes(), v, 100*30*60);
storeTokensByValue(auth);
}
private void storeTokensByValue(final OAuth2Authorization auth) {
Token<Jwt> jwt = auth.getToken(Jwt.class);
Jwt jwtAccessToken = jwt!=null ? jwt.getToken() : null;
if(jwtAccessToken != null) {
cache.set(cacheTokenKey(jwtAccessToken.getTokenValue()), auth.getId(), 30*60);
}
Token<OAuth2AuthorizationCode> authCodeWr = auth.getToken(OAuth2AuthorizationCode.class);
OAuth2AuthorizationCode authCodeToken = authCodeWr!=null ? authCodeWr.getToken() : null;
if(authCodeToken != null) {
cache.set(cacheTokenKey(authCodeToken.getTokenValue()), auth.getId(), 300);
}
Token<OAuth2AccessToken> accessTokenWr = auth.getToken(OAuth2AccessToken.class);
OAuth2AccessToken accessToken = accessTokenWr!=null ? accessTokenWr.getToken() : null;
if(accessToken != null) {
cache.set(cacheTokenKey(accessToken.getTokenValue()), auth.getId(), 30*60);
}
Token<OidcIdToken> oidcTkn = auth.getToken(OidcIdToken.class);
OidcIdToken oidcIdToken = oidcTkn!=null ? oidcTkn.getToken() : null;
if(oidcIdToken != null) {
cache.set(cacheTokenKey(oidcIdToken.getTokenValue()), auth.getId(), 300);
}
Token<OAuth2UserCode> userCode = auth.getToken(OAuth2UserCode.class);
OAuth2UserCode userCodeToken = userCode!=null ? userCode.getToken() : null;
if(userCodeToken != null) {
cache.set(cacheTokenKey(userCodeToken.getTokenValue()), auth.getId(), 300);
}
Token<OAuth2DeviceCode> deviceCode = auth.getToken(OAuth2DeviceCode.class);
OAuth2DeviceCode deviceCodeToken = deviceCode!=null ? deviceCode.getToken() : null;
if(deviceCodeToken != null) {
cache.set(cacheTokenKey(deviceCodeToken.getTokenValue()), auth.getId(), 300);
}
if(auth.getRefreshToken() != null) {
OAuth2RefreshToken refreshToken = auth.getRefreshToken().getToken();
if(refreshToken != null) {
cache.set(cacheTokenKey(refreshToken.getTokenValue()), auth.getId(), 100*30*60);
}
}
}
public OAuth2Authorization findByToken(String token, OAuth2TokenType tokenType) {
final String key = cacheTokenKey(token);
String id = cache.get(key);
if(id == null || id.isEmpty()) {
logger.info("findByToken return null " + token);
return null;
}
byte[] bytes = cache.get(cacheKey(id).getBytes());
if(bytes == null || bytes.length == 0) {
bytes = cache.get(cacheKey_init(id).getBytes());
}
if(bytes == null || bytes.length == 0) {
return null;
}
return readFromBytes(bytes);
}
private OAuth2Authorization readFromBytes(final byte[] bytes) {
try(ByteArrayInputStream bin = new ByteArrayInputStream(bytes)) {
try(ObjectInputStream objIn = new ObjectInputStream(bin)) {
return (OAuth2Authorization)objIn.readObject();
}
}
catch(Exception e) {
logger.error("", e);
throw new RuntimeException(e);
}
}
private byte[] getBytes(final OAuth2Authorization auth) {
try {
try(ByteArrayOutputStream bout = new ByteArrayOutputStream()) {
try(ObjectOutputStream objOut = new ObjectOutputStream(bout)) {
objOut.writeObject(auth);
}
return bout.toByteArray();
}
}
catch(Exception e) {
logger.error("", e);
throw new RuntimeException(e);
}
}
}
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.commons.pool2.impl.GenericObjectPoolConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
import redis.clients.jedis.HostAndPort;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisCluster;
import redis.clients.jedis.Tuple;
import redis.clients.jedis.params.SetParams;
@Component
public class MyJedis {
private static final Logger logger = LoggerFactory.getLogger(MyJedis.class);
private static void initCluster() {
List<String> hostPorts = Arrays.asList("ip1:port1,ip2:port2".split(","));
for(final String hp: hostPorts) {
try {
String[] hpArr = hp.split(":");
final String host = hpArr[0];
final int port = Integer.parseInt(hpArr[1]);
jedisCluster = new JedisCluster(new HostAndPort(host, port),
120, 120, 3, getRedisPassword(), null, getGenericPoolConfig(), false);
jedisCluster.exists("some-key");
break;
}
catch(Exception e) {
e.printStackTrace();
}
}
logger.info("jedis cluster inited.");
}
private static final Thread initerThread = new Thread() {
public void run() {
if(jedisCluster == null) {
initCluster();
}
}
};
static void initializeRedisCache() {
if(jedisCluster == null) {
initerThread.start();
}
}
private static JedisCluster jedisCluster;
private static GenericObjectPoolConfig<Jedis> getGenericPoolConfig(){
GenericObjectPoolConfig<Jedis> genericObjectPoolConfig =
new GenericObjectPoolConfig<Jedis>();
genericObjectPoolConfig.setTestOnBorrow(true);
return genericObjectPoolConfig;
}
long unlink(String key) {
return jedisCluster.unlink(key);
}
long unlink(byte[] key) {
return jedisCluster.unlink(key);
}
String get(String key) {
return jedisCluster.get(key);
}
byte[] get(byte[] key) {
return jedisCluster.get(key);
}
boolean exists(String key) {
return jedisCluster.exists(key);
}
long ttl(String key) {
return jedisCluster.ttl(key);
}
String set(String key, String val) {
return jedisCluster.set(key, val);
}
String set(String key, String val, long expireSeconds) {
return jedisCluster.set(key, val, new SetParams().ex(expireSeconds));
}
String set(byte[] key, byte[] val, long expireSeconds) {
return jedisCluster.set(key, val, new SetParams().ex(expireSeconds));
}
}