Introduction

This is the third post and will be dedicated to OAuth2. There are two separate parts of OAuth2 in Spring Security, OAuth2 Resource Server and OAuth2 Client. Based on the OAuth2 RFC, OAuth2 resource server is the authorization server, sometimes also the resource owner and the resource provider if resources are hosted in the same server. On the other hand, OAuth2 client makes request to restricted resources, it often happens when you make requests to the resources you hosted in cloud providers, such as AWS and GCP.

OAuth2 Resource Server

Based on the official docs and knowledge we learned from previous post, we know that in order to validate user-provided credentials (username/password or token), we need to

  1. configure an AuthenticationFilter class, like UsernamePasswordAuthenticationFilter registered by formLogin(withDefaults()), to transform credentials into an AuthenticationToken object
  2. Pass the AuthenticationToken to a ProviderManager that contains many AuthenticationProviders, e.g. DaoAuthenticationProvider. Loop over providers to see if the object can be processed.
  3. If authentication succeeds, rest providers do not need to be run. A new AuthenticationToken with proper authorities will be created. Otherwise, if authentication fails, the application will throw an AuthenticationException that handled by ExceptionTranslationFilter, that decided which actions to take.

BearerTokenAuthenticationFilter

Similar to UsernamePasswordAuthenticationFilter, BearerTokenAuthenticationFilter is used to handle Http requests that has an Authorization header.
The filter is configured using http.oauth2ResourceServer() method. The oauth2ResourceServer() method provides two configurations, Jwt and OpaqueToken.

  • Jwt provides readable content. Client can understand information once decode the token.
  • OpaqueToken is only intended to be understood by the issuer. It can have proprietary information that is not for public uses.

Here we only consider Jwt. To set up a minimal configuration for Jwt for. You need a default JwtConfigure and a JwtDecoder bean. Otherwise, application will not start. The code came from spring sample GitHub

1
2
3
4
5
6
7
8
9
10
11
12
@Bean
public SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
return http
.authorizeHttpRequests(request -> request.anyRequest().authenticated())
.oauth2ResourceServer(configurer -> configurer.jwt(withDefaults()))
.build();
}

@Bean
public JwtDecoder jwtDecoder() {
return NimbusJwtDecoder.withPublicKey(this.publicKey).build();
}

The documentation in spring.io explains the process pretty well And I just want to fill in with more details

The default implementation of JwtDecoder.decode() uses DefaultJWTProcessor to verity that if the Jwt header contains “alg: RS256” and the payload contains “sub”. So generally speaking, the default BearerTokenAuthenticationFilter will allow any Jwt tokens with those two constrains applied.

Coding session

In this session, instead of using Nimbus, I want to use io.jsonwebtoken JJWT library to build jwt encoder and decoder. Also, I will take note of some experiments.

JwtDecoder

As described top, to allow spring application to run with oauth2ResourceServer(), we need a JwtDecoder Bean. So let’s implement one using the jjwt library.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
@AllArgsConstructor
public class JJwtDecoder implements JwtDecoder {
private static final String DECODING_ERROR_MESSAGE_TEMPLATE = "An error occurred while attempting to decode the Jwt: %s";

private PublicKey publicKey;

private final Converter<Map<String, Object>, Map<String, Object>> claimSetConverter = MappedJwtClaimSetConverter
.withDefaults(Collections.emptyMap());

private final OAuth2TokenValidator<Jwt> jwtValidator = JwtValidators.createDefault();

@Override
public Jwt decode(String token) throws JwtException {
Jws<Claims> jws = parse(token);
// verify that the jws is signed with RS256, and has an issuer.
if (!verify(jws)) {
throw new BadJwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed token"));
}
// construct a spring's jwt format
Jwt createdJwt = createJwt(token, jws);
return validateJwt(createdJwt);
}

/**
* Parse token into JJwt's jws format
* @return Jws<Claims>
*/
private Jws<Claims> parse(String token) {
Jws<Claims> jws = null;
try {
jws = Jwts.parser().verifyWith(publicKey)
.build().parseSignedClaims(token);
} catch (Exception ex) {
throw new BadJwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, ex.getMessage()));
}
return jws;
}

