使用 SpringBoot3, Security, auth0, Kotlin 组合实现前后端分离项目

创建项目

先在 start.spring.io 创建项目

语言为 Kotlin,类型为 Gradle - Kotlin, JDK 用的 Oracle JDK 17,数据库用的 MariaDB

依赖如下,其它部分和 Spring Initializr 一致

增加了 SpringDoc 和 auth0 两个依赖

// build.gradle.kts

val jwtVersion = "4.4.0"
val docVersion = "2.1.0"

// .......

dependencies {
    implementation("org.springframework.boot:spring-boot-starter-data-jpa")
    implementation("org.springframework.boot:spring-boot-starter-security")
    implementation("org.springframework.boot:spring-boot-starter-web")
    implementation("com.fasterxml.jackson.module:jackson-module-kotlin")
    implementation("org.jetbrains.kotlin:kotlin-reflect")
    implementation("com.auth0:java-jwt:$jwtVersion")
    implementation("org.springdoc:springdoc-openapi-starter-webmvc-api:$docVersion")
    implementation("org.springdoc:springdoc-openapi-starter-webmvc-ui:$docVersion")
    runtimeOnly("org.mariadb.jdbc:mariadb-java-client")
    testImplementation("org.springframework.boot:spring-boot-starter-test")
    testImplementation("org.springframework.security:spring-security-test")
}

// ......

Spring Security 配置

Spring Security 是一个比较完善的安全框架,但默认不使用前后端分离,所以需要我们手动调整一下

配置 Security 的方法就是实现对应的接口,并在配置类内注册,其可以替换的接口较多,这里就只写所需要的。

JWT 方案简单说就是,调用 /login 接口登录,获取到一个 token,然后调用需要权限的路径时,请求头携带该 token,这样服务端就能知道该用户是否有足够的权限。

由于是前后端分离项目,返回的数据是 Json,所以需要将原本的未登录和无权限逻辑替换为返回 Json

所以我们要实现两个接口,AccessDeniedHandler 用于覆盖无权限时的返回,AuthenticationEntryPoint 用于覆盖未登录时的返回。

// SecurityHandler.kt

@Component
class SecurityHandler: AccessDeniedHandler, AuthenticationEntryPoint {
    override fun handle(
        request: HttpServletRequest?,
        response: HttpServletResponse?,
        accessDeniedException: AccessDeniedException?
    ) {
        response?.contentType = "application/json;charset=utf-8"
        response?.writer?.write(ObjectMapper().writeValueAsString(mapOf("code" to 403, "msg" to "无权限")))
    }

    override fun commence(
        request: HttpServletRequest?,
        response: HttpServletResponse?,
        authException: AuthenticationException?
    ) {
        response?.contentType = "application/json;charset=utf-8"
        response?.writer?.write(ObjectMapper().writeValueAsString(mapOf("code" to 401, "msg" to "未登录")))
    }
}

由于 JWT 方案是不保存状态的,这也意味着用户信息要在每次请示时登录认证之前通过 token 解析出来,所以我们还要在加一个用于解析 token 的 Filter 放在登录验证之前

// JwtAuthenticationTokenFilter.kt

class JwtAuthenticationTokenFilter : OncePerRequestFilter() {

    @Resource
    private lateinit var jwtService: JwtService

    @Resource
    private lateinit var userRepository: UserRepository

    private val tokenHeader = "Authorization"
    private val tokenHead = "Bearer"

    override fun doFilterInternal(
        request: HttpServletRequest,
        response: HttpServletResponse,
        filterChain: FilterChain
    ) {
        val authHeader: String? = request.getHeader(tokenHeader)
        if (!authHeader.isNullOrEmpty() && authHeader.startsWith(tokenHead)) {
            val authToken = authHeader.substring(tokenHead.length).trimStart()
            val username: String? = jwtService.getUsername(authToken)
            if (!username.isNullOrEmpty() && SecurityContextHolder.getContext().authentication == null) {
                // 此处要有缓存,这里简化成每次都从数据库内查询了
                val user: User? = userRepository.getUserByUsername(username)
                
                // 生成 Security 所用的认证类,放到 SecurityContextHolder context 内,即可走通认证流程
                val authentication = UsernamePasswordAuthenticationToken(user, null, user?.getAuthorities())
                authentication.details = WebAuthenticationDetailsSource().buildDetails(request)
                SecurityContextHolder.getContext().authentication = authentication
            }
        }
        filterChain.doFilter(request, response)  // Filter 是链式的,不能直接 return 要继续走下面的 Filter
    }
}

