2026-06-30 17:38:52 +02:00

286 lines
12 KiB
Java

package cz.trask.adfsauthms.service;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Authenticator;
import java.net.HttpURLConnection;
import java.net.InetSocketAddress;
import java.net.PasswordAuthentication;
import java.net.Proxy;
import java.net.URI;
import java.net.URL;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.security.KeyFactory;
import java.security.MessageDigest;
import java.security.PrivateKey;
import java.security.Signature;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.security.spec.PKCS8EncodedKeySpec;
import java.util.Base64;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.UUID;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import com.fasterxml.jackson.databind.ObjectMapper;
import cz.trask.adfsauthms.config.AdfsConfig;
import cz.trask.adfsauthms.dto.TokenPayloadIdp;
public class AdfsTokenService {
private static final Logger logger = LogManager.getLogger(AdfsTokenService.class);
private static final String PARAM_RESOURCE = "resource";
private static final String PARAM_CLIENT_ID = "client_id";
private static final String PARAM_CLIENT_ASSERTION_TYPE = "client_assertion_type";
private static final String PARAM_CLIENT_ASSERTION = "client_assertion";
private static final String PARAM_GRANT_TYPE = "grant_type";
private static final String GRANT_TYPE_CLIENT_CREDENTIALS = "client_credentials";
private static final String CLIENT_ASSERTION_TYPE =
"urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
private static final String HEADER_CONTENT_TYPE = "Content-Type";
private static final String CONTENT_TYPE_FORM = "application/x-www-form-urlencoded";
private static final String METHOD_POST = "POST";
private final AdfsConfig config;
private final ObjectMapper objectMapper;
private volatile TokenPayloadIdp cachedToken;
public AdfsTokenService(AdfsConfig config, ObjectMapper objectMapper) {
this.config = config;
this.objectMapper = objectMapper;
}
public synchronized TokenPayloadIdp getToken(String clientIdOverride) throws Exception {
validateConfig();
if (clientIdOverride == null && cachedToken != null && cachedToken.getAccessToken() != null
&& !cachedToken.getAccessToken().isBlank()
&& !isJwtExpired(cachedToken.getAccessToken())) {
logger.debug("Returning cached ADFS token");
return cachedToken;
}
logger.info("Fetching new ADFS token from {} (Client ID: {})", config.getTokenUrl(),
clientIdOverride != null ? clientIdOverride : getDefaultClientId());
TokenPayloadIdp fetchedToken = fetchToken(clientIdOverride);
if (clientIdOverride == null) {
this.cachedToken = fetchedToken;
}
return fetchedToken;
}
private String getDefaultClientId() {
if (config.getClientIds() != null && !config.getClientIds().isEmpty()) {
return config.getClientIds().get(0);
}
return null;
}
public synchronized void invalidateCache() {
logger.debug("Invalidating cached ADFS token");
this.cachedToken = null;
}
private TokenPayloadIdp fetchToken(String clientIdOverride) throws Exception {
Proxy proxy = buildProxy();
if (proxy != null) {
logger.debug("Using proxy for ADFS token fetch: {}", config.getProxyHost());
}
String effectiveClientId = clientIdOverride != null ? clientIdOverride : getDefaultClientId();
String clientAssertion = generateJwtAssertion(effectiveClientId);
logger.debug("Generated JWT client assertion for ADFS: {}", clientAssertion);
String formData = buildFormData(effectiveClientId, clientAssertion);
HttpURLConnection connection = null;
try {
URL url = new URI(config.getTokenUrl()).toURL();
connection = (HttpURLConnection) url.openConnection(proxy != null ? proxy : Proxy.NO_PROXY);
connection.setRequestMethod(METHOD_POST);
connection.setDoInput(true);
connection.setDoOutput(true);
connection.setRequestProperty(HEADER_CONTENT_TYPE, CONTENT_TYPE_FORM);
try (OutputStream outputStream = connection.getOutputStream()) {
outputStream.write(formData.getBytes(StandardCharsets.UTF_8));
}
int responseCode = connection.getResponseCode();
logger.debug("ADFS response code: {}", responseCode);
InputStream responseStream = responseCode >= 400 ? connection.getErrorStream() : connection.getInputStream();
String responseBody = readFully(responseStream);
if (responseCode >= 400) {
logger.debug("ADFS error body: {}", responseBody);
throw new IllegalStateException("ADFS returned HTTP " + responseCode + ": " + responseBody);
}
logger.debug("ADFS token fetched successfully");
return objectMapper.readValue(responseBody, TokenPayloadIdp.class);
} finally {
if (connection != null) {
connection.disconnect();
}
}
}
private Proxy buildProxy() {
if (isBlank(config.getProxyHost()) || isBlank(config.getProxyPort())) {
return null;
}
int port = Integer.parseInt(config.getProxyPort());
if (!isBlank(config.getProxyUser()) && !isBlank(config.getProxyPassword())) {
Authenticator.setDefault(new Authenticator() {
@Override
protected PasswordAuthentication getPasswordAuthentication() {
return new PasswordAuthentication(config.getProxyUser(), config.getProxyPassword().toCharArray());
}
});
}
return new Proxy(Proxy.Type.HTTP, new InetSocketAddress(config.getProxyHost(), port));
}
private String buildFormData(String clientId, String clientAssertion) throws Exception {
Map<String, String> params = new LinkedHashMap<>();
params.put(PARAM_RESOURCE, config.getResource());
params.put(PARAM_CLIENT_ID, clientId);
params.put(PARAM_CLIENT_ASSERTION_TYPE, CLIENT_ASSERTION_TYPE);
params.put(PARAM_CLIENT_ASSERTION, clientAssertion);
params.put(PARAM_GRANT_TYPE, GRANT_TYPE_CLIENT_CREDENTIALS);
StringBuilder builder = new StringBuilder();
for (Map.Entry<String, String> entry : params.entrySet()) {
if (builder.length() > 0) {
builder.append('&');
}
builder.append(URLEncoder.encode(entry.getKey(), StandardCharsets.UTF_8.name()));
builder.append('=');
builder.append(URLEncoder.encode(entry.getValue(), StandardCharsets.UTF_8.name()));
}
return builder.toString();
}
private String generateJwtAssertion(String clientId) throws Exception {
String certPem = new String(Base64.getDecoder().decode(config.getCertificate().replaceAll("\\s+", "")), StandardCharsets.UTF_8)
.replace("-----BEGIN CERTIFICATE-----", "")
.replace("-----END CERTIFICATE-----", "")
.replaceAll("\\s+", "");
X509Certificate cert = getCertificate(certPem);
String keyPem = new String(Base64.getDecoder().decode(config.getPrivateKey().replaceAll("\\s+", "")), StandardCharsets.UTF_8)
.replace("-----BEGIN PRIVATE KEY-----", "")
.replace("-----END PRIVATE KEY-----", "")
.replaceAll("\\s+", "");
PrivateKey privateKey = getPrivateKey(keyPem);
MessageDigest sha1 = MessageDigest.getInstance("SHA-1");
String x5t = Base64.getUrlEncoder().withoutPadding().encodeToString(sha1.digest(cert.getEncoded()));
long now = System.currentTimeMillis() / 1000L;
Map<String, Object> header = new LinkedHashMap<>();
header.put("alg", "RS256");
header.put("typ", "JWT");
header.put("x5t", x5t);
header.put("kid", x5t);
Map<String, Object> claims = new LinkedHashMap<>();
claims.put("iss", clientId);
claims.put("sub", clientId);
claims.put("aud", config.getAudience());
claims.put("jti", UUID.randomUUID().toString());
claims.put("iat", now);
claims.put("nbf", now);
claims.put("exp", now + 600);
String headerPart = Base64.getUrlEncoder().withoutPadding()
.encodeToString(objectMapper.writeValueAsBytes(header));
String claimsPart = Base64.getUrlEncoder().withoutPadding()
.encodeToString(objectMapper.writeValueAsBytes(claims));
String signingInput = headerPart + "." + claimsPart;
Signature signature = Signature.getInstance("SHA256withRSA");
signature.initSign(privateKey);
signature.update(signingInput.getBytes(StandardCharsets.UTF_8));
String signaturePart = Base64.getUrlEncoder().withoutPadding().encodeToString(signature.sign());
return signingInput + "." + signaturePart;
}
private boolean isJwtExpired(String jwt) throws Exception {
String[] parts = jwt.split("\\.");
if (parts.length != 3) {
throw new IllegalArgumentException("Invalid JWT format");
}
byte[] decodedBytes = Base64.getUrlDecoder().decode(parts[1]);
Map<String, Object> claims = objectMapper.readValue(decodedBytes, Map.class);
Number exp = (Number) claims.get("exp");
if (exp == null) {
throw new IllegalArgumentException("JWT does not contain 'exp' claim");
}
return exp.longValue() * 1000L < System.currentTimeMillis();
}
private X509Certificate getCertificate(String certificatePem) throws Exception {
String normalized = certificatePem.replace("-----BEGIN CERTIFICATE-----", "")
.replace("-----END CERTIFICATE-----", "")
.replaceAll("\\s", "");
byte[] certBytes = Base64.getDecoder().decode(normalized);
CertificateFactory factory = CertificateFactory.getInstance("X.509");
return (X509Certificate) factory.generateCertificate(new ByteArrayInputStream(certBytes));
}
private PrivateKey getPrivateKey(String privateKeyPem) throws Exception {
String normalized = privateKeyPem.replaceAll("-----BEGIN (.*) PRIVATE KEY-----", "")
.replaceAll("-----END (.*) PRIVATE KEY-----", "")
.replaceAll("\\s", "");
byte[] keyBytes = Base64.getDecoder().decode(normalized);
PKCS8EncodedKeySpec spec = new PKCS8EncodedKeySpec(keyBytes);
KeyFactory keyFactory = KeyFactory.getInstance("RSA");
return keyFactory.generatePrivate(spec);
}
private String readFully(InputStream inputStream) throws Exception {
if (inputStream == null) {
return "";
}
try (InputStream in = inputStream; ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
byte[] buffer = new byte[4096];
for (int read = in.read(buffer); read >= 0; read = in.read(buffer)) {
baos.write(buffer, 0, read);
}
return baos.toString(StandardCharsets.UTF_8.name());
}
}
private void validateConfig() {
requireValue(config.getTokenUrl(), "tokenUrl");
requireValue(config.getAudience(), "audience");
requireValue(config.getResource(), "resource");
if (config.getClientIds() == null || config.getClientIds().isEmpty()) {
throw new IllegalArgumentException("Missing or empty config value: clientIds");
}
requireValue(config.getCertificate(), "certificate");
requireValue(config.getPrivateKey(), "privateKey");
}
private void requireValue(String value, String name) {
if (isBlank(value)) {
throw new IllegalArgumentException("Missing or empty config value: " + name);
}
}
private boolean isBlank(String value) {
return value == null || value.trim().isEmpty();
}
}