教程 · 2020年11月20日 0

Spring Security 与 JWT 整合

内容目录

[lwptoc]

Spring Security 是一个用于 Spring 程序身份认证的框架,里面东西很多,用起来也很方便,但是上手并不是那么容易,最近粗略看了一些教程,总算是把它用到自己的项目里去了,并且跟 JWT 整合了一下,下面详细记录一下我的步骤防止我下次用的时候忘了。 -_-

1. 导入依赖

首先将相关的 jar 导入项目中。
构建配置如下(部分):

Maven

<!-- Spring Security -->
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-security</artifactId>
</dependency>

<!-- Java Web Token -->
<!-- https://github.com/jwtk/jjwt -->
<dependency>
    <groupId>io.jsonwebtoken</groupId>
    <artifactId>jjwt-api</artifactId>
    <version>0.11.2</version>
</dependency>
<dependency>
    <groupId>io.jsonwebtoken</groupId>
    <artifactId>jjwt-impl</artifactId>
    <version>0.11.2</version>
    <scope>runtime</scope>
</dependency>
<dependency>
    <groupId>io.jsonwebtoken</groupId>
    <artifactId>jjwt-jackson</artifactId>
    <version>0.11.2</version>
    <scope>compile</scope> <!-- 这里需要对 jackson 进行一些配置,所以不宜用 runtime -->
</dependency>

<!-- 其他 -->

Gradle

implementation 'org.springframework.boot:spring-boot-starter-security'
implementation 'io.jsonwebtoken:jjwt-api:0.11.2'
implementation 'io.jsonwebtoken:jjwt-jackson:0.11.2'
runtimeOnly 'io.jsonwebtoken:jjwt-impl:0.11.2'

2. 启用 Spring Security

创建一个 Spring 配置类。

package com.example.config;

import javax.servlet.http.HttpServletResponse;
import com.example.filter.JwtAuthenticationFilter;
import com.example.filter.JwtAuthorizationFilter;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.context.annotation.Bean;
import org.springframework.http.HttpMethod;
import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder;
import org.springframework.security.config.annotation.method.configuration.EnableGlobalMethodSecurity;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
import org.springframework.security.config.http.SessionCreationPolicy;
import org.springframework.security.core.userdetails.UserDetailsService;
import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder;

@EnableWebSecurity
@EnableGlobalMethodSecurity(prePostEnabled = true) // 启用方法注解的解析
public class SecurityConfig extends WebSecurityConfigurerAdapter {

    private final UserDetailsService userDetailsService;

    public SecurityConfig(@Qualifier("userDetailsServiceImpl") UserDetailsService userDetailsService) {
        this.userDetailsService = userDetailsService;
    }

    @Bean
    public BCryptPasswordEncoder bCryptPasswordEncoder() {
        return new BCryptPasswordEncoder(); // 使用 BCrypt 密码编码器
    }

    @Bean
    @Override
    public UserDetailsService userDetailsServiceBean() {
        return this.userDetailsService;
    }

    @Override
    protected void configure(AuthenticationManagerBuilder auth) throws Exception {
        // 使用自定义验证服务
        auth.userDetailsService(userDetailsService).passwordEncoder(bCryptPasswordEncoder());
    }

    @Override
    protected void configure(HttpSecurity http) throws Exception {
        http.cors().disable()
            // 禁用 CSRF
            .csrf().disable()
            .authorizeRequests()
            .antMatchers(HttpMethod.POST, "/login").permitAll()
            .antMatchers("/error").permitAll()
            .and()
            //添加自定义 Filter,见下
            .addFilter(new JwtAuthenticationFilter(authenticationManager()))
            .addFilter(new JwtAuthorizationFilter(authenticationManager()))
            // 不需要session(不创建会话)
            .sessionManagement().sessionCreationPolicy(SessionCreationPolicy.STATELESS).and()
            // 授权异常处理
            .exceptionHandling().authenticationEntryPoint((httpServletRequest, httpServletResponse, e) -> httpServletResponse.sendError(HttpServletResponse.SC_UNAUTHORIZED, e.getMessage()))
            .accessDeniedHandler((httpServletRequest, httpServletResponse, e) -> httpServletResponse.sendError(HttpServletResponse.SC_FORBIDDEN, e.getMessage()))
            // 防止 H2 web 页面的 Frame 被拦截
            .and().headers()
            .frameOptions().disable();
    }
}