JwtService 还需要实现一下,即通过 username 生成 token,在通过 token 解析 username

// JwtService.kt

import com.auth0.jwt.JWT
import com.auth0.jwt.algorithms.Algorithm
import org.springframework.stereotype.Service
import java.util.*

@Service
class JwtService {
    private val secret = "qwert12345"  // 自定义即可
    private val audience = "http://0.0.0.0:8080/"  // 通常是服务器域名
    private val issuer = "http://0.0.0.0:8080"

    fun getAlgorithm(): Algorithm {
        return Algorithm.HMAC256(secret)
    }

    fun generateToken(username: String, expire: Long): String {
        return JWT.create()
            .withAudience(audience)
            .withIssuer(issuer)
            .withClaim("username", username)
            .withExpiresAt(Date(System.currentTimeMillis() + expire))
            .sign(getAlgorithm())
    }

    fun getUsername(token: String?): String? {
        return try {
            val decoded = JWT
                .require(getAlgorithm())
                .withAudience(audience)
                .withIssuer(issuer)
                .build().verify(token)
            decoded.getClaim("username").asString()
        } catch (e: Exception) {
            null
        }
    }
}

最后我们要写一个配置类,用自己的 Handler 覆盖默认的

// SecurityConfig.kt

@Configuration
class SecurityConfig {

    @Resource
    private lateinit var securityHandler: SecurityHandler

    private val witheList = arrayOf(
        "/static/**",
        "/resources/**",
        "/api/register",
        "/api/login",
    )

    @Bean
    fun filterChain(http: HttpSecurity): SecurityFilterChain {
        with(http) {
            cors { it.disable() }  // 允许跨域
            csrf { it.disable() }
            formLogin { it.disable() }  // 禁用默认登录逻辑
            authorizeHttpRequests {
                it.requestMatchers(*witheList).permitAll()  // 白名单内路径不用权限
                it.anyRequest().hasAnyRole("USER")  // 其它路径需要 USER 角色
            }

            exceptionHandling {
                it.authenticationEntryPoint(securityHandler).accessDeniedHandler(securityHandler)
            }

            sessionManagement {
                it.sessionCreationPolicy(SessionCreationPolicy.STATELESS)  // 不保存 session
            }

            addFilterBefore(JwtAuthenticationTokenFilter(), UsernamePasswordAuthenticationFilter::class.java)

            return build()
        }
    }
}

MVC 逻辑

// User.kt

@Entity
@Table(name = "users")
data class User(
    @Id
    @GeneratedValue(strategy = GenerationType.IDENTITY)
    var id: Int,

    @Column(length = 31, nullable = false)
    var username: String,

    @Column(length = 31, nullable = false)
    var password: String,

    @Column(length = 31, nullable = false)
    var roles: String
) {
    fun getAuthorities(): MutableCollection<out GrantedAuthority> {
        // 该方法需要返回该用户的权限,数据使用 "ROLE_USER,ROLE_ADMIN" 多角色用 ',' 分隔的格式
        val roleList = mutableListOf<SimpleGrantedAuthority>()
        val roleStrList = this.roles.split(",")
        for (role in roleStrList) {
            roleList.add(SimpleGrantedAuthority(role))
        }
        return roleList
    }
}

// UserRepository.kt

@Repository
interface UserRepository : CrudRepository<User, Int> {
    fun getUserByUsername(username: String): User?
}

// DemoController.kt

@RestController
class DemoController {
    @Resource
    private lateinit var userRepository: UserRepository

    @Resource
    private lateinit var jwtService: JwtService
    private val passwordEncoder = BCryptPasswordEncoder()

    @PostMapping("/login")
    fun login(@RequestBody data: Map<String, String>): Map<String, Any> {
        val username = data["username"]
        val password = data["password"]
        if (username.isNullOrEmpty() || password.isNullOrEmpty()) {
            return mapOf("code" to 400, "msg" to "登录失败")
        }

        val user: User? = userRepository.getUserByUsername(username)
        if (user == null || !passwordEncoder.matches(password, user.password)) {
            return mapOf("code" to 400, "msg" to "密码错误")
        }

        val authentication = UsernamePasswordAuthenticationToken(user, null, user.getAuthorities())
        SecurityContextHolder.getContext().authentication = authentication
        val token: String = jwtService.generateToken(user.username, 60000 * 60)
        return mapOf("code" to 200, "token" to token)
    }

    @GetMapping("hello")
    fun demo(): Map<String, Any> {
        return mapOf("code" to 200, "msg" to "hello")
    }
}