Model do wrong print response on java - ( ai.onnxruntime)
Guys what is wrong with my extract :
You: hi
hi
You: howare you
are you?
You: I am Kostya
Ko Kos Koyy I
You: Who are you ?
? you???
You: Where is Rome city ?
is Rome? ??
You: yes
yes
===============
public static String generate(String prompt, OrtEnvironment env, OrtSession encoderSession,
OrtSession decoderSession, SpProcessor sp) throws OrtException {
// Токенизируем входной текст с помощью DJL SentencePiece
int[] inputIdsInt = sp.encode(prompt);
long[] inputIds = Arrays.stream(inputIdsInt).mapToLong(i -> i).toArray();
long[] attentionMask = new long[inputIds.length];
Arrays.fill(attentionMask, 1L);
Map<String, OnnxTensor> encoderInputs = new HashMap<>();
long[][] inputIds2D = new long[][] { inputIds };
long[][] attentionMask2D = new long[][] { attentionMask };
encoderInputs.put("input_ids", OnnxTensor.createTensor(env, inputIds2D));
encoderInputs.put("attention_mask", OnnxTensor.createTensor(env, attentionMask2D));
OrtSession.Result encoderResult = encoderSession.run(encoderInputs);
float[][][] encoderHiddenState = (float[][][]) ((OnnxTensor) encoderResult.get("last_hidden_state").get()).getValue();
Map<String, OnnxTensor> decoderInputs = new HashMap<>();
LongBuffer buffer = LongBuffer.wrap(inputIds);
decoderInputs.put("input_ids", OnnxTensor.createTensor(env, inputIds2D));
decoderInputs.put("encoder_attention_mask", OnnxTensor.createTensor(env, attentionMask2D));
decoderInputs.put("encoder_hidden_states", OnnxTensor.createTensor(env, encoderHiddenState));
OrtSession.Result decoderResult = decoderSession.run(decoderInputs);
OnnxTensor decoderOutput = (OnnxTensor) decoderResult.get("logits").get();
long[] shape = decoderOutput.getInfo().getShape();
float[][][] logits = (float[][][]) decoderOutput.getValue();
int batchSize = (int) shape[0];
int sequenceLength = (int) shape[1];
int vocabSize = (int) shape[2];
int[] generatedTokenIds = new int[sequenceLength];
for (int i = 0; i < sequenceLength; i++) {
float[] tokenLogits = logits[0][i]; // logits for token i
int maxIndex = 0;
float maxValue = tokenLogits[0];
for (int j = 1; j < vocabSize; j++) {
if (tokenLogits[j] > maxValue) {
maxValue = tokenLogits[j];
maxIndex = j;
}
}
generatedTokenIds[i] = maxIndex;
}
return sp.decode(generatedTokenIds);
}
- T5-small with use_cache=True . Bad export onnx model
🚫 Past key/value caching
Without this, the model recalculates everything from scratch each time. This is inefficient and hinders autoregression.
🚫 Correct cross-attention
If the model was not exported with support for and , it may ignore context.
🚫 Positional embeddings
The ONNX model may not take into account the position of tokens if they are not explicitly specified, which leads to repetition of the same word.
LOL
What's going on
• The model accepts and in the encoder—that's correct.
• The decoder accepts and—that's also correct.
• But the results—"France??", "Spanish. German. I? you?", "why fox jumps…"—suggest that the decoder isn't taking context into account and is simply repeating high-frequency tokens.
This is typical behavior if the model was exported without support and without properly configured cross-attention.
decoder_with_past_model.onnx was tranned with other Tokenizer . LOL
public static String generate(String prompt, OrtEnvironment env, OrtSession encoderSession,
OrtSession decoderSession, SpProcessor sp) throws OrtException {
int[] inputIdsInt = sp.encode(prompt);
long[] inputIds = Arrays.stream(inputIdsInt).mapToLong(i -> i).toArray();
long[] attentionMask = new long[inputIds.length];
Arrays.fill(attentionMask, 1L);
Map<String, OnnxTensor> encoderInputs = new HashMap<>();
encoderInputs.put("input_ids",
OnnxTensor.createTensor(env, LongBuffer.wrap(inputIds), new long[] { 1, inputIds.length }));
encoderInputs.put("attention_mask",
OnnxTensor.createTensor(env, LongBuffer.wrap(attentionMask), new long[] { 1, attentionMask.length }));
OrtSession.Result encoderResult = encoderSession.run(encoderInputs);
OnnxTensor encoderHiddenStates = (OnnxTensor) encoderResult.get("last_hidden_state").get();
long encoderSeqLen = encoderHiddenStates.getInfo().getShape()[1];
List<Long> generated = new ArrayList<>();
generated.add(0L); // <pad> token
for (int step = 0; step < 50; step++) {
long[] decoderInputIds = new long[] { generated.get(generated.size() - 1) };
Map<String, OnnxTensor> decoderInputs = new HashMap<>();
decoderInputs.put("input_ids",
OnnxTensor.createTensor(env, LongBuffer.wrap(decoderInputIds), new long[] { 1, 1 }));
decoderInputs.put("encoder_hidden_states", encoderHiddenStates);
long[] decoderShape = new long[] { 1, 8, 1, 64 };
float[] decoderZeros = new float[1 * 8 * 1 * 64];
long[] encoderShape = new long[] { 1, 8, encoderSeqLen, 64 };
float[] encoderZeros = new float[(int) (1 * 8 * encoderSeqLen * 64)];
for (int i = 0; i < 6; i++) {
decoderInputs.put("past_key_values." + i + ".decoder.key",
OnnxTensor.createTensor(env, FloatBuffer.wrap(decoderZeros), decoderShape));
decoderInputs.put("past_key_values." + i + ".decoder.value",
OnnxTensor.createTensor(env, FloatBuffer.wrap(decoderZeros), decoderShape));
decoderInputs.put("past_key_values." + i + ".encoder.key",
OnnxTensor.createTensor(env, FloatBuffer.wrap(encoderZeros), encoderShape));
decoderInputs.put("past_key_values." + i + ".encoder.value",
OnnxTensor.createTensor(env, FloatBuffer.wrap(encoderZeros), encoderShape));
}
OrtSession.Result decoderResult = decoderSession.run(decoderInputs);
OnnxTensor decoderOutput = (OnnxTensor) decoderResult.get("logits").get();
float[][][] logits = (float[][][]) decoderOutput.getValue();
float[] lastLogits = logits[0][0];
// ====== Argmax ======
int nextToken = 0;
float maxLogit = lastLogits[0];
for (int i = 1; i < lastLogits.length; i++) {
if (lastLogits[i] > maxLogit) {
maxLogit = lastLogits[i];
nextToken = i;
}
}
generated.add((long) nextToken);
if (nextToken == 1)
break; // <eos> token
}
''int vocabSize = 32099; // T5-small стандартный размер
//int[] safeIds = generated.stream().mapToInt(i -> i.intValue()).filter(id -> id >= 0 && id < vocabSize).toArray();
int[] safeIds = generated.stream().mapToInt(i -> i.intValue()).toArray();
return sp.decode(safeIds);
}
You: hi
Exception in thread "main" ai.djl.engine.EngineException: Out of range: Invalid id: 32099
at ai.djl.sentencepiece.jni.SentencePieceLibrary.decode(Native Method)
at ai.djl.sentencepiece.SpProcessor.decode(SpProcessor.java:121)
at com.cbsinc.cms.jrack.service.t5_onnx.generate(t5_onnx.java:173)
at com.cbsinc.cms.jrack.service.t5_onnx.main(t5_onnx.java:90)
fired my time a lot . it is not good work
any solution to get valid response from the mode ?
I tried to use decoder_model_merged.onnx
Suck it was trained with other Tokenizer aslo
public static String generate_m(String prompt, OrtEnvironment env, OrtSession encoderSession,
OrtSession decoderSession, SpProcessor sp) throws OrtException {
// ====== Токенизация ======
int[] inputIdsInt = sp.encode(prompt);
long[] inputIds = Arrays.stream(inputIdsInt).mapToLong(i -> i).toArray();
long[] attentionMask = new long[inputIds.length];
Arrays.fill(attentionMask, 1L);
// ====== Энкодер ======
Map<String, OnnxTensor> encoderInputs = new HashMap<>();
encoderInputs.put("input_ids",
OnnxTensor.createTensor(env, LongBuffer.wrap(inputIds), new long[] { 1, inputIds.length }));
encoderInputs.put("attention_mask",
OnnxTensor.createTensor(env, LongBuffer.wrap(attentionMask), new long[] { 1, attentionMask.length }));
OrtSession.Result encoderResult = encoderSession.run(encoderInputs);
OnnxTensor encoderHiddenStates = (OnnxTensor) encoderResult.get("last_hidden_state").get();
long encoderSeqLen = encoderHiddenStates.getInfo().getShape()[1];
// ====== Декодер: авто-регрессия ======
List generated = new ArrayList<>();
generated.add(0L); // token
for (int step = 0; step < 50; step++) {
long[] decoderInputIds = new long[] { generated.get(generated.size() - 1) };
Map<String, OnnxTensor> decoderInputs = new HashMap<>();
decoderInputs.put("input_ids",
OnnxTensor.createTensor(env, LongBuffer.wrap(decoderInputIds), new long[] { 1, 1 }));
decoderInputs.put("encoder_hidden_states", encoderHiddenStates);
// ====== use_cache_branch: обязательно ======
decoderInputs.put("use_cache_branch", OnnxTensor.createTensor(env, new boolean[] { true }));
// ====== past_key_values ======
long[] decoderShape = new long[] { 1, 8, step == 0 ? 1 : step, 64 };
float[] decoderZeros = new float[(int) (decoderShape[0] * decoderShape[1] * decoderShape[2]
* decoderShape[3])];
long[] encoderShape = new long[] { 1, 8, encoderSeqLen, 64 };
float[] encoderZeros = new float[(int) (encoderShape[0] * encoderShape[1] * encoderShape[2]
* encoderShape[3])];
for (int i = 0; i < 6; i++) {
decoderInputs.put("past_key_values." + i + ".decoder.key",
OnnxTensor.createTensor(env, FloatBuffer.wrap(decoderZeros), decoderShape));
decoderInputs.put("past_key_values." + i + ".decoder.value",
OnnxTensor.createTensor(env, FloatBuffer.wrap(decoderZeros), decoderShape));
decoderInputs.put("past_key_values." + i + ".encoder.key",
OnnxTensor.createTensor(env, FloatBuffer.wrap(encoderZeros), encoderShape));
decoderInputs.put("past_key_values." + i + ".encoder.value",
OnnxTensor.createTensor(env, FloatBuffer.wrap(encoderZeros), encoderShape));
}
OrtSession.Result decoderResult = decoderSession.run(decoderInputs);
OnnxTensor decoderOutput = (OnnxTensor) decoderResult.get("logits").get();
float[][][] logits = (float[][][]) decoderOutput.getValue();
float[] lastLogits = logits[0][0];
// ====== Argmax ======
int nextToken = 0;
float maxLogit = lastLogits[0];
for (int i = 1; i < lastLogits.length; i++) {
if (lastLogits[i] > maxLogit) {
maxLogit = lastLogits[i];
nextToken = i;
}
}
generated.add((long) nextToken);
if (nextToken == 1)
break; // <eos> token
}
// ====== Фильтрация недопустимых токенов ======
int vocabSize = 32099;
int[] safeIds = generated.stream().mapToInt(i -> i.intValue()).filter(id -> id >= 0 && id < vocabSize)
.toArray();
return sp.decode(safeIds);
}
responded me
You: hi
.
You: