286 lines
12 KiB
Java
286 lines
12 KiB
Java
package cz.trask.adfsauthms.service;
|
|
|
|
import java.io.ByteArrayInputStream;
|
|
import java.io.ByteArrayOutputStream;
|
|
import java.io.InputStream;
|
|
import java.io.OutputStream;
|
|
import java.net.Authenticator;
|
|
import java.net.HttpURLConnection;
|
|
import java.net.InetSocketAddress;
|
|
import java.net.PasswordAuthentication;
|
|
import java.net.Proxy;
|
|
import java.net.URI;
|
|
import java.net.URL;
|
|
import java.net.URLEncoder;
|
|
import java.nio.charset.StandardCharsets;
|
|
import java.security.KeyFactory;
|
|
import java.security.MessageDigest;
|
|
import java.security.PrivateKey;
|
|
import java.security.Signature;
|
|
import java.security.cert.CertificateFactory;
|
|
import java.security.cert.X509Certificate;
|
|
import java.security.spec.PKCS8EncodedKeySpec;
|
|
import java.util.Base64;
|
|
import java.util.LinkedHashMap;
|
|
import java.util.Map;
|
|
import java.util.UUID;
|
|
|
|
import org.apache.logging.log4j.LogManager;
|
|
import org.apache.logging.log4j.Logger;
|
|
|
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
|
|
|
import cz.trask.adfsauthms.config.AdfsConfig;
|
|
import cz.trask.adfsauthms.dto.TokenPayloadIdp;
|
|
|
|
public class AdfsTokenService {
|
|
|
|
private static final Logger logger = LogManager.getLogger(AdfsTokenService.class);
|
|
|
|
private static final String PARAM_RESOURCE = "resource";
|
|
private static final String PARAM_CLIENT_ID = "client_id";
|
|
private static final String PARAM_CLIENT_ASSERTION_TYPE = "client_assertion_type";
|
|
private static final String PARAM_CLIENT_ASSERTION = "client_assertion";
|
|
private static final String PARAM_GRANT_TYPE = "grant_type";
|
|
|
|
private static final String GRANT_TYPE_CLIENT_CREDENTIALS = "client_credentials";
|
|
private static final String CLIENT_ASSERTION_TYPE =
|
|
"urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
|
|
|
|
private static final String HEADER_CONTENT_TYPE = "Content-Type";
|
|
private static final String CONTENT_TYPE_FORM = "application/x-www-form-urlencoded";
|
|
private static final String METHOD_POST = "POST";
|
|
|
|
private final AdfsConfig config;
|
|
private final ObjectMapper objectMapper;
|
|
|
|
private volatile TokenPayloadIdp cachedToken;
|
|
|
|
public AdfsTokenService(AdfsConfig config, ObjectMapper objectMapper) {
|
|
this.config = config;
|
|
this.objectMapper = objectMapper;
|
|
}
|
|
|
|
public synchronized TokenPayloadIdp getToken(String clientIdOverride) throws Exception {
|
|
validateConfig();
|
|
if (clientIdOverride == null && cachedToken != null && cachedToken.getAccessToken() != null
|
|
&& !cachedToken.getAccessToken().isBlank()
|
|
&& !isJwtExpired(cachedToken.getAccessToken())) {
|
|
logger.debug("Returning cached ADFS token");
|
|
return cachedToken;
|
|
}
|
|
|
|
logger.info("Fetching new ADFS token from {} (Client ID: {})", config.getTokenUrl(),
|
|
clientIdOverride != null ? clientIdOverride : getDefaultClientId());
|
|
TokenPayloadIdp fetchedToken = fetchToken(clientIdOverride);
|
|
if (clientIdOverride == null) {
|
|
this.cachedToken = fetchedToken;
|
|
}
|
|
return fetchedToken;
|
|
}
|
|
|
|
private String getDefaultClientId() {
|
|
if (config.getClientIds() != null && !config.getClientIds().isEmpty()) {
|
|
return config.getClientIds().get(0);
|
|
}
|
|
return null;
|
|
}
|
|
|
|
public synchronized void invalidateCache() {
|
|
logger.debug("Invalidating cached ADFS token");
|
|
this.cachedToken = null;
|
|
}
|
|
|
|
private TokenPayloadIdp fetchToken(String clientIdOverride) throws Exception {
|
|
Proxy proxy = buildProxy();
|
|
if (proxy != null) {
|
|
logger.debug("Using proxy for ADFS token fetch: {}", config.getProxyHost());
|
|
}
|
|
String effectiveClientId = clientIdOverride != null ? clientIdOverride : getDefaultClientId();
|
|
String clientAssertion = generateJwtAssertion(effectiveClientId);
|
|
logger.debug("Generated JWT client assertion for ADFS: {}", clientAssertion);
|
|
String formData = buildFormData(effectiveClientId, clientAssertion);
|
|
|
|
HttpURLConnection connection = null;
|
|
try {
|
|
URL url = new URI(config.getTokenUrl()).toURL();
|
|
connection = (HttpURLConnection) url.openConnection(proxy != null ? proxy : Proxy.NO_PROXY);
|
|
connection.setRequestMethod(METHOD_POST);
|
|
connection.setDoInput(true);
|
|
connection.setDoOutput(true);
|
|
connection.setRequestProperty(HEADER_CONTENT_TYPE, CONTENT_TYPE_FORM);
|
|
|
|
try (OutputStream outputStream = connection.getOutputStream()) {
|
|
outputStream.write(formData.getBytes(StandardCharsets.UTF_8));
|
|
}
|
|
|
|
int responseCode = connection.getResponseCode();
|
|
logger.debug("ADFS response code: {}", responseCode);
|
|
InputStream responseStream = responseCode >= 400 ? connection.getErrorStream() : connection.getInputStream();
|
|
String responseBody = readFully(responseStream);
|
|
if (responseCode >= 400) {
|
|
logger.debug("ADFS error body: {}", responseBody);
|
|
throw new IllegalStateException("ADFS returned HTTP " + responseCode + ": " + responseBody);
|
|
}
|
|
logger.debug("ADFS token fetched successfully");
|
|
return objectMapper.readValue(responseBody, TokenPayloadIdp.class);
|
|
} finally {
|
|
if (connection != null) {
|
|
connection.disconnect();
|
|
}
|
|
}
|
|
}
|
|
|
|
private Proxy buildProxy() {
|
|
if (isBlank(config.getProxyHost()) || isBlank(config.getProxyPort())) {
|
|
return null;
|
|
}
|
|
|
|
int port = Integer.parseInt(config.getProxyPort());
|
|
if (!isBlank(config.getProxyUser()) && !isBlank(config.getProxyPassword())) {
|
|
Authenticator.setDefault(new Authenticator() {
|
|
@Override
|
|
protected PasswordAuthentication getPasswordAuthentication() {
|
|
return new PasswordAuthentication(config.getProxyUser(), config.getProxyPassword().toCharArray());
|
|
}
|
|
});
|
|
}
|
|
return new Proxy(Proxy.Type.HTTP, new InetSocketAddress(config.getProxyHost(), port));
|
|
}
|
|
|
|
private String buildFormData(String clientId, String clientAssertion) throws Exception {
|
|
Map<String, String> params = new LinkedHashMap<>();
|
|
params.put(PARAM_RESOURCE, config.getResource());
|
|
params.put(PARAM_CLIENT_ID, clientId);
|
|
params.put(PARAM_CLIENT_ASSERTION_TYPE, CLIENT_ASSERTION_TYPE);
|
|
params.put(PARAM_CLIENT_ASSERTION, clientAssertion);
|
|
params.put(PARAM_GRANT_TYPE, GRANT_TYPE_CLIENT_CREDENTIALS);
|
|
|
|
StringBuilder builder = new StringBuilder();
|
|
for (Map.Entry<String, String> entry : params.entrySet()) {
|
|
if (builder.length() > 0) {
|
|
builder.append('&');
|
|
}
|
|
builder.append(URLEncoder.encode(entry.getKey(), StandardCharsets.UTF_8.name()));
|
|
builder.append('=');
|
|
builder.append(URLEncoder.encode(entry.getValue(), StandardCharsets.UTF_8.name()));
|
|
}
|
|
return builder.toString();
|
|
}
|
|
|
|
private String generateJwtAssertion(String clientId) throws Exception {
|
|
String certPem = new String(Base64.getDecoder().decode(config.getCertificate().replaceAll("\\s+", "")), StandardCharsets.UTF_8)
|
|
.replace("-----BEGIN CERTIFICATE-----", "")
|
|
.replace("-----END CERTIFICATE-----", "")
|
|
.replaceAll("\\s+", "");
|
|
X509Certificate cert = getCertificate(certPem);
|
|
|
|
String keyPem = new String(Base64.getDecoder().decode(config.getPrivateKey().replaceAll("\\s+", "")), StandardCharsets.UTF_8)
|
|
.replace("-----BEGIN PRIVATE KEY-----", "")
|
|
.replace("-----END PRIVATE KEY-----", "")
|
|
.replaceAll("\\s+", "");
|
|
PrivateKey privateKey = getPrivateKey(keyPem);
|
|
|
|
MessageDigest sha1 = MessageDigest.getInstance("SHA-1");
|
|
String x5t = Base64.getUrlEncoder().withoutPadding().encodeToString(sha1.digest(cert.getEncoded()));
|
|
|
|
long now = System.currentTimeMillis() / 1000L;
|
|
|
|
Map<String, Object> header = new LinkedHashMap<>();
|
|
header.put("alg", "RS256");
|
|
header.put("typ", "JWT");
|
|
header.put("x5t", x5t);
|
|
header.put("kid", x5t);
|
|
|
|
Map<String, Object> claims = new LinkedHashMap<>();
|
|
claims.put("iss", clientId);
|
|
claims.put("sub", clientId);
|
|
claims.put("aud", config.getAudience());
|
|
claims.put("jti", UUID.randomUUID().toString());
|
|
claims.put("iat", now);
|
|
claims.put("nbf", now);
|
|
claims.put("exp", now + 600);
|
|
|
|
String headerPart = Base64.getUrlEncoder().withoutPadding()
|
|
.encodeToString(objectMapper.writeValueAsBytes(header));
|
|
String claimsPart = Base64.getUrlEncoder().withoutPadding()
|
|
.encodeToString(objectMapper.writeValueAsBytes(claims));
|
|
String signingInput = headerPart + "." + claimsPart;
|
|
|
|
Signature signature = Signature.getInstance("SHA256withRSA");
|
|
signature.initSign(privateKey);
|
|
signature.update(signingInput.getBytes(StandardCharsets.UTF_8));
|
|
String signaturePart = Base64.getUrlEncoder().withoutPadding().encodeToString(signature.sign());
|
|
return signingInput + "." + signaturePart;
|
|
}
|
|
|
|
private boolean isJwtExpired(String jwt) throws Exception {
|
|
String[] parts = jwt.split("\\.");
|
|
if (parts.length != 3) {
|
|
throw new IllegalArgumentException("Invalid JWT format");
|
|
}
|
|
|
|
byte[] decodedBytes = Base64.getUrlDecoder().decode(parts[1]);
|
|
Map<String, Object> claims = objectMapper.readValue(decodedBytes, Map.class);
|
|
Number exp = (Number) claims.get("exp");
|
|
if (exp == null) {
|
|
throw new IllegalArgumentException("JWT does not contain 'exp' claim");
|
|
}
|
|
return exp.longValue() * 1000L < System.currentTimeMillis();
|
|
}
|
|
|
|
private X509Certificate getCertificate(String certificatePem) throws Exception {
|
|
String normalized = certificatePem.replace("-----BEGIN CERTIFICATE-----", "")
|
|
.replace("-----END CERTIFICATE-----", "")
|
|
.replaceAll("\\s", "");
|
|
byte[] certBytes = Base64.getDecoder().decode(normalized);
|
|
CertificateFactory factory = CertificateFactory.getInstance("X.509");
|
|
return (X509Certificate) factory.generateCertificate(new ByteArrayInputStream(certBytes));
|
|
}
|
|
|
|
private PrivateKey getPrivateKey(String privateKeyPem) throws Exception {
|
|
String normalized = privateKeyPem.replaceAll("-----BEGIN (.*) PRIVATE KEY-----", "")
|
|
.replaceAll("-----END (.*) PRIVATE KEY-----", "")
|
|
.replaceAll("\\s", "");
|
|
byte[] keyBytes = Base64.getDecoder().decode(normalized);
|
|
PKCS8EncodedKeySpec spec = new PKCS8EncodedKeySpec(keyBytes);
|
|
KeyFactory keyFactory = KeyFactory.getInstance("RSA");
|
|
return keyFactory.generatePrivate(spec);
|
|
}
|
|
|
|
private String readFully(InputStream inputStream) throws Exception {
|
|
if (inputStream == null) {
|
|
return "";
|
|
}
|
|
|
|
try (InputStream in = inputStream; ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
|
|
byte[] buffer = new byte[4096];
|
|
for (int read = in.read(buffer); read >= 0; read = in.read(buffer)) {
|
|
baos.write(buffer, 0, read);
|
|
}
|
|
return baos.toString(StandardCharsets.UTF_8.name());
|
|
}
|
|
}
|
|
|
|
private void validateConfig() {
|
|
requireValue(config.getTokenUrl(), "tokenUrl");
|
|
requireValue(config.getAudience(), "audience");
|
|
requireValue(config.getResource(), "resource");
|
|
if (config.getClientIds() == null || config.getClientIds().isEmpty()) {
|
|
throw new IllegalArgumentException("Missing or empty config value: clientIds");
|
|
}
|
|
requireValue(config.getCertificate(), "certificate");
|
|
requireValue(config.getPrivateKey(), "privateKey");
|
|
}
|
|
|
|
private void requireValue(String value, String name) {
|
|
if (isBlank(value)) {
|
|
throw new IllegalArgumentException("Missing or empty config value: " + name);
|
|
}
|
|
}
|
|
|
|
private boolean isBlank(String value) {
|
|
return value == null || value.trim().isEmpty();
|
|
}
|
|
}
|