Refactor routing parameters test; simplify system prompt to recognize the current time by default

Refactored the integration test to use parameterized tests for better readability and maintainability.
This commit is contained in:
sku 2024-09-19 16:07:25 +02:00
parent d0f6bb58b9
commit 479fc8e84d
2 changed files with 54 additions and 59 deletions

View File

@ -14,7 +14,7 @@ class RoutingParametersService(private val conversationCache: ConversationCache,
data class RoutingParameters(val start: String?, val destination: String?, val time: String?) {}
fun getRoutingParameters(requestId: String, prompt: String): RoutingParameters {
val request = chatRequest(prompt, generateSystemPrompt())
val request = chatRequest(generateSystemPrompt(), prompt)
val openAIResponse = openAIService.chat(request)
if (openAIResponse.choices.isEmpty()) {
return RoutingParameters("", "", "")
@ -29,29 +29,19 @@ class RoutingParametersService(private val conversationCache: ConversationCache,
private fun generateSystemPrompt(): String {
val zone = ZoneId.of("Europe/Berlin")
val now = LocalDateTime.now().atZone(zone)
val testTime = now.withHour(21).withMinute(0).withSecond(0).withNano(0)
val now = LocalDateTime.now().atZone(zone).format(DateTimeFormatter.ISO_OFFSET_DATE_TIME)
return """
You are a smart assistant that helps users with routing requests. Your goal is to extract three key parameters from each user query: the start location, destination location, and time. When a user provides a routing request, parse the following:
You are a smart assistant that extracts three key parameters from user queries: start location, destination, and time. Always return a structured JSON with these fields:
Start: The location where the journey begins (e.g., city, point of interest, address).
Destination: The location where the user wants to go (e.g., city, point of interest, address).
Time: When the user wants to travel (e.g., specific time, day, or general time frame).
Your task is to always output a structured JSON object with these three fields. If a parameter is missing or unclear, indicate it as null. Make sure to handle a variety of natural language inputs, including informal and conversational language.
Start: Where the journey begins.
Destination: Where the user wants to go.
Time: When the user wants to travel.
If start or destination are missing, set it to null. Handle informal language and adjust time to UTC.
Use the next AM/PM if applicable, and for "now" use $now.
If only a day is given, assume the current time on that day. If no time can be inferred, set it to now.
For example:
Input: How do I get from New York to Boston at 9 PM?
Output: {"start": "New York", "destination": "Boston", "time": "${testTime.format(DateTimeFormatter.ISO_OFFSET_DATE_TIME)}"}
Assume that the user means the next AM or PM time and adjust the time by adding 12 hours if necessary.
Always format the time as a UTC timestamp.
If the user asks for the current time (now, today, jetzt, heute), use ${now.format(DateTimeFormatter.ISO_OFFSET_DATE_TIME)}.
If only a day is given, assume the current time at that day.
If no time can be deducted, mark it as null.
If the input is ambiguous, do your best to infer the meaning and provide the most likely interpretation.
Be concise, clear, and precise in your extraction. Return a plain unformatted json (no markdown).
Return a plain unformatted json (no markdown).
Example output: {"start": "New York", "destination": "Boston", "time": "$now"}
"""
}

View File

@ -3,72 +3,77 @@ package de.hbt.routing.service
import de.hbt.routing.configuration.OpenAIRestTemplateConfig
import de.hbt.routing.openai.OpenAIService
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Disabled
import org.junit.jupiter.api.Test
import org.assertj.core.api.Assertions.within
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.Arguments
import org.junit.jupiter.params.provider.Arguments.of
import org.junit.jupiter.params.provider.MethodSource
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.boot.test.context.SpringBootTest
import java.time.Duration
import java.time.LocalDateTime
import java.time.OffsetDateTime
import java.time.ZoneId
import java.time.ZonedDateTime
import java.time.format.DateTimeFormatter
import java.time.temporal.ChronoUnit
import java.util.stream.Stream
@SpringBootTest(classes = [RoutingParametersService::class, ConversationCache::class, OpenAIService::class, OpenAIRestTemplateConfig::class])
@Disabled("manual")
//@Disabled("manual")
class RoutingParametersServiceIntegrationTest {
private val REQUEST_ID = "123"
private val ZONE = ZoneId.of("Europe/Berlin")
private val NOW = LocalDateTime.now().atZone(ZONE)
private val TOMORROW_AT_8 = NOW.plusDays(1).withHour(8).withMinute(0).withSecond(0).withNano(0)
@Autowired
var routingParametersService: RoutingParametersService? = null
@Autowired
var conversationCache: ConversationCache? = null
@Test
fun getRoutingParameters() {
@ParameterizedTest
@MethodSource("data")
fun getRoutingParameters(prompt: String, result: RoutingParametersService.RoutingParameters) {
//given
assertThat(routingParametersService).isNotNull
//when
val prompt = "Wie komme ich morgen um 8 vom Grüningweg zur Holmer Straße in Wedel?"
val routingParameters = routingParametersService?.getRoutingParameters(REQUEST_ID, prompt)
//then
val start = "Grüningweg"
val destination = "Holmer Straße, Wedel"
val time = TOMORROW_AT_8.format(DateTimeFormatter.ISO_OFFSET_DATE_TIME)
assertThat(routingParameters)
.isNotNull
.isEqualTo(RoutingParametersService.RoutingParameters(start = start, destination = destination, time = time))
assertThat(conversationCache?.getConversation(REQUEST_ID))
.isNotNull
.hasSize(1)
.satisfies(timeIsCloseToExpected(result.time!!))
.satisfies({ assertThat(it?.start).isEqualTo(result.start) })
.satisfies({ assertThat(it?.destination).isEqualTo(result.destination) })
}
@Test
fun getRoutingParametersWithNow() {
//given
assertThat(routingParametersService).isNotNull
private fun timeIsCloseToExpected(expectedTime: String): (input: RoutingParametersService.RoutingParameters?) -> Unit =
{
assertThat(it?.time).isNotEmpty
val timeActual = OffsetDateTime.parse(it?.time!!)
val timeExpected = OffsetDateTime.parse(expectedTime)
assertThat(timeActual).isCloseTo(timeExpected, within(30, ChronoUnit.SECONDS))
}
//when
val prompt = "Wie komme ich jetzt vom Grüningweg zur Holmer Straße in Wedel?"
val routingParameters = routingParametersService?.getRoutingParameters(REQUEST_ID, prompt)
companion object {
private val ZONE = ZoneId.of("Europe/Berlin")
private val NOW_TIME = LocalDateTime.now().atZone(ZONE)
private val NOW = NOW_TIME.format(DateTimeFormatter.ISO_OFFSET_DATE_TIME)
private val TOMORROW_AT_8 = NOW_TIME.plusDays(1).withHour(8).withMinute(0).withSecond(0).withNano(0)
.format(DateTimeFormatter.ISO_OFFSET_DATE_TIME)
//then
val time = NOW
assertThat(routingParameters?.time).isNotNull()
val parsedTime = routingParameters?.time?.let { ZonedDateTime.parse(it) }
val between = Duration.between(time, parsedTime)
assertThat(between.seconds).isLessThan(10)
@JvmStatic
fun data(): Stream<Arguments> = Stream.of(
of("Wie komme ich morgen um 8 vom Grüningweg zur Holmer Straße in Wedel?",
result("Grüningweg", "Holmer Straße, Wedel", TOMORROW_AT_8)),
of("Vom Hauptbahnhof zur Stadthausbrücke",
result("Hauptbahnhof", "Stadthausbrücke", NOW)),
of("Ich am in Ahrensburg. When can I get the next connection to Lüneburg?",
result("Ahrensburg", "Lüneburg", NOW)),
)
private fun result(
start: String,
destination: String,
time: String?
) = RoutingParametersService.RoutingParameters(start = start, destination = destination, time = time)
assertThat(conversationCache?.getConversation(REQUEST_ID))
.isNotNull
.hasSize(1)
}
}