diff --git a/KekUploadServer/Services/UploadService.cs b/KekUploadServer/Services/UploadService.cs index 218b4e2..b310573 100644 --- a/KekUploadServer/Services/UploadService.cs +++ b/KekUploadServer/Services/UploadService.cs @@ -174,6 +174,8 @@ public async Task HandleWebSocket(WebSocket webSocket) string? uploadStreamId = null; while (!receiveResult.CloseStatus.HasValue) { + if(webSocket.State != WebSocketState.Open) + break; receiveResult = await webSocket.ReceiveAsync( new ArraySegment(buffer), CancellationToken.None); @@ -183,6 +185,7 @@ public async Task HandleWebSocket(WebSocket webSocket) var info = Encoding.UTF8.GetString(buffer); const string uploadStreamIdPrefix = webSocketClientPrefix + "UploadStreamId: "; const string uploadTextDataPrefix = webSocketClientPrefix + "TextData: "; + const string finishUploadStreamPrefix = webSocketClientPrefix + "Finish: "; if (info.StartsWith(uploadStreamIdPrefix)) { uploadStreamId = info.Substring(uploadStreamIdPrefix.Length, 32); @@ -203,10 +206,38 @@ public async Task HandleWebSocket(WebSocket webSocket) } else { - _memoryCache.TryGetValue(uploadStreamId, out _); + if (!_memoryCache.TryGetValue(uploadStreamId, out _)) + { + await webSocket.SendAsync(new ArraySegment(Encoding.UTF8.GetBytes(webSocketServerPrefix + "UploadStream has expired, please create a new one!")), WebSocketMessageType.Text, true, CancellationToken.None); + await webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Expired UploadStream!", default); + return; + } var offset = Encoding.UTF8.GetByteCount(uploadTextDataPrefix); await UploadChunk(uploadItem, buffer, null, offset, receiveResult.Count - offset); } + }else if (info.StartsWith(finishUploadStreamPrefix)) + { + if (uploadStreamId == null || uploadItem == null) + { + await webSocket.SendAsync(new ArraySegment(Encoding.UTF8.GetBytes(webSocketServerPrefix + "No valid UploadStreamId specified!!!")), WebSocketMessageType.Text, true, CancellationToken.None); + } + else + { + var offset = Encoding.UTF8.GetByteCount(finishUploadStreamPrefix); + var hash = Encoding.UTF8.GetString(buffer, offset, receiveResult.Count - offset); + var finalHash = await FinalizeHash(uploadItem.Hasher); + if (hash != finalHash) + { + await webSocket.SendAsync(new ArraySegment(Encoding.UTF8.GetBytes(webSocketServerPrefix + "Invalid Hash!")), WebSocketMessageType.Text, true, CancellationToken.None); + await webSocket.CloseAsync(WebSocketCloseStatus.InvalidPayloadData, "Invalid Hash!", default); + return; + } + uploadItem.Hash = finalHash; + var result = await FinishUploadStream(uploadItem); + await webSocket.SendAsync(new ArraySegment(Encoding.UTF8.GetBytes(webSocketServerPrefix + "Id: " + result)), WebSocketMessageType.Text, true, CancellationToken.None); + await webSocket.CloseAsync(WebSocketCloseStatus.PolicyViolation, "Upload successful!", default); + return; + } } break; case WebSocketMessageType.Binary: @@ -216,7 +247,12 @@ public async Task HandleWebSocket(WebSocket webSocket) } else { - _memoryCache.TryGetValue(uploadStreamId, out _); + if (!_memoryCache.TryGetValue(uploadStreamId, out _)) + { + await webSocket.SendAsync(new ArraySegment(Encoding.UTF8.GetBytes(webSocketServerPrefix + "UploadStream has expired, please create a new one!")), WebSocketMessageType.Text, true, CancellationToken.None); + await webSocket.CloseAsync(WebSocketCloseStatus.PolicyViolation, "Expired UploadStream!", default); + return; + } await UploadChunk(uploadItem, buffer, null, 0, receiveResult.Count); } break; @@ -226,7 +262,7 @@ public async Task HandleWebSocket(WebSocket webSocket) } await webSocket.CloseAsync( - receiveResult.CloseStatus.Value, + receiveResult.CloseStatus ?? WebSocketCloseStatus.Empty, receiveResult.CloseStatusDescription, CancellationToken.None); }