private Jwt createJwt(String token, Jws<Claims> jws) {
// It is safe to use jws.getHeader() and jws.getPayload() because both header and claim in JJwt extends Map.
// The claims need extra conversion because "exp" and "iat" is type Long and JwtBuilder requires Instant type.
Map<String, Object> claims = this.claimSetConverter.convert(jws.getPayload());
// @formatter:off
return Jwt.withTokenValue(token)
.headers((h) -> h.putAll(jws.getHeader()))
.claims((c) -> c.putAll(claims)) // never null if jws.getPayload() is not null.
.build();
}

private boolean verify(Jws<Claims> jws) {
String alg = jws.getHeader().getAlgorithm();
// check if jws uses RS256 signing.
if (alg == null || !alg.equals("RS256")) {
return false;
}
// check if payload contains a "sub" subject field
return jws.getPayload().containsKey("sub");
}

private Jwt validateJwt(Jwt jwt) {
OAuth2TokenValidatorResult result = this.jwtValidator.validate(jwt);
if (result.hasErrors()) {
Collection<OAuth2Error> errors = result.getErrors();
String validationErrorString = getJwtValidationExceptionMessage(errors);
throw new JwtValidationException(validationErrorString, errors);
}
return jwt;
}

private String getJwtValidationExceptionMessage(Collection<OAuth2Error> errors) {
for (OAuth2Error oAuth2Error : errors) {
if (StringUtils.hasLength(oAuth2Error.getDescription())) {
return String.format(DECODING_ERROR_MESSAGE_TEMPLATE, oAuth2Error.getDescription());
}
}
return "Unable to validate Jwt";
}
}

To explain the code above:

  1. We use JJwt library to transform token value into Jws claim set, verify that Jws contains “alg” in header and “sub” in payload.
  2. claimSetConverter (code) is a set of value converters to convert user provided claims into values required by spring security.
    • For example, “exp” and “iat” fields both have Long type values, but spring’s Jwt requires both to be Instant type values. MappedJwtClaimSetConverter::convertInstant is used to convert Long to Instant.
  3. jwtValidator is a set of validation checks that make sure Jwt is valid.
    • For example, the default validator, JwtTimestampValidator code, checks if the token has expired or not.
  4. validateJwt() and getJwtValidationExceptionMessage() methods are directly borrowed from NimbusJwtDecoder.java.

Now the application should run and behave the same as the example code on top.

JwtEncoder

Next, I want to issue a Jws to users who has logged in. I need to implement a JwtEncoder bean.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
@AllArgsConstructor
public class JJwtEncoder implements JwtEncoder {
private SecureDigestAlgorithm<PrivateKey, PublicKey> alg;

private PrivateKey privateKey;

@Override
public Jwt encode(JwtEncoderParameters parameters) throws JwtEncodingException {

// inherit header, if exists
JwsHeader jwsHeader = parameters.getJwsHeader();
Map<String, Object> headerMap = new LinkedHashMap<>();
if (jwsHeader != null && jwsHeader.getHeaders() != null) {
headerMap.putAll(jwsHeader.getHeaders());
}
// encoding algorithm is decided by alg, other place should not have knowledge about what algorithm used.
headerMap.put("alg", alg.getId());

JwtClaimsSet claims = parameters.getClaims();
// @formatter:off
String tokenValue = Jwts.builder()
.header()
.add(headerMap)
.and()
.claims()
.add(claims.getClaims())
.and()
.signWith(privateKey)
.compact();
// @formatter:on

return new Jwt(tokenValue, claims.getIssuedAt(), claims.getExpiresAt(), headerMap, claims.getClaims());
}
}

There is no extra work done here because all header and claims information are provided by the caller. The encoder will only set the “alg” field in header.

Retrieve result from BearerTokenAuthenticationFilter

The original purpose of this section is to implement a repository service for Jwt tokens. But later on, I realize that it’s a dumb idea to save jwt tokens unless the application needs bookkeeping (write-only). The only thing I can consider of doing is to refresh token after validation, so I rewrite this side as a side note for that.

PostBearerTokenAuthenticationFilter

The idea for this is to add a PostBearerTokenAuthenticationFilter right after BearerTokenAuthenticationFilter. So the code will be

