Skip to content
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Identity.Client.Core;

namespace Microsoft.Identity.Client.ManagedIdentity.V2
{
internal interface IMtlsBindingCache
{
Task<Tuple<X509Certificate2, string /*endpoint*/, string /*clientId*/>> GetOrCreateAsync(
string cacheKey,
Func<Task<Tuple<X509Certificate2, string, string>>> factory,
CancellationToken cancellationToken,
ILoggerAdapter logger);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System.Security.Cryptography.X509Certificates;
using Microsoft.Identity.Client.Core;

namespace Microsoft.Identity.Client.ManagedIdentity.V2
{
/// <summary>
/// Persistence interface for IMDSv2 mTLS binding certificates.
/// Implementations must be best-effort and non-throwing.
/// </summary>
internal interface IPersistentCertificateCache
{
/// <summary>
/// Reads the newest valid (≥24h remaining, has private key) entry for the alias.
/// </summary>
bool Read(string alias, out CertificateCacheValue value, ILoggerAdapter logger = null);

/// <summary>
/// Persists the certificate for the alias (best-effort). Implementations should
/// tag entries to allow alias scoping and prune expired duplicates conservatively.
/// </summary>
void Write(string alias, X509Certificate2 cert, string endpointBase, ILoggerAdapter logger = null);

/// <summary>
/// Prunes expired entries for the alias (best-effort).
/// </summary>
void Delete(string alias, ILoggerAdapter logger = null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ internal class ImdsV2ManagedIdentitySource : AbstractManagedIdentity
// Central, process-local cache for mTLS binding (cert + endpoint + canonical client_id).
internal static readonly ICertificateCache s_mtlsCertificateCache = new InMemoryCertificateCache();

// Per-key async de-duplication so concurrent callers don’t double-mint.
internal static readonly ConcurrentDictionary<string, SemaphoreSlim> s_perKeyGates =
new ConcurrentDictionary<string, SemaphoreSlim>(StringComparer.Ordinal);
private readonly IMtlsBindingCache _mtlsCache;

// used in unit tests
public const string ImdsV2ApiVersion = "2.0";
Expand Down Expand Up @@ -195,7 +193,12 @@ public static AbstractManagedIdentity Create(RequestContext requestContext)

internal ImdsV2ManagedIdentitySource(RequestContext requestContext) :
base(requestContext, ManagedIdentitySource.ImdsV2)
{ }
{
IPersistentCertificateCache persisted =
PersistentCertificateCacheFactory.Create(requestContext.Logger);

_mtlsCache = new MtlsBindingCache(s_mtlsCertificateCache, persisted);
}

private async Task<CertificateRequestResponse> ExecuteCertificateRequestAsync(
string clientId,
Expand Down Expand Up @@ -440,65 +443,13 @@ private async Task<string> GetAttestationJwtAsync(
return response.AttestationToken;
}

// ...unchanged usings and class header...

/// <summary>
/// Read-through cache: try cache; if missing, run async factory once (per key),
/// store the result, and return it. Thread-safe for the given cacheKey.
/// </summary>
private static async Task<Tuple<X509Certificate2, string, string>> GetOrCreateMtlsBindingAsync(
private Task<Tuple<X509Certificate2, string, string>> GetOrCreateMtlsBindingAsync(
string cacheKey,
Func<Task<Tuple<X509Certificate2, string, string>>> factory,
CancellationToken cancellationToken,
ILoggerAdapter logger)
{
if (string.IsNullOrWhiteSpace(cacheKey))
throw new ArgumentException("cacheKey must be non-empty.", nameof(cacheKey));
if (factory is null)
throw new ArgumentNullException(nameof(factory));

X509Certificate2 cachedCertificate;
string cachedEndpointBase;
string cachedClientId;

// 1) Only lookup by cacheKey
if (s_mtlsCertificateCache.TryGet(cacheKey, out var cached, logger))
{
cachedCertificate = cached.Certificate;
cachedEndpointBase = cached.Endpoint;
cachedClientId = cached.ClientId;

return Tuple.Create(cachedCertificate, cachedEndpointBase, cachedClientId);
}

// 2) Gate per cacheKey
var gate = s_perKeyGates.GetOrAdd(cacheKey, _ => new SemaphoreSlim(1, 1));
await gate.WaitAsync(cancellationToken).ConfigureAwait(false);

try
{
// Re-check after acquiring the gate
if (s_mtlsCertificateCache.TryGet(cacheKey, out cached, logger))
{
cachedCertificate = cached.Certificate;
cachedEndpointBase = cached.Endpoint;
cachedClientId = cached.ClientId;
return Tuple.Create(cachedCertificate, cachedEndpointBase, cachedClientId);
}

// 3) Mint + cache under the provided cacheKey
var created = await factory().ConfigureAwait(false);

s_mtlsCertificateCache.Set(cacheKey,
new CertificateCacheValue(created.Item1, created.Item2, created.Item3),
logger);

return created;
}
finally
{
gate.Release();
}
return _mtlsCache.GetOrCreateAsync(cacheKey, factory, cancellationToken, logger);
}

internal static void ResetCertCacheForTest()
Expand All @@ -508,14 +459,6 @@ internal static void ResetCertCacheForTest()
{
s_mtlsCertificateCache.Clear();
}

foreach (var gate in s_perKeyGates.Values)
{
try
{ gate.Dispose(); }
catch { }
}
s_perKeyGates.Clear();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Security.Cryptography;
using System.Text;
using System.Threading;
using Microsoft.Identity.Client.PlatformsCommon.Shared;

namespace Microsoft.Identity.Client.ManagedIdentity.V2
{
/// <summary>
/// Executes paramref name="action"/ under a cross-process, per-alias mutex.
/// We attempt 2 namespaces, in order:
/// 1) <c>Global\</c> — preferred so we dedupe across all sessions on the machine
/// (e.g., service + user session). This can be denied by OS policy or missing
/// SeCreateGlobalPrivilege in some contexts.
/// 2) <c>Local\</c> — fallback to still dedupe within the current session when
/// <c>Global\</c> is not permitted.
/// Using both ensures we never throw (persistence is best-effort) while getting
/// machine-wide dedupe when allowed and session-local dedupe otherwise.
/// Notes:
/// - The mutex name is derived from <c>alias</c> (= cacheKey) via SHA-256 hex (truncated)
/// to avoid invalid characters / length issues.
/// - On non-Windows runtimes the Global/Local prefixes are treated as part of the name;
/// behavior remains correct but dedupe scope is platform-defined.
/// - Abandoned mutexes are treated as acquired to avoid blocking after a crash.
/// </summary>

internal static class InterprocessLock
{
// Prefer Global\ for cross-session dedupe; fall back to Local\
// if ACLs block Global\ to remain non-throwing.
public static bool TryWithAliasLock(
string alias,
TimeSpan timeout,
Action action,
Action<string> logVerbose = null)
{
var nameGlobal = GetMutexNameForAlias(alias, preferGlobal: true);
var nameLocal = GetMutexNameForAlias(alias, preferGlobal: false);

foreach (var name in new[] { nameGlobal, nameLocal })
{
try
{
// Create or open existing
using var m = new Mutex(initiallyOwned: false, name);

// Wait to acquire
bool entered;

try
{
entered = m.WaitOne(timeout);
}
catch (AbandonedMutexException)
{
entered = true; // prior holder crashed
}

if (!entered)
{
logVerbose?.Invoke($"[PersistentCert] Skip persist (lock busy '{name}').");
return false;
}

try
{
action();
}
finally
{
try
{
m.ReleaseMutex();
}
catch
{
/* best-effort */
}
}

return true;
}
catch (UnauthorizedAccessException)
{
logVerbose?.Invoke($"[PersistentCert] No access to mutex scope '{name}', trying next.");
continue; // try Local if Global blocked
}
catch (Exception ex)
{
logVerbose?.Invoke($"[PersistentCert] Lock failure '{name}': {ex.Message}");
return false;
}
}

return false;
}

public static string GetMutexNameForAlias(string alias, bool preferGlobal = true)
{
string suffix = HashAlias(Canonicalize(alias));
return (preferGlobal ? @"Global\" : @"Local\") + "MSAL_MI_P_" + suffix;
}

private static string Canonicalize(string alias) => (alias ?? string.Empty).Trim().ToUpperInvariant();

private static string HashAlias(string s)
{
try
{
var hex = new CommonCryptographyManager().CreateSha256HashHex(s);
// Truncate to 32 chars to fit mutex name length limits
return string.IsNullOrEmpty(hex) ? "0" : (hex.Length > 32 ? hex.Substring(0, 32) : hex);
}
catch
{
return "0";
}
}
}
}
Loading