From 93d1108622325cb820197cc4b7e7456384c1c668 Mon Sep 17 00:00:00 2001 From: ShiftHackZ Date: Sat, 30 Dec 2023 16:21:40 +0200 Subject: [PATCH 1/2] Multiple LD Models | Patch 1 --- app/build.gradle | 4 +- .../aisdv1/app/AiStableDiffusionClientApp.kt | 2 + .../aisdv1/app/di/ProvidersModule.kt | 5 + .../local/DownloadableModelLocalDataSource.kt | 57 +++++- .../data/mappers/LocalAiModelMappers.kt | 37 ++++ .../data/preference/PreferenceManagerImpl.kt | 7 + .../DownloadableModelRemoteDataSource.kt | 23 ++- .../DownloadableModelRepositoryImpl.kt | 20 +- .../LocalDiffusionGenerationRepositoryImpl.kt | 8 +- .../datasource/DownloadableModelDataSource.kt | 13 +- .../aisdv1/domain/di/DomainModule.kt | 9 +- .../aisdv1/domain/entity/Configuration.kt | 1 + .../aisdv1/domain/entity/LocalAiModel.kt | 10 + .../domain/preference/PreferenceManager.kt | 1 + .../repository/DownloadableModelRepository.kt | 10 +- .../CheckDownloadedModelUseCase.kt | 7 - .../downloadable/DeleteModelUseCase.kt | 2 +- .../downloadable/DeleteModelUseCaseImpl.kt | 2 +- .../downloadable/DownloadModelUseCase.kt | 2 +- .../downloadable/DownloadModelUseCaseImpl.kt | 2 +- .../downloadable/GetLocalAiModelsUseCase.kt | 8 + ...Impl.kt => GetLocalAiModelsUseCaseImpl.kt} | 6 +- .../downloadable/SelectLocalAiModelUseCase.kt | 7 + .../SelectLocalAiModelUseCaseImpl.kt | 10 + .../settings/GetConfigurationUseCaseImpl.kt | 1 + .../SetServerConfigurationUseCaseImpl.kt | 1 + .../ai/tokenizer/EnglishTextTokenizer.kt | 8 +- .../aisdv1/feature/diffusion/ai/unet/UNet.kt | 5 +- .../feature/diffusion/ai/vae/VaeDecoder.kt | 4 +- .../environment/LocalModelIdProvider.kt | 5 + .../api/sdai/DownloadableModelsRestApi.kt | 6 +- .../api/sdai/DownloadableModelsRestApiImpl.kt | 13 +- .../response/DownloadableModelResponse.kt | 4 + .../aisdv1/presentation/di/ViewModelModule.kt | 3 +- .../screen/setup/ServerSetupContract.kt | 14 +- .../screen/setup/ServerSetupScreen.kt | 104 ++--------- .../screen/setup/ServerSetupViewModel.kt | 152 +++++++++++----- .../screen/setup/mappers/LocalModelMappers.kt | 27 +++ .../widget/item/LocalModelItem.kt | 120 ++++++++++++ .../3.json | 171 ++++++++++++++++++ .../db/persistent/PersistentDatabase.kt | 17 +- .../persistent/contract/LocalModelContract.kt | 10 + .../db/persistent/dao/LocalModelDao.kt | 26 +++ .../db/persistent/entity/LocalModelEntity.kt | 19 ++ .../aisdv1/storage/di/DatabaseModule.kt | 1 + 45 files changed, 768 insertions(+), 196 deletions(-) create mode 100644 data/src/main/java/com/shifthackz/aisdv1/data/mappers/LocalAiModelMappers.kt create mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalAiModel.kt delete mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/CheckDownloadedModelUseCase.kt create mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCase.kt rename domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/{CheckDownloadedModelUseCaseImpl.kt => GetLocalAiModelsUseCaseImpl.kt} (57%) create mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/SelectLocalAiModelUseCase.kt create mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/SelectLocalAiModelUseCaseImpl.kt create mode 100644 feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/environment/LocalModelIdProvider.kt create mode 100644 presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/LocalModelMappers.kt create mode 100644 presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/item/LocalModelItem.kt create mode 100644 storage/schemas/com.shifthackz.aisdv1.storage.db.persistent.PersistentDatabase/3.json create mode 100644 storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/contract/LocalModelContract.kt create mode 100644 storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/dao/LocalModelDao.kt create mode 100644 storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/entity/LocalModelEntity.kt diff --git a/app/build.gradle b/app/build.gradle index ab736b7b..2c179864 100755 --- a/app/build.gradle +++ b/app/build.gradle @@ -14,8 +14,8 @@ android { namespace 'com.shifthackz.aisdv1.app' defaultConfig { applicationId "com.shifthackz.aisdv1.app" - versionName "0.5.3" - versionCode 166 + versionName "0.5.4" + versionCode 167 buildConfigField "String", "IMAGE_CDN_URL", "\"https://random.imagecdn.app\"" buildConfigField "String", "HORDE_AI_URL", "\"https://stablehorde.net\"" diff --git a/app/src/main/java/com/shifthackz/aisdv1/app/AiStableDiffusionClientApp.kt b/app/src/main/java/com/shifthackz/aisdv1/app/AiStableDiffusionClientApp.kt index 07f7295d..6650f771 100755 --- a/app/src/main/java/com/shifthackz/aisdv1/app/AiStableDiffusionClientApp.kt +++ b/app/src/main/java/com/shifthackz/aisdv1/app/AiStableDiffusionClientApp.kt @@ -7,6 +7,7 @@ import com.shifthackz.aisdv1.app.di.featureModule import com.shifthackz.aisdv1.app.di.preferenceModule import com.shifthackz.aisdv1.app.di.providersModule import com.shifthackz.aisdv1.core.common.log.FileLoggingTree +import com.shifthackz.aisdv1.core.common.log.errorLog import com.shifthackz.aisdv1.core.imageprocessing.di.imageProcessingModule import com.shifthackz.aisdv1.core.validation.di.validatorsModule import com.shifthackz.aisdv1.data.di.dataModule @@ -25,6 +26,7 @@ class AiStableDiffusionClientApp : Application() { override fun onCreate() { super.onCreate() StrictMode.setVmPolicy(VmPolicy.Builder().build()) + Thread.currentThread().setUncaughtExceptionHandler { _, t -> errorLog(t) } initializeKoin() initializeLogging() } diff --git a/app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt b/app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt index 5d775258..29af6b38 100755 --- a/app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt +++ b/app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt @@ -11,6 +11,7 @@ import com.shifthackz.aisdv1.domain.feature.auth.AuthorizationStore import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.feature.diffusion.entity.LocalDiffusionFlag import com.shifthackz.aisdv1.feature.diffusion.environment.DeviceNNAPIFlagProvider +import com.shifthackz.aisdv1.feature.diffusion.environment.LocalModelIdProvider import com.shifthackz.aisdv1.network.qualifiers.ApiUrlProvider import com.shifthackz.aisdv1.network.qualifiers.CredentialsProvider import com.shifthackz.aisdv1.network.qualifiers.HordeApiKeyProvider @@ -108,4 +109,8 @@ val providersModule = module { .let(LocalDiffusionFlag::value) } } + + single { + LocalModelIdProvider { get().localModelId } + } } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt index 50e4f3df..0ba9fdfc 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt @@ -1,29 +1,70 @@ package com.shifthackz.aisdv1.data.local import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor -import com.shifthackz.aisdv1.core.common.log.debugLog +import com.shifthackz.aisdv1.data.mappers.mapDomainToEntity +import com.shifthackz.aisdv1.data.mappers.mapEntityToDomain import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource +import com.shifthackz.aisdv1.domain.entity.LocalAiModel +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.storage.db.persistent.dao.LocalModelDao +import com.shifthackz.aisdv1.storage.db.persistent.entity.LocalModelEntity import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Observable import io.reactivex.rxjava3.core.Single import java.io.File internal class DownloadableModelLocalDataSource( private val fileProviderDescriptor: FileProviderDescriptor, + private val dao: LocalModelDao, + private val preferenceManager: PreferenceManager, ) : DownloadableModelDataSource.Local { + override fun getAll(): Single> = dao.query() + .map(List::mapEntityToDomain) + .flatMap { models -> models.withLocalData() } - private val localModelDirectory: File - get() = File(fileProviderDescriptor.localModelDirPath) + override fun getById(id: String) = dao.queryById(id) + .map(LocalModelEntity::mapEntityToDomain) + .flatMap { model -> model.withLocalData() } - override fun exists(): Single = Single.create { emitter -> + override fun getSelected(): Single = Single + .just(preferenceManager.localModelId) + .flatMap(::getById) + .onErrorResumeNext { Single.error(Throwable("No selected model")) } + + override fun select(id: String): Completable = Completable.fromAction { + preferenceManager.localModelId = id + } + + override fun save(list: List) = dao.insertList(list.mapDomainToEntity()) + + override fun isDownloaded(id: String): Single = Single.create { emitter -> try { - val files = (localModelDirectory.listFiles()?.filter { it.isDirectory }) ?: emptyList() - if (!emitter.isDisposed) emitter.onSuccess(localModelDirectory.exists() && files.size == 4) + val localModelDir = getLocalModelDirectory(id) + val files = (localModelDir.listFiles()?.filter { it.isDirectory }) ?: emptyList() + if (!emitter.isDisposed) emitter.onSuccess(localModelDir.exists() && files.size == 4) } catch (e: Exception) { if (!emitter.isDisposed) emitter.onSuccess(false) } } - override fun delete(): Completable = Completable.fromAction { - localModelDirectory.deleteRecursively() + override fun delete(id: String): Completable = Completable.fromAction { + getLocalModelDirectory(id).deleteRecursively() } + + private fun getLocalModelDirectory(id: String): File { + return File("${fileProviderDescriptor.localModelDirPath}/${id}") + } + + private fun List.withLocalData(): Single> = Observable + .fromIterable(this) + .flatMapSingle { model -> model.withLocalData() } + .toList() + + private fun LocalAiModel.withLocalData(): Single = isDownloaded(id) + .map { downloaded -> + copy( + downloaded = downloaded, + selected = preferenceManager.localModelId == id, + ) + } } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/LocalAiModelMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/LocalAiModelMappers.kt new file mode 100644 index 00000000..7b8fbca6 --- /dev/null +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/LocalAiModelMappers.kt @@ -0,0 +1,37 @@ +package com.shifthackz.aisdv1.data.mappers + +import com.shifthackz.aisdv1.domain.entity.LocalAiModel +import com.shifthackz.aisdv1.network.response.DownloadableModelResponse +import com.shifthackz.aisdv1.storage.db.persistent.entity.LocalModelEntity + +//region RAW --> DOMAIN +fun List.mapRawToDomain(): List = + map(DownloadableModelResponse::mapRawToDomain) + +fun DownloadableModelResponse.mapRawToDomain(): LocalAiModel = with(this) { + LocalAiModel( + id = id ?: "", + name = name ?: "", + size = size ?: "", + sources = sources ?: emptyList(), + ) +} +//endregion + +//region DOMAIN --> ENTITY +fun List.mapDomainToEntity(): List = + map(LocalAiModel::mapDomainToEntity) + +fun LocalAiModel.mapDomainToEntity(): LocalModelEntity = with(this) { + LocalModelEntity(id, name, size, sources) +} +//endregion + +//region ENTITY --> DOMAIN +fun List.mapEntityToDomain(): List = + map(LocalModelEntity::mapEntityToDomain) + +fun LocalModelEntity.mapEntityToDomain(): LocalAiModel = with(this) { + LocalAiModel(id, name, size, sources) +} +//endregion diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt index 3163aa6d..7931a1f1 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt @@ -82,6 +82,12 @@ class PreferenceManagerImpl( .putBoolean(KEY_FORCE_SETUP_AFTER_UPDATE, value) .apply() + override var localModelId: String + get() = preferences.getString(KEY_LOCAL_MODEL_ID, "") ?: "" + set(value) = preferences.edit() + .putString(KEY_LOCAL_MODEL_ID, value) + .apply() + override var localUseNNAPI: Boolean get() = preferences.getBoolean(KEY_LOCAL_NN_API, false) set(value) = preferences.edit() @@ -117,6 +123,7 @@ class PreferenceManagerImpl( private const val KEY_SERVER_SOURCE = "key_server_source" private const val KEY_HORDE_API_KEY = "key_horde_api_key" private const val KEY_LOCAL_NN_API = "key_local_nn_api" + private const val KEY_LOCAL_MODEL_ID = "key_local_model_id" private const val KEY_FORCE_SETUP_AFTER_UPDATE = "force_upd_setup_v0.x.x-v0.5.3" } } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSource.kt index e289be0d..8f6720de 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSource.kt @@ -2,11 +2,15 @@ package com.shifthackz.aisdv1.data.remote import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor import com.shifthackz.aisdv1.core.common.file.unzip +import com.shifthackz.aisdv1.data.mappers.mapRawToDomain import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource import com.shifthackz.aisdv1.domain.entity.DownloadState +import com.shifthackz.aisdv1.domain.entity.LocalAiModel import com.shifthackz.aisdv1.network.api.sdai.DownloadableModelsRestApi +import com.shifthackz.aisdv1.network.response.DownloadableModelResponse import io.reactivex.rxjava3.core.Completable import io.reactivex.rxjava3.core.Observable +import io.reactivex.rxjava3.core.Single import java.io.File internal class DownloadableModelRemoteDataSource( @@ -14,18 +18,21 @@ internal class DownloadableModelRemoteDataSource( private val fileProviderDescriptor: FileProviderDescriptor, ) : DownloadableModelDataSource.Remote { - private val destinationPath = "${fileProviderDescriptor.localModelDirPath}/model.zip" + override fun fetch() = api + .fetchDownloadableModels() + .map(List::mapRawToDomain) - override fun download() = Completable + override fun download(id: String, url: String): Observable = Completable .fromAction { - val dir = File(fileProviderDescriptor.localModelDirPath) - val destination = File(destinationPath) + val dir = File("${fileProviderDescriptor.localModelDirPath}/${id}") + val destination = File(getDestinationPath(id)) if (destination.exists()) destination.delete() if (!dir.exists()) dir.mkdirs() } .andThen( api.downloadModel( - "${fileProviderDescriptor.localModelDirPath}/model.zip", + remoteUrl = url, + localPath = getDestinationPath(id), stateProgress = DownloadState::Downloading, stateComplete = DownloadState::Complete, stateFailed = DownloadState::Error, @@ -43,10 +50,14 @@ internal class DownloadableModelRemoteDataSource( emitter.onError(e) } } - .andThen(Completable.fromAction { File(destinationPath).delete() }) + .andThen(Completable.fromAction { File(getDestinationPath(id)).delete() }) .andThen(chain) } else { chain } } + + private fun getDestinationPath(id: String): String { + return "${fileProviderDescriptor.localModelDirPath}/${id}/model.zip" + } } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImpl.kt index f92d4e5e..5ee0d356 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImpl.kt @@ -2,16 +2,28 @@ package com.shifthackz.aisdv1.data.repository import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository -import io.reactivex.rxjava3.core.Completable internal class DownloadableModelRepositoryImpl( private val remoteDataSource: DownloadableModelDataSource.Remote, private val localDataSource: DownloadableModelDataSource.Local, ) : DownloadableModelRepository { - override fun isModelDownloaded() = localDataSource.exists() + override fun isModelDownloaded(id: String) = localDataSource.isDownloaded(id) - override fun download() = remoteDataSource.download() + override fun download(id: String) = localDataSource + .getById(id) + .flatMapObservable { model -> + remoteDataSource.download(id, model.sources.firstOrNull() ?: "") + } - override fun delete() = localDataSource.delete() + override fun delete(id: String) = localDataSource.delete(id) + + override fun getAll() = remoteDataSource + .fetch() + .flatMapCompletable(localDataSource::save) + .andThen(localDataSource.getAll()) + .onErrorResumeNext { localDataSource.getAll() } + + override fun getById(id: String) = localDataSource.getById(id) + override fun select(id: String) = localDataSource.select(id) } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImpl.kt index 37368669..9d6a1661 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImpl.kt @@ -12,8 +12,6 @@ import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion import com.shifthackz.aisdv1.domain.gateway.MediaStoreGateway import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepository -import io.reactivex.rxjava3.core.Flowable -import io.reactivex.rxjava3.core.Observable import io.reactivex.rxjava3.core.Single internal class LocalDiffusionGenerationRepositoryImpl( @@ -35,9 +33,9 @@ internal class LocalDiffusionGenerationRepositoryImpl( override fun observeStatus() = localDiffusion.observeStatus() override fun generateFromText(payload: TextToImagePayload) = downloadableLocalDataSource - .exists() - .flatMap { modelDownloaded -> - if (modelDownloaded) generate(payload) + .getSelected() + .flatMap { model -> + if (model.downloaded) generate(payload) else Single.error(Throwable("Model not downloaded")) } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/DownloadableModelDataSource.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/DownloadableModelDataSource.kt index 1f2b6bd8..801d14b9 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/DownloadableModelDataSource.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/DownloadableModelDataSource.kt @@ -1,6 +1,7 @@ package com.shifthackz.aisdv1.domain.datasource import com.shifthackz.aisdv1.domain.entity.DownloadState +import com.shifthackz.aisdv1.domain.entity.LocalAiModel import io.reactivex.rxjava3.core.Completable import io.reactivex.rxjava3.core.Observable import io.reactivex.rxjava3.core.Single @@ -8,11 +9,17 @@ import io.reactivex.rxjava3.core.Single sealed interface DownloadableModelDataSource { interface Remote : DownloadableModelDataSource { - fun download(): Observable + fun fetch(): Single> + fun download(id: String, url: String): Observable } interface Local : DownloadableModelDataSource { - fun exists(): Single - fun delete(): Completable + fun getAll(): Single> + fun getById(id: String): Single + fun getSelected(): Single + fun select(id: String): Completable + fun save(list: List): Completable + fun isDownloaded(id: String): Single + fun delete(id: String): Completable } } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt index 3c84c5b4..96fe7d18 100755 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt @@ -18,12 +18,14 @@ import com.shifthackz.aisdv1.domain.usecase.connectivity.TestHordeApiKeyUseCase import com.shifthackz.aisdv1.domain.usecase.connectivity.TestHordeApiKeyUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.debug.DebugInsertBadBase64UseCase import com.shifthackz.aisdv1.domain.usecase.debug.DebugInsertBadBase64UseCaseImpl -import com.shifthackz.aisdv1.domain.usecase.downloadable.CheckDownloadedModelUseCase -import com.shifthackz.aisdv1.domain.usecase.downloadable.CheckDownloadedModelUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.downloadable.DeleteModelUseCase import com.shifthackz.aisdv1.domain.usecase.downloadable.DeleteModelUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCase import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCaseImpl +import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalAiModelsUseCase +import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalAiModelsUseCaseImpl +import com.shifthackz.aisdv1.domain.usecase.downloadable.SelectLocalAiModelUseCase +import com.shifthackz.aisdv1.domain.usecase.downloadable.SelectLocalAiModelUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.gallery.DeleteGalleryItemUseCase import com.shifthackz.aisdv1.domain.usecase.gallery.DeleteGalleryItemUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.gallery.GetAllGalleryUseCase @@ -88,9 +90,10 @@ internal val useCasesModule = module { factoryOf(::SaveLastResultToCacheUseCaseImpl) bind SaveLastResultToCacheUseCase::class factoryOf(::GetLastResultFromCacheUseCaseImpl) bind GetLastResultFromCacheUseCase::class factoryOf(::ObserveLocalDiffusionProcessStatusUseCaseImpl) bind ObserveLocalDiffusionProcessStatusUseCase::class + factoryOf(::GetLocalAiModelsUseCaseImpl) bind GetLocalAiModelsUseCase::class factoryOf(::DownloadModelUseCaseImpl) bind DownloadModelUseCase::class factoryOf(::DeleteModelUseCaseImpl) bind DeleteModelUseCase::class - factoryOf(::CheckDownloadedModelUseCaseImpl) bind CheckDownloadedModelUseCase::class + factoryOf(::SelectLocalAiModelUseCaseImpl) bind SelectLocalAiModelUseCase::class } internal val debugModule = module { diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt index f135e2f2..89b6d28f 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt @@ -8,4 +8,5 @@ data class Configuration( val source: ServerSource, val hordeApiKey: String, val authCredentials: AuthorizationCredentials, + val localModelId: String, ) diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalAiModel.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalAiModel.kt new file mode 100644 index 00000000..22c0e68a --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalAiModel.kt @@ -0,0 +1,10 @@ +package com.shifthackz.aisdv1.domain.entity + +data class LocalAiModel( + val id: String, + val name: String, + val size: String, + val sources: List, + val downloaded: Boolean = false, + val selected: Boolean = false, +) diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt index c1958c82..50428a95 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt @@ -14,6 +14,7 @@ interface PreferenceManager { var source: ServerSource var hordeApiKey: String var forceSetupAfterUpdate: Boolean + var localModelId: String var localUseNNAPI: Boolean fun observe(): Flowable diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/DownloadableModelRepository.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/DownloadableModelRepository.kt index b350933c..d5a50419 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/DownloadableModelRepository.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/DownloadableModelRepository.kt @@ -1,12 +1,16 @@ package com.shifthackz.aisdv1.domain.repository import com.shifthackz.aisdv1.domain.entity.DownloadState +import com.shifthackz.aisdv1.domain.entity.LocalAiModel import io.reactivex.rxjava3.core.Completable import io.reactivex.rxjava3.core.Observable import io.reactivex.rxjava3.core.Single interface DownloadableModelRepository { - fun isModelDownloaded(): Single - fun download(): Observable - fun delete(): Completable + fun isModelDownloaded(id: String): Single + fun download(id: String): Observable + fun delete(id: String): Completable + fun getAll(): Single> + fun getById(id: String): Single + fun select(id: String): Completable } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/CheckDownloadedModelUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/CheckDownloadedModelUseCase.kt deleted file mode 100644 index a34b3eb2..00000000 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/CheckDownloadedModelUseCase.kt +++ /dev/null @@ -1,7 +0,0 @@ -package com.shifthackz.aisdv1.domain.usecase.downloadable - -import io.reactivex.rxjava3.core.Single - -interface CheckDownloadedModelUseCase { - operator fun invoke(): Single -} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/DeleteModelUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/DeleteModelUseCase.kt index 4f3259cb..9fc4e7c6 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/DeleteModelUseCase.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/DeleteModelUseCase.kt @@ -3,5 +3,5 @@ package com.shifthackz.aisdv1.domain.usecase.downloadable import io.reactivex.rxjava3.core.Completable interface DeleteModelUseCase { - operator fun invoke(): Completable + operator fun invoke(id: String): Completable } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/DeleteModelUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/DeleteModelUseCaseImpl.kt index 2a50d898..b35a0043 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/DeleteModelUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/DeleteModelUseCaseImpl.kt @@ -6,5 +6,5 @@ internal class DeleteModelUseCaseImpl( private val downloadableModelRepository: DownloadableModelRepository, ) : DeleteModelUseCase { - override fun invoke() = downloadableModelRepository.delete() + override fun invoke(id: String) = downloadableModelRepository.delete(id) } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/DownloadModelUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/DownloadModelUseCase.kt index 7587636a..331fa243 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/DownloadModelUseCase.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/DownloadModelUseCase.kt @@ -4,5 +4,5 @@ import com.shifthackz.aisdv1.domain.entity.DownloadState import io.reactivex.rxjava3.core.Observable interface DownloadModelUseCase { - operator fun invoke(): Observable + operator fun invoke(id: String): Observable } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/DownloadModelUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/DownloadModelUseCaseImpl.kt index 48638f5a..6c5cf292 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/DownloadModelUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/DownloadModelUseCaseImpl.kt @@ -6,5 +6,5 @@ internal class DownloadModelUseCaseImpl( private val downloadableModelRepository: DownloadableModelRepository, ) : DownloadModelUseCase { - override fun invoke() = downloadableModelRepository.download() + override fun invoke(id: String) = downloadableModelRepository.download(id) } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCase.kt new file mode 100644 index 00000000..efe71374 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCase.kt @@ -0,0 +1,8 @@ +package com.shifthackz.aisdv1.domain.usecase.downloadable + +import com.shifthackz.aisdv1.domain.entity.LocalAiModel +import io.reactivex.rxjava3.core.Single + +interface GetLocalAiModelsUseCase { + operator fun invoke(): Single> +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/CheckDownloadedModelUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCaseImpl.kt similarity index 57% rename from domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/CheckDownloadedModelUseCaseImpl.kt rename to domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCaseImpl.kt index 57df7f66..7bdbe7e7 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/CheckDownloadedModelUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCaseImpl.kt @@ -2,9 +2,9 @@ package com.shifthackz.aisdv1.domain.usecase.downloadable import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository -internal class CheckDownloadedModelUseCaseImpl( +internal class GetLocalAiModelsUseCaseImpl( private val downloadableModelRepository: DownloadableModelRepository, -) : CheckDownloadedModelUseCase { +) : GetLocalAiModelsUseCase { - override fun invoke() = downloadableModelRepository.isModelDownloaded() + override fun invoke() = downloadableModelRepository.getAll() } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/SelectLocalAiModelUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/SelectLocalAiModelUseCase.kt new file mode 100644 index 00000000..96a3dbbc --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/SelectLocalAiModelUseCase.kt @@ -0,0 +1,7 @@ +package com.shifthackz.aisdv1.domain.usecase.downloadable + +import io.reactivex.rxjava3.core.Completable + +interface SelectLocalAiModelUseCase { + operator fun invoke(id: String): Completable +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/SelectLocalAiModelUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/SelectLocalAiModelUseCaseImpl.kt new file mode 100644 index 00000000..86b2b994 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/SelectLocalAiModelUseCaseImpl.kt @@ -0,0 +1,10 @@ +package com.shifthackz.aisdv1.domain.usecase.downloadable + +import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository + +internal class SelectLocalAiModelUseCaseImpl( + private val downloadableModelRepository: DownloadableModelRepository, +) : SelectLocalAiModelUseCase { + + override fun invoke(id: String) = downloadableModelRepository.select(id) +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt index 9903022c..1739f51a 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt @@ -17,6 +17,7 @@ internal class GetConfigurationUseCaseImpl( source = preferenceManager.source, hordeApiKey = preferenceManager.hordeApiKey, authCredentials = authorizationStore.getAuthorizationCredentials(), + localModelId = preferenceManager.localModelId, ) ) } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt index ad7218df..3d94cfc2 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt @@ -17,5 +17,6 @@ internal class SetServerConfigurationUseCaseImpl( preferenceManager.serverUrl = configuration.serverUrl preferenceManager.demoMode = configuration.demoMode preferenceManager.hordeApiKey = configuration.hordeApiKey + preferenceManager.localModelId = configuration.localModelId } } diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/tokenizer/EnglishTextTokenizer.kt b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/tokenizer/EnglishTextTokenizer.kt index 33fb0cbe..c8ddc672 100644 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/tokenizer/EnglishTextTokenizer.kt +++ b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/tokenizer/EnglishTextTokenizer.kt @@ -15,6 +15,7 @@ import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.ORT import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.ORT_KEY_MODEL_FORMAT import com.shifthackz.aisdv1.feature.diffusion.ai.extensions.halfCorner import com.shifthackz.aisdv1.feature.diffusion.ai.extensions.toArrays +import com.shifthackz.aisdv1.feature.diffusion.environment.LocalModelIdProvider import com.shifthackz.aisdv1.feature.diffusion.environment.OrtEnvironmentProvider import java.io.BufferedReader import java.io.FileInputStream @@ -27,6 +28,7 @@ import java.util.regex.Pattern internal class EnglishTextTokenizer( private val ortEnvironmentProvider: OrtEnvironmentProvider, private val fileProviderDescriptor: FileProviderDescriptor, + private val localModelIdProvider: LocalModelIdProvider, ) : LocalDiffusionTextTokenizer { private val pattern = Pattern.compile(TOKENIZER_REGEX) @@ -45,7 +47,7 @@ internal class EnglishTextTokenizer( val options = OrtSession.SessionOptions() options.addConfigEntry(ORT_KEY_MODEL_FORMAT, ORT) session = ortEnvironmentProvider.get().createSession( - "${fileProviderDescriptor.localModelDirPath}/${LocalDiffusionContract.TOKENIZER_MODEL}", + "${fileProviderDescriptor.localModelDirPath}/${localModelIdProvider.get()}/${LocalDiffusionContract.TOKENIZER_MODEL}", options ) if (!isInitMap) { @@ -203,7 +205,7 @@ internal class EnglishTextTokenizer( private fun loadEncoder(): Map { val map: MutableMap = HashMap() try { - val path = "${fileProviderDescriptor.localModelDirPath}/${LocalDiffusionContract.TOKENIZER_VOCABULARY}" + val path = "${fileProviderDescriptor.localModelDirPath}/${localModelIdProvider.get()}/${LocalDiffusionContract.TOKENIZER_VOCABULARY}" val jsonReader = JsonReader(InputStreamReader(FileInputStream(path))) jsonReader.beginObject() while (jsonReader.hasNext()) { @@ -229,7 +231,7 @@ internal class EnglishTextTokenizer( private fun loadBpeRanks(): Map, Int?> { val result: MutableMap, Int?> = HashMap() try { - val path = "${fileProviderDescriptor.localModelDirPath}/${LocalDiffusionContract.TOKENIZER_MERGES}" + val path = "${fileProviderDescriptor.localModelDirPath}/${localModelIdProvider.get()}/${LocalDiffusionContract.TOKENIZER_MERGES}" val reader = BufferedReader(InputStreamReader(FileInputStream(path))) var line: String var startLine = 1 diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/unet/UNet.kt b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/unet/UNet.kt index d4633085..43d311a4 100644 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/unet/UNet.kt +++ b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/unet/UNet.kt @@ -26,6 +26,7 @@ import com.shifthackz.aisdv1.feature.diffusion.entity.LocalDiffusionTensor import com.shifthackz.aisdv1.feature.diffusion.environment.OrtEnvironmentProvider import com.shifthackz.aisdv1.feature.diffusion.entity.LocalDiffusionFlag import com.shifthackz.aisdv1.feature.diffusion.environment.DeviceNNAPIFlagProvider +import com.shifthackz.aisdv1.feature.diffusion.environment.LocalModelIdProvider import java.nio.IntBuffer import java.util.EnumSet import java.util.Random @@ -37,12 +38,14 @@ internal class UNet( private val deviceNNAPIFlagProvider: DeviceNNAPIFlagProvider, private val ortEnvironmentProvider: OrtEnvironmentProvider, private val fileProviderDescriptor: FileProviderDescriptor, + private val localModelIdProvider: LocalModelIdProvider, ) { private val decoder: VaeDecoder get() = VaeDecoder( ortEnvironmentProvider, fileProviderDescriptor, + localModelIdProvider, deviceNNAPIFlagProvider.get(), ) @@ -62,7 +65,7 @@ internal class UNet( options.addNnapi(EnumSet.of(NNAPIFlags.CPU_DISABLED)) } session = ortEnvironmentProvider.get().createSession( - "${fileProviderDescriptor.localModelDirPath}/${LocalDiffusionContract.UNET_MODEL}", + "${fileProviderDescriptor.localModelDirPath}/${localModelIdProvider.get()}/${LocalDiffusionContract.UNET_MODEL}", options ) } diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/vae/VaeDecoder.kt b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/vae/VaeDecoder.kt index 11564c14..82624134 100644 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/vae/VaeDecoder.kt +++ b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/vae/VaeDecoder.kt @@ -12,12 +12,14 @@ import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.ORT import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.ORT_KEY_MODEL_FORMAT import com.shifthackz.aisdv1.feature.diffusion.environment.OrtEnvironmentProvider import com.shifthackz.aisdv1.feature.diffusion.entity.LocalDiffusionFlag +import com.shifthackz.aisdv1.feature.diffusion.environment.LocalModelIdProvider import java.util.EnumSet import kotlin.math.roundToInt internal class VaeDecoder( private val ortEnvironmentProvider: OrtEnvironmentProvider, private val fileProviderDescriptor: FileProviderDescriptor, + private val localModelIdProvider: LocalModelIdProvider, private val deviceId: Int, ) { @@ -63,7 +65,7 @@ internal class VaeDecoder( options.addNnapi(EnumSet.of(NNAPIFlags.CPU_DISABLED)) } session = ortEnvironmentProvider.get().createSession( - "${fileProviderDescriptor.localModelDirPath}/${LocalDiffusionContract.VAE_MODEL}", + "${fileProviderDescriptor.localModelDirPath}/${localModelIdProvider.get()}/${LocalDiffusionContract.VAE_MODEL}", options ) } diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/environment/LocalModelIdProvider.kt b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/environment/LocalModelIdProvider.kt new file mode 100644 index 00000000..3c24f58b --- /dev/null +++ b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/environment/LocalModelIdProvider.kt @@ -0,0 +1,5 @@ +package com.shifthackz.aisdv1.feature.diffusion.environment + +fun interface LocalModelIdProvider { + fun get(): String +} diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsRestApi.kt b/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsRestApi.kt index 27e34e36..8a777db0 100644 --- a/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsRestApi.kt +++ b/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsRestApi.kt @@ -11,8 +11,12 @@ import java.io.File interface DownloadableModelsRestApi { + @GET("/models.json") + fun fetchDownloadableModels(): Single> + fun downloadModel( - path: String, + remoteUrl: String, + localPath: String, stateProgress: (Int) -> T, stateComplete: (File) -> T, stateFailed: (Throwable) -> T, diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsRestApiImpl.kt b/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsRestApiImpl.kt index 83da10a5..e60eff52 100644 --- a/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsRestApiImpl.kt +++ b/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsRestApiImpl.kt @@ -2,22 +2,25 @@ package com.shifthackz.aisdv1.network.api.sdai import com.shifthackz.aisdv1.network.extensions.saveFile import io.reactivex.rxjava3.core.Observable +import io.reactivex.rxjava3.core.Single import java.io.File internal class DownloadableModelsRestApiImpl( private val rawApi: DownloadableModelsRestApi.RawApi, ) : DownloadableModelsRestApi { + override fun fetchDownloadableModels() = rawApi.fetchDownloadableModels() + override fun downloadModel( - path: String, + remoteUrl: String, + localPath: String, stateProgress: (Int) -> T, stateComplete: (File) -> T, stateFailed: (Throwable) -> T - ): Observable = rawApi - .fetchDownloadableModels() - .map { models -> models.first().sources?.first() ?: "" } + ): Observable = Single + .just(remoteUrl) .flatMap(rawApi::downloadModel) .flatMapObservable { body -> - body.saveFile(path, stateProgress, stateComplete, stateFailed) + body.saveFile(localPath, stateProgress, stateComplete, stateFailed) } } diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/response/DownloadableModelResponse.kt b/network/src/main/java/com/shifthackz/aisdv1/network/response/DownloadableModelResponse.kt index e257ccf0..5f08ac66 100644 --- a/network/src/main/java/com/shifthackz/aisdv1/network/response/DownloadableModelResponse.kt +++ b/network/src/main/java/com/shifthackz/aisdv1/network/response/DownloadableModelResponse.kt @@ -3,8 +3,12 @@ package com.shifthackz.aisdv1.network.response import com.google.gson.annotations.SerializedName data class DownloadableModelResponse( + @SerializedName("id") + val id: String?, @SerializedName("name") val name: String?, + @SerializedName("size") + val size: String?, @SerializedName("sources") val sources: List?, ) diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/di/ViewModelModule.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/di/ViewModelModule.kt index 039adabe..2392fe0a 100755 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/di/ViewModelModule.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/di/ViewModelModule.kt @@ -44,9 +44,10 @@ val viewModelModule = module { testConnectivityUseCase = get(), testHordeApiKeyUseCase = get(), setServerConfigurationUseCase = get(), + selectLocalAiModelUseCase = get(), downloadModelUseCase = get(), deleteModelUseCase = get(), - checkDownloadedModelUseCase = get(), + getLocalAiModelsUseCase = get(), dataPreLoaderUseCase = get(), schedulersProvider = get(), preferenceManager = get(), diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupContract.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupContract.kt index e131d7eb..5fa9c274 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupContract.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupContract.kt @@ -32,8 +32,9 @@ data class ServerSetupState( val originalLogin: String = "", val password: String = "", val originalPassword: String = "", - val localModelDownloaded: Boolean = false, - val downloadState: DownloadState = DownloadState.Unknown, + val localModels: List = emptyList(), +// val localModelDownloaded: Boolean = false, +// val downloadState: DownloadState = DownloadState.Unknown, val passwordVisible: Boolean = false, val serverUrlValidationError: UiText? = null, val loginValidationError: UiText? = null, @@ -107,6 +108,15 @@ data class ServerSetupState( ANONYMOUS, HTTP_BASIC; } + + data class LocalModel( + val id: String, + val name: String, + val size: String, + val downloaded: Boolean = false, + val downloadState: DownloadState = DownloadState.Unknown, + val selected: Boolean = false, + ) } enum class ServerSetupLaunchSource(val key: Int) { diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreen.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreen.kt index 044d3f66..7a30645a 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreen.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreen.kt @@ -5,12 +5,10 @@ package com.shifthackz.aisdv1.presentation.screen.setup import androidx.compose.foundation.background import androidx.compose.foundation.border import androidx.compose.foundation.clickable -import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Spacer -import androidx.compose.foundation.layout.defaultMinSize import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.height @@ -30,16 +28,12 @@ import androidx.compose.material.icons.filled.Help import androidx.compose.material.icons.filled.Visibility import androidx.compose.material.icons.filled.VisibilityOff import androidx.compose.material.icons.outlined.ArrowBack -import androidx.compose.material.icons.outlined.FileDownload -import androidx.compose.material.icons.outlined.FileDownloadDone -import androidx.compose.material.icons.outlined.FileDownloadOff import androidx.compose.material3.Button import androidx.compose.material3.CenterAlignedTopAppBar import androidx.compose.material3.Checkbox import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.Icon import androidx.compose.material3.IconButton -import androidx.compose.material3.LinearProgressIndicator import androidx.compose.material3.LocalContentColor import androidx.compose.material3.MaterialTheme import androidx.compose.material3.Scaffold @@ -49,7 +43,6 @@ import androidx.compose.material3.TextField import androidx.compose.runtime.Composable import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier -import androidx.compose.ui.draw.clip import androidx.compose.ui.graphics.Color import androidx.compose.ui.res.stringResource import androidx.compose.ui.text.font.FontWeight @@ -65,12 +58,12 @@ import com.shifthackz.aisdv1.core.common.links.LinksProvider import com.shifthackz.aisdv1.core.model.asString import com.shifthackz.aisdv1.core.model.asUiText import com.shifthackz.aisdv1.core.ui.MviScreen -import com.shifthackz.aisdv1.domain.entity.DownloadState import com.shifthackz.aisdv1.presentation.R import com.shifthackz.aisdv1.presentation.utils.Constants import com.shifthackz.aisdv1.presentation.widget.dialog.ErrorDialog import com.shifthackz.aisdv1.presentation.widget.dialog.ProgressDialog import com.shifthackz.aisdv1.presentation.widget.input.DropdownTextField +import com.shifthackz.aisdv1.presentation.widget.item.LocalModelItem import com.shifthackz.aisdv1.presentation.widget.item.SettingsItem import org.koin.core.component.KoinComponent import org.koin.core.component.inject @@ -103,7 +96,8 @@ class ServerSetupScreen( onServerInstructionsItemClick = { launchUrl(linksProvider.setupInstructionsUrl) }, onOpenHordeWebSite = { launchUrl(linksProvider.hordeUrl) }, onOpenHordeSignUpWebSite = { launchUrl(linksProvider.hordeSignUpUrl) }, - onDownloadCardButtonClick = viewModel::downloadClickReducer, + onDownloadCardButtonClick = viewModel::localModelDownloadClickReducer, + onSelectLocalModel = viewModel::localModelSelect, onSetupButtonClick = viewModel::connectToServer, onDismissScreenDialog = viewModel::dismissScreenDialog, ) @@ -132,7 +126,8 @@ private fun ScreenContent( onServerInstructionsItemClick: () -> Unit = {}, onOpenHordeWebSite: () -> Unit = {}, onOpenHordeSignUpWebSite: () -> Unit = {}, - onDownloadCardButtonClick: () -> Unit = {}, + onDownloadCardButtonClick: (ServerSetupState.LocalModel) -> Unit = {}, + onSelectLocalModel: (ServerSetupState.LocalModel) -> Unit = {}, onSetupButtonClick: () -> Unit = {}, onDismissScreenDialog: () -> Unit = {}, ) { @@ -168,7 +163,9 @@ private fun ScreenContent( .padding(bottom = 16.dp), onClick = onSetupButtonClick, enabled = when (state.mode) { - ServerSetupState.Mode.LOCAL -> state.localModelDownloaded + ServerSetupState.Mode.LOCAL -> state.localModels.any { + it.downloaded && it.selected + } else -> true }, ) { @@ -224,6 +221,7 @@ private fun ScreenContent( ServerSetupState.Mode.LOCAL -> LocalDiffusionSetupTab( state = state, onDownloadCardButtonClick = onDownloadCardButtonClick, + onSelectLocalModel = onSelectLocalModel, ) } } @@ -453,7 +451,8 @@ private fun HordeAiSetupTab( private fun LocalDiffusionSetupTab( modifier: Modifier = Modifier, state: ServerSetupState, - onDownloadCardButtonClick: () -> Unit = {}, + onDownloadCardButtonClick: (ServerSetupState.LocalModel) -> Unit = {}, + onSelectLocalModel: (ServerSetupState.LocalModel) -> Unit = {}, ) { Column( modifier = modifier.padding(horizontal = 16.dp), @@ -468,83 +467,16 @@ private fun LocalDiffusionSetupTab( fontWeight = FontWeight.Bold, ) Text( - modifier = Modifier.padding(top = 16.dp), + modifier = Modifier.padding(top = 16.dp, bottom = 16.dp), text = stringResource(id = R.string.hint_local_diffusion_sub_title), style = MaterialTheme.typography.bodyMedium, ) - Column( - modifier = modifier - .padding(top = 24.dp) - .fillMaxWidth() - .clip(RoundedCornerShape(16.dp)) - .background(color = MaterialTheme.colorScheme.surfaceTint.copy(alpha = 0.8f)) - .defaultMinSize(minHeight = 50.dp), - ) { - Row( - modifier = Modifier.padding(vertical = 4.dp), - horizontalArrangement = Arrangement.Center, - ) { - val icon = when (state.downloadState) { - is DownloadState.Downloading -> Icons.Outlined.FileDownload - else -> { - if (state.localModelDownloaded) Icons.Outlined.FileDownloadDone - else Icons.Outlined.FileDownloadOff - } - } - Icon( - modifier = modifier - .padding(horizontal = 8.dp) - .size(48.dp), - imageVector = icon, - contentDescription = "Download state", - ) - Column( - modifier = Modifier.padding(start = 4.dp) - ) { - Text( - text = stringResource(id = R.string.model_local_diffusion), - ) - Text( - text = stringResource(id = R.string.model_local_diffusion_size), - ) - } - Spacer(modifier = Modifier.weight(1f)) - Button( - modifier = Modifier.padding(end = 8.dp), - onClick = onDownloadCardButtonClick, - ) { - Text( - text = stringResource(id = when (state.downloadState) { - is DownloadState.Downloading -> R.string.cancel - is DownloadState.Error -> R.string.retry - else -> { - if (state.localModelDownloaded) R.string.delete - else R.string.download - } - }), - color = LocalContentColor.current, - ) - } - } - when (state.downloadState) { - is DownloadState.Downloading -> { - LinearProgressIndicator( - modifier = Modifier - .padding(8.dp) - .fillMaxWidth(), - progress = state.downloadState.percent / 100f, - ) - } - is DownloadState.Error -> { - Text( - modifier = Modifier - .padding(horizontal = 8.dp) - .padding(bottom = 8.dp), - text = stringResource(id = R.string.error_download_fail), - ) - } - else -> Unit - } + state.localModels.forEach { localModel -> + LocalModelItem( + model = localModel, + onDownloadCardButtonClick = onDownloadCardButtonClick, + onSelect = onSelectLocalModel, + ) } Text( modifier = Modifier.padding(top = 16.dp), diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt index f042846d..afa7cff9 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt @@ -18,15 +18,17 @@ import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.domain.usecase.caching.DataPreLoaderUseCase import com.shifthackz.aisdv1.domain.usecase.connectivity.TestConnectivityUseCase import com.shifthackz.aisdv1.domain.usecase.connectivity.TestHordeApiKeyUseCase -import com.shifthackz.aisdv1.domain.usecase.downloadable.CheckDownloadedModelUseCase import com.shifthackz.aisdv1.domain.usecase.downloadable.DeleteModelUseCase import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCase +import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalAiModelsUseCase +import com.shifthackz.aisdv1.domain.usecase.downloadable.SelectLocalAiModelUseCase import com.shifthackz.aisdv1.domain.usecase.settings.GetConfigurationUseCase import com.shifthackz.aisdv1.domain.usecase.settings.SetServerConfigurationUseCase import com.shifthackz.aisdv1.presentation.features.SetupConnectEvent import com.shifthackz.aisdv1.presentation.features.SetupConnectFailure import com.shifthackz.aisdv1.presentation.features.SetupConnectSuccess import com.shifthackz.aisdv1.presentation.screen.setup.mappers.mapToUi +import com.shifthackz.aisdv1.presentation.screen.setup.mappers.withNewState import com.shifthackz.aisdv1.presentation.utils.Constants import io.reactivex.rxjava3.core.Observable import io.reactivex.rxjava3.core.Single @@ -43,9 +45,11 @@ class ServerSetupViewModel( private val testConnectivityUseCase: TestConnectivityUseCase, private val testHordeApiKeyUseCase: TestHordeApiKeyUseCase, private val setServerConfigurationUseCase: SetServerConfigurationUseCase, + private val selectLocalAiModelUseCase: SelectLocalAiModelUseCase, private val downloadModelUseCase: DownloadModelUseCase, private val deleteModelUseCase: DeleteModelUseCase, - private val checkDownloadedModelUseCase: CheckDownloadedModelUseCase, +// private val checkDownloadedModelUseCase: CheckDownloadedModelUseCase, + private val getLocalAiModelsUseCase: GetLocalAiModelsUseCase, private val dataPreLoaderUseCase: DataPreLoaderUseCase, private val schedulersProvider: SchedulersProvider, private val preferenceManager: PreferenceManager, @@ -60,11 +64,11 @@ class ServerSetupViewModel( init { !getConfigurationUseCase() - .zipWith(checkDownloadedModelUseCase(), ::Pair) + .zipWith(getLocalAiModelsUseCase(), ::Pair) .subscribeOnMainThread(schedulersProvider) - .subscribeBy(::errorLog) { (configuration, isDownloaded) -> + .subscribeBy(::errorLog) { (configuration, localModels) -> currentState - .copy(localModelDownloaded = isDownloaded) + .copy(localModels = localModels.mapToUi()) .withSource(configuration.source) .withDemoMode(configuration.demoMode) .withServerUrl(configuration.serverUrl) @@ -152,7 +156,9 @@ class ServerSetupViewModel( validation.isValid } } - ServerSetupState.Mode.LOCAL -> currentState.localModelDownloaded + ServerSetupState.Mode.LOCAL -> { + currentState.localModels.find { it.selected && it.downloaded } != null + } } private fun connectToAutomaticInstance() { @@ -173,6 +179,7 @@ class ServerSetupViewModel( source = currentState.mode.toSource(), hordeApiKey = currentState.hordeApiKey, authCredentials = credentials, + localModelId = currentState.localModels.find { it.selected }?.id ?: "", ) ) .doOnSubscribe { setScreenDialog(ServerSetupState.Dialog.Communicating) } @@ -195,6 +202,7 @@ class ServerSetupViewModel( source = currentState.originalMode.toSource(), hordeApiKey = currentState.originalHordeApiKey, authCredentials = currentState.credentialsDomain(true), + localModelId = currentState.localModels.find { it.selected }?.id ?: "", ), ).andThen(Single.just(Result.failure(t))) } @@ -221,6 +229,7 @@ class ServerSetupViewModel( source = ServerSource.HORDE, hordeApiKey = testApiKey, authCredentials = AuthorizationCredentials.None, + localModelId = currentState.localModels.find { it.selected }?.id ?: "", ), ) .andThen(testHordeApiKeyUseCase()) @@ -237,6 +246,7 @@ class ServerSetupViewModel( source = currentState.originalMode.toSource(), hordeApiKey = currentState.originalHordeApiKey, authCredentials = AuthorizationCredentials.None, + localModelId = currentState.localModels.find { it.selected }?.id ?: "", ) ).andThen(Single.just(Result.failure(t))) } @@ -261,6 +271,7 @@ class ServerSetupViewModel( source = ServerSource.LOCAL, hordeApiKey = Constants.HORDE_DEFAULT_API_KEY, authCredentials = AuthorizationCredentials.None, + localModelId = currentState.localModels.find { it.selected }?.id ?: "", ), ) .andThen(Single.just(Result.success(Unit))) @@ -278,49 +289,98 @@ class ServerSetupViewModel( } } - fun downloadClickReducer() = when { - currentState.downloadState is DownloadState.Downloading -> { - downloadDisposable?.dispose() - downloadDisposable = null - setState(currentState.copy(downloadState = DownloadState.Unknown)) + fun localModelSelect(localModel: ServerSetupState.LocalModel) { + if (currentState.localModels.any { it.downloadState is DownloadState.Downloading }) { + return } - currentState.localModelDownloaded -> { - setState( - currentState.copy( - downloadState = DownloadState.Unknown, - localModelDownloaded = false, + setState( + currentState.copy( + localModels = currentState.localModels.withNewState( + localModel.copy(selected = true), + ), + ), + ) + } + + fun localModelDownloadClickReducer(localModel: ServerSetupState.LocalModel) { + when { + localModel.downloadState is DownloadState.Downloading -> { + downloadDisposable?.dispose() + downloadDisposable = null + setState( + currentState.copy( + localModels = currentState.localModels.withNewState( + localModel.copy(downloadState = DownloadState.Unknown), + ), + ), ) - ) - !deleteModelUseCase() - .subscribeOnMainThread(schedulersProvider) - .subscribeBy(::errorLog) - } - else -> { - setState(currentState.copy(downloadState = DownloadState.Downloading())) - downloadDisposable?.dispose() - downloadDisposable = null - downloadDisposable = downloadModelUseCase() - .distinctUntilChanged() - .subscribeOnMainThread(schedulersProvider) - .subscribeBy( - onError = { t -> - val message = t.localizedMessage ?: "Error" - setState(currentState.copy(downloadState = DownloadState.Error(t))) - setScreenDialog(ServerSetupState.Dialog.Error(message.asUiText())) - }, - onNext = { downloadState -> - debugLog("DOWNLOAD STATE : $downloadState") - val newState = when (downloadState) { - is DownloadState.Complete -> currentState.copy( - downloadState = downloadState, - localModelDownloaded = true, - ) - else -> currentState.copy(downloadState = downloadState) - } - setState(newState) - }, + } + localModel.downloaded -> { + setState( + currentState.copy( + localModels = currentState.localModels.withNewState( + localModel.copy( + downloadState = DownloadState.Unknown, + downloaded = false, + ), + ), + ) + ) + !deleteModelUseCase(localModel.id) + .subscribeOnMainThread(schedulersProvider) + .subscribeBy(::errorLog) + } + else -> { + setState( + currentState.copy( + localModels = currentState.localModels.withNewState( + localModel.copy( + downloadState = DownloadState.Downloading(), + ), + ), + ), ) - .apply { addToDisposable() } + downloadDisposable?.dispose() + downloadDisposable = null + downloadDisposable = downloadModelUseCase(localModel.id) + .distinctUntilChanged() + .subscribeOnMainThread(schedulersProvider) + .subscribeBy( + onError = { t -> + val message = t.localizedMessage ?: "Error" + setState( + currentState.copy( + localModels = currentState.localModels.withNewState( + localModel.copy( + downloadState = DownloadState.Error(t), + ), + ), + ), + ) + setScreenDialog(ServerSetupState.Dialog.Error(message.asUiText())) + }, + onNext = { downloadState -> + debugLog("DOWNLOAD STATE : $downloadState") + val newState = when (downloadState) { + is DownloadState.Complete -> currentState.copy( + localModels = currentState.localModels.withNewState( + localModel.copy( + downloadState = downloadState, + downloaded = true, + ), + ), + ) + else -> currentState.copy( + localModels = currentState.localModels.withNewState( + localModel.copy(downloadState = downloadState), + ), + ) + } + setState(newState) + }, + ) + .apply { addToDisposable() } + } } } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/LocalModelMappers.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/LocalModelMappers.kt new file mode 100644 index 00000000..83207985 --- /dev/null +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/LocalModelMappers.kt @@ -0,0 +1,27 @@ +package com.shifthackz.aisdv1.presentation.screen.setup.mappers + +import com.shifthackz.aisdv1.domain.entity.LocalAiModel +import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState + +fun List.mapToUi(): List = map(LocalAiModel::mapToUi) + +fun LocalAiModel.mapToUi(): ServerSetupState.LocalModel = with(this) { + ServerSetupState.LocalModel( + id = id, + name = name, + size = size, + downloaded = downloaded, + selected = selected, + ) +} + +fun List.withNewState( + model: ServerSetupState.LocalModel, +): List = + map { + if (it.id == model.id) model + else { + if (model.selected) it.copy(selected = false) + else it + } + } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/item/LocalModelItem.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/item/LocalModelItem.kt new file mode 100644 index 00000000..b0f44ae5 --- /dev/null +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/item/LocalModelItem.kt @@ -0,0 +1,120 @@ +package com.shifthackz.aisdv1.presentation.widget.item + +import androidx.compose.foundation.background +import androidx.compose.foundation.border +import androidx.compose.foundation.clickable +import androidx.compose.foundation.layout.Arrangement +import androidx.compose.foundation.layout.Column +import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.Spacer +import androidx.compose.foundation.layout.defaultMinSize +import androidx.compose.foundation.layout.fillMaxWidth +import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.size +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.outlined.FileDownload +import androidx.compose.material.icons.outlined.FileDownloadDone +import androidx.compose.material.icons.outlined.FileDownloadOff +import androidx.compose.material3.Button +import androidx.compose.material3.Icon +import androidx.compose.material3.LinearProgressIndicator +import androidx.compose.material3.LocalContentColor +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.Text +import androidx.compose.runtime.Composable +import androidx.compose.ui.Modifier +import androidx.compose.ui.draw.clip +import androidx.compose.ui.graphics.Color +import androidx.compose.ui.res.stringResource +import androidx.compose.ui.text.capitalize +import androidx.compose.ui.text.intl.Locale +import androidx.compose.ui.unit.dp +import com.shifthackz.aisdv1.domain.entity.DownloadState +import com.shifthackz.aisdv1.presentation.R +import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState + +@Composable +fun LocalModelItem( + modifier: Modifier = Modifier, + model: ServerSetupState.LocalModel, + onDownloadCardButtonClick: (ServerSetupState.LocalModel) -> Unit = {}, + onSelect: (ServerSetupState.LocalModel) -> Unit = {} +) { + Column( + modifier = modifier + .padding(vertical = 8.dp) + .fillMaxWidth() + .clip(RoundedCornerShape(16.dp)) + .background(color = MaterialTheme.colorScheme.surfaceTint.copy(alpha = 0.8f)) + .defaultMinSize(minHeight = 50.dp) + .border( + width = 2.dp, + shape = RoundedCornerShape(16.dp), + color = if (model.selected) MaterialTheme.colorScheme.primary else Color.Transparent, + ) + .clickable { onSelect(model) }, + ) { + Row( + modifier = Modifier.padding(vertical = 4.dp), + horizontalArrangement = Arrangement.Center, + ) { + val icon = when (model.downloadState) { + is DownloadState.Downloading -> Icons.Outlined.FileDownload + else -> { + if (model.downloaded) Icons.Outlined.FileDownloadDone + else Icons.Outlined.FileDownloadOff + } + } + Icon( + modifier = modifier + .padding(horizontal = 8.dp) + .size(48.dp), + imageVector = icon, + contentDescription = "Download state", + ) + Column( + modifier = Modifier.padding(start = 4.dp) + ) { + Text(text = model.name) + Text(model.size) + } + Spacer(modifier = Modifier.weight(1f)) + Button( + modifier = Modifier.padding(end = 8.dp), + onClick = { onDownloadCardButtonClick(model) }, + ) { + Text( + text = stringResource(id = when (model.downloadState) { + is DownloadState.Downloading -> R.string.cancel + is DownloadState.Error -> R.string.retry + else -> { + if (model.downloaded) R.string.delete + else R.string.download + } + }), + color = LocalContentColor.current, + ) + } + } + when (model.downloadState) { + is DownloadState.Downloading -> { + LinearProgressIndicator( + modifier = Modifier + .padding(8.dp) + .fillMaxWidth(), + progress = model.downloadState.percent / 100f, + ) + } + is DownloadState.Error -> { + Text( + modifier = Modifier + .padding(horizontal = 8.dp) + .padding(bottom = 8.dp), + text = stringResource(id = R.string.error_download_fail), + ) + } + else -> Unit + } + } +} diff --git a/storage/schemas/com.shifthackz.aisdv1.storage.db.persistent.PersistentDatabase/3.json b/storage/schemas/com.shifthackz.aisdv1.storage.db.persistent.PersistentDatabase/3.json new file mode 100644 index 00000000..9760464a --- /dev/null +++ b/storage/schemas/com.shifthackz.aisdv1.storage.db.persistent.PersistentDatabase/3.json @@ -0,0 +1,171 @@ +{ + "formatVersion": 1, + "database": { + "version": 3, + "identityHash": "da19714c55ccc83480117f4fb1825f56", + "entities": [ + { + "tableName": "generation_results", + "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `image_base_64` TEXT NOT NULL, `original_image_base_64` TEXT NOT NULL, `created_at` INTEGER NOT NULL, `generation_type` TEXT NOT NULL, `prompt` TEXT NOT NULL, `negative_prompt` TEXT NOT NULL, `width` INTEGER NOT NULL, `height` INTEGER NOT NULL, `sampling_steps` INTEGER NOT NULL, `cfg_scale` REAL NOT NULL, `restore_faces` INTEGER NOT NULL, `sampler` TEXT NOT NULL, `seed` TEXT NOT NULL, `sub_seed` TEXT NOT NULL DEFAULT '', `sub_seed_strength` REAL NOT NULL DEFAULT 0.0, `denoising_strength` REAL NOT NULL DEFAULT 0.0)", + "fields": [ + { + "fieldPath": "id", + "columnName": "id", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "imageBase64", + "columnName": "image_base_64", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "originalImageBase64", + "columnName": "original_image_base_64", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "createdAt", + "columnName": "created_at", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "generationType", + "columnName": "generation_type", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "prompt", + "columnName": "prompt", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "negativePrompt", + "columnName": "negative_prompt", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "width", + "columnName": "width", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "height", + "columnName": "height", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "samplingSteps", + "columnName": "sampling_steps", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "cfgScale", + "columnName": "cfg_scale", + "affinity": "REAL", + "notNull": true + }, + { + "fieldPath": "restoreFaces", + "columnName": "restore_faces", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "sampler", + "columnName": "sampler", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "seed", + "columnName": "seed", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "subSeed", + "columnName": "sub_seed", + "affinity": "TEXT", + "notNull": true, + "defaultValue": "''" + }, + { + "fieldPath": "subSeedStrength", + "columnName": "sub_seed_strength", + "affinity": "REAL", + "notNull": true, + "defaultValue": "0.0" + }, + { + "fieldPath": "denoisingStrength", + "columnName": "denoising_strength", + "affinity": "REAL", + "notNull": true, + "defaultValue": "0.0" + } + ], + "primaryKey": { + "autoGenerate": true, + "columnNames": [ + "id" + ] + }, + "indices": [], + "foreignKeys": [] + }, + { + "tableName": "local_models", + "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` TEXT NOT NULL, `name` TEXT NOT NULL, `size` TEXT NOT NULL, `sources` TEXT NOT NULL, PRIMARY KEY(`id`))", + "fields": [ + { + "fieldPath": "id", + "columnName": "id", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "name", + "columnName": "name", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "size", + "columnName": "size", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "sources", + "columnName": "sources", + "affinity": "TEXT", + "notNull": true + } + ], + "primaryKey": { + "autoGenerate": false, + "columnNames": [ + "id" + ] + }, + "indices": [], + "foreignKeys": [] + } + ], + "views": [], + "setupQueries": [ + "CREATE TABLE IF NOT EXISTS room_master_table (id INTEGER PRIMARY KEY,identity_hash TEXT)", + "INSERT OR REPLACE INTO room_master_table (id,identity_hash) VALUES(42, 'da19714c55ccc83480117f4fb1825f56')" + ] + } +} \ No newline at end of file diff --git a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/PersistentDatabase.kt b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/PersistentDatabase.kt index 701a9ca2..8452b244 100644 --- a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/PersistentDatabase.kt +++ b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/PersistentDatabase.kt @@ -5,16 +5,21 @@ import androidx.room.Database import androidx.room.RoomDatabase import androidx.room.TypeConverters import com.shifthackz.aisdv1.storage.converters.DateConverters +import com.shifthackz.aisdv1.storage.converters.ListConverters +import com.shifthackz.aisdv1.storage.converters.MapConverters import com.shifthackz.aisdv1.storage.db.persistent.PersistentDatabase.Companion.DB_VERSION import com.shifthackz.aisdv1.storage.db.persistent.contract.GenerationResultContract import com.shifthackz.aisdv1.storage.db.persistent.dao.GenerationResultDao +import com.shifthackz.aisdv1.storage.db.persistent.dao.LocalModelDao import com.shifthackz.aisdv1.storage.db.persistent.entity.GenerationResultEntity +import com.shifthackz.aisdv1.storage.db.persistent.entity.LocalModelEntity @Database( version = DB_VERSION, exportSchema = true, entities = [ GenerationResultEntity::class, + LocalModelEntity::class, ], autoMigrations = [ /** @@ -24,14 +29,22 @@ import com.shifthackz.aisdv1.storage.db.persistent.entity.GenerationResultEntity * - [GenerationResultContract.DENOISING_STRENGTH] */ AutoMigration(from = 1, to = 2), + /** + * Added [LocalModelEntity]. + */ + AutoMigration(from = 2, to = 3), ], ) -@TypeConverters(DateConverters::class) +@TypeConverters( + DateConverters::class, + ListConverters::class, +) internal abstract class PersistentDatabase : RoomDatabase() { abstract fun generationResultDao(): GenerationResultDao + abstract fun localModelDao(): LocalModelDao companion object { const val DB_NAME = "ai_sd_v1_storage_db" - const val DB_VERSION = 2 + const val DB_VERSION = 3 } } diff --git a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/contract/LocalModelContract.kt b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/contract/LocalModelContract.kt new file mode 100644 index 00000000..03b3c03a --- /dev/null +++ b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/contract/LocalModelContract.kt @@ -0,0 +1,10 @@ +package com.shifthackz.aisdv1.storage.db.persistent.contract + +object LocalModelContract { + const val TABLE = "local_models" + + const val ID = "id" + const val NAME = "name" + const val SIZE = "size" + const val SOURCES = "sources" +} diff --git a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/dao/LocalModelDao.kt b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/dao/LocalModelDao.kt new file mode 100644 index 00000000..9336a544 --- /dev/null +++ b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/dao/LocalModelDao.kt @@ -0,0 +1,26 @@ +package com.shifthackz.aisdv1.storage.db.persistent.dao + +import androidx.room.Dao +import androidx.room.Insert +import androidx.room.OnConflictStrategy +import androidx.room.Query +import com.shifthackz.aisdv1.storage.db.persistent.contract.LocalModelContract +import com.shifthackz.aisdv1.storage.db.persistent.entity.LocalModelEntity +import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Single + +@Dao +interface LocalModelDao { + + @Query("SELECT * FROM ${LocalModelContract.TABLE}") + fun query(): Single> + + @Query("SELECT * FROM ${LocalModelContract.TABLE} WHERE ${LocalModelContract.ID} = :id LIMIT 1") + fun queryById(id: String): Single + + @Insert(onConflict = OnConflictStrategy.REPLACE) + fun insert(item: LocalModelEntity): Completable + + @Insert(onConflict = OnConflictStrategy.REPLACE) + fun insertList(items: List): Completable +} diff --git a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/entity/LocalModelEntity.kt b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/entity/LocalModelEntity.kt new file mode 100644 index 00000000..ab896641 --- /dev/null +++ b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/entity/LocalModelEntity.kt @@ -0,0 +1,19 @@ +package com.shifthackz.aisdv1.storage.db.persistent.entity + +import androidx.room.ColumnInfo +import androidx.room.Entity +import androidx.room.PrimaryKey +import com.shifthackz.aisdv1.storage.db.persistent.contract.LocalModelContract + +@Entity(tableName = LocalModelContract.TABLE) +data class LocalModelEntity( + @PrimaryKey(autoGenerate = false) + @ColumnInfo(name = LocalModelContract.ID) + val id: String, + @ColumnInfo(name = LocalModelContract.NAME) + val name: String, + @ColumnInfo(name = LocalModelContract.SIZE) + val size: String, + @ColumnInfo(name = LocalModelContract.SOURCES) + val sources: List, +) diff --git a/storage/src/main/java/com/shifthackz/aisdv1/storage/di/DatabaseModule.kt b/storage/src/main/java/com/shifthackz/aisdv1/storage/di/DatabaseModule.kt index 02fbc9a7..4727dc15 100755 --- a/storage/src/main/java/com/shifthackz/aisdv1/storage/di/DatabaseModule.kt +++ b/storage/src/main/java/com/shifthackz/aisdv1/storage/di/DatabaseModule.kt @@ -44,5 +44,6 @@ val databaseModule = module { //region PERSISTENT DB DAOs single { get().generationResultDao() } + single { get().localModelDao() } //endregion } From 16b665381c68070ff9c5b804a088bd0bfeccbdd6 Mon Sep 17 00:00:00 2001 From: ShiftHackZ Date: Sun, 31 Dec 2023 15:05:09 +0200 Subject: [PATCH 2/2] Multiple LD Models | Patch 2 --- app/build.gradle | 7 +- app/src/foss/AndroidManifest.xml | 27 ++++ app/src/main/AndroidManifest.xml | 30 +--- .../aisdv1/app/di/ProvidersModule.kt | 4 +- app/src/playstore/AndroidManifest.xml | 23 +++ .../core/common/appbuild/BuildInfoProvider.kt | 10 ++ .../aisdv1/core/common/appbuild/BuildType.kt | 13 ++ .../local/DownloadableModelLocalDataSource.kt | 36 ++++- .../aisdv1/domain/di/DomainModule.kt | 3 - .../aisdv1/domain/entity/LocalAiModel.kt | 11 +- .../downloadable/SelectLocalAiModelUseCase.kt | 7 - .../SelectLocalAiModelUseCaseImpl.kt | 10 -- .../ai/tokenizer/EnglishTextTokenizer.kt | 7 +- .../aisdv1/feature/diffusion/ai/unet/UNet.kt | 3 +- .../feature/diffusion/ai/vae/VaeDecoder.kt | 3 +- .../extensions/LocalDiffusionPaths.kt | 19 +++ .../activity/AiStableDiffusionActivity.kt | 16 ++ .../aisdv1/presentation/di/ViewModelModule.kt | 1 - .../screen/setup/ServerSetupContract.kt | 3 +- .../screen/setup/ServerSetupScreen.kt | 66 ++++++++- .../screen/setup/ServerSetupViewModel.kt | 17 ++- .../screen/setup/mappers/LocalModelMappers.kt | 3 + .../widget/item/LocalModelItem.kt | 138 ++++++++++++++++-- presentation/src/main/res/values/strings.xml | 8 + 24 files changed, 373 insertions(+), 92 deletions(-) create mode 100755 app/src/foss/AndroidManifest.xml create mode 100755 app/src/playstore/AndroidManifest.xml create mode 100644 core/common/src/main/java/com/shifthackz/aisdv1/core/common/appbuild/BuildType.kt delete mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/SelectLocalAiModelUseCase.kt delete mode 100644 domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/SelectLocalAiModelUseCaseImpl.kt create mode 100644 feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/extensions/LocalDiffusionPaths.kt diff --git a/app/build.gradle b/app/build.gradle index 2c179864..9b47473f 100755 --- a/app/build.gradle +++ b/app/build.gradle @@ -48,9 +48,14 @@ android { foss { dimension "type" applicationIdSuffix = ".foss" - resValue "string", "app_name", "SDAI" + resValue "string", "app_name", "SDAI FOSS" buildConfigField "String", "BUILD_FLAVOR_TYPE", "\"FOSS\"" } + playstore { + dimension "type" + resValue "string", "app_name", "SDAI" + buildConfigField "String", "BUILD_FLAVOR_TYPE", "\"GOOGLE_PLAY\"" + } } } diff --git a/app/src/foss/AndroidManifest.xml b/app/src/foss/AndroidManifest.xml new file mode 100755 index 00000000..ff165861 --- /dev/null +++ b/app/src/foss/AndroidManifest.xml @@ -0,0 +1,27 @@ + + + + + + + + + + + + + + + + + diff --git a/app/src/main/AndroidManifest.xml b/app/src/main/AndroidManifest.xml index d26dbbf3..ab1f079a 100755 --- a/app/src/main/AndroidManifest.xml +++ b/app/src/main/AndroidManifest.xml @@ -2,24 +2,6 @@ - - - - - - - - - - - + android:windowSoftInputMode="adjustResize" + tools:ignore="LockedOrientationActivity"> @@ -50,14 +33,5 @@ android:name="android.support.FILE_PROVIDER_PATHS" android:resource="@xml/file_provider_paths" /> - - - - - diff --git a/app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt b/app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt index 29af6b38..0239dced 100755 --- a/app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt +++ b/app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt @@ -2,6 +2,7 @@ package com.shifthackz.aisdv1.app.di import com.shifthackz.aisdv1.app.BuildConfig import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider +import com.shifthackz.aisdv1.core.common.appbuild.BuildType import com.shifthackz.aisdv1.core.common.appbuild.BuildVersion import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor import com.shifthackz.aisdv1.core.common.links.LinksProvider @@ -74,12 +75,13 @@ val providersModule = module { override val isDebug: Boolean = BuildConfig.DEBUG override val buildNumber: Int = BuildConfig.VERSION_CODE override val version: BuildVersion = BuildVersion(BuildConfig.VERSION_NAME) + override val type: BuildType = BuildType.fromBuildConfig(BuildConfig.BUILD_FLAVOR_TYPE) override fun toString(): String = buildString { append("$version") if (BuildConfig.DEBUG) append("-dev") append(" ($buildNumber)") - append(" FOSS") + if (type == BuildType.FOSS) append(" FOSS") } } } diff --git a/app/src/playstore/AndroidManifest.xml b/app/src/playstore/AndroidManifest.xml new file mode 100755 index 00000000..c6895945 --- /dev/null +++ b/app/src/playstore/AndroidManifest.xml @@ -0,0 +1,23 @@ + + + + + + + + + + + + + + + diff --git a/core/common/src/main/java/com/shifthackz/aisdv1/core/common/appbuild/BuildInfoProvider.kt b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/appbuild/BuildInfoProvider.kt index ad85ba0f..cc97b370 100644 --- a/core/common/src/main/java/com/shifthackz/aisdv1/core/common/appbuild/BuildInfoProvider.kt +++ b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/appbuild/BuildInfoProvider.kt @@ -4,4 +4,14 @@ interface BuildInfoProvider { val isDebug: Boolean val buildNumber: Int val version: BuildVersion + val type: BuildType + + companion object { + val stub = object : BuildInfoProvider { + override val isDebug: Boolean = true + override val buildNumber: Int = 0 + override val version: BuildVersion = BuildVersion() + override val type: BuildType = BuildType.FOSS + } + } } diff --git a/core/common/src/main/java/com/shifthackz/aisdv1/core/common/appbuild/BuildType.kt b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/appbuild/BuildType.kt new file mode 100644 index 00000000..d431d0f8 --- /dev/null +++ b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/appbuild/BuildType.kt @@ -0,0 +1,13 @@ +package com.shifthackz.aisdv1.core.common.appbuild + +enum class BuildType { + FOSS, + PLAY; + + companion object { + fun fromBuildConfig(input: String) = when (input) { + "FOSS" -> FOSS + else -> PLAY + } + } +} diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt index 0ba9fdfc..9fdedf85 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt @@ -1,5 +1,7 @@ package com.shifthackz.aisdv1.data.local +import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider +import com.shifthackz.aisdv1.core.common.appbuild.BuildType import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor import com.shifthackz.aisdv1.data.mappers.mapDomainToEntity import com.shifthackz.aisdv1.data.mappers.mapEntityToDomain @@ -17,14 +19,26 @@ internal class DownloadableModelLocalDataSource( private val fileProviderDescriptor: FileProviderDescriptor, private val dao: LocalModelDao, private val preferenceManager: PreferenceManager, + private val buildInfoProvider: BuildInfoProvider, ) : DownloadableModelDataSource.Local { override fun getAll(): Single> = dao.query() .map(List::mapEntityToDomain) + .map { models -> + buildList { + addAll(models) + if (buildInfoProvider.type == BuildType.FOSS) add(LocalAiModel.CUSTOM) + } + } .flatMap { models -> models.withLocalData() } - override fun getById(id: String) = dao.queryById(id) - .map(LocalModelEntity::mapEntityToDomain) - .flatMap { model -> model.withLocalData() } + override fun getById(id: String): Single { + val chain = if (id == LocalAiModel.CUSTOM.id) Single.just(LocalAiModel.CUSTOM) + else dao + .queryById(id) + .map(LocalModelEntity::mapEntityToDomain) + + return chain.flatMap { model -> model.withLocalData() } + } override fun getSelected(): Single = Single .just(preferenceManager.localModelId) @@ -35,13 +49,21 @@ internal class DownloadableModelLocalDataSource( preferenceManager.localModelId = id } - override fun save(list: List) = dao.insertList(list.mapDomainToEntity()) + override fun save(list: List) = list + .filter { it.id != LocalAiModel.CUSTOM.id } + .mapDomainToEntity() + .let(dao::insertList) override fun isDownloaded(id: String): Single = Single.create { emitter -> try { - val localModelDir = getLocalModelDirectory(id) - val files = (localModelDir.listFiles()?.filter { it.isDirectory }) ?: emptyList() - if (!emitter.isDisposed) emitter.onSuccess(localModelDir.exists() && files.size == 4) + if (id == LocalAiModel.CUSTOM.id) { + if (!emitter.isDisposed) emitter.onSuccess(true) + } else { + val localModelDir = getLocalModelDirectory(id) + val files = + (localModelDir.listFiles()?.filter { it.isDirectory }) ?: emptyList() + if (!emitter.isDisposed) emitter.onSuccess(localModelDir.exists() && files.size == 4) + } } catch (e: Exception) { if (!emitter.isDisposed) emitter.onSuccess(false) } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt index 96fe7d18..d653e22a 100755 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt @@ -24,8 +24,6 @@ import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCase import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalAiModelsUseCase import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalAiModelsUseCaseImpl -import com.shifthackz.aisdv1.domain.usecase.downloadable.SelectLocalAiModelUseCase -import com.shifthackz.aisdv1.domain.usecase.downloadable.SelectLocalAiModelUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.gallery.DeleteGalleryItemUseCase import com.shifthackz.aisdv1.domain.usecase.gallery.DeleteGalleryItemUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.gallery.GetAllGalleryUseCase @@ -93,7 +91,6 @@ internal val useCasesModule = module { factoryOf(::GetLocalAiModelsUseCaseImpl) bind GetLocalAiModelsUseCase::class factoryOf(::DownloadModelUseCaseImpl) bind DownloadModelUseCase::class factoryOf(::DeleteModelUseCaseImpl) bind DeleteModelUseCase::class - factoryOf(::SelectLocalAiModelUseCaseImpl) bind SelectLocalAiModelUseCase::class } internal val debugModule = module { diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalAiModel.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalAiModel.kt index 22c0e68a..734ad5ed 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalAiModel.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalAiModel.kt @@ -7,4 +7,13 @@ data class LocalAiModel( val sources: List, val downloaded: Boolean = false, val selected: Boolean = false, -) +) { + companion object { + val CUSTOM = LocalAiModel( + id = "CUSTOM", + name = "Custom", + size = "NaN", + sources = emptyList(), + ) + } +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/SelectLocalAiModelUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/SelectLocalAiModelUseCase.kt deleted file mode 100644 index 96a3dbbc..00000000 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/SelectLocalAiModelUseCase.kt +++ /dev/null @@ -1,7 +0,0 @@ -package com.shifthackz.aisdv1.domain.usecase.downloadable - -import io.reactivex.rxjava3.core.Completable - -interface SelectLocalAiModelUseCase { - operator fun invoke(id: String): Completable -} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/SelectLocalAiModelUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/SelectLocalAiModelUseCaseImpl.kt deleted file mode 100644 index 86b2b994..00000000 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/SelectLocalAiModelUseCaseImpl.kt +++ /dev/null @@ -1,10 +0,0 @@ -package com.shifthackz.aisdv1.domain.usecase.downloadable - -import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository - -internal class SelectLocalAiModelUseCaseImpl( - private val downloadableModelRepository: DownloadableModelRepository, -) : SelectLocalAiModelUseCase { - - override fun invoke(id: String) = downloadableModelRepository.select(id) -} diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/tokenizer/EnglishTextTokenizer.kt b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/tokenizer/EnglishTextTokenizer.kt index c8ddc672..9cd3c615 100644 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/tokenizer/EnglishTextTokenizer.kt +++ b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/tokenizer/EnglishTextTokenizer.kt @@ -17,6 +17,7 @@ import com.shifthackz.aisdv1.feature.diffusion.ai.extensions.halfCorner import com.shifthackz.aisdv1.feature.diffusion.ai.extensions.toArrays import com.shifthackz.aisdv1.feature.diffusion.environment.LocalModelIdProvider import com.shifthackz.aisdv1.feature.diffusion.environment.OrtEnvironmentProvider +import com.shifthackz.aisdv1.feature.diffusion.extensions.modelPathPrefix import java.io.BufferedReader import java.io.FileInputStream import java.io.InputStreamReader @@ -47,7 +48,7 @@ internal class EnglishTextTokenizer( val options = OrtSession.SessionOptions() options.addConfigEntry(ORT_KEY_MODEL_FORMAT, ORT) session = ortEnvironmentProvider.get().createSession( - "${fileProviderDescriptor.localModelDirPath}/${localModelIdProvider.get()}/${LocalDiffusionContract.TOKENIZER_MODEL}", + "${modelPathPrefix(fileProviderDescriptor, localModelIdProvider)}/${LocalDiffusionContract.TOKENIZER_MODEL}", options ) if (!isInitMap) { @@ -205,7 +206,7 @@ internal class EnglishTextTokenizer( private fun loadEncoder(): Map { val map: MutableMap = HashMap() try { - val path = "${fileProviderDescriptor.localModelDirPath}/${localModelIdProvider.get()}/${LocalDiffusionContract.TOKENIZER_VOCABULARY}" + val path = "${modelPathPrefix(fileProviderDescriptor, localModelIdProvider)}/${LocalDiffusionContract.TOKENIZER_VOCABULARY}" val jsonReader = JsonReader(InputStreamReader(FileInputStream(path))) jsonReader.beginObject() while (jsonReader.hasNext()) { @@ -231,7 +232,7 @@ internal class EnglishTextTokenizer( private fun loadBpeRanks(): Map, Int?> { val result: MutableMap, Int?> = HashMap() try { - val path = "${fileProviderDescriptor.localModelDirPath}/${localModelIdProvider.get()}/${LocalDiffusionContract.TOKENIZER_MERGES}" + val path = "${modelPathPrefix(fileProviderDescriptor, localModelIdProvider)}/${LocalDiffusionContract.TOKENIZER_MERGES}" val reader = BufferedReader(InputStreamReader(FileInputStream(path))) var line: String var startLine = 1 diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/unet/UNet.kt b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/unet/UNet.kt index 43d311a4..297a4ca9 100644 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/unet/UNet.kt +++ b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/unet/UNet.kt @@ -27,6 +27,7 @@ import com.shifthackz.aisdv1.feature.diffusion.environment.OrtEnvironmentProvide import com.shifthackz.aisdv1.feature.diffusion.entity.LocalDiffusionFlag import com.shifthackz.aisdv1.feature.diffusion.environment.DeviceNNAPIFlagProvider import com.shifthackz.aisdv1.feature.diffusion.environment.LocalModelIdProvider +import com.shifthackz.aisdv1.feature.diffusion.extensions.modelPathPrefix import java.nio.IntBuffer import java.util.EnumSet import java.util.Random @@ -65,7 +66,7 @@ internal class UNet( options.addNnapi(EnumSet.of(NNAPIFlags.CPU_DISABLED)) } session = ortEnvironmentProvider.get().createSession( - "${fileProviderDescriptor.localModelDirPath}/${localModelIdProvider.get()}/${LocalDiffusionContract.UNET_MODEL}", + "${modelPathPrefix(fileProviderDescriptor, localModelIdProvider)}/${LocalDiffusionContract.UNET_MODEL}", options ) } diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/vae/VaeDecoder.kt b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/vae/VaeDecoder.kt index 82624134..72726d0f 100644 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/vae/VaeDecoder.kt +++ b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/vae/VaeDecoder.kt @@ -13,6 +13,7 @@ import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.ORT_KEY_MO import com.shifthackz.aisdv1.feature.diffusion.environment.OrtEnvironmentProvider import com.shifthackz.aisdv1.feature.diffusion.entity.LocalDiffusionFlag import com.shifthackz.aisdv1.feature.diffusion.environment.LocalModelIdProvider +import com.shifthackz.aisdv1.feature.diffusion.extensions.modelPathPrefix import java.util.EnumSet import kotlin.math.roundToInt @@ -65,7 +66,7 @@ internal class VaeDecoder( options.addNnapi(EnumSet.of(NNAPIFlags.CPU_DISABLED)) } session = ortEnvironmentProvider.get().createSession( - "${fileProviderDescriptor.localModelDirPath}/${localModelIdProvider.get()}/${LocalDiffusionContract.VAE_MODEL}", + "${modelPathPrefix(fileProviderDescriptor, localModelIdProvider)}/${LocalDiffusionContract.VAE_MODEL}", options ) } diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/extensions/LocalDiffusionPaths.kt b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/extensions/LocalDiffusionPaths.kt new file mode 100644 index 00000000..d4faaddd --- /dev/null +++ b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/extensions/LocalDiffusionPaths.kt @@ -0,0 +1,19 @@ +package com.shifthackz.aisdv1.feature.diffusion.extensions + +import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor +import com.shifthackz.aisdv1.domain.entity.LocalAiModel +import com.shifthackz.aisdv1.feature.diffusion.environment.LocalModelIdProvider + +private const val PATH = "/storage/emulated/0/Download/SDAI/model" + +fun modelPathPrefix( + fileProviderDescriptor: FileProviderDescriptor, + localModelIdProvider: LocalModelIdProvider, +): String { + val modelId = localModelIdProvider.get(); + return if (modelId == LocalAiModel.CUSTOM.id) { + PATH + } else { + "${fileProviderDescriptor.localModelDirPath}/${modelId}" + } +} diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/activity/AiStableDiffusionActivity.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/activity/AiStableDiffusionActivity.kt index 9e34e84c..0e11a0dd 100755 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/activity/AiStableDiffusionActivity.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/activity/AiStableDiffusionActivity.kt @@ -1,9 +1,12 @@ package com.shifthackz.aisdv1.presentation.activity import android.Manifest +import android.content.Intent import android.content.pm.PackageManager import android.os.Build import android.os.Bundle +import android.provider.Settings.ACTION_MANAGE_ALL_FILES_ACCESS_PERMISSION +import android.widget.Toast import androidx.activity.ComponentActivity import androidx.activity.compose.setContent import androidx.activity.result.contract.ActivityResultContracts @@ -131,6 +134,7 @@ class AiStableDiffusionActivity : ComponentActivity(), ImagePickerFeature, FileS } }, launchUrl = ::openUrl, + launchManageStoragePermission = ::setupManageStoragePermission, ).Build() } @@ -230,6 +234,18 @@ class AiStableDiffusionActivity : ComponentActivity(), ImagePickerFeature, FileS } } + private fun setupManageStoragePermission() { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.R) { + val intent = Intent(ACTION_MANAGE_ALL_FILES_ACCESS_PERMISSION) + startActivity(intent) + } else { + val hasPermission = requestStoragePermission() + if (hasPermission) { + Toast.makeText(this, "Already granted", Toast.LENGTH_LONG).show() + } + } + } + private fun requestNotificationPermission() { if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU && ActivityCompat.checkSelfPermission( this, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/di/ViewModelModule.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/di/ViewModelModule.kt index 2392fe0a..5e5e2c58 100755 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/di/ViewModelModule.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/di/ViewModelModule.kt @@ -44,7 +44,6 @@ val viewModelModule = module { testConnectivityUseCase = get(), testHordeApiKeyUseCase = get(), setServerConfigurationUseCase = get(), - selectLocalAiModelUseCase = get(), downloadModelUseCase = get(), deleteModelUseCase = get(), getLocalAiModelsUseCase = get(), diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupContract.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupContract.kt index 5fa9c274..6afacd75 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupContract.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupContract.kt @@ -33,8 +33,7 @@ data class ServerSetupState( val password: String = "", val originalPassword: String = "", val localModels: List = emptyList(), -// val localModelDownloaded: Boolean = false, -// val downloadState: DownloadState = DownloadState.Unknown, + val localCustomModel: Boolean = false, val passwordVisible: Boolean = false, val serverUrlValidationError: UiText? = null, val loginValidationError: UiText? = null, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreen.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreen.kt index 7a30645a..2eecaa5f 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreen.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreen.kt @@ -36,6 +36,7 @@ import androidx.compose.material3.Icon import androidx.compose.material3.IconButton import androidx.compose.material3.LocalContentColor import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.OutlinedButton import androidx.compose.material3.Scaffold import androidx.compose.material3.Switch import androidx.compose.material3.Text @@ -54,10 +55,13 @@ import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.sp import androidx.lifecycle.compose.collectAsStateWithLifecycle +import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider +import com.shifthackz.aisdv1.core.common.appbuild.BuildType import com.shifthackz.aisdv1.core.common.links.LinksProvider import com.shifthackz.aisdv1.core.model.asString import com.shifthackz.aisdv1.core.model.asUiText import com.shifthackz.aisdv1.core.ui.MviScreen +import com.shifthackz.aisdv1.domain.entity.LocalAiModel import com.shifthackz.aisdv1.presentation.R import com.shifthackz.aisdv1.presentation.utils.Constants import com.shifthackz.aisdv1.presentation.widget.dialog.ErrorDialog @@ -73,17 +77,21 @@ class ServerSetupScreen( private val onNavigateBack: () -> Unit = {}, private val onServerSetupComplete: () -> Unit = {}, private val launchUrl: (String) -> Unit = {}, + private val launchManageStoragePermission: () -> Unit = {}, ) : MviScreen(viewModel), KoinComponent { private val linksProvider: LinksProvider by inject() + private val buildInfoProvider: BuildInfoProvider by inject() @Composable override fun Content() { ScreenContent( modifier = Modifier.fillMaxSize(), state = viewModel.state.collectAsStateWithLifecycle().value, + buildInfoProvider = buildInfoProvider, demoModeUrl = linksProvider.demoModeUrl, onNavigateBack = onNavigateBack, + launchManageStoragePermission = launchManageStoragePermission, onServerModeUpdated = viewModel::updateServerMode, onServerUrlUpdated = viewModel::updateServerUrl, onAuthTypeSelected = viewModel::updateAuthType, @@ -98,6 +106,7 @@ class ServerSetupScreen( onOpenHordeSignUpWebSite = { launchUrl(linksProvider.hordeSignUpUrl) }, onDownloadCardButtonClick = viewModel::localModelDownloadClickReducer, onSelectLocalModel = viewModel::localModelSelect, + onAllowLocalCustomModel = viewModel::updateAllowLocalCustomModel, onSetupButtonClick = viewModel::connectToServer, onDismissScreenDialog = viewModel::dismissScreenDialog, ) @@ -112,8 +121,10 @@ class ServerSetupScreen( private fun ScreenContent( modifier: Modifier = Modifier, state: ServerSetupState, + buildInfoProvider: BuildInfoProvider = BuildInfoProvider.stub, demoModeUrl: String, onNavigateBack: () -> Unit = {}, + launchManageStoragePermission: () -> Unit = {}, onServerModeUpdated: (ServerSetupState.Mode) -> Unit = {}, onServerUrlUpdated: (String) -> Unit = {}, onAuthTypeSelected: (ServerSetupState.AuthType) -> Unit = {}, @@ -128,6 +139,7 @@ private fun ScreenContent( onOpenHordeSignUpWebSite: () -> Unit = {}, onDownloadCardButtonClick: (ServerSetupState.LocalModel) -> Unit = {}, onSelectLocalModel: (ServerSetupState.LocalModel) -> Unit = {}, + onAllowLocalCustomModel: (Boolean) -> Unit = {}, onSetupButtonClick: () -> Unit = {}, onDismissScreenDialog: () -> Unit = {}, ) { @@ -220,10 +232,14 @@ private fun ScreenContent( ) ServerSetupState.Mode.LOCAL -> LocalDiffusionSetupTab( state = state, + buildInfoProvider = buildInfoProvider, + launchManageStoragePermission = launchManageStoragePermission, onDownloadCardButtonClick = onDownloadCardButtonClick, onSelectLocalModel = onSelectLocalModel, + onAllowLocalCustomModel = onAllowLocalCustomModel, ) } + Spacer(modifier = Modifier.height(32.dp)) } }, ) @@ -451,8 +467,11 @@ private fun HordeAiSetupTab( private fun LocalDiffusionSetupTab( modifier: Modifier = Modifier, state: ServerSetupState, + buildInfoProvider: BuildInfoProvider = BuildInfoProvider.stub, + launchManageStoragePermission: () -> Unit = {}, onDownloadCardButtonClick: (ServerSetupState.LocalModel) -> Unit = {}, onSelectLocalModel: (ServerSetupState.LocalModel) -> Unit = {}, + onAllowLocalCustomModel: (Boolean) -> Unit = {}, ) { Column( modifier = modifier.padding(horizontal = 16.dp), @@ -471,13 +490,48 @@ private fun LocalDiffusionSetupTab( text = stringResource(id = R.string.hint_local_diffusion_sub_title), style = MaterialTheme.typography.bodyMedium, ) - state.localModels.forEach { localModel -> - LocalModelItem( - model = localModel, - onDownloadCardButtonClick = onDownloadCardButtonClick, - onSelect = onSelectLocalModel, + if (buildInfoProvider.type == BuildType.FOSS) { + Row( + verticalAlignment = Alignment.CenterVertically, + ) { + Switch( + checked = state.localCustomModel, + onCheckedChange = onAllowLocalCustomModel, + ) + Text( + modifier = Modifier.padding(start = 8.dp), + text = stringResource(id = R.string.model_local_custom_switch), + ) + } + } + if (state.localCustomModel && buildInfoProvider.type == BuildType.FOSS) { + Text( + modifier = Modifier.padding(vertical = 8.dp), + text = stringResource(id = R.string.model_local_permission_title), + style = MaterialTheme.typography.bodyMedium, ) + OutlinedButton( + modifier = Modifier.fillMaxSize().padding(vertical = 8.dp), + onClick = launchManageStoragePermission, + ) { + Text( + text = stringResource(id = R.string.model_local_permission_button), + color = LocalContentColor.current, + ) + } } + state.localModels + .filter { + val customPredicate = it.id == LocalAiModel.CUSTOM.id + if (state.localCustomModel) customPredicate else !customPredicate + } + .forEach { localModel -> + LocalModelItem( + model = localModel, + onDownloadCardButtonClick = onDownloadCardButtonClick, + onSelect = onSelectLocalModel, + ) + } Text( modifier = Modifier.padding(top = 16.dp), text = stringResource(id = R.string.hint_local_diffusion_warning), @@ -517,7 +571,6 @@ private fun ConfigurationModeButton( ServerSetupState.Mode.LOCAL -> Icons.Default.Android }, contentDescription = null, -// tint = MaterialTheme.colorScheme.onSecondaryContainer, ) Text( modifier = Modifier @@ -529,7 +582,6 @@ private fun ConfigurationModeButton( ServerSetupState.Mode.LOCAL -> R.string.srv_type_local }), fontSize = 14.sp, -// color = MaterialTheme.colorScheme.onSecondaryContainer, textAlign = TextAlign.Center, lineHeight = 15.sp, ) diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt index afa7cff9..bbdec680 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt @@ -11,6 +11,7 @@ import com.shifthackz.aisdv1.core.validation.url.UrlValidator import com.shifthackz.aisdv1.core.viewmodel.MviRxViewModel import com.shifthackz.aisdv1.domain.entity.Configuration import com.shifthackz.aisdv1.domain.entity.DownloadState +import com.shifthackz.aisdv1.domain.entity.LocalAiModel import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.feature.analytics.Analytics import com.shifthackz.aisdv1.domain.feature.auth.AuthorizationCredentials @@ -21,12 +22,12 @@ import com.shifthackz.aisdv1.domain.usecase.connectivity.TestHordeApiKeyUseCase import com.shifthackz.aisdv1.domain.usecase.downloadable.DeleteModelUseCase import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCase import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalAiModelsUseCase -import com.shifthackz.aisdv1.domain.usecase.downloadable.SelectLocalAiModelUseCase import com.shifthackz.aisdv1.domain.usecase.settings.GetConfigurationUseCase import com.shifthackz.aisdv1.domain.usecase.settings.SetServerConfigurationUseCase import com.shifthackz.aisdv1.presentation.features.SetupConnectEvent import com.shifthackz.aisdv1.presentation.features.SetupConnectFailure import com.shifthackz.aisdv1.presentation.features.SetupConnectSuccess +import com.shifthackz.aisdv1.presentation.screen.setup.mappers.mapLocalCustomModelSwitchState import com.shifthackz.aisdv1.presentation.screen.setup.mappers.mapToUi import com.shifthackz.aisdv1.presentation.screen.setup.mappers.withNewState import com.shifthackz.aisdv1.presentation.utils.Constants @@ -45,10 +46,8 @@ class ServerSetupViewModel( private val testConnectivityUseCase: TestConnectivityUseCase, private val testHordeApiKeyUseCase: TestHordeApiKeyUseCase, private val setServerConfigurationUseCase: SetServerConfigurationUseCase, - private val selectLocalAiModelUseCase: SelectLocalAiModelUseCase, private val downloadModelUseCase: DownloadModelUseCase, private val deleteModelUseCase: DeleteModelUseCase, -// private val checkDownloadedModelUseCase: CheckDownloadedModelUseCase, private val getLocalAiModelsUseCase: GetLocalAiModelsUseCase, private val dataPreLoaderUseCase: DataPreLoaderUseCase, private val schedulersProvider: SchedulersProvider, @@ -69,6 +68,7 @@ class ServerSetupViewModel( .subscribeBy(::errorLog) { (configuration, localModels) -> currentState .copy(localModels = localModels.mapToUi()) + .copy(localCustomModel = localModels.mapLocalCustomModelSwitchState()) .withSource(configuration.source) .withDemoMode(configuration.demoMode) .withServerUrl(configuration.serverUrl) @@ -115,6 +115,17 @@ class ServerSetupViewModel( .copy(hordeDefaultApiKey = value) .let(::setState) + fun updateAllowLocalCustomModel(value: Boolean) = currentState + .copy( + localCustomModel = value, + localModels = currentState.localModels.withNewState( + currentState.localModels.find { it.id == LocalAiModel.CUSTOM.id }!!.copy( + selected = value, + ), + ), + ) + .let(::setState) + fun connectToServer() { if (!validate()) return return when (currentState.mode) { diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/LocalModelMappers.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/LocalModelMappers.kt index 83207985..169624ac 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/LocalModelMappers.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/LocalModelMappers.kt @@ -5,6 +5,9 @@ import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState fun List.mapToUi(): List = map(LocalAiModel::mapToUi) +fun List.mapLocalCustomModelSwitchState(): Boolean = + find { it.selected && it.id == LocalAiModel.CUSTOM.id } != null + fun LocalAiModel.mapToUi(): ServerSetupState.LocalModel = with(this) { ServerSetupState.LocalModel( id = id, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/item/LocalModelItem.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/item/LocalModelItem.kt index b0f44ae5..e7f2e507 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/item/LocalModelItem.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/item/LocalModelItem.kt @@ -9,6 +9,7 @@ import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Spacer import androidx.compose.foundation.layout.defaultMinSize import androidx.compose.foundation.layout.fillMaxWidth +import androidx.compose.foundation.layout.height import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.size import androidx.compose.foundation.shape.RoundedCornerShape @@ -16,6 +17,7 @@ import androidx.compose.material.icons.Icons import androidx.compose.material.icons.outlined.FileDownload import androidx.compose.material.icons.outlined.FileDownloadDone import androidx.compose.material.icons.outlined.FileDownloadOff +import androidx.compose.material.icons.outlined.Landslide import androidx.compose.material3.Button import androidx.compose.material3.Icon import androidx.compose.material3.LinearProgressIndicator @@ -23,6 +25,7 @@ import androidx.compose.material3.LocalContentColor import androidx.compose.material3.MaterialTheme import androidx.compose.material3.Text import androidx.compose.runtime.Composable +import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.compose.ui.draw.clip import androidx.compose.ui.graphics.Color @@ -30,7 +33,9 @@ import androidx.compose.ui.res.stringResource import androidx.compose.ui.text.capitalize import androidx.compose.ui.text.intl.Locale import androidx.compose.ui.unit.dp +import androidx.compose.ui.unit.times import com.shifthackz.aisdv1.domain.entity.DownloadState +import com.shifthackz.aisdv1.domain.entity.LocalAiModel import com.shifthackz.aisdv1.presentation.R import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState @@ -58,12 +63,14 @@ fun LocalModelItem( Row( modifier = Modifier.padding(vertical = 4.dp), horizontalArrangement = Arrangement.Center, + verticalAlignment = Alignment.CenterVertically, ) { val icon = when (model.downloadState) { is DownloadState.Downloading -> Icons.Outlined.FileDownload - else -> { - if (model.downloaded) Icons.Outlined.FileDownloadDone - else Icons.Outlined.FileDownloadOff + else -> when { + model.id == LocalAiModel.CUSTOM.id -> Icons.Outlined.Landslide + model.downloaded -> Icons.Outlined.FileDownloadDone + else -> Icons.Outlined.FileDownloadOff } } Icon( @@ -77,23 +84,122 @@ fun LocalModelItem( modifier = Modifier.padding(start = 4.dp) ) { Text(text = model.name) - Text(model.size) + if (model.id != LocalAiModel.CUSTOM.id) { + Text(model.size) + } } Spacer(modifier = Modifier.weight(1f)) - Button( - modifier = Modifier.padding(end = 8.dp), - onClick = { onDownloadCardButtonClick(model) }, + if (model.id != LocalAiModel.CUSTOM.id) { + Button( + modifier = Modifier.padding(end = 8.dp), + onClick = { onDownloadCardButtonClick(model) }, + ) { + Text( + text = stringResource( + id = when (model.downloadState) { + is DownloadState.Downloading -> R.string.cancel + is DownloadState.Error -> R.string.retry + else -> { + if (model.downloaded) R.string.delete + else R.string.download + } + } + ), + color = LocalContentColor.current, + ) + } + } + } + if (model.id == LocalAiModel.CUSTOM.id) { + Column( + modifier = Modifier.padding(8.dp), ) { Text( - text = stringResource(id = when (model.downloadState) { - is DownloadState.Downloading -> R.string.cancel - is DownloadState.Error -> R.string.retry - else -> { - if (model.downloaded) R.string.delete - else R.string.download - } - }), - color = LocalContentColor.current, + text = stringResource(id = R.string.model_local_custom_title), + style = MaterialTheme.typography.bodyMedium, + ) + Spacer(modifier = Modifier.height(4.dp)) + Text( + text = stringResource(id = R.string.model_local_custom_sub_title), + style = MaterialTheme.typography.bodyMedium, + ) + Spacer(modifier = Modifier.height(4.dp)) + + fun folderModifier(treeNum: Int) = Modifier.padding(start = (treeNum - 1) * 12.dp) + val folderStyle = MaterialTheme.typography.bodySmall + Text( + modifier = folderModifier(1), + text = "Download", + style = folderStyle, + ) + Text( + modifier = folderModifier(2), + text = "SDAI", + style = folderStyle, + ) + Text( + modifier = folderModifier(3), + text = "model", + style = folderStyle, + ) + + Text( + modifier = folderModifier(4), + text = "text_encoder", + style = folderStyle, + ) + Text( + modifier = folderModifier(5), + text = "model.ort", + style = folderStyle, + ) + + Text( + modifier = folderModifier(4), + text = "tokenizer", + style = folderStyle, + ) + Text( + modifier = folderModifier(5), + text = "merges.txt", + style = folderStyle, + ) + Text( + modifier = folderModifier(5), + text = "special_tokens_map.json", + style = folderStyle, + ) + Text( + modifier = folderModifier(5), + text = "tokenizer_config.json", + style = folderStyle, + ) + Text( + modifier = folderModifier(5), + text = "tokenizer_config.json", + style = folderStyle, + ) + + Text( + modifier = folderModifier(4), + text = "unet", + style = folderStyle, + ) + Text( + modifier = folderModifier(5), + text = "model.ort", + style = folderStyle, + ) + + Text( + modifier = folderModifier(4), + text = "vae_decoder", + style = folderStyle, + ) + Text( + modifier = folderModifier(5), + text = "model.ort", + style = folderStyle, ) } } diff --git a/presentation/src/main/res/values/strings.xml b/presentation/src/main/res/values/strings.xml index 29eeebf0..f5287e7d 100755 --- a/presentation/src/main/res/values/strings.xml +++ b/presentation/src/main/res/values/strings.xml @@ -201,6 +201,14 @@ Local Diffusion ≈ 1.2 Gb + Load custom model + To be able to load custom model, you need to allow SDAI app manage storage permissions, because starting from Android 11 it is needed to access non-scoped storage files. + Setup permission + + To use local custom model, place it to local folder in your phone storage: Download/SDAi/model + The final folder structure should be: + + QA actions Insert bad Base64 in DB