diff --git a/src/client/Microsoft.Identity.Client/Http/Retry/HttpRetryCondition.cs b/src/client/Microsoft.Identity.Client/Http/Retry/HttpRetryConditions.cs similarity index 95% rename from src/client/Microsoft.Identity.Client/Http/Retry/HttpRetryCondition.cs rename to src/client/Microsoft.Identity.Client/Http/Retry/HttpRetryConditions.cs index 8b2231cf4a..81886b1e4c 100644 --- a/src/client/Microsoft.Identity.Client/Http/Retry/HttpRetryCondition.cs +++ b/src/client/Microsoft.Identity.Client/Http/Retry/HttpRetryConditions.cs @@ -27,6 +27,21 @@ public static bool DefaultManagedIdentity(HttpResponse response, Exception excep }; } + /// + /// Retry policy specific to Imds v1 and v2 Probe. + /// Extends Imds retry policy but excludes 404 status code. + /// + public static bool ImdsProbe(HttpResponse response, Exception exception) + { + if (!Imds(response, exception)) + { + return false; + } + + // If Imds would retry but the status code is 404, don't retry + return (int)response.StatusCode is not 404; + } + /// /// Retry policy specific to IMDS Managed Identity. /// @@ -62,21 +77,6 @@ public static bool RegionDiscovery(HttpResponse response, Exception exception) return (int)response.StatusCode is not (404 or 408); } - /// - /// Retry policy specific to CSR Metadata Probe. - /// Extends Imds retry policy but excludes 404 status code. - /// - public static bool CsrMetadataProbe(HttpResponse response, Exception exception) - { - if (!Imds(response, exception)) - { - return false; - } - - // If Imds would retry but the status code is 404, don't retry - return (int)response.StatusCode is not 404; - } - /// /// Retry condition for /token and /authorize endpoints /// diff --git a/src/client/Microsoft.Identity.Client/Http/Retry/CsrMetadataProbeRetryPolicy.cs b/src/client/Microsoft.Identity.Client/Http/Retry/ImdsProbeRetryPolicy.cs similarity index 66% rename from src/client/Microsoft.Identity.Client/Http/Retry/CsrMetadataProbeRetryPolicy.cs rename to src/client/Microsoft.Identity.Client/Http/Retry/ImdsProbeRetryPolicy.cs index 71de66726d..b939e6e13e 100644 --- a/src/client/Microsoft.Identity.Client/Http/Retry/CsrMetadataProbeRetryPolicy.cs +++ b/src/client/Microsoft.Identity.Client/Http/Retry/ImdsProbeRetryPolicy.cs @@ -5,11 +5,11 @@ namespace Microsoft.Identity.Client.Http.Retry { - internal class CsrMetadataProbeRetryPolicy : ImdsRetryPolicy + internal class ImdsProbeRetryPolicy : ImdsRetryPolicy { protected override bool ShouldRetry(HttpResponse response, Exception exception) { - return HttpRetryConditions.CsrMetadataProbe(response, exception); + return HttpRetryConditions.ImdsProbe(response, exception); } } } diff --git a/src/client/Microsoft.Identity.Client/Http/Retry/RetryPolicyFactory.cs b/src/client/Microsoft.Identity.Client/Http/Retry/RetryPolicyFactory.cs index e190f1ba4d..f9c97bb6c3 100644 --- a/src/client/Microsoft.Identity.Client/Http/Retry/RetryPolicyFactory.cs +++ b/src/client/Microsoft.Identity.Client/Http/Retry/RetryPolicyFactory.cs @@ -14,12 +14,12 @@ public virtual IRetryPolicy GetRetryPolicy(RequestType requestType) case RequestType.STS: case RequestType.ManagedIdentityDefault: return new DefaultRetryPolicy(requestType); + case RequestType.ImdsProbe: + return new ImdsProbeRetryPolicy(); case RequestType.Imds: return new ImdsRetryPolicy(); case RequestType.RegionDiscovery: return new RegionDiscoveryRetryPolicy(); - case RequestType.CsrMetadataProbe: - return new CsrMetadataProbeRetryPolicy(); default: throw new ArgumentOutOfRangeException(nameof(requestType), requestType, "Unknown request type."); } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs index 29c1b4a4db..eb322a08a8 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs @@ -11,7 +11,10 @@ using Microsoft.Identity.Client.ApiConfig.Parameters; using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.Http; +using Microsoft.Identity.Client.Http.Retry; using Microsoft.Identity.Client.Internal; +using Microsoft.Identity.Client.ManagedIdentity.V2; +using Microsoft.Identity.Client.OAuth2; namespace Microsoft.Identity.Client.ManagedIdentity { @@ -19,9 +22,10 @@ internal class ImdsManagedIdentitySource : AbstractManagedIdentity { // IMDS constants. Docs for IMDS are available here https://docs.microsoft.com/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http // used in unit tests as well + public const string ApiVersionQueryParam = "api-version"; public const string DefaultImdsBaseEndpoint= "http://169.254.169.254"; - private const string ImdsTokenPath = "/metadata/identity/oauth2/token"; public const string ImdsApiVersion = "2018-02-01"; + public const string ImdsTokenPath = "/metadata/identity/oauth2/token"; private const string DefaultMessage = "[Managed Identity] Service request failed."; @@ -36,6 +40,11 @@ internal class ImdsManagedIdentitySource : AbstractManagedIdentity private static string s_cachedBaseEndpoint = null; + public static AbstractManagedIdentity Create(RequestContext requestContext) + { + return new ImdsManagedIdentitySource(requestContext); + } + internal ImdsManagedIdentitySource(RequestContext requestContext) : base(requestContext, ManagedIdentitySource.Imds) { @@ -51,7 +60,7 @@ protected override Task CreateRequestAsync(string resour ManagedIdentityRequest request = new(HttpMethod.Get, _imdsEndpoint); request.Headers.Add("Metadata", "true"); - request.QueryParameters["api-version"] = ImdsApiVersion; + request.QueryParameters[ApiVersionQueryParam] = ImdsApiVersion; request.QueryParameters["resource"] = resource; switch (_requestContext.ServiceBundle.Config.ManagedIdentityId.IdType) @@ -211,5 +220,106 @@ public static Uri GetValidatedEndpoint( return builder.Uri; } + + public static string ImdsQueryParamsHelper( + RequestContext requestContext, + string apiVersionQueryParam, + string imdsApiVersion) + { + var queryParams = $"{apiVersionQueryParam}={imdsApiVersion}"; + + var userAssignedIdQueryParam = GetUserAssignedIdQueryParam( + requestContext.ServiceBundle.Config.ManagedIdentityId.IdType, + requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId, + requestContext.Logger); + + if (userAssignedIdQueryParam != null) + { + queryParams += $"&{userAssignedIdQueryParam.Value.Key}={userAssignedIdQueryParam.Value.Value}"; + } + + return queryParams; + } + + public static async Task ProbeImdsEndpointAsync( + RequestContext requestContext, + ImdsVersion imdsVersion, + CancellationToken cancellationToken) + { + string apiVersionQueryParam; + string imdsApiVersion; + string imdsEndpoint; + string imdsStringHelper; + + switch (imdsVersion) + { + case ImdsVersion.V2: +#if NET462 + requestContext.Logger.Info("[Managed Identity] IMDSv2 flow is not supported on .NET Framework 4.6.2. Cryptographic operations required for managed identity authentication are unavailable on this platform. Skipping IMDSv2 probe."); + return false; +#else + apiVersionQueryParam = ImdsV2ManagedIdentitySource.ApiVersionQueryParam; + imdsApiVersion = ImdsV2ManagedIdentitySource.ImdsV2ApiVersion; + imdsEndpoint = ImdsV2ManagedIdentitySource.CsrMetadataPath; + imdsStringHelper = "IMDSv2"; + break; +#endif + case ImdsVersion.V1: + apiVersionQueryParam = ApiVersionQueryParam; + imdsApiVersion = ImdsApiVersion; + imdsEndpoint = ImdsTokenPath; + imdsStringHelper = "IMDSv1"; + break; + + default: + throw new ArgumentOutOfRangeException(nameof(imdsVersion), imdsVersion, null); + } + + var queryParams = ImdsQueryParamsHelper(requestContext, apiVersionQueryParam, imdsApiVersion); + + // probe omits the "Metadata: true" header and then treats 400 Bad Request as success + var headers = new Dictionary + { + { OAuth2Header.XMsCorrelationId, requestContext.CorrelationId.ToString() } + }; + + IRetryPolicyFactory retryPolicyFactory = requestContext.ServiceBundle.Config.RetryPolicyFactory; + IRetryPolicy retryPolicy = retryPolicyFactory.GetRetryPolicy(RequestType.ImdsProbe); + + HttpResponse response = null; + + try + { + response = await requestContext.ServiceBundle.HttpManager.SendRequestAsync( + GetValidatedEndpoint(requestContext.Logger, imdsEndpoint, queryParams), + headers, + body: null, + method: HttpMethod.Get, + logger: requestContext.Logger, + doNotThrow: false, + mtlsCertificate: null, + validateServerCertificate: null, + cancellationToken: cancellationToken, + retryPolicy: retryPolicy) + .ConfigureAwait(false); + } + catch (Exception ex) + { + requestContext.Logger.Info($"[Managed Identity] {imdsStringHelper} probe endpoint failure. Exception occurred while sending request to probe endpoint: {ex}"); + return false; + } + + // probe omits the "Metadata: true" header and then treats 400 Bad Request as success + if (response.StatusCode == HttpStatusCode.BadRequest) + { + requestContext.Logger.Info(() => $"[Managed Identity] {imdsStringHelper} managed identity is available."); + return true; + } + else + { + requestContext.Logger.Info(() => $"[Managed Identity] {imdsStringHelper} managed identity is not available. Status code: {response.StatusCode}, Body: {response.Body}"); + return false; + } + } } } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs index a8683f7281..0ae3cd10f9 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs @@ -2,7 +2,6 @@ // Licensed under the MIT License. using System; -using System.Collections.Concurrent; using System.IO; using System.Security.Cryptography.X509Certificates; using System.Threading; @@ -41,12 +40,15 @@ internal async Task SendTokenRequestForManagedIdentityA AcquireTokenForManagedIdentityParameters parameters, CancellationToken cancellationToken) { - AbstractManagedIdentity msi = await GetOrSelectManagedIdentitySourceAsync(requestContext, parameters.IsMtlsPopRequested).ConfigureAwait(false); + AbstractManagedIdentity msi = await GetOrSelectManagedIdentitySourceAsync(requestContext, parameters.IsMtlsPopRequested, cancellationToken).ConfigureAwait(false); return await msi.AuthenticateAsync(parameters, cancellationToken).ConfigureAwait(false); } // This method tries to create managed identity source for different sources, if none is created then defaults to IMDS. - private async Task GetOrSelectManagedIdentitySourceAsync(RequestContext requestContext, bool isMtlsPopRequested) + private async Task GetOrSelectManagedIdentitySourceAsync( + RequestContext requestContext, + bool isMtlsPopRequested, + CancellationToken cancellationToken) { using (requestContext.Logger.LogMethodDuration()) { @@ -58,7 +60,7 @@ private async Task GetOrSelectManagedIdentitySourceAsyn if (s_sourceName == ManagedIdentitySource.None) { // First invocation: detect and cache - source = await GetManagedIdentitySourceAsync(requestContext, isMtlsPopRequested).ConfigureAwait(false); + source = await GetManagedIdentitySourceAsync(requestContext, isMtlsPopRequested, cancellationToken).ConfigureAwait(false); } else { @@ -66,20 +68,19 @@ private async Task GetOrSelectManagedIdentitySourceAsyn source = s_sourceName; } - // If the source has already been set to ImdsV2 (via this method, - // or GetManagedIdentitySourceAsync in ManagedIdentityApplication.cs) and mTLS PoP was NOT requested - // In this case, we need to fall back to ImdsV1, because ImdsV2 currently only supports mTLS PoP requests + // If the source has already been set to ImdsV2 (via this method, or GetManagedIdentitySourceAsync in ManagedIdentityApplication.cs) + // and mTLS PoP was NOT requested: fall back to ImdsV1, because ImdsV2 currently only supports mTLS PoP requests if (source == ManagedIdentitySource.ImdsV2 && !isMtlsPopRequested) { requestContext.Logger.Info("[Managed Identity] ImdsV2 detected, but mTLS PoP was not requested. Falling back to ImdsV1 for this request only. Please use the \"WithMtlsProofOfPossession\" API to request a token via ImdsV2."); // Do NOT modify s_sourceName; keep cached ImdsV2 so future PoP // requests can leverage it. - source = ManagedIdentitySource.DefaultToImds; + source = ManagedIdentitySource.Imds; } // If the source is determined to be ImdsV1 and mTLS PoP was requested, // throw an exception since ImdsV1 does not support mTLS PoP - if (source == ManagedIdentitySource.DefaultToImds && isMtlsPopRequested) + if (source == ManagedIdentitySource.Imds && isMtlsPopRequested) { throw new MsalClientException( MsalError.MtlsPopTokenNotSupportedinImdsV1, @@ -94,7 +95,8 @@ private async Task GetOrSelectManagedIdentitySourceAsyn ManagedIdentitySource.CloudShell => CloudShellManagedIdentitySource.Create(requestContext), ManagedIdentitySource.AzureArc => AzureArcManagedIdentitySource.Create(requestContext), ManagedIdentitySource.ImdsV2 => ImdsV2ManagedIdentitySource.Create(requestContext), - _ => new ImdsManagedIdentitySource(requestContext) + ManagedIdentitySource.Imds => ImdsManagedIdentitySource.Create(requestContext), + _ => throw new MsalClientException(MsalError.ManagedIdentityAllSourcesUnavailable, MsalErrorMessage.ManagedIdentityAllSourcesUnavailable) }; } } @@ -103,39 +105,58 @@ private async Task GetOrSelectManagedIdentitySourceAsyn // This method is perf sensitive any changes should be benchmarked. internal async Task GetManagedIdentitySourceAsync( RequestContext requestContext, - bool isMtlsPopRequested) + bool isMtlsPopRequested, + CancellationToken cancellationToken) { // First check env vars to avoid the probe if possible - ManagedIdentitySource source = GetManagedIdentitySourceNoImdsV2(requestContext.Logger); - - // If a source is detected via env vars, or - // a source wasn't detected (it defaulted to ImdsV1) and MtlsPop was NOT requested, - // use the source. - // (don't trigger the ImdsV2 probe endpoint if MtlsPop was NOT requested) - if (source != ManagedIdentitySource.DefaultToImds || !isMtlsPopRequested) + ManagedIdentitySource source = GetManagedIdentitySourceNoImds(requestContext.Logger); + if (source != ManagedIdentitySource.None) { s_sourceName = source; return source; } - // Otherwise, probe IMDSv2 - var response = await ImdsV2ManagedIdentitySource.GetCsrMetadataAsync(requestContext, probeMode: true).ConfigureAwait(false); - if (response != null) + // skip the ImdsV2 probe if MtlsPop was NOT requested + if (isMtlsPopRequested) + { + var imdsV2Response = await ImdsManagedIdentitySource.ProbeImdsEndpointAsync(requestContext, ImdsVersion.V2, cancellationToken).ConfigureAwait(false); + if (imdsV2Response) + { + requestContext.Logger.Info("[Managed Identity] ImdsV2 detected."); + s_sourceName = ManagedIdentitySource.ImdsV2; + return s_sourceName; + } + } + else + { + requestContext.Logger.Info("[Managed Identity] Mtls Pop was not requested; skipping ImdsV2 probe."); + } + + var imdsV1Response = await ImdsManagedIdentitySource.ProbeImdsEndpointAsync(requestContext, ImdsVersion.V1, cancellationToken).ConfigureAwait(false); + if (imdsV1Response) { - requestContext.Logger.Info("[Managed Identity] ImdsV2 detected."); - s_sourceName = ManagedIdentitySource.ImdsV2; + requestContext.Logger.Info("[Managed Identity] ImdsV1 detected."); + s_sourceName = ManagedIdentitySource.Imds; return s_sourceName; } - requestContext.Logger.Info("[Managed Identity] IMDSv2 probe failed. Defaulting to IMDSv1."); - s_sourceName = ManagedIdentitySource.DefaultToImds; + requestContext.Logger.Info($"[Managed Identity] {MsalErrorMessage.ManagedIdentityAllSourcesUnavailable}"); + s_sourceName = ManagedIdentitySource.None; return s_sourceName; } - // Detect managed identity source based on the availability of environment variables. - // The result of this method is not cached because reading environment variables is cheap. - // This method is perf sensitive any changes should be benchmarked. - internal static ManagedIdentitySource GetManagedIdentitySourceNoImdsV2(ILoggerAdapter logger = null) + /// + /// Detects the managed identity source based on the availability of environment variables. + /// It does not probe IMDS, but it checks for all other sources. + /// This method does not cache its result, as reading environment variables is inexpensive. + /// It is performance sensitive; any changes should be benchmarked. + /// + /// Optional logger for diagnostic output. + /// + /// The detected based on environment variables. + /// Returns ManagedIdentitySource.None if no environment-based source is detected. + /// + internal static ManagedIdentitySource GetManagedIdentitySourceNoImds(ILoggerAdapter logger = null) { string identityEndpoint = EnvironmentVariables.IdentityEndpoint; string identityHeader = EnvironmentVariables.IdentityHeader; @@ -177,7 +198,7 @@ internal static ManagedIdentitySource GetManagedIdentitySourceNoImdsV2(ILoggerAd } else { - return ManagedIdentitySource.DefaultToImds; + return ManagedIdentitySource.None; } } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentitySource.cs index 0b687fe7bb..faf3e6f16c 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentitySource.cs @@ -48,6 +48,7 @@ public enum ManagedIdentitySource /// Indicates that the source is defaulted to IMDS since no environment variables are set. /// This is used to detect the managed identity source. /// + [Obsolete("In use only to support the now obsolete GetManagedIdentitySource API. Will be removed in a future version. Use GetManagedIdentitySourceAsync instead.")] DefaultToImds, /// diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs index 404c619d8b..114b6bbc01 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs @@ -2,7 +2,6 @@ // Licensed under the MIT License. using System; -using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; using System.Net; @@ -29,20 +28,15 @@ internal class ImdsV2ManagedIdentitySource : AbstractManagedIdentity private readonly IMtlsCertificateCache _mtlsCache; // used in unit tests + public const string ApiVersionQueryParam = "cred-api-version"; public const string ImdsV2ApiVersion = "2.0"; public const string CsrMetadataPath = "/metadata/identity/getplatformmetadata"; public const string CertificateRequestPath = "/metadata/identity/issuecredential"; public const string AcquireEntraTokenPath = "/oauth2/v2.0/token"; - public static async Task GetCsrMetadataAsync( - RequestContext requestContext, - bool probeMode) + public static async Task GetCsrMetadataAsync(RequestContext requestContext) { -#if NET462 - requestContext.Logger.Info("[Managed Identity] IMDSv2 flow is not supported on .NET Framework 4.6.2. Cryptographic operations required for managed identity authentication are unavailable on this platform. Skipping IMDSv2 probe."); - return await Task.FromResult(null).ConfigureAwait(false); -#else - var queryParams = ImdsV2QueryParamsHelper(requestContext); + var queryParams = ImdsManagedIdentitySource.ImdsQueryParamsHelper(requestContext, ApiVersionQueryParam, ImdsV2ApiVersion); var headers = new Dictionary { @@ -51,7 +45,7 @@ public static async Task GetCsrMetadataAsync( }; IRetryPolicyFactory retryPolicyFactory = requestContext.ServiceBundle.Config.RetryPolicyFactory; - IRetryPolicy retryPolicy = retryPolicyFactory.GetRetryPolicy(RequestType.CsrMetadataProbe); + IRetryPolicy retryPolicy = retryPolicyFactory.GetRetryPolicy(RequestType.Imds); HttpResponse response = null; @@ -72,45 +66,28 @@ public static async Task GetCsrMetadataAsync( } catch (Exception ex) { - if (probeMode) - { - requestContext.Logger.Info($"[Managed Identity] IMDSv2 CSR endpoint failure. Exception occurred while sending request to CSR metadata endpoint: {ex}"); - return null; - } - else - { - ThrowProbeFailedException( - "ImdsV2ManagedIdentitySource.GetCsrMetadataAsync failed.", - ex); - } + ThrowCsrMetadataRequestException( + "ImdsV2ManagedIdentitySource.GetCsrMetadataAsync failed.", + ex); } if (response.StatusCode != HttpStatusCode.OK) { - if (probeMode) - { - requestContext.Logger.Info(() => $"[Managed Identity] IMDSv2 managed identity is not available. Status code: {response.StatusCode}, Body: {response.Body}"); - return null; - } - else - { - ThrowProbeFailedException( - $"ImdsV2ManagedIdentitySource.GetCsrMetadataAsync failed due to HTTP error. Status code: {response.StatusCode} Body: {response.Body}", - null, - (int)response.StatusCode); - } + ThrowCsrMetadataRequestException( + $"ImdsV2ManagedIdentitySource.GetCsrMetadataAsync failed due to HTTP error. Status code: {response.StatusCode} Body: {response.Body}", + null, + (int)response.StatusCode); } - if (!ValidateCsrMetadataResponse(response, requestContext.Logger, probeMode)) + if (!ValidateCsrMetadataResponse(response, requestContext.Logger)) { return null; } - return TryCreateCsrMetadata(response, requestContext.Logger, probeMode); -#endif + return TryCreateCsrMetadata(response, requestContext.Logger); } - private static void ThrowProbeFailedException( + private static void ThrowCsrMetadataRequestException( String errorMessage, Exception ex = null, int? statusCode = null) @@ -125,8 +102,7 @@ private static void ThrowProbeFailedException( private static bool ValidateCsrMetadataResponse( HttpResponse response, - ILoggerAdapter logger, - bool probeMode) + ILoggerAdapter logger) { string serverHeader = response.HeadersAsDictionary .FirstOrDefault((kvp) => { @@ -135,34 +111,18 @@ private static bool ValidateCsrMetadataResponse( if (serverHeader == null) { - if (probeMode) - { - logger.Info(() => $"[Managed Identity] IMDSv2 managed identity is not available. 'server' header is missing from the CSR metadata response. Body: {response.Body}"); - return false; - } - else - { - ThrowProbeFailedException( - $"ImdsV2ManagedIdentitySource.GetCsrMetadataAsync failed because response doesn't have server header. Status code: {response.StatusCode} Body: {response.Body}", - null, - (int)response.StatusCode); - } + ThrowCsrMetadataRequestException( + $"ImdsV2ManagedIdentitySource.GetCsrMetadataAsync failed because response doesn't have server header. Status code: {response.StatusCode} Body: {response.Body}", + null, + (int)response.StatusCode); } if (!serverHeader.Contains("IMDS", StringComparison.OrdinalIgnoreCase)) { - if (probeMode) - { - logger.Info(() => $"[Managed Identity] IMDSv2 managed identity is not available. The 'server' header format is invalid. Extracted server header: {serverHeader}"); - return false; - } - else - { - ThrowProbeFailedException( - $"ImdsV2ManagedIdentitySource.GetCsrMetadataAsync failed because the 'server' header format is invalid. Extracted server header: {serverHeader}. Status code: {response.StatusCode} Body: {response.Body}", - null, - (int)response.StatusCode); - } + ThrowCsrMetadataRequestException( + $"ImdsV2ManagedIdentitySource.GetCsrMetadataAsync failed because the 'server' header format is invalid. Extracted server header: {serverHeader}. Status code: {response.StatusCode} Body: {response.Body}", + null, + (int)response.StatusCode); } return true; @@ -170,13 +130,12 @@ private static bool ValidateCsrMetadataResponse( private static CsrMetadata TryCreateCsrMetadata( HttpResponse response, - ILoggerAdapter logger, - bool probeMode) + ILoggerAdapter logger) { CsrMetadata csrMetadata = JsonHelper.DeserializeFromJson(response.Body); if (!CsrMetadata.ValidateCsrMetadata(csrMetadata)) { - ThrowProbeFailedException( + ThrowCsrMetadataRequestException( $"ImdsV2ManagedIdentitySource.GetCsrMetadataAsync failed because the CsrMetadata response is invalid. Status code: {response.StatusCode} Body: {response.Body}", null, (int)response.StatusCode); @@ -212,7 +171,7 @@ private async Task ExecuteCertificateRequestAsync( string csr, ManagedIdentityKeyInfo managedIdentityKeyInfo) { - var queryParams = ImdsV2QueryParamsHelper(_requestContext); + var queryParams = ImdsManagedIdentitySource.ImdsQueryParamsHelper(_requestContext, ApiVersionQueryParam, ImdsV2ApiVersion); // TODO: add bypass_cache query param in case of token revocation. Boolean: true/false @@ -300,7 +259,7 @@ private async Task ExecuteCertificateRequestAsync( protected override async Task CreateRequestAsync(string resource) { - CsrMetadata csrMetadata = await GetCsrMetadataAsync(_requestContext, false).ConfigureAwait(false); + CsrMetadata csrMetadata = await GetCsrMetadataAsync(_requestContext).ConfigureAwait(false); string certCacheKey = _requestContext.ServiceBundle.Config.ClientId; @@ -381,23 +340,6 @@ protected override async Task CreateRequestAsync(string return request; } - private static string ImdsV2QueryParamsHelper(RequestContext requestContext) - { - var queryParams = $"cred-api-version={ImdsV2ApiVersion}"; - - var userAssignedIdQueryParam = ImdsManagedIdentitySource.GetUserAssignedIdQueryParam( - requestContext.ServiceBundle.Config.ManagedIdentityId.IdType, - requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId, - requestContext.Logger); - - if (userAssignedIdQueryParam != null) - { - queryParams += $"&{userAssignedIdQueryParam.Value.Key}={userAssignedIdQueryParam.Value.Value}"; - } - - return queryParams; - } - /// /// Obtains an attestation JWT for the KeyGuard/CSR payload using the configured /// attestation provider and normalized endpoint. diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsVersion.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsVersion.cs new file mode 100644 index 0000000000..9da6a5a511 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsVersion.cs @@ -0,0 +1,7 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace Microsoft.Identity.Client.ManagedIdentity.V2 +{ + internal enum ImdsVersion { V1, V2 } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentityApplication.cs b/src/client/Microsoft.Identity.Client/ManagedIdentityApplication.cs index ea1fdb9c37..9a0e2e46e9 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentityApplication.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentityApplication.cs @@ -56,18 +56,18 @@ public AcquireTokenForManagedIdentityParameterBuilder AcquireTokenForManagedIden } /// - public async Task GetManagedIdentitySourceAsync() + public async Task GetManagedIdentitySourceAsync(CancellationToken cancellationToken) { if (ManagedIdentityClient.s_sourceName != ManagedIdentitySource.None) { return ManagedIdentityClient.s_sourceName; } - // Create a temporary RequestContext for the CSR metadata probe request. - var csrMetadataProbeRequestContext = new RequestContext(this.ServiceBundle, Guid.NewGuid(), null, CancellationToken.None); + // Create a temporary RequestContext for the logger and the IMDS probe request. + var requestContext = new RequestContext(this.ServiceBundle, Guid.NewGuid(), null, cancellationToken); // GetManagedIdentitySourceAsync might return ImdsV2 = true, but it still requires .WithMtlsProofOfPossesion on the Managed Identity Application object to hit the ImdsV2 flow - return await ManagedIdentityClient.GetManagedIdentitySourceAsync(csrMetadataProbeRequestContext, isMtlsPopRequested: true).ConfigureAwait(false); + return await ManagedIdentityClient.GetManagedIdentitySourceAsync(requestContext, isMtlsPopRequested: true, cancellationToken).ConfigureAwait(false); } /// @@ -77,7 +77,16 @@ public async Task GetManagedIdentitySourceAsync() [Obsolete("Use GetManagedIdentitySourceAsync() instead. \"ManagedIdentityApplication mi = miBuilder.Build() as ManagedIdentityApplication;\"")] public static ManagedIdentitySource GetManagedIdentitySource() { - return ManagedIdentityClient.GetManagedIdentitySourceNoImdsV2(); + var source = ManagedIdentityClient.GetManagedIdentitySourceNoImds(); + + return source == ManagedIdentitySource.None +#pragma warning disable CS0618 + // ManagedIdentitySource.DefaultToImds is marked obsolete, but is intentionally used here as a sentinel value to support legacy detection logic. + // This value signals that none of the environment-based managed identity sources were detected. + ? ManagedIdentitySource.DefaultToImds +#pragma warning restore CS0618 + : source; + } } } diff --git a/src/client/Microsoft.Identity.Client/MsalError.cs b/src/client/Microsoft.Identity.Client/MsalError.cs index 526718e7df..688458a3cf 100644 --- a/src/client/Microsoft.Identity.Client/MsalError.cs +++ b/src/client/Microsoft.Identity.Client/MsalError.cs @@ -1227,5 +1227,10 @@ public static class MsalError /// mTLS PoP tokens are not supported in IMDS V1. /// public const string MtlsPopTokenNotSupportedinImdsV1 = "mtls_pop_token_not_supported_in_imds_v1"; + + /// + /// All managed identity sources are unavailable. + /// + public const string ManagedIdentityAllSourcesUnavailable = "managed_identity_all_sources_unavailable"; } } diff --git a/src/client/Microsoft.Identity.Client/MsalErrorMessage.cs b/src/client/Microsoft.Identity.Client/MsalErrorMessage.cs index 59742eead3..efff53181c 100644 --- a/src/client/Microsoft.Identity.Client/MsalErrorMessage.cs +++ b/src/client/Microsoft.Identity.Client/MsalErrorMessage.cs @@ -450,5 +450,6 @@ public static string InvalidTokenProviderResponseValue(string invalidValueName) public const string InvalidCertificate = "The certificate received from the Imds server is invalid."; public const string CannotSwitchBetweenImdsVersionsForPreview = "ImdsV2 is currently experimental - A Bearer token has already been received; Please restart the application to receive a mTLS PoP token."; public const string MtlsPopTokenNotSupportedinImdsV1 = "A mTLS PoP token cannot be requested because the application\'s source was determined to be ImdsV1."; + public const string ManagedIdentityAllSourcesUnavailable = "All Managed Identity sources are unavailable."; } } diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net462/PublicAPI.Shipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net462/PublicAPI.Shipped.txt index 8d65bf72b7..201cd0ba14 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/net462/PublicAPI.Shipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/net462/PublicAPI.Shipped.txt @@ -1086,7 +1086,6 @@ const Microsoft.Identity.Client.MsalError.MtlsNotSupportedForManagedIdentity = " const Microsoft.Identity.Client.MsalError.MtlsPopTokenNotSupportedinImdsV1 = "mtls_pop_token_not_supported_in_imds_v1" -> string Microsoft.Identity.Client.IMsalMtlsHttpClientFactory Microsoft.Identity.Client.IMsalMtlsHttpClientFactory.GetHttpClient(System.Security.Cryptography.X509Certificates.X509Certificate2 x509Certificate2) -> System.Net.Http.HttpClient -Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySourceAsync() -> System.Threading.Tasks.Task Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.ImdsV2 = 8 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource Microsoft.Identity.Client.ManagedIdentityApplicationBuilder.WithExtraQueryParameters(System.Collections.Generic.IDictionary extraQueryParameters) -> Microsoft.Identity.Client.ManagedIdentityApplicationBuilder static Microsoft.Identity.Client.ApplicationBase.ResetStateForTest() -> void diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net462/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net462/PublicAPI.Unshipped.txt index 5f282702bb..3241ccd9cd 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/net462/PublicAPI.Unshipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/net462/PublicAPI.Unshipped.txt @@ -1 +1,2 @@ - \ No newline at end of file +const Microsoft.Identity.Client.MsalError.ManagedIdentityAllSourcesUnavailable = "managed_identity_all_sources_unavailable" -> string +Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySourceAsync(System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net472/PublicAPI.Shipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net472/PublicAPI.Shipped.txt index 8d65bf72b7..201cd0ba14 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/net472/PublicAPI.Shipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/net472/PublicAPI.Shipped.txt @@ -1086,7 +1086,6 @@ const Microsoft.Identity.Client.MsalError.MtlsNotSupportedForManagedIdentity = " const Microsoft.Identity.Client.MsalError.MtlsPopTokenNotSupportedinImdsV1 = "mtls_pop_token_not_supported_in_imds_v1" -> string Microsoft.Identity.Client.IMsalMtlsHttpClientFactory Microsoft.Identity.Client.IMsalMtlsHttpClientFactory.GetHttpClient(System.Security.Cryptography.X509Certificates.X509Certificate2 x509Certificate2) -> System.Net.Http.HttpClient -Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySourceAsync() -> System.Threading.Tasks.Task Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.ImdsV2 = 8 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource Microsoft.Identity.Client.ManagedIdentityApplicationBuilder.WithExtraQueryParameters(System.Collections.Generic.IDictionary extraQueryParameters) -> Microsoft.Identity.Client.ManagedIdentityApplicationBuilder static Microsoft.Identity.Client.ApplicationBase.ResetStateForTest() -> void diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net472/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net472/PublicAPI.Unshipped.txt index 5f282702bb..3241ccd9cd 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/net472/PublicAPI.Unshipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/net472/PublicAPI.Unshipped.txt @@ -1 +1,2 @@ - \ No newline at end of file +const Microsoft.Identity.Client.MsalError.ManagedIdentityAllSourcesUnavailable = "managed_identity_all_sources_unavailable" -> string +Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySourceAsync(System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net8.0-android/PublicAPI.Shipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net8.0-android/PublicAPI.Shipped.txt index 54b77db373..4597fbd530 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/net8.0-android/PublicAPI.Shipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/net8.0-android/PublicAPI.Shipped.txt @@ -1052,7 +1052,6 @@ const Microsoft.Identity.Client.MsalError.MtlsNotSupportedForManagedIdentity = " const Microsoft.Identity.Client.MsalError.MtlsPopTokenNotSupportedinImdsV1 = "mtls_pop_token_not_supported_in_imds_v1" -> string Microsoft.Identity.Client.IMsalMtlsHttpClientFactory Microsoft.Identity.Client.IMsalMtlsHttpClientFactory.GetHttpClient(System.Security.Cryptography.X509Certificates.X509Certificate2 x509Certificate2) -> System.Net.Http.HttpClient -Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySourceAsync() -> System.Threading.Tasks.Task Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.ImdsV2 = 8 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource Microsoft.Identity.Client.ManagedIdentityApplicationBuilder.WithExtraQueryParameters(System.Collections.Generic.IDictionary extraQueryParameters) -> Microsoft.Identity.Client.ManagedIdentityApplicationBuilder static Microsoft.Identity.Client.ApplicationBase.ResetStateForTest() -> void diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net8.0-android/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net8.0-android/PublicAPI.Unshipped.txt index 5f282702bb..3241ccd9cd 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/net8.0-android/PublicAPI.Unshipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/net8.0-android/PublicAPI.Unshipped.txt @@ -1 +1,2 @@ - \ No newline at end of file +const Microsoft.Identity.Client.MsalError.ManagedIdentityAllSourcesUnavailable = "managed_identity_all_sources_unavailable" -> string +Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySourceAsync(System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net8.0-ios/PublicAPI.Shipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net8.0-ios/PublicAPI.Shipped.txt index cddad35aae..a9bd19a52a 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/net8.0-ios/PublicAPI.Shipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/net8.0-ios/PublicAPI.Shipped.txt @@ -1054,7 +1054,6 @@ const Microsoft.Identity.Client.MsalError.MtlsNotSupportedForManagedIdentity = " const Microsoft.Identity.Client.MsalError.MtlsPopTokenNotSupportedinImdsV1 = "mtls_pop_token_not_supported_in_imds_v1" -> string Microsoft.Identity.Client.IMsalMtlsHttpClientFactory Microsoft.Identity.Client.IMsalMtlsHttpClientFactory.GetHttpClient(System.Security.Cryptography.X509Certificates.X509Certificate2 x509Certificate2) -> System.Net.Http.HttpClient -Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySourceAsync() -> System.Threading.Tasks.Task Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.ImdsV2 = 8 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource Microsoft.Identity.Client.ManagedIdentityApplicationBuilder.WithExtraQueryParameters(System.Collections.Generic.IDictionary extraQueryParameters) -> Microsoft.Identity.Client.ManagedIdentityApplicationBuilder static Microsoft.Identity.Client.ApplicationBase.ResetStateForTest() -> void diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net8.0-ios/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net8.0-ios/PublicAPI.Unshipped.txt index 5f282702bb..3241ccd9cd 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/net8.0-ios/PublicAPI.Unshipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/net8.0-ios/PublicAPI.Unshipped.txt @@ -1 +1,2 @@ - \ No newline at end of file +const Microsoft.Identity.Client.MsalError.ManagedIdentityAllSourcesUnavailable = "managed_identity_all_sources_unavailable" -> string +Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySourceAsync(System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net8.0/PublicAPI.Shipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net8.0/PublicAPI.Shipped.txt index c88bee10b4..7bfb7d24c6 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/net8.0/PublicAPI.Shipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/net8.0/PublicAPI.Shipped.txt @@ -1048,7 +1048,6 @@ const Microsoft.Identity.Client.MsalError.MtlsNotSupportedForManagedIdentity = " const Microsoft.Identity.Client.MsalError.MtlsPopTokenNotSupportedinImdsV1 = "mtls_pop_token_not_supported_in_imds_v1" -> string Microsoft.Identity.Client.IMsalMtlsHttpClientFactory Microsoft.Identity.Client.IMsalMtlsHttpClientFactory.GetHttpClient(System.Security.Cryptography.X509Certificates.X509Certificate2 x509Certificate2) -> System.Net.Http.HttpClient -Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySourceAsync() -> System.Threading.Tasks.Task Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.ImdsV2 = 8 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource Microsoft.Identity.Client.ManagedIdentityApplicationBuilder.WithExtraQueryParameters(System.Collections.Generic.IDictionary extraQueryParameters) -> Microsoft.Identity.Client.ManagedIdentityApplicationBuilder static Microsoft.Identity.Client.ApplicationBase.ResetStateForTest() -> void diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net8.0/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net8.0/PublicAPI.Unshipped.txt index 5f282702bb..3241ccd9cd 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/net8.0/PublicAPI.Unshipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/net8.0/PublicAPI.Unshipped.txt @@ -1 +1,2 @@ - \ No newline at end of file +const Microsoft.Identity.Client.MsalError.ManagedIdentityAllSourcesUnavailable = "managed_identity_all_sources_unavailable" -> string +Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySourceAsync(System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task diff --git a/src/client/Microsoft.Identity.Client/PublicApi/netstandard2.0/PublicAPI.Shipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/netstandard2.0/PublicAPI.Shipped.txt index 40d355e131..115ba8489a 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/netstandard2.0/PublicAPI.Shipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/netstandard2.0/PublicAPI.Shipped.txt @@ -1048,7 +1048,6 @@ const Microsoft.Identity.Client.MsalError.MtlsNotSupportedForManagedIdentity = " const Microsoft.Identity.Client.MsalError.MtlsPopTokenNotSupportedinImdsV1 = "mtls_pop_token_not_supported_in_imds_v1" -> string Microsoft.Identity.Client.IMsalMtlsHttpClientFactory Microsoft.Identity.Client.IMsalMtlsHttpClientFactory.GetHttpClient(System.Security.Cryptography.X509Certificates.X509Certificate2 x509Certificate2) -> System.Net.Http.HttpClient -Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySourceAsync() -> System.Threading.Tasks.Task Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.ImdsV2 = 8 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource Microsoft.Identity.Client.ManagedIdentityApplicationBuilder.WithExtraQueryParameters(System.Collections.Generic.IDictionary extraQueryParameters) -> Microsoft.Identity.Client.ManagedIdentityApplicationBuilder static Microsoft.Identity.Client.ApplicationBase.ResetStateForTest() -> void diff --git a/src/client/Microsoft.Identity.Client/PublicApi/netstandard2.0/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/netstandard2.0/PublicAPI.Unshipped.txt index 5f282702bb..3241ccd9cd 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/netstandard2.0/PublicAPI.Unshipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/netstandard2.0/PublicAPI.Unshipped.txt @@ -1 +1,2 @@ - \ No newline at end of file +const Microsoft.Identity.Client.MsalError.ManagedIdentityAllSourcesUnavailable = "managed_identity_all_sources_unavailable" -> string +Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySourceAsync(System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task diff --git a/src/client/Microsoft.Identity.Client/RequestType.cs b/src/client/Microsoft.Identity.Client/RequestType.cs index 272bcfa5c9..0339d8145d 100644 --- a/src/client/Microsoft.Identity.Client/RequestType.cs +++ b/src/client/Microsoft.Identity.Client/RequestType.cs @@ -19,18 +19,18 @@ internal enum RequestType ManagedIdentityDefault, /// - /// Instance Metadata Service (IMDS) request, used for obtaining tokens from the Azure VM metadata endpoint. + /// Instance Metadata Service (IMDS) v1 and v2 probe request, used to probe IMDS v1 and v2 managed identities to determine if they are available. /// - Imds, + ImdsProbe, /// - /// Region Discovery request, used for region discovery operations with exponential backoff retry strategy. + /// Instance Metadata Service (IMDS) request, used for obtaining tokens from the Azure VM metadata endpoint. /// - RegionDiscovery, + Imds, /// - /// CSR Metadata Probe request, used to probe an IMDSv2 managed identity for metadata to be used in acquiring a token. + /// Region Discovery request, used for region discovery operations with exponential backoff retry strategy. /// - CsrMetadataProbe + RegionDiscovery } } diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs index 0cd3aab2a6..4bfd94d5a7 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs @@ -614,6 +614,93 @@ public static MsalTokenResponse CreateMsalRunTimeBrokerTokenResponse(string acce }; } + public static MockHttpMessageHandler MockImdsProbe( + ImdsVersion imdsVersion, + UserAssignedIdentityId userAssignedIdentityId = UserAssignedIdentityId.None, + string userAssignedId = null, + bool success = true, + bool retry = false) + { + string apiVersionQueryParam; + string imdsApiVersion; + string imdsEndpoint; + + switch (imdsVersion) + { + case ImdsVersion.V2: + apiVersionQueryParam = ImdsV2ManagedIdentitySource.ApiVersionQueryParam; + imdsApiVersion = ImdsV2ManagedIdentitySource.ImdsV2ApiVersion; + imdsEndpoint = ImdsV2ManagedIdentitySource.CsrMetadataPath; + break; + + case ImdsVersion.V1: + apiVersionQueryParam = ImdsManagedIdentitySource.ApiVersionQueryParam; + imdsApiVersion = ImdsManagedIdentitySource.ImdsApiVersion; + imdsEndpoint = ImdsManagedIdentitySource.ImdsTokenPath; + break; + + default: + throw new ArgumentOutOfRangeException(nameof(imdsVersion), imdsVersion, null); + } + + HttpStatusCode statusCode; + + if (success) + { + statusCode = HttpStatusCode.BadRequest; // IMDS probe success returns 400 Bad Request + } + else + { + if (retry) + { + statusCode = HttpStatusCode.InternalServerError; + } + else + { + statusCode = HttpStatusCode.NotFound; + } + } + + IDictionary expectedQueryParams = new Dictionary(); + IDictionary expectedRequestHeaders = new Dictionary(); + IList presentRequestHeaders = new List + { + OAuth2Header.XMsCorrelationId + }; + + if (userAssignedIdentityId != UserAssignedIdentityId.None && userAssignedId != null) + { + var userAssignedIdQueryParam = ImdsManagedIdentitySource.GetUserAssignedIdQueryParam( + (ManagedIdentityIdType)userAssignedIdentityId, userAssignedId, null); + expectedQueryParams.Add(userAssignedIdQueryParam.Value.Key, userAssignedIdQueryParam.Value.Value); + } + expectedQueryParams.Add(apiVersionQueryParam, imdsApiVersion); + + var handler = new MockHttpMessageHandler() + { + ExpectedUrl = $"{ImdsManagedIdentitySource.DefaultImdsBaseEndpoint}{imdsEndpoint}", + ExpectedMethod = HttpMethod.Get, + ExpectedQueryParams = expectedQueryParams, + ExpectedRequestHeaders = expectedRequestHeaders, + PresentRequestHeaders = presentRequestHeaders, + ResponseMessage = new HttpResponseMessage(statusCode) + { + Content = new StringContent(""), + } + }; + + return handler; + } + + public static MockHttpMessageHandler MockImdsProbeFailure( + ImdsVersion imdsVersion, + UserAssignedIdentityId userAssignedIdentityId = UserAssignedIdentityId.None, + string userAssignedId = null, + bool retry = false) + { + return MockImdsProbe(imdsVersion, userAssignedIdentityId, userAssignedId, success: false, retry: retry); + } + public static MockHttpMessageHandler MockCsrResponse( HttpStatusCode statusCode = HttpStatusCode.OK, string responseServerHeader = "IMDS/150.870.65.1854", @@ -626,9 +713,9 @@ public static MockHttpMessageHandler MockCsrResponse( IDictionary expectedQueryParams = new Dictionary(); IDictionary expectedRequestHeaders = new Dictionary(); IList presentRequestHeaders = new List - { - OAuth2Header.XMsCorrelationId - }; + { + OAuth2Header.XMsCorrelationId + }; if (userAssignedIdentityId != UserAssignedIdentityId.None && userAssignedId != null) { @@ -666,7 +753,6 @@ public static MockHttpMessageHandler MockCsrResponse( return handler; } - // used for unit tests in ManagedIdentityTests.cs public static MockHttpMessageHandler MockCsrResponseFailure() { // 400 doesn't trigger the retry policy @@ -674,19 +760,19 @@ public static MockHttpMessageHandler MockCsrResponseFailure() } public static MockHttpMessageHandler MockCertificateRequestResponse( - UserAssignedIdentityId userAssignedIdentityId = UserAssignedIdentityId.None, - string userAssignedId = null, - string certificate = TestConstants.ValidRawCertificate, - string clientIdOverride = null, - string tenantIdOverride = null, - string mtlsEndpointOverride = null) + UserAssignedIdentityId userAssignedIdentityId = UserAssignedIdentityId.None, + string userAssignedId = null, + string certificate = TestConstants.ValidRawCertificate, + string clientIdOverride = null, + string tenantIdOverride = null, + string mtlsEndpointOverride = null) { IDictionary expectedQueryParams = new Dictionary(); IDictionary expectedRequestHeaders = new Dictionary(); IList presentRequestHeaders = new List - { - OAuth2Header.XMsCorrelationId - }; + { + OAuth2Header.XMsCorrelationId + }; if (userAssignedIdentityId != UserAssignedIdentityId.None && userAssignedId != null) { diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpMessageHandler.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpMessageHandler.cs index cdfbb5432b..061d5b15ce 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpMessageHandler.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpMessageHandler.cs @@ -43,7 +43,6 @@ internal class MockHttpMessageHandler : HttpClientHandler protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) { - ActualRequestMessage = request; if (ExceptionToThrow != null) diff --git a/tests/Microsoft.Identity.Test.Unit/Helpers/TestRetryPolicies.cs b/tests/Microsoft.Identity.Test.Unit/Helpers/TestRetryPolicies.cs index 07a2ba9a11..79a68473ba 100644 --- a/tests/Microsoft.Identity.Test.Unit/Helpers/TestRetryPolicies.cs +++ b/tests/Microsoft.Identity.Test.Unit/Helpers/TestRetryPolicies.cs @@ -40,9 +40,9 @@ internal override Task DelayAsync(int milliseconds) } } - internal class TestCsrMetadataProbeRetryPolicy : CsrMetadataProbeRetryPolicy + internal class TestImdsProbeRetryPolicy : ImdsProbeRetryPolicy { - public TestCsrMetadataProbeRetryPolicy() : base() { } + public TestImdsProbeRetryPolicy() : base() { } internal override Task DelayAsync(int milliseconds) { diff --git a/tests/Microsoft.Identity.Test.Unit/Helpers/TestRetryPolicyFactory.cs b/tests/Microsoft.Identity.Test.Unit/Helpers/TestRetryPolicyFactory.cs index 2ed0c98f0d..c501026cd7 100644 --- a/tests/Microsoft.Identity.Test.Unit/Helpers/TestRetryPolicyFactory.cs +++ b/tests/Microsoft.Identity.Test.Unit/Helpers/TestRetryPolicyFactory.cs @@ -16,12 +16,12 @@ public virtual IRetryPolicy GetRetryPolicy(RequestType requestType) case RequestType.STS: case RequestType.ManagedIdentityDefault: return new TestDefaultRetryPolicy(requestType); + case RequestType.ImdsProbe: + return new TestImdsProbeRetryPolicy(); case RequestType.Imds: return new TestImdsRetryPolicy(); case RequestType.RegionDiscovery: return new TestRegionDiscoveryRetryPolicy(); - case RequestType.CsrMetadataProbe: - return new TestCsrMetadataProbeRetryPolicy(); default: throw new ArgumentOutOfRangeException(nameof(requestType), requestType, "Unknown request type."); } diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/AppServiceTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/AppServiceTests.cs index c78924c248..508ace58cc 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/AppServiceTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/AppServiceTests.cs @@ -70,7 +70,7 @@ public async Task TestAppServiceUpgradeScenario( ManagedIdentityApplication mi = miBuilder.Build() as ManagedIdentityApplication; - Assert.AreEqual(expectedManagedIdentitySource, await mi.GetManagedIdentitySourceAsync().ConfigureAwait(false)); + Assert.AreEqual(expectedManagedIdentitySource, await mi.GetManagedIdentitySourceAsync(ManagedIdentityTests.ImdsProbesCancellationToken).ConfigureAwait(false)); } } } diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsTests.cs index 9853d54283..204074482d 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsTests.cs @@ -3,10 +3,13 @@ using System; using System.Net; +using System.Net.Http; +using System.Threading; using System.Threading.Tasks; using Microsoft.Identity.Client; using Microsoft.Identity.Client.AppConfig; using Microsoft.Identity.Client.ManagedIdentity; +using Microsoft.Identity.Client.ManagedIdentity.V2; using Microsoft.Identity.Test.Common.Core.Helpers; using Microsoft.Identity.Test.Common.Core.Mocks; using Microsoft.Identity.Test.Unit.Helpers; @@ -39,11 +42,10 @@ public async Task ImdsFails404TwiceThenSucceeds200Async( .WithHttpManager(httpManager) .WithRetryPolicyFactory(_testRetryPolicyFactory); - // Disable cache to avoid pollution - - IManagedIdentityApplication mi = miBuilder.Build(); + ManagedIdentityTests.MockImdsV1Probe(httpManager, ManagedIdentitySource.Imds, userAssignedIdentityId, userAssignedId); + // Simulate two 404s (to trigger retries), then a successful response const int Num404Errors = 2; for (int i = 0; i < Num404Errors; i++) @@ -98,11 +100,10 @@ public async Task ImdsFails410FourTimesThenSucceeds200Async( .WithHttpManager(httpManager) .WithRetryPolicyFactory(_testRetryPolicyFactory); - // Disable cache to avoid pollution - - IManagedIdentityApplication mi = miBuilder.Build(); + ManagedIdentityTests.MockImdsV1Probe(httpManager, ManagedIdentitySource.Imds, userAssignedIdentityId, userAssignedId); + // Simulate four 410s (to trigger retries), then a successful response const int Num410Errors = 4; for (int i = 0; i < Num410Errors; i++) @@ -157,11 +158,10 @@ public async Task ImdsFails410PermanentlyAsync( .WithHttpManager(httpManager) .WithRetryPolicyFactory(_testRetryPolicyFactory); - // Disable cache to avoid pollution - - IManagedIdentityApplication mi = miBuilder.Build(); + ManagedIdentityTests.MockImdsV1Probe(httpManager, ManagedIdentitySource.Imds, userAssignedIdentityId, userAssignedId); + // Simulate permanent 410s (to trigger the maximum number of retries) const int Num410Errors = 1 + TestImdsRetryPolicy.LinearStrategyNumRetries; // initial request + maximum number of retries for (int i = 0; i < Num410Errors; i++) @@ -213,11 +213,10 @@ public async Task ImdsFails504PermanentlyAsync( .WithHttpManager(httpManager) .WithRetryPolicyFactory(_testRetryPolicyFactory); - // Disable cache to avoid pollution - - IManagedIdentityApplication mi = miBuilder.Build(); + ManagedIdentityTests.MockImdsV1Probe(httpManager, ManagedIdentitySource.Imds, userAssignedIdentityId, userAssignedId); + // Simulate permanent 504s (to trigger the maximum number of retries) const int Num504Errors = 1 + TestImdsRetryPolicy.ExponentialStrategyNumRetries; // initial request + maximum number of retries for (int i = 0; i < Num504Errors; i++) @@ -269,11 +268,10 @@ public async Task ImdsFails400WhichIsNonRetriableAndRetryPolicyIsNotTriggeredAsy .WithHttpManager(httpManager) .WithRetryPolicyFactory(_testRetryPolicyFactory); - // Disable cache to avoid pollution - - IManagedIdentityApplication mi = miBuilder.Build(); + ManagedIdentityTests.MockImdsV1Probe(httpManager, ManagedIdentitySource.Imds, userAssignedIdentityId, userAssignedId); + httpManager.AddManagedIdentityMockHandler( ManagedIdentityTests.ImdsEndpoint, ManagedIdentityTests.Resource, @@ -321,11 +319,10 @@ public async Task ImdsFails500AndRetryPolicyIsDisabledAndNotTriggeredAsync( .WithHttpManager(httpManager) .WithRetryPolicyFactory(_testRetryPolicyFactory); - // Disable cache to avoid pollution - - IManagedIdentityApplication mi = miBuilder.Build(); + ManagedIdentityTests.MockImdsV1Probe(httpManager, ManagedIdentitySource.Imds, userAssignedIdentityId, userAssignedId); + httpManager.AddManagedIdentityMockHandler( ManagedIdentityTests.ImdsEndpoint, ManagedIdentityTests.Resource, @@ -367,11 +364,10 @@ public async Task ImdsRetryPolicyLifeTimeIsPerRequestAsync() .WithHttpManager(httpManager) .WithRetryPolicyFactory(_testRetryPolicyFactory); - // Disable cache to avoid pollution - - IManagedIdentityApplication mi = miBuilder.Build(); + ManagedIdentityTests.MockImdsV1Probe(httpManager, ManagedIdentitySource.Imds); + // Simulate permanent errors (to trigger the maximum number of retries) const int Num504Errors = 1 + TestImdsRetryPolicy.ExponentialStrategyNumRetries; // initial request + maximum number of retries for (int i = 0; i < Num504Errors; i++) @@ -438,5 +434,36 @@ await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) Assert.AreEqual(Num504Errors, requestsMade); } } + + [TestMethod] + public async Task ProbeImdsEndpointAsync_TimesOutAfterOneSecond() + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned); + + miBuilder + .WithHttpManager(httpManager) + .WithRetryPolicyFactory(_testRetryPolicyFactory); + + var managedIdentityApp = miBuilder.Build(); + + httpManager.AddMockHandler(MockHelpers.MockImdsProbeFailure(ImdsVersion.V2)); + httpManager.AddMockHandler(MockHelpers.MockImdsProbe(ImdsVersion.V1)); + + var imdsProbesCancellationToken = new CancellationTokenSource(TimeSpan.FromSeconds(0)).Token; // timeout immediately + + var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync(imdsProbesCancellationToken).ConfigureAwait(false); + Assert.AreEqual(ManagedIdentitySource.None, miSource); // Probe timed out, no source available + + var ex = await Assert.ThrowsExceptionAsync(async () => + await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync().ConfigureAwait(false) + ).ConfigureAwait(false); + + Assert.AreEqual(MsalError.ManagedIdentityAllSourcesUnavailable, ex.ErrorCode); + } + } } } diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index da89c2c3fa..7e61d7d908 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -2,9 +2,7 @@ // Licensed under the MIT License. using System; -using System.Collections.Generic; using System.IO; -using System.Net; using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; using System.Threading; @@ -92,7 +90,7 @@ private async Task CreateManagedIdentityAsync( bool addProbeMock = true, bool addSourceCheck = true, ManagedIdentityKeyType managedIdentityKeyType = ManagedIdentityKeyType.InMemory, - bool imdsV2 = true) // false indicates imdsV1 + ImdsVersion imdsVersion = ImdsVersion.V2) { ManagedIdentityApplicationBuilder miBuilder = null; @@ -110,33 +108,28 @@ private async Task CreateManagedIdentityAsync( .WithHttpManager(httpManager) .WithRetryPolicyFactory(_testRetryPolicyFactory); - if (imdsV2) + if (imdsVersion == ImdsVersion.V2) { miBuilder.WithCsrFactory(_testCsrFactory); } var managedIdentityApp = miBuilder.Build(); - if (!imdsV2) + if (imdsVersion == ImdsVersion.V1) { + httpManager.AddMockHandler(MockHelpers.MockImdsProbeFailure(ImdsVersion.V2, userAssignedIdentityId, userAssignedId)); + httpManager.AddMockHandler(MockHelpers.MockImdsProbe(ImdsVersion.V1, userAssignedIdentityId, userAssignedId)); return managedIdentityApp; } if (addProbeMock) { - if (uami) - { - httpManager.AddMockHandler(MockHelpers.MockCsrResponse(userAssignedIdentityId: userAssignedIdentityId, userAssignedId: userAssignedId)); - } - else - { - httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); - } + httpManager.AddMockHandler(MockHelpers.MockImdsProbe(ImdsVersion.V2, userAssignedIdentityId, userAssignedId)); } if (addSourceCheck) { - var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); + var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync(ManagedIdentityTests.ImdsProbesCancellationToken).ConfigureAwait(false); Assert.AreEqual(ManagedIdentitySource.ImdsV2, miSource); } @@ -352,9 +345,7 @@ public async Task ImdsV2EndpointsAreNotAvailableButMtlsPopTokenWasRequested( { SetEnvironmentVariables(ManagedIdentitySource.Imds, TestConstants.ImdsEndpoint); - var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId, imdsV2: false).ConfigureAwait(false); - - httpManager.AddMockHandler(MockHelpers.MockCsrResponseFailure()); + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId, imdsVersion: ImdsVersion.V1).ConfigureAwait(false); var ex = await Assert.ThrowsExceptionAsync(async () => await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) @@ -382,6 +373,7 @@ public async Task ApplicationsCannotSwitchBetweenImdsVersionsForPreview( var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId, managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false); + // IMDSv1 request mock httpManager.AddManagedIdentityMockHandler( ManagedIdentityTests.ImdsEndpoint, ManagedIdentityTests.Resource, @@ -400,7 +392,7 @@ public async Task ApplicationsCannotSwitchBetweenImdsVersionsForPreview( Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); // even though the app fell back to ImdsV1, the source should still be ImdsV2 - var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); + var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync(ManagedIdentityTests.ImdsProbesCancellationToken).ConfigureAwait(false); Assert.AreEqual(ManagedIdentitySource.ImdsV2, miSource); // none of the mocks from AddMocksToGetEntraToken are needed since checking the cache occurs before the network requests @@ -415,165 +407,169 @@ await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Res } #endregion Failure Tests + #region Probe Tests [TestMethod] - public async Task GetCsrMetadataAsyncSucceeds() + public async Task ProbeImdsEndpointAsyncSucceeds() { using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); - var handler = httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); + httpManager.AddMockHandler(MockHelpers.MockImdsProbe(ImdsVersion.V2)); await CreateManagedIdentityAsync(httpManager, addProbeMock: false).ConfigureAwait(false); } } [TestMethod] - public async Task GetCsrMetadataAsyncSucceedsAfterRetry() + public async Task ProbeImdsEndpointAsyncSucceedsAfterRetry() { using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); - // First attempt fails with INTERNAL_SERVER_ERROR (500) - httpManager.AddMockHandler(MockHelpers.MockCsrResponse(HttpStatusCode.InternalServerError)); + // `retry: true` indicates a retriable status code will be returned + httpManager.AddMockHandler(MockHelpers.MockImdsProbeFailure(ImdsVersion.V2, retry: true)); - // Second attempt succeeds (defined inside of CreateSAMIAsync) + // Second attempt succeeds (defined inside of CreateManagedIdentityAsync) await CreateManagedIdentityAsync(httpManager).ConfigureAwait(false); } } [TestMethod] - public async Task GetCsrMetadataAsyncFailsWithMissingServerHeader() + public async Task ProbeImdsEndpointAsyncFails404WhichIsNonRetriableAndRetryPolicyIsNotTriggeredAsync() { using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); - httpManager.AddMockHandler(MockHelpers.MockCsrResponse(responseServerHeader: null)); + // `retry: false` indicates a retriable status code will be returned + httpManager.AddMockHandler(MockHelpers.MockImdsProbeFailure(ImdsVersion.V2, retry: false)); + httpManager.AddMockHandler(MockHelpers.MockImdsProbe(ImdsVersion.V1)); var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, addProbeMock: false, addSourceCheck: false).ConfigureAwait(false); - var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); - Assert.AreEqual(ManagedIdentitySource.DefaultToImds, miSource); + var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync(ManagedIdentityTests.ImdsProbesCancellationToken).ConfigureAwait(false); + Assert.AreEqual(ManagedIdentitySource.Imds, miSource); } } + #endregion Probe Tests - [TestMethod] - public async Task GetCsrMetadataAsyncFailsWithInvalidFormat() + #region Fallback Behavior Tests + // Verifies non-mTLS request after IMDSv2 detection falls back per-request to IMDSv1 (Bearer), + [DataTestMethod] + [DataRow(UserAssignedIdentityId.None, null)] + [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId)] + public async Task NonMtlsRequest_FallsBackToImdsV1( + UserAssignedIdentityId userAssignedIdentityId, + string userAssignedId) { using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { + ManagedIdentityClient.ResetSourceForTest(); SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); - httpManager.AddMockHandler(MockHelpers.MockCsrResponse(responseServerHeader: "I_MDS/150.870.65.1854")); + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId, managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false); - var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, addProbeMock: false, addSourceCheck: false).ConfigureAwait(false); + // IMDSv1 request mock + httpManager.AddManagedIdentityMockHandler( + ManagedIdentityTests.ImdsEndpoint, + ManagedIdentityTests.Resource, + MockHelpers.GetMsiSuccessfulResponse(), + ManagedIdentitySource.Imds, + userAssignedIdentityId: userAssignedIdentityId, + userAssignedId: userAssignedId); - var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); - Assert.AreEqual(ManagedIdentitySource.DefaultToImds, miSource); + var result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + //.WithMtlsProofOfPossession() - excluding this will cause fallback to ImdsV1 + .ExecuteAsync().ConfigureAwait(false); + + Assert.AreEqual(Bearer, result.TokenType); + Assert.IsNull(result.BindingCertificate, "Bearer token should not have binding certificate."); + Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); + + // indicates ImdsV2 is still available + var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync(ManagedIdentityTests.ImdsProbesCancellationToken).ConfigureAwait(false); + Assert.AreEqual(ManagedIdentitySource.ImdsV2, miSource); } } [TestMethod] - public async Task GetCsrMetadataAsyncFailsAfterMaxRetries() + public async Task ImdsV2ProbeFailsMaxRetries_FallsBackToImdsV1() { using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); - const int Num500Errors = 1 + TestCsrMetadataProbeRetryPolicy.ExponentialStrategyNumRetries; + const int Num500Errors = 1 + TestImdsProbeRetryPolicy.ExponentialStrategyNumRetries; for (int i = 0; i < Num500Errors; i++) { - httpManager.AddMockHandler(MockHelpers.MockCsrResponse(HttpStatusCode.InternalServerError)); + // `retry: true` indicates a retriable status code will be returned + httpManager.AddMockHandler(MockHelpers.MockImdsProbeFailure(ImdsVersion.V2, retry: true)); } + httpManager.AddMockHandler(MockHelpers.MockImdsProbe(ImdsVersion.V1)); + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, addProbeMock: false, addSourceCheck: false).ConfigureAwait(false); - var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); - Assert.AreEqual(ManagedIdentitySource.DefaultToImds, miSource); + var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync(ManagedIdentityTests.ImdsProbesCancellationToken).ConfigureAwait(false); + Assert.AreEqual(ManagedIdentitySource.Imds, miSource); } } + #endregion + #region CSR Metadata Tests [TestMethod] - public async Task GetCsrMetadataAsyncFails404WhichIsNonRetriableAndRetryPolicyIsNotTriggeredAsync() + public async Task GetCsrMetadataAsyncFailsWithMissingServerHeader() { using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); - httpManager.AddMockHandler(MockHelpers.MockCsrResponse(HttpStatusCode.NotFound)); + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager).ConfigureAwait(false); - var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, addProbeMock: false, addSourceCheck: false).ConfigureAwait(false); + httpManager.AddMockHandler(MockHelpers.MockCsrResponse(responseServerHeader: null)); + + var ex = await Assert.ThrowsExceptionAsync(async () => + await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .ExecuteAsync().ConfigureAwait(false) + ).ConfigureAwait(false); - var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); - Assert.AreEqual(ManagedIdentitySource.DefaultToImds, miSource); + Assert.AreEqual(MsalError.ManagedIdentityRequestFailed, ex.ErrorCode); } } - [DataTestMethod] - [DataRow(UserAssignedIdentityId.None, null)] // SAMI - [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId)] // UAMI - [DataRow(UserAssignedIdentityId.ResourceId, TestConstants.MiResourceId)] // UAMI - [DataRow(UserAssignedIdentityId.ObjectId, TestConstants.ObjectId)] // UAMI - public async Task ProbeDoesNotFireWhenMtlsPopNotRequested( - UserAssignedIdentityId userAssignedIdentityId, - string userAssignedId) + [TestMethod] + public async Task GetCsrMetadataAsyncFailsWithInvalidFormat() { using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { - SetEnvironmentVariables(ManagedIdentitySource.Imds, TestConstants.ImdsEndpoint); - - ManagedIdentityApplicationBuilder miBuilder = null; - - var uami = userAssignedIdentityId != UserAssignedIdentityId.None && userAssignedId != null; - if (uami) - { - miBuilder = CreateMIABuilder(userAssignedId, userAssignedIdentityId); - } - else - { - miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned); - } - - miBuilder - .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory); + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); - var managedIdentityApp = miBuilder.Build(); + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager).ConfigureAwait(false); - // mock probe to show ImdsV2 is available - httpManager.AddMockHandler(MockHelpers.MockCsrResponse(userAssignedIdentityId: userAssignedIdentityId, userAssignedId: userAssignedId)); - - var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); - // this indicates ImdsV2 is available - Assert.AreEqual(ManagedIdentitySource.ImdsV2, miSource); - - httpManager.AddManagedIdentityMockHandler( - ManagedIdentityTests.ImdsEndpoint, - ManagedIdentityTests.Resource, - MockHelpers.GetMsiSuccessfulResponse(), - ManagedIdentitySource.Imds, - userAssignedId: userAssignedId, - userAssignedIdentityId: userAssignedIdentityId); + httpManager.AddMockHandler(MockHelpers.MockCsrResponse(responseServerHeader: "I_MDS/150.870.65.1854")); - // ImdsV1 flow will be used since .WithMtlsProofOfPossession() is not used here - var result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource).ExecuteAsync().ConfigureAwait(false); + var ex = await Assert.ThrowsExceptionAsync(async () => + await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .ExecuteAsync().ConfigureAwait(false) + ).ConfigureAwait(false); - Assert.IsNotNull(result); - Assert.IsNotNull(result.AccessToken); - Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); + Assert.AreEqual(MsalError.ManagedIdentityRequestFailed, ex.ErrorCode); } } + #endregion CSR Metadata Tests - #region Cuid Tests + #region CSR Generation Tests [TestMethod] public void TestCsrGeneration_OnlyVmId() { @@ -600,7 +596,6 @@ public void TestCsrGeneration_VmIdAndVmssId() var (csr, _) = Csr.Generate(rsa, TestConstants.ClientId, TestConstants.TenantId, cuid); CsrValidator.ValidateCsrContent(csr, TestConstants.ClientId, TestConstants.TenantId, cuid); } - #endregion [DataTestMethod] [DataRow("Invalid@#$%Certificate!")] @@ -611,6 +606,7 @@ public void TestCsrGeneration_BadCert_ThrowsMsalServiceException(string badCert) Assert.ThrowsException(() => CsrValidator.ParseRawCsr(badCert)); } + #endregion CSR Generation Tests #region AttachPrivateKeyToCert Tests [TestMethod] @@ -1529,42 +1525,5 @@ private static void AssertCertSubjectCnDc(X509Certificate2 cert, string expected } #endregion - - #region Fallback Behavior Tests - // Verifies non-mTLS request after IMDSv2 detection falls back per-request to IMDSv1 (Bearer), - [DataTestMethod] - [DataRow(UserAssignedIdentityId.None, null)] - [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId)] - public async Task NonMtlsRequest_FallbackToImdsV1( - UserAssignedIdentityId idKind, - string idValue) - { - using (new EnvVariableContext()) - using (var httpManager = new MockHttpManager()) - { - ManagedIdentityClient.ResetSourceForTest(); - SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); - - var mi = await CreateManagedIdentityAsync(httpManager, idKind, idValue, managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false); - - // Fallback token (Bearer) mock - httpManager.AddManagedIdentityMockHandler( - ManagedIdentityTests.ImdsEndpoint, - ManagedIdentityTests.Resource, - MockHelpers.GetMsiSuccessfulResponse(), - ManagedIdentitySource.Imds, - userAssignedIdentityId: idKind, - userAssignedId: idValue); - - var token = await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) - // No .WithMtlsProofOfPossession() => triggers fallback - .ExecuteAsync().ConfigureAwait(false); - - Assert.AreEqual(Bearer, token.TokenType); - Assert.IsNull(token.BindingCertificate, "Bearer token should not have binding certificate."); - Assert.AreEqual(TokenSource.IdentityProvider, token.AuthenticationResultMetadata.TokenSource); - } - } - #endregion } } diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs index 7ae48f4a54..a4039d0791 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs @@ -4,7 +4,6 @@ using System; using System.Collections.Generic; using System.Diagnostics; -using System.Linq; using System.Net; using System.Net.Http; using System.Net.Sockets; @@ -14,6 +13,7 @@ using Microsoft.Identity.Client.AppConfig; using Microsoft.Identity.Client.Internal; using Microsoft.Identity.Client.ManagedIdentity; +using Microsoft.Identity.Client.ManagedIdentity.V2; using Microsoft.Identity.Client.TelemetryCore.Internal.Events; using Microsoft.Identity.Test.Common; using Microsoft.Identity.Test.Common.Core.Helpers; @@ -40,21 +40,37 @@ public class ManagedIdentityTests : TestBase internal const string ExpectedErrorCode = "ErrorCode"; internal const string ExpectedCorrelationId = "Some GUID"; + internal static CancellationToken ImdsProbesCancellationToken = new CancellationTokenSource(TimeSpan.FromMinutes(5)).Token; // never timeout for the unit tests + private readonly TestRetryPolicyFactory _testRetryPolicyFactory = new TestRetryPolicyFactory(); + // MtlsPop is disabled for all these tests, so no need to mock IMDSv2 probe here + internal static void MockImdsV1Probe( + MockHttpManager httpManager, + ManagedIdentitySource managedIdentitySource, + UserAssignedIdentityId userAssignedIdentityId = UserAssignedIdentityId.None, + string userAssignedId = null) + { + if (managedIdentitySource == ManagedIdentitySource.Imds) + { + httpManager.AddMockHandler(MockHelpers.MockImdsProbe(ImdsVersion.V1, userAssignedIdentityId, userAssignedId)); + } + } + [DataTestMethod] - [DataRow("http://127.0.0.1:41564/msi/token/", ManagedIdentitySource.AppService, ManagedIdentitySource.AppService)] - [DataRow(AppServiceEndpoint, ManagedIdentitySource.AppService, ManagedIdentitySource.AppService)] - [DataRow(ImdsEndpoint, ManagedIdentitySource.Imds, ManagedIdentitySource.DefaultToImds)] - [DataRow(null, ManagedIdentitySource.Imds, ManagedIdentitySource.DefaultToImds)] - [DataRow(AzureArcEndpoint, ManagedIdentitySource.AzureArc, ManagedIdentitySource.AzureArc)] - [DataRow(CloudShellEndpoint, ManagedIdentitySource.CloudShell, ManagedIdentitySource.CloudShell)] - [DataRow(ServiceFabricEndpoint, ManagedIdentitySource.ServiceFabric, ManagedIdentitySource.ServiceFabric)] - [DataRow(MachineLearningEndpoint, ManagedIdentitySource.MachineLearning, ManagedIdentitySource.MachineLearning)] + [DataRow("http://127.0.0.1:41564/msi/token/", ManagedIdentitySource.AppService)] + [DataRow(AppServiceEndpoint, ManagedIdentitySource.AppService)] + [DataRow(ImdsEndpoint, ManagedIdentitySource.Imds)] + [DataRow(null, ManagedIdentitySource.Imds)] + [DataRow(ImdsEndpoint, ManagedIdentitySource.ImdsV2)] + [DataRow(null, ManagedIdentitySource.ImdsV2)] + [DataRow(AzureArcEndpoint, ManagedIdentitySource.AzureArc)] + [DataRow(CloudShellEndpoint, ManagedIdentitySource.CloudShell)] + [DataRow(ServiceFabricEndpoint, ManagedIdentitySource.ServiceFabric)] + [DataRow(MachineLearningEndpoint, ManagedIdentitySource.MachineLearning)] public async Task GetManagedIdentityTests( string endpoint, - ManagedIdentitySource managedIdentitySource, - ManagedIdentitySource expectedManagedIdentitySource) + ManagedIdentitySource managedIdentitySource) { using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) @@ -66,7 +82,17 @@ public async Task GetManagedIdentityTests( ManagedIdentityApplication mi = miBuilder.Build() as ManagedIdentityApplication; - Assert.AreEqual(expectedManagedIdentitySource, await mi.GetManagedIdentitySourceAsync().ConfigureAwait(false)); + if (managedIdentitySource == ManagedIdentitySource.ImdsV2) + { + httpManager.AddMockHandler(MockHelpers.MockImdsProbe(ImdsVersion.V2)); + } + else if (managedIdentitySource == ManagedIdentitySource.Imds) + { + httpManager.AddMockHandler(MockHelpers.MockImdsProbeFailure(ImdsVersion.V2)); + httpManager.AddMockHandler(MockHelpers.MockImdsProbe(ImdsVersion.V1)); + } + + Assert.AreEqual(managedIdentitySource, await mi.GetManagedIdentitySourceAsync(ImdsProbesCancellationToken).ConfigureAwait(false)); } } @@ -99,6 +125,8 @@ public async Task SAMIHappyPathAsync( var mi = miBuilder.Build(); + MockImdsV1Probe(httpManager, managedIdentitySource); + httpManager.AddManagedIdentityMockHandler( endpoint, Resource, @@ -148,6 +176,8 @@ public async Task UAMIHappyPathAsync( var mi = miBuilder.Build(); + MockImdsV1Probe(httpManager, managedIdentitySource, userAssignedIdentityId, userAssignedId); + httpManager.AddManagedIdentityMockHandler( endpoint, Resource, @@ -194,6 +224,8 @@ public async Task ManagedIdentityDifferentScopesTestAsync( var mi = miBuilder.Build(); + MockImdsV1Probe(httpManager, managedIdentitySource); + httpManager.AddManagedIdentityMockHandler( endpoint, Resource, @@ -251,6 +283,8 @@ public async Task ManagedIdentityForceRefreshTestAsync( var mi = miBuilder.Build(); + MockImdsV1Probe(httpManager, managedIdentitySource); + httpManager.AddManagedIdentityMockHandler( endpoint, Resource, @@ -310,6 +344,8 @@ public async Task ManagedIdentityWithClaimsAndCapabilitiesTestAsync( var mi = miBuilder.Build(); + MockImdsV1Probe(httpManager, managedIdentitySource); + httpManager.AddManagedIdentityMockHandler( endpoint, Resource, @@ -372,6 +408,8 @@ public async Task ManagedIdentityWithClaimsTestAsync( var mi = miBuilder.Build(); + MockImdsV1Probe(httpManager, managedIdentitySource); + httpManager.AddManagedIdentityMockHandler( endpoint, Resource, @@ -443,6 +481,8 @@ public async Task ManagedIdentityTestWrongScopeAsync(string resource, ManagedIde var mi = miBuilder.Build(); + MockImdsV1Probe(httpManager, managedIdentitySource); + httpManager.AddManagedIdentityMockHandler(endpoint, resource, MockHelpers.GetMsiErrorResponse(managedIdentitySource), managedIdentitySource, statusCode: HttpStatusCode.InternalServerError); httpManager.AddManagedIdentityMockHandler(endpoint, resource, MockHelpers.GetMsiErrorResponse(managedIdentitySource), @@ -546,6 +586,8 @@ public async Task ManagedIdentityErrorResponseNoPayloadTestAsync(ManagedIdentity var mi = miBuilder.Build(); + MockImdsV1Probe(httpManager, managedIdentitySource); + httpManager.AddManagedIdentityMockHandler(endpoint, "scope", "", managedIdentitySource, statusCode: HttpStatusCode.InternalServerError); httpManager.AddManagedIdentityMockHandler(endpoint, "scope", "", @@ -585,6 +627,8 @@ public async Task ManagedIdentityNullResponseAsync(ManagedIdentitySource managed var mi = miBuilder.Build(); + MockImdsV1Probe(httpManager, managedIdentitySource); + httpManager.AddManagedIdentityMockHandler( endpoint, Resource, @@ -622,6 +666,8 @@ public async Task ManagedIdentityUnreachableNetworkAsync(ManagedIdentitySource m var mi = miBuilder.Build(); + MockImdsV1Probe(httpManager, managedIdentitySource); + httpManager.AddFailingRequest(new HttpRequestException("A socket operation was attempted to an unreachable network.", new SocketException(10051))); @@ -709,10 +755,7 @@ public async Task ManagedIdentityCacheTestAsync() var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - - - var mi = miBuilder.BuildConcrete(); CancellationTokenSource cts = new CancellationTokenSource(); @@ -773,7 +816,6 @@ public async Task ManagedIdentityExpiresOnTestAsync(int expiresInHours, bool ref Assert.AreEqual(ApiEvent.ApiIds.AcquireTokenForSystemAssignedManagedIdentity, builder.CommonParameters.ApiId); Assert.AreEqual(refreshOnHasValue, result.AuthenticationResultMetadata.RefreshOn.HasValue); Assert.IsTrue(result.ExpiresOn > DateTimeOffset.UtcNow, "The token's ExpiresOn should be in the future."); - } } @@ -815,10 +857,7 @@ public async Task ManagedIdentityIsProActivelyRefreshedAsync() var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - - - var mi = miBuilder.BuildConcrete(); httpManager.AddManagedIdentityMockHandler( @@ -842,13 +881,11 @@ public async Task ManagedIdentityIsProActivelyRefreshedAsync() MockHelpers.GetMsiSuccessfulResponse(), ManagedIdentitySource.AppService); - // Act Trace.WriteLine("4. ATM - should perform an RT refresh"); result = await mi.AcquireTokenForManagedIdentity(Resource) .ExecuteAsync() .ConfigureAwait(false); - // Assert TestCommon.YieldTillSatisfied(() => httpManager.QueueSize == 0); Assert.IsNotNull(result); @@ -883,10 +920,7 @@ public async Task ProactiveRefresh_CancelsSuccessfully_Async() var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithLogging(LocalLogCallback) .WithHttpManager(httpManager); - - - var mi = miBuilder.BuildConcrete(); httpManager.AddManagedIdentityMockHandler( @@ -906,12 +940,10 @@ public async Task ProactiveRefresh_CancelsSuccessfully_Async() cts.Cancel(); cts.Dispose(); - // Act result = await mi.AcquireTokenForManagedIdentity(Resource) .ExecuteAsync(cancellationToken) .ConfigureAwait(false); - // Assert Assert.IsTrue(TestCommon.YieldTillSatisfied(() => wasErrorLogged)); void LocalLogCallback(LogLevel level, string message, bool containsPii) @@ -941,10 +973,7 @@ public async Task ParallelRequests_CallTokenEndpointOnceAsync() var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - - - var mi = miBuilder.BuildConcrete(); httpManager.AddManagedIdentityMockHandler( @@ -1020,6 +1049,8 @@ public async Task InvalidJsonResponseHandling(ManagedIdentitySource managedIdent var mi = miBuilder.Build(); + MockImdsV1Probe(httpManager, managedIdentitySource); + httpManager.AddManagedIdentityMockHandler( endpoint, "scope", @@ -1060,6 +1091,8 @@ public async Task ManagedIdentityRequestTokensForDifferentScopesTestAsync( var mi = miBuilder.Build(); + MockImdsV1Probe(httpManager, managedIdentitySource); + // Mock handler for the initial resource request httpManager.AddManagedIdentityMockHandler(endpoint, initialResource, MockHelpers.GetMsiSuccessfulResponse(), managedIdentitySource); @@ -1090,20 +1123,18 @@ public async Task ManagedIdentityRequestTokensForDifferentScopesTestAsync( } } - [DataTestMethod] - [DataRow(ManagedIdentitySource.AppService)] - [DataRow(ManagedIdentitySource.Imds)] - public async Task UnsupportedManagedIdentitySource_ThrowsExceptionDuringTokenAcquisitionAsync( - ManagedIdentitySource managedIdentitySource) + // probe will fail for IMDS (due to unsupported endpoint) before Token Acquisition is attempted + [TestMethod] + public async Task UnsupportedNonImdsManagedIdentitySource_ThrowsExceptionDuringTokenAcquisitionAsync() { string UnsupportedEndpoint = "unsupported://endpoint"; using (new EnvVariableContext()) { - SetEnvironmentVariables(managedIdentitySource, UnsupportedEndpoint); + SetEnvironmentVariables(ManagedIdentitySource.AppService, UnsupportedEndpoint); - // Create the Managed Identity Application - var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned); + var miBuilder = ManagedIdentityApplicationBuilder + .Create(ManagedIdentityId.SystemAssigned); var mi = miBuilder.Build(); @@ -1117,6 +1148,32 @@ await mi.AcquireTokenForManagedIdentity("https://management.azure.com") } } + [TestMethod] + public async Task UnavailableManagedIdentitySource_ThrowsExceptionDuringTokenAcquisitionAsync() + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.Imds, ImdsEndpoint); + + var miBuilder = ManagedIdentityApplicationBuilder + .Create(ManagedIdentityId.SystemAssigned) + .WithHttpManager(httpManager); + + var mi = miBuilder.Build(); + + httpManager.AddMockHandler(MockHelpers.MockImdsProbeFailure(ImdsVersion.V1)); + + var ex = await Assert.ThrowsExceptionAsync(async () => + await mi.AcquireTokenForManagedIdentity("https://management.azure.com") + .ExecuteAsync() + .ConfigureAwait(false)).ConfigureAwait(false); + + Assert.IsNotNull(ex); + Assert.AreEqual(MsalError.ManagedIdentityAllSourcesUnavailable, ex.ErrorCode); + } + } + [TestMethod] public async Task MixedUserAndSystemAssignedManagedIdentityTestAsync() { @@ -1130,9 +1187,10 @@ public async Task MixedUserAndSystemAssignedManagedIdentityTestAsync() string SystemAssignedClientId = "system_assigned_managed_identity"; // Create a builder for user-assigned identity - var userAssignedBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.WithUserAssignedClientId(UserAssignedClientId)) + var userAssignedBuilder = ManagedIdentityApplicationBuilder + .Create(ManagedIdentityId + .WithUserAssignedClientId(UserAssignedClientId)) .WithHttpManager(httpManager); - userAssignedBuilder.Config.AccessorOptions = null; @@ -1158,7 +1216,6 @@ public async Task MixedUserAndSystemAssignedManagedIdentityTestAsync() // Verify user-assigned cache entries userAssignedCacheRecorder.AssertAccessCounts(1, 1); - // Create a builder for system-assigned identity var systemAssignedBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); @@ -1166,10 +1223,8 @@ public async Task MixedUserAndSystemAssignedManagedIdentityTestAsync() var systemAssignedMI = systemAssignedBuilder.BuildConcrete(); - // Record token cache access for system-assigned identity var systemAssignedCacheRecorder = systemAssignedMI.AppTokenCacheInternal.RecordAccess(); - // Mock handler for system-assigned token httpManager.AddManagedIdentityMockHandler( AppServiceEndpoint, Resource, @@ -1314,6 +1369,8 @@ public async Task ManagedIdentityWithCapabilitiesTestAsync( var mi = miBuilder.Build(); + MockImdsV1Probe(httpManager, managedIdentitySource); + httpManager.AddManagedIdentityMockHandler( endpoint, Resource, @@ -1338,43 +1395,44 @@ public async Task ManagedIdentityWithCapabilitiesTestAsync( } } - [TestMethod] - public void ValidateServerCertificate_OnlySetForServiceFabric() + [DataTestMethod] + [DataRow(ManagedIdentitySource.AppService)] + [DataRow(ManagedIdentitySource.AzureArc)] + [DataRow(ManagedIdentitySource.CloudShell)] + [DataRow(ManagedIdentitySource.Imds)] + [DataRow(ManagedIdentitySource.ImdsV2)] + [DataRow(ManagedIdentitySource.ServiceFabric)] + [DataRow(ManagedIdentitySource.MachineLearning)] + public void ValidateServerCertificate_OnlySetForServiceFabric(ManagedIdentitySource managedIdentitySource) { using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { - // Test all managed identity sources - foreach (ManagedIdentitySource sourceType in Enum.GetValues(typeof(ManagedIdentitySource)) - .Cast() - .Where(s => s != ManagedIdentitySource.None && s != ManagedIdentitySource.DefaultToImds && s != ManagedIdentitySource.ImdsV2)) - { - // Create a managed identity source for each type - AbstractManagedIdentity managedIdentity = CreateManagedIdentitySource(sourceType, httpManager); + // Create a managed identity source for each type + AbstractManagedIdentity managedIdentity = CreateManagedIdentitySource(managedIdentitySource, httpManager); - // Check if ValidateServerCertificate is set based on the source type - bool shouldHaveCallback = sourceType == ManagedIdentitySource.ServiceFabric; - bool hasCallback = managedIdentity.GetValidationCallback() != null; + // Check if ValidateServerCertificate is set based on the source type + bool shouldHaveCallback = managedIdentitySource == ManagedIdentitySource.ServiceFabric; + bool hasCallback = managedIdentity.GetValidationCallback() != null; - Assert.AreEqual( - shouldHaveCallback, - hasCallback, - $"For source type {sourceType}, ValidateServerCertificate should {(shouldHaveCallback ? "" : "not ")}be set"); + Assert.AreEqual( + shouldHaveCallback, + hasCallback, + $"For source type {managedIdentitySource}, ValidateServerCertificate should {(shouldHaveCallback ? "" : "not ")}be set"); - // For ServiceFabric, verify it's set to the right method - if (sourceType == ManagedIdentitySource.ServiceFabric) - { - Assert.IsNotNull(managedIdentity.GetValidationCallback(), - "ServiceFabric should have ValidateServerCertificate set"); + // For ServiceFabric, verify it's set to the right method + if (managedIdentitySource == ManagedIdentitySource.ServiceFabric) + { + Assert.IsNotNull(managedIdentity.GetValidationCallback(), + "ServiceFabric should have ValidateServerCertificate set"); - Assert.IsInstanceOfType(managedIdentity, typeof(ServiceFabricManagedIdentitySource), - "ServiceFabric managed identity should be of type ServiceFabricManagedIdentitySource"); - } - else - { - Assert.IsNull(managedIdentity.GetValidationCallback(), - $"Non-ServiceFabric source type {sourceType} should not have ValidateServerCertificate set"); - } + Assert.IsInstanceOfType(managedIdentity, typeof(ServiceFabricManagedIdentitySource), + "ServiceFabric managed identity should be of type ServiceFabricManagedIdentitySource"); + } + else + { + Assert.IsNull(managedIdentity.GetValidationCallback(), + $"Non-ServiceFabric source type {managedIdentitySource} should not have ValidateServerCertificate set"); } } } @@ -1397,9 +1455,6 @@ private AbstractManagedIdentity CreateManagedIdentitySource(ManagedIdentitySourc switch (sourceType) { - case ManagedIdentitySource.ServiceFabric: - managedIdentity = ServiceFabricManagedIdentitySource.Create(requestContext); - break; case ManagedIdentitySource.AppService: managedIdentity = AppServiceManagedIdentitySource.Create(requestContext); break; @@ -1412,9 +1467,15 @@ private AbstractManagedIdentity CreateManagedIdentitySource(ManagedIdentitySourc case ManagedIdentitySource.Imds: managedIdentity = new ImdsManagedIdentitySource(requestContext); break; + case ManagedIdentitySource.ImdsV2: + managedIdentity = new ImdsV2ManagedIdentitySource(requestContext); + break; case ManagedIdentitySource.MachineLearning: managedIdentity = MachineLearningManagedIdentitySource.Create(requestContext); break; + case ManagedIdentitySource.ServiceFabric: + managedIdentity = ServiceFabricManagedIdentitySource.Create(requestContext); + break; default: throw new NotSupportedException($"Unsupported managed identity source type: {sourceType}"); } @@ -1437,7 +1498,8 @@ public async Task ManagedIdentityWithExtraQueryParametersTestAsync() { "custom_param", "custom_value" } }; - var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) + var miBuilder = ManagedIdentityApplicationBuilder + .Create(ManagedIdentityId.SystemAssigned) .WithExperimentalFeatures(true) .WithExtraQueryParameters(extraQueryParameters) .WithHttpManager(httpManager);