diff --git a/shared/src/commonMain/kotlin/dev/sasikanth/rss/reader/opml/OpmlManager.kt b/shared/src/commonMain/kotlin/dev/sasikanth/rss/reader/opml/OpmlManager.kt index 326571d58..cbc0746d5 100644 --- a/shared/src/commonMain/kotlin/dev/sasikanth/rss/reader/opml/OpmlManager.kt +++ b/shared/src/commonMain/kotlin/dev/sasikanth/rss/reader/opml/OpmlManager.kt @@ -32,8 +32,11 @@ import kotlin.time.measureTime import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll import kotlinx.coroutines.cancelChildren import kotlinx.coroutines.channels.ProducerScope +import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.flow.SharedFlow @@ -41,7 +44,6 @@ import kotlinx.coroutines.flow.channelFlow import kotlinx.coroutines.flow.collect import kotlinx.coroutines.flow.onCompletion import kotlinx.coroutines.flow.onEach -import kotlinx.coroutines.joinAll import kotlinx.coroutines.launch import kotlinx.coroutines.withContext import me.tatarka.inject.annotations.Inject @@ -174,36 +176,55 @@ class OpmlManager( } private fun addOpmlSources(sources: List): Flow = channelFlow { - val totalSourcesCount = sources.size + val feeds = sources.filterIsInstance() + val groups = sources.filterIsInstance() + val totalSourcesCount = feeds.size + groups.flatMap { it.feeds }.size val processedFeedsCount = AtomicInt(0) - if (sources.size > IMPORT_CHUNKS) { - sources.reversed().chunked(IMPORT_CHUNKS).forEach { sourcesInChunk -> - sourcesInChunk.map { source -> launch { createSourceInDB(source) } }.joinAll() + if (feeds.isNotEmpty()) { + addFeeds(feeds, processedFeedsCount, totalSourcesCount) + } - val size = processedFeedsCount.addAndGet(sourcesInChunk.size) - sendProgress(size, totalSourcesCount) - } - } else { - sources.reversed().forEachIndexed { index, source -> - launch { createSourceInDB(source) }.join() + if (groups.isNotEmpty()) { + // Since groups can contain multiple feeds, we don't want to add them in parallel + groups.forEach { group -> + val feedIds = addFeeds(group.feeds, processedFeedsCount, totalSourcesCount) + val groupId = rssRepository.createGroup(group.title) - sendProgress(index, totalSourcesCount) + rssRepository.addFeedIdsToGroups(groupIds = setOf(groupId), feedIds = feedIds) } } } - private suspend fun createSourceInDB(source: OpmlSource) { - when (source) { - is OpmlFeed -> { - addFeed(source) - } - is OpmlFeedGroup -> { - val groupId = rssRepository.createGroup(source.title) - val feedIds = source.feeds.mapNotNull { feed -> addFeed(feed) } + private suspend fun ProducerScope.addFeeds( + feeds: List, + processedFeedsCount: AtomicInt, + totalFeedsCount: Int, + ): List { + return coroutineScope { + val ids: List = + feeds + .reversed() + .chunked(IMPORT_CHUNKS) + .map { sourcesInChunk -> + val ids = + sourcesInChunk + .map { feed -> + async { + addFeed(feed).also { + val progressIndex = processedFeedsCount.incrementAndGet() + send(calculateProgress(progressIndex, totalFeedsCount)) + } + } + } + .awaitAll() + .filterNotNull() - rssRepository.addFeedIdsToGroups(groupIds = setOf(groupId), feedIds = feedIds) - } + return@map ids + } + .flatten() + + return@coroutineScope ids } } @@ -217,10 +238,8 @@ class OpmlManager( return result.feedId } - private suspend fun ProducerScope.sendProgress(progressIndex: Int, totalFeedCount: Int) { - // We are converting the total feed count to float - // so that we can get the precise progress like 0.1, 0.2..etc., - send(((progressIndex / totalFeedCount.toFloat()) * 100).roundToInt()) + private fun calculateProgress(progressIndex: Int, totalFeedCount: Int): Int { + return ((progressIndex / totalFeedCount.toFloat()) * 100).roundToInt() } }