Program.cs 6.5 KB


  1. using System;
  2. using System.Collections.Concurrent;
  3. using System.Collections.Generic;
  4. using System.Collections.ObjectModel;
  5. using System.IO;
  6. using System.Linq;
  7. using System.Net;
  8. using System.Net.Sockets;
  9. using System.Text;
  10. namespace DnsForwarder
  11. {
  12. internal class Program
  13. {
  14. private static string _defaultDns;
  15. private static string _cnDns;
  16. private static IReadOnlyDictionary<string, int> _chinaList;
  17. private static string[] _chinaListDns;
  18. private static Socket _listener;
  19. private static BlockingCollection<string> _consoleOutout;
  20. private static void Main(string[] args)
  21. {
  22. if (args.Length != 4)
  23. {
  24. Console.WriteLine("<listen address> <default dns server ip> <.CN dns server ip> <path to dnsmasq-china-list file>");
  25. Environment.Exit(-1);
  26. return;
  27. }
  28. if (!IPAddress.TryParse(args[0], out var listenAddress))
  29. {
  30. Console.WriteLine("Invalid listen address.");
  31. Environment.Exit(-2);
  32. return;
  33. }
  34. Console.WriteLine("Starting...");
  35. Console.WriteLine($"Listen address: {args[0]}");
  36. Console.WriteLine($"Default Server: {_defaultDns = args[1]}");
  37. Console.WriteLine($".CN DNS Server: {_cnDns = args[2]}");
  38. Console.WriteLine($"dnsmasq-china-list: {args[3]}");
  39. Console.WriteLine("Loading list file...");
  40. LoadListFile(args[3]);
  41. Console.WriteLine("OK. Listening...");
  42. _listener = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
  43. _listener.Bind(new IPEndPoint(listenAddress, 53));
  44. _consoleOutout = new BlockingCollection<string>();
  45. StartHandleRequest();
  46. while (true)
  47. {
  48. Console.WriteLine(_consoleOutout.Take());
  49. }
  50. }
  51. private static void StartHandleRequest()
  52. {
  53. while (true)
  54. {
  55. try
  56. {
  57. var buf = new byte[1500];
  58. EndPoint from = new IPEndPoint(IPAddress.Any, 0);
  59. _listener.BeginReceiveFrom(buf, 0, buf.Length, SocketFlags.None, ref from, Callback, buf);
  60. break;
  61. }
  62. catch (Exception e)
  63. {
  64. Console.WriteLine(e);
  65. }
  66. }
  67. }
  68. private static void Callback(IAsyncResult ar)
  69. {
  70. var flagNextStarted = false;
  71. EndPoint from = new IPEndPoint(IPAddress.Any, 0);
  72. string domain = null;
  73. string target = null;
  74. try
  75. {
  76. var count = _listener.EndReceiveFrom(ar, ref from);
  77. StartHandleRequest();
  78. flagNextStarted = true;
  79. var buf = (byte[])ar.AsyncState;
  80. domain = ExtractDomainName(buf);
  81. target = MatchServer(domain);
  82. var dnsResponse = GetDnsResponse(buf, count, target);
  83. if (dnsResponse != null) _listener.SendTo(dnsResponse, from);
  84. _consoleOutout.Add($"{DateTime.Now:yyyyMMdd HH:mm:ss} {from} [{target}]\t{domain}");
  85. }
  86. catch (Exception e)
  87. {
  88. _consoleOutout.Add($"{DateTime.Now:yyyyMMdd HH:mm:ss} {from} [{target ?? "Unknown"}]\t{domain ?? "Unknown"} Err:{e.Message}");
  89. }
  90. catch
  91. {
  92. _consoleOutout.Add($"{DateTime.Now:yyyyMMdd HH:mm:ss} {from} [{target ?? "Unknown"}]\t{domain ?? "Unknown"} Err:Unknown");
  93. }
  94. finally
  95. {
  96. if (flagNextStarted == false) StartHandleRequest();
  97. }
  98. }
  99. protected static string ExtractDomainName(byte[] buf)
  100. {
  101. //seeking for end of domain
  102. var ptr = 12;
  103. while (buf[ptr] != 0)
  104. {
  105. ptr += buf[ptr] + 1;
  106. }
  107. var bufDomain = new byte[ptr - 12];
  108. Array.Copy(buf, 12, bufDomain, 0, bufDomain.Length);
  109. //fill dots
  110. ptr = 0;
  111. while (ptr < bufDomain.Length)
  112. {
  113. var b = bufDomain[ptr];
  114. bufDomain[ptr] = (byte)'.';
  115. ptr += b + 1;
  116. }
  117. return Encoding.ASCII.GetString(bufDomain, 1, bufDomain.Length - 1);
  118. }
  119. protected static byte[] GetDnsResponse(byte[] buf, int count, string host, int port = 53)
  120. {
  121. using var to = new UdpClient();
  122. to.Connect(host, 53);
  123. to.Send(buf, count);
  124. //Handle Upstream TimeOut
  125. var asyncResult = to.BeginReceive(null, null);
  126. asyncResult.AsyncWaitHandle.WaitOne(2000);
  127. if (asyncResult.IsCompleted)
  128. {
  129. IPEndPoint remoteEP = null;
  130. byte[] receivedData = to.EndReceive(asyncResult, ref remoteEP);
  131. return receivedData;
  132. }
  133. return null;
  134. }
  135. protected static string MatchServer(string domain)
  136. {
  137. var lower = domain.ToLower();
  138. if (lower.EndsWith(".cn")) return _cnDns;
  139. var parts = lower.Split('.').Reverse().ToArray();
  140. for (int i = parts.Length - 1; i >= 0; i--)
  141. {
  142. var d = string.Join(".", parts.Take(i + 1).Reverse());
  143. if (_chinaList.TryGetValue(d, out var tar)) return _chinaListDns[tar];
  144. }
  145. return _defaultDns;
  146. }
  147. protected static void LoadListFile(params string[] paths)
  148. {
  149. var lines = paths.SelectMany(File.ReadAllLines);
  150. var dic = new Dictionary<string, int>();
  151. var tar = new List<string>();
  152. foreach (var line in lines)
  153. {
  154. var p = line.Trim();
  155. if (p.StartsWith("#")) continue;
  156. var parts = p.Split('/');
  157. if (parts.Length != 3 && parts[0] != "server=") continue;
  158. var domain = parts[1];
  159. var dns = parts[2];
  160. var dnsIndex = tar.IndexOf(dns);
  161. if (dnsIndex == -1)
  162. {
  163. tar.Add(dns);
  164. dnsIndex = tar.Count - 1;
  165. }
  166. dic[domain] = dnsIndex;
  167. }
  168. _chinaListDns = tar.ToArray();
  169. _chinaList = new ReadOnlyDictionary<string, int>(dic);
  170. }
  171. }
  172. }