Commit b0d8896d authored by Michael Bryant's avatar Michael Bryant

Initial commit

parents
package main
import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
"runtime"
"strings"
"github.com/BurntSushi/toml"
)
// ProgramName holds the name of the program for building config paths
const ProgramName = "shadowhosts"
// ConfigName holds the expected config file name for building paths
const ConfigName = "config.toml"
// Sep is a shorthand for the system filepath separator
const Sep = string(filepath.Separator)
const DefaultConfig = `# [DANGEROUS] Uncomment to allow redirection entries from remote sources
#allow_redirect = true
# Add additional sources of domains to block here
sources = [
"https://adaway.org/hosts.txt",
"https://hosts-file.net/ad_servers.txt",
"https://pgl.yoyo.org/adservers/serverlist.php?hostformat=hosts&showintro=0&mimetype=plaintext"
]
# Add additional domains to block here
blacklist = []
# Add domains to unblock from online sources here
whitelist = []
# Add redirection rules here, following the example
[redirect]
# "localhost" = "127.0.0.1"
`
// Test that file exists
func fileExists(file string) bool {
_, err := os.Stat(file)
return err == nil
}
func inUserConfigDir() string {
// Then check for configuration in the user's config directory
env := ""
if runtime.GOOS == "linux" || runtime.GOOS == "darwin" || strings.Contains(runtime.GOOS, "bsd") || runtime.GOOS == "dragonfly" {
env = os.Getenv("HOME") + Sep + ".config"
} else if runtime.GOOS == "windows" {
env = os.Getenv("APPDATA")
}
// If found in user configuration directory, return it
if env != "" && env != Sep+".config" {
file := env + Sep + ProgramName + Sep + ConfigName
return file
}
return ""
}
func findConfigFile() (string, error) {
// If configuration file specified, use it
if flags.Config != "" {
if !fileExists(flags.Config) {
return "", fmt.Errorf("Configuration file %s does not exist", flags.Config)
}
return flags.Config, nil
}
// Check for a portable config first
file := "." + Sep + ConfigName
if fileExists(file) {
return file, nil
}
// Check in user config directory
file = inUserConfigDir()
if file != "" && fileExists(file) {
return file, nil
}
// Else if a *nix system, check /etc
if runtime.GOOS == "linux" || runtime.GOOS == "darwin" || strings.Contains(runtime.GOOS, "bsd") || runtime.GOOS == "dragonfly" {
if fileExists("/etc") {
return "/etc" + Sep + ProgramName + Sep + ConfigName, nil
}
}
// If not found, error
return "", fmt.Errorf("could not find an existing configuration file. Use --makeconfig to generate one")
}
// GetHostsConfig returns a HostsConfig or an error if something failed
func GetHostsConfig() (HostsConfig, error) {
// Get configuration file path
config := NewHostsConfig()
configFile, err := findConfigFile()
if err != nil {
return config, err
}
// Parse configuration file
_, err = toml.DecodeFile(configFile, &config)
if err != nil {
return config, err
}
// Return configuration
return config, nil
}
// GenerateConfig creates a default configuration file at the given location
func GenerateConfig(out string) error {
parent := filepath.Dir(out)
err := os.MkdirAll(parent, 0755)
if err != nil {
return err
}
err = ioutil.WriteFile(out, ([]byte)(DefaultConfig), 0644)
return err
}
package main
import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"regexp"
"sort"
"strings"
)
// HostsConfig holds the configuration used to generate the hosts file
type HostsConfig struct {
AllowRedirect bool
Sources map[string]bool
Blacklist map[string]struct{}
Whitelist map[string]struct{}
Redirect map[string]string
}
// NewHostsConfig returns a new instance of a HostsConfig struct
// with maps already initialized.
func NewHostsConfig() HostsConfig {
var conf HostsConfig
conf.Sources = make(map[string]bool)
conf.Blacklist = make(map[string]struct{})
conf.Whitelist = make(map[string]struct{})
conf.Redirect = make(map[string]string)
return conf
}
var domainRegexp = regexp.MustCompile(`^([0-9A-Za-z]([0-9A-Za-z]|\-|\_)*\.)*([0-9A-Za-z]([0-9A-Za-z]|\-|\_)*)$`)
var ipRegexp = regexp.MustCompile(`(2([0-4][0-9]|5[1-5])|(1)?[0-9]{1,2})\.(2([0-4][0-9]|5[1-5])|(1)?[0-9]{1,2}){3}`)
var multiSpaceRegexp = regexp.MustCompile(`(\s){2,}`)
var commentRegexp = regexp.MustCompile(`#.*`)
const fileHeader = `###########################################################################################
## This file was automatically generated by shadowhosts by Shadow53. Do not make changes ##
## to this file directly. Instead, modify the shadowhosts configuration file and re-run ##
## shadowhosts. ##
###########################################################################################
127.0.0.1 localhost localhost.localdomain
`
func invalidDomain(domain string) bool {
return !domainRegexp.MatchString(domain) || domain == "localhost" ||
domain == "localhost.localdomain"
}
// AddBlacklist adds a domain to the blacklist.
// Returns an error if it is not a valid domain
func (h *HostsConfig) AddBlacklist(domain string) error {
if invalidDomain(domain) {
return fmt.Errorf("%s is not a valid domain name", domain)
}
if _, ok := h.Whitelist[domain]; ok {
// TODO: What to return here?
return nil
}
// Add to set
h.Blacklist[domain] = struct{}{}
return nil
}
// AddWhitelist removes a domain from the blacklist, if found.
func (h *HostsConfig) AddWhitelist(domain string) error {
// No need to validate here, as it must be a valid domain
// to be blacklisted. Whitelisting something not blacklisted
// does nothing
h.Whitelist[domain] = struct{}{}
// Only remove from the set if it exists
if _, ok := h.Blacklist[domain]; ok {
delete(h.Blacklist, domain)
}
return nil
}
// AddSource adds a URL source to download a hosts list from
func (h *HostsConfig) AddSource(src string) error {
_, err := url.ParseRequestURI(src)
if err != nil {
return err
}
// Set to false to add to the set and mark unparsed
h.Sources[src] = false
return nil
}
// AddRedirect adds a redirection rule from `domain` to `ip`
func (h *HostsConfig) AddRedirect(domain, ip string) error {
if invalidDomain(domain) {
return fmt.Errorf("%s is not a valid domain name", domain)
}
if ip != "127.0.0.1" && !ipRegexp.MatchString(ip) {
return fmt.Errorf("%s is not a valid ip address", ip)
}
// Add to whitelist so redirection takes precedence
h.AddWhitelist(domain)
h.Redirect[domain] = ip
return nil
}
// DownloadSources downloads non-downloaded sources and
// adds the entries to the HostsConfig
func (h *HostsConfig) DownloadSources() error {
for url := range h.Sources {
if !h.Sources[url] {
// Download file
resp, err := http.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()
// Error if not success
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("Error while connecting to %v:\n Received non-ok status code %v", url, resp.StatusCode)
}
// Read contents of Body
content, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
// Parse Body contents as string and split into lines
entries := strings.Split(string(content[:]), "\n")
for _, entry := range entries {
//fmt.Println("Original entry: " + entry)
// Remove comments
entry = commentRegexp.ReplaceAllString(entry, "")
//fmt.Println("Without comments: " + entry)
// Trim leading and trailing spaces
entry = strings.TrimSpace(entry)
//fmt.Println("Trimmed spaces: " + entry)
// Ignore now-empty lines
if entry != "" {
// Collapse all whitespace into a single space character
entry = multiSpaceRegexp.ReplaceAllString(entry, " ")
strs := strings.Split(entry, " ")
ip := strs[0]
for _, domain := range strs[1:] {
if domain != "localhost" && domain != "localhost.localdomain" {
if ip == "0.0.0.0" || ip == "127.0.0.1" {
err = h.AddBlacklist(domain)
if err != nil {
return err
}
} else if h.AllowRedirect {
err = h.AddRedirect(domain, ip)
if err != nil {
return nil
}
}
}
}
}
}
h.Sources[url] = true
}
}
return nil
}
// GenerateHosts generates an alphabetically sorted list of host entries as a string
func (h HostsConfig) GenerateHosts() []byte {
// Create sorted list of all domains
var domains sort.StringSlice = make([]string, len(h.Blacklist)+len(h.Redirect))
for domain := range h.Blacklist {
domains = append(domains, domain)
}
for domain := range h.Redirect {
domains = append(domains, domain)
}
domains.Sort()
var buf bytes.Buffer
buf.WriteString(fileHeader)
for _, domain := range domains {
if _, ok := h.Blacklist[domain]; ok {
buf.WriteString("0.0.0.0 " + domain + "\n")
} else if ip, ok := h.Redirect[domain]; ok {
buf.WriteString(ip + " " + domain + "\n")
}
}
return buf.Bytes()
}
// UnmarshalTOML is used by the toml package to parse and set values in the HostsConfig
func (h *HostsConfig) UnmarshalTOML(data interface{}) error {
if data == nil {
return nil
}
var err error
d := data.(map[string]interface{})
if d["sources"] != nil {
srcs, ok := d["sources"].([]interface{})
if !ok {
return fmt.Errorf("Could not parse sources as an array. Received %v", d["sources"])
}
for _, src := range srcs {
srcStr, ok := src.(string)
if !ok {
return fmt.Errorf("Could not parse source URL as a string. Received %s", src)
}
err = h.AddSource(srcStr)
if err != nil {
return nil
}
}
}
if d["whitelist"] != nil {
wlist, ok := d["whitelist"].([]interface{})
if !ok {
return fmt.Errorf("Could not parse whitelist as an array. Received %v", d["whitelist"])
}
for _, domain := range wlist {
domainStr, ok := domain.(string)
if !ok {
return fmt.Errorf("Could not parse domain as a string. Received %s", domain)
}
err = h.AddWhitelist(domainStr)
if err != nil {
return err
}
}
}
if d["blacklist"] != nil {
blist, ok := d["blacklist"].([]interface{})
if !ok {
return fmt.Errorf("Could not parse blacklist as an array. Received %v", d["blacklist"])
}
for _, domain := range blist {
domainStr, ok := domain.(string)
if !ok {
return fmt.Errorf("Could not parse domain as a string. Received %s", domain)
}
err = h.AddBlacklist(domainStr)
if err != nil {
return err
}
}
}
if d["redirect"] != nil {
redirects, ok := d["redirect"].(map[string]interface{})
if !ok {
return fmt.Errorf("Could not parse redirect as a domain-ip mapping. Received %v", d["redirect"])
}
for domain, ip := range redirects {
ipString, ok := ip.(string)
if !ok {
return fmt.Errorf("Could not parse redirect IP address as a string. Received %v", ip)
}
err = h.AddRedirect(domain, ipString)
if err != nil {
return err
}
}
}
if d["allow_redirect"] != nil {
allowRedir, ok := d["allow_redirect"].(bool)
if !ok {
return fmt.Errorf("Could not parse \"allow_redirect\" as a boolean. Received %v", d["allow_redirect"])
}
h.AllowRedirect = allowRedir
}
return nil
}
package main
import (
"flag"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"runtime"
)
var flags struct {
Config string
Output string
MkDir bool
GenConfig bool
}
var flagsAreInit = false
func initFlags() {
if !flag.Parsed() {
flag.StringVar(&flags.Config, "config", "", "Path to a configuration file for shadowhosts to use")
flag.StringVar(&flags.Output, "out", "", "Path to file to output hosts file or default configuration to. File will be truncated if it exists")
flag.BoolVar(&flags.MkDir, "mkdir", false, "Set to true to make any missing parent directories for the file specified in --out")
flag.BoolVar(&flags.GenConfig, "genconfig", false, "Generate a default configuration and exit")
flag.Parse()
}
}
func getHostsFile() string {
// Windows is special
if runtime.GOOS == "windows" {
return os.Getenv("SystemRoot") + Sep + "System32" + Sep + "drivers" + Sep + "etc" + Sep + "hosts"
}
// Everyone else makes it available here
return "/etc/hosts"
}
func main() {
initFlags()
if flags.GenConfig {
userConfig := inUserConfigDir()
if flags.Output != "" {
GenerateConfig(flags.Output)
} else if userConfig != "" {
GenerateConfig(userConfig)
} else if !fileExists("." + Sep + ConfigName) {
GenerateConfig("." + Sep + ConfigName)
} else {
fmt.Println("Could not determine where to put the configuration file. Try again using the --out flag.")
os.Exit(1)
}
os.Exit(0)
}
config, err := GetHostsConfig()
if err != nil {
fmt.Println(err)
os.Exit(1)
}
// Figure out which file to output to
var out string
if flags.Output != "" {
out = flags.Output
} else {
out = getHostsFile()
}
// Deal with parent directories
parent := filepath.Dir(out)
if !fileExists(parent) {
if flags.MkDir {
os.MkdirAll(parent, 0755)
} else {
fmt.Printf("Parent folders of %s don't exist. Create them yourself or pass the --mkdir flag to shadowhosts.\n", out)
os.Exit(1)
}
}
err = config.DownloadSources()
if err != nil {
fmt.Println(err)
os.Exit(1)
}
err = ioutil.WriteFile(out, config.GenerateHosts(), 0644)
if err != nil {
if os.IsPermission(err) {
fmt.Printf("Could not write to file %s - you may need to run this program with admin privileges\n", out)
os.Exit(1)
} else {
fmt.Printf("Could not write to file %s: %s\n", out, err)
os.Exit(1)
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment