Browse Source

Rev: async, split console output

HOME 3 years ago
parent
commit
4dcf9a4d87
1 changed files with 72 additions and 11 deletions
  1. 72 11
      DnsForwarder/Program.cs

+ 72 - 11
DnsForwarder/Program.cs

@@ -1,19 +1,18 @@
 using System;
+using System.Collections.Concurrent;
 using System.Collections.Generic;
 using System.IO;
 using System.Linq;
 using System.Net;
 using System.Net.Sockets;
-using System.Runtime.InteropServices;
 using System.Text;
 
 namespace DnsForwarder
 {
-    internal class Program
+    internal class Program : ProgramRev1
     {
-        private static string DefaultServer;
-        private static string ChinaServer;
-        private static Dictionary<string, string> ChinaList;
+        private static Socket Listener;
+        private static BlockingCollection<string> ConsoleOutout;
 
         private static void Main(string[] args)
         {
@@ -31,6 +30,68 @@ namespace DnsForwarder
 
             Console.WriteLine("Loading list file...");
             ChinaList = LoadListFile(args[2]).ToDictionary(p => p[0], p => p[1]);
+            Console.WriteLine("OK. Listening...");
+
+            Listener = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
+            Listener.Bind(new IPEndPoint(IPAddress.Any, 53));
+
+            ConsoleOutout = new BlockingCollection<string>();
+
+            StartHandleRequest();
+
+            while (true)
+            {
+                Console.WriteLine(ConsoleOutout.Take());
+            }
+        }
+
+        private static void StartHandleRequest()
+        {
+            var buf = new byte[1500];
+            EndPoint from = new IPEndPoint(IPAddress.Any, 0);
+            Listener.BeginReceiveFrom(buf, 0, buf.Length, SocketFlags.None, ref from, Callback, buf);
+        }
+
+        private static void Callback(IAsyncResult ar)
+        {
+            EndPoint from = new IPEndPoint(IPAddress.Any, 0);
+            var count = Listener.EndReceiveFrom(ar, ref from);
+            var buf = (byte[])ar.AsyncState;
+
+            StartHandleRequest();
+
+            var domain = ExtractDomainName(buf, count);
+
+            var target = MatchServer(domain);
+
+            ConsoleOutout.Add($"{DateTime.Now:yyyyMMdd HH:mm:ss} {from} [{target}]\t{domain}");
+
+            Listener.SendTo(GetDnsResponse(buf, count, target), from);
+        }
+    }
+
+    internal class ProgramRev1
+    {
+        protected static string DefaultServer;
+        protected static string ChinaServer;
+        protected static Dictionary<string, string> ChinaList;
+
+        private static void MainRev1(string[] args)
+        {
+            if (args.Length != 3)
+            {
+                Console.WriteLine("<default dns server ip> <china dns server ip> <path to dnsmasq-china-list file>");
+                Environment.Exit(-1);
+                return;
+            }
+
+            Console.WriteLine("Starting...");
+            Console.WriteLine($"Default Server:{DefaultServer = args[0]}");
+            Console.WriteLine($"China DNS Server:{ChinaServer = args[1]}");
+            Console.WriteLine($"dnsmasq-china-list file:{args[2]}");
+
+            Console.WriteLine("Loading list file...");
+            ChinaList = LoadListFile(args[2]).ToDictionary(p => p[0], p => p[1]);
 
             Console.WriteLine("OK. Listing...");
 
@@ -52,7 +113,7 @@ namespace DnsForwarder
             }
         }
 
-        private static string ExtractDomainName(byte[] buf, int count)
+        protected static string ExtractDomainName(byte[] buf, int count)
         {
             var lst = new List<byte[]>();
 
@@ -77,16 +138,16 @@ namespace DnsForwarder
             return Encoding.ASCII.GetString(bufDomain, 1, bufDomain.Length - 1);
         }
 
-        private static byte[] GetDnsResponse(byte[] buf, int count, string host, int port = 53)
+        protected static byte[] GetDnsResponse(byte[] buf, int count, string host, int port = 53)
         {
             using var to = new UdpClient();
             to.Connect(host, 53);
             to.Send(buf, count);
             IPEndPoint outEp = null;
-            return to.Receive(ref outEp);
+            return to.Receive(ref outEp); //TODO: Handle Upstream TimeOut
         }
 
-        private static string MatchServer(string domain)
+        protected static string MatchServer(string domain)
         {
             var lower = domain.ToLower();
             if (lower.EndsWith(".cn")) return ChinaServer;
@@ -98,11 +159,11 @@ namespace DnsForwarder
                 var d = string.Join(".", parts.Take(i + 1).Reverse());
                 if (ChinaList.TryGetValue(d, out var tar)) return tar;
             }
-            
+
             return DefaultServer;
         }
 
-        private static string[][] LoadListFile(string path)
+        protected static string[][] LoadListFile(string path)
         {
             var lines = File.ReadAllLines(path);
             var items = lines.Select(p =>