| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538 |
- import { useEffect, useMemo, useState } from 'react';
- import { StyleSheet, View } from 'react-native';
- import { Image } from 'expo-image';
- import * as ImagePicker from 'expo-image-picker';
- import * as ImageManipulator from 'expo-image-manipulator';
- import jpeg from 'jpeg-js';
- import ParallaxScrollView from '@/components/parallax-scroll-view';
- import { ThemedText } from '@/components/themed-text';
- import { ThemedView } from '@/components/themed-view';
- import { IconSymbol } from '@/components/ui/icon-symbol';
- import { Colors, Fonts } from '@/constants/theme';
- import { useTranslation } from '@/localization/i18n';
- import { ThemedButton } from '@/components/themed-button';
- import { dbPromise, initCoreTables } from '@/services/db';
- import { useColorScheme } from '@/hooks/use-color-scheme';
- const classNames = require('@/assets/class_names.json') as string[];
- type Prediction = {
- index: number;
- label: string;
- score: number;
- };
- type InferenceLog = {
- id: number;
- top_label: string | null;
- top_score: number | null;
- created_at: string | null;
- uri: string | null;
- };
- export default function OnnxScreen() {
- const { t } = useTranslation();
- const theme = useColorScheme() ?? 'light';
- const palette = Colors[theme];
- const [status, setStatus] = useState(t('onnx.status.pick'));
- const [imageUri, setImageUri] = useState<string | null>(null);
- const [predictions, setPredictions] = useState<Prediction[]>([]);
- const [logs, setLogs] = useState<InferenceLog[]>([]);
- const canRun = useMemo(() => Boolean(imageUri), [imageUri]);
- useEffect(() => {
- let active = true;
- initCoreTables()
- .then(async () => {
- const rows = await loadLogs();
- if (active) setLogs(rows);
- })
- .catch((error) => {
- setStatus(`Error: ${String(error)}`);
- });
- return () => {
- active = false;
- };
- }, []);
- async function pickImage() {
- const result = await ImagePicker.launchImageLibraryAsync({
- mediaTypes: getImageMediaTypes(),
- quality: 1,
- });
- if (result.canceled) return;
- const asset = result.assets[0];
- setImageUri(asset.uri);
- setPredictions([]);
- setStatus(t('onnx.status.ready'));
- }
- async function handleRun() {
- if (!imageUri) return;
- setStatus(t('onnx.status.preprocessing'));
- try {
- const { runPlantVillageInference } = await import('@/services/onnx/runPlantVillage');
- const inputTensor = await imageToTensor(imageUri);
- setStatus(t('onnx.status.running'));
- const scores = await runPlantVillageInference(inputTensor);
- const probabilities = softmax(scores);
- const top = topK(probabilities, 3).map((item) => ({
- ...item,
- label: humanizeLabel(classNames[item.index] ?? `Class ${item.index}`),
- }));
- setPredictions(top);
- await logInference({
- uri: imageUri,
- top,
- scores: probabilities,
- });
- setLogs(await loadLogs());
- setStatus(t('onnx.status.done'));
- } catch (error) {
- const message = String(error);
- if (message.includes('install') || message.includes('onnxruntime')) {
- setStatus(t('onnx.status.nativeMissing'));
- } else {
- setStatus(`Error: ${message}`);
- }
- }
- }
- return (
- <ParallaxScrollView
- headerBackgroundColor={{ light: '#E9E4DB', dark: '#2A2A2A' }}
- headerImage={
- <Image
- source={require('@/assets/images/leafscan.jpg')}
- style={styles.headerImage}
- contentFit="cover"
- />
- }>
- <ThemedView style={[styles.hero, { backgroundColor: palette.surface, borderColor: palette.border }]}>
- <View style={[styles.heroIcon, { backgroundColor: palette.card }]}>
- <IconSymbol size={44} color={palette.tint} name="leaf.fill" />
- </View>
- <View style={styles.heroText}>
- <ThemedText type="title" style={styles.heroTitle}>
- {t('onnx.title')}
- </ThemedText>
- <ThemedText style={[styles.heroSubtitle, { color: palette.muted }]}>
- {t('onnx.howBody')}
- </ThemedText>
- </View>
- </ThemedView>
- <ThemedView style={[styles.card, { backgroundColor: palette.card, borderColor: palette.border }]}>
- <View style={styles.cardHeader}>
- <ThemedText type="subtitle">{t('onnx.testTitle')}</ThemedText>
- <View style={[styles.statusPill, { backgroundColor: palette.surface }]}>
- <ThemedText style={[styles.statusText, { color: palette.muted }]}>{status}</ThemedText>
- </View>
- </View>
- {!imageUri ? (
- <ThemedText style={[styles.inlineHint, { color: palette.muted }]}>
- Pick an image to run the model.
- </ThemedText>
- ) : null}
- <View
- style={[
- styles.previewWrap,
- { backgroundColor: palette.surface, borderColor: palette.border },
- ]}>
- {imageUri ? (
- <Image source={{ uri: imageUri }} style={styles.preview} contentFit="cover" />
- ) : (
- <View style={styles.previewPlaceholder}>
- <IconSymbol size={36} color={palette.muted} name="photo.on.rectangle.angled" />
- <ThemedText style={[styles.placeholderText, { color: palette.muted }]}>
- {t('onnx.status.pick')}
- </ThemedText>
- </View>
- )}
- </View>
- <View style={styles.actionRow}>
- <ThemedButton title={t('onnx.pickImage')} onPress={pickImage} variant="secondary" />
- <ThemedButton title={t('onnx.runModel')} onPress={handleRun} disabled={!canRun} />
- </View>
- </ThemedView>
- {predictions.length > 0 ? (
- <ThemedView style={[styles.card, { backgroundColor: palette.card, borderColor: palette.border }]}>
- <ThemedText type="subtitle">{t('onnx.topPredictions')}</ThemedText>
- <View style={styles.results}>
- {predictions.map((pred) => {
- const percent = Number((pred.score * 100).toFixed(2));
- return (
- <View key={pred.index} style={styles.predRow}>
- <View style={styles.predLabelRow}>
- <ThemedText style={styles.predLabel}>{pred.label}</ThemedText>
- <ThemedText style={styles.predPercent}>{percent}%</ThemedText>
- </View>
- <View style={styles.barTrack}>
- <View style={[styles.barFill, { width: `${Math.min(percent, 100)}%` }]} />
- </View>
- </View>
- );
- })}
- </View>
- </ThemedView>
- ) : null}
- {logs.length > 0 ? (
- <ThemedView style={[styles.card, { backgroundColor: palette.card, borderColor: palette.border }]}>
- <ThemedText type="subtitle">Recent scans</ThemedText>
- <View style={styles.logList}>
- {logs.map((item) => {
- const score = item.top_score ? Math.round(item.top_score * 100) : null;
- return (
- <View key={item.id} style={[styles.logRow, { backgroundColor: palette.surface }]}>
- <View style={[styles.logThumb, { backgroundColor: palette.card }]}>
- {item.uri ? (
- <Image source={{ uri: item.uri }} style={styles.logImage} contentFit="cover" />
- ) : (
- <IconSymbol size={18} color={palette.muted} name="photo" />
- )}
- </View>
- <View style={styles.logMeta}>
- <ThemedText style={styles.logLabel}>
- {item.top_label ?? t('onnx.topPredictions')}
- </ThemedText>
- <ThemedText style={[styles.logSub, { color: palette.muted }]}>
- {item.created_at ? formatDate(item.created_at) : ''}
- </ThemedText>
- </View>
- <ThemedText style={[styles.logScore, { color: palette.tint }]}>
- {score !== null ? `${score}%` : '--'}
- </ThemedText>
- </View>
- );
- })}
- </View>
- </ThemedView>
- ) : null}
- <ThemedView style={[styles.card, { backgroundColor: palette.card, borderColor: palette.border }]}>
- <ThemedText type="subtitle">{t('onnx.howTitle')}</ThemedText>
- <ThemedText>{t('onnx.howBody')}</ThemedText>
- </ThemedView>
- </ParallaxScrollView>
- );
- }
- function getImageMediaTypes() {
- const mediaType = (ImagePicker as { MediaType?: { Image?: unknown; Images?: unknown } })
- .MediaType;
- return mediaType?.Image ?? mediaType?.Images ?? ['images'];
- }
- function topK(scores: Float32Array, k: number): Prediction[] {
- const items = Array.from(scores).map((score, index) => ({
- index,
- label: `Class ${index}`,
- score,
- }));
- items.sort((a, b) => b.score - a.score);
- return items.slice(0, k);
- }
- function softmax(scores: Float32Array) {
- const max = Math.max(...scores);
- const exps = Array.from(scores).map((value) => Math.exp(value - max));
- const sum = exps.reduce((acc, value) => acc + value, 0);
- return new Float32Array(exps.map((value) => value / sum));
- }
- async function imageToTensor(uri: string) {
- const resized = await ImageManipulator.manipulateAsync(
- uri,
- [{ resize: { width: 224, height: 224 } }],
- {
- base64: true,
- format: ImageManipulator.SaveFormat.JPEG,
- compress: 1,
- }
- );
- if (!resized.base64) {
- throw new Error('Failed to read image data.');
- }
- const bytes = base64ToUint8Array(resized.base64);
- const decoded = jpeg.decode(bytes, { useTArray: true });
- const { data, width, height } = decoded;
- if (width !== 224 || height !== 224) {
- throw new Error(`Unexpected image size: ${width}x${height}`);
- }
- const size = width * height;
- const floatData = new Float32Array(1 * 3 * size);
- for (let i = 0; i < size; i += 1) {
- const pixelIndex = i * 4;
- const r = data[pixelIndex] / 255;
- const g = data[pixelIndex + 1] / 255;
- const b = data[pixelIndex + 2] / 255;
- floatData[i] = r;
- floatData[size + i] = g;
- floatData[size * 2 + i] = b;
- }
- return floatData;
- }
- function base64ToUint8Array(base64: string) {
- const cleaned = base64.replace(/[^A-Za-z0-9+/=]/g, '');
- const bytes: number[] = [];
- const chars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/';
- let i = 0;
- while (i < cleaned.length) {
- const enc1 = decodeBase64Char(cleaned.charAt(i++), chars);
- const enc2 = decodeBase64Char(cleaned.charAt(i++), chars);
- const enc3 = decodeBase64Char(cleaned.charAt(i++), chars);
- const enc4 = decodeBase64Char(cleaned.charAt(i++), chars);
- const chr1 = (enc1 << 2) | (enc2 >> 4);
- const chr2 = ((enc2 & 15) << 4) | (enc3 >> 2);
- const chr3 = ((enc3 & 3) << 6) | enc4;
- bytes.push(chr1);
- if (enc3 !== 64) bytes.push(chr2);
- if (enc4 !== 64) bytes.push(chr3);
- }
- return new Uint8Array(bytes);
- }
- function decodeBase64Char(char: string, chars: string) {
- if (char === '=') return 64;
- const index = chars.indexOf(char);
- if (index === -1) return 64;
- return index;
- }
- async function loadLogs() {
- const db = await dbPromise;
- const rows = await db.getAllAsync<InferenceLog>(
- `SELECT mi.id, mi.top_label, mi.top_score, mi.created_at, img.uri
- FROM ml_inferences mi
- LEFT JOIN images img ON img.id = mi.image_id
- ORDER BY mi.created_at DESC
- LIMIT 5;`
- );
- return rows;
- }
- async function logInference(params: {
- uri: string;
- top: Prediction[];
- scores: Float32Array;
- }) {
- const db = await dbPromise;
- const now = new Date().toISOString();
- const topHit = params.top[0];
- const scoresJson = JSON.stringify(
- params.top.map((item) => ({ index: item.index, label: item.label, score: item.score }))
- );
- const insert = await db.runAsync(
- 'INSERT INTO images (observation_id, uri, thumbnail_uri, width, height, created_at) VALUES (?, ?, ?, ?, ?, ?);',
- null,
- params.uri,
- null,
- 224,
- 224,
- now
- );
- await db.runAsync(
- 'INSERT INTO ml_inferences (image_id, model_name, model_version, top_label, top_score, scores_json, created_at) VALUES (?, ?, ?, ?, ?, ?, ?);',
- insert.lastInsertRowId,
- 'plantvillage_mnv3s_224',
- '1',
- topHit?.label ?? null,
- topHit?.score ?? null,
- scoresJson,
- now
- );
- }
- function humanizeLabel(label: string) {
- const cleaned = label
- .replace(/[_/]+/g, ' ')
- .replace(/\s+/g, ' ')
- .trim();
- return cleaned
- .split(' ')
- .map((word) => (word ? word[0].toUpperCase() + word.slice(1) : word))
- .join(' ');
- }
- function formatDate(value: string) {
- try {
- return new Date(value).toLocaleString();
- } catch {
- return value;
- }
- }
- const styles = StyleSheet.create({
- headerImage: {
- width: '100%',
- height: '100%',
- },
- hero: {
- borderRadius: 20,
- borderWidth: 1,
- flexDirection: 'row',
- gap: 14,
- padding: 16,
- alignItems: 'center',
- marginBottom: 16,
- },
- heroIcon: {
- borderRadius: 999,
- padding: 10,
- },
- heroText: {
- flex: 1,
- gap: 6,
- },
- heroTitle: {
- fontFamily: Fonts.rounded,
- },
- heroSubtitle: {
- color: '#516458',
- },
- card: {
- backgroundColor: '#FFFFFF',
- borderRadius: 18,
- padding: 16,
- gap: 12,
- marginBottom: 16,
- borderWidth: 1,
- borderColor: '#E3DED6',
- },
- cardHeader: {
- flexDirection: 'row',
- alignItems: 'center',
- gap: 12,
- flexWrap: 'wrap',
- },
- statusPill: {
- backgroundColor: '#E6F0E2',
- borderRadius: 999,
- paddingHorizontal: 10,
- paddingVertical: 4,
- },
- statusText: {
- fontSize: 12,
- color: '#3D5F4A',
- },
- previewWrap: {
- borderRadius: 16,
- overflow: 'hidden',
- backgroundColor: '#F3F0EA',
- borderWidth: 1,
- borderColor: '#E1DBD1',
- },
- preview: {
- height: 220,
- width: '100%',
- },
- previewPlaceholder: {
- height: 200,
- alignItems: 'center',
- justifyContent: 'center',
- gap: 8,
- },
- placeholderText: {
- color: '#6B736D',
- },
- inlineHint: {
- color: '#6B736D',
- fontSize: 13,
- },
- actionRow: {
- flexDirection: 'row',
- flexWrap: 'wrap',
- gap: 12,
- },
- results: {
- gap: 12,
- },
- logList: {
- gap: 10,
- },
- logRow: {
- flexDirection: 'row',
- alignItems: 'center',
- gap: 12,
- padding: 10,
- borderRadius: 12,
- backgroundColor: '#F7F4EE',
- },
- logThumb: {
- width: 48,
- height: 48,
- borderRadius: 10,
- backgroundColor: '#ECE7DE',
- alignItems: 'center',
- justifyContent: 'center',
- overflow: 'hidden',
- },
- logImage: {
- width: '100%',
- height: '100%',
- },
- logMeta: {
- flex: 1,
- gap: 2,
- },
- logLabel: {
- fontWeight: '600',
- },
- logSub: {
- fontSize: 12,
- color: '#6B736D',
- },
- logScore: {
- fontWeight: '600',
- color: '#3D5F4A',
- },
- predRow: {
- gap: 6,
- },
- predLabelRow: {
- flexDirection: 'row',
- justifyContent: 'space-between',
- alignItems: 'center',
- gap: 12,
- },
- predLabel: {
- flex: 1,
- },
- predPercent: {
- color: '#3D5F4A',
- },
- barTrack: {
- height: 8,
- borderRadius: 999,
- backgroundColor: '#E7E1D7',
- overflow: 'hidden',
- },
- barFill: {
- height: 8,
- borderRadius: 999,
- backgroundColor: '#4B7B57',
- },
- });
|