diff --git a/docker-compose.yml b/docker-compose.yml index 445651b..98f7640 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -6,6 +6,7 @@ services: ports: [ "8430:80" ] environment: - SECRET=${SECRET} + - MAX_AGE_MS=${MAX_AGE_MS} restart: unless-stopped redis: hostname: sus-redis diff --git a/main.go b/main.go index 15ec3ad..24d88b9 100644 --- a/main.go +++ b/main.go @@ -24,6 +24,8 @@ import ( "log" "net/http" "os" + "strconv" + "time" "github.com/go-redis/redis/v8" "github.com/gorilla/mux" @@ -38,6 +40,7 @@ var client = redis.NewClient(&redis.Options{ var SECRET string var INDEX_GET_REDIRECT = "http://alv.cx" +var MAX_AGE_MS int64 = 500 func main() { r := mux.NewRouter() @@ -59,6 +62,14 @@ func main() { INDEX_GET_REDIRECT = p } + if p, ok := os.LookupEnv("MAX_AGE_MS"); ok { + if v, err := strconv.ParseInt(p, 10, 64); err != nil { + fmt.Printf("Unable to parse environment variable MAX_AGE_MS: %v\n", p) + } else { + MAX_AGE_MS = v + } + } + log.Fatal(http.ListenAndServe(listenAddress, r)) } @@ -74,16 +85,29 @@ func indexHandler(w http.ResponseWriter, r *http.Request) { command := r.PostForm.Get("Command") shortlink := r.PostForm.Get("Shortlink") value := r.PostForm.Get("Value") - fmt.Printf("command: %v, shortlink: %v, value: %v\n", command, shortlink, value) - fmt.Println(shortlink) - fmt.Println(value) + req_timestamp := r.PostForm.Get("Timestamp") + req_timestamp_int, err := strconv.ParseInt(req_timestamp, 10, 64) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("Bad request")) + return + } + + cur_timestamp := time.Now().UnixNano() + if req_timestamp_int+MAX_AGE_MS*1000*1000 < cur_timestamp { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("Bad request")) + return + } + + fmt.Printf("req_timestamp: %v, command: %v, shortlink: %v, value: %v\n", req_timestamp, command, shortlink, value) signature := r.Header.Get("Signature") calculatedSignature := fmt.Sprintf( "SUS-SIGNATURE-%v", getSha256HMACSignature( []byte(SECRET), - command+":"+shortlink+":"+value, + req_timestamp+":"+command+":"+shortlink+":"+value, ), ) diff --git a/readme.md b/readme.md index 0189b28..3d0a654 100644 --- a/readme.md +++ b/readme.md @@ -39,6 +39,7 @@ flag is not provided. - `LISTEN_ADDRESS`---the address the server is listening on (default is `0.0.0.0:80`) - `INDEX_GET_REDIRECT`---the URL the user should be redirected to if they try to access `/` on the server (default is `http://alv.cx`) +- `MAX_AGE_MS`---how old a request can be (in milliseconds) before the server will refuse to process it. (default is 500 milliseconds) ### setting up susmng diff --git a/susmng.py b/susmng.py index 259e905..9f6aa18 100755 --- a/susmng.py +++ b/susmng.py @@ -8,6 +8,7 @@ import pathlib import os import json import sys +import time def get_args(): @@ -51,20 +52,34 @@ def main(args): if args.command == "delete" and args.value != "confirm": print("--value not set to 'confirm'... delete operation may fail") - r = requests.post(f"{'http' if args.http else 'https'}://{server}", - data = { - 'Command': args.command, - 'Shortlink': args.shortlink, - 'Value': args.value, - }, - headers = { - 'Signature': 'SUS-SIGNATURE-' + hmac.new( - secret.encode("UTF-8"), - (args.command+":"+args.shortlink+":"+args.value).encode("UTF-8"), - hashlib.sha256 - ).hexdigest() - } - ) + # accoring to python documentation (https://docs.python.org/3/library/time.html#time.time) + # this function does not explicitly have to use unix time, and implementation is dependent + # platform. + # most platforms (windows, unix) will probably give unix time though. + # + # the server side (main.go file) does explicitly use unix time (time.Now().UnixNano()) to get + # this number, but hopefully there should be no issues on most platforms. + timestamp = str(time.time_ns()) + + data = { + 'Command': args.command, + 'Shortlink': args.shortlink, + 'Value': args.value, + 'Timestamp': timestamp, + } + + headers = { + 'Signature': 'SUS-SIGNATURE-' + hmac.new( + secret.encode("UTF-8"), + (timestamp + ":" + args.command + ":" + args.shortlink + ":" + args.value).encode("UTF-8"), + hashlib.sha256 + ).hexdigest() + } + + print(f"{data=}") + print(f"{headers=}") + + r = requests.post(f"{'http' if args.http else 'https'}://{server}", data=data, headers=headers) print(r, file=sys.stderr) print(r.content.decode().strip()) return 0