package cz.trask.adfsauthms.service; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.InputStream; import java.io.OutputStream; import java.net.Authenticator; import java.net.HttpURLConnection; import java.net.InetSocketAddress; import java.net.PasswordAuthentication; import java.net.Proxy; import java.net.URI; import java.net.URL; import java.net.URLEncoder; import java.nio.charset.StandardCharsets; import java.security.KeyFactory; import java.security.MessageDigest; import java.security.PrivateKey; import java.security.Signature; import java.security.cert.CertificateFactory; import java.security.cert.X509Certificate; import java.security.spec.PKCS8EncodedKeySpec; import java.util.Base64; import java.util.LinkedHashMap; import java.util.Map; import java.util.UUID; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import com.fasterxml.jackson.databind.ObjectMapper; import cz.trask.adfsauthms.config.AdfsConfig; import cz.trask.adfsauthms.dto.TokenPayloadIdp; public class AdfsTokenService { private static final Logger logger = LogManager.getLogger(AdfsTokenService.class); private static final String PARAM_RESOURCE = "resource"; private static final String PARAM_CLIENT_ID = "client_id"; private static final String PARAM_CLIENT_ASSERTION_TYPE = "client_assertion_type"; private static final String PARAM_CLIENT_ASSERTION = "client_assertion"; private static final String PARAM_GRANT_TYPE = "grant_type"; private static final String GRANT_TYPE_CLIENT_CREDENTIALS = "client_credentials"; private static final String CLIENT_ASSERTION_TYPE = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"; private static final String HEADER_CONTENT_TYPE = "Content-Type"; private static final String CONTENT_TYPE_FORM = "application/x-www-form-urlencoded"; private static final String METHOD_POST = "POST"; private final AdfsConfig config; private final ObjectMapper objectMapper; private volatile TokenPayloadIdp cachedToken; public AdfsTokenService(AdfsConfig config, ObjectMapper objectMapper) { this.config = config; this.objectMapper = objectMapper; } public synchronized TokenPayloadIdp getToken(String clientIdOverride) throws Exception { validateConfig(); if (clientIdOverride == null && cachedToken != null && cachedToken.getAccessToken() != null && !cachedToken.getAccessToken().isBlank() && !isJwtExpired(cachedToken.getAccessToken())) { logger.debug("Returning cached ADFS token"); return cachedToken; } logger.info("Fetching new ADFS token from {} (Client ID: {})", config.getTokenUrl(), clientIdOverride != null ? clientIdOverride : getDefaultClientId()); TokenPayloadIdp fetchedToken = fetchToken(clientIdOverride); if (clientIdOverride == null) { this.cachedToken = fetchedToken; } return fetchedToken; } private String getDefaultClientId() { if (config.getClientIds() != null && !config.getClientIds().isEmpty()) { return config.getClientIds().get(0); } return null; } public synchronized void invalidateCache() { logger.debug("Invalidating cached ADFS token"); this.cachedToken = null; } private TokenPayloadIdp fetchToken(String clientIdOverride) throws Exception { Proxy proxy = buildProxy(); if (proxy != null) { logger.debug("Using proxy for ADFS token fetch: {}", config.getProxyHost()); } String effectiveClientId = clientIdOverride != null ? clientIdOverride : getDefaultClientId(); String clientAssertion = generateJwtAssertion(effectiveClientId); logger.debug("Generated JWT client assertion for ADFS: {}", clientAssertion); String formData = buildFormData(effectiveClientId, clientAssertion); HttpURLConnection connection = null; try { URL url = new URI(config.getTokenUrl()).toURL(); connection = (HttpURLConnection) url.openConnection(proxy != null ? proxy : Proxy.NO_PROXY); connection.setRequestMethod(METHOD_POST); connection.setDoInput(true); connection.setDoOutput(true); connection.setRequestProperty(HEADER_CONTENT_TYPE, CONTENT_TYPE_FORM); try (OutputStream outputStream = connection.getOutputStream()) { outputStream.write(formData.getBytes(StandardCharsets.UTF_8)); } int responseCode = connection.getResponseCode(); logger.debug("ADFS response code: {}", responseCode); InputStream responseStream = responseCode >= 400 ? connection.getErrorStream() : connection.getInputStream(); String responseBody = readFully(responseStream); if (responseCode >= 400) { logger.debug("ADFS error body: {}", responseBody); throw new IllegalStateException("ADFS returned HTTP " + responseCode + ": " + responseBody); } logger.debug("ADFS token fetched successfully"); return objectMapper.readValue(responseBody, TokenPayloadIdp.class); } finally { if (connection != null) { connection.disconnect(); } } } private Proxy buildProxy() { if (isBlank(config.getProxyHost()) || isBlank(config.getProxyPort())) { return null; } int port = Integer.parseInt(config.getProxyPort()); if (!isBlank(config.getProxyUser()) && !isBlank(config.getProxyPassword())) { Authenticator.setDefault(new Authenticator() { @Override protected PasswordAuthentication getPasswordAuthentication() { return new PasswordAuthentication(config.getProxyUser(), config.getProxyPassword().toCharArray()); } }); } return new Proxy(Proxy.Type.HTTP, new InetSocketAddress(config.getProxyHost(), port)); } private String buildFormData(String clientId, String clientAssertion) throws Exception { Map params = new LinkedHashMap<>(); params.put(PARAM_RESOURCE, config.getResource()); params.put(PARAM_CLIENT_ID, clientId); params.put(PARAM_CLIENT_ASSERTION_TYPE, CLIENT_ASSERTION_TYPE); params.put(PARAM_CLIENT_ASSERTION, clientAssertion); params.put(PARAM_GRANT_TYPE, GRANT_TYPE_CLIENT_CREDENTIALS); StringBuilder builder = new StringBuilder(); for (Map.Entry entry : params.entrySet()) { if (builder.length() > 0) { builder.append('&'); } builder.append(URLEncoder.encode(entry.getKey(), StandardCharsets.UTF_8.name())); builder.append('='); builder.append(URLEncoder.encode(entry.getValue(), StandardCharsets.UTF_8.name())); } return builder.toString(); } private String generateJwtAssertion(String clientId) throws Exception { String certPem = new String(Base64.getDecoder().decode(config.getCertificate().replaceAll("\\s+", "")), StandardCharsets.UTF_8) .replace("-----BEGIN CERTIFICATE-----", "") .replace("-----END CERTIFICATE-----", "") .replaceAll("\\s+", ""); X509Certificate cert = getCertificate(certPem); String keyPem = new String(Base64.getDecoder().decode(config.getPrivateKey().replaceAll("\\s+", "")), StandardCharsets.UTF_8) .replace("-----BEGIN PRIVATE KEY-----", "") .replace("-----END PRIVATE KEY-----", "") .replaceAll("\\s+", ""); PrivateKey privateKey = getPrivateKey(keyPem); MessageDigest sha1 = MessageDigest.getInstance("SHA-1"); String x5t = Base64.getUrlEncoder().withoutPadding().encodeToString(sha1.digest(cert.getEncoded())); long now = System.currentTimeMillis() / 1000L; Map header = new LinkedHashMap<>(); header.put("alg", "RS256"); header.put("typ", "JWT"); header.put("x5t", x5t); header.put("kid", x5t); Map claims = new LinkedHashMap<>(); claims.put("iss", clientId); claims.put("sub", clientId); claims.put("aud", config.getAudience()); claims.put("jti", UUID.randomUUID().toString()); claims.put("iat", now); claims.put("nbf", now); claims.put("exp", now + 600); String headerPart = Base64.getUrlEncoder().withoutPadding() .encodeToString(objectMapper.writeValueAsBytes(header)); String claimsPart = Base64.getUrlEncoder().withoutPadding() .encodeToString(objectMapper.writeValueAsBytes(claims)); String signingInput = headerPart + "." + claimsPart; Signature signature = Signature.getInstance("SHA256withRSA"); signature.initSign(privateKey); signature.update(signingInput.getBytes(StandardCharsets.UTF_8)); String signaturePart = Base64.getUrlEncoder().withoutPadding().encodeToString(signature.sign()); return signingInput + "." + signaturePart; } private boolean isJwtExpired(String jwt) throws Exception { String[] parts = jwt.split("\\."); if (parts.length != 3) { throw new IllegalArgumentException("Invalid JWT format"); } byte[] decodedBytes = Base64.getUrlDecoder().decode(parts[1]); Map claims = objectMapper.readValue(decodedBytes, Map.class); Number exp = (Number) claims.get("exp"); if (exp == null) { throw new IllegalArgumentException("JWT does not contain 'exp' claim"); } return exp.longValue() * 1000L < System.currentTimeMillis(); } private X509Certificate getCertificate(String certificatePem) throws Exception { String normalized = certificatePem.replace("-----BEGIN CERTIFICATE-----", "") .replace("-----END CERTIFICATE-----", "") .replaceAll("\\s", ""); byte[] certBytes = Base64.getDecoder().decode(normalized); CertificateFactory factory = CertificateFactory.getInstance("X.509"); return (X509Certificate) factory.generateCertificate(new ByteArrayInputStream(certBytes)); } private PrivateKey getPrivateKey(String privateKeyPem) throws Exception { String normalized = privateKeyPem.replaceAll("-----BEGIN (.*) PRIVATE KEY-----", "") .replaceAll("-----END (.*) PRIVATE KEY-----", "") .replaceAll("\\s", ""); byte[] keyBytes = Base64.getDecoder().decode(normalized); PKCS8EncodedKeySpec spec = new PKCS8EncodedKeySpec(keyBytes); KeyFactory keyFactory = KeyFactory.getInstance("RSA"); return keyFactory.generatePrivate(spec); } private String readFully(InputStream inputStream) throws Exception { if (inputStream == null) { return ""; } try (InputStream in = inputStream; ByteArrayOutputStream baos = new ByteArrayOutputStream()) { byte[] buffer = new byte[4096]; for (int read = in.read(buffer); read >= 0; read = in.read(buffer)) { baos.write(buffer, 0, read); } return baos.toString(StandardCharsets.UTF_8.name()); } } private void validateConfig() { requireValue(config.getTokenUrl(), "tokenUrl"); requireValue(config.getAudience(), "audience"); requireValue(config.getResource(), "resource"); if (config.getClientIds() == null || config.getClientIds().isEmpty()) { throw new IllegalArgumentException("Missing or empty config value: clientIds"); } requireValue(config.getCertificate(), "certificate"); requireValue(config.getPrivateKey(), "privateKey"); } private void requireValue(String value, String name) { if (isBlank(value)) { throw new IllegalArgumentException("Missing or empty config value: " + name); } } private boolean isBlank(String value) { return value == null || value.trim().isEmpty(); } }