using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Collections.ObjectModel; using System.IO; using System.Linq; using System.Net; using System.Net.Sockets; using System.Text; namespace DnsForwarder { internal class Program { private static string _defaultDns; private static string _cnDns; private static IReadOnlyDictionary _chinaList; private static string[] _chinaListDns; private static Socket _listener; private static BlockingCollection _consoleOutout; private static void Main(string[] args) { if (args.Length != 4) { Console.WriteLine(" <.CN dns server ip> "); Environment.Exit(-1); return; } if (!IPAddress.TryParse(args[0], out var listenAddress)) { Console.WriteLine("Invalid listen address."); Environment.Exit(-2); return; } Console.WriteLine("Starting..."); Console.WriteLine($"Listen address: {args[0]}"); Console.WriteLine($"Default Server: {_defaultDns = args[1]}"); Console.WriteLine($".CN DNS Server: {_cnDns = args[2]}"); Console.WriteLine($"dnsmasq-china-list: {args[3]}"); Console.WriteLine("Loading list file..."); LoadListFile(args[3]); Console.WriteLine("OK. Listening..."); _listener = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); _listener.Bind(new IPEndPoint(listenAddress, 53)); _consoleOutout = new BlockingCollection(); StartHandleRequest(); while (true) { Console.WriteLine(_consoleOutout.Take()); } } private static void StartHandleRequest() { while (true) { try { var buf = new byte[1500]; EndPoint from = new IPEndPoint(IPAddress.Any, 0); _listener.BeginReceiveFrom(buf, 0, buf.Length, SocketFlags.None, ref from, Callback, buf); break; } catch (Exception e) { Console.WriteLine(e); } } } private static void Callback(IAsyncResult ar) { var flagNextStarted = false; EndPoint from = new IPEndPoint(IPAddress.Any, 0); string domain = null; string target = null; try { var count = _listener.EndReceiveFrom(ar, ref from); StartHandleRequest(); flagNextStarted = true; var buf = (byte[])ar.AsyncState; domain = ExtractDomainName(buf); target = MatchServer(domain); var dnsResponse = GetDnsResponse(buf, count, target); if (dnsResponse != null) _listener.SendTo(dnsResponse, from); _consoleOutout.Add($"{DateTime.Now:yyyyMMdd HH:mm:ss} {from} [{target}]\t{domain}"); } catch (Exception e) { _consoleOutout.Add($"{DateTime.Now:yyyyMMdd HH:mm:ss} {from} [{target ?? "Unknown"}]\t{domain ?? "Unknown"} Err:{e.Message}"); } catch { _consoleOutout.Add($"{DateTime.Now:yyyyMMdd HH:mm:ss} {from} [{target ?? "Unknown"}]\t{domain ?? "Unknown"} Err:Unknown"); } finally { if (flagNextStarted == false) StartHandleRequest(); } } protected static string ExtractDomainName(byte[] buf) { //seeking for end of domain var ptr = 12; while (buf[ptr] != 0) { ptr += buf[ptr] + 1; } var bufDomain = new byte[ptr - 12]; Array.Copy(buf, 12, bufDomain, 0, bufDomain.Length); //fill dots ptr = 0; while (ptr < bufDomain.Length) { var b = bufDomain[ptr]; bufDomain[ptr] = (byte)'.'; ptr += b + 1; } return Encoding.ASCII.GetString(bufDomain, 1, bufDomain.Length - 1); } protected static byte[] GetDnsResponse(byte[] buf, int count, string host, int port = 53) { using var to = new UdpClient(); to.Connect(host, 53); to.Send(buf, count); //Handle Upstream TimeOut var asyncResult = to.BeginReceive(null, null); asyncResult.AsyncWaitHandle.WaitOne(2000); if (asyncResult.IsCompleted) { IPEndPoint remoteEP = null; byte[] receivedData = to.EndReceive(asyncResult, ref remoteEP); return receivedData; } return null; } protected static string MatchServer(string domain) { var lower = domain.ToLower(); if (lower.EndsWith(".cn")) return _cnDns; var parts = lower.Split('.').Reverse().ToArray(); for (int i = parts.Length - 1; i >= 0; i--) { var d = string.Join(".", parts.Take(i + 1).Reverse()); if (_chinaList.TryGetValue(d, out var tar)) return _chinaListDns[tar]; } return _defaultDns; } protected static void LoadListFile(params string[] paths) { var lines = paths.SelectMany(File.ReadAllLines); var dic = new Dictionary(); var tar = new List(); foreach (var line in lines) { var p = line.Trim(); if (p.StartsWith("#")) continue; var parts = p.Split('/'); if (parts.Length != 3 && parts[0] != "server=") continue; var domain = parts[1]; var dns = parts[2]; var dnsIndex = tar.IndexOf(dns); if (dnsIndex == -1) { tar.Add(dns); dnsIndex = tar.Count - 1; } dic[domain] = dnsIndex; } _chinaListDns = tar.ToArray(); _chinaList = new ReadOnlyDictionary(dic); } } }