Skip to content

Commit

Permalink
feat: lock 구현 (MZ-307) (#81)
Browse files Browse the repository at this point in the history
* feat: lockManager 구현

* feat: 투표하기 락 적용

* fix: 빈 객체 삭제하도록 변경

* fix: actor 생성 연산 수정

* fix: actor key 생성 방식 수정

* fix: key 생성 방식 변경 및 M-thread 테스트 추가

* fix: try lock 로직 변경

* chore: ktlint formatting

* fix: 동시성 테스트 쓰레드, 테스트 횟수 변경 및 로직 에러 코드 세분화
  • Loading branch information
wjdtkdgns authored Aug 19, 2024
1 parent 58ab443 commit e47b056
Show file tree
Hide file tree
Showing 8 changed files with 400 additions and 13 deletions.
14 changes: 14 additions & 0 deletions api/src/main/kotlin/com/oksusu/susu/api/common/lock/LockKey.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package com.oksusu.susu.api.common.lock

private enum class LockType {
VOTE,
;
}

class LockKey {
companion object {
fun getVoteKey(id: Long): String {
return "${LockType.VOTE}_$id"
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package com.oksusu.susu.api.common.lock

interface LockManager {
suspend fun <T> lock(key: String, block: suspend () -> T): T
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
package com.oksusu.susu.api.common.lock

import com.oksusu.susu.common.exception.ErrorCode
import com.oksusu.susu.common.exception.FailToExecuteException
import io.github.oshai.kotlinlogging.KotlinLogging
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.SendChannel
import kotlinx.coroutines.channels.actor
import org.springframework.stereotype.Component
import java.util.LinkedList
import java.util.concurrent.ConcurrentHashMap

private val logger = KotlinLogging.logger { }

private enum class LockReturn {
/**
* 락 실행
*/
PROCESS_LOCK,

/**
* 락 해제됨
*/
UNLOCK,

/**
* 등록된 채널 삭제
*/
DELETE_CHANNEL,

/**
* 락 큐가 안비었음
*/
NOT_EMPTY_QUEUE,

/**
* 락 큐가 비었음
*/
EMPTY_QUEUE,
;
}

private sealed class LockMsg {
/** 락 획득 시도 */
class TryLock(val channel: SendChannel<LockReturn>) : LockMsg()

/** 락 해제 */
class UnLock(val channel: SendChannel<LockReturn>) : LockMsg()

/** 등록된 채널 지우기 */
class DeleteChannel(val channel: SendChannel<LockReturn>) : LockMsg()

/** 큐 비었는지 확인 */
class CheckQueueEmpty(val channel: SendChannel<LockReturn>) : LockMsg()
}

@OptIn(ObsoleteCoroutinesApi::class)
private fun lockActor() = CoroutineScope(Dispatchers.IO).actor<LockMsg>(capacity = 1000) {
// queue 맨 앞 == 락 설정
val lockQueue = LinkedList<SendChannel<LockReturn>>()

for (msg in channel) {
when (msg) {
is LockMsg.TryLock -> {
// 큐에 채널 등록하기
lockQueue.offer(msg.channel)

// 만약 방금 등록한 채널이 큐의 맨 앞이라면 바로 실행
if (lockQueue.peek() == msg.channel) {
msg.channel.send(LockReturn.PROCESS_LOCK)
}
}

is LockMsg.UnLock -> {
// 현재 락을 획득한 채널을 큐에서 삭제
lockQueue.poll()

// 다음 락 획득 대상 notify하기
if (lockQueue.peek() != null) {
lockQueue.peek().send(LockReturn.PROCESS_LOCK)
}

// 락 해제 및 큐 삭제 완료 알리기
msg.channel.send(LockReturn.UNLOCK)
}

is LockMsg.DeleteChannel -> {
if (lockQueue.peek() == msg.channel) {
// 삭제하려는 채널이 큐의 맨 앞일 때, 큐에서 삭제하고 다음꺼 실행
lockQueue.poll()
if (lockQueue.peek() != null) {
lockQueue.peek().send(LockReturn.PROCESS_LOCK)
}
} else {
// 삭제하려는 채널이 큐의 맨 앞이 아닐 때, 큐에서만 삭제
lockQueue.remove(msg.channel)
}

// 삭제 완료 처리 알리기
msg.channel.send(LockReturn.DELETE_CHANNEL)
}

is LockMsg.CheckQueueEmpty -> {
if (lockQueue.peek() == null) {
msg.channel.send(LockReturn.EMPTY_QUEUE)
} else {
msg.channel.send(LockReturn.NOT_EMPTY_QUEUE)
}
}
}
}
}

@Component
class SuspendableLockManager : LockManager {
companion object {
private const val WAIT_TIME = 3000L
private const val LEASE_TIME = 3000L
}

private val actorMap = ConcurrentHashMap<String, SendChannel<LockMsg>>()

override suspend fun <T> lock(key: String, block: suspend () -> T): T {
// lock 관련 리턴 받을 채널
Channel<LockReturn>().run {
val channel = this

// 락 설정
val actor = tryLock(key, channel)

try {
// 로직 실행
return withTimeout(LEASE_TIME) {
block()
}
} catch (e: TimeoutCancellationException) {
// 락 보유 시간 에러 처리
throw FailToExecuteException(ErrorCode.LOCK_TIMEOUT_ERROR)
} catch (e: Exception) {
// 나머지 에러 처리
throw e
} finally {
// 락 해제
releaseLock(actor, channel)

// 큐가 빈 액터 삭제
deleteEmptyQueueActor(channel, key)

logger.info { actorMap }
}
}
}

private suspend fun tryLock(key: String, channel: Channel<LockReturn>): SendChannel<LockMsg> {
val actor = actorMap.compute(key) { _, value ->
val actor = value ?: lockActor()

runBlocking(Dispatchers.Unconfined) {
actor.send(LockMsg.TryLock(channel))
}

actor
} ?: throw FailToExecuteException(ErrorCode.FAIL_TO_GET_LOCK)

try {
withTimeout(WAIT_TIME) {
channel.receive()
}
} catch (e: TimeoutCancellationException) {
// 락 획득 시간 에러 처리
throw FailToExecuteException(ErrorCode.ACQUIRE_LOCK_TIMEOUT)
} catch (e: Exception) {
// 수신 채널 지우기
actor.send(LockMsg.DeleteChannel(channel))
channel.receive()

throw FailToExecuteException(ErrorCode.FAIL_TO_EXECUTE_LOCK)
}

return actor
}

private suspend fun releaseLock(actor: SendChannel<LockMsg>, channel: Channel<LockReturn>) {
actor.send(LockMsg.UnLock(channel))
channel.receive()
}

private suspend fun deleteEmptyQueueActor(channel: Channel<LockReturn>, key: String) {
actorMap.computeIfPresent(key) { _, value ->
val rtn = runBlocking(Dispatchers.Unconfined) {
value.send(LockMsg.CheckQueueEmpty(channel))
channel.receive()
}

if (rtn == LockReturn.EMPTY_QUEUE) {
null
} else {
value
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,16 @@ package com.oksusu.susu.api.post.application

import com.oksusu.susu.api.auth.model.AuthUser
import com.oksusu.susu.api.common.dto.SusuPageRequest
import com.oksusu.susu.api.common.lock.LockKey
import com.oksusu.susu.api.common.lock.LockManager
import com.oksusu.susu.api.count.application.CountService
import com.oksusu.susu.api.event.model.DeleteVoteCountEvent
import com.oksusu.susu.api.post.model.OnboardingVoteOptionCountModel
import com.oksusu.susu.api.post.model.VoteCountModel
import com.oksusu.susu.api.post.model.VoteOptionAndHistoryModel
import com.oksusu.susu.api.post.model.VoteOptionCountModel
import com.oksusu.susu.api.post.model.VoteOptionModel
import com.oksusu.susu.api.post.model.*
import com.oksusu.susu.api.post.model.request.CreateVoteHistoryRequest
import com.oksusu.susu.api.post.model.request.CreateVoteRequest
import com.oksusu.susu.api.post.model.request.OverwriteVoteHistoryRequest
import com.oksusu.susu.api.post.model.request.UpdateVoteRequest
import com.oksusu.susu.api.post.model.response.CreateAndUpdateVoteResponse
import com.oksusu.susu.api.post.model.response.OnboardingVoteResponse
import com.oksusu.susu.api.post.model.response.VoteAllInfoResponse
import com.oksusu.susu.api.post.model.response.VoteAndOptionsWithCountResponse
import com.oksusu.susu.api.post.model.response.VoteWithCountResponse
import com.oksusu.susu.api.post.model.response.*
import com.oksusu.susu.api.post.model.vo.SearchVoteRequest
import com.oksusu.susu.api.user.application.BlockService
import com.oksusu.susu.common.config.SusuConfig
Expand Down Expand Up @@ -56,6 +50,7 @@ class VoteFacade(
private val eventPublisher: ApplicationEventPublisher,
private val onboardingGetVoteConfig: SusuConfig.OnboardingGetVoteConfig,
private val voteValidateService: VoteValidateService,
private val lockManager: LockManager,
) {
private val logger = KotlinLogging.logger { }

Expand Down Expand Up @@ -167,9 +162,11 @@ class VoteFacade(
}

suspend fun vote(user: AuthUser, id: Long, request: CreateVoteHistoryRequest) {
when (request.isCancel) {
true -> cancelVote(user.uid, id, request.optionId)
false -> castVote(user.uid, id, request.optionId)
lockManager.lock(LockKey.getVoteKey(id)) {
when (request.isCancel) {
true -> cancelVote(user.uid, id, request.optionId)
false -> castVote(user.uid, id, request.optionId)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package com.oksusu.susu.api.common.lock

import com.oksusu.susu.api.testExtension.CONCURRENT_COUNT
import com.oksusu.susu.api.testExtension.THREAD_COUNT
import com.oksusu.susu.api.testExtension.executeConcurrency
import io.github.oshai.kotlinlogging.KotlinLogging
import io.kotest.core.spec.style.DescribeSpec
import io.kotest.matchers.equals.shouldBeEqual
import kotlinx.coroutines.*
import java.util.concurrent.CountDownLatch
import java.util.concurrent.Executors
import java.util.concurrent.atomic.AtomicLong

class SuspendableLockManagerTest : DescribeSpec({
val logger = KotlinLogging.logger { }

val lockManager = SuspendableLockManager()
val countService1 = CountService()
val countService2 = CountService()
val countService3 = CountService()

beforeEach {
countService1.apply { this.counter = 0 }
countService2.apply { this.counter = 0 }
countService3.apply { this.counter = 0 }
}

describe("suspendable lock manager") {
context("락을 설정하면") {
it("여러 쓰레드를 생성해 동작했을 때, 카운트가 동작한 수만큼 증가해야한다.") {
val successCount = AtomicLong()

executeConcurrency(successCount) {
lockManager.lock("1") {
countService1.increase()
logger.info { "1 ${countService1.counter}" }
}
}

countService1.counter shouldBeEqual CONCURRENT_COUNT
successCount.get() shouldBeEqual CONCURRENT_COUNT.toLong()
}

it("여러 쓰레드를 생성해 동작했을 때, 키 별로 락이 지정되고, 카운트가 올바르게 증가해야한다.") {
val successCount = AtomicLong()
val executorService = Executors.newFixedThreadPool(THREAD_COUNT * 3)
val latch = CountDownLatch(CONCURRENT_COUNT * 3)
for (i in 1..CONCURRENT_COUNT) {
executorService.submit {
try {
runBlocking {
lockManager.lock("1") {
countService1.increase()
logger.info { "1 ${countService1.counter}" }
}
}
successCount.getAndIncrement()
} catch (e: Throwable) {
logger.info { e.toString() }
} finally {
latch.countDown()
}
}
executorService.submit {
try {
runBlocking {
lockManager.lock("2") {
countService2.increase()
logger.info { "2 ${countService2.counter}" }
}
}
successCount.getAndIncrement()
} catch (e: Throwable) {
logger.info { e.toString() }
} finally {
latch.countDown()
}
}
executorService.submit {
try {
runBlocking {
lockManager.lock("3") {
countService3.increase()
logger.info { "3 ${countService3.counter}" }
}
}
successCount.getAndIncrement()
} catch (e: Throwable) {
logger.info { e.toString() }
} finally {
latch.countDown()
}
}
}
latch.await()

countService1.counter shouldBeEqual CONCURRENT_COUNT
countService2.counter shouldBeEqual CONCURRENT_COUNT
countService3.counter shouldBeEqual CONCURRENT_COUNT
successCount.get() shouldBeEqual CONCURRENT_COUNT * 3L
}
}
}
})

private class CountService {
var counter: Int = 0

suspend fun increase() {
val curCount = counter
delay(5)
counter = curCount + 1
}
}
Loading

0 comments on commit e47b056

Please sign in to comment.