3. 创建用户相关的数据实体类和服务

略过,只需要注意让实体类实现 org.springframework.security.core.userdetails.UserDetails 接口,让服务实现 org.springframework.security.core.userdetails.UserDetailsService 接口即可,详细说明查看相关文档。

还有就是 org.springframework.security.core.userdetails.UserDetails#getAuthorities 方法需要返回当前用户的权限,下面会用到的。

4. 编写认证和验证的过滤器

首先是用于认证的过滤器,它的作用是给发给登录接口、带有用户名密码的请求返回一个 token,代码如下:

package com.example.filter;

import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.util.List;
import java.util.stream.Collectors;
import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import com.example.bean.LoginRequest;
import com.example.model.Account;
import com.example.utils.JwtUtils;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.AuthenticationServiceException;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter;
import org.springframework.web.context.WebApplicationContext;

public class JwtAuthenticationFilter extends UsernamePasswordAuthenticationFilter {

    private final AuthenticationManager authenticationManager;
    private final ThreadLocal<Boolean> rememberMe = new ThreadLocal<>();

    public JwtAuthenticationFilter(AuthenticationManager authenticationManager) {
        this.authenticationManager = authenticationManager;
        // 设置验证 URL
        super.setFilterProcessesUrl("/login");
    }

    @Override
    public Authentication attemptAuthentication(HttpServletRequest request, HttpServletResponse response) throws AuthenticationException {
        if (request.getParameter("logout") != null && !request.getParameter("logout").equalsIgnoreCase("false")) {
            response.setHeader(JwtUtils.TOKEN_HEADER, "logout"); // 当请求退出登录时清除 token 即可
            response.setStatus(HttpServletResponse.SC_NO_CONTENT);
            return null;
        }
        if (!request.getMethod().equals("POST")) {
            throw new AuthenticationServiceException(
                "Authentication method not supported: " + request.getMethod());
        }
        WebApplicationContext context = RequestContextUtils.findWebApplicationContext(request);
        if (context == null) throw new IllegalStateException("not in a request");
        ObjectMapper objectMapper = context.getBean(ObjectMapper.class); // 获取 ObjectMapper,也可以从构造器传参注入
        try {
            LoginRequest loginRequest = objectMapper.readValue(request.getInputStream(), LoginRequest.class);
            rememberMe.set(loginRequest.rememberMe);
            UsernamePasswordAuthenticationToken authentication = new UsernamePasswordAuthenticationToken(
                loginRequest.username, loginRequest.password);
            return authenticationManager.authenticate(authentication);
        } catch (IOException e) {
            throw new AuthenticationServiceException(e.getLocalizedMessage(), e);
        }
    }

    @Override
    protected void successfulAuthentication(HttpServletRequest request, HttpServletResponse response, FilterChain chain, Authentication authResult) {
        // 若验证成功,principal 中会包含取到的用户信息(用第 3 步写的服务获取的)
        Account account = (Account) authResult.getPrincipal();
        List<String> authorities = account.getAuthorities().stream().map(GrantedAuthority::getAuthority).collect(Collectors.toList()); // 转换成字符列表
        // 创建 Token
        String token = JwtUtils.getToken(account.getUsername(), authorities, rememberMe.get(), account);
        rememberMe.remove();
        // Http Response Header 中返回 Token
        response.setHeader(JwtUtils.TOKEN_HEADER, token);
        response.setStatus(HttpServletResponse.SC_NO_CONTENT);
    }

    @Override
    protected void unsuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response, AuthenticationException failed) throws IOException {
        response.sendError(HttpServletResponse.SC_UNAUTHORIZED, failed.getMessage());
    }
}

用到的工具类代码如下:

package com.example.bean;
/**
 * 用户登录请求
 */
public class LoginRequest {
    public String username;
    public String password;
    public boolean rememberMe;
}
package com.example.utils;

import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Jws;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.SignatureAlgorithm;
import io.jsonwebtoken.io.DecodingException;
import io.jsonwebtoken.jackson.io.JacksonDeserializer;
import io.jsonwebtoken.lang.Maps;
import io.jsonwebtoken.security.Keys;
import java.util.Arrays;
import java.util.Date;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import javax.crypto.SecretKey;
import javax.xml.bind.DatatypeConverter;
import net.wlgzs.grading.model.Account;
import org.json.JSONException;
import org.json.JSONObject;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.context.SecurityContextHolder;

public class JwtUtils {

