import { useEffect, useMemo, useState } from 'react'; import { StyleSheet, View } from 'react-native'; import { Image } from 'expo-image'; import * as ImagePicker from 'expo-image-picker'; import * as ImageManipulator from 'expo-image-manipulator'; import jpeg from 'jpeg-js'; import ParallaxScrollView from '@/components/parallax-scroll-view'; import { ThemedText } from '@/components/themed-text'; import { ThemedView } from '@/components/themed-view'; import { IconSymbol } from '@/components/ui/icon-symbol'; import { Colors, Fonts } from '@/constants/theme'; import { useTranslation } from '@/localization/i18n'; import { ThemedButton } from '@/components/themed-button'; import { dbPromise, initCoreTables } from '@/services/db'; import { useColorScheme } from '@/hooks/use-color-scheme'; const classNames = require('@/assets/class_names.json') as string[]; type Prediction = { index: number; label: string; score: number; }; type InferenceLog = { id: number; top_label: string | null; top_score: number | null; created_at: string | null; uri: string | null; }; export default function OnnxScreen() { const { t } = useTranslation(); const theme = useColorScheme() ?? 'light'; const palette = Colors[theme]; const [status, setStatus] = useState(t('onnx.status.pick')); const [imageUri, setImageUri] = useState(null); const [predictions, setPredictions] = useState([]); const [logs, setLogs] = useState([]); const canRun = useMemo(() => Boolean(imageUri), [imageUri]); useEffect(() => { let active = true; initCoreTables() .then(async () => { const rows = await loadLogs(); if (active) setLogs(rows); }) .catch((error) => { setStatus(`Error: ${String(error)}`); }); return () => { active = false; }; }, []); async function pickImage() { const result = await ImagePicker.launchImageLibraryAsync({ mediaTypes: getImageMediaTypes(), quality: 1, }); if (result.canceled) return; const asset = result.assets[0]; setImageUri(asset.uri); setPredictions([]); setStatus(t('onnx.status.ready')); } async function handleRun() { if (!imageUri) return; setStatus(t('onnx.status.preprocessing')); try { const { runPlantVillageInference } = await import('@/services/onnx/runPlantVillage'); const inputTensor = await imageToTensor(imageUri); setStatus(t('onnx.status.running')); const scores = await runPlantVillageInference(inputTensor); const probabilities = softmax(scores); const top = topK(probabilities, 3).map((item) => ({ ...item, label: humanizeLabel(classNames[item.index] ?? `Class ${item.index}`), })); setPredictions(top); await logInference({ uri: imageUri, top, scores: probabilities, }); setLogs(await loadLogs()); setStatus(t('onnx.status.done')); } catch (error) { const message = String(error); if (message.includes('install') || message.includes('onnxruntime')) { setStatus(t('onnx.status.nativeMissing')); } else { setStatus(`Error: ${message}`); } } } return ( }> {t('onnx.title')} {t('onnx.howBody')} {t('onnx.testTitle')} {status} {!imageUri ? ( Pick an image to run the model. ) : null} {imageUri ? ( ) : ( {t('onnx.status.pick')} )} {predictions.length > 0 ? ( {t('onnx.topPredictions')} {predictions.map((pred) => { const percent = Number((pred.score * 100).toFixed(2)); return ( {pred.label} {percent}% ); })} ) : null} {logs.length > 0 ? ( Recent scans {logs.map((item) => { const score = item.top_score ? Math.round(item.top_score * 100) : null; return ( {item.uri ? ( ) : ( )} {item.top_label ?? t('onnx.topPredictions')} {item.created_at ? formatDate(item.created_at) : ''} {score !== null ? `${score}%` : '--'} ); })} ) : null} {t('onnx.howTitle')} {t('onnx.howBody')} ); } function getImageMediaTypes() { const mediaType = (ImagePicker as { MediaType?: { Image?: unknown; Images?: unknown } }) .MediaType; return mediaType?.Image ?? mediaType?.Images ?? ['images']; } function topK(scores: Float32Array, k: number): Prediction[] { const items = Array.from(scores).map((score, index) => ({ index, label: `Class ${index}`, score, })); items.sort((a, b) => b.score - a.score); return items.slice(0, k); } function softmax(scores: Float32Array) { const max = Math.max(...scores); const exps = Array.from(scores).map((value) => Math.exp(value - max)); const sum = exps.reduce((acc, value) => acc + value, 0); return new Float32Array(exps.map((value) => value / sum)); } async function imageToTensor(uri: string) { const resized = await ImageManipulator.manipulateAsync( uri, [{ resize: { width: 224, height: 224 } }], { base64: true, format: ImageManipulator.SaveFormat.JPEG, compress: 1, } ); if (!resized.base64) { throw new Error('Failed to read image data.'); } const bytes = base64ToUint8Array(resized.base64); const decoded = jpeg.decode(bytes, { useTArray: true }); const { data, width, height } = decoded; if (width !== 224 || height !== 224) { throw new Error(`Unexpected image size: ${width}x${height}`); } const size = width * height; const floatData = new Float32Array(1 * 3 * size); for (let i = 0; i < size; i += 1) { const pixelIndex = i * 4; const r = data[pixelIndex] / 255; const g = data[pixelIndex + 1] / 255; const b = data[pixelIndex + 2] / 255; floatData[i] = r; floatData[size + i] = g; floatData[size * 2 + i] = b; } return floatData; } function base64ToUint8Array(base64: string) { const cleaned = base64.replace(/[^A-Za-z0-9+/=]/g, ''); const bytes: number[] = []; const chars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/'; let i = 0; while (i < cleaned.length) { const enc1 = decodeBase64Char(cleaned.charAt(i++), chars); const enc2 = decodeBase64Char(cleaned.charAt(i++), chars); const enc3 = decodeBase64Char(cleaned.charAt(i++), chars); const enc4 = decodeBase64Char(cleaned.charAt(i++), chars); const chr1 = (enc1 << 2) | (enc2 >> 4); const chr2 = ((enc2 & 15) << 4) | (enc3 >> 2); const chr3 = ((enc3 & 3) << 6) | enc4; bytes.push(chr1); if (enc3 !== 64) bytes.push(chr2); if (enc4 !== 64) bytes.push(chr3); } return new Uint8Array(bytes); } function decodeBase64Char(char: string, chars: string) { if (char === '=') return 64; const index = chars.indexOf(char); if (index === -1) return 64; return index; } async function loadLogs() { const db = await dbPromise; const rows = await db.getAllAsync( `SELECT mi.id, mi.top_label, mi.top_score, mi.created_at, img.uri FROM ml_inferences mi LEFT JOIN images img ON img.id = mi.image_id ORDER BY mi.created_at DESC LIMIT 5;` ); return rows; } async function logInference(params: { uri: string; top: Prediction[]; scores: Float32Array; }) { const db = await dbPromise; const now = new Date().toISOString(); const topHit = params.top[0]; const scoresJson = JSON.stringify( params.top.map((item) => ({ index: item.index, label: item.label, score: item.score })) ); const insert = await db.runAsync( 'INSERT INTO images (observation_id, uri, thumbnail_uri, width, height, created_at) VALUES (?, ?, ?, ?, ?, ?);', null, params.uri, null, 224, 224, now ); await db.runAsync( 'INSERT INTO ml_inferences (image_id, model_name, model_version, top_label, top_score, scores_json, created_at) VALUES (?, ?, ?, ?, ?, ?, ?);', insert.lastInsertRowId, 'plantvillage_mnv3s_224', '1', topHit?.label ?? null, topHit?.score ?? null, scoresJson, now ); } function humanizeLabel(label: string) { const cleaned = label .replace(/[_/]+/g, ' ') .replace(/\s+/g, ' ') .trim(); return cleaned .split(' ') .map((word) => (word ? word[0].toUpperCase() + word.slice(1) : word)) .join(' '); } function formatDate(value: string) { try { return new Date(value).toLocaleString(); } catch { return value; } } const styles = StyleSheet.create({ headerImage: { width: '100%', height: '100%', }, hero: { borderRadius: 20, borderWidth: 1, flexDirection: 'row', gap: 14, padding: 16, alignItems: 'center', marginBottom: 16, }, heroIcon: { borderRadius: 999, padding: 10, }, heroText: { flex: 1, gap: 6, }, heroTitle: { fontFamily: Fonts.rounded, }, heroSubtitle: { color: '#516458', }, card: { backgroundColor: '#FFFFFF', borderRadius: 18, padding: 16, gap: 12, marginBottom: 16, borderWidth: 1, borderColor: '#E3DED6', }, cardHeader: { flexDirection: 'row', alignItems: 'center', gap: 12, flexWrap: 'wrap', }, statusPill: { backgroundColor: '#E6F0E2', borderRadius: 999, paddingHorizontal: 10, paddingVertical: 4, }, statusText: { fontSize: 12, color: '#3D5F4A', }, previewWrap: { borderRadius: 16, overflow: 'hidden', backgroundColor: '#F3F0EA', borderWidth: 1, borderColor: '#E1DBD1', }, preview: { height: 220, width: '100%', }, previewPlaceholder: { height: 200, alignItems: 'center', justifyContent: 'center', gap: 8, }, placeholderText: { color: '#6B736D', }, inlineHint: { color: '#6B736D', fontSize: 13, }, actionRow: { flexDirection: 'row', flexWrap: 'wrap', gap: 12, }, results: { gap: 12, }, logList: { gap: 10, }, logRow: { flexDirection: 'row', alignItems: 'center', gap: 12, padding: 10, borderRadius: 12, backgroundColor: '#F7F4EE', }, logThumb: { width: 48, height: 48, borderRadius: 10, backgroundColor: '#ECE7DE', alignItems: 'center', justifyContent: 'center', overflow: 'hidden', }, logImage: { width: '100%', height: '100%', }, logMeta: { flex: 1, gap: 2, }, logLabel: { fontWeight: '600', }, logSub: { fontSize: 12, color: '#6B736D', }, logScore: { fontWeight: '600', color: '#3D5F4A', }, predRow: { gap: 6, }, predLabelRow: { flexDirection: 'row', justifyContent: 'space-between', alignItems: 'center', gap: 12, }, predLabel: { flex: 1, }, predPercent: { color: '#3D5F4A', }, barTrack: { height: 8, borderRadius: 999, backgroundColor: '#E7E1D7', overflow: 'hidden', }, barFill: { height: 8, borderRadius: 999, backgroundColor: '#4B7B57', }, });