Skip to content

Commit

Permalink
feat: revamp config (#230)
Browse files Browse the repository at this point in the history
  • Loading branch information
fortuna authored May 7, 2024
1 parent 4cdcf58 commit 2e2d2af
Show file tree
Hide file tree
Showing 22 changed files with 419 additions and 203 deletions.
279 changes: 146 additions & 133 deletions x/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,187 +22,194 @@ import (
"strings"

"github.com/Jigsaw-Code/outline-sdk/transport"
"github.com/Jigsaw-Code/outline-sdk/transport/split"
"github.com/Jigsaw-Code/outline-sdk/transport/tlsfrag"
)

// ConfigParser enables the creation of stream and packet dialers based on a config. The config is
// ConfigToDialer enables the creation of stream and packet dialers based on a config. The config is
// extensible by registering wrappers for config subtypes.
type ConfigParser struct {
sdWrapers map[string]WrapStreamDialerFunc
pdWrappers map[string]WrapPacketDialerFunc
type ConfigToDialer struct {
// Base StreamDialer to create direct stream connections. If you need direct stream connections, this must not be nil.
BaseStreamDialer transport.StreamDialer
// Base PacketDialer to create direct packet connections. If you need direct packet connections, this must not be nil.
BasePacketDialer transport.PacketDialer
sdBuilders map[string]NewStreamDialerFunc
pdBuilders map[string]NewPacketDialerFunc
}

// NewDefaultConfigParser creates a [ConfigParser] with a set of default wrappers already registered.
func NewDefaultConfigParser() *ConfigParser {
p := new(ConfigParser)
// NewStreamDialerFunc wraps a Dialer based on the wrapConfig. The innerSD and innerPD functions can provide a base Stream and Packet Dialers if needed.
type NewStreamDialerFunc func(innerSD func() (transport.StreamDialer, error), innerPD func() (transport.PacketDialer, error), wrapConfig *url.URL) (transport.StreamDialer, error)

// NewPacketDialerFunc wraps a Dialer based on the wrapConfig. The innerSD and innerPD functions can provide a base Stream and Packet Dialers if needed.
type NewPacketDialerFunc func(innerSD func() (transport.StreamDialer, error), innerPD func() (transport.PacketDialer, error), wrapConfig *url.URL) (transport.PacketDialer, error)

// NewDefaultConfigToDialer creates a [ConfigToDialer] with a set of default wrappers already registered.
func NewDefaultConfigToDialer() *ConfigToDialer {
p := new(ConfigToDialer)
p.BaseStreamDialer = &transport.TCPDialer{}
p.BasePacketDialer = &transport.UDPDialer{}

// Please keep the list in alphabetical order.
p.RegisterStreamDialerWrapper("doh", wrapStreamDialerWithDOH)
p.RegisterPacketDialerWrapper("doh", func(baseDialer transport.PacketDialer, wrapConfig *url.URL) (transport.PacketDialer, error) {
return nil, errors.New("doh is not supported for PacketDialers")
})
p.RegisterStreamDialerType("do53", wrapStreamDialerWithDO53)

p.RegisterStreamDialerWrapper("override", wrapStreamDialerWithOverride)
p.RegisterPacketDialerWrapper("override", wrapPacketDialerWithOverride)
p.RegisterStreamDialerType("doh", wrapStreamDialerWithDOH)

p.RegisterStreamDialerWrapper("socks5", wrapStreamDialerWithSOCKS5)
p.RegisterPacketDialerWrapper("socks5", func(baseDialer transport.PacketDialer, wrapConfig *url.URL) (transport.PacketDialer, error) {
return nil, errors.New("socks5 is not supported for PacketDialers")
})
p.RegisterStreamDialerType("override", wrapStreamDialerWithOverride)
p.RegisterPacketDialerType("override", wrapPacketDialerWithOverride)

p.RegisterStreamDialerWrapper("split", func(baseDialer transport.StreamDialer, wrapConfig *url.URL) (transport.StreamDialer, error) {
prefixBytesStr := wrapConfig.Opaque
prefixBytes, err := strconv.Atoi(prefixBytesStr)
if err != nil {
return nil, fmt.Errorf("prefixBytes is not a number: %v. Split config should be in split:<number> format", prefixBytesStr)
}
return split.NewStreamDialer(baseDialer, int64(prefixBytes))
})
p.RegisterPacketDialerWrapper("split", func(baseDialer transport.PacketDialer, wrapConfig *url.URL) (transport.PacketDialer, error) {
return nil, errors.New("split is not supported for PacketDialers")
})
p.RegisterStreamDialerType("socks5", wrapStreamDialerWithSOCKS5)

p.RegisterStreamDialerWrapper("ss", wrapStreamDialerWithShadowsocks)
p.RegisterPacketDialerWrapper("ss", wrapPacketDialerWithShadowsocks)
p.RegisterStreamDialerType("split", wrapStreamDialerWithSplit)

p.RegisterStreamDialerWrapper("tls", wrapStreamDialerWithTLS)
p.RegisterPacketDialerWrapper("tls", func(baseDialer transport.PacketDialer, wrapConfig *url.URL) (transport.PacketDialer, error) {
return nil, errors.New("tls is not supported for PacketDialers")
})
p.RegisterStreamDialerType("ss", wrapStreamDialerWithShadowsocks)
p.RegisterPacketDialerType("ss", wrapPacketDialerWithShadowsocks)

p.RegisterStreamDialerWrapper("tlsfrag", func(baseDialer transport.StreamDialer, wrapConfig *url.URL) (transport.StreamDialer, error) {
p.RegisterStreamDialerType("tls", wrapStreamDialerWithTLS)

p.RegisterStreamDialerType("tlsfrag", func(innerSD func() (transport.StreamDialer, error), innerPD func() (transport.PacketDialer, error), wrapConfig *url.URL) (transport.StreamDialer, error) {
sd, err := innerSD()
if err != nil {
return nil, err
}
lenStr := wrapConfig.Opaque
fixedLen, err := strconv.Atoi(lenStr)
if err != nil {
return nil, fmt.Errorf("invalid tlsfrag option: %v. It should be in tlsfrag:<number> format", lenStr)
}
return tlsfrag.NewFixedLenStreamDialer(baseDialer, fixedLen)
})
p.RegisterPacketDialerWrapper("tlsfrag", func(baseDialer transport.PacketDialer, wrapConfig *url.URL) (transport.PacketDialer, error) {
return nil, errors.New("tlsfrag is not supported for PacketDialers")
return tlsfrag.NewFixedLenStreamDialer(sd, fixedLen)
})

return p
}

// WrapStreamDialerFunc wraps a [transport.StreamDialer] based on the wrapConfig.
type WrapStreamDialerFunc func(dialer transport.StreamDialer, wrapConfig *url.URL) (transport.StreamDialer, error)

// RegisterStreamDialerWrapper will register a wrapper for stream dialers under the given subtype.
func (p *ConfigParser) RegisterStreamDialerWrapper(subtype string, wrapper WrapStreamDialerFunc) error {
if p.sdWrapers == nil {
p.sdWrapers = make(map[string]WrapStreamDialerFunc)
// RegisterStreamDialerType will register a wrapper for stream dialers under the given subtype.
func (p *ConfigToDialer) RegisterStreamDialerType(subtype string, newDialer NewStreamDialerFunc) error {
if p.sdBuilders == nil {
p.sdBuilders = make(map[string]NewStreamDialerFunc)
}

if _, found := p.sdWrapers[subtype]; found {
if _, found := p.sdBuilders[subtype]; found {
return fmt.Errorf("config parser %v for StreamDialer added twice", subtype)
}
p.sdWrapers[subtype] = wrapper
p.sdBuilders[subtype] = newDialer
return nil
}

// WrapPacketDialerFunc wraps a [transport.PacketDialer] based on the wrapConfig.
type WrapPacketDialerFunc func(dialer transport.PacketDialer, wrapConfig *url.URL) (transport.PacketDialer, error)

// RegisterPacketDialerWrapper will register a wrapper for packet dialers under the given subtype.
func (p *ConfigParser) RegisterPacketDialerWrapper(subtype string, wrapper WrapPacketDialerFunc) error {
if p.pdWrappers == nil {
p.pdWrappers = make(map[string]WrapPacketDialerFunc)
// RegisterPacketDialerType will register a wrapper for packet dialers under the given subtype.
func (p *ConfigToDialer) RegisterPacketDialerType(subtype string, newDialer NewPacketDialerFunc) error {
if p.pdBuilders == nil {
p.pdBuilders = make(map[string]NewPacketDialerFunc)
}

if _, found := p.pdWrappers[subtype]; found {
return fmt.Errorf("config parser %v for PacketDialer added twice", subtype)
if _, found := p.pdBuilders[subtype]; found {
return fmt.Errorf("config parser %v for StreamDialer added twice", subtype)
}
p.pdWrappers[subtype] = wrapper
p.pdBuilders[subtype] = newDialer
return nil
}

func parseConfigPart(oneDialerConfig string) (*url.URL, error) {
oneDialerConfig = strings.TrimSpace(oneDialerConfig)
if oneDialerConfig == "" {
return nil, errors.New("empty config part")
func parseConfig(configText string) ([]*url.URL, error) {
parts := strings.Split(strings.TrimSpace(configText), "|")
if len(parts) == 1 && parts[0] == "" {
return []*url.URL{}, nil
}
// Make it "<scheme>:" if it's only "<scheme>" to parse as a URL.
if !strings.Contains(oneDialerConfig, ":") {
oneDialerConfig += ":"
urls := make([]*url.URL, 0, len(parts))
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
return nil, errors.New("empty config part")
}
// Make it "<scheme>:" if it's only "<scheme>" to parse as a URL.
if !strings.Contains(part, ":") {
part += ":"
}
url, err := url.Parse(part)
if err != nil {
return nil, fmt.Errorf("part is not a valid URL: %w", err)
}
urls = append(urls, url)
}
url, err := url.Parse(oneDialerConfig)
return urls, nil
}

// NewStreamDialer creates a [Dialer] according to transportConfig, using dialer as the
// base [Dialer]. The given dialer must not be nil.
func (p *ConfigToDialer) NewStreamDialer(transportConfig string) (transport.StreamDialer, error) {
parts, err := parseConfig(transportConfig)
if err != nil {
return nil, fmt.Errorf("part is not a valid URL: %w", err)
return nil, err
}
return url, nil
return p.newStreamDialer(parts)
}

// WrapStreamDialer creates a [transport.StreamDialer] according to transportConfig, using dialer as the
// base [transport.StreamDialer]. The given dialer must not be nil.
func (p *ConfigParser) WrapStreamDialer(dialer transport.StreamDialer, transportConfig string) (transport.StreamDialer, error) {
if dialer == nil {
return nil, errors.New("base dialer must not be nil")
}
transportConfig = strings.TrimSpace(transportConfig)
if transportConfig == "" {
return dialer, nil
// NewPacketDialer creates a [Dialer] according to transportConfig, using dialer as the
// base [Dialer]. The given dialer must not be nil.
func (p *ConfigToDialer) NewPacketDialer(transportConfig string) (transport.PacketDialer, error) {
parts, err := parseConfig(transportConfig)
if err != nil {
return nil, err
}
for _, part := range strings.Split(transportConfig, "|") {
url, err := parseConfigPart(part)
if err != nil {
return nil, err
}
w, ok := p.sdWrapers[url.Scheme]
if !ok {
return nil, fmt.Errorf("config scheme '%v' is not supported", url.Scheme)
}
dialer, err = w(dialer, url)
if err != nil {
return nil, err
return p.newPacketDialer(parts)
}

func (p *ConfigToDialer) newStreamDialer(configParts []*url.URL) (transport.StreamDialer, error) {
if len(configParts) == 0 {
if p.BaseStreamDialer == nil {
return nil, fmt.Errorf("base StreamDialer must not be nil")
}
return p.BaseStreamDialer, nil
}
thisURL := configParts[len(configParts)-1]
innerConfig := configParts[:len(configParts)-1]
newDialer, ok := p.sdBuilders[thisURL.Scheme]
if !ok {
return nil, fmt.Errorf("config scheme '%v' is not supported for Stream Dialers", thisURL.Scheme)
}
return dialer, nil
newSD := func() (transport.StreamDialer, error) {
return p.newStreamDialer(innerConfig)
}
newPD := func() (transport.PacketDialer, error) {
return p.newPacketDialer(innerConfig)
}
return newDialer(newSD, newPD, thisURL)
}

// WrapPacketDialer creates a [transport.PacketDialer] according to transportConfig, using dialer as the
// base [transport.PacketDialer]. The given dialer must not be nil.
func (p *ConfigParser) WrapPacketDialer(dialer transport.PacketDialer, transportConfig string) (transport.PacketDialer, error) {
if dialer == nil {
return nil, errors.New("base dialer must not be nil")
func (p *ConfigToDialer) newPacketDialer(configParts []*url.URL) (transport.PacketDialer, error) {
if len(configParts) == 0 {
if p.BasePacketDialer == nil {
return nil, fmt.Errorf("base PacketDialer must not be nil")
}
return p.BasePacketDialer, nil
}
transportConfig = strings.TrimSpace(transportConfig)
if transportConfig == "" {
return dialer, nil
thisURL := configParts[len(configParts)-1]
innerConfig := configParts[:len(configParts)-1]
newDialer, ok := p.pdBuilders[thisURL.Scheme]
if !ok {
return nil, fmt.Errorf("config scheme '%v' is not supported for Packet Dialers", thisURL.Scheme)
}
for _, part := range strings.Split(transportConfig, "|") {
url, err := parseConfigPart(part)
if err != nil {
return nil, err
}
w, ok := p.pdWrappers[url.Scheme]
if !ok {
return nil, fmt.Errorf("config scheme '%v' is not supported", url.Scheme)
}
dialer, err = w(dialer, url)
if err != nil {
return nil, err
}
newSD := func() (transport.StreamDialer, error) {
return p.newStreamDialer(innerConfig)
}
newPD := func() (transport.PacketDialer, error) {
return p.newPacketDialer(innerConfig)
}
return dialer, nil
return newDialer(newSD, newPD, thisURL)
}

// NewpacketListener creates a new [transport.PacketListener] according to the given config,
// the config must contain only one "ss://" segment.
// TODO: make NewPacketListener configurable.
func NewPacketListener(transportConfig string) (transport.PacketListener, error) {
if transportConfig = strings.TrimSpace(transportConfig); transportConfig == "" {
parts, err := parseConfig(transportConfig)
if err != nil {
return nil, err
}
if len(parts) == 0 {
return nil, errors.New("config is required")
}
if strings.Contains(transportConfig, "|") {
if len(parts) > 1 {
return nil, errors.New("multi-part config is not supported")
}

url, err := parseConfigPart(transportConfig)
if err != nil {
return nil, fmt.Errorf("failed to parse config: %w", err)
}
url := parts[0]
// Please keep scheme list sorted.
switch strings.ToLower(url.Scheme) {
case "ss":
Expand All @@ -214,34 +221,40 @@ func NewPacketListener(transportConfig string) (transport.PacketListener, error)
}

func SanitizeConfig(transportConfig string) (string, error) {
parts, err := parseConfig(transportConfig)
if err != nil {
return "", err
}

// Do nothing if the config is empty
if transportConfig == "" {
if len(parts) == 0 {
return "", nil
}
// Split the string into parts
parts := strings.Split(transportConfig, "|")

// Iterate through each part
for i, part := range parts {
u, err := parseConfigPart(part)
if err != nil {
return "", fmt.Errorf("failed to parse config part: %w", err)
}
textParts := make([]string, len(parts))
for i, u := range parts {
scheme := strings.ToLower(u.Scheme)
switch scheme {
case "ss":
parts[i], _ = sanitizeShadowsocksURL(u)
textParts[i], err = sanitizeShadowsocksURL(u)
if err != nil {
return "", err
}
case "socks5":
parts[i], _ = sanitizeSocks5URL(u)
textParts[i], err = sanitizeSocks5URL(u)
if err != nil {
return "", err
}
case "override", "split", "tls", "tlsfrag":
// No sanitization needed
parts[i] = u.String()
textParts[i] = u.String()
default:
parts[i] = scheme + "://UNKNOWN"
textParts[i] = scheme + "://UNKNOWN"
}
}
// Join the parts back into a string
return strings.Join(parts, "|"), nil
return strings.Join(textParts, "|"), nil
}

func sanitizeSocks5URL(u *url.URL) (string, error) {
Expand Down
Loading

0 comments on commit 2e2d2af

Please sign in to comment.