    public static final long EXPIRATION = 20 * 60L; // 20 分钟
    public static final long EXPIRATION_REMEMBER = 60 * 60 * 24 * 7L; // 一周
    public static final String JWT_SECRET_KEY = "自己生成一个密钥"; // 生成方式:DatatypeConverter.printBase64Binary(Keys.secretKeyFor(SignatureAlgorithm.HS256).getEncoded())
    public static final String TOKEN_HEADER = "Authorization";
    public static final String TOKEN_PREFIX = "Bearer ";
    public static final String TOKEN_TYPE = "JWT";

    private static final byte[] apiKeySecretBytes = DatatypeConverter.parseBase64Binary(JWT_SECRET_KEY);
    private static final SecretKey secretKey = Keys.hmacShaKeyFor(apiKeySecretBytes);

    private JwtUtils() {
        throw new IllegalStateException();
    }

    public static String getToken(String username, List<String> roles, boolean isRememberMe, Account account) {
        long expiration = isRememberMe ? EXPIRATION_REMEMBER : EXPIRATION;
        final Date createdDate = new Date();
        final Date expirationDate = new Date(createdDate.getTime() + expiration * 1000);
        String token = Jwts.builder()
            .setHeaderParam("typ", TOKEN_TYPE)
            .setHeaderParam("cty", account.getClass().getName()) // 这里我的 Account 有派生类,所以存了个类名
            .signWith(secretKey, SignatureAlgorithm.HS256)
            .claim("rol", String.join(",", roles))
            .setIssuer("JWT 示例")
            .setIssuedAt(createdDate)
            .setSubject(username)
            .setExpiration(expirationDate)
            .claim("account", account)
            .compact();
        return TOKEN_PREFIX + token;
    }

    public static boolean isTokenExpired(Claims tokenBody) {
        Date expiredDate = tokenBody.getExpiration();
        return expiredDate.before(new Date());
    }

    public static String getUsernameByToken(String token) {
        return parseToken(token).getBody().getSubject();
    }

    public static List<? extends GrantedAuthority> getUserRolesByToken(String token) {
        String role = (String) parseToken(token).getBody()
            .get("rol");
        return Arrays.stream(role.split(","))
            .map(String::trim)
            .map(Account.AdminLevel::valueOf)
            .collect(Collectors.toList());
    }

    public static Account getAccountByTokenBody(Jws<Claims> token) {
        Account account;
        try {
            account = (Account) token.getBody().get("account", Class.forName(token.getHeader().getContentType()));
        } catch (ClassNotFoundException e) {
            throw new IllegalArgumentException("accountType is not a known account class", e);
        }
        return account;
    }

    public static Optional<Account> getCurrentAccount() {
        // 这一步从当前请求上下文中获取已经验证的账号对象
        Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
        return Optional.ofNullable(authentication)
            .map(auth -> (Account) auth.getDetails());
    }

    @SuppressWarnings("unchecked")
    public static Class<? extends Account> getAccountType(String token) {
        if (token.isEmpty()) throw new DecodingException("token 为空");
        if (token.indexOf('.') <= 1) throw new DecodingException("token 格式错误");
        StringBuilder headerBase64 = new StringBuilder(token.substring(0, token.indexOf('.')).replace('-', '+').replace('_', '/'));
        while (headerBase64.length() % 4 != 0) headerBase64.append("=");
        String header = new String(DatatypeConverter.parseBase64Binary(headerBase64.toString()));
        try {
            JSONObject jsonObject = new JSONObject(header);
            String cty = jsonObject.getString("cty");
            return (Class<? extends Account>) Class.forName(cty);
        } catch (JSONException | ClassNotFoundException e) {
            throw new DecodingException("accountType is not a known account class", e);
        }
    }

    @SuppressWarnings({"unchecked", "rawtypes"})
    public static Jws<Claims> parseToken(String token) {
        // ======== 由于我的 Account 有派生类,所以先解析是哪种账号 ========
        Class<? extends Account> accountClass = getAccountType(token);
        // ==================
        return Jwts.parserBuilder()
            .setSigningKey(secretKey)
            .deserializeJsonWith(new JacksonDeserializer(Maps.of("account", accountClass).build()))
            .build()
            .parseClaimsJws(token);
    }

}

其次是过滤掉未验证的请求,写一个过滤器即可。

package com.example.filter;

