package invirt.data.mongodb

import com.mongodb.client.model.Indexes
import com.mongodb.kotlin.client.MongoCollection
import org.slf4j.LoggerFactory
import kotlin.reflect.KClass
import kotlin.reflect.full.findAnnotation
import kotlin.reflect.full.isSubclassOf
import kotlin.reflect.full.memberProperties

@Target(AnnotationTarget.PROPERTY)
@Retention(AnnotationRetention.RUNTIME)
@MustBeDocumented
annotation class Indexed(
    val order: Order = Order.ASC,
    vararg val fields: String
) {

    enum class Order {
        ASC,
        DESC
    }
}

@Target(AnnotationTarget.PROPERTY)
@Retention(AnnotationRetention.RUNTIME)
@MustBeDocumented
annotation class TextIndexed(
    vararg val fields: String
)

fun <T : MongoEntity> MongoCollection<T>.createIndexes() {
    createPropertyIndexes(this.documentClass.kotlin)

    // Handling Timestamped explicitly, because annotations are not inherited in Kotlin
    // so we cannot simply add @Indexed on Timestamped.createdAt
    if (this.documentClass.kotlin.isSubclassOf(TimestampedEntity::class)) {
        createPropertyIndexes(TimestampedEntity::class)
    }
}

private fun MongoCollection<*>.createPropertyIndexes(cls: KClass<*>) {
    val log = LoggerFactory.getLogger("invirt.data.mongodb")
    val textIndexedFields = mutableListOf<String>()
    cls.memberProperties.forEach { property ->

        // Indexed properties
        property.findAnnotation<Indexed>()?.let { annotation ->
            val fields = if (annotation.fields.isEmpty()) listOf(property.name) else annotation.fields.toList()
            fields.forEach { field ->
                val index = when (annotation.order) {
                    Indexed.Order.ASC -> Indexes.ascending(field)
                    Indexed.Order.DESC -> Indexes.descending(field)
                }
                createIndex(index)
                log.info("Creating ${annotation.order} index for ${cls.simpleName}.${field}")
            }
        }

        // Text indexed properties need collecting into one index as Mongo only supports one text index
        property.findAnnotation<TextIndexed>()?.let { annotation ->
            val fields = if (annotation.fields.isEmpty()) listOf(property.name) else annotation.fields.toList()
            textIndexedFields.addAll(fields)
        }
    }

    if (textIndexedFields.isNotEmpty()) {
        createIndex(Indexes.compoundIndex(textIndexedFields.map { Indexes.text(it) }))
        log.info("Creating text index for ${textIndexedFields.joinToString(", ") { cls.simpleName + "." + it }}")
    }
}
