diff --git a/Driver.cpp b/Driver.cpp index 0cd4a29..a10e28c 100644 --- a/Driver.cpp +++ b/Driver.cpp @@ -897,3 +897,49 @@ OvpnFindPeerVPN6(POVPN_DEVICE device, IN6_ADDR addr) OvpnPeerContext** ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->PeersByVpn6, &pp); return ptr ? (OvpnPeerContext*)*ptr : NULL; } + +VOID +OvpnDeletePeerFromTable(RTL_GENERIC_TABLE *table, OvpnPeerContext *peer, char* tableName) +{ + auto peerId = peer->PeerId; + auto pp = &peer; + + if (RtlDeleteElementGenericTable(table, pp)) { + LOG_INFO("Peer deleted", TraceLoggingValue(tableName, "table"), TraceLoggingValue(peerId, "peer-id")); + + if (InterlockedDecrement(&peer->RefCounter) == 0) { + OvpnPeerCtxFree(peer); + LOG_INFO("Peer freed", TraceLoggingValue(peerId, "peer-id")); + } + } + else { + LOG_INFO("Peer not found", TraceLoggingValue(tableName, "table"), TraceLoggingValue(peerId, "peer-id")); + } +} + +NTSTATUS +OvpnDeletePeer(POVPN_DEVICE device, INT32 peerId) +{ + NTSTATUS status = STATUS_SUCCESS; + + KIRQL kirql = ExAcquireSpinLockExclusive(&device->SpinLock); + + LOG_INFO("Deleting peer", TraceLoggingValue(peerId, "peer-id")); + + // get peer from main table + OvpnPeerContext* peerCtx = OvpnFindPeer(device, peerId); + if (peerCtx == NULL) { + status = STATUS_NOT_FOUND; + LOG_WARN("Peer not found", TraceLoggingValue(peerId, "peer-id")); + } + else { + OvpnDeletePeerFromTable(&device->PeersByVpn4, peerCtx, "vpn4"); + OvpnDeletePeerFromTable(&device->PeersByVpn6, peerCtx, "vpn6"); + + OvpnDeletePeerFromTable(&device->Peers, peerCtx, "peers"); + } + + ExReleaseSpinLockExclusive(&device->SpinLock, kirql); + + return status; +} diff --git a/Driver.h b/Driver.h index f9f2dc1..3ac667e 100644 --- a/Driver.h +++ b/Driver.h @@ -150,3 +150,7 @@ OvpnFindPeerVPN4(_In_ POVPN_DEVICE device, _In_ IN_ADDR addr); _Must_inspect_result_ OvpnPeerContext* OvpnFindPeerVPN6(_In_ POVPN_DEVICE device, _In_ IN6_ADDR addr); + +_Must_inspect_result_ +NTSTATUS +OvpnDeletePeer(_In_ POVPN_DEVICE device, INT32 peerId); diff --git a/timer.cpp b/timer.cpp index 2d93495..29034a0 100644 --- a/timer.cpp +++ b/timer.cpp @@ -117,7 +117,7 @@ static BOOLEAN OvpnTimerRecv(WDFTIMER timer) if (device->Mode == OVPN_MODE_P2P) { status = WdfIoQueueRetrieveNextRequest(device->PendingReadsQueue, &request); if (!NT_SUCCESS(status)) { - LOG_WARN("No pending request for keepalive timeout notification"); + LOG_INFO("No pending request for keepalive timeout notification"); return FALSE; } @@ -125,12 +125,15 @@ static BOOLEAN OvpnTimerRecv(WDFTIMER timer) WdfRequestCompleteWithInformation(request, STATUS_CONNECTION_DISCONNECTED, bytesSent); } else { + (VOID)OvpnDeletePeer(device, peerId); + status = WdfIoQueueRetrieveNextRequest(device->PendingNotificationRequestsQueue, &request); if (!NT_SUCCESS(status)) { - LOG_WARN("Adding keepalive timeout notification to the queue"); + LOG_INFO("Adding keepalive timeout notification to the queue"); return NT_SUCCESS(device->PendingNotificationsQueue.AddEvent(OVPN_NOTIFY_DEL_PEER, peerId, OVPN_DEL_PEER_REASON_EXPIRED)); } else { + LOG_INFO("Notify userspace about expired peer"); OVPN_NOTIFY_EVENT *evt; ULONG_PTR bytesSent = 0; LOG_IF_NOT_NT_SUCCESS(status = WdfRequestRetrieveOutputBuffer(request, sizeof(OVPN_NOTIFY_EVENT), (PVOID*)&evt, nullptr)); @@ -141,8 +144,6 @@ static BOOLEAN OvpnTimerRecv(WDFTIMER timer) bytesSent = sizeof(OVPN_NOTIFY_EVENT); } WdfRequestCompleteWithInformation(request, status, bytesSent); - - // TODO: remove peer } }