import io.jsonwebtoken.Claims;
import io.jsonwebtoken.ExpiredJwtException;
import io.jsonwebtoken.Jws;
import io.jsonwebtoken.JwtException;
import io.jsonwebtoken.Jwts;
import java.io.IOException;
import java.time.Duration;
import java.util.Date;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import com.example.utils.JwtUtils;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.web.authentication.www.BasicAuthenticationFilter;
import org.springframework.util.StringUtils;

@Slf4j
public class JwtAuthorizationFilter extends BasicAuthenticationFilter {

    public JwtAuthorizationFilter(AuthenticationManager authenticationManager) {
        super(authenticationManager);
    }

    @Override
    protected void doFilterInternal(HttpServletRequest request,
                                    HttpServletResponse response,
                                    FilterChain chain) throws IOException, ServletException {
        String token = request.getHeader(JwtUtils.TOKEN_HEADER);
        if (token == null || !token.startsWith(JwtUtils.TOKEN_PREFIX)) {
            SecurityContextHolder.clearContext();
        } else {
            try {
                UsernamePasswordAuthenticationToken authRequest = getAuthentication(token, response);
                log.info("authentication: {}", authRequest);
                if (authRequest == null) {
                    chain.doFilter(request, response);
                    return;
                }
//              Authentication authResult = this.getAuthenticationManager().authenticate(authRequest); // 不需要再这里验证,因为如果 JWT 非法,就已经抛出异常了
                SecurityContextHolder.getContext().setAuthentication(authRequest);
                onSuccessfulAuthentication(request, response, authRequest);
            } catch (AuthenticationException failed) {
                SecurityContextHolder.clearContext();
                onUnsuccessfulAuthentication(request, response, failed);
                if (this.isIgnoreFailure()) {
                    chain.doFilter(request, response);
                } else {
                    this.getAuthenticationEntryPoint().commence(request, response, failed);
                }
                return;
            }
        }

        chain.doFilter(request, response);
    }

    /**
     * 获取用户认证信息 Authentication
     */
    private UsernamePasswordAuthenticationToken getAuthentication(String authorization, HttpServletResponse response) {
        String token = authorization.replaceFirst(JwtUtils.TOKEN_PREFIX, "");
        long renewPeriod = Duration.ofMinutes(20).toMillis();
        try {
            Jws<Claims> jwsToken;
            try {
                jwsToken = JwtUtils.parseToken(token);
            } catch (ExpiredJwtException e) {
                // 过期 20 分钟后仍有效
                // 防止当前用户正在使用中而过 token 期
                if (System.currentTimeMillis() - e.getClaims().getExpiration().getTime() < renewPeriod) {
                    Account account;
                    try {
                        account = (Account) e.getClaims().get("account", Class.forName(e.getHeader().getContentType()));
                        token = JwtUtils.getToken(e.getClaims().getSubject(), Arrays.asList(e.getClaims().get("rol", String.class).split(",")), false, account);
                        token = token.substring(token.indexOf(' ') + 1);
                    } catch (ClassNotFoundException ex) {
                        throw new IllegalArgumentException("accountType is not a known account class", ex);
                    }
                    response.setHeader(JwtUtils.TOKEN_HEADER, JwtUtils.TOKEN_PREFIX + token);
                    jwsToken = JwtUtils.parseToken(token);
                } else throw e;
            }
            Claims tokenBody = jwsToken.getBody();

            String username = tokenBody.getSubject();
            if (!StringUtils.isEmpty(username)) {
                UserDetails userDetails = JwtUtils.getAccountByTokenBody(jwsToken);
                UsernamePasswordAuthenticationToken usernamePasswordAuthenticationToken = new UsernamePasswordAuthenticationToken(username, null, userDetails.getAuthorities());
                usernamePasswordAuthenticationToken.setDetails(userDetails);
                return userDetails.isEnabled() ? usernamePasswordAuthenticationToken : null;
            }
        } catch (JwtException | IllegalArgumentException e) {
            AuthenticationException ex = new BadCredentialsException(e.getLocalizedMessage(), e);
            log.warn("token:{} 无效", token, ex);
            throw ex;
        }
        return null;
    }
}

5. 在请求方法上加上权限要求

@GetMapping("/info")
@PreAuthorize("hasAnyRole('USER')") // 这里写认证权限
public Optional<Account> getInfo() {
    return JwtUtils.getCurrentAccount();
}

OK 到这里就把应用加上身份认证了,点到为止。