1
2
3
4
5
6
7
8
9
10
@Bean
public SecurityFilterChain securityFilterChain1(HttpSecurity http) throws Exception {
return http
.authorizeHttpRequests(request -> request.anyRequest().authenticated())
// added BearerTokenAuthenticationFilter
.oauth2ResourceServer(configurer -> configurer.jwt(withDefaults()))
// added before BearerTokenAuthenticationFilter
.addFilterAfter(new PostBearerTokenAuthenticationFilter(), BearerTokenAuthenticationFilter.class)
.build();
}

The only problem is how to retrieve the authenticationResult from BearerTokenAuthenticationFilter. To do that, I realize that after Authentication authenticationResult = authenticationManager.authenticate(authenticationRequest);, the authenticationResult is put in a SecurityContext object and saved in two places at the same time:

  1. SecurityContextHolderStrategy: save context in thread-local per request (default) or globally in application.
  2. SecurityContextRepository: save context in the http’s request attributes, keyed by RequestAttributeSecurityContextRepository.class.getName() .concat(".SPRING_SECURITY_CONTEXT");.

Hence, to modify the AuthenticationResult that came from BearerTokenAuthenticationFilter:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
public class PostBearerTokenAuthenticationFilter extends OncePerRequestFilter {

private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();

@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
SecurityContext context = securityContextHolderStrategy.getContext();

Authentication authentication = context.getAuthentication();
// making sure authentication is produced by BearerTokenAuthenticationFilter
if (authentication instanceof JwtAuthenticationToken jwtAuthenticationToken) {
Jwt token = jwtAuthenticationToken.getToken();
if (token.getExpiresAt() == null) {
throw new BadJwtException("Jwt token does not have an expiration date");
}
// do the work
Authentication newAuthenticationResult = /* TODO */;
context.setAuthentication(newAuthenticationResult);
// re-save the result
this.securityContextHolderStrategy.setContext(context);
this.securityContextRepository.saveContext(context, request, response);
}

// continue the filter chain
filterChain.doFilter(request, response);
}
}

This method can apply to any object that you want to share within the same request.

Customize JwtConfigurer

This is just my thought, I have not actually fully implemented one yet. For code below,

1
2
3
4
5
6
7
8
@Bean
public SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
return http
.authorizeHttpRequests(request -> request.anyRequest().authenticated())
.formLogin(withDefaults())
.oauth2ResourceServer(configure -> configure.jwt(withDefaults()))
.build();
}

I can replace
configure.jwt(withDefaults())
to
configure.jwt(jwtConfigurer -> jwtConfigurer.authenticationManager(new ProviderManager(new JwtAuthenticationProvider(new JJwtDecoder(this.publicKey))))).

In that case, I can do some extra work

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
public class MyProviderManager extends ProviderManager {
public MyProviderManager(AuthenticationProvider... providers) {
super(providers);
}

@Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException {
// only return successful Authentication object if all providers return success
/* TODO */
}
}

class PostJwtAuthenticationProvider implements AuthenticationProvider {
@Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException {
/* TODO */
}

@Override
public boolean supports(Class<?> authentication) {
return true;
}
}

@Bean
public SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
AuthenticationProvider jwtProvider = new JwtAuthenticationProvider(new JJwtDecoder(this.publicKey));
AuthenticationProvider postJwtProvider = new PostJwtAuthenticationProvider();
ProviderManager jwtProviderManager = new MyProviderManager(jwtProvider, postJwtProvider);

return http
.authorizeHttpRequests(
request -> request.anyRequest().authenticated())
.formLogin(withDefaults())
.oauth2ResourceServer(configure -> configure.jwt(jwtConfigurer -> jwtConfigurer.authenticationManager(jwtProviderManager)))
.build();
}

Unlike ProviderManager, the implementation of MyProviderManager will only authenticate the request must when both jwtProvider and postJwtProvider return success.

Things to say

The only thing I feel that is important but yet to try is ObjectPostProcessor class. It seems to me that it is like a wrapper class where you can do some extra initialization and cleanup, but I cannot be certain though.

I think my next post will relate to the spring security’s OAuth client and test connecting to external services like GCP and AWS.