Program.cs 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. using System.Buffers;
  2. using System.IO.Pipelines;
  3. using Microsoft.AspNetCore.Connections;
  4. using System.Net;
  5. using System.Security.Cryptography.X509Certificates;
  6. using DnsClient;
  7. using System.Text;
  8. #region INIT
  9. using Microsoft.Extensions.Logging.Console;
  10. using System.Net.Security;
  11. using System.Net.Sockets;
  12. using System.Security.Cryptography;
  13. var builder = WebApplication.CreateBuilder(args);
  14. //控制台日志格式
  15. builder.Services.AddLogging(opt =>
  16. {
  17. opt.AddSimpleConsole(p =>
  18. {
  19. p.TimestampFormat = "[dd HH:mm:ss] ";
  20. p.SingleLine = true;
  21. p.ColorBehavior = LoggerColorBehavior.Enabled;
  22. });
  23. });
  24. builder.WebHost.UseKestrel(opt =>
  25. {
  26. opt.Listen(IPAddress.Any, 0, lisOpt =>
  27. {
  28. lisOpt.UseConnectionHandler<NoSniProxyHandler>();
  29. });
  30. });
  31. await using var host = builder.Build();
  32. var logger = host.Services.GetRequiredService<ILogger<Program>>();
  33. logger.LogInformation("Hello, World!");
  34. await host.RunAsync();
  35. #endregion INIT
  36. public class NoSniProxyHandler : ConnectionHandler
  37. {
  38. private const string dnsServerName = "reliable-dns-server-in-hosts";
  39. private static readonly IPAddress? dnsServerIp = Dns.GetHostEntry(dnsServerName).AddressList.FirstOrDefault();
  40. private static readonly LookupClient lookup = new(dnsServerIp);
  41. public override async Task OnConnectedAsync(ConnectionContext connection)
  42. {
  43. var requestStream = connection.Transport.Input;
  44. var responseStream = connection.Transport.Output;
  45. var firstLine = await ReadLineAsync(requestStream);
  46. if (firstLine == null)
  47. {
  48. connection.Abort(new ConnectionAbortedException("Canceled: First line rad fail"));
  49. return;
  50. }
  51. var firstLineParts = firstLine.Split(' ', 3);
  52. if (firstLineParts.Length < 3)
  53. {
  54. connection.Abort(new ConnectionAbortedException("Canceled: First line bad"));
  55. return;
  56. }
  57. var method = firstLineParts[0];
  58. var url = firstLineParts[1];
  59. var ver = firstLineParts[2];
  60. var uri = new Uri(url);
  61. var targetHost = uri.Host;
  62. var result = await lookup.QueryAsync(targetHost, QueryType.A);
  63. var record = result.Answers.ARecords().FirstOrDefault();
  64. var ip = record?.Address;
  65. var tcpClient = new TcpClient();
  66. await tcpClient.ConnectAsync(new IPEndPoint(ip, 443));
  67. var ssl = new SslStream(tcpClient.GetStream());
  68. var sslOptions = new SslClientAuthenticationOptions
  69. {
  70. TargetHost = string.Empty, // Leave this empty to avoid sending SNI
  71. RemoteCertificateValidationCallback = (o, certificate, chain, errors) => VerifyServerCert(targetHost, certificate, chain, errors),
  72. };
  73. await ssl.AuthenticateAsClientAsync(sslOptions, new CancellationTokenSource(TimeSpan.FromSeconds(10)).Token);
  74. ssl.Write(Encoding.ASCII.GetBytes($"{method} {uri.PathAndQuery} {ver}\r\n"));
  75. var outgoing = requestStream.CopyToAsync(ssl);
  76. var inbound = ssl.CopyToAsync(responseStream);
  77. await Task.WhenAll(outgoing, inbound);
  78. }
  79. private async Task<string?> ReadLineAsync(PipeReader requestStream)
  80. {
  81. while (true)
  82. {
  83. var readResult = await requestStream.ReadAsync();
  84. if (readResult.IsCanceled) return null;
  85. var (seq, exm) = ExtractLine(readResult.Buffer, out var line);
  86. requestStream.AdvanceTo(seq, exm);
  87. if (line != null) return line;
  88. }
  89. }
  90. private (SequencePosition pos, SequencePosition exm) ExtractLine(ReadOnlySequence<byte> buffer, out string? line)
  91. {
  92. var reader = new SequenceReader<byte>(buffer);
  93. if (reader.TryReadTo(out ReadOnlySpan<byte> span, "\r\n"u8))
  94. {
  95. if (span.Length > 4096) throw new InvalidDataException("Too long for line");
  96. line = Encoding.ASCII.GetString(span);
  97. return (reader.Position, reader.Position);
  98. }
  99. line = null;
  100. return (buffer.Start, buffer.End);
  101. }
  102. private bool VerifyServerCert(string targetHost, X509Certificate certificate, X509Chain? chain, SslPolicyErrors errs)
  103. {
  104. if (errs == SslPolicyErrors.None) return true;
  105. if (errs != SslPolicyErrors.RemoteCertificateNameMismatch) return false;
  106. if (certificate is not X509Certificate2 cert2) return false;
  107. // 验证证书的有效期
  108. if (DateTime.Now < cert2.NotBefore || DateTime.Now > cert2.NotAfter) return false;
  109. // 然后比较证书名称和主机名称 (任意一个匹配)
  110. var names = GetAllSubjectAlternativeNames(cert2);
  111. var flagNameMatched = false;
  112. foreach (var certName in names)
  113. {
  114. if (certName.StartsWith("*."))
  115. {
  116. if (!targetHost.EndsWith(certName[2..], StringComparison.OrdinalIgnoreCase)) continue;
  117. flagNameMatched = true;
  118. break;
  119. }
  120. if (certName.Equals(targetHost, StringComparison.OrdinalIgnoreCase))
  121. {
  122. flagNameMatched = true;
  123. break;
  124. }
  125. }
  126. if (flagNameMatched == false) return false;
  127. // 构建证书链
  128. if (chain == null) return false;
  129. chain.ChainPolicy.RevocationMode = X509RevocationMode.NoCheck; // 不检查吊销,太耗时了
  130. chain.ChainPolicy.RevocationFlag = X509RevocationFlag.EntireChain;
  131. chain.ChainPolicy.UrlRetrievalTimeout = new TimeSpan(0, 0, 10);
  132. chain.ChainPolicy.VerificationFlags = X509VerificationFlags.NoFlag;
  133. var isValidChain = chain.Build(cert2);
  134. if (isValidChain) return true;
  135. foreach (X509ChainStatus chainStatus in chain.ChainStatus)
  136. {
  137. // 仅处理会影响安全性的错误状态
  138. if (chainStatus.Status == X509ChainStatusFlags.RevocationStatusUnknown ||
  139. chainStatus.Status == X509ChainStatusFlags.OfflineRevocation ||
  140. chainStatus.Status == X509ChainStatusFlags.NoError)
  141. {
  142. continue;
  143. }
  144. // 其他任何错误状态都认为证书无效
  145. return false;
  146. }
  147. return true;
  148. }
  149. IReadOnlyList<string> GetAllSubjectAlternativeNames(X509Certificate2 cert)
  150. {
  151. var names = new HashSet<string>();
  152. foreach (var extension in cert.Extensions)
  153. {
  154. if (extension is X509SubjectAlternativeNameExtension sanExtension)
  155. {
  156. foreach (var name in sanExtension.EnumerateDnsNames()) names.Add(name);
  157. }
  158. else if (extension.Oid?.Value == "2.5.29.17") // Subject Alternative Name OID
  159. {
  160. var asnData = new AsnEncodedData(extension.Oid, extension.RawData);
  161. var sanString = asnData.Format(true);
  162. var sanParts = sanString.Split(new[] { ", ", "DNS Name=", " " }, StringSplitOptions.RemoveEmptyEntries);
  163. foreach (var part in sanParts)
  164. {
  165. if (!string.IsNullOrEmpty(part) && !part.StartsWith("IPAddress") && !part.StartsWith("Uri"))
  166. {
  167. names.Add(part);
  168. }
  169. }
  170. }
  171. }
  172. // 添加证书的CN(通用名称)
  173. var certName = cert.GetNameInfo(X509NameType.DnsName, false);
  174. if (!string.IsNullOrEmpty(certName)) names.Add(certName);
  175. return [.. names];
  176. }
  177. }