Program.cs 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. using System;
  2. using System.Collections.Concurrent;
  3. using System.Collections.Generic;
  4. using System.Collections.ObjectModel;
  5. using System.IO;
  6. using System.Linq;
  7. using System.Net;
  8. using System.Net.Sockets;
  9. using System.Text;
  10. namespace DnsForwarder
  11. {
  12. internal class Program
  13. {
  14. private static string _defaultDns;
  15. private static string _cnDns;
  16. private static IReadOnlyDictionary<string, int> _chinaList;
  17. private static string[] _chinaListDns;
  18. private static Socket _listener;
  19. private static BlockingCollection<string> _consoleOutout;
  20. private static void Main(string[] args)
  21. {
  22. if (args.Length != 3)
  23. {
  24. Console.WriteLine("<default dns server ip> <.CN dns server ip> <path to dnsmasq-china-list file>");
  25. Environment.Exit(-1);
  26. return;
  27. }
  28. Console.WriteLine("Starting...");
  29. Console.WriteLine($"Default Server: {_defaultDns = args[0]}");
  30. Console.WriteLine($".CN DNS Server: {_cnDns = args[1]}");
  31. Console.WriteLine($"dnsmasq-china-list: {args[2]}");
  32. Console.WriteLine("Loading list file...");
  33. LoadListFile(args[2]);
  34. Console.WriteLine("OK. Listening...");
  35. _listener = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
  36. _listener.Bind(new IPEndPoint(IPAddress.Any, 53));
  37. _consoleOutout = new BlockingCollection<string>();
  38. StartHandleRequest();
  39. while (true)
  40. {
  41. Console.WriteLine(_consoleOutout.Take());
  42. }
  43. }
  44. private static void StartHandleRequest()
  45. {
  46. var buf = new byte[1500];
  47. EndPoint from = new IPEndPoint(IPAddress.Any, 0);
  48. _listener.BeginReceiveFrom(buf, 0, buf.Length, SocketFlags.None, ref from, Callback, buf);
  49. }
  50. private static void Callback(IAsyncResult ar)
  51. {
  52. EndPoint from = new IPEndPoint(IPAddress.Any, 0);
  53. var count = _listener.EndReceiveFrom(ar, ref from);
  54. var buf = (byte[])ar.AsyncState;
  55. StartHandleRequest();
  56. var domain = ExtractDomainName(buf, count);
  57. var target = MatchServer(domain);
  58. try
  59. {
  60. var dnsResponse = GetDnsResponse(buf, count, target);
  61. if (dnsResponse != null) _listener.SendTo(dnsResponse, from);
  62. _consoleOutout.Add($"{DateTime.Now:yyyyMMdd HH:mm:ss} {from} [{target}]\t{domain}");
  63. }
  64. catch (Exception e)
  65. {
  66. _consoleOutout.Add($"{DateTime.Now:yyyyMMdd HH:mm:ss} {from} [{target}]\t{domain} ERR:{e.Message}");
  67. }
  68. }
  69. protected static string ExtractDomainName(byte[] buf, int count)
  70. {
  71. var lst = new List<byte[]>();
  72. var ptr = 12;
  73. while (buf[ptr] != 0)
  74. {
  75. ptr += buf[ptr] + 1;
  76. }
  77. var bufDomain = new byte[ptr - 12];
  78. Array.Copy(buf, 12, bufDomain, 0, bufDomain.Length);
  79. //dot fill
  80. ptr = 0;
  81. while (ptr < bufDomain.Length)
  82. {
  83. var b = bufDomain[ptr];
  84. bufDomain[ptr] = (byte)'.';
  85. ptr += b + 1;
  86. }
  87. return Encoding.ASCII.GetString(bufDomain, 1, bufDomain.Length - 1);
  88. }
  89. protected static byte[] GetDnsResponse(byte[] buf, int count, string host, int port = 53)
  90. {
  91. using var to = new UdpClient();
  92. to.Connect(host, 53);
  93. to.Send(buf, count);
  94. //Handle Upstream TimeOut
  95. var asyncResult = to.BeginReceive(null, null);
  96. asyncResult.AsyncWaitHandle.WaitOne(2000);
  97. if (asyncResult.IsCompleted)
  98. {
  99. IPEndPoint remoteEP = null;
  100. byte[] receivedData = to.EndReceive(asyncResult, ref remoteEP);
  101. return receivedData;
  102. }
  103. return null;
  104. }
  105. protected static string MatchServer(string domain)
  106. {
  107. var lower = domain.ToLower();
  108. if (lower.EndsWith(".cn")) return _cnDns;
  109. var parts = lower.Split('.').Reverse().ToArray();
  110. for (int i = parts.Length - 1; i >= 0; i--)
  111. {
  112. var d = string.Join(".", parts.Take(i + 1).Reverse());
  113. if (_chinaList.TryGetValue(d, out var tar)) return _chinaListDns[tar];
  114. }
  115. return _defaultDns;
  116. }
  117. protected static void LoadListFile(params string[] paths)
  118. {
  119. var lines = paths.SelectMany(File.ReadAllLines);
  120. var dic = new Dictionary<string, int>();
  121. var tar = new List<string>();
  122. foreach (var line in lines)
  123. {
  124. var p = line.Trim();
  125. if (p.StartsWith("#")) continue;
  126. var parts = p.Split('/');
  127. if (parts.Length != 3 && parts[0] != "server=") continue;
  128. var domain = parts[1];
  129. var dns = parts[2];
  130. var dnsIndex = tar.IndexOf(dns);
  131. if (dnsIndex == -1)
  132. {
  133. tar.Add(dns);
  134. dnsIndex = tar.Count - 1;
  135. }
  136. dic[domain] = dnsIndex;
  137. }
  138. _chinaListDns = tar.ToArray();
  139. _chinaList = new ReadOnlyDictionary<string, int>(dic);
  140. }
  141. }